refactor(experiment): 重构模型保存机制,支持 processors 持久化
- 模型保存路径改为 models/{model_type}/ 目录结构
- save_model_with_factors 新增 fitted_processors 参数
- 新增 load_processors 函数加载处理器状态
- Storage 查询排序优化:ORDER BY ts_code, trade_date
This commit is contained in:
@@ -217,7 +217,7 @@ class Storage:
|
||||
params.append(ts_code)
|
||||
|
||||
where_clause = f"WHERE {' AND '.join(conditions)}" if conditions else ""
|
||||
query = f"SELECT * FROM {name} {where_clause} ORDER BY trade_date"
|
||||
query = f"SELECT * FROM {name} {where_clause} ORDER BY ts_code, trade_date"
|
||||
|
||||
try:
|
||||
# Execute query with parameters (SQL injection safe)
|
||||
@@ -255,7 +255,7 @@ class Storage:
|
||||
conditions.append(f"ts_code = '{ts_code}'")
|
||||
|
||||
where_clause = f"WHERE {' AND '.join(conditions)}" if conditions else ""
|
||||
query = f"SELECT * FROM {name} {where_clause} ORDER BY trade_date"
|
||||
query = f"SELECT * FROM {name} {where_clause} ORDER BY ts_code, trade_date"
|
||||
|
||||
# 使用 DuckDB 的 Polars 导出(需要 pyarrow)
|
||||
df = self._connection.sql(query).pl()
|
||||
|
||||
@@ -228,7 +228,7 @@ SELECTED_FACTORS = [
|
||||
"GTJA_alpha162",
|
||||
"GTJA_alpha163",
|
||||
"GTJA_alpha164",
|
||||
"GTJA_alpha165",
|
||||
# "GTJA_alpha165",
|
||||
"GTJA_alpha166",
|
||||
"GTJA_alpha167",
|
||||
"GTJA_alpha168",
|
||||
@@ -243,7 +243,7 @@ SELECTED_FACTORS = [
|
||||
"GTJA_alpha178",
|
||||
"GTJA_alpha179",
|
||||
"GTJA_alpha180",
|
||||
"GTJA_alpha183",
|
||||
# "GTJA_alpha183",
|
||||
"GTJA_alpha184",
|
||||
"GTJA_alpha185",
|
||||
"GTJA_alpha187",
|
||||
@@ -258,44 +258,44 @@ FACTOR_DEFINITIONS = {}
|
||||
# 需要排除的因子列表(这些因子不会被计算和使用)
|
||||
# 用于临时屏蔽效果不好的因子,无需从 SELECTED_FACTORS 中删除
|
||||
EXCLUDED_FACTORS: List[str] = [
|
||||
'GTJA_alpha005',
|
||||
'GTJA_alpha028',
|
||||
'GTJA_alpha023',
|
||||
'GTJA_alpha002',
|
||||
'GTJA_alpha010',
|
||||
'GTJA_alpha011',
|
||||
'GTJA_alpha044',
|
||||
'GTJA_alpha036',
|
||||
'GTJA_alpha027',
|
||||
'GTJA_alpha109',
|
||||
'GTJA_alpha104',
|
||||
'GTJA_alpha103',
|
||||
'GTJA_alpha085',
|
||||
'GTJA_alpha111',
|
||||
'GTJA_alpha092',
|
||||
'GTJA_alpha067',
|
||||
'GTJA_alpha060',
|
||||
'GTJA_alpha062',
|
||||
'GTJA_alpha063',
|
||||
'GTJA_alpha079',
|
||||
'GTJA_alpha073',
|
||||
'GTJA_alpha087',
|
||||
'GTJA_alpha117',
|
||||
'GTJA_alpha113',
|
||||
'GTJA_alpha138',
|
||||
'GTJA_alpha121',
|
||||
'GTJA_alpha124',
|
||||
'GTJA_alpha133',
|
||||
'GTJA_alpha131',
|
||||
'GTJA_alpha118',
|
||||
'GTJA_alpha164',
|
||||
'GTJA_alpha162',
|
||||
'GTJA_alpha157',
|
||||
'GTJA_alpha171',
|
||||
'GTJA_alpha177',
|
||||
'GTJA_alpha180',
|
||||
'GTJA_alpha188',
|
||||
'GTJA_alpha191',
|
||||
# "GTJA_alpha005",
|
||||
# "GTJA_alpha028",
|
||||
# "GTJA_alpha023",
|
||||
# "GTJA_alpha002",
|
||||
# "GTJA_alpha010",
|
||||
# "GTJA_alpha011",
|
||||
# "GTJA_alpha044",
|
||||
# "GTJA_alpha036",
|
||||
# "GTJA_alpha027",
|
||||
# "GTJA_alpha109",
|
||||
# "GTJA_alpha104",
|
||||
# "GTJA_alpha103",
|
||||
# "GTJA_alpha085",
|
||||
# "GTJA_alpha111",
|
||||
# "GTJA_alpha092",
|
||||
# "GTJA_alpha067",
|
||||
# "GTJA_alpha060",
|
||||
# "GTJA_alpha062",
|
||||
# "GTJA_alpha063",
|
||||
# "GTJA_alpha079",
|
||||
# "GTJA_alpha073",
|
||||
# "GTJA_alpha087",
|
||||
# "GTJA_alpha117",
|
||||
# "GTJA_alpha113",
|
||||
# "GTJA_alpha138",
|
||||
# "GTJA_alpha121",
|
||||
# "GTJA_alpha124",
|
||||
# "GTJA_alpha133",
|
||||
# "GTJA_alpha131",
|
||||
# "GTJA_alpha118",
|
||||
# "GTJA_alpha164",
|
||||
# "GTJA_alpha162",
|
||||
# "GTJA_alpha157",
|
||||
# "GTJA_alpha171",
|
||||
# "GTJA_alpha177",
|
||||
# "GTJA_alpha180",
|
||||
# "GTJA_alpha188",
|
||||
# "GTJA_alpha191",
|
||||
]
|
||||
|
||||
|
||||
@@ -317,11 +317,11 @@ def get_label_factor(label_name: str) -> dict:
|
||||
# 辅助函数
|
||||
# =============================================================================
|
||||
def register_factors(
|
||||
engine: FactorEngine,
|
||||
selected_factors: List[str],
|
||||
factor_definitions: dict,
|
||||
label_factor: dict,
|
||||
excluded_factors: Optional[List[str]] = None,
|
||||
engine: FactorEngine,
|
||||
selected_factors: List[str],
|
||||
factor_definitions: dict,
|
||||
label_factor: dict,
|
||||
excluded_factors: Optional[List[str]] = None,
|
||||
) -> List[str]:
|
||||
"""注册因子。
|
||||
|
||||
@@ -402,11 +402,11 @@ def register_factors(
|
||||
|
||||
|
||||
def prepare_data(
|
||||
engine: FactorEngine,
|
||||
feature_cols: List[str],
|
||||
start_date: str,
|
||||
end_date: str,
|
||||
label_name: str,
|
||||
engine: FactorEngine,
|
||||
feature_cols: List[str],
|
||||
start_date: str,
|
||||
end_date: str,
|
||||
label_name: str,
|
||||
) -> pl.DataFrame:
|
||||
"""准备数据。
|
||||
|
||||
@@ -464,11 +464,11 @@ def stock_pool_filter(df: pl.DataFrame) -> pl.Series:
|
||||
"""
|
||||
# 代码筛选(排除创业板、科创板、北交所)
|
||||
code_filter = (
|
||||
~df["ts_code"].str.starts_with("30") # 排除创业板
|
||||
& ~df["ts_code"].str.starts_with("68") # 排除科创板
|
||||
& ~df["ts_code"].str.starts_with("8") # 排除北交所
|
||||
& ~df["ts_code"].str.starts_with("9") # 排除北交所
|
||||
& ~df["ts_code"].str.starts_with("4") # 排除北交所
|
||||
~df["ts_code"].str.starts_with("30") # 排除创业板
|
||||
& ~df["ts_code"].str.starts_with("68") # 排除科创板
|
||||
& ~df["ts_code"].str.starts_with("8") # 排除北交所
|
||||
& ~df["ts_code"].str.starts_with("9") # 排除北交所
|
||||
& ~df["ts_code"].str.starts_with("4") # 排除北交所
|
||||
)
|
||||
|
||||
# 在已筛选的股票中,选取市值最小的500只
|
||||
@@ -490,7 +490,7 @@ OUTPUT_DIR = "output"
|
||||
SAVE_PREDICTIONS = True
|
||||
|
||||
# 模型保存配置
|
||||
SAVE_MODEL = False # 是否保存模型
|
||||
SAVE_MODEL = True # 是否保存模型
|
||||
MODEL_SAVE_DIR = "models" # 模型保存目录
|
||||
|
||||
# Top N 配置:每日推荐股票数量
|
||||
@@ -523,58 +523,68 @@ def get_output_path(model_type: str, test_start: str, test_end: str) -> str:
|
||||
|
||||
|
||||
def get_model_save_path(
|
||||
model_type: str, model_name: Optional[str] = None
|
||||
model_type: str,
|
||||
) -> Optional[str]:
|
||||
"""生成模型保存路径。
|
||||
|
||||
模型将保存在 models/{model_type}/ 目录下,包含 model.pkl 和 factors.json
|
||||
|
||||
Args:
|
||||
model_type: 模型类型("regression" 或 "rank")
|
||||
model_name: 模型名称,默认为 model_type
|
||||
|
||||
Returns:
|
||||
模型保存路径,如果 SAVE_MODEL 为 False 则返回 None
|
||||
模型保存路径(models/{model_type}/model.pkl),如果 SAVE_MODEL 为 False 则返回 None
|
||||
"""
|
||||
if not SAVE_MODEL:
|
||||
return None
|
||||
|
||||
import os
|
||||
|
||||
# 确保模型保存目录存在
|
||||
os.makedirs(MODEL_SAVE_DIR, exist_ok=True)
|
||||
# 模型保存目录:models/{model_type}/
|
||||
model_dir = os.path.join(MODEL_SAVE_DIR, model_type)
|
||||
os.makedirs(model_dir, exist_ok=True)
|
||||
|
||||
# 使用 model_name 或默认使用 model_type
|
||||
name = model_name if model_name else model_type
|
||||
filename = f"{name}.pkl"
|
||||
return os.path.join(MODEL_SAVE_DIR, filename)
|
||||
# 模型文件路径
|
||||
return os.path.join(model_dir, "model.pkl")
|
||||
|
||||
|
||||
def save_model_with_factors(
|
||||
model,
|
||||
model_path: str,
|
||||
selected_factors: List[str],
|
||||
factor_definitions: dict,
|
||||
) -> None:
|
||||
"""保存模型及关联的因子信息。
|
||||
model,
|
||||
model_path: str,
|
||||
selected_factors: list[str],
|
||||
factor_definitions: dict,
|
||||
fitted_processors: list | None = None,
|
||||
) -> str:
|
||||
"""保存模型及关联的因子信息和处理器。
|
||||
|
||||
除了保存模型本身,还会保存一个同名的 .factors.json 文件,
|
||||
包含 SELECTED_FACTORS 和 FACTOR_DEFINITIONS,以便后续加载模型时
|
||||
知道使用了哪些因子。
|
||||
将模型、因子信息和处理器保存到同一文件夹(models/{model_type}/)下:
|
||||
- model.pkl: 模型文件
|
||||
- factors.json: 因子信息文件
|
||||
- processors.pkl: 处理器状态文件(如果提供)
|
||||
|
||||
Args:
|
||||
model: 训练好的模型实例(需有 save 方法)
|
||||
model_path: 模型保存路径
|
||||
model_path: 模型保存路径(由 get_model_save_path 生成)
|
||||
selected_factors: 从 metadata 中选择的因子名称列表
|
||||
factor_definitions: 通过表达式定义的因子字典
|
||||
fitted_processors: 已拟合的处理器列表(可选)
|
||||
|
||||
Returns:
|
||||
模型文件夹路径
|
||||
"""
|
||||
import json
|
||||
import os
|
||||
import pickle
|
||||
|
||||
# 获取模型文件夹路径
|
||||
model_dir = os.path.dirname(model_path)
|
||||
|
||||
# 1. 保存模型本身
|
||||
model.save(model_path)
|
||||
print(f"[模型保存] 模型已保存至: {model_path}")
|
||||
|
||||
# 2. 保存因子信息到 .factors.json 文件
|
||||
factors_path = model_path.replace(".pkl", ".factors.json")
|
||||
# 2. 保存因子信息到 factors.json 文件
|
||||
factors_path = os.path.join(model_dir, "factors.json")
|
||||
|
||||
factors_info = {
|
||||
"selected_factors": selected_factors,
|
||||
@@ -592,12 +602,22 @@ def save_model_with_factors(
|
||||
print(f" - 来自 metadata: {factors_info['selected_factors_count']} 个")
|
||||
print(f" - 来自表达式定义: {factors_info['factor_definitions_count']} 个")
|
||||
|
||||
# 3. 保存处理器(如果提供)
|
||||
if fitted_processors is not None:
|
||||
processors_path = os.path.join(model_dir, "processors.pkl")
|
||||
with open(processors_path, "wb") as f:
|
||||
pickle.dump(fitted_processors, f)
|
||||
print(f"[模型保存] 处理器已保存至: {processors_path}")
|
||||
print(f"[模型保存] 共 {len(fitted_processors)} 个处理器")
|
||||
|
||||
return model_dir
|
||||
|
||||
|
||||
def load_model_factors(model_path: str) -> Optional[dict]:
|
||||
"""加载模型关联的因子信息。
|
||||
|
||||
Args:
|
||||
model_path: 模型保存路径
|
||||
model_path: 模型保存路径(models/{model_type}/model.pkl)
|
||||
|
||||
Returns:
|
||||
包含因子信息的字典,如果文件不存在则返回 None
|
||||
@@ -605,7 +625,9 @@ def load_model_factors(model_path: str) -> Optional[dict]:
|
||||
import json
|
||||
import os
|
||||
|
||||
factors_path = model_path.replace(".pkl", ".factors.json")
|
||||
# 获取模型文件夹路径
|
||||
model_dir = os.path.dirname(model_path)
|
||||
factors_path = os.path.join(model_dir, "factors.json")
|
||||
|
||||
if not os.path.exists(factors_path):
|
||||
print(f"[警告] 未找到因子信息文件: {factors_path}")
|
||||
@@ -618,3 +640,30 @@ def load_model_factors(model_path: str) -> Optional[dict]:
|
||||
f"[模型加载] 已加载因子信息,总计 {factors_info.get('total_feature_count', 'N/A')} 个因子"
|
||||
)
|
||||
return factors_info
|
||||
|
||||
|
||||
def load_processors(model_path: str) -> list | None:
|
||||
"""加载模型关联的处理器。
|
||||
|
||||
Args:
|
||||
model_path: 模型保存路径(models/{model_type}/model.pkl)
|
||||
|
||||
Returns:
|
||||
处理器列表,如果文件不存在则返回 None
|
||||
"""
|
||||
import pickle
|
||||
import os
|
||||
|
||||
# 获取模型文件夹路径
|
||||
model_dir = os.path.dirname(model_path)
|
||||
processors_path = os.path.join(model_dir, "processors.pkl")
|
||||
|
||||
if not os.path.exists(processors_path):
|
||||
print(f"[警告] 未找到处理器文件: {processors_path}")
|
||||
return None
|
||||
|
||||
with open(processors_path, "rb") as f:
|
||||
fitted_processors = pickle.load(f)
|
||||
|
||||
print(f"[模型加载] 已加载 {len(fitted_processors)} 个处理器")
|
||||
return fitted_processors
|
||||
|
||||
@@ -603,6 +603,7 @@ if SAVE_MODEL:
|
||||
model_path=model_save_path,
|
||||
selected_factors=SELECTED_FACTORS,
|
||||
factor_definitions=FACTOR_DEFINITIONS,
|
||||
fitted_processors=fitted_processors,
|
||||
)
|
||||
|
||||
print("\n训练流程完成!")
|
||||
|
||||
413
src/experiment/predict.py
Normal file
413
src/experiment/predict.py
Normal file
@@ -0,0 +1,413 @@
|
||||
"""预测脚本 - 加载模型并对指定时间段进行预测。
|
||||
|
||||
支持两种模型类型:
|
||||
- regression: 回归模型
|
||||
- rank: 排序学习模型
|
||||
|
||||
脚本会自动分析 models 目录下的模型类型,用户只需指定模型类型和预测时间段。
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import polars as pl
|
||||
|
||||
from src.factors import FactorEngine
|
||||
from src.training import (
|
||||
STFilter,
|
||||
StockPoolManager,
|
||||
Winsorizer,
|
||||
NullFiller,
|
||||
StandardScaler,
|
||||
CrossSectionalStandardScaler,
|
||||
)
|
||||
from src.training.components.models import LightGBMModel, LightGBMLambdaRankModel
|
||||
from src.experiment.common import (
|
||||
get_label_factor,
|
||||
stock_pool_filter,
|
||||
STOCK_FILTER_REQUIRED_COLUMNS,
|
||||
OUTPUT_DIR,
|
||||
TOP_N,
|
||||
load_processors,
|
||||
)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# 配置区域 - 用户需要修改这些配置
|
||||
# =============================================================================
|
||||
# 模型类型: "regression" 或 "rank"
|
||||
MODEL_TYPE = "rank"
|
||||
|
||||
# 预测时间段(不从中读取,使用这里的配置)
|
||||
PREDICT_START = "20250101"
|
||||
PREDICT_END = "20261231"
|
||||
|
||||
# 数据回看窗口天数(用于计算时序因子,需要向前获取额外数据)
|
||||
# 例如:如果因子使用了 ts_mean(close, 60),则回看窗口至少为 60 天
|
||||
LOOKBACK_DAYS = 365 * 3 # 向前获取 1 年的数据确保所有因子都能正确计算
|
||||
|
||||
# 模型路径配置
|
||||
MODEL_SAVE_DIR = "models"
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def detect_model_type(models_dir: str) -> str:
|
||||
"""自动检测模型类型。
|
||||
|
||||
检查 models 目录下有哪些模型类型可用。
|
||||
|
||||
Args:
|
||||
models_dir: 模型保存目录
|
||||
|
||||
Returns:
|
||||
检测到的模型类型,如果多个则优先返回 regression
|
||||
"""
|
||||
available_types: List[str] = []
|
||||
|
||||
# 检查 regression 模型
|
||||
regression_path = os.path.join(models_dir, "regression", "model.pkl")
|
||||
if os.path.exists(regression_path):
|
||||
available_types.append("regression")
|
||||
|
||||
# 检查 rank 模型
|
||||
rank_path = os.path.join(models_dir, "rank", "model.pkl")
|
||||
if os.path.exists(rank_path):
|
||||
available_types.append("rank")
|
||||
|
||||
if not available_types:
|
||||
raise FileNotFoundError(
|
||||
f"未在 {models_dir} 目录下找到任何模型。"
|
||||
f"请确保模型已训练并保存在 models/regression/model.pkl 或 models/rank/model.pkl"
|
||||
)
|
||||
|
||||
print(f"[模型检测] 可用的模型类型: {available_types}")
|
||||
|
||||
# 如果用户指定的类型可用,直接返回
|
||||
if MODEL_TYPE in available_types:
|
||||
return MODEL_TYPE
|
||||
|
||||
# 如果用户未指定或指定的不可用,返回第一个可用的
|
||||
print(f"[模型检测] 使用默认模型类型: {available_types[0]}")
|
||||
return available_types[0]
|
||||
|
||||
|
||||
def load_model_and_factors(
|
||||
model_type: str, models_dir: str
|
||||
) -> Tuple[Any, Dict[str, Any], List[str], str]:
|
||||
"""加载模型及其关联的因子信息。
|
||||
|
||||
Args:
|
||||
model_type: 模型类型("regression" 或 "rank")
|
||||
models_dir: 模型保存目录
|
||||
|
||||
Returns:
|
||||
(model, factors_info, feature_cols, model_path) 元组
|
||||
"""
|
||||
model_path = os.path.join(models_dir, model_type, "model.pkl")
|
||||
factors_path = os.path.join(models_dir, model_type, "factors.json")
|
||||
|
||||
print(f"\n{'=' * 80}")
|
||||
print(f"加载模型: {model_type}")
|
||||
print(f"{'=' * 80}")
|
||||
|
||||
# 检查模型文件是否存在
|
||||
if not os.path.exists(model_path):
|
||||
raise FileNotFoundError(f"模型文件不存在: {model_path}")
|
||||
|
||||
# 加载模型(根据模型类型选择正确的加载方法)
|
||||
print(f"[模型加载] 正在加载模型: {model_path}")
|
||||
|
||||
if model_type == "regression":
|
||||
model = LightGBMModel.load(model_path)
|
||||
elif model_type == "rank":
|
||||
model = LightGBMLambdaRankModel.load(model_path)
|
||||
else:
|
||||
raise ValueError(f"不支持的模型类型: {model_type}")
|
||||
|
||||
print(f"[模型加载] 已加载模型: {model_path}")
|
||||
print(f"[模型加载] 模型类型: {type(model).__name__}")
|
||||
|
||||
# 加载因子信息
|
||||
if not os.path.exists(factors_path):
|
||||
raise FileNotFoundError(f"因子信息文件不存在: {factors_path}")
|
||||
|
||||
with open(factors_path, "r", encoding="utf-8") as f:
|
||||
factors_info = json.load(f)
|
||||
|
||||
print(f"[因子加载] 已加载因子信息: {factors_path}")
|
||||
print(f"[因子加载] 因子总数: {factors_info.get('total_feature_count', 'N/A')}")
|
||||
print(
|
||||
f"[因子加载] 来自 metadata: {factors_info.get('selected_factors_count', 'N/A')}"
|
||||
)
|
||||
print(
|
||||
f"[因子加载] 来自表达式: {factors_info.get('factor_definitions_count', 'N/A')}"
|
||||
)
|
||||
|
||||
# 构建特征列列表
|
||||
selected_factors = factors_info.get("selected_factors", [])
|
||||
factor_definitions = factors_info.get("factor_definitions", {})
|
||||
feature_cols = selected_factors + list(factor_definitions.keys())
|
||||
|
||||
print(f"[特征列] 共 {len(feature_cols)} 个特征")
|
||||
|
||||
return model, factors_info, feature_cols, model_path
|
||||
|
||||
|
||||
def get_lookback_start_date(start_date: str, lookback_days: int) -> str:
|
||||
"""计算考虑回看窗口后的实际开始日期。
|
||||
|
||||
为了确保时序因子(如 ts_mean(close, 20))在预测开始日期能正确计算,
|
||||
需要向前获取额外的历史数据。
|
||||
|
||||
Args:
|
||||
start_date: 预测开始日期 (YYYYMMDD)
|
||||
lookback_days: 回看窗口天数
|
||||
|
||||
Returns:
|
||||
考虑回看后的实际开始日期 (YYYYMMDD)
|
||||
"""
|
||||
start_dt = datetime.strptime(start_date, "%Y%m%d")
|
||||
lookback_dt = start_dt - timedelta(days=lookback_days)
|
||||
return lookback_dt.strftime("%Y%m%d")
|
||||
|
||||
|
||||
def prepare_data(
|
||||
engine: FactorEngine,
|
||||
feature_cols: list[str],
|
||||
start_date: str,
|
||||
end_date: str,
|
||||
label_name: str,
|
||||
) -> pl.DataFrame:
|
||||
"""准备预测数据。
|
||||
|
||||
Args:
|
||||
engine: FactorEngine实例
|
||||
feature_cols: 特征列名称列表
|
||||
start_date: 开始日期 (YYYYMMDD)
|
||||
end_date: 结束日期 (YYYYMMDD)
|
||||
label_name: label列名称
|
||||
|
||||
Returns:
|
||||
包含因子计算结果的数据框
|
||||
"""
|
||||
print(f"\n{'=' * 80}")
|
||||
print(f"准备数据: {start_date} - {end_date}")
|
||||
print(f"{'=' * 80}")
|
||||
|
||||
factor_names = feature_cols + [label_name]
|
||||
|
||||
data = engine.compute(
|
||||
factor_names=factor_names,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
)
|
||||
|
||||
print(f"数据形状: {data.shape}")
|
||||
print(f"前5行预览:")
|
||||
print(data.head())
|
||||
|
||||
return data
|
||||
|
||||
|
||||
def apply_processors(data: pl.DataFrame, fitted_processors: List[Any]) -> pl.DataFrame:
|
||||
"""应用已拟合的数据处理器。
|
||||
|
||||
Args:
|
||||
data: 输入数据
|
||||
fitted_processors: 已拟合的处理器列表(从训练时保存)
|
||||
|
||||
Returns:
|
||||
处理后的数据
|
||||
"""
|
||||
print(f"\n{'=' * 80}")
|
||||
print("应用数据处理器(使用训练时保存的参数)")
|
||||
print(f"{'=' * 80}")
|
||||
|
||||
for i, processor in enumerate(fitted_processors, 1):
|
||||
print(f" [{i}/{len(fitted_processors)}] {processor.__class__.__name__}")
|
||||
data = processor.transform(data) # 使用 transform,不是 fit_transform
|
||||
|
||||
print(f"处理后数据形状: {data.shape}")
|
||||
return data
|
||||
|
||||
|
||||
def predict():
|
||||
"""主预测流程。"""
|
||||
print("\n" + "=" * 80)
|
||||
print("ProStock 预测脚本")
|
||||
print("=" * 80)
|
||||
|
||||
# 1. 自动检测模型类型
|
||||
model_type = detect_model_type(MODEL_SAVE_DIR)
|
||||
print(f"\n[配置] 使用模型类型: {model_type}")
|
||||
print(f"[配置] 预测时间段: {PREDICT_START} - {PREDICT_END}")
|
||||
|
||||
# 2. 加载模型和因子信息
|
||||
model, factors_info, feature_cols, model_path = load_model_and_factors(
|
||||
model_type, MODEL_SAVE_DIR
|
||||
)
|
||||
|
||||
# 提取因子配置
|
||||
selected_factors = factors_info.get("selected_factors", [])
|
||||
factor_definitions = factors_info.get("factor_definitions", {})
|
||||
|
||||
# 3. 创建 FactorEngine
|
||||
print("\n[1] 创建 FactorEngine")
|
||||
engine = FactorEngine()
|
||||
|
||||
# 4. 注册因子
|
||||
print("\n[2] 注册因子")
|
||||
label_name = "future_return_5"
|
||||
label_factor = get_label_factor(label_name)
|
||||
|
||||
# 注册来自 metadata 的因子
|
||||
print(" 注册 metadata 因子:")
|
||||
for name in selected_factors:
|
||||
engine.add_factor(name)
|
||||
print(f" - {name}")
|
||||
|
||||
# 注册表达式因子
|
||||
print(" 注册表达式因子:")
|
||||
for name, expr in factor_definitions.items():
|
||||
engine.add_factor(name, expr)
|
||||
print(f" - {name}: {expr}")
|
||||
|
||||
# 注册 label 因子
|
||||
print(" 注册 Label 因子:")
|
||||
for name, expr in label_factor.items():
|
||||
engine.add_factor(name, expr)
|
||||
print(f" - {name}: {expr}")
|
||||
|
||||
# 5. 准备数据(考虑回看窗口)
|
||||
print(f"\n[数据准备] 预测时间段: {PREDICT_START} - {PREDICT_END}")
|
||||
print(f"[数据准备] 回看窗口: {LOOKBACK_DAYS} 天")
|
||||
|
||||
actual_start = get_lookback_start_date(PREDICT_START, LOOKBACK_DAYS)
|
||||
print(f"[数据准备] 实际加载数据时间段: {actual_start} - {PREDICT_END}")
|
||||
|
||||
data = prepare_data(
|
||||
engine=engine,
|
||||
feature_cols=feature_cols,
|
||||
start_date=actual_start,
|
||||
end_date=PREDICT_END,
|
||||
label_name=label_name,
|
||||
)
|
||||
|
||||
# 过滤回看数据,只保留预测日期范围内的数据
|
||||
print(f"[数据准备] 过滤回看数据,保留 {PREDICT_START} 之后的数据...")
|
||||
data = data.filter(data["trade_date"] >= PREDICT_START)
|
||||
print(f"[数据准备] 过滤后数据形状: {data.shape}")
|
||||
|
||||
# 6. 股票池筛选
|
||||
print("\n[3] 股票池筛选")
|
||||
pool_manager = StockPoolManager(
|
||||
filter_func=stock_pool_filter,
|
||||
required_columns=STOCK_FILTER_REQUIRED_COLUMNS,
|
||||
data_router=engine.router,
|
||||
)
|
||||
|
||||
st_filter = STFilter(data_router=engine.router)
|
||||
|
||||
# 先执行 ST 过滤
|
||||
if st_filter:
|
||||
print(" 应用 ST 过滤器...")
|
||||
data = st_filter.filter(data)
|
||||
print(f" ST 过滤后数据规模: {data.shape}")
|
||||
|
||||
# 股票池筛选
|
||||
print(" 执行每日股票池筛选...")
|
||||
filtered_data = pool_manager.filter_and_select_daily(data)
|
||||
print(f" 筛选前数据规模: {data.shape}")
|
||||
print(f" 筛选后数据规模: {filtered_data.shape}")
|
||||
print(f" 筛选前股票数: {data['ts_code'].n_unique()}")
|
||||
print(f" 筛选后股票数: {filtered_data['ts_code'].n_unique()}")
|
||||
|
||||
# 7. 加载并应用数据处理器
|
||||
print("\n[3.5] 加载数据处理器")
|
||||
model_dir = os.path.dirname(model_path)
|
||||
fitted_processors = load_processors(model_path)
|
||||
|
||||
if fitted_processors is None:
|
||||
raise FileNotFoundError(
|
||||
f"未找到处理器文件,请确保模型已正确训练并保存处理器到 {model_dir}/processors.pkl"
|
||||
)
|
||||
|
||||
processed_data = apply_processors(filtered_data, fitted_processors)
|
||||
|
||||
# 8. 生成预测
|
||||
print("\n[4] 生成预测")
|
||||
print("-" * 60)
|
||||
X = processed_data.select(feature_cols)
|
||||
print(f" 预测样本数: {len(X)}")
|
||||
print(f" 特征数: {len(feature_cols)}")
|
||||
|
||||
predictions = model.predict(X)
|
||||
print(f" 预测完成!")
|
||||
|
||||
print(f"\n 预测结果统计:")
|
||||
print(f" 均值: {predictions.mean():.6f}")
|
||||
print(f" 标准差: {predictions.std():.6f}")
|
||||
print(f" 最小值: {predictions.min():.6f}")
|
||||
print(f" 最大值: {predictions.max():.6f}")
|
||||
|
||||
# 添加预测列
|
||||
processed_data = processed_data.with_columns([pl.Series("prediction", predictions)])
|
||||
|
||||
# 9. 保存结果
|
||||
print("\n[5] 保存预测结果")
|
||||
print("-" * 60)
|
||||
|
||||
# 确保输出目录存在
|
||||
os.makedirs(OUTPUT_DIR, exist_ok=True)
|
||||
|
||||
# 生成输出文件名
|
||||
start_dt = datetime.strptime(PREDICT_START, "%Y%m%d")
|
||||
end_dt = datetime.strptime(PREDICT_END, "%Y%m%d")
|
||||
date_str = f"{start_dt.strftime('%Y%m%d')}_{end_dt.strftime('%Y%m%d')}"
|
||||
|
||||
# 保存每日 Top N
|
||||
print(f" 保存每日 Top {TOP_N} 股票...")
|
||||
output_path = os.path.join(OUTPUT_DIR, "predict_output.csv")
|
||||
|
||||
# 按日期分组,取每日 top N
|
||||
topn_by_date = []
|
||||
unique_dates = processed_data["trade_date"].unique().sort()
|
||||
for date in unique_dates:
|
||||
day_data = processed_data.filter(processed_data["trade_date"] == date)
|
||||
# 按 prediction 降序排序,取前 N
|
||||
topn = day_data.sort("prediction", descending=True).head(TOP_N)
|
||||
topn_by_date.append(topn)
|
||||
|
||||
# 合并所有日期的 top N
|
||||
topn_results = pl.concat(topn_by_date)
|
||||
|
||||
# 格式化日期并调整列顺序:日期、分数、股票
|
||||
topn_to_save = topn_results.select(
|
||||
[
|
||||
pl.col("trade_date").str.slice(0, 4)
|
||||
+ "-"
|
||||
+ pl.col("trade_date").str.slice(4, 2)
|
||||
+ "-"
|
||||
+ pl.col("trade_date").str.slice(6, 2).alias("date"),
|
||||
pl.col("prediction").alias("score"),
|
||||
pl.col("ts_code"),
|
||||
]
|
||||
)
|
||||
topn_to_save.write_csv(output_path, include_header=True)
|
||||
print(f" 保存路径: {output_path}")
|
||||
print(
|
||||
f" 保存行数: {len(topn_to_save)}({len(unique_dates)}个交易日 × 每日top{TOP_N})"
|
||||
)
|
||||
print(f"\n 预览(前15行):")
|
||||
print(topn_to_save.head(15))
|
||||
|
||||
print("\n" + "=" * 80)
|
||||
print("预测完成!")
|
||||
print("=" * 80)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
predict()
|
||||
@@ -598,4 +598,5 @@ if SAVE_MODEL:
|
||||
model_path=model_save_path,
|
||||
selected_factors=SELECTED_FACTORS,
|
||||
factor_definitions=FACTOR_DEFINITIONS,
|
||||
fitted_processors=fitted_processors,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user