Files
ProStock/src/training/main.py
liaozhaorun 990a77ec6c refactor(training): 重构训练管道,支持灵活因子配置
- 添加 FactorConfig 类,支持链式 API 和动态因子列表
- 重构 prepare_data 使用 FactorEngine 计算因子,确保防泄露机制生效
- 新增 prepare_data_and_train / train_and_predict 分步执行接口
- 修复 factors/__init__.py docstring 语法错误
2026-02-25 23:39:02 +08:00

303 lines
8.4 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""训练流程入口脚本
运行方式:
uv run python -m src.training.main
或:
uv run python src/training/main.py
本脚本提供两种运行方式:
1. run_full_pipeline(): 完整训练流程(数据准备 -> 训练 -> 预测)
2. prepare_data_and_train() + train_and_predict(): 分步执行,便于调试和调整
因子配置示例:
from src.factors import MovingAverageFactor, ReturnRankFactor
# 直接传入因子实例列表 - 最简单的方式
factors = [
MovingAverageFactor(period=5),
MovingAverageFactor(period=10),
MovingAverageFactor(period=20),
ReturnRankFactor(period=5),
ReturnRankFactor(period=10),
]
# 运行完整流程
result = run_full_pipeline(factors=factors)
"""
from pathlib import Path
from typing import Optional, List
import polars as pl
from src.factors import BaseFactor
from src.training.pipeline import (
FactorConfig,
predict_top_stocks,
prepare_data,
save_top_stocks,
train_model,
)
def prepare_data_and_train(
factors: Optional[List[BaseFactor]] = None,
data_dir: str = "data",
train_start: str = "20190101",
train_end: str = "20231231",
val_start: str = "20240102",
val_end: str = "20240531",
test_start: str = "20240602",
test_end: str = "20241231",
) -> tuple[pl.DataFrame, pl.DataFrame, pl.DataFrame, FactorConfig, str]:
"""第一步:数据处理
加载原始数据,计算因子和标签,拆分训练/验证/测试集。
Args:
factors: 因子实例列表,默认为 None使用 MA5, MA10, ReturnRank5
data_dir: 数据目录
train_start: 训练集开始日期
train_end: 训练集结束日期
val_start: 验证集开始日期
val_end: 验证集结束日期
test_start: 测试集开始日期
test_end: 测试集结束日期
Returns:
tuple: (train_data, val_data, test_data, factor_config, label_col)
"""
print("=" * 50)
print("[Step 1] 数据处理")
print("=" * 50)
print(f"训练集: {train_start} -> {train_end}")
print(f"验证集: {val_start} -> {val_end}")
print(f"测试集: {test_start} -> {test_end}")
print()
# 1. 准备数据
train_data, val_data, test_data, factor_config = prepare_data(
factors=factors,
data_dir=data_dir,
train_start=train_start,
train_end=train_end,
val_start=val_start,
val_end=val_end,
test_start=test_start,
test_end=test_end,
)
print(f"训练集样本数: {len(train_data)}")
print(f"验证集样本数: {len(val_data)}")
print(f"测试集样本数: {len(test_data)}")
print()
# 打印少量数据样本展示
print("=" * 50)
print("[数据预览] 训练集前3行:")
print(train_data.head(3))
print()
print("[数据预览] 验证集前3行:")
print(val_data.head(3))
print()
print("[数据预览] 测试集前3行:")
print(test_data.head(3))
print()
# 2. 获取特征列名
feature_cols = factor_config.get_feature_names()
label_col = "label"
print(f"特征列: {feature_cols}")
print(f"标签列: {label_col}")
print()
return train_data, val_data, test_data, factor_config, label_col
def train_and_predict(
train_data: pl.DataFrame,
val_data: pl.DataFrame,
test_data: pl.DataFrame,
factor_config: FactorConfig,
label_col: str = "label",
top_n: int = 5,
output_path: str = "output/top_stocks.tsv",
) -> pl.DataFrame:
"""第二步:训练和预测
使用处理好的数据训练模型,进行测试集预测并保存结果。
Args:
train_data: 训练数据
val_data: 验证数据
test_data: 测试数据
factor_config: 因子配置对象
label_col: 标签列名
top_n: 每日选股数量
output_path: 输出文件路径
Returns:
选股结果DataFrame
"""
print("=" * 50)
print("[Step 2] 模型训练与预测")
print("=" * 50)
print()
# 获取特征列名
feature_cols = factor_config.get_feature_names()
print(f"使用特征: {feature_cols}")
print()
# 3. 训练模型
print("[Training] Training model...")
model, pipeline = train_model(
train_data=train_data,
val_data=val_data,
feature_cols=feature_cols,
label_col=label_col,
)
print()
# 4. 测试集预测
print("[Predict] Predicting on test set...")
top_stocks = predict_top_stocks(
model=model,
pipeline=pipeline,
test_data=test_data,
feature_cols=feature_cols,
top_n=top_n,
)
print()
# 5. 保存结果
print(f"[Saving] Saving results to {output_path}...")
Path(output_path).parent.mkdir(parents=True, exist_ok=True)
save_top_stocks(top_stocks, output_path)
print()
return top_stocks
def run_full_pipeline(
factors: Optional[List[BaseFactor]] = None,
train_start: str = "20190101",
train_end: str = "20231231",
val_start: str = "20240102",
val_end: str = "20240531",
test_start: str = "20240602",
test_end: str = "20241231",
top_n: int = 5,
output_path: str = "output/top_stocks.tsv",
) -> pl.DataFrame:
"""运行完整训练流程
相当于依次调用 prepare_data_and_train 和 train_and_predict。
Args:
factors: 因子实例列表,默认为 None使用 MA5, MA10, ReturnRank5
train_start: 训练集开始日期
train_end: 训练集结束日期
val_start: 验证集开始日期
val_end: 验证集结束日期
test_start: 测试集开始日期
test_end: 测试集结束日期
top_n: 每日选股数量
output_path: 输出文件路径
Returns:
选股结果DataFrame
"""
# 第一步:数据处理
train_data, val_data, test_data, factor_config, label_col = prepare_data_and_train(
factors=factors,
train_start=train_start,
train_end=train_end,
val_start=val_start,
val_end=val_end,
test_start=test_start,
test_end=test_end,
)
# 第二步:训练和预测
result = train_and_predict(
train_data=train_data,
val_data=val_data,
test_data=test_data,
factor_config=factor_config,
label_col=label_col,
top_n=top_n,
output_path=output_path,
)
print("=" * 50)
print("[Done] 训练流程完成!")
print("=" * 50)
return result
if __name__ == "__main__":
from src.factors import MovingAverageFactor, ReturnRankFactor
# ========== 因子配置 ==========
# 直接传入因子实例列表 - 简单直观
factors = [
MovingAverageFactor(period=5), # 5日移动平均线
MovingAverageFactor(period=10), # 10日移动平均线
MovingAverageFactor(period=20), # 20日移动平均线
ReturnRankFactor(period=5), # 5日收益率排名
ReturnRankFactor(period=10), # 10日收益率排名
]
# ========== 运行方式 ==========
# 方式一:完整流程(一次性执行)
# result = run_full_pipeline(
# factors=factors,
# train_start="20190101",
# train_end="20231231",
# val_start="20240102",
# val_end="20240531",
# test_start="20240602",
# test_end="20241231",
# top_n=5,
# output_path="output/top_stocks.tsv",
# )
# 方式二:分步执行(便于调试)
# 第一步:数据处理
train_data, val_data, test_data, factor_config, label_col = prepare_data_and_train(
factors=factors,
train_start="20190101",
train_end="20231231",
val_start="20240102",
val_end="20240531",
test_start="20240602",
test_end="20241231",
)
# 可在此处添加自定义逻辑,例如:
# - 查看数据分布
# - 调整特征
# - 保存中间结果
print("\n[Info] 因子配置详情:")
print(f" 因子列表: {factor_config.get_feature_names()}")
print(f" 最大回溯天数: {factor_config.get_max_lookback()}")
# 第二步:训练和预测
# result = train_and_predict(
# train_data=train_data,
# val_data=val_data,
# test_data=test_data,
# factor_config=factor_config,
# label_col=label_col,
# top_n=5,
# output_path="output/top_stocks.tsv",
# )
#
# print("\n[Result] Top stocks selection:")
# print(result)