Compare commits
4 Commits
5a1f278df8
...
8b85a02003
| Author | SHA1 | Date | |
|---|---|---|---|
| 8b85a02003 | |||
| 555cb00276 | |||
| 7b935b0fa3 | |||
| aefe6d06cf |
@@ -88,6 +88,7 @@ __all__ = [
|
|||||||
# Historical stock list
|
# Historical stock list
|
||||||
"get_bak_basic",
|
"get_bak_basic",
|
||||||
"sync_bak_basic",
|
"sync_bak_basic",
|
||||||
|
"BakBasicSync",
|
||||||
# Namechange
|
# Namechange
|
||||||
"get_namechange",
|
"get_namechange",
|
||||||
"sync_namechange",
|
"sync_namechange",
|
||||||
@@ -105,3 +106,77 @@ __all__ = [
|
|||||||
"sync_stock_st",
|
"sync_stock_st",
|
||||||
"StockSTSync",
|
"StockSTSync",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# 自动注册同步任务到 SyncRegistry
|
||||||
|
# 这样 sync.py 不需要手动罗列各个接口
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
try:
|
||||||
|
from src.data.sync_registry import sync_registry
|
||||||
|
|
||||||
|
# 1. Trade Calendar - 最高优先级,其他任务可能依赖
|
||||||
|
sync_registry.register_func(
|
||||||
|
name="trade_cal",
|
||||||
|
sync_func=sync_trade_cal_cache,
|
||||||
|
display_name="交易日历",
|
||||||
|
description="交易日期缓存",
|
||||||
|
order=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 2. Stock Basic - 基础数据
|
||||||
|
sync_registry.register_func(
|
||||||
|
name="stock_basic",
|
||||||
|
sync_func=sync_all_stocks,
|
||||||
|
display_name="股票基本信息",
|
||||||
|
description="所有上市/退市股票的基础信息",
|
||||||
|
order=2,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 3. Pro Bar - 通用行情(推荐用于替代日线)
|
||||||
|
from src.data.api_wrappers.api_pro_bar import ProBarSync
|
||||||
|
|
||||||
|
sync_registry.register_class(
|
||||||
|
name="pro_bar",
|
||||||
|
sync_class=ProBarSync,
|
||||||
|
display_name="Pro Bar 数据",
|
||||||
|
description="包含复权因子、换手率、量比的数据",
|
||||||
|
order=10,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 4. Daily Basic - 每日指标
|
||||||
|
from src.data.api_wrappers.api_daily_basic import DailyBasicSync
|
||||||
|
|
||||||
|
sync_registry.register_class(
|
||||||
|
name="daily_basic",
|
||||||
|
sync_class=DailyBasicSync,
|
||||||
|
display_name="每日指标",
|
||||||
|
description="市盈率、市净率、换手率、市值等指标",
|
||||||
|
order=20,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 5. Bak Basic - 历史股票列表
|
||||||
|
from src.data.api_wrappers.api_bak_basic import BakBasicSync
|
||||||
|
|
||||||
|
sync_registry.register_class(
|
||||||
|
name="bak_basic",
|
||||||
|
sync_class=BakBasicSync,
|
||||||
|
display_name="历史股票列表",
|
||||||
|
description="历史股票列表(包含退市股票)",
|
||||||
|
order=30,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 6. ST Stock - ST股票列表
|
||||||
|
from src.data.api_wrappers.api_stock_st import StockSTSync
|
||||||
|
|
||||||
|
sync_registry.register_class(
|
||||||
|
name="stock_st",
|
||||||
|
sync_class=StockSTSync,
|
||||||
|
display_name="ST股票列表",
|
||||||
|
description="ST股票历史记录",
|
||||||
|
order=40,
|
||||||
|
)
|
||||||
|
|
||||||
|
except ImportError:
|
||||||
|
# sync_registry 可能不存在(首次导入),忽略
|
||||||
|
pass
|
||||||
|
|||||||
232
src/data/sync.py
232
src/data/sync.py
@@ -24,6 +24,10 @@
|
|||||||
from src.data.api_wrappers import sync_namechange
|
from src.data.api_wrappers import sync_namechange
|
||||||
sync_namechange(force=True)
|
sync_namechange(force=True)
|
||||||
|
|
||||||
|
【架构说明】
|
||||||
|
本模块使用 SyncRegistry 注册表模式管理同步任务,避免手动罗列各个接口。
|
||||||
|
同步任务在 api_wrappers/__init__.py 中自动注册,新增接口无需修改 sync.py。
|
||||||
|
|
||||||
使用方式:
|
使用方式:
|
||||||
# 预览同步(检查数据量,不写入)
|
# 预览同步(检查数据量,不写入)
|
||||||
from src.data.sync import preview_sync
|
from src.data.sync import preview_sync
|
||||||
@@ -35,18 +39,23 @@
|
|||||||
|
|
||||||
# 强制全量重载
|
# 强制全量重载
|
||||||
result = sync_all_data(force_full=True)
|
result = sync_all_data(force_full=True)
|
||||||
|
|
||||||
|
# 查看已注册的所有同步任务
|
||||||
|
from src.data.sync_registry import sync_registry
|
||||||
|
tasks = sync_registry.list_tasks()
|
||||||
|
for task in tasks:
|
||||||
|
print(f"{task.name}: {task.display_name}")
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Optional, Dict, Union, Any
|
from typing import Optional, Dict, Union, Any
|
||||||
|
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
|
|
||||||
|
# 导入以触发自动注册
|
||||||
|
from src.data import api_wrappers # noqa: F401
|
||||||
|
from src.data.sync_registry import sync_registry
|
||||||
from src.data.api_wrappers import sync_all_stocks
|
from src.data.api_wrappers import sync_all_stocks
|
||||||
from src.data.api_wrappers.api_daily import sync_daily, preview_daily_sync
|
from src.data.api_wrappers.api_daily import sync_daily, preview_daily_sync
|
||||||
from src.data.api_wrappers.api_pro_bar import sync_pro_bar
|
|
||||||
from src.data.api_wrappers.api_bak_basic import sync_bak_basic
|
|
||||||
from src.data.api_wrappers.api_daily_basic import sync_daily_basic
|
|
||||||
from src.data.api_wrappers.api_stock_st import sync_stock_st
|
|
||||||
|
|
||||||
|
|
||||||
def preview_sync(
|
def preview_sync(
|
||||||
@@ -150,19 +159,24 @@ def sync_all_data(
|
|||||||
force_full: bool = False,
|
force_full: bool = False,
|
||||||
max_workers: Optional[int] = None,
|
max_workers: Optional[int] = None,
|
||||||
dry_run: bool = False,
|
dry_run: bool = False,
|
||||||
|
selected: Optional[list[str]] = None,
|
||||||
) -> dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
"""同步所有每日更新的数据类型。
|
"""同步所有每日更新的数据类型。
|
||||||
|
|
||||||
【重要】本函数仅同步每日更新的数据,不包含季度/低频数据。
|
【重要】本函数仅同步每日更新的数据,不包含季度/低频数据。
|
||||||
|
|
||||||
该函数按顺序同步以下每日更新的数据类型:
|
【自动注册机制】
|
||||||
1. 交易日历 (sync_trade_cal_cache)
|
同步任务在 api_wrappers/__init__.py 中自动注册到 SyncRegistry。
|
||||||
2. 股票基本信息 (sync_all_stocks)
|
当前注册的同步任务(按执行顺序):
|
||||||
3. 日线数据 (sync_daily)
|
1. trade_cal: 交易日历缓存
|
||||||
4. Pro Bar 数据 (sync_pro_bar)
|
2. stock_basic: 股票基本信息
|
||||||
5. 每日指标数据 (sync_daily_basic)
|
3. pro_bar: Pro Bar 数据(复权、换手率、量比)
|
||||||
6. 历史股票列表 (sync_bak_basic)
|
4. daily_basic: 每日指标(PE、PB、换手率、市值)
|
||||||
7. ST股票列表 (sync_stock_st)
|
5. bak_basic: 历史股票列表
|
||||||
|
6. stock_st: ST股票列表
|
||||||
|
|
||||||
|
新增接口时,只需在 api_wrappers/__init__.py 中添加注册代码,
|
||||||
|
无需修改本函数。
|
||||||
|
|
||||||
【不包含的同步(需单独调用)】
|
【不包含的同步(需单独调用)】
|
||||||
- 财务数据: 利润表、资产负债表、现金流量表(季度更新)
|
- 财务数据: 利润表、资产负债表、现金流量表(季度更新)
|
||||||
@@ -177,6 +191,8 @@ def sync_all_data(
|
|||||||
force_full: 若为 True,强制所有数据类型完整重载
|
force_full: 若为 True,强制所有数据类型完整重载
|
||||||
max_workers: 日线数据同步的工作线程数(默认: 10)
|
max_workers: 日线数据同步的工作线程数(默认: 10)
|
||||||
dry_run: 若为 True,仅显示将要同步的内容,不写入数据
|
dry_run: 若为 True,仅显示将要同步的内容,不写入数据
|
||||||
|
selected: 只同步指定的任务列表,None表示同步所有
|
||||||
|
例如: selected=["trade_cal", "stock_basic"] 只同步交易日历和股票基本信息
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
映射数据类型到同步结果的字典
|
映射数据类型到同步结果的字典
|
||||||
@@ -189,163 +205,73 @@ def sync_all_data(
|
|||||||
>>>
|
>>>
|
||||||
>>> # Dry run
|
>>> # Dry run
|
||||||
>>> result = sync_all_data(dry_run=True)
|
>>> result = sync_all_data(dry_run=True)
|
||||||
|
>>>
|
||||||
|
>>> # 只同步特定任务
|
||||||
|
>>> result = sync_all_data(selected=["trade_cal", "stock_basic"])
|
||||||
|
>>>
|
||||||
|
>>> # 查看所有可用任务
|
||||||
|
>>> from src.data.sync_registry import sync_registry
|
||||||
|
>>> tasks = sync_registry.list_tasks()
|
||||||
|
>>> for t in tasks:
|
||||||
|
... print(f"{t.name}: {t.display_name}")
|
||||||
"""
|
"""
|
||||||
results: dict[str, Any] = {}
|
return sync_registry.sync_all(
|
||||||
|
force_full=force_full,
|
||||||
print("\n" + "=" * 60)
|
max_workers=max_workers,
|
||||||
print("[sync_all_data] Starting full data synchronization...")
|
dry_run=dry_run,
|
||||||
print("=" * 60)
|
selected=selected,
|
||||||
|
|
||||||
# 1. Sync trade calendar (always needed first)
|
|
||||||
print("\n[1/7] Syncing trade calendar cache...")
|
|
||||||
try:
|
|
||||||
from src.data.api_wrappers import sync_trade_cal_cache
|
|
||||||
|
|
||||||
sync_trade_cal_cache()
|
|
||||||
results["trade_cal"] = pd.DataFrame()
|
|
||||||
print("[1/7] Trade calendar: OK")
|
|
||||||
except Exception as e:
|
|
||||||
print(f"[1/7] Trade calendar: FAILED - {e}")
|
|
||||||
results["trade_cal"] = pd.DataFrame()
|
|
||||||
|
|
||||||
# 2. Sync stock basic info
|
|
||||||
print("\n[2/7] Syncing stock basic info...")
|
|
||||||
try:
|
|
||||||
sync_all_stocks()
|
|
||||||
results["stock_basic"] = pd.DataFrame()
|
|
||||||
print("[2/7] Stock basic: OK")
|
|
||||||
except Exception as e:
|
|
||||||
print(f"[2/7] Stock basic: FAILED - {e}")
|
|
||||||
results["stock_basic"] = pd.DataFrame()
|
|
||||||
|
|
||||||
# 3. Sync daily market data
|
|
||||||
# print("\n[3/7] Syncing daily market data...")
|
|
||||||
# try:
|
|
||||||
# # 确保表存在
|
|
||||||
# from src.data.api_wrappers.api_daily import DailySync
|
|
||||||
#
|
|
||||||
# DailySync().ensure_table_exists()
|
|
||||||
#
|
|
||||||
# daily_result = sync_daily(
|
|
||||||
# force_full=force_full,
|
|
||||||
# max_workers=max_workers,
|
|
||||||
# dry_run=dry_run,
|
|
||||||
# )
|
|
||||||
# results["daily"] = daily_result
|
|
||||||
# total_daily_records = (
|
|
||||||
# sum(len(df) for df in daily_result.values()) if daily_result else 0
|
|
||||||
# )
|
|
||||||
# print(
|
|
||||||
# f"[3/7] Daily data: OK ({total_daily_records} records from {len(daily_result)} stocks)"
|
|
||||||
# )
|
|
||||||
# except Exception as e:
|
|
||||||
# print(f"[3/7] Daily data: FAILED - {e}")
|
|
||||||
# results["daily"] = pd.DataFrame()
|
|
||||||
|
|
||||||
# 4. Sync Pro Bar data
|
|
||||||
print("\n[4/7] Syncing Pro Bar data (with adj, tor, vr)...")
|
|
||||||
try:
|
|
||||||
# 确保表存在
|
|
||||||
from src.data.api_wrappers.api_pro_bar import ProBarSync
|
|
||||||
|
|
||||||
ProBarSync().ensure_table_exists()
|
|
||||||
|
|
||||||
pro_bar_result = sync_pro_bar(
|
|
||||||
force_full=force_full,
|
|
||||||
max_workers=max_workers,
|
|
||||||
dry_run=dry_run,
|
|
||||||
)
|
|
||||||
results["pro_bar"] = pro_bar_result
|
|
||||||
total_pro_bar_records = (
|
|
||||||
sum(len(df) for df in pro_bar_result.values()) if pro_bar_result else 0
|
|
||||||
)
|
|
||||||
print(
|
|
||||||
f"[4/7] Pro Bar data: OK ({total_pro_bar_records} records from {len(pro_bar_result)} stocks)"
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
print(f"[4/7] Pro Bar data: FAILED - {e}")
|
|
||||||
results["pro_bar"] = pd.DataFrame()
|
|
||||||
|
|
||||||
# 5. Sync daily basic indicators
|
|
||||||
print(
|
|
||||||
"\n[5/7] Syncing daily basic indicators (PE, PB, turnover rate, market value)..."
|
|
||||||
)
|
)
|
||||||
try:
|
|
||||||
# 确保表存在
|
|
||||||
from src.data.api_wrappers.api_daily_basic import DailyBasicSync
|
|
||||||
|
|
||||||
DailyBasicSync().ensure_table_exists()
|
|
||||||
|
|
||||||
daily_basic_result = sync_daily_basic(force_full=force_full, dry_run=dry_run)
|
def list_sync_tasks() -> list[dict[str, Any]]:
|
||||||
results["daily_basic"] = daily_basic_result
|
"""列出所有已注册的同步任务。
|
||||||
print(f"[5/7] Daily basic: OK ({len(daily_basic_result)} records)")
|
|
||||||
except Exception as e:
|
|
||||||
print(f"[5/7] Daily basic: FAILED - {e}")
|
|
||||||
results["daily_basic"] = pd.DataFrame()
|
|
||||||
|
|
||||||
# 6. Sync stock historical list (bak_basic)
|
Returns:
|
||||||
print("\n[6/7] Syncing stock historical list (bak_basic)...")
|
任务信息列表,每个任务包含 name, display_name, description, order, enabled
|
||||||
try:
|
|
||||||
# 确保表存在
|
|
||||||
from src.data.api_wrappers.api_bak_basic import BakBasicSync
|
|
||||||
|
|
||||||
BakBasicSync().ensure_table_exists()
|
Example:
|
||||||
|
>>> tasks = list_sync_tasks()
|
||||||
bak_basic_result = sync_bak_basic(force_full=force_full)
|
>>> for task in tasks:
|
||||||
results["bak_basic"] = bak_basic_result
|
... print(f"{task['order']:2d}. {task['name']}: {task['display_name']}")
|
||||||
print(f"[6/7] Bak basic: OK ({len(bak_basic_result)} records)")
|
"""
|
||||||
except Exception as e:
|
tasks = sync_registry.list_tasks()
|
||||||
print(f"[6/7] Bak basic: FAILED - {e}")
|
return [
|
||||||
results["bak_basic"] = pd.DataFrame()
|
{
|
||||||
|
"name": t.name,
|
||||||
# 7. Sync ST stock list
|
"display_name": t.display_name,
|
||||||
print("\n[7/7] Syncing ST stock list...")
|
"description": t.description,
|
||||||
try:
|
"order": t.order,
|
||||||
# 确保表存在
|
"enabled": t.enabled,
|
||||||
from src.data.api_wrappers.api_stock_st import StockSTSync
|
}
|
||||||
|
for t in tasks
|
||||||
StockSTSync().ensure_table_exists()
|
]
|
||||||
|
|
||||||
stock_st_result = sync_stock_st(force_full=force_full)
|
|
||||||
results["stock_st"] = stock_st_result
|
|
||||||
print(f"[7/7] ST stock list: OK ({len(stock_st_result)} records)")
|
|
||||||
except Exception as e:
|
|
||||||
print(f"[7/7] ST stock list: FAILED - {e}")
|
|
||||||
results["stock_st"] = pd.DataFrame()
|
|
||||||
|
|
||||||
# Summary
|
|
||||||
print("\n" + "=" * 60)
|
|
||||||
print("[sync_all_data] Sync Summary")
|
|
||||||
print("=" * 60)
|
|
||||||
for data_type, data in results.items():
|
|
||||||
if isinstance(data, dict):
|
|
||||||
# 日线和 Pro Bar 返回的是 dict[str, DataFrame]
|
|
||||||
total_records = sum(len(df) for df in data.values())
|
|
||||||
print(f" {data_type}: {len(data)} stocks, {total_records} total records")
|
|
||||||
else:
|
|
||||||
# daily_basic 和 bak_basic 返回的是 DataFrame
|
|
||||||
print(f" {data_type}: {len(data)} records")
|
|
||||||
print("=" * 60)
|
|
||||||
print("\nNote: namechange is NOT in auto-sync. To sync manually:")
|
|
||||||
print(" from src.data.api_wrappers import sync_namechange")
|
|
||||||
print(" sync_namechange(force=True)")
|
|
||||||
|
|
||||||
return results
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
print("=" * 60)
|
print("=" * 60)
|
||||||
print("Data Sync Module")
|
print("Data Sync Module")
|
||||||
print("=" * 60)
|
print("=" * 60)
|
||||||
|
print("\nRegistered sync tasks:")
|
||||||
|
print("-" * 60)
|
||||||
|
|
||||||
|
tasks = list_sync_tasks()
|
||||||
|
for task in tasks:
|
||||||
|
status = "[启用]" if task["enabled"] else "[禁用]"
|
||||||
|
print(f" {status} {task['order']:2d}. {task['name']}: {task['display_name']}")
|
||||||
|
|
||||||
|
print("-" * 60)
|
||||||
|
print(f"\nTotal: {len(tasks)} tasks")
|
||||||
print("\nUsage:")
|
print("\nUsage:")
|
||||||
print(" # Sync all data types at once (RECOMMENDED)")
|
print(" # Sync all data types at once (RECOMMENDED)")
|
||||||
print(" from src.data.sync import sync_all_data")
|
print(" from src.data.sync import sync_all_data")
|
||||||
print(" result = sync_all_data() # Incremental sync all")
|
print(" result = sync_all_data() # Incremental sync all")
|
||||||
print(" result = sync_all_data(force_full=True) # Full reload")
|
print(" result = sync_all_data(force_full=True) # Full reload")
|
||||||
print("")
|
print("")
|
||||||
print(" # Or sync individual data types:")
|
print(" # Sync selected data types only")
|
||||||
print(" from src.data.sync import sync_all, preview_sync")
|
print(" result = sync_all_data(selected=['trade_cal', 'pro_bar'])")
|
||||||
print(" from src.data.api_wrappers import sync_daily_basic, sync_bak_basic")
|
print("")
|
||||||
|
print(" # List all available sync tasks")
|
||||||
|
print(" tasks = list_sync_tasks()")
|
||||||
print("")
|
print("")
|
||||||
print(" # Preview before sync (recommended)")
|
print(" # Preview before sync (recommended)")
|
||||||
print(" preview = preview_sync()")
|
print(" preview = preview_sync()")
|
||||||
@@ -356,10 +282,6 @@ if __name__ == "__main__":
|
|||||||
print(" # Actual sync")
|
print(" # Actual sync")
|
||||||
print(" result = sync_all() # Incremental sync")
|
print(" result = sync_all() # Incremental sync")
|
||||||
print(" result = sync_all(force_full=True) # Full reload")
|
print(" result = sync_all(force_full=True) # Full reload")
|
||||||
print("")
|
|
||||||
print(" # bak_basic sync")
|
|
||||||
print(" result = sync_bak_basic() # Incremental sync")
|
|
||||||
print(" result = sync_bak_basic(force_full=True) # Full reload")
|
|
||||||
print("\n" + "=" * 60)
|
print("\n" + "=" * 60)
|
||||||
|
|
||||||
# Run sync_all_data by default
|
# Run sync_all_data by default
|
||||||
|
|||||||
333
src/data/sync_registry.py
Normal file
333
src/data/sync_registry.py
Normal file
@@ -0,0 +1,333 @@
|
|||||||
|
"""数据同步注册表模块。
|
||||||
|
|
||||||
|
该模块提供统一的同步任务注册和管理机制,避免在 sync.py 中手动罗列各个接口。
|
||||||
|
|
||||||
|
使用方式:
|
||||||
|
# 在 api_xxx.py 文件中注册同步任务
|
||||||
|
from src.data.sync_registry import sync_registry, SyncTask
|
||||||
|
|
||||||
|
sync_registry.register(
|
||||||
|
SyncTask(
|
||||||
|
name="pro_bar",
|
||||||
|
display_name="Pro Bar 数据",
|
||||||
|
description="包含复权因子、换手率、量比的数据",
|
||||||
|
sync_func=lambda **kwargs: ProBarSync().sync_all(**kwargs),
|
||||||
|
preview_func=lambda **kwargs: ProBarSync().preview_sync(**kwargs),
|
||||||
|
order=10, # 执行顺序
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# 在 sync.py 中统一执行
|
||||||
|
from src.data.sync_registry import sync_registry
|
||||||
|
results = sync_registry.sync_all(force_full=False, dry_run=False)
|
||||||
|
"""
|
||||||
|
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Callable, Optional, Any
|
||||||
|
from collections import OrderedDict
|
||||||
|
|
||||||
|
import pandas as pd
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class SyncTask:
|
||||||
|
"""同步任务定义。
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
name: 任务唯一标识名
|
||||||
|
display_name: 显示名称(用于日志)
|
||||||
|
description: 任务描述
|
||||||
|
sync_func: 同步函数,接收 force_full, dry_run 等参数
|
||||||
|
preview_func: 预览函数(可选)
|
||||||
|
order: 执行顺序(数字越小越先执行,默认100)
|
||||||
|
enabled: 是否启用(默认True)
|
||||||
|
"""
|
||||||
|
|
||||||
|
name: str
|
||||||
|
display_name: str
|
||||||
|
description: str
|
||||||
|
sync_func: Callable[..., pd.DataFrame | dict[str, pd.DataFrame]]
|
||||||
|
preview_func: Optional[Callable[..., Any]] = None
|
||||||
|
order: int = 100
|
||||||
|
enabled: bool = True
|
||||||
|
|
||||||
|
|
||||||
|
class SyncRegistry:
|
||||||
|
"""同步任务注册表。
|
||||||
|
|
||||||
|
统一管理所有数据同步任务,支持自动发现和批量执行。
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> registry = SyncRegistry()
|
||||||
|
>>>
|
||||||
|
>>> # 注册类方式同步器
|
||||||
|
>>> registry.register_class("daily", DailySync, "日线数据", "股票日线行情", order=10)
|
||||||
|
>>>
|
||||||
|
>>> # 注册函数方式同步器
|
||||||
|
>>> registry.register_func("stock_basic", sync_all_stocks, "股票基本信息", order=5)
|
||||||
|
>>>
|
||||||
|
>>> # 批量执行所有同步
|
||||||
|
>>> results = registry.sync_all()
|
||||||
|
>>>
|
||||||
|
>>> # 只执行特定任务
|
||||||
|
>>> results = registry.sync_selected(["stock_basic", "daily"])
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self._tasks: OrderedDict[str, SyncTask] = OrderedDict()
|
||||||
|
|
||||||
|
def register(self, task: SyncTask) -> "SyncRegistry":
|
||||||
|
"""注册同步任务。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task: 同步任务定义
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
self,支持链式调用
|
||||||
|
"""
|
||||||
|
if task.name in self._tasks:
|
||||||
|
print(
|
||||||
|
f"[SyncRegistry] Warning: Task '{task.name}' already registered, overwriting"
|
||||||
|
)
|
||||||
|
|
||||||
|
self._tasks[task.name] = task
|
||||||
|
return self
|
||||||
|
|
||||||
|
def register_class(
|
||||||
|
self,
|
||||||
|
name: str,
|
||||||
|
sync_class: type,
|
||||||
|
display_name: str,
|
||||||
|
description: str,
|
||||||
|
order: int = 100,
|
||||||
|
) -> "SyncRegistry":
|
||||||
|
"""注册基于类的同步器。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: 任务名
|
||||||
|
sync_class: 同步器类(必须有 sync_all() 和 ensure_table_exists() 方法)
|
||||||
|
display_name: 显示名称
|
||||||
|
description: 描述
|
||||||
|
order: 执行顺序
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
self,支持链式调用
|
||||||
|
"""
|
||||||
|
|
||||||
|
def sync_func(**kwargs) -> dict[str, pd.DataFrame]:
|
||||||
|
instance = sync_class()
|
||||||
|
instance.ensure_table_exists()
|
||||||
|
return instance.sync_all(**kwargs)
|
||||||
|
|
||||||
|
def preview_func(**kwargs) -> Any:
|
||||||
|
instance = sync_class()
|
||||||
|
return instance.preview_sync(**kwargs)
|
||||||
|
|
||||||
|
return self.register(
|
||||||
|
SyncTask(
|
||||||
|
name=name,
|
||||||
|
display_name=display_name,
|
||||||
|
description=description,
|
||||||
|
sync_func=sync_func,
|
||||||
|
preview_func=preview_func,
|
||||||
|
order=order,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
def register_func(
|
||||||
|
self,
|
||||||
|
name: str,
|
||||||
|
sync_func: Callable[..., pd.DataFrame],
|
||||||
|
display_name: str,
|
||||||
|
description: str = "",
|
||||||
|
order: int = 100,
|
||||||
|
) -> "SyncRegistry":
|
||||||
|
"""注册基于函数的同步器。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: 任务名
|
||||||
|
sync_func: 同步函数
|
||||||
|
display_name: 显示名称
|
||||||
|
description: 描述
|
||||||
|
order: 执行顺序
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
self,支持链式调用
|
||||||
|
"""
|
||||||
|
return self.register(
|
||||||
|
SyncTask(
|
||||||
|
name=name,
|
||||||
|
display_name=display_name,
|
||||||
|
description=description,
|
||||||
|
sync_func=sync_func,
|
||||||
|
order=order,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_task(self, name: str) -> Optional[SyncTask]:
|
||||||
|
"""获取指定任务。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: 任务名
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
SyncTask 或 None
|
||||||
|
"""
|
||||||
|
return self._tasks.get(name)
|
||||||
|
|
||||||
|
def list_tasks(self, enabled_only: bool = True) -> list[SyncTask]:
|
||||||
|
"""获取所有任务列表(按 order 排序)。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
enabled_only: 是否只返回启用的任务
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
排序后的任务列表
|
||||||
|
"""
|
||||||
|
tasks = self._tasks.values()
|
||||||
|
if enabled_only:
|
||||||
|
tasks = [t for t in tasks if t.enabled]
|
||||||
|
return sorted(tasks, key=lambda t: t.order)
|
||||||
|
|
||||||
|
def sync_all(
|
||||||
|
self,
|
||||||
|
force_full: bool = False,
|
||||||
|
dry_run: bool = False,
|
||||||
|
max_workers: Optional[int] = None,
|
||||||
|
selected: Optional[list[str]] = None,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""执行所有同步任务。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
force_full: 是否强制完整重载
|
||||||
|
dry_run: 是否仅预览
|
||||||
|
max_workers: 工作线程数(传递给支持的任务)
|
||||||
|
selected: 只执行指定的任务列表,None表示执行所有
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
每个任务的执行结果字典
|
||||||
|
"""
|
||||||
|
tasks = self.list_tasks(enabled_only=True)
|
||||||
|
|
||||||
|
if selected:
|
||||||
|
tasks = [t for t in tasks if t.name in selected]
|
||||||
|
|
||||||
|
total = len(tasks)
|
||||||
|
results: dict[str, Any] = {}
|
||||||
|
|
||||||
|
print("\n" + "=" * 60)
|
||||||
|
print("[SyncRegistry] Starting data synchronization...")
|
||||||
|
print(f"[SyncRegistry] Total tasks: {total}")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
for idx, task in enumerate(tasks, 1):
|
||||||
|
print(f"\n[{idx}/{total}] Syncing {task.display_name}...")
|
||||||
|
if task.description:
|
||||||
|
print(f" Description: {task.description}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 构建参数
|
||||||
|
kwargs: dict[str, Any] = {
|
||||||
|
"force_full": force_full,
|
||||||
|
"dry_run": dry_run,
|
||||||
|
}
|
||||||
|
if max_workers is not None:
|
||||||
|
kwargs["max_workers"] = max_workers
|
||||||
|
|
||||||
|
# 执行同步
|
||||||
|
result = task.sync_func(**kwargs)
|
||||||
|
results[task.name] = result
|
||||||
|
|
||||||
|
# 输出统计信息
|
||||||
|
if isinstance(result, dict):
|
||||||
|
# 返回 dict[str, DataFrame],如日线数据
|
||||||
|
total_records = sum(len(df) for df in result.values())
|
||||||
|
print(
|
||||||
|
f"[{idx}/{total}] {task.display_name}: OK ({len(result)} items, {total_records} records)"
|
||||||
|
)
|
||||||
|
elif isinstance(result, pd.DataFrame):
|
||||||
|
# 返回 DataFrame
|
||||||
|
print(
|
||||||
|
f"[{idx}/{total}] {task.display_name}: OK ({len(result)} records)"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
print(f"[{idx}/{total}] {task.display_name}: OK")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"[{idx}/{total}] {task.display_name}: FAILED - {e}")
|
||||||
|
results[task.name] = pd.DataFrame()
|
||||||
|
|
||||||
|
# Summary
|
||||||
|
print("\n" + "=" * 60)
|
||||||
|
print("[SyncRegistry] Sync Summary")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
success_count = sum(
|
||||||
|
1
|
||||||
|
for name in results
|
||||||
|
if not (isinstance(results[name], pd.DataFrame) and results[name].empty)
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
f"Total tasks: {total}, Success: {success_count}, Failed: {total - success_count}"
|
||||||
|
)
|
||||||
|
|
||||||
|
for name, data in results.items():
|
||||||
|
task = self._tasks.get(name)
|
||||||
|
display_name = task.display_name if task else name
|
||||||
|
|
||||||
|
if isinstance(data, dict):
|
||||||
|
total_records = sum(len(df) for df in data.values())
|
||||||
|
print(f" {display_name}: {len(data)} items, {total_records} records")
|
||||||
|
elif isinstance(data, pd.DataFrame):
|
||||||
|
status = "OK" if not data.empty else "EMPTY/FAILED"
|
||||||
|
print(f" {display_name}: {len(data)} records ({status})")
|
||||||
|
else:
|
||||||
|
print(f" {display_name}: Completed")
|
||||||
|
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
def preview_all(
|
||||||
|
self,
|
||||||
|
force_full: bool = False,
|
||||||
|
selected: Optional[list[str]] = None,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""预览所有启用的任务。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
force_full: 是否预览完整重载
|
||||||
|
selected: 只预览指定的任务列表
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
每个任务的预览结果
|
||||||
|
"""
|
||||||
|
tasks = self.list_tasks(enabled_only=True)
|
||||||
|
|
||||||
|
if selected:
|
||||||
|
tasks = [t for t in tasks if t.name in selected]
|
||||||
|
|
||||||
|
results: dict[str, Any] = {}
|
||||||
|
|
||||||
|
print("\n" + "=" * 60)
|
||||||
|
print("[SyncRegistry] Previewing sync tasks...")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
for task in tasks:
|
||||||
|
if task.preview_func is None:
|
||||||
|
print(f"\n[{task.display_name}] Preview not supported")
|
||||||
|
continue
|
||||||
|
|
||||||
|
print(f"\n[{task.display_name}] Previewing...")
|
||||||
|
try:
|
||||||
|
result = task.preview_func(force_full=force_full)
|
||||||
|
results[task.name] = result
|
||||||
|
except Exception as e:
|
||||||
|
print(f"[{task.display_name}] Preview failed: {e}")
|
||||||
|
results[task.name] = None
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
# Global registry instance
|
||||||
|
sync_registry = SyncRegistry()
|
||||||
1257
src/experiment/regression.ipynb
Normal file
1257
src/experiment/regression.ipynb
Normal file
File diff suppressed because one or more lines are too long
@@ -1,7 +1,7 @@
|
|||||||
"""LightGBM 回归训练示例 - 使用因子字符串表达式
|
"""LightGBM 回归训练示例 - 使用因子字符串表达式
|
||||||
|
|
||||||
使用字符串表达式定义因子,训练 LightGBM 回归模型预测未来5日收益率。
|
使用字符串表达式定义因子,训练 LightGBM 回归模型预测未来5日收益率。
|
||||||
Label: return_5 = (close / ts_delay(close, 5)) - 1
|
Label: return_5 = (ts_delay(close, -5) / close) - 1 # 未来5日收益率
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
@@ -20,6 +20,7 @@ from src.training import (
|
|||||||
StockPoolManager,
|
StockPoolManager,
|
||||||
Trainer,
|
Trainer,
|
||||||
Winsorizer,
|
Winsorizer,
|
||||||
|
NullFiller,
|
||||||
)
|
)
|
||||||
from src.training.config import TrainingConfig
|
from src.training.config import TrainingConfig
|
||||||
|
|
||||||
@@ -56,7 +57,7 @@ FACTOR_DEFINITIONS = {
|
|||||||
|
|
||||||
# Label 因子定义(不参与训练,用于计算目标)
|
# Label 因子定义(不参与训练,用于计算目标)
|
||||||
LABEL_FACTOR = {
|
LABEL_FACTOR = {
|
||||||
"return_5": "(close / ts_delay(close, 5)) - 1",
|
"return_5": "(ts_delay(close, -5) / close) - 1", # 未来5日收益率
|
||||||
}
|
}
|
||||||
|
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
@@ -224,6 +225,7 @@ def train_regression_model():
|
|||||||
|
|
||||||
# 6. 创建数据处理器(从 PROCESSOR_CONFIGS 解析)
|
# 6. 创建数据处理器(从 PROCESSOR_CONFIGS 解析)
|
||||||
processors = [
|
processors = [
|
||||||
|
NullFiller(strategy="mean"),
|
||||||
Winsorizer(**PROCESSOR_CONFIGS[0]["params"]), # type: ignore[arg-type]
|
Winsorizer(**PROCESSOR_CONFIGS[0]["params"]), # type: ignore[arg-type]
|
||||||
StandardScaler(exclude_cols=["ts_code", "trade_date", target_col]), # type: ignore[call-arg]
|
StandardScaler(exclude_cols=["ts_code", "trade_date", target_col]), # type: ignore[call-arg]
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -26,6 +26,7 @@ from src.training.components.selectors import (
|
|||||||
# 数据处理器
|
# 数据处理器
|
||||||
from src.training.components.processors import (
|
from src.training.components.processors import (
|
||||||
CrossSectionalStandardScaler,
|
CrossSectionalStandardScaler,
|
||||||
|
NullFiller,
|
||||||
StandardScaler,
|
StandardScaler,
|
||||||
Winsorizer,
|
Winsorizer,
|
||||||
)
|
)
|
||||||
@@ -57,6 +58,7 @@ __all__ = [
|
|||||||
"StockFilterConfig",
|
"StockFilterConfig",
|
||||||
"MarketCapSelectorConfig",
|
"MarketCapSelectorConfig",
|
||||||
# 数据处理器
|
# 数据处理器
|
||||||
|
"NullFiller",
|
||||||
"StandardScaler",
|
"StandardScaler",
|
||||||
"CrossSectionalStandardScaler",
|
"CrossSectionalStandardScaler",
|
||||||
"Winsorizer",
|
"Winsorizer",
|
||||||
|
|||||||
@@ -5,11 +5,13 @@
|
|||||||
|
|
||||||
from src.training.components.processors.transforms import (
|
from src.training.components.processors.transforms import (
|
||||||
CrossSectionalStandardScaler,
|
CrossSectionalStandardScaler,
|
||||||
|
NullFiller,
|
||||||
StandardScaler,
|
StandardScaler,
|
||||||
Winsorizer,
|
Winsorizer,
|
||||||
)
|
)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
|
"NullFiller",
|
||||||
"StandardScaler",
|
"StandardScaler",
|
||||||
"CrossSectionalStandardScaler",
|
"CrossSectionalStandardScaler",
|
||||||
"Winsorizer",
|
"Winsorizer",
|
||||||
|
|||||||
@@ -1,9 +1,9 @@
|
|||||||
"""数据处理器实现
|
"""数据处理器实现
|
||||||
|
|
||||||
包含标准化、缩尾等数据处理器。
|
包含标准化、缩尾、缺失值填充等数据处理器。
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import List, Optional
|
from typing import List, Literal, Optional, Union
|
||||||
|
|
||||||
import polars as pl
|
import polars as pl
|
||||||
|
|
||||||
@@ -11,6 +11,204 @@ from src.training.components.base import BaseProcessor
|
|||||||
from src.training.registry import register_processor
|
from src.training.registry import register_processor
|
||||||
|
|
||||||
|
|
||||||
|
@register_processor("null_filler")
|
||||||
|
class NullFiller(BaseProcessor):
|
||||||
|
"""缺失值填充处理器
|
||||||
|
|
||||||
|
支持多种填充策略:固定值、0、均值、中值。
|
||||||
|
可以全局填充或使用当天截面统计量填充。
|
||||||
|
|
||||||
|
填充策略:
|
||||||
|
- "zero": 填充0
|
||||||
|
- "mean": 填充均值(全局或当天截面)
|
||||||
|
- "median": 填充中值(全局或当天截面)
|
||||||
|
- "value": 填充指定数值
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
strategy: 填充策略,可选 "zero", "mean", "median", "value"
|
||||||
|
fill_value: 当 strategy="value" 时使用的填充值
|
||||||
|
by_date: 是否按日期独立计算统计量(仅对 mean/median 有效)
|
||||||
|
date_col: 日期列名
|
||||||
|
exclude_cols: 不参与填充的列名列表
|
||||||
|
stats_: 存储学习到的统计量(全局模式)
|
||||||
|
"""
|
||||||
|
|
||||||
|
name = "null_filler"
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
strategy: Literal["zero", "mean", "median", "value"] = "zero",
|
||||||
|
fill_value: Optional[float] = None,
|
||||||
|
by_date: bool = True,
|
||||||
|
date_col: str = "trade_date",
|
||||||
|
exclude_cols: Optional[List[str]] = None,
|
||||||
|
):
|
||||||
|
"""初始化缺失值填充处理器
|
||||||
|
|
||||||
|
Args:
|
||||||
|
strategy: 填充策略,默认 "zero"
|
||||||
|
- "zero": 填充0
|
||||||
|
- "mean": 填充均值
|
||||||
|
- "median": 填充中值
|
||||||
|
- "value": 填充指定数值(需配合 fill_value)
|
||||||
|
fill_value: 当 strategy="value" 时的填充值,默认为 None
|
||||||
|
by_date: 是否每天独立计算统计量,默认 False(全局统计量)
|
||||||
|
date_col: 日期列名,默认 "trade_date"
|
||||||
|
exclude_cols: 不参与填充的列名列表,默认为 ["ts_code", "trade_date"]
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: 策略无效或 fill_value 未提供时
|
||||||
|
"""
|
||||||
|
if strategy not in ("zero", "mean", "median", "value"):
|
||||||
|
raise ValueError(
|
||||||
|
f"无效的填充策略: {strategy},必须是 'zero', 'mean', 'median', 'value' 之一"
|
||||||
|
)
|
||||||
|
|
||||||
|
if strategy == "value" and fill_value is None:
|
||||||
|
raise ValueError("当 strategy='value' 时,必须提供 fill_value")
|
||||||
|
|
||||||
|
self.strategy = strategy
|
||||||
|
self.fill_value = fill_value
|
||||||
|
self.by_date = by_date
|
||||||
|
self.date_col = date_col
|
||||||
|
self.exclude_cols = exclude_cols or ["ts_code", "trade_date"]
|
||||||
|
self.stats_: dict = {}
|
||||||
|
|
||||||
|
def fit(self, X: pl.DataFrame) -> "NullFiller":
|
||||||
|
"""学习统计量(仅在全局模式下)
|
||||||
|
|
||||||
|
在全局模式下,计算每列的均值或中值作为填充值。
|
||||||
|
在截面模式下(by_date=True),不需要 fit,每天独立计算。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
X: 训练数据
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
self
|
||||||
|
"""
|
||||||
|
if not self.by_date and self.strategy in ("mean", "median"):
|
||||||
|
numeric_cols = [
|
||||||
|
c
|
||||||
|
for c in X.columns
|
||||||
|
if c not in self.exclude_cols and X[c].dtype.is_numeric()
|
||||||
|
]
|
||||||
|
|
||||||
|
for col in numeric_cols:
|
||||||
|
if self.strategy == "mean":
|
||||||
|
self.stats_[col] = X[col].mean()
|
||||||
|
else: # median
|
||||||
|
self.stats_[col] = X[col].median()
|
||||||
|
|
||||||
|
return self
|
||||||
|
|
||||||
|
def transform(self, X: pl.DataFrame) -> pl.DataFrame:
|
||||||
|
"""填充缺失值
|
||||||
|
|
||||||
|
Args:
|
||||||
|
X: 待转换数据
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
填充后的数据
|
||||||
|
"""
|
||||||
|
if self.strategy == "zero":
|
||||||
|
return self._fill_with_zero(X)
|
||||||
|
elif self.strategy == "value":
|
||||||
|
return self._fill_with_value(X)
|
||||||
|
elif self.strategy in ("mean", "median"):
|
||||||
|
if self.by_date:
|
||||||
|
return self._fill_by_date(X)
|
||||||
|
else:
|
||||||
|
return self._fill_global(X)
|
||||||
|
else:
|
||||||
|
# 不应该到达这里,因为 __init__ 已经验证
|
||||||
|
raise ValueError(f"未知的填充策略: {self.strategy}")
|
||||||
|
|
||||||
|
def _fill_with_zero(self, X: pl.DataFrame) -> pl.DataFrame:
|
||||||
|
"""使用0填充缺失值"""
|
||||||
|
numeric_cols = [
|
||||||
|
c
|
||||||
|
for c in X.columns
|
||||||
|
if c not in self.exclude_cols and X[c].dtype.is_numeric()
|
||||||
|
]
|
||||||
|
|
||||||
|
expressions = []
|
||||||
|
for col in X.columns:
|
||||||
|
if col in numeric_cols:
|
||||||
|
expr = pl.col(col).fill_null(0).alias(col)
|
||||||
|
expressions.append(expr)
|
||||||
|
else:
|
||||||
|
expressions.append(pl.col(col))
|
||||||
|
|
||||||
|
return X.select(expressions)
|
||||||
|
|
||||||
|
def _fill_with_value(self, X: pl.DataFrame) -> pl.DataFrame:
|
||||||
|
"""使用指定值填充缺失值"""
|
||||||
|
numeric_cols = [
|
||||||
|
c
|
||||||
|
for c in X.columns
|
||||||
|
if c not in self.exclude_cols and X[c].dtype.is_numeric()
|
||||||
|
]
|
||||||
|
|
||||||
|
expressions = []
|
||||||
|
for col in X.columns:
|
||||||
|
if col in numeric_cols:
|
||||||
|
expr = pl.col(col).fill_null(self.fill_value).alias(col)
|
||||||
|
expressions.append(expr)
|
||||||
|
else:
|
||||||
|
expressions.append(pl.col(col))
|
||||||
|
|
||||||
|
return X.select(expressions)
|
||||||
|
|
||||||
|
def _fill_global(self, X: pl.DataFrame) -> pl.DataFrame:
|
||||||
|
"""使用全局统计量填充(训练集学到的统计量)"""
|
||||||
|
expressions = []
|
||||||
|
for col in X.columns:
|
||||||
|
if col in self.stats_:
|
||||||
|
fill_val = self.stats_[col]
|
||||||
|
expr = pl.col(col).fill_null(fill_val).alias(col)
|
||||||
|
expressions.append(expr)
|
||||||
|
else:
|
||||||
|
expressions.append(pl.col(col))
|
||||||
|
|
||||||
|
return X.select(expressions)
|
||||||
|
|
||||||
|
def _fill_by_date(self, X: pl.DataFrame) -> pl.DataFrame:
|
||||||
|
"""使用每天截面统计量填充"""
|
||||||
|
numeric_cols = [
|
||||||
|
c
|
||||||
|
for c in X.columns
|
||||||
|
if c not in self.exclude_cols and X[c].dtype.is_numeric()
|
||||||
|
]
|
||||||
|
|
||||||
|
# 计算每天的统计量
|
||||||
|
stat_exprs = []
|
||||||
|
for col in numeric_cols:
|
||||||
|
if self.strategy == "mean":
|
||||||
|
stat_exprs.append(
|
||||||
|
pl.col(col).mean().over(self.date_col).alias(f"{col}_stat")
|
||||||
|
)
|
||||||
|
else: # median
|
||||||
|
stat_exprs.append(
|
||||||
|
pl.col(col).median().over(self.date_col).alias(f"{col}_stat")
|
||||||
|
)
|
||||||
|
|
||||||
|
# 添加统计量列
|
||||||
|
result = X.with_columns(stat_exprs)
|
||||||
|
|
||||||
|
# 使用统计量填充缺失值
|
||||||
|
fill_exprs = []
|
||||||
|
for col in X.columns:
|
||||||
|
if col in numeric_cols:
|
||||||
|
expr = pl.col(col).fill_null(pl.col(f"{col}_stat")).alias(col)
|
||||||
|
fill_exprs.append(expr)
|
||||||
|
else:
|
||||||
|
fill_exprs.append(pl.col(col))
|
||||||
|
|
||||||
|
result = result.select(fill_exprs)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
@register_processor("standard_scaler")
|
@register_processor("standard_scaler")
|
||||||
class StandardScaler(BaseProcessor):
|
class StandardScaler(BaseProcessor):
|
||||||
"""标准化处理器(全局标准化)
|
"""标准化处理器(全局标准化)
|
||||||
|
|||||||
Reference in New Issue
Block a user