refactor(training): 重构训练管道,支持灵活因子配置
- 添加 FactorConfig 类,支持链式 API 和动态因子列表 - 重构 prepare_data 使用 FactorEngine 计算因子,确保防泄露机制生效 - 新增 prepare_data_and_train / train_and_predict 分步执行接口 - 修复 factors/__init__.py docstring 语法错误
This commit is contained in:
@@ -5,26 +5,298 @@
|
||||
|
||||
或:
|
||||
uv run python src/training/main.py
|
||||
|
||||
本脚本提供两种运行方式:
|
||||
1. run_full_pipeline(): 完整训练流程(数据准备 -> 训练 -> 预测)
|
||||
2. prepare_data_and_train() + train_and_predict(): 分步执行,便于调试和调整
|
||||
|
||||
因子配置示例:
|
||||
from src.factors import MovingAverageFactor, ReturnRankFactor
|
||||
|
||||
# 直接传入因子实例列表 - 最简单的方式
|
||||
factors = [
|
||||
MovingAverageFactor(period=5),
|
||||
MovingAverageFactor(period=10),
|
||||
MovingAverageFactor(period=20),
|
||||
ReturnRankFactor(period=5),
|
||||
ReturnRankFactor(period=10),
|
||||
]
|
||||
|
||||
# 运行完整流程
|
||||
result = run_full_pipeline(factors=factors)
|
||||
"""
|
||||
|
||||
from src.training.pipeline import run_training
|
||||
from pathlib import Path
|
||||
from typing import Optional, List
|
||||
|
||||
import polars as pl
|
||||
|
||||
from src.factors import BaseFactor
|
||||
from src.training.pipeline import (
|
||||
FactorConfig,
|
||||
predict_top_stocks,
|
||||
prepare_data,
|
||||
save_top_stocks,
|
||||
train_model,
|
||||
)
|
||||
|
||||
|
||||
def prepare_data_and_train(
|
||||
factors: Optional[List[BaseFactor]] = None,
|
||||
data_dir: str = "data",
|
||||
train_start: str = "20190101",
|
||||
train_end: str = "20231231",
|
||||
val_start: str = "20240102",
|
||||
val_end: str = "20240531",
|
||||
test_start: str = "20240602",
|
||||
test_end: str = "20241231",
|
||||
) -> tuple[pl.DataFrame, pl.DataFrame, pl.DataFrame, FactorConfig, str]:
|
||||
"""第一步:数据处理
|
||||
|
||||
加载原始数据,计算因子和标签,拆分训练/验证/测试集。
|
||||
|
||||
Args:
|
||||
factors: 因子实例列表,默认为 None(使用 MA5, MA10, ReturnRank5)
|
||||
data_dir: 数据目录
|
||||
train_start: 训练集开始日期
|
||||
train_end: 训练集结束日期
|
||||
val_start: 验证集开始日期
|
||||
val_end: 验证集结束日期
|
||||
test_start: 测试集开始日期
|
||||
test_end: 测试集结束日期
|
||||
|
||||
Returns:
|
||||
tuple: (train_data, val_data, test_data, factor_config, label_col)
|
||||
"""
|
||||
print("=" * 50)
|
||||
print("[Step 1] 数据处理")
|
||||
print("=" * 50)
|
||||
print(f"训练集: {train_start} -> {train_end}")
|
||||
print(f"验证集: {val_start} -> {val_end}")
|
||||
print(f"测试集: {test_start} -> {test_end}")
|
||||
print()
|
||||
|
||||
# 1. 准备数据
|
||||
train_data, val_data, test_data, factor_config = prepare_data(
|
||||
factors=factors,
|
||||
data_dir=data_dir,
|
||||
train_start=train_start,
|
||||
train_end=train_end,
|
||||
val_start=val_start,
|
||||
val_end=val_end,
|
||||
test_start=test_start,
|
||||
test_end=test_end,
|
||||
)
|
||||
|
||||
print(f"训练集样本数: {len(train_data)}")
|
||||
print(f"验证集样本数: {len(val_data)}")
|
||||
print(f"测试集样本数: {len(test_data)}")
|
||||
print()
|
||||
|
||||
# 打印少量数据样本展示
|
||||
print("=" * 50)
|
||||
print("[数据预览] 训练集前3行:")
|
||||
print(train_data.head(3))
|
||||
print()
|
||||
print("[数据预览] 验证集前3行:")
|
||||
print(val_data.head(3))
|
||||
print()
|
||||
print("[数据预览] 测试集前3行:")
|
||||
print(test_data.head(3))
|
||||
print()
|
||||
|
||||
# 2. 获取特征列名
|
||||
feature_cols = factor_config.get_feature_names()
|
||||
label_col = "label"
|
||||
|
||||
print(f"特征列: {feature_cols}")
|
||||
print(f"标签列: {label_col}")
|
||||
print()
|
||||
|
||||
return train_data, val_data, test_data, factor_config, label_col
|
||||
|
||||
|
||||
def train_and_predict(
|
||||
train_data: pl.DataFrame,
|
||||
val_data: pl.DataFrame,
|
||||
test_data: pl.DataFrame,
|
||||
factor_config: FactorConfig,
|
||||
label_col: str = "label",
|
||||
top_n: int = 5,
|
||||
output_path: str = "output/top_stocks.tsv",
|
||||
) -> pl.DataFrame:
|
||||
"""第二步:训练和预测
|
||||
|
||||
使用处理好的数据训练模型,进行测试集预测并保存结果。
|
||||
|
||||
Args:
|
||||
train_data: 训练数据
|
||||
val_data: 验证数据
|
||||
test_data: 测试数据
|
||||
factor_config: 因子配置对象
|
||||
label_col: 标签列名
|
||||
top_n: 每日选股数量
|
||||
output_path: 输出文件路径
|
||||
|
||||
Returns:
|
||||
选股结果DataFrame
|
||||
"""
|
||||
print("=" * 50)
|
||||
print("[Step 2] 模型训练与预测")
|
||||
print("=" * 50)
|
||||
print()
|
||||
|
||||
# 获取特征列名
|
||||
feature_cols = factor_config.get_feature_names()
|
||||
print(f"使用特征: {feature_cols}")
|
||||
print()
|
||||
|
||||
# 3. 训练模型
|
||||
print("[Training] Training model...")
|
||||
model, pipeline = train_model(
|
||||
train_data=train_data,
|
||||
val_data=val_data,
|
||||
feature_cols=feature_cols,
|
||||
label_col=label_col,
|
||||
)
|
||||
print()
|
||||
|
||||
# 4. 测试集预测
|
||||
print("[Predict] Predicting on test set...")
|
||||
top_stocks = predict_top_stocks(
|
||||
model=model,
|
||||
pipeline=pipeline,
|
||||
test_data=test_data,
|
||||
feature_cols=feature_cols,
|
||||
top_n=top_n,
|
||||
)
|
||||
print()
|
||||
|
||||
# 5. 保存结果
|
||||
print(f"[Saving] Saving results to {output_path}...")
|
||||
Path(output_path).parent.mkdir(parents=True, exist_ok=True)
|
||||
save_top_stocks(top_stocks, output_path)
|
||||
print()
|
||||
|
||||
return top_stocks
|
||||
|
||||
|
||||
def run_full_pipeline(
|
||||
factors: Optional[List[BaseFactor]] = None,
|
||||
train_start: str = "20190101",
|
||||
train_end: str = "20231231",
|
||||
val_start: str = "20240102",
|
||||
val_end: str = "20240531",
|
||||
test_start: str = "20240602",
|
||||
test_end: str = "20241231",
|
||||
top_n: int = 5,
|
||||
output_path: str = "output/top_stocks.tsv",
|
||||
) -> pl.DataFrame:
|
||||
"""运行完整训练流程
|
||||
|
||||
相当于依次调用 prepare_data_and_train 和 train_and_predict。
|
||||
|
||||
Args:
|
||||
factors: 因子实例列表,默认为 None(使用 MA5, MA10, ReturnRank5)
|
||||
train_start: 训练集开始日期
|
||||
train_end: 训练集结束日期
|
||||
val_start: 验证集开始日期
|
||||
val_end: 验证集结束日期
|
||||
test_start: 测试集开始日期
|
||||
test_end: 测试集结束日期
|
||||
top_n: 每日选股数量
|
||||
output_path: 输出文件路径
|
||||
|
||||
Returns:
|
||||
选股结果DataFrame
|
||||
"""
|
||||
# 第一步:数据处理
|
||||
train_data, val_data, test_data, factor_config, label_col = prepare_data_and_train(
|
||||
factors=factors,
|
||||
train_start=train_start,
|
||||
train_end=train_end,
|
||||
val_start=val_start,
|
||||
val_end=val_end,
|
||||
test_start=test_start,
|
||||
test_end=test_end,
|
||||
)
|
||||
|
||||
# 第二步:训练和预测
|
||||
result = train_and_predict(
|
||||
train_data=train_data,
|
||||
val_data=val_data,
|
||||
test_data=test_data,
|
||||
factor_config=factor_config,
|
||||
label_col=label_col,
|
||||
top_n=top_n,
|
||||
output_path=output_path,
|
||||
)
|
||||
|
||||
print("=" * 50)
|
||||
print("[Done] 训练流程完成!")
|
||||
print("=" * 50)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 运行完整训练流程
|
||||
# 训练集:20190101 - 20231231
|
||||
# 验证集:20240102 - 20240531 (与训练集间隔1天,避免数据泄露)
|
||||
# 测试集:20240602 - 20241231 (与验证集间隔1天,避免数据泄露)
|
||||
result = run_training(
|
||||
from src.factors import MovingAverageFactor, ReturnRankFactor
|
||||
|
||||
# ========== 因子配置 ==========
|
||||
# 直接传入因子实例列表 - 简单直观
|
||||
factors = [
|
||||
MovingAverageFactor(period=5), # 5日移动平均线
|
||||
MovingAverageFactor(period=10), # 10日移动平均线
|
||||
MovingAverageFactor(period=20), # 20日移动平均线
|
||||
ReturnRankFactor(period=5), # 5日收益率排名
|
||||
ReturnRankFactor(period=10), # 10日收益率排名
|
||||
]
|
||||
|
||||
# ========== 运行方式 ==========
|
||||
|
||||
# 方式一:完整流程(一次性执行)
|
||||
# result = run_full_pipeline(
|
||||
# factors=factors,
|
||||
# train_start="20190101",
|
||||
# train_end="20231231",
|
||||
# val_start="20240102",
|
||||
# val_end="20240531",
|
||||
# test_start="20240602",
|
||||
# test_end="20241231",
|
||||
# top_n=5,
|
||||
# output_path="output/top_stocks.tsv",
|
||||
# )
|
||||
|
||||
# 方式二:分步执行(便于调试)
|
||||
# 第一步:数据处理
|
||||
train_data, val_data, test_data, factor_config, label_col = prepare_data_and_train(
|
||||
factors=factors,
|
||||
train_start="20190101",
|
||||
train_end="20231231",
|
||||
val_start="20240102",
|
||||
val_end="20240531",
|
||||
test_start="20240602",
|
||||
test_end="20241231",
|
||||
top_n=5,
|
||||
output_path="output/top_stocks.tsv",
|
||||
)
|
||||
|
||||
print("\n[Result] Top stocks selection:")
|
||||
print(result)
|
||||
# 可在此处添加自定义逻辑,例如:
|
||||
# - 查看数据分布
|
||||
# - 调整特征
|
||||
# - 保存中间结果
|
||||
print("\n[Info] 因子配置详情:")
|
||||
print(f" 因子列表: {factor_config.get_feature_names()}")
|
||||
print(f" 最大回溯天数: {factor_config.get_max_lookback()}")
|
||||
|
||||
# 第二步:训练和预测
|
||||
# result = train_and_predict(
|
||||
# train_data=train_data,
|
||||
# val_data=val_data,
|
||||
# test_data=test_data,
|
||||
# factor_config=factor_config,
|
||||
# label_col=label_col,
|
||||
# top_n=5,
|
||||
# output_path="output/top_stocks.tsv",
|
||||
# )
|
||||
#
|
||||
# print("\n[Result] Top stocks selection:")
|
||||
# print(result)
|
||||
|
||||
Reference in New Issue
Block a user