Files
ProStock/src/experiment/predict.py
liaozhaorun ccd42082c2 refactor(experiment): 重构模型保存机制,支持 processors 持久化
- 模型保存路径改为 models/{model_type}/ 目录结构
- save_model_with_factors 新增 fitted_processors 参数
- 新增 load_processors 函数加载处理器状态
- Storage 查询排序优化:ORDER BY ts_code, trade_date
2026-03-19 21:06:11 +08:00

414 lines
13 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""预测脚本 - 加载模型并对指定时间段进行预测。
支持两种模型类型:
- regression: 回归模型
- rank: 排序学习模型
脚本会自动分析 models 目录下的模型类型,用户只需指定模型类型和预测时间段。
"""
import json
import os
from datetime import datetime, timedelta
from typing import Any, Dict, List, Optional, Tuple
import numpy as np
import polars as pl
from src.factors import FactorEngine
from src.training import (
STFilter,
StockPoolManager,
Winsorizer,
NullFiller,
StandardScaler,
CrossSectionalStandardScaler,
)
from src.training.components.models import LightGBMModel, LightGBMLambdaRankModel
from src.experiment.common import (
get_label_factor,
stock_pool_filter,
STOCK_FILTER_REQUIRED_COLUMNS,
OUTPUT_DIR,
TOP_N,
load_processors,
)
# =============================================================================
# 配置区域 - 用户需要修改这些配置
# =============================================================================
# 模型类型: "regression" 或 "rank"
MODEL_TYPE = "rank"
# 预测时间段(不从中读取,使用这里的配置)
PREDICT_START = "20250101"
PREDICT_END = "20261231"
# 数据回看窗口天数(用于计算时序因子,需要向前获取额外数据)
# 例如:如果因子使用了 ts_mean(close, 60),则回看窗口至少为 60 天
LOOKBACK_DAYS = 365 * 3 # 向前获取 1 年的数据确保所有因子都能正确计算
# 模型路径配置
MODEL_SAVE_DIR = "models"
# =============================================================================
def detect_model_type(models_dir: str) -> str:
"""自动检测模型类型。
检查 models 目录下有哪些模型类型可用。
Args:
models_dir: 模型保存目录
Returns:
检测到的模型类型,如果多个则优先返回 regression
"""
available_types: List[str] = []
# 检查 regression 模型
regression_path = os.path.join(models_dir, "regression", "model.pkl")
if os.path.exists(regression_path):
available_types.append("regression")
# 检查 rank 模型
rank_path = os.path.join(models_dir, "rank", "model.pkl")
if os.path.exists(rank_path):
available_types.append("rank")
if not available_types:
raise FileNotFoundError(
f"未在 {models_dir} 目录下找到任何模型。"
f"请确保模型已训练并保存在 models/regression/model.pkl 或 models/rank/model.pkl"
)
print(f"[模型检测] 可用的模型类型: {available_types}")
# 如果用户指定的类型可用,直接返回
if MODEL_TYPE in available_types:
return MODEL_TYPE
# 如果用户未指定或指定的不可用,返回第一个可用的
print(f"[模型检测] 使用默认模型类型: {available_types[0]}")
return available_types[0]
def load_model_and_factors(
model_type: str, models_dir: str
) -> Tuple[Any, Dict[str, Any], List[str], str]:
"""加载模型及其关联的因子信息。
Args:
model_type: 模型类型("regression""rank"
models_dir: 模型保存目录
Returns:
(model, factors_info, feature_cols, model_path) 元组
"""
model_path = os.path.join(models_dir, model_type, "model.pkl")
factors_path = os.path.join(models_dir, model_type, "factors.json")
print(f"\n{'=' * 80}")
print(f"加载模型: {model_type}")
print(f"{'=' * 80}")
# 检查模型文件是否存在
if not os.path.exists(model_path):
raise FileNotFoundError(f"模型文件不存在: {model_path}")
# 加载模型(根据模型类型选择正确的加载方法)
print(f"[模型加载] 正在加载模型: {model_path}")
if model_type == "regression":
model = LightGBMModel.load(model_path)
elif model_type == "rank":
model = LightGBMLambdaRankModel.load(model_path)
else:
raise ValueError(f"不支持的模型类型: {model_type}")
print(f"[模型加载] 已加载模型: {model_path}")
print(f"[模型加载] 模型类型: {type(model).__name__}")
# 加载因子信息
if not os.path.exists(factors_path):
raise FileNotFoundError(f"因子信息文件不存在: {factors_path}")
with open(factors_path, "r", encoding="utf-8") as f:
factors_info = json.load(f)
print(f"[因子加载] 已加载因子信息: {factors_path}")
print(f"[因子加载] 因子总数: {factors_info.get('total_feature_count', 'N/A')}")
print(
f"[因子加载] 来自 metadata: {factors_info.get('selected_factors_count', 'N/A')}"
)
print(
f"[因子加载] 来自表达式: {factors_info.get('factor_definitions_count', 'N/A')}"
)
# 构建特征列列表
selected_factors = factors_info.get("selected_factors", [])
factor_definitions = factors_info.get("factor_definitions", {})
feature_cols = selected_factors + list(factor_definitions.keys())
print(f"[特征列] 共 {len(feature_cols)} 个特征")
return model, factors_info, feature_cols, model_path
def get_lookback_start_date(start_date: str, lookback_days: int) -> str:
"""计算考虑回看窗口后的实际开始日期。
为了确保时序因子(如 ts_mean(close, 20))在预测开始日期能正确计算,
需要向前获取额外的历史数据。
Args:
start_date: 预测开始日期 (YYYYMMDD)
lookback_days: 回看窗口天数
Returns:
考虑回看后的实际开始日期 (YYYYMMDD)
"""
start_dt = datetime.strptime(start_date, "%Y%m%d")
lookback_dt = start_dt - timedelta(days=lookback_days)
return lookback_dt.strftime("%Y%m%d")
def prepare_data(
engine: FactorEngine,
feature_cols: list[str],
start_date: str,
end_date: str,
label_name: str,
) -> pl.DataFrame:
"""准备预测数据。
Args:
engine: FactorEngine实例
feature_cols: 特征列名称列表
start_date: 开始日期 (YYYYMMDD)
end_date: 结束日期 (YYYYMMDD)
label_name: label列名称
Returns:
包含因子计算结果的数据框
"""
print(f"\n{'=' * 80}")
print(f"准备数据: {start_date} - {end_date}")
print(f"{'=' * 80}")
factor_names = feature_cols + [label_name]
data = engine.compute(
factor_names=factor_names,
start_date=start_date,
end_date=end_date,
)
print(f"数据形状: {data.shape}")
print(f"前5行预览:")
print(data.head())
return data
def apply_processors(data: pl.DataFrame, fitted_processors: List[Any]) -> pl.DataFrame:
"""应用已拟合的数据处理器。
Args:
data: 输入数据
fitted_processors: 已拟合的处理器列表(从训练时保存)
Returns:
处理后的数据
"""
print(f"\n{'=' * 80}")
print("应用数据处理器(使用训练时保存的参数)")
print(f"{'=' * 80}")
for i, processor in enumerate(fitted_processors, 1):
print(f" [{i}/{len(fitted_processors)}] {processor.__class__.__name__}")
data = processor.transform(data) # 使用 transform不是 fit_transform
print(f"处理后数据形状: {data.shape}")
return data
def predict():
"""主预测流程。"""
print("\n" + "=" * 80)
print("ProStock 预测脚本")
print("=" * 80)
# 1. 自动检测模型类型
model_type = detect_model_type(MODEL_SAVE_DIR)
print(f"\n[配置] 使用模型类型: {model_type}")
print(f"[配置] 预测时间段: {PREDICT_START} - {PREDICT_END}")
# 2. 加载模型和因子信息
model, factors_info, feature_cols, model_path = load_model_and_factors(
model_type, MODEL_SAVE_DIR
)
# 提取因子配置
selected_factors = factors_info.get("selected_factors", [])
factor_definitions = factors_info.get("factor_definitions", {})
# 3. 创建 FactorEngine
print("\n[1] 创建 FactorEngine")
engine = FactorEngine()
# 4. 注册因子
print("\n[2] 注册因子")
label_name = "future_return_5"
label_factor = get_label_factor(label_name)
# 注册来自 metadata 的因子
print(" 注册 metadata 因子:")
for name in selected_factors:
engine.add_factor(name)
print(f" - {name}")
# 注册表达式因子
print(" 注册表达式因子:")
for name, expr in factor_definitions.items():
engine.add_factor(name, expr)
print(f" - {name}: {expr}")
# 注册 label 因子
print(" 注册 Label 因子:")
for name, expr in label_factor.items():
engine.add_factor(name, expr)
print(f" - {name}: {expr}")
# 5. 准备数据(考虑回看窗口)
print(f"\n[数据准备] 预测时间段: {PREDICT_START} - {PREDICT_END}")
print(f"[数据准备] 回看窗口: {LOOKBACK_DAYS}")
actual_start = get_lookback_start_date(PREDICT_START, LOOKBACK_DAYS)
print(f"[数据准备] 实际加载数据时间段: {actual_start} - {PREDICT_END}")
data = prepare_data(
engine=engine,
feature_cols=feature_cols,
start_date=actual_start,
end_date=PREDICT_END,
label_name=label_name,
)
# 过滤回看数据,只保留预测日期范围内的数据
print(f"[数据准备] 过滤回看数据,保留 {PREDICT_START} 之后的数据...")
data = data.filter(data["trade_date"] >= PREDICT_START)
print(f"[数据准备] 过滤后数据形状: {data.shape}")
# 6. 股票池筛选
print("\n[3] 股票池筛选")
pool_manager = StockPoolManager(
filter_func=stock_pool_filter,
required_columns=STOCK_FILTER_REQUIRED_COLUMNS,
data_router=engine.router,
)
st_filter = STFilter(data_router=engine.router)
# 先执行 ST 过滤
if st_filter:
print(" 应用 ST 过滤器...")
data = st_filter.filter(data)
print(f" ST 过滤后数据规模: {data.shape}")
# 股票池筛选
print(" 执行每日股票池筛选...")
filtered_data = pool_manager.filter_and_select_daily(data)
print(f" 筛选前数据规模: {data.shape}")
print(f" 筛选后数据规模: {filtered_data.shape}")
print(f" 筛选前股票数: {data['ts_code'].n_unique()}")
print(f" 筛选后股票数: {filtered_data['ts_code'].n_unique()}")
# 7. 加载并应用数据处理器
print("\n[3.5] 加载数据处理器")
model_dir = os.path.dirname(model_path)
fitted_processors = load_processors(model_path)
if fitted_processors is None:
raise FileNotFoundError(
f"未找到处理器文件,请确保模型已正确训练并保存处理器到 {model_dir}/processors.pkl"
)
processed_data = apply_processors(filtered_data, fitted_processors)
# 8. 生成预测
print("\n[4] 生成预测")
print("-" * 60)
X = processed_data.select(feature_cols)
print(f" 预测样本数: {len(X)}")
print(f" 特征数: {len(feature_cols)}")
predictions = model.predict(X)
print(f" 预测完成!")
print(f"\n 预测结果统计:")
print(f" 均值: {predictions.mean():.6f}")
print(f" 标准差: {predictions.std():.6f}")
print(f" 最小值: {predictions.min():.6f}")
print(f" 最大值: {predictions.max():.6f}")
# 添加预测列
processed_data = processed_data.with_columns([pl.Series("prediction", predictions)])
# 9. 保存结果
print("\n[5] 保存预测结果")
print("-" * 60)
# 确保输出目录存在
os.makedirs(OUTPUT_DIR, exist_ok=True)
# 生成输出文件名
start_dt = datetime.strptime(PREDICT_START, "%Y%m%d")
end_dt = datetime.strptime(PREDICT_END, "%Y%m%d")
date_str = f"{start_dt.strftime('%Y%m%d')}_{end_dt.strftime('%Y%m%d')}"
# 保存每日 Top N
print(f" 保存每日 Top {TOP_N} 股票...")
output_path = os.path.join(OUTPUT_DIR, "predict_output.csv")
# 按日期分组,取每日 top N
topn_by_date = []
unique_dates = processed_data["trade_date"].unique().sort()
for date in unique_dates:
day_data = processed_data.filter(processed_data["trade_date"] == date)
# 按 prediction 降序排序,取前 N
topn = day_data.sort("prediction", descending=True).head(TOP_N)
topn_by_date.append(topn)
# 合并所有日期的 top N
topn_results = pl.concat(topn_by_date)
# 格式化日期并调整列顺序:日期、分数、股票
topn_to_save = topn_results.select(
[
pl.col("trade_date").str.slice(0, 4)
+ "-"
+ pl.col("trade_date").str.slice(4, 2)
+ "-"
+ pl.col("trade_date").str.slice(6, 2).alias("date"),
pl.col("prediction").alias("score"),
pl.col("ts_code"),
]
)
topn_to_save.write_csv(output_path, include_header=True)
print(f" 保存路径: {output_path}")
print(
f" 保存行数: {len(topn_to_save)}{len(unique_dates)}个交易日 × 每日top{TOP_N}"
)
print(f"\n 预览前15行:")
print(topn_to_save.head(15))
print("\n" + "=" * 80)
print("预测完成!")
print("=" * 80)
if __name__ == "__main__":
predict()