feat(experiment): 添加模型保存功能及因子信息持久化

- 新增 SAVE_MODEL 配置控制是否保存模型
- 新增 get_model_save_path() 生成模型保存路径
- 新增 save_model_with_factors() 保存模型及关联因子信息
- 新增 load_model_factors() 加载因子信息用于模型复现
- 更新训练脚本使用新的模型保存方式
- 清理 data/sync.py 中的废弃代码
This commit is contained in:
2026-03-16 22:50:47 +08:00
parent 5ed06d20d2
commit 16f82d3458
5 changed files with 163 additions and 119 deletions

View File

@@ -55,104 +55,6 @@ import pandas as pd
from src.data import api_wrappers # noqa: F401
from src.data.sync_registry import sync_registry
from src.data.api_wrappers import sync_all_stocks
from src.data.api_wrappers.api_daily import sync_daily, preview_daily_sync
def preview_sync(
force_full: bool = False,
start_date: Optional[str] = None,
end_date: Optional[str] = None,
sample_size: int = 3,
max_workers: Optional[int] = None,
) -> dict[str, Any]:
"""预览日线同步数据量和样本(不实际同步)。
这是推荐的方式,可在实际同步前检查将要同步的内容。
Args:
force_full: 若为 True预览全量同步从 20180101
start_date: 手动指定起始日期(覆盖自动检测)
end_date: 手动指定结束日期(默认为今天)
sample_size: 预览用样本股票数量(默认: 3
max_workers: 工作线程数(默认: 10
Returns:
包含预览信息的字典:
{
'sync_needed': bool,
'stock_count': int,
'start_date': str,
'end_date': str,
'estimated_records': int,
'sample_data': pd.DataFrame,
'mode': str, # 'full', 'incremental', 'partial', 或 'none'
}
Example:
>>> # 预览将要同步的内容
>>> preview = preview_sync()
>>>
>>> # 预览全量同步
>>> preview = preview_sync(force_full=True)
>>>
>>> # 预览更多样本
>>> preview = preview_sync(sample_size=5)
"""
return preview_daily_sync(
force_full=force_full,
start_date=start_date,
end_date=end_date,
sample_size=sample_size,
)
def sync_all(
force_full: bool = False,
start_date: Optional[str] = None,
end_date: Optional[str] = None,
max_workers: Optional[int] = None,
dry_run: bool = False,
) -> dict[str, pd.DataFrame]:
"""同步所有股票的日线数据。
这是日线数据同步的主要入口点。
Args:
force_full: 若为 True强制从 20180101 完整重载
start_date: 手动指定起始日期YYYYMMDD
end_date: 手动指定结束日期(默认为今天)
max_workers: 工作线程数(默认: 10
dry_run: 若为 True仅预览将要同步的内容不写入数据
Returns:
映射 ts_code 到 DataFrame 的字典
Example:
>>> # 首次同步(从 20180101 全量加载)
>>> result = sync_all()
>>>
>>> # 后续同步(增量 - 仅新数据)
>>> result = sync_all()
>>>
>>> # 强制完整重载
>>> result = sync_all(force_full=True)
>>>
>>> # 手动指定日期范围
>>> result = sync_all(start_date='20240101', end_date='20240131')
>>>
>>> # 自定义线程数
>>> result = sync_all(max_workers=20)
>>>
>>> # Dry run仅预览
>>> result = sync_all(dry_run=True)
"""
return sync_daily(
force_full=force_full,
start_date=start_date,
end_date=end_date,
max_workers=max_workers,
dry_run=dry_run,
)
def sync_all_data(