feat(experiment): 添加模型保存功能及因子信息持久化

- 新增 SAVE_MODEL 配置控制是否保存模型
- 新增 get_model_save_path() 生成模型保存路径
- 新增 save_model_with_factors() 保存模型及关联因子信息
- 新增 load_model_factors() 加载因子信息用于模型复现
- 更新训练脚本使用新的模型保存方式
- 清理 data/sync.py 中的废弃代码
This commit is contained in:
2026-03-16 22:50:47 +08:00
parent 5ed06d20d2
commit 16f82d3458
5 changed files with 163 additions and 119 deletions

1
.gitignore vendored
View File

@@ -83,3 +83,4 @@ src/training/output/*
# AI Agent 工作目录 # AI Agent 工作目录
/.sisyphus/ /.sisyphus/
/src/experiment/output/ /src/experiment/output/
/src/experiment/models/

View File

@@ -55,104 +55,6 @@ import pandas as pd
from src.data import api_wrappers # noqa: F401 from src.data import api_wrappers # noqa: F401
from src.data.sync_registry import sync_registry from src.data.sync_registry import sync_registry
from src.data.api_wrappers import sync_all_stocks from src.data.api_wrappers import sync_all_stocks
from src.data.api_wrappers.api_daily import sync_daily, preview_daily_sync
def preview_sync(
force_full: bool = False,
start_date: Optional[str] = None,
end_date: Optional[str] = None,
sample_size: int = 3,
max_workers: Optional[int] = None,
) -> dict[str, Any]:
"""预览日线同步数据量和样本(不实际同步)。
这是推荐的方式,可在实际同步前检查将要同步的内容。
Args:
force_full: 若为 True预览全量同步从 20180101
start_date: 手动指定起始日期(覆盖自动检测)
end_date: 手动指定结束日期(默认为今天)
sample_size: 预览用样本股票数量(默认: 3
max_workers: 工作线程数(默认: 10
Returns:
包含预览信息的字典:
{
'sync_needed': bool,
'stock_count': int,
'start_date': str,
'end_date': str,
'estimated_records': int,
'sample_data': pd.DataFrame,
'mode': str, # 'full', 'incremental', 'partial', 或 'none'
}
Example:
>>> # 预览将要同步的内容
>>> preview = preview_sync()
>>>
>>> # 预览全量同步
>>> preview = preview_sync(force_full=True)
>>>
>>> # 预览更多样本
>>> preview = preview_sync(sample_size=5)
"""
return preview_daily_sync(
force_full=force_full,
start_date=start_date,
end_date=end_date,
sample_size=sample_size,
)
def sync_all(
force_full: bool = False,
start_date: Optional[str] = None,
end_date: Optional[str] = None,
max_workers: Optional[int] = None,
dry_run: bool = False,
) -> dict[str, pd.DataFrame]:
"""同步所有股票的日线数据。
这是日线数据同步的主要入口点。
Args:
force_full: 若为 True强制从 20180101 完整重载
start_date: 手动指定起始日期YYYYMMDD
end_date: 手动指定结束日期(默认为今天)
max_workers: 工作线程数(默认: 10
dry_run: 若为 True仅预览将要同步的内容不写入数据
Returns:
映射 ts_code 到 DataFrame 的字典
Example:
>>> # 首次同步(从 20180101 全量加载)
>>> result = sync_all()
>>>
>>> # 后续同步(增量 - 仅新数据)
>>> result = sync_all()
>>>
>>> # 强制完整重载
>>> result = sync_all(force_full=True)
>>>
>>> # 手动指定日期范围
>>> result = sync_all(start_date='20240101', end_date='20240131')
>>>
>>> # 自定义线程数
>>> result = sync_all(max_workers=20)
>>>
>>> # Dry run仅预览
>>> result = sync_all(dry_run=True)
"""
return sync_daily(
force_full=force_full,
start_date=start_date,
end_date=end_date,
max_workers=max_workers,
dry_run=dry_run,
)
def sync_all_data( def sync_all_data(

View File

@@ -5,7 +5,7 @@
""" """
from datetime import datetime from datetime import datetime
from typing import List from typing import List, Optional
import polars as pl import polars as pl
@@ -255,8 +255,7 @@ SELECTED_FACTORS = [
] ]
# 因子定义字典完整因子库用于存放尚未注册到metadata的因子 # 因子定义字典完整因子库用于存放尚未注册到metadata的因子
FACTOR_DEFINITIONS = { FACTOR_DEFINITIONS = {}
}
def get_label_factor(label_name: str) -> dict: def get_label_factor(label_name: str) -> dict:
@@ -417,7 +416,10 @@ STOCK_FILTER_REQUIRED_COLUMNS = ["total_mv"]
# ============================================================================= # =============================================================================
OUTPUT_DIR = "output" OUTPUT_DIR = "output"
SAVE_PREDICTIONS = True SAVE_PREDICTIONS = True
PERSIST_MODEL = False
# 模型保存配置
SAVE_MODEL = True # 是否保存模型
MODEL_SAVE_DIR = "models" # 模型保存目录
# Top N 配置:每日推荐股票数量 # Top N 配置:每日推荐股票数量
TOP_N = 5 # 可调整为 10, 20 等 TOP_N = 5 # 可调整为 10, 20 等
@@ -446,3 +448,101 @@ def get_output_path(model_type: str, test_start: str, test_end: str) -> str:
filename = f"{model_type}_output.csv" filename = f"{model_type}_output.csv"
return os.path.join(OUTPUT_DIR, filename) return os.path.join(OUTPUT_DIR, filename)
def get_model_save_path(
model_type: str, model_name: Optional[str] = None
) -> Optional[str]:
"""生成模型保存路径。
Args:
model_type: 模型类型("regression""rank"
model_name: 模型名称,默认为 model_type
Returns:
模型保存路径,如果 SAVE_MODEL 为 False 则返回 None
"""
if not SAVE_MODEL:
return None
import os
# 确保模型保存目录存在
os.makedirs(MODEL_SAVE_DIR, exist_ok=True)
# 使用 model_name 或默认使用 model_type
name = model_name if model_name else model_type
filename = f"{name}.pkl"
return os.path.join(MODEL_SAVE_DIR, filename)
def save_model_with_factors(
model,
model_path: str,
selected_factors: List[str],
factor_definitions: dict,
) -> None:
"""保存模型及关联的因子信息。
除了保存模型本身,还会保存一个同名的 .factors.json 文件,
包含 SELECTED_FACTORS 和 FACTOR_DEFINITIONS以便后续加载模型时
知道使用了哪些因子。
Args:
model: 训练好的模型实例(需有 save 方法)
model_path: 模型保存路径
selected_factors: 从 metadata 中选择的因子名称列表
factor_definitions: 通过表达式定义的因子字典
"""
import json
import os
# 1. 保存模型本身
model.save(model_path)
print(f"[模型保存] 模型已保存至: {model_path}")
# 2. 保存因子信息到 .factors.json 文件
factors_path = model_path.replace(".pkl", ".factors.json")
factors_info = {
"selected_factors": selected_factors,
"factor_definitions": factor_definitions,
"total_feature_count": len(selected_factors) + len(factor_definitions),
"selected_factors_count": len(selected_factors),
"factor_definitions_count": len(factor_definitions),
}
with open(factors_path, "w", encoding="utf-8") as f:
json.dump(factors_info, f, ensure_ascii=False, indent=2)
print(f"[模型保存] 因子信息已保存至: {factors_path}")
print(f"[模型保存] 总计 {factors_info['total_feature_count']} 个因子")
print(f" - 来自 metadata: {factors_info['selected_factors_count']}")
print(f" - 来自表达式定义: {factors_info['factor_definitions_count']}")
def load_model_factors(model_path: str) -> Optional[dict]:
"""加载模型关联的因子信息。
Args:
model_path: 模型保存路径
Returns:
包含因子信息的字典,如果文件不存在则返回 None
"""
import json
import os
factors_path = model_path.replace(".pkl", ".factors.json")
if not os.path.exists(factors_path):
print(f"[警告] 未找到因子信息文件: {factors_path}")
return None
with open(factors_path, "r", encoding="utf-8") as f:
factors_info = json.load(f)
print(
f"[模型加载] 已加载因子信息,总计 {factors_info.get('total_feature_count', 'N/A')} 个因子"
)
return factors_info

View File

@@ -32,6 +32,7 @@ from src.training import (
NullFiller, NullFiller,
StandardScaler, StandardScaler,
check_data_quality, check_data_quality,
CrossSectionalStandardScaler,
) )
from src.training.components.models import LightGBMLambdaRankModel from src.training.components.models import LightGBMLambdaRankModel
from src.training.config import TrainingConfig from src.training.config import TrainingConfig
@@ -53,10 +54,15 @@ from src.experiment.common import (
STOCK_FILTER_REQUIRED_COLUMNS, STOCK_FILTER_REQUIRED_COLUMNS,
OUTPUT_DIR, OUTPUT_DIR,
SAVE_PREDICTIONS, SAVE_PREDICTIONS,
PERSIST_MODEL, SAVE_MODEL,
get_model_save_path,
save_model_with_factors,
TOP_N, TOP_N,
) )
# 训练类型标识
TRAINING_TYPE = "rank"
# %% md # %% md
# ## 2. 本地辅助函数 # ## 2. 本地辅助函数
@@ -219,21 +225,23 @@ N_QUANTILES = 20 # 将 label 分为 20 组
MODEL_PARAMS = { MODEL_PARAMS = {
"objective": "lambdarank", "objective": "lambdarank",
"metric": "ndcg", "metric": "ndcg",
"ndcg_at": 10, # 评估 NDCG@k "ndcg_at": 15,
"learning_rate": 0.01, "learning_rate": 0.002,
"num_leaves": 31, "num_leaves": 63,
"max_depth": 4, "max_depth": 5,
"min_data_in_leaf": 20, "min_data_in_leaf": 63,
"n_estimators": 2000, "n_estimators": 1000,
"early_stopping_round": 100, "early_stopping_round": 150,
"subsample": 0.8, "subsample": 0.8,
"colsample_bytree": 0.8, "colsample_bytree": 0.8,
"reg_alpha": 0.1, "reg_alpha": 0.5,
"reg_lambda": 1.0, "reg_lambda": 1.0,
"verbose": -1, "verbose": -1,
"random_state": 42, "random_state": 42,
"lambdarank_truncation_level": 10, "lambdarank_truncation_level": 30,
"label_gain": [i for i in range(1, N_QUANTILES + 1)], "label_gain": [
i for i in range(1, N_QUANTILES + 1)
], # 如果收益率被分为了比如 5 档,建议用[0, 1, 3, 7, 15] 这种指数型 gain
} }
# 注意stock_pool_filter, STOCK_FILTER_REQUIRED_COLUMNS, OUTPUT_DIR 等配置 # 注意stock_pool_filter, STOCK_FILTER_REQUIRED_COLUMNS, OUTPUT_DIR 等配置
@@ -287,7 +295,7 @@ model = LightGBMLambdaRankModel(params=MODEL_PARAMS)
processors = [ processors = [
NullFiller(feature_cols=feature_cols, strategy="mean"), NullFiller(feature_cols=feature_cols, strategy="mean"),
Winsorizer(feature_cols=feature_cols, lower=0.01, upper=0.99), Winsorizer(feature_cols=feature_cols, lower=0.01, upper=0.99),
StandardScaler(feature_cols=feature_cols), CrossSectionalStandardScaler(feature_cols=feature_cols),
] ]
# 8. 创建数据划分器 # 8. 创建数据划分器
@@ -310,7 +318,7 @@ pool_manager = StockPoolManager(
# 10. 创建 ST 过滤器 # 10. 创建 ST 过滤器
st_filter = STFilter(data_router=engine.router) st_filter = STFilter(data_router=engine.router)
# 11. 创建训练器 # 11. 创建训练器(禁用自动保存,我们将在训练后手动保存以包含因子信息)
trainer = Trainer( trainer = Trainer(
model=model, model=model,
pool_manager=pool_manager, pool_manager=pool_manager,
@@ -319,7 +327,7 @@ trainer = Trainer(
splitter=splitter, splitter=splitter,
target_col=target_col, target_col=target_col,
feature_cols=feature_cols, feature_cols=feature_cols,
persist_model=PERSIST_MODEL, persist_model=False, # 禁用自动保存,手动保存以包含因子信息
) )
# %% md # %% md
# ### 4.1 股票池筛选 # ### 4.1 股票池筛选
@@ -582,6 +590,20 @@ print(
print(f"\n 预览前15行:") print(f"\n 预览前15行:")
print(topn_to_save.head(15)) print(topn_to_save.head(15))
# 保存模型和因子信息(如果启用)
if SAVE_MODEL:
print("\n" + "=" * 80)
print("保存模型和因子信息")
print("=" * 80)
model_save_path = get_model_save_path(TRAINING_TYPE)
if model_save_path:
save_model_with_factors(
model=model,
model_path=model_save_path,
selected_factors=SELECTED_FACTORS,
factor_definitions=FACTOR_DEFINITIONS,
)
print("\n训练流程完成!") print("\n训练流程完成!")
# %% md # %% md
# ## 5. 总结 # ## 5. 总结

View File

@@ -37,10 +37,15 @@ from src.experiment.common import (
STOCK_FILTER_REQUIRED_COLUMNS, STOCK_FILTER_REQUIRED_COLUMNS,
OUTPUT_DIR, OUTPUT_DIR,
SAVE_PREDICTIONS, SAVE_PREDICTIONS,
PERSIST_MODEL, SAVE_MODEL,
get_model_save_path,
save_model_with_factors,
TOP_N, TOP_N,
) )
# 训练类型标识
TRAINING_TYPE = "regression"
# %% md # %% md
# ## 2. 配置参数 # ## 2. 配置参数
@@ -153,7 +158,7 @@ st_filter = STFilter(
data_router=engine.router, data_router=engine.router,
) )
# 10. 创建训练器 # 10. 创建训练器(禁用自动保存,我们将在训练后手动保存以包含因子信息)
trainer = Trainer( trainer = Trainer(
model=model, model=model,
pool_manager=pool_manager, pool_manager=pool_manager,
@@ -162,7 +167,7 @@ trainer = Trainer(
splitter=splitter, splitter=splitter,
target_col=target_col, target_col=target_col,
feature_cols=feature_cols, feature_cols=feature_cols,
persist_model=PERSIST_MODEL, persist_model=False, # 禁用自动保存,手动保存以包含因子信息
) )
# %% md # %% md
# ### 4.2 执行训练 # ### 4.2 执行训练
@@ -577,3 +582,17 @@ if zero_importance:
print(f" - {feat}") print(f" - {feat}")
else: else:
print("\n所有特征都有一定重要性") print("\n所有特征都有一定重要性")
# 保存模型和因子信息(如果启用)
if SAVE_MODEL:
print("\n" + "=" * 80)
print("保存模型和因子信息")
print("=" * 80)
model_save_path = get_model_save_path(TRAINING_TYPE)
if model_save_path:
save_model_with_factors(
model=model,
model_path=model_save_path,
selected_factors=SELECTED_FACTORS,
factor_definitions=FACTOR_DEFINITIONS,
)