feat(experiment): 添加模型保存功能及因子信息持久化
- 新增 SAVE_MODEL 配置控制是否保存模型 - 新增 get_model_save_path() 生成模型保存路径 - 新增 save_model_with_factors() 保存模型及关联因子信息 - 新增 load_model_factors() 加载因子信息用于模型复现 - 更新训练脚本使用新的模型保存方式 - 清理 data/sync.py 中的废弃代码
This commit is contained in:
@@ -55,104 +55,6 @@ import pandas as pd
|
||||
from src.data import api_wrappers # noqa: F401
|
||||
from src.data.sync_registry import sync_registry
|
||||
from src.data.api_wrappers import sync_all_stocks
|
||||
from src.data.api_wrappers.api_daily import sync_daily, preview_daily_sync
|
||||
|
||||
|
||||
def preview_sync(
|
||||
force_full: bool = False,
|
||||
start_date: Optional[str] = None,
|
||||
end_date: Optional[str] = None,
|
||||
sample_size: int = 3,
|
||||
max_workers: Optional[int] = None,
|
||||
) -> dict[str, Any]:
|
||||
"""预览日线同步数据量和样本(不实际同步)。
|
||||
|
||||
这是推荐的方式,可在实际同步前检查将要同步的内容。
|
||||
|
||||
Args:
|
||||
force_full: 若为 True,预览全量同步(从 20180101)
|
||||
start_date: 手动指定起始日期(覆盖自动检测)
|
||||
end_date: 手动指定结束日期(默认为今天)
|
||||
sample_size: 预览用样本股票数量(默认: 3)
|
||||
max_workers: 工作线程数(默认: 10)
|
||||
|
||||
Returns:
|
||||
包含预览信息的字典:
|
||||
{
|
||||
'sync_needed': bool,
|
||||
'stock_count': int,
|
||||
'start_date': str,
|
||||
'end_date': str,
|
||||
'estimated_records': int,
|
||||
'sample_data': pd.DataFrame,
|
||||
'mode': str, # 'full', 'incremental', 'partial', 或 'none'
|
||||
}
|
||||
|
||||
Example:
|
||||
>>> # 预览将要同步的内容
|
||||
>>> preview = preview_sync()
|
||||
>>>
|
||||
>>> # 预览全量同步
|
||||
>>> preview = preview_sync(force_full=True)
|
||||
>>>
|
||||
>>> # 预览更多样本
|
||||
>>> preview = preview_sync(sample_size=5)
|
||||
"""
|
||||
return preview_daily_sync(
|
||||
force_full=force_full,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
sample_size=sample_size,
|
||||
)
|
||||
|
||||
|
||||
def sync_all(
|
||||
force_full: bool = False,
|
||||
start_date: Optional[str] = None,
|
||||
end_date: Optional[str] = None,
|
||||
max_workers: Optional[int] = None,
|
||||
dry_run: bool = False,
|
||||
) -> dict[str, pd.DataFrame]:
|
||||
"""同步所有股票的日线数据。
|
||||
|
||||
这是日线数据同步的主要入口点。
|
||||
|
||||
Args:
|
||||
force_full: 若为 True,强制从 20180101 完整重载
|
||||
start_date: 手动指定起始日期(YYYYMMDD)
|
||||
end_date: 手动指定结束日期(默认为今天)
|
||||
max_workers: 工作线程数(默认: 10)
|
||||
dry_run: 若为 True,仅预览将要同步的内容,不写入数据
|
||||
|
||||
Returns:
|
||||
映射 ts_code 到 DataFrame 的字典
|
||||
|
||||
Example:
|
||||
>>> # 首次同步(从 20180101 全量加载)
|
||||
>>> result = sync_all()
|
||||
>>>
|
||||
>>> # 后续同步(增量 - 仅新数据)
|
||||
>>> result = sync_all()
|
||||
>>>
|
||||
>>> # 强制完整重载
|
||||
>>> result = sync_all(force_full=True)
|
||||
>>>
|
||||
>>> # 手动指定日期范围
|
||||
>>> result = sync_all(start_date='20240101', end_date='20240131')
|
||||
>>>
|
||||
>>> # 自定义线程数
|
||||
>>> result = sync_all(max_workers=20)
|
||||
>>>
|
||||
>>> # Dry run(仅预览)
|
||||
>>> result = sync_all(dry_run=True)
|
||||
"""
|
||||
return sync_daily(
|
||||
force_full=force_full,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
max_workers=max_workers,
|
||||
dry_run=dry_run,
|
||||
)
|
||||
|
||||
|
||||
def sync_all_data(
|
||||
|
||||
Reference in New Issue
Block a user