- 添加 FactorConfig 类,支持链式 API 和动态因子列表 - 重构 prepare_data 使用 FactorEngine 计算因子,确保防泄露机制生效 - 新增 prepare_data_and_train / train_and_predict 分步执行接口 - 修复 factors/__init__.py docstring 语法错误
303 lines
8.4 KiB
Python
303 lines
8.4 KiB
Python
"""训练流程入口脚本
|
||
|
||
运行方式:
|
||
uv run python -m src.training.main
|
||
|
||
或:
|
||
uv run python src/training/main.py
|
||
|
||
本脚本提供两种运行方式:
|
||
1. run_full_pipeline(): 完整训练流程(数据准备 -> 训练 -> 预测)
|
||
2. prepare_data_and_train() + train_and_predict(): 分步执行,便于调试和调整
|
||
|
||
因子配置示例:
|
||
from src.factors import MovingAverageFactor, ReturnRankFactor
|
||
|
||
# 直接传入因子实例列表 - 最简单的方式
|
||
factors = [
|
||
MovingAverageFactor(period=5),
|
||
MovingAverageFactor(period=10),
|
||
MovingAverageFactor(period=20),
|
||
ReturnRankFactor(period=5),
|
||
ReturnRankFactor(period=10),
|
||
]
|
||
|
||
# 运行完整流程
|
||
result = run_full_pipeline(factors=factors)
|
||
"""
|
||
|
||
from pathlib import Path
|
||
from typing import Optional, List
|
||
|
||
import polars as pl
|
||
|
||
from src.factors import BaseFactor
|
||
from src.training.pipeline import (
|
||
FactorConfig,
|
||
predict_top_stocks,
|
||
prepare_data,
|
||
save_top_stocks,
|
||
train_model,
|
||
)
|
||
|
||
|
||
def prepare_data_and_train(
|
||
factors: Optional[List[BaseFactor]] = None,
|
||
data_dir: str = "data",
|
||
train_start: str = "20190101",
|
||
train_end: str = "20231231",
|
||
val_start: str = "20240102",
|
||
val_end: str = "20240531",
|
||
test_start: str = "20240602",
|
||
test_end: str = "20241231",
|
||
) -> tuple[pl.DataFrame, pl.DataFrame, pl.DataFrame, FactorConfig, str]:
|
||
"""第一步:数据处理
|
||
|
||
加载原始数据,计算因子和标签,拆分训练/验证/测试集。
|
||
|
||
Args:
|
||
factors: 因子实例列表,默认为 None(使用 MA5, MA10, ReturnRank5)
|
||
data_dir: 数据目录
|
||
train_start: 训练集开始日期
|
||
train_end: 训练集结束日期
|
||
val_start: 验证集开始日期
|
||
val_end: 验证集结束日期
|
||
test_start: 测试集开始日期
|
||
test_end: 测试集结束日期
|
||
|
||
Returns:
|
||
tuple: (train_data, val_data, test_data, factor_config, label_col)
|
||
"""
|
||
print("=" * 50)
|
||
print("[Step 1] 数据处理")
|
||
print("=" * 50)
|
||
print(f"训练集: {train_start} -> {train_end}")
|
||
print(f"验证集: {val_start} -> {val_end}")
|
||
print(f"测试集: {test_start} -> {test_end}")
|
||
print()
|
||
|
||
# 1. 准备数据
|
||
train_data, val_data, test_data, factor_config = prepare_data(
|
||
factors=factors,
|
||
data_dir=data_dir,
|
||
train_start=train_start,
|
||
train_end=train_end,
|
||
val_start=val_start,
|
||
val_end=val_end,
|
||
test_start=test_start,
|
||
test_end=test_end,
|
||
)
|
||
|
||
print(f"训练集样本数: {len(train_data)}")
|
||
print(f"验证集样本数: {len(val_data)}")
|
||
print(f"测试集样本数: {len(test_data)}")
|
||
print()
|
||
|
||
# 打印少量数据样本展示
|
||
print("=" * 50)
|
||
print("[数据预览] 训练集前3行:")
|
||
print(train_data.head(3))
|
||
print()
|
||
print("[数据预览] 验证集前3行:")
|
||
print(val_data.head(3))
|
||
print()
|
||
print("[数据预览] 测试集前3行:")
|
||
print(test_data.head(3))
|
||
print()
|
||
|
||
# 2. 获取特征列名
|
||
feature_cols = factor_config.get_feature_names()
|
||
label_col = "label"
|
||
|
||
print(f"特征列: {feature_cols}")
|
||
print(f"标签列: {label_col}")
|
||
print()
|
||
|
||
return train_data, val_data, test_data, factor_config, label_col
|
||
|
||
|
||
def train_and_predict(
|
||
train_data: pl.DataFrame,
|
||
val_data: pl.DataFrame,
|
||
test_data: pl.DataFrame,
|
||
factor_config: FactorConfig,
|
||
label_col: str = "label",
|
||
top_n: int = 5,
|
||
output_path: str = "output/top_stocks.tsv",
|
||
) -> pl.DataFrame:
|
||
"""第二步:训练和预测
|
||
|
||
使用处理好的数据训练模型,进行测试集预测并保存结果。
|
||
|
||
Args:
|
||
train_data: 训练数据
|
||
val_data: 验证数据
|
||
test_data: 测试数据
|
||
factor_config: 因子配置对象
|
||
label_col: 标签列名
|
||
top_n: 每日选股数量
|
||
output_path: 输出文件路径
|
||
|
||
Returns:
|
||
选股结果DataFrame
|
||
"""
|
||
print("=" * 50)
|
||
print("[Step 2] 模型训练与预测")
|
||
print("=" * 50)
|
||
print()
|
||
|
||
# 获取特征列名
|
||
feature_cols = factor_config.get_feature_names()
|
||
print(f"使用特征: {feature_cols}")
|
||
print()
|
||
|
||
# 3. 训练模型
|
||
print("[Training] Training model...")
|
||
model, pipeline = train_model(
|
||
train_data=train_data,
|
||
val_data=val_data,
|
||
feature_cols=feature_cols,
|
||
label_col=label_col,
|
||
)
|
||
print()
|
||
|
||
# 4. 测试集预测
|
||
print("[Predict] Predicting on test set...")
|
||
top_stocks = predict_top_stocks(
|
||
model=model,
|
||
pipeline=pipeline,
|
||
test_data=test_data,
|
||
feature_cols=feature_cols,
|
||
top_n=top_n,
|
||
)
|
||
print()
|
||
|
||
# 5. 保存结果
|
||
print(f"[Saving] Saving results to {output_path}...")
|
||
Path(output_path).parent.mkdir(parents=True, exist_ok=True)
|
||
save_top_stocks(top_stocks, output_path)
|
||
print()
|
||
|
||
return top_stocks
|
||
|
||
|
||
def run_full_pipeline(
|
||
factors: Optional[List[BaseFactor]] = None,
|
||
train_start: str = "20190101",
|
||
train_end: str = "20231231",
|
||
val_start: str = "20240102",
|
||
val_end: str = "20240531",
|
||
test_start: str = "20240602",
|
||
test_end: str = "20241231",
|
||
top_n: int = 5,
|
||
output_path: str = "output/top_stocks.tsv",
|
||
) -> pl.DataFrame:
|
||
"""运行完整训练流程
|
||
|
||
相当于依次调用 prepare_data_and_train 和 train_and_predict。
|
||
|
||
Args:
|
||
factors: 因子实例列表,默认为 None(使用 MA5, MA10, ReturnRank5)
|
||
train_start: 训练集开始日期
|
||
train_end: 训练集结束日期
|
||
val_start: 验证集开始日期
|
||
val_end: 验证集结束日期
|
||
test_start: 测试集开始日期
|
||
test_end: 测试集结束日期
|
||
top_n: 每日选股数量
|
||
output_path: 输出文件路径
|
||
|
||
Returns:
|
||
选股结果DataFrame
|
||
"""
|
||
# 第一步:数据处理
|
||
train_data, val_data, test_data, factor_config, label_col = prepare_data_and_train(
|
||
factors=factors,
|
||
train_start=train_start,
|
||
train_end=train_end,
|
||
val_start=val_start,
|
||
val_end=val_end,
|
||
test_start=test_start,
|
||
test_end=test_end,
|
||
)
|
||
|
||
# 第二步:训练和预测
|
||
result = train_and_predict(
|
||
train_data=train_data,
|
||
val_data=val_data,
|
||
test_data=test_data,
|
||
factor_config=factor_config,
|
||
label_col=label_col,
|
||
top_n=top_n,
|
||
output_path=output_path,
|
||
)
|
||
|
||
print("=" * 50)
|
||
print("[Done] 训练流程完成!")
|
||
print("=" * 50)
|
||
|
||
return result
|
||
|
||
|
||
if __name__ == "__main__":
|
||
from src.factors import MovingAverageFactor, ReturnRankFactor
|
||
|
||
# ========== 因子配置 ==========
|
||
# 直接传入因子实例列表 - 简单直观
|
||
factors = [
|
||
MovingAverageFactor(period=5), # 5日移动平均线
|
||
MovingAverageFactor(period=10), # 10日移动平均线
|
||
MovingAverageFactor(period=20), # 20日移动平均线
|
||
ReturnRankFactor(period=5), # 5日收益率排名
|
||
ReturnRankFactor(period=10), # 10日收益率排名
|
||
]
|
||
|
||
# ========== 运行方式 ==========
|
||
|
||
# 方式一:完整流程(一次性执行)
|
||
# result = run_full_pipeline(
|
||
# factors=factors,
|
||
# train_start="20190101",
|
||
# train_end="20231231",
|
||
# val_start="20240102",
|
||
# val_end="20240531",
|
||
# test_start="20240602",
|
||
# test_end="20241231",
|
||
# top_n=5,
|
||
# output_path="output/top_stocks.tsv",
|
||
# )
|
||
|
||
# 方式二:分步执行(便于调试)
|
||
# 第一步:数据处理
|
||
train_data, val_data, test_data, factor_config, label_col = prepare_data_and_train(
|
||
factors=factors,
|
||
train_start="20190101",
|
||
train_end="20231231",
|
||
val_start="20240102",
|
||
val_end="20240531",
|
||
test_start="20240602",
|
||
test_end="20241231",
|
||
)
|
||
|
||
# 可在此处添加自定义逻辑,例如:
|
||
# - 查看数据分布
|
||
# - 调整特征
|
||
# - 保存中间结果
|
||
print("\n[Info] 因子配置详情:")
|
||
print(f" 因子列表: {factor_config.get_feature_names()}")
|
||
print(f" 最大回溯天数: {factor_config.get_max_lookback()}")
|
||
|
||
# 第二步:训练和预测
|
||
# result = train_and_predict(
|
||
# train_data=train_data,
|
||
# val_data=val_data,
|
||
# test_data=test_data,
|
||
# factor_config=factor_config,
|
||
# label_col=label_col,
|
||
# top_n=5,
|
||
# output_path="output/top_stocks.tsv",
|
||
# )
|
||
#
|
||
# print("\n[Result] Top stocks selection:")
|
||
# print(result)
|