feat(experiment): 添加模型保存功能及因子信息持久化
- 新增 SAVE_MODEL 配置控制是否保存模型 - 新增 get_model_save_path() 生成模型保存路径 - 新增 save_model_with_factors() 保存模型及关联因子信息 - 新增 load_model_factors() 加载因子信息用于模型复现 - 更新训练脚本使用新的模型保存方式 - 清理 data/sync.py 中的废弃代码
This commit is contained in:
@@ -32,6 +32,7 @@ from src.training import (
|
||||
NullFiller,
|
||||
StandardScaler,
|
||||
check_data_quality,
|
||||
CrossSectionalStandardScaler,
|
||||
)
|
||||
from src.training.components.models import LightGBMLambdaRankModel
|
||||
from src.training.config import TrainingConfig
|
||||
@@ -53,10 +54,15 @@ from src.experiment.common import (
|
||||
STOCK_FILTER_REQUIRED_COLUMNS,
|
||||
OUTPUT_DIR,
|
||||
SAVE_PREDICTIONS,
|
||||
PERSIST_MODEL,
|
||||
SAVE_MODEL,
|
||||
get_model_save_path,
|
||||
save_model_with_factors,
|
||||
TOP_N,
|
||||
)
|
||||
|
||||
# 训练类型标识
|
||||
TRAINING_TYPE = "rank"
|
||||
|
||||
|
||||
# %% md
|
||||
# ## 2. 本地辅助函数
|
||||
@@ -219,21 +225,23 @@ N_QUANTILES = 20 # 将 label 分为 20 组
|
||||
MODEL_PARAMS = {
|
||||
"objective": "lambdarank",
|
||||
"metric": "ndcg",
|
||||
"ndcg_at": 10, # 评估 NDCG@k
|
||||
"learning_rate": 0.01,
|
||||
"num_leaves": 31,
|
||||
"max_depth": 4,
|
||||
"min_data_in_leaf": 20,
|
||||
"n_estimators": 2000,
|
||||
"early_stopping_round": 100,
|
||||
"ndcg_at": 15,
|
||||
"learning_rate": 0.002,
|
||||
"num_leaves": 63,
|
||||
"max_depth": 5,
|
||||
"min_data_in_leaf": 63,
|
||||
"n_estimators": 1000,
|
||||
"early_stopping_round": 150,
|
||||
"subsample": 0.8,
|
||||
"colsample_bytree": 0.8,
|
||||
"reg_alpha": 0.1,
|
||||
"reg_alpha": 0.5,
|
||||
"reg_lambda": 1.0,
|
||||
"verbose": -1,
|
||||
"random_state": 42,
|
||||
"lambdarank_truncation_level": 10,
|
||||
"label_gain": [i for i in range(1, N_QUANTILES + 1)],
|
||||
"lambdarank_truncation_level": 30,
|
||||
"label_gain": [
|
||||
i for i in range(1, N_QUANTILES + 1)
|
||||
], # 如果收益率被分为了比如 5 档,建议用[0, 1, 3, 7, 15] 这种指数型 gain
|
||||
}
|
||||
|
||||
# 注意:stock_pool_filter, STOCK_FILTER_REQUIRED_COLUMNS, OUTPUT_DIR 等配置
|
||||
@@ -287,7 +295,7 @@ model = LightGBMLambdaRankModel(params=MODEL_PARAMS)
|
||||
processors = [
|
||||
NullFiller(feature_cols=feature_cols, strategy="mean"),
|
||||
Winsorizer(feature_cols=feature_cols, lower=0.01, upper=0.99),
|
||||
StandardScaler(feature_cols=feature_cols),
|
||||
CrossSectionalStandardScaler(feature_cols=feature_cols),
|
||||
]
|
||||
|
||||
# 8. 创建数据划分器
|
||||
@@ -310,7 +318,7 @@ pool_manager = StockPoolManager(
|
||||
# 10. 创建 ST 过滤器
|
||||
st_filter = STFilter(data_router=engine.router)
|
||||
|
||||
# 11. 创建训练器
|
||||
# 11. 创建训练器(禁用自动保存,我们将在训练后手动保存以包含因子信息)
|
||||
trainer = Trainer(
|
||||
model=model,
|
||||
pool_manager=pool_manager,
|
||||
@@ -319,7 +327,7 @@ trainer = Trainer(
|
||||
splitter=splitter,
|
||||
target_col=target_col,
|
||||
feature_cols=feature_cols,
|
||||
persist_model=PERSIST_MODEL,
|
||||
persist_model=False, # 禁用自动保存,手动保存以包含因子信息
|
||||
)
|
||||
# %% md
|
||||
# ### 4.1 股票池筛选
|
||||
@@ -582,6 +590,20 @@ print(
|
||||
print(f"\n 预览(前15行):")
|
||||
print(topn_to_save.head(15))
|
||||
|
||||
# 保存模型和因子信息(如果启用)
|
||||
if SAVE_MODEL:
|
||||
print("\n" + "=" * 80)
|
||||
print("保存模型和因子信息")
|
||||
print("=" * 80)
|
||||
model_save_path = get_model_save_path(TRAINING_TYPE)
|
||||
if model_save_path:
|
||||
save_model_with_factors(
|
||||
model=model,
|
||||
model_path=model_save_path,
|
||||
selected_factors=SELECTED_FACTORS,
|
||||
factor_definitions=FACTOR_DEFINITIONS,
|
||||
)
|
||||
|
||||
print("\n训练流程完成!")
|
||||
# %% md
|
||||
# ## 5. 总结
|
||||
|
||||
Reference in New Issue
Block a user