refactor(experiment): 重构模型保存机制,支持 processors 持久化
- 模型保存路径改为 models/{model_type}/ 目录结构
- save_model_with_factors 新增 fitted_processors 参数
- 新增 load_processors 函数加载处理器状态
- Storage 查询排序优化:ORDER BY ts_code, trade_date
This commit is contained in:
@@ -228,7 +228,7 @@ SELECTED_FACTORS = [
|
||||
"GTJA_alpha162",
|
||||
"GTJA_alpha163",
|
||||
"GTJA_alpha164",
|
||||
"GTJA_alpha165",
|
||||
# "GTJA_alpha165",
|
||||
"GTJA_alpha166",
|
||||
"GTJA_alpha167",
|
||||
"GTJA_alpha168",
|
||||
@@ -243,7 +243,7 @@ SELECTED_FACTORS = [
|
||||
"GTJA_alpha178",
|
||||
"GTJA_alpha179",
|
||||
"GTJA_alpha180",
|
||||
"GTJA_alpha183",
|
||||
# "GTJA_alpha183",
|
||||
"GTJA_alpha184",
|
||||
"GTJA_alpha185",
|
||||
"GTJA_alpha187",
|
||||
@@ -258,44 +258,44 @@ FACTOR_DEFINITIONS = {}
|
||||
# 需要排除的因子列表(这些因子不会被计算和使用)
|
||||
# 用于临时屏蔽效果不好的因子,无需从 SELECTED_FACTORS 中删除
|
||||
EXCLUDED_FACTORS: List[str] = [
|
||||
'GTJA_alpha005',
|
||||
'GTJA_alpha028',
|
||||
'GTJA_alpha023',
|
||||
'GTJA_alpha002',
|
||||
'GTJA_alpha010',
|
||||
'GTJA_alpha011',
|
||||
'GTJA_alpha044',
|
||||
'GTJA_alpha036',
|
||||
'GTJA_alpha027',
|
||||
'GTJA_alpha109',
|
||||
'GTJA_alpha104',
|
||||
'GTJA_alpha103',
|
||||
'GTJA_alpha085',
|
||||
'GTJA_alpha111',
|
||||
'GTJA_alpha092',
|
||||
'GTJA_alpha067',
|
||||
'GTJA_alpha060',
|
||||
'GTJA_alpha062',
|
||||
'GTJA_alpha063',
|
||||
'GTJA_alpha079',
|
||||
'GTJA_alpha073',
|
||||
'GTJA_alpha087',
|
||||
'GTJA_alpha117',
|
||||
'GTJA_alpha113',
|
||||
'GTJA_alpha138',
|
||||
'GTJA_alpha121',
|
||||
'GTJA_alpha124',
|
||||
'GTJA_alpha133',
|
||||
'GTJA_alpha131',
|
||||
'GTJA_alpha118',
|
||||
'GTJA_alpha164',
|
||||
'GTJA_alpha162',
|
||||
'GTJA_alpha157',
|
||||
'GTJA_alpha171',
|
||||
'GTJA_alpha177',
|
||||
'GTJA_alpha180',
|
||||
'GTJA_alpha188',
|
||||
'GTJA_alpha191',
|
||||
# "GTJA_alpha005",
|
||||
# "GTJA_alpha028",
|
||||
# "GTJA_alpha023",
|
||||
# "GTJA_alpha002",
|
||||
# "GTJA_alpha010",
|
||||
# "GTJA_alpha011",
|
||||
# "GTJA_alpha044",
|
||||
# "GTJA_alpha036",
|
||||
# "GTJA_alpha027",
|
||||
# "GTJA_alpha109",
|
||||
# "GTJA_alpha104",
|
||||
# "GTJA_alpha103",
|
||||
# "GTJA_alpha085",
|
||||
# "GTJA_alpha111",
|
||||
# "GTJA_alpha092",
|
||||
# "GTJA_alpha067",
|
||||
# "GTJA_alpha060",
|
||||
# "GTJA_alpha062",
|
||||
# "GTJA_alpha063",
|
||||
# "GTJA_alpha079",
|
||||
# "GTJA_alpha073",
|
||||
# "GTJA_alpha087",
|
||||
# "GTJA_alpha117",
|
||||
# "GTJA_alpha113",
|
||||
# "GTJA_alpha138",
|
||||
# "GTJA_alpha121",
|
||||
# "GTJA_alpha124",
|
||||
# "GTJA_alpha133",
|
||||
# "GTJA_alpha131",
|
||||
# "GTJA_alpha118",
|
||||
# "GTJA_alpha164",
|
||||
# "GTJA_alpha162",
|
||||
# "GTJA_alpha157",
|
||||
# "GTJA_alpha171",
|
||||
# "GTJA_alpha177",
|
||||
# "GTJA_alpha180",
|
||||
# "GTJA_alpha188",
|
||||
# "GTJA_alpha191",
|
||||
]
|
||||
|
||||
|
||||
@@ -317,11 +317,11 @@ def get_label_factor(label_name: str) -> dict:
|
||||
# 辅助函数
|
||||
# =============================================================================
|
||||
def register_factors(
|
||||
engine: FactorEngine,
|
||||
selected_factors: List[str],
|
||||
factor_definitions: dict,
|
||||
label_factor: dict,
|
||||
excluded_factors: Optional[List[str]] = None,
|
||||
engine: FactorEngine,
|
||||
selected_factors: List[str],
|
||||
factor_definitions: dict,
|
||||
label_factor: dict,
|
||||
excluded_factors: Optional[List[str]] = None,
|
||||
) -> List[str]:
|
||||
"""注册因子。
|
||||
|
||||
@@ -402,11 +402,11 @@ def register_factors(
|
||||
|
||||
|
||||
def prepare_data(
|
||||
engine: FactorEngine,
|
||||
feature_cols: List[str],
|
||||
start_date: str,
|
||||
end_date: str,
|
||||
label_name: str,
|
||||
engine: FactorEngine,
|
||||
feature_cols: List[str],
|
||||
start_date: str,
|
||||
end_date: str,
|
||||
label_name: str,
|
||||
) -> pl.DataFrame:
|
||||
"""准备数据。
|
||||
|
||||
@@ -464,11 +464,11 @@ def stock_pool_filter(df: pl.DataFrame) -> pl.Series:
|
||||
"""
|
||||
# 代码筛选(排除创业板、科创板、北交所)
|
||||
code_filter = (
|
||||
~df["ts_code"].str.starts_with("30") # 排除创业板
|
||||
& ~df["ts_code"].str.starts_with("68") # 排除科创板
|
||||
& ~df["ts_code"].str.starts_with("8") # 排除北交所
|
||||
& ~df["ts_code"].str.starts_with("9") # 排除北交所
|
||||
& ~df["ts_code"].str.starts_with("4") # 排除北交所
|
||||
~df["ts_code"].str.starts_with("30") # 排除创业板
|
||||
& ~df["ts_code"].str.starts_with("68") # 排除科创板
|
||||
& ~df["ts_code"].str.starts_with("8") # 排除北交所
|
||||
& ~df["ts_code"].str.starts_with("9") # 排除北交所
|
||||
& ~df["ts_code"].str.starts_with("4") # 排除北交所
|
||||
)
|
||||
|
||||
# 在已筛选的股票中,选取市值最小的500只
|
||||
@@ -490,7 +490,7 @@ OUTPUT_DIR = "output"
|
||||
SAVE_PREDICTIONS = True
|
||||
|
||||
# 模型保存配置
|
||||
SAVE_MODEL = False # 是否保存模型
|
||||
SAVE_MODEL = True # 是否保存模型
|
||||
MODEL_SAVE_DIR = "models" # 模型保存目录
|
||||
|
||||
# Top N 配置:每日推荐股票数量
|
||||
@@ -523,58 +523,68 @@ def get_output_path(model_type: str, test_start: str, test_end: str) -> str:
|
||||
|
||||
|
||||
def get_model_save_path(
|
||||
model_type: str, model_name: Optional[str] = None
|
||||
model_type: str,
|
||||
) -> Optional[str]:
|
||||
"""生成模型保存路径。
|
||||
|
||||
模型将保存在 models/{model_type}/ 目录下,包含 model.pkl 和 factors.json
|
||||
|
||||
Args:
|
||||
model_type: 模型类型("regression" 或 "rank")
|
||||
model_name: 模型名称,默认为 model_type
|
||||
|
||||
Returns:
|
||||
模型保存路径,如果 SAVE_MODEL 为 False 则返回 None
|
||||
模型保存路径(models/{model_type}/model.pkl),如果 SAVE_MODEL 为 False 则返回 None
|
||||
"""
|
||||
if not SAVE_MODEL:
|
||||
return None
|
||||
|
||||
import os
|
||||
|
||||
# 确保模型保存目录存在
|
||||
os.makedirs(MODEL_SAVE_DIR, exist_ok=True)
|
||||
# 模型保存目录:models/{model_type}/
|
||||
model_dir = os.path.join(MODEL_SAVE_DIR, model_type)
|
||||
os.makedirs(model_dir, exist_ok=True)
|
||||
|
||||
# 使用 model_name 或默认使用 model_type
|
||||
name = model_name if model_name else model_type
|
||||
filename = f"{name}.pkl"
|
||||
return os.path.join(MODEL_SAVE_DIR, filename)
|
||||
# 模型文件路径
|
||||
return os.path.join(model_dir, "model.pkl")
|
||||
|
||||
|
||||
def save_model_with_factors(
|
||||
model,
|
||||
model_path: str,
|
||||
selected_factors: List[str],
|
||||
factor_definitions: dict,
|
||||
) -> None:
|
||||
"""保存模型及关联的因子信息。
|
||||
model,
|
||||
model_path: str,
|
||||
selected_factors: list[str],
|
||||
factor_definitions: dict,
|
||||
fitted_processors: list | None = None,
|
||||
) -> str:
|
||||
"""保存模型及关联的因子信息和处理器。
|
||||
|
||||
除了保存模型本身,还会保存一个同名的 .factors.json 文件,
|
||||
包含 SELECTED_FACTORS 和 FACTOR_DEFINITIONS,以便后续加载模型时
|
||||
知道使用了哪些因子。
|
||||
将模型、因子信息和处理器保存到同一文件夹(models/{model_type}/)下:
|
||||
- model.pkl: 模型文件
|
||||
- factors.json: 因子信息文件
|
||||
- processors.pkl: 处理器状态文件(如果提供)
|
||||
|
||||
Args:
|
||||
model: 训练好的模型实例(需有 save 方法)
|
||||
model_path: 模型保存路径
|
||||
model_path: 模型保存路径(由 get_model_save_path 生成)
|
||||
selected_factors: 从 metadata 中选择的因子名称列表
|
||||
factor_definitions: 通过表达式定义的因子字典
|
||||
fitted_processors: 已拟合的处理器列表(可选)
|
||||
|
||||
Returns:
|
||||
模型文件夹路径
|
||||
"""
|
||||
import json
|
||||
import os
|
||||
import pickle
|
||||
|
||||
# 获取模型文件夹路径
|
||||
model_dir = os.path.dirname(model_path)
|
||||
|
||||
# 1. 保存模型本身
|
||||
model.save(model_path)
|
||||
print(f"[模型保存] 模型已保存至: {model_path}")
|
||||
|
||||
# 2. 保存因子信息到 .factors.json 文件
|
||||
factors_path = model_path.replace(".pkl", ".factors.json")
|
||||
# 2. 保存因子信息到 factors.json 文件
|
||||
factors_path = os.path.join(model_dir, "factors.json")
|
||||
|
||||
factors_info = {
|
||||
"selected_factors": selected_factors,
|
||||
@@ -592,12 +602,22 @@ def save_model_with_factors(
|
||||
print(f" - 来自 metadata: {factors_info['selected_factors_count']} 个")
|
||||
print(f" - 来自表达式定义: {factors_info['factor_definitions_count']} 个")
|
||||
|
||||
# 3. 保存处理器(如果提供)
|
||||
if fitted_processors is not None:
|
||||
processors_path = os.path.join(model_dir, "processors.pkl")
|
||||
with open(processors_path, "wb") as f:
|
||||
pickle.dump(fitted_processors, f)
|
||||
print(f"[模型保存] 处理器已保存至: {processors_path}")
|
||||
print(f"[模型保存] 共 {len(fitted_processors)} 个处理器")
|
||||
|
||||
return model_dir
|
||||
|
||||
|
||||
def load_model_factors(model_path: str) -> Optional[dict]:
|
||||
"""加载模型关联的因子信息。
|
||||
|
||||
Args:
|
||||
model_path: 模型保存路径
|
||||
model_path: 模型保存路径(models/{model_type}/model.pkl)
|
||||
|
||||
Returns:
|
||||
包含因子信息的字典,如果文件不存在则返回 None
|
||||
@@ -605,7 +625,9 @@ def load_model_factors(model_path: str) -> Optional[dict]:
|
||||
import json
|
||||
import os
|
||||
|
||||
factors_path = model_path.replace(".pkl", ".factors.json")
|
||||
# 获取模型文件夹路径
|
||||
model_dir = os.path.dirname(model_path)
|
||||
factors_path = os.path.join(model_dir, "factors.json")
|
||||
|
||||
if not os.path.exists(factors_path):
|
||||
print(f"[警告] 未找到因子信息文件: {factors_path}")
|
||||
@@ -618,3 +640,30 @@ def load_model_factors(model_path: str) -> Optional[dict]:
|
||||
f"[模型加载] 已加载因子信息,总计 {factors_info.get('total_feature_count', 'N/A')} 个因子"
|
||||
)
|
||||
return factors_info
|
||||
|
||||
|
||||
def load_processors(model_path: str) -> list | None:
|
||||
"""加载模型关联的处理器。
|
||||
|
||||
Args:
|
||||
model_path: 模型保存路径(models/{model_type}/model.pkl)
|
||||
|
||||
Returns:
|
||||
处理器列表,如果文件不存在则返回 None
|
||||
"""
|
||||
import pickle
|
||||
import os
|
||||
|
||||
# 获取模型文件夹路径
|
||||
model_dir = os.path.dirname(model_path)
|
||||
processors_path = os.path.join(model_dir, "processors.pkl")
|
||||
|
||||
if not os.path.exists(processors_path):
|
||||
print(f"[警告] 未找到处理器文件: {processors_path}")
|
||||
return None
|
||||
|
||||
with open(processors_path, "rb") as f:
|
||||
fitted_processors = pickle.load(f)
|
||||
|
||||
print(f"[模型加载] 已加载 {len(fitted_processors)} 个处理器")
|
||||
return fitted_processors
|
||||
|
||||
Reference in New Issue
Block a user