feat(experiment): 添加模型保存功能及因子信息持久化
- 新增 SAVE_MODEL 配置控制是否保存模型 - 新增 get_model_save_path() 生成模型保存路径 - 新增 save_model_with_factors() 保存模型及关联因子信息 - 新增 load_model_factors() 加载因子信息用于模型复现 - 更新训练脚本使用新的模型保存方式 - 清理 data/sync.py 中的废弃代码
This commit is contained in:
@@ -37,10 +37,15 @@ from src.experiment.common import (
|
||||
STOCK_FILTER_REQUIRED_COLUMNS,
|
||||
OUTPUT_DIR,
|
||||
SAVE_PREDICTIONS,
|
||||
PERSIST_MODEL,
|
||||
SAVE_MODEL,
|
||||
get_model_save_path,
|
||||
save_model_with_factors,
|
||||
TOP_N,
|
||||
)
|
||||
|
||||
# 训练类型标识
|
||||
TRAINING_TYPE = "regression"
|
||||
|
||||
|
||||
# %% md
|
||||
# ## 2. 配置参数
|
||||
@@ -153,7 +158,7 @@ st_filter = STFilter(
|
||||
data_router=engine.router,
|
||||
)
|
||||
|
||||
# 10. 创建训练器
|
||||
# 10. 创建训练器(禁用自动保存,我们将在训练后手动保存以包含因子信息)
|
||||
trainer = Trainer(
|
||||
model=model,
|
||||
pool_manager=pool_manager,
|
||||
@@ -162,7 +167,7 @@ trainer = Trainer(
|
||||
splitter=splitter,
|
||||
target_col=target_col,
|
||||
feature_cols=feature_cols,
|
||||
persist_model=PERSIST_MODEL,
|
||||
persist_model=False, # 禁用自动保存,手动保存以包含因子信息
|
||||
)
|
||||
# %% md
|
||||
# ### 4.2 执行训练
|
||||
@@ -577,3 +582,17 @@ if zero_importance:
|
||||
print(f" - {feat}")
|
||||
else:
|
||||
print("\n所有特征都有一定重要性")
|
||||
|
||||
# 保存模型和因子信息(如果启用)
|
||||
if SAVE_MODEL:
|
||||
print("\n" + "=" * 80)
|
||||
print("保存模型和因子信息")
|
||||
print("=" * 80)
|
||||
model_save_path = get_model_save_path(TRAINING_TYPE)
|
||||
if model_save_path:
|
||||
save_model_with_factors(
|
||||
model=model,
|
||||
model_path=model_save_path,
|
||||
selected_factors=SELECTED_FACTORS,
|
||||
factor_definitions=FACTOR_DEFINITIONS,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user