refactor(training): 重构训练管道,支持灵活因子配置

- 添加 FactorConfig 类,支持链式 API 和动态因子列表
- 重构 prepare_data 使用 FactorEngine 计算因子,确保防泄露机制生效
- 新增 prepare_data_and_train / train_and_predict 分步执行接口
- 修复 factors/__init__.py docstring 语法错误
This commit is contained in:
2026-02-25 23:39:02 +08:00
parent a9e4746239
commit 990a77ec6c
2 changed files with 564 additions and 119 deletions

View File

@@ -5,26 +5,298 @@
或:
uv run python src/training/main.py
本脚本提供两种运行方式:
1. run_full_pipeline(): 完整训练流程(数据准备 -> 训练 -> 预测)
2. prepare_data_and_train() + train_and_predict(): 分步执行,便于调试和调整
因子配置示例:
from src.factors import MovingAverageFactor, ReturnRankFactor
# 直接传入因子实例列表 - 最简单的方式
factors = [
MovingAverageFactor(period=5),
MovingAverageFactor(period=10),
MovingAverageFactor(period=20),
ReturnRankFactor(period=5),
ReturnRankFactor(period=10),
]
# 运行完整流程
result = run_full_pipeline(factors=factors)
"""
from src.training.pipeline import run_training
from pathlib import Path
from typing import Optional, List
import polars as pl
from src.factors import BaseFactor
from src.training.pipeline import (
FactorConfig,
predict_top_stocks,
prepare_data,
save_top_stocks,
train_model,
)
def prepare_data_and_train(
factors: Optional[List[BaseFactor]] = None,
data_dir: str = "data",
train_start: str = "20190101",
train_end: str = "20231231",
val_start: str = "20240102",
val_end: str = "20240531",
test_start: str = "20240602",
test_end: str = "20241231",
) -> tuple[pl.DataFrame, pl.DataFrame, pl.DataFrame, FactorConfig, str]:
"""第一步:数据处理
加载原始数据,计算因子和标签,拆分训练/验证/测试集。
Args:
factors: 因子实例列表,默认为 None使用 MA5, MA10, ReturnRank5
data_dir: 数据目录
train_start: 训练集开始日期
train_end: 训练集结束日期
val_start: 验证集开始日期
val_end: 验证集结束日期
test_start: 测试集开始日期
test_end: 测试集结束日期
Returns:
tuple: (train_data, val_data, test_data, factor_config, label_col)
"""
print("=" * 50)
print("[Step 1] 数据处理")
print("=" * 50)
print(f"训练集: {train_start} -> {train_end}")
print(f"验证集: {val_start} -> {val_end}")
print(f"测试集: {test_start} -> {test_end}")
print()
# 1. 准备数据
train_data, val_data, test_data, factor_config = prepare_data(
factors=factors,
data_dir=data_dir,
train_start=train_start,
train_end=train_end,
val_start=val_start,
val_end=val_end,
test_start=test_start,
test_end=test_end,
)
print(f"训练集样本数: {len(train_data)}")
print(f"验证集样本数: {len(val_data)}")
print(f"测试集样本数: {len(test_data)}")
print()
# 打印少量数据样本展示
print("=" * 50)
print("[数据预览] 训练集前3行:")
print(train_data.head(3))
print()
print("[数据预览] 验证集前3行:")
print(val_data.head(3))
print()
print("[数据预览] 测试集前3行:")
print(test_data.head(3))
print()
# 2. 获取特征列名
feature_cols = factor_config.get_feature_names()
label_col = "label"
print(f"特征列: {feature_cols}")
print(f"标签列: {label_col}")
print()
return train_data, val_data, test_data, factor_config, label_col
def train_and_predict(
train_data: pl.DataFrame,
val_data: pl.DataFrame,
test_data: pl.DataFrame,
factor_config: FactorConfig,
label_col: str = "label",
top_n: int = 5,
output_path: str = "output/top_stocks.tsv",
) -> pl.DataFrame:
"""第二步:训练和预测
使用处理好的数据训练模型,进行测试集预测并保存结果。
Args:
train_data: 训练数据
val_data: 验证数据
test_data: 测试数据
factor_config: 因子配置对象
label_col: 标签列名
top_n: 每日选股数量
output_path: 输出文件路径
Returns:
选股结果DataFrame
"""
print("=" * 50)
print("[Step 2] 模型训练与预测")
print("=" * 50)
print()
# 获取特征列名
feature_cols = factor_config.get_feature_names()
print(f"使用特征: {feature_cols}")
print()
# 3. 训练模型
print("[Training] Training model...")
model, pipeline = train_model(
train_data=train_data,
val_data=val_data,
feature_cols=feature_cols,
label_col=label_col,
)
print()
# 4. 测试集预测
print("[Predict] Predicting on test set...")
top_stocks = predict_top_stocks(
model=model,
pipeline=pipeline,
test_data=test_data,
feature_cols=feature_cols,
top_n=top_n,
)
print()
# 5. 保存结果
print(f"[Saving] Saving results to {output_path}...")
Path(output_path).parent.mkdir(parents=True, exist_ok=True)
save_top_stocks(top_stocks, output_path)
print()
return top_stocks
def run_full_pipeline(
factors: Optional[List[BaseFactor]] = None,
train_start: str = "20190101",
train_end: str = "20231231",
val_start: str = "20240102",
val_end: str = "20240531",
test_start: str = "20240602",
test_end: str = "20241231",
top_n: int = 5,
output_path: str = "output/top_stocks.tsv",
) -> pl.DataFrame:
"""运行完整训练流程
相当于依次调用 prepare_data_and_train 和 train_and_predict。
Args:
factors: 因子实例列表,默认为 None使用 MA5, MA10, ReturnRank5
train_start: 训练集开始日期
train_end: 训练集结束日期
val_start: 验证集开始日期
val_end: 验证集结束日期
test_start: 测试集开始日期
test_end: 测试集结束日期
top_n: 每日选股数量
output_path: 输出文件路径
Returns:
选股结果DataFrame
"""
# 第一步:数据处理
train_data, val_data, test_data, factor_config, label_col = prepare_data_and_train(
factors=factors,
train_start=train_start,
train_end=train_end,
val_start=val_start,
val_end=val_end,
test_start=test_start,
test_end=test_end,
)
# 第二步:训练和预测
result = train_and_predict(
train_data=train_data,
val_data=val_data,
test_data=test_data,
factor_config=factor_config,
label_col=label_col,
top_n=top_n,
output_path=output_path,
)
print("=" * 50)
print("[Done] 训练流程完成!")
print("=" * 50)
return result
if __name__ == "__main__":
# 运行完整训练流程
# 训练集20190101 - 20231231
# 验证集20240102 - 20240531 (与训练集间隔1天避免数据泄露)
# 测试集20240602 - 20241231 (与验证集间隔1天避免数据泄露)
result = run_training(
from src.factors import MovingAverageFactor, ReturnRankFactor
# ========== 因子配置 ==========
# 直接传入因子实例列表 - 简单直观
factors = [
MovingAverageFactor(period=5), # 5日移动平均线
MovingAverageFactor(period=10), # 10日移动平均线
MovingAverageFactor(period=20), # 20日移动平均线
ReturnRankFactor(period=5), # 5日收益率排名
ReturnRankFactor(period=10), # 10日收益率排名
]
# ========== 运行方式 ==========
# 方式一:完整流程(一次性执行)
# result = run_full_pipeline(
# factors=factors,
# train_start="20190101",
# train_end="20231231",
# val_start="20240102",
# val_end="20240531",
# test_start="20240602",
# test_end="20241231",
# top_n=5,
# output_path="output/top_stocks.tsv",
# )
# 方式二:分步执行(便于调试)
# 第一步:数据处理
train_data, val_data, test_data, factor_config, label_col = prepare_data_and_train(
factors=factors,
train_start="20190101",
train_end="20231231",
val_start="20240102",
val_end="20240531",
test_start="20240602",
test_end="20241231",
top_n=5,
output_path="output/top_stocks.tsv",
)
print("\n[Result] Top stocks selection:")
print(result)
# 可在此处添加自定义逻辑,例如:
# - 查看数据分布
# - 调整特征
# - 保存中间结果
print("\n[Info] 因子配置详情:")
print(f" 因子列表: {factor_config.get_feature_names()}")
print(f" 最大回溯天数: {factor_config.get_max_lookback()}")
# 第二步:训练和预测
# result = train_and_predict(
# train_data=train_data,
# val_data=val_data,
# test_data=test_data,
# factor_config=factor_config,
# label_col=label_col,
# top_n=5,
# output_path="output/top_stocks.tsv",
# )
#
# print("\n[Result] Top stocks selection:")
# print(result)