feat(experiment): 添加模型保存功能及因子信息持久化
- 新增 SAVE_MODEL 配置控制是否保存模型 - 新增 get_model_save_path() 生成模型保存路径 - 新增 save_model_with_factors() 保存模型及关联因子信息 - 新增 load_model_factors() 加载因子信息用于模型复现 - 更新训练脚本使用新的模型保存方式 - 清理 data/sync.py 中的废弃代码
This commit is contained in:
@@ -5,7 +5,7 @@
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import List
|
||||
from typing import List, Optional
|
||||
|
||||
import polars as pl
|
||||
|
||||
@@ -255,8 +255,7 @@ SELECTED_FACTORS = [
|
||||
]
|
||||
|
||||
# 因子定义字典(完整因子库,用于存放尚未注册到metadata的因子)
|
||||
FACTOR_DEFINITIONS = {
|
||||
}
|
||||
FACTOR_DEFINITIONS = {}
|
||||
|
||||
|
||||
def get_label_factor(label_name: str) -> dict:
|
||||
@@ -417,7 +416,10 @@ STOCK_FILTER_REQUIRED_COLUMNS = ["total_mv"]
|
||||
# =============================================================================
|
||||
OUTPUT_DIR = "output"
|
||||
SAVE_PREDICTIONS = True
|
||||
PERSIST_MODEL = False
|
||||
|
||||
# 模型保存配置
|
||||
SAVE_MODEL = True # 是否保存模型
|
||||
MODEL_SAVE_DIR = "models" # 模型保存目录
|
||||
|
||||
# Top N 配置:每日推荐股票数量
|
||||
TOP_N = 5 # 可调整为 10, 20 等
|
||||
@@ -446,3 +448,101 @@ def get_output_path(model_type: str, test_start: str, test_end: str) -> str:
|
||||
|
||||
filename = f"{model_type}_output.csv"
|
||||
return os.path.join(OUTPUT_DIR, filename)
|
||||
|
||||
|
||||
def get_model_save_path(
|
||||
model_type: str, model_name: Optional[str] = None
|
||||
) -> Optional[str]:
|
||||
"""生成模型保存路径。
|
||||
|
||||
Args:
|
||||
model_type: 模型类型("regression" 或 "rank")
|
||||
model_name: 模型名称,默认为 model_type
|
||||
|
||||
Returns:
|
||||
模型保存路径,如果 SAVE_MODEL 为 False 则返回 None
|
||||
"""
|
||||
if not SAVE_MODEL:
|
||||
return None
|
||||
|
||||
import os
|
||||
|
||||
# 确保模型保存目录存在
|
||||
os.makedirs(MODEL_SAVE_DIR, exist_ok=True)
|
||||
|
||||
# 使用 model_name 或默认使用 model_type
|
||||
name = model_name if model_name else model_type
|
||||
filename = f"{name}.pkl"
|
||||
return os.path.join(MODEL_SAVE_DIR, filename)
|
||||
|
||||
|
||||
def save_model_with_factors(
|
||||
model,
|
||||
model_path: str,
|
||||
selected_factors: List[str],
|
||||
factor_definitions: dict,
|
||||
) -> None:
|
||||
"""保存模型及关联的因子信息。
|
||||
|
||||
除了保存模型本身,还会保存一个同名的 .factors.json 文件,
|
||||
包含 SELECTED_FACTORS 和 FACTOR_DEFINITIONS,以便后续加载模型时
|
||||
知道使用了哪些因子。
|
||||
|
||||
Args:
|
||||
model: 训练好的模型实例(需有 save 方法)
|
||||
model_path: 模型保存路径
|
||||
selected_factors: 从 metadata 中选择的因子名称列表
|
||||
factor_definitions: 通过表达式定义的因子字典
|
||||
"""
|
||||
import json
|
||||
import os
|
||||
|
||||
# 1. 保存模型本身
|
||||
model.save(model_path)
|
||||
print(f"[模型保存] 模型已保存至: {model_path}")
|
||||
|
||||
# 2. 保存因子信息到 .factors.json 文件
|
||||
factors_path = model_path.replace(".pkl", ".factors.json")
|
||||
|
||||
factors_info = {
|
||||
"selected_factors": selected_factors,
|
||||
"factor_definitions": factor_definitions,
|
||||
"total_feature_count": len(selected_factors) + len(factor_definitions),
|
||||
"selected_factors_count": len(selected_factors),
|
||||
"factor_definitions_count": len(factor_definitions),
|
||||
}
|
||||
|
||||
with open(factors_path, "w", encoding="utf-8") as f:
|
||||
json.dump(factors_info, f, ensure_ascii=False, indent=2)
|
||||
|
||||
print(f"[模型保存] 因子信息已保存至: {factors_path}")
|
||||
print(f"[模型保存] 总计 {factors_info['total_feature_count']} 个因子")
|
||||
print(f" - 来自 metadata: {factors_info['selected_factors_count']} 个")
|
||||
print(f" - 来自表达式定义: {factors_info['factor_definitions_count']} 个")
|
||||
|
||||
|
||||
def load_model_factors(model_path: str) -> Optional[dict]:
|
||||
"""加载模型关联的因子信息。
|
||||
|
||||
Args:
|
||||
model_path: 模型保存路径
|
||||
|
||||
Returns:
|
||||
包含因子信息的字典,如果文件不存在则返回 None
|
||||
"""
|
||||
import json
|
||||
import os
|
||||
|
||||
factors_path = model_path.replace(".pkl", ".factors.json")
|
||||
|
||||
if not os.path.exists(factors_path):
|
||||
print(f"[警告] 未找到因子信息文件: {factors_path}")
|
||||
return None
|
||||
|
||||
with open(factors_path, "r", encoding="utf-8") as f:
|
||||
factors_info = json.load(f)
|
||||
|
||||
print(
|
||||
f"[模型加载] 已加载因子信息,总计 {factors_info.get('total_feature_count', 'N/A')} 个因子"
|
||||
)
|
||||
return factors_info
|
||||
|
||||
@@ -32,6 +32,7 @@ from src.training import (
|
||||
NullFiller,
|
||||
StandardScaler,
|
||||
check_data_quality,
|
||||
CrossSectionalStandardScaler,
|
||||
)
|
||||
from src.training.components.models import LightGBMLambdaRankModel
|
||||
from src.training.config import TrainingConfig
|
||||
@@ -53,10 +54,15 @@ from src.experiment.common import (
|
||||
STOCK_FILTER_REQUIRED_COLUMNS,
|
||||
OUTPUT_DIR,
|
||||
SAVE_PREDICTIONS,
|
||||
PERSIST_MODEL,
|
||||
SAVE_MODEL,
|
||||
get_model_save_path,
|
||||
save_model_with_factors,
|
||||
TOP_N,
|
||||
)
|
||||
|
||||
# 训练类型标识
|
||||
TRAINING_TYPE = "rank"
|
||||
|
||||
|
||||
# %% md
|
||||
# ## 2. 本地辅助函数
|
||||
@@ -219,21 +225,23 @@ N_QUANTILES = 20 # 将 label 分为 20 组
|
||||
MODEL_PARAMS = {
|
||||
"objective": "lambdarank",
|
||||
"metric": "ndcg",
|
||||
"ndcg_at": 10, # 评估 NDCG@k
|
||||
"learning_rate": 0.01,
|
||||
"num_leaves": 31,
|
||||
"max_depth": 4,
|
||||
"min_data_in_leaf": 20,
|
||||
"n_estimators": 2000,
|
||||
"early_stopping_round": 100,
|
||||
"ndcg_at": 15,
|
||||
"learning_rate": 0.002,
|
||||
"num_leaves": 63,
|
||||
"max_depth": 5,
|
||||
"min_data_in_leaf": 63,
|
||||
"n_estimators": 1000,
|
||||
"early_stopping_round": 150,
|
||||
"subsample": 0.8,
|
||||
"colsample_bytree": 0.8,
|
||||
"reg_alpha": 0.1,
|
||||
"reg_alpha": 0.5,
|
||||
"reg_lambda": 1.0,
|
||||
"verbose": -1,
|
||||
"random_state": 42,
|
||||
"lambdarank_truncation_level": 10,
|
||||
"label_gain": [i for i in range(1, N_QUANTILES + 1)],
|
||||
"lambdarank_truncation_level": 30,
|
||||
"label_gain": [
|
||||
i for i in range(1, N_QUANTILES + 1)
|
||||
], # 如果收益率被分为了比如 5 档,建议用[0, 1, 3, 7, 15] 这种指数型 gain
|
||||
}
|
||||
|
||||
# 注意:stock_pool_filter, STOCK_FILTER_REQUIRED_COLUMNS, OUTPUT_DIR 等配置
|
||||
@@ -287,7 +295,7 @@ model = LightGBMLambdaRankModel(params=MODEL_PARAMS)
|
||||
processors = [
|
||||
NullFiller(feature_cols=feature_cols, strategy="mean"),
|
||||
Winsorizer(feature_cols=feature_cols, lower=0.01, upper=0.99),
|
||||
StandardScaler(feature_cols=feature_cols),
|
||||
CrossSectionalStandardScaler(feature_cols=feature_cols),
|
||||
]
|
||||
|
||||
# 8. 创建数据划分器
|
||||
@@ -310,7 +318,7 @@ pool_manager = StockPoolManager(
|
||||
# 10. 创建 ST 过滤器
|
||||
st_filter = STFilter(data_router=engine.router)
|
||||
|
||||
# 11. 创建训练器
|
||||
# 11. 创建训练器(禁用自动保存,我们将在训练后手动保存以包含因子信息)
|
||||
trainer = Trainer(
|
||||
model=model,
|
||||
pool_manager=pool_manager,
|
||||
@@ -319,7 +327,7 @@ trainer = Trainer(
|
||||
splitter=splitter,
|
||||
target_col=target_col,
|
||||
feature_cols=feature_cols,
|
||||
persist_model=PERSIST_MODEL,
|
||||
persist_model=False, # 禁用自动保存,手动保存以包含因子信息
|
||||
)
|
||||
# %% md
|
||||
# ### 4.1 股票池筛选
|
||||
@@ -582,6 +590,20 @@ print(
|
||||
print(f"\n 预览(前15行):")
|
||||
print(topn_to_save.head(15))
|
||||
|
||||
# 保存模型和因子信息(如果启用)
|
||||
if SAVE_MODEL:
|
||||
print("\n" + "=" * 80)
|
||||
print("保存模型和因子信息")
|
||||
print("=" * 80)
|
||||
model_save_path = get_model_save_path(TRAINING_TYPE)
|
||||
if model_save_path:
|
||||
save_model_with_factors(
|
||||
model=model,
|
||||
model_path=model_save_path,
|
||||
selected_factors=SELECTED_FACTORS,
|
||||
factor_definitions=FACTOR_DEFINITIONS,
|
||||
)
|
||||
|
||||
print("\n训练流程完成!")
|
||||
# %% md
|
||||
# ## 5. 总结
|
||||
|
||||
@@ -37,10 +37,15 @@ from src.experiment.common import (
|
||||
STOCK_FILTER_REQUIRED_COLUMNS,
|
||||
OUTPUT_DIR,
|
||||
SAVE_PREDICTIONS,
|
||||
PERSIST_MODEL,
|
||||
SAVE_MODEL,
|
||||
get_model_save_path,
|
||||
save_model_with_factors,
|
||||
TOP_N,
|
||||
)
|
||||
|
||||
# 训练类型标识
|
||||
TRAINING_TYPE = "regression"
|
||||
|
||||
|
||||
# %% md
|
||||
# ## 2. 配置参数
|
||||
@@ -153,7 +158,7 @@ st_filter = STFilter(
|
||||
data_router=engine.router,
|
||||
)
|
||||
|
||||
# 10. 创建训练器
|
||||
# 10. 创建训练器(禁用自动保存,我们将在训练后手动保存以包含因子信息)
|
||||
trainer = Trainer(
|
||||
model=model,
|
||||
pool_manager=pool_manager,
|
||||
@@ -162,7 +167,7 @@ trainer = Trainer(
|
||||
splitter=splitter,
|
||||
target_col=target_col,
|
||||
feature_cols=feature_cols,
|
||||
persist_model=PERSIST_MODEL,
|
||||
persist_model=False, # 禁用自动保存,手动保存以包含因子信息
|
||||
)
|
||||
# %% md
|
||||
# ### 4.2 执行训练
|
||||
@@ -577,3 +582,17 @@ if zero_importance:
|
||||
print(f" - {feat}")
|
||||
else:
|
||||
print("\n所有特征都有一定重要性")
|
||||
|
||||
# 保存模型和因子信息(如果启用)
|
||||
if SAVE_MODEL:
|
||||
print("\n" + "=" * 80)
|
||||
print("保存模型和因子信息")
|
||||
print("=" * 80)
|
||||
model_save_path = get_model_save_path(TRAINING_TYPE)
|
||||
if model_save_path:
|
||||
save_model_with_factors(
|
||||
model=model,
|
||||
model_path=model_save_path,
|
||||
selected_factors=SELECTED_FACTORS,
|
||||
factor_definitions=FACTOR_DEFINITIONS,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user