refactor(training): 重构训练管道,支持灵活因子配置
- 添加 FactorConfig 类,支持链式 API 和动态因子列表 - 重构 prepare_data 使用 FactorEngine 计算因子,确保防泄露机制生效 - 新增 prepare_data_and_train / train_and_predict 分步执行接口 - 修复 factors/__init__.py docstring 语法错误
This commit is contained in:
@@ -5,26 +5,298 @@
|
|||||||
|
|
||||||
或:
|
或:
|
||||||
uv run python src/training/main.py
|
uv run python src/training/main.py
|
||||||
|
|
||||||
|
本脚本提供两种运行方式:
|
||||||
|
1. run_full_pipeline(): 完整训练流程(数据准备 -> 训练 -> 预测)
|
||||||
|
2. prepare_data_and_train() + train_and_predict(): 分步执行,便于调试和调整
|
||||||
|
|
||||||
|
因子配置示例:
|
||||||
|
from src.factors import MovingAverageFactor, ReturnRankFactor
|
||||||
|
|
||||||
|
# 直接传入因子实例列表 - 最简单的方式
|
||||||
|
factors = [
|
||||||
|
MovingAverageFactor(period=5),
|
||||||
|
MovingAverageFactor(period=10),
|
||||||
|
MovingAverageFactor(period=20),
|
||||||
|
ReturnRankFactor(period=5),
|
||||||
|
ReturnRankFactor(period=10),
|
||||||
|
]
|
||||||
|
|
||||||
|
# 运行完整流程
|
||||||
|
result = run_full_pipeline(factors=factors)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from src.training.pipeline import run_training
|
from pathlib import Path
|
||||||
|
from typing import Optional, List
|
||||||
|
|
||||||
|
import polars as pl
|
||||||
|
|
||||||
|
from src.factors import BaseFactor
|
||||||
|
from src.training.pipeline import (
|
||||||
|
FactorConfig,
|
||||||
|
predict_top_stocks,
|
||||||
|
prepare_data,
|
||||||
|
save_top_stocks,
|
||||||
|
train_model,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_data_and_train(
|
||||||
|
factors: Optional[List[BaseFactor]] = None,
|
||||||
|
data_dir: str = "data",
|
||||||
|
train_start: str = "20190101",
|
||||||
|
train_end: str = "20231231",
|
||||||
|
val_start: str = "20240102",
|
||||||
|
val_end: str = "20240531",
|
||||||
|
test_start: str = "20240602",
|
||||||
|
test_end: str = "20241231",
|
||||||
|
) -> tuple[pl.DataFrame, pl.DataFrame, pl.DataFrame, FactorConfig, str]:
|
||||||
|
"""第一步:数据处理
|
||||||
|
|
||||||
|
加载原始数据,计算因子和标签,拆分训练/验证/测试集。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
factors: 因子实例列表,默认为 None(使用 MA5, MA10, ReturnRank5)
|
||||||
|
data_dir: 数据目录
|
||||||
|
train_start: 训练集开始日期
|
||||||
|
train_end: 训练集结束日期
|
||||||
|
val_start: 验证集开始日期
|
||||||
|
val_end: 验证集结束日期
|
||||||
|
test_start: 测试集开始日期
|
||||||
|
test_end: 测试集结束日期
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple: (train_data, val_data, test_data, factor_config, label_col)
|
||||||
|
"""
|
||||||
|
print("=" * 50)
|
||||||
|
print("[Step 1] 数据处理")
|
||||||
|
print("=" * 50)
|
||||||
|
print(f"训练集: {train_start} -> {train_end}")
|
||||||
|
print(f"验证集: {val_start} -> {val_end}")
|
||||||
|
print(f"测试集: {test_start} -> {test_end}")
|
||||||
|
print()
|
||||||
|
|
||||||
|
# 1. 准备数据
|
||||||
|
train_data, val_data, test_data, factor_config = prepare_data(
|
||||||
|
factors=factors,
|
||||||
|
data_dir=data_dir,
|
||||||
|
train_start=train_start,
|
||||||
|
train_end=train_end,
|
||||||
|
val_start=val_start,
|
||||||
|
val_end=val_end,
|
||||||
|
test_start=test_start,
|
||||||
|
test_end=test_end,
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"训练集样本数: {len(train_data)}")
|
||||||
|
print(f"验证集样本数: {len(val_data)}")
|
||||||
|
print(f"测试集样本数: {len(test_data)}")
|
||||||
|
print()
|
||||||
|
|
||||||
|
# 打印少量数据样本展示
|
||||||
|
print("=" * 50)
|
||||||
|
print("[数据预览] 训练集前3行:")
|
||||||
|
print(train_data.head(3))
|
||||||
|
print()
|
||||||
|
print("[数据预览] 验证集前3行:")
|
||||||
|
print(val_data.head(3))
|
||||||
|
print()
|
||||||
|
print("[数据预览] 测试集前3行:")
|
||||||
|
print(test_data.head(3))
|
||||||
|
print()
|
||||||
|
|
||||||
|
# 2. 获取特征列名
|
||||||
|
feature_cols = factor_config.get_feature_names()
|
||||||
|
label_col = "label"
|
||||||
|
|
||||||
|
print(f"特征列: {feature_cols}")
|
||||||
|
print(f"标签列: {label_col}")
|
||||||
|
print()
|
||||||
|
|
||||||
|
return train_data, val_data, test_data, factor_config, label_col
|
||||||
|
|
||||||
|
|
||||||
|
def train_and_predict(
|
||||||
|
train_data: pl.DataFrame,
|
||||||
|
val_data: pl.DataFrame,
|
||||||
|
test_data: pl.DataFrame,
|
||||||
|
factor_config: FactorConfig,
|
||||||
|
label_col: str = "label",
|
||||||
|
top_n: int = 5,
|
||||||
|
output_path: str = "output/top_stocks.tsv",
|
||||||
|
) -> pl.DataFrame:
|
||||||
|
"""第二步:训练和预测
|
||||||
|
|
||||||
|
使用处理好的数据训练模型,进行测试集预测并保存结果。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
train_data: 训练数据
|
||||||
|
val_data: 验证数据
|
||||||
|
test_data: 测试数据
|
||||||
|
factor_config: 因子配置对象
|
||||||
|
label_col: 标签列名
|
||||||
|
top_n: 每日选股数量
|
||||||
|
output_path: 输出文件路径
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
选股结果DataFrame
|
||||||
|
"""
|
||||||
|
print("=" * 50)
|
||||||
|
print("[Step 2] 模型训练与预测")
|
||||||
|
print("=" * 50)
|
||||||
|
print()
|
||||||
|
|
||||||
|
# 获取特征列名
|
||||||
|
feature_cols = factor_config.get_feature_names()
|
||||||
|
print(f"使用特征: {feature_cols}")
|
||||||
|
print()
|
||||||
|
|
||||||
|
# 3. 训练模型
|
||||||
|
print("[Training] Training model...")
|
||||||
|
model, pipeline = train_model(
|
||||||
|
train_data=train_data,
|
||||||
|
val_data=val_data,
|
||||||
|
feature_cols=feature_cols,
|
||||||
|
label_col=label_col,
|
||||||
|
)
|
||||||
|
print()
|
||||||
|
|
||||||
|
# 4. 测试集预测
|
||||||
|
print("[Predict] Predicting on test set...")
|
||||||
|
top_stocks = predict_top_stocks(
|
||||||
|
model=model,
|
||||||
|
pipeline=pipeline,
|
||||||
|
test_data=test_data,
|
||||||
|
feature_cols=feature_cols,
|
||||||
|
top_n=top_n,
|
||||||
|
)
|
||||||
|
print()
|
||||||
|
|
||||||
|
# 5. 保存结果
|
||||||
|
print(f"[Saving] Saving results to {output_path}...")
|
||||||
|
Path(output_path).parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
save_top_stocks(top_stocks, output_path)
|
||||||
|
print()
|
||||||
|
|
||||||
|
return top_stocks
|
||||||
|
|
||||||
|
|
||||||
|
def run_full_pipeline(
|
||||||
|
factors: Optional[List[BaseFactor]] = None,
|
||||||
|
train_start: str = "20190101",
|
||||||
|
train_end: str = "20231231",
|
||||||
|
val_start: str = "20240102",
|
||||||
|
val_end: str = "20240531",
|
||||||
|
test_start: str = "20240602",
|
||||||
|
test_end: str = "20241231",
|
||||||
|
top_n: int = 5,
|
||||||
|
output_path: str = "output/top_stocks.tsv",
|
||||||
|
) -> pl.DataFrame:
|
||||||
|
"""运行完整训练流程
|
||||||
|
|
||||||
|
相当于依次调用 prepare_data_and_train 和 train_and_predict。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
factors: 因子实例列表,默认为 None(使用 MA5, MA10, ReturnRank5)
|
||||||
|
train_start: 训练集开始日期
|
||||||
|
train_end: 训练集结束日期
|
||||||
|
val_start: 验证集开始日期
|
||||||
|
val_end: 验证集结束日期
|
||||||
|
test_start: 测试集开始日期
|
||||||
|
test_end: 测试集结束日期
|
||||||
|
top_n: 每日选股数量
|
||||||
|
output_path: 输出文件路径
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
选股结果DataFrame
|
||||||
|
"""
|
||||||
|
# 第一步:数据处理
|
||||||
|
train_data, val_data, test_data, factor_config, label_col = prepare_data_and_train(
|
||||||
|
factors=factors,
|
||||||
|
train_start=train_start,
|
||||||
|
train_end=train_end,
|
||||||
|
val_start=val_start,
|
||||||
|
val_end=val_end,
|
||||||
|
test_start=test_start,
|
||||||
|
test_end=test_end,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 第二步:训练和预测
|
||||||
|
result = train_and_predict(
|
||||||
|
train_data=train_data,
|
||||||
|
val_data=val_data,
|
||||||
|
test_data=test_data,
|
||||||
|
factor_config=factor_config,
|
||||||
|
label_col=label_col,
|
||||||
|
top_n=top_n,
|
||||||
|
output_path=output_path,
|
||||||
|
)
|
||||||
|
|
||||||
|
print("=" * 50)
|
||||||
|
print("[Done] 训练流程完成!")
|
||||||
|
print("=" * 50)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# 运行完整训练流程
|
from src.factors import MovingAverageFactor, ReturnRankFactor
|
||||||
# 训练集:20190101 - 20231231
|
|
||||||
# 验证集:20240102 - 20240531 (与训练集间隔1天,避免数据泄露)
|
# ========== 因子配置 ==========
|
||||||
# 测试集:20240602 - 20241231 (与验证集间隔1天,避免数据泄露)
|
# 直接传入因子实例列表 - 简单直观
|
||||||
result = run_training(
|
factors = [
|
||||||
|
MovingAverageFactor(period=5), # 5日移动平均线
|
||||||
|
MovingAverageFactor(period=10), # 10日移动平均线
|
||||||
|
MovingAverageFactor(period=20), # 20日移动平均线
|
||||||
|
ReturnRankFactor(period=5), # 5日收益率排名
|
||||||
|
ReturnRankFactor(period=10), # 10日收益率排名
|
||||||
|
]
|
||||||
|
|
||||||
|
# ========== 运行方式 ==========
|
||||||
|
|
||||||
|
# 方式一:完整流程(一次性执行)
|
||||||
|
# result = run_full_pipeline(
|
||||||
|
# factors=factors,
|
||||||
|
# train_start="20190101",
|
||||||
|
# train_end="20231231",
|
||||||
|
# val_start="20240102",
|
||||||
|
# val_end="20240531",
|
||||||
|
# test_start="20240602",
|
||||||
|
# test_end="20241231",
|
||||||
|
# top_n=5,
|
||||||
|
# output_path="output/top_stocks.tsv",
|
||||||
|
# )
|
||||||
|
|
||||||
|
# 方式二:分步执行(便于调试)
|
||||||
|
# 第一步:数据处理
|
||||||
|
train_data, val_data, test_data, factor_config, label_col = prepare_data_and_train(
|
||||||
|
factors=factors,
|
||||||
train_start="20190101",
|
train_start="20190101",
|
||||||
train_end="20231231",
|
train_end="20231231",
|
||||||
val_start="20240102",
|
val_start="20240102",
|
||||||
val_end="20240531",
|
val_end="20240531",
|
||||||
test_start="20240602",
|
test_start="20240602",
|
||||||
test_end="20241231",
|
test_end="20241231",
|
||||||
top_n=5,
|
|
||||||
output_path="output/top_stocks.tsv",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
print("\n[Result] Top stocks selection:")
|
# 可在此处添加自定义逻辑,例如:
|
||||||
print(result)
|
# - 查看数据分布
|
||||||
|
# - 调整特征
|
||||||
|
# - 保存中间结果
|
||||||
|
print("\n[Info] 因子配置详情:")
|
||||||
|
print(f" 因子列表: {factor_config.get_feature_names()}")
|
||||||
|
print(f" 最大回溯天数: {factor_config.get_max_lookback()}")
|
||||||
|
|
||||||
|
# 第二步:训练和预测
|
||||||
|
# result = train_and_predict(
|
||||||
|
# train_data=train_data,
|
||||||
|
# val_data=val_data,
|
||||||
|
# test_data=test_data,
|
||||||
|
# factor_config=factor_config,
|
||||||
|
# label_col=label_col,
|
||||||
|
# top_n=5,
|
||||||
|
# output_path="output/top_stocks.tsv",
|
||||||
|
# )
|
||||||
|
#
|
||||||
|
# print("\n[Result] Top stocks selection:")
|
||||||
|
# print(result)
|
||||||
|
|||||||
@@ -1,21 +1,48 @@
|
|||||||
"""训练管道 - 包含数据处理、模型训练和预测功能
|
"""训练管道 - 包含数据处理、模型训练和预测功能
|
||||||
|
|
||||||
本模块提供:
|
本模块提供:
|
||||||
1. 数据准备:从因子计算结果中准备训练/测试数据
|
1. 数据准备:使用 FactorEngine 从因子计算结果中准备训练/测试数据
|
||||||
2. 数据处理:Fillna(0) -> Dropna
|
2. 数据处理:Fillna(0) -> Dropna
|
||||||
3. 模型训练:使用LightGBM训练分类模型
|
3. 模型训练:使用LightGBM训练分类模型
|
||||||
4. 预测和选股:输出每日top5股票池
|
4. 预测和选股:输出每日top5股票池
|
||||||
|
|
||||||
|
注意:本模块使用 src.factors 框架进行因子计算,确保防泄露机制生效。
|
||||||
|
|
||||||
|
因子配置示例:
|
||||||
|
from src.factors import MovingAverageFactor, ReturnRankFactor
|
||||||
|
from src.training.pipeline import prepare_data
|
||||||
|
|
||||||
|
# 直接传入因子实例列表 - 简单直观
|
||||||
|
factors = [
|
||||||
|
MovingAverageFactor(period=5),
|
||||||
|
MovingAverageFactor(period=10),
|
||||||
|
ReturnRankFactor(period=5),
|
||||||
|
]
|
||||||
|
|
||||||
|
train_data, val_data, test_data, factor_config = prepare_data(
|
||||||
|
factors=factors,
|
||||||
|
...
|
||||||
|
)
|
||||||
|
|
||||||
|
# 或者使用 FactorConfig 包装(支持链式添加)
|
||||||
|
from src.training.pipeline import FactorConfig
|
||||||
|
|
||||||
|
config = FactorConfig()
|
||||||
|
.add(MovingAverageFactor(period=5))
|
||||||
|
.add(MovingAverageFactor(period=10))
|
||||||
|
.add(ReturnRankFactor(period=5))
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List, Optional, Tuple
|
from typing import List, Optional
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import polars as pl
|
import polars as pl
|
||||||
|
|
||||||
from src.factors import DataLoader, FactorEngine
|
from src.factors import DataLoader, FactorEngine, BaseFactor
|
||||||
from src.factors.data_spec import DataSpec
|
from src.factors.data_spec import DataSpec
|
||||||
|
from src.factors.momentum import MovingAverageFactor, ReturnRankFactor
|
||||||
from src.pipeline import (
|
from src.pipeline import (
|
||||||
DropNAProcessor,
|
DropNAProcessor,
|
||||||
FillNAProcessor,
|
FillNAProcessor,
|
||||||
@@ -26,7 +53,98 @@ from src.pipeline import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ========== 因子配置类 ==========
|
||||||
|
|
||||||
|
|
||||||
|
class FactorConfig:
|
||||||
|
"""因子配置类 - 管理因子列表
|
||||||
|
|
||||||
|
用于包装因子实例列表,支持链式添加。
|
||||||
|
|
||||||
|
示例:
|
||||||
|
# 方式1:初始化时传入列表
|
||||||
|
config = FactorConfig([
|
||||||
|
MovingAverageFactor(period=5),
|
||||||
|
ReturnRankFactor(period=5),
|
||||||
|
])
|
||||||
|
|
||||||
|
# 方式2:链式添加
|
||||||
|
config = FactorConfig()
|
||||||
|
.add(MovingAverageFactor(period=5))
|
||||||
|
.add(ReturnRankFactor(period=5))
|
||||||
|
|
||||||
|
# 获取因子实例列表
|
||||||
|
factors = config.get_factors()
|
||||||
|
|
||||||
|
# 获取特征列名
|
||||||
|
feature_cols = config.get_feature_names()
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, factors: Optional[List[BaseFactor]] = None):
|
||||||
|
"""初始化因子配置
|
||||||
|
|
||||||
|
Args:
|
||||||
|
factors: 因子实例列表
|
||||||
|
"""
|
||||||
|
self._factors: List[BaseFactor] = factors or []
|
||||||
|
|
||||||
|
def add(self, factor: BaseFactor) -> "FactorConfig":
|
||||||
|
"""添加因子到配置
|
||||||
|
|
||||||
|
支持链式调用:
|
||||||
|
config = FactorConfig()
|
||||||
|
.add(MovingAverageFactor(period=5))
|
||||||
|
.add(ReturnRankFactor(period=5))
|
||||||
|
|
||||||
|
Args:
|
||||||
|
factor: 因子实例
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
self,支持链式调用
|
||||||
|
"""
|
||||||
|
if not isinstance(factor, BaseFactor):
|
||||||
|
raise ValueError(f"必须是 BaseFactor 实例, got {type(factor)}")
|
||||||
|
self._factors.append(factor)
|
||||||
|
return self
|
||||||
|
|
||||||
|
def get_factors(self) -> List[BaseFactor]:
|
||||||
|
"""获取因子实例列表
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
因子实例列表
|
||||||
|
"""
|
||||||
|
return self._factors
|
||||||
|
|
||||||
|
def get_feature_names(self) -> List[str]:
|
||||||
|
"""获取所有因子的特征列名
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
特征列名列表
|
||||||
|
"""
|
||||||
|
return [f.name for f in self._factors]
|
||||||
|
|
||||||
|
def get_max_lookback(self) -> int:
|
||||||
|
"""获取所有因子中最大的 lookback 天数
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
最大 lookback 天数
|
||||||
|
"""
|
||||||
|
max_lookback = 0
|
||||||
|
for factor in self._factors:
|
||||||
|
for spec in factor.data_specs:
|
||||||
|
max_lookback = max(max_lookback, spec.lookback_days)
|
||||||
|
return max_lookback
|
||||||
|
|
||||||
|
def __len__(self) -> int:
|
||||||
|
return len(self._factors)
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
names = [f.name for f in self._factors]
|
||||||
|
return f"FactorConfig({names})"
|
||||||
|
|
||||||
|
|
||||||
def prepare_data(
|
def prepare_data(
|
||||||
|
factors: Optional[List[BaseFactor]] = None,
|
||||||
data_dir: str = "data",
|
data_dir: str = "data",
|
||||||
train_start: str = "20180101",
|
train_start: str = "20180101",
|
||||||
train_end: str = "20230101",
|
train_end: str = "20230101",
|
||||||
@@ -34,12 +152,13 @@ def prepare_data(
|
|||||||
val_end: str = "20230601",
|
val_end: str = "20230601",
|
||||||
test_start: str = "20230601",
|
test_start: str = "20230601",
|
||||||
test_end: str = "20240101",
|
test_end: str = "20240101",
|
||||||
) -> tuple[pl.DataFrame, pl.DataFrame, pl.DataFrame]:
|
) -> tuple[pl.DataFrame, pl.DataFrame, pl.DataFrame, FactorConfig]:
|
||||||
"""准备训练、验证和测试数据
|
"""准备训练、验证和测试数据
|
||||||
|
|
||||||
从DuckDB加载原始日线数据,计算所需因子并生成标签。
|
使用 FactorEngine 计算因子,确保防泄露机制生效。
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
factors: 因子实例列表,默认为 None(使用 MA5, MA10, ReturnRank5)
|
||||||
data_dir: 数据目录
|
data_dir: 数据目录
|
||||||
train_start: 训练集开始日期
|
train_start: 训练集开始日期
|
||||||
train_end: 训练集结束日期
|
train_end: 训练集结束日期
|
||||||
@@ -49,45 +168,108 @@ def prepare_data(
|
|||||||
test_end: 测试集结束日期
|
test_end: 测试集结束日期
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
(train_data, val_data, test_data): 训练集、验证集和测试集的DataFrame
|
(train_data, val_data, test_data, factor_config):
|
||||||
|
训练集、验证集、测试集的DataFrame,以及使用的因子配置
|
||||||
"""
|
"""
|
||||||
from src.data.storage import Storage
|
from src.data.storage import Storage
|
||||||
|
|
||||||
storage = Storage()
|
storage = Storage()
|
||||||
|
|
||||||
# 加载日线数据(需要更多历史数据用于计算因子)
|
# 1. 处理因子配置
|
||||||
# 训练集需要更多历史数据(用于计算因子lookback)
|
if factors is None:
|
||||||
lookback_days = 20 # 足够计算MA10和5日收益率
|
# 默认因子配置
|
||||||
|
factor_config = FactorConfig(
|
||||||
|
[
|
||||||
|
MovingAverageFactor(period=5),
|
||||||
|
MovingAverageFactor(period=10),
|
||||||
|
ReturnRankFactor(period=5),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
elif isinstance(factors, FactorConfig):
|
||||||
|
factor_config = factors
|
||||||
|
elif isinstance(factors, list):
|
||||||
|
# 转换为 FactorConfig
|
||||||
|
factor_config = FactorConfig(factors)
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"factors 必须是 List[BaseFactor] 或 FactorConfig, got {type(factors)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
factors_list = factor_config.get_factors()
|
||||||
|
feature_cols = factor_config.get_feature_names()
|
||||||
|
|
||||||
|
if not factors_list:
|
||||||
|
raise ValueError("至少需要提供一个因子")
|
||||||
|
|
||||||
|
print(f"[PrepareData] 使用因子: {feature_cols}")
|
||||||
|
|
||||||
|
# 2. 初始化 FactorEngine
|
||||||
|
loader = DataLoader(data_dir=data_dir)
|
||||||
|
engine = FactorEngine(loader)
|
||||||
|
|
||||||
|
# 3. 计算需要的回溯天数
|
||||||
|
max_lookback = factor_config.get_max_lookback()
|
||||||
start_with_lookback = str(int(train_start) - 10000) # 往前取一年
|
start_with_lookback = str(int(train_start) - 10000) # 往前取一年
|
||||||
|
|
||||||
# 查询全部数据(包含train、val、test),然后再拆分
|
# 获取所有股票列表
|
||||||
# 注意:DuckDB 中 trade_date 是 DATE 类型,需要转换
|
|
||||||
start_dt = f"{start_with_lookback[:4]}-{start_with_lookback[4:6]}-{start_with_lookback[6:8]}"
|
start_dt = f"{start_with_lookback[:4]}-{start_with_lookback[4:6]}-{start_with_lookback[6:8]}"
|
||||||
end_dt = f"{test_end[:4]}-{test_end[4:6]}-{test_end[6:8]}"
|
end_dt = f"{test_end[:4]}-{test_end[4:6]}-{test_end[6:8]}"
|
||||||
|
all_stocks_query = f"""
|
||||||
all_query = f"""
|
SELECT DISTINCT ts_code FROM daily
|
||||||
SELECT ts_code, trade_date, close, pre_close
|
|
||||||
FROM daily
|
|
||||||
WHERE trade_date >= '{start_dt}' AND trade_date <= '{end_dt}'
|
WHERE trade_date >= '{start_dt}' AND trade_date <= '{end_dt}'
|
||||||
ORDER BY ts_code, trade_date
|
|
||||||
"""
|
"""
|
||||||
all_raw = storage._connection.sql(all_query).pl()
|
all_stocks_df = storage._connection.sql(all_stocks_query).pl()
|
||||||
# 转换 trade_date 为字符串格式
|
all_stocks = all_stocks_df["ts_code"].to_list()
|
||||||
all_raw = all_raw.with_columns(
|
|
||||||
pl.col("trade_date").dt.strftime("%Y-%m-%d").alias("trade_date")
|
print(f"[PrepareData] 股票数量: {len(all_stocks)}")
|
||||||
|
|
||||||
|
# 4. 计算所有因子并合并
|
||||||
|
all_features = None
|
||||||
|
|
||||||
|
for factor in factors_list:
|
||||||
|
print(f"[PrepareData] 计算因子: {factor.name} ({factor.factor_type})")
|
||||||
|
|
||||||
|
if factor.factor_type == "time_series":
|
||||||
|
# 时序因子需要传入股票列表
|
||||||
|
result = engine.compute(
|
||||||
|
factor,
|
||||||
|
stock_codes=all_stocks,
|
||||||
|
start_date=start_with_lookback,
|
||||||
|
end_date=test_end,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# 截面因子不需要股票列表
|
||||||
|
result = engine.compute(
|
||||||
|
factor,
|
||||||
|
start_date=start_with_lookback,
|
||||||
|
end_date=test_end,
|
||||||
)
|
)
|
||||||
|
|
||||||
# 过滤不符合条件的股票
|
# 合并结果
|
||||||
all_raw = _filter_invalid_stocks(all_raw)
|
if all_features is None:
|
||||||
print(f"[PrepareData] After filtering: total={len(all_raw)}")
|
all_features = result
|
||||||
|
else:
|
||||||
# 计算因子和标签(需要全局数据一次性计算)
|
# 确保没有重复的 _right 列
|
||||||
all_data = _compute_features_and_label(
|
result = result.select([
|
||||||
all_raw,
|
c for c in result.columns if not c.endswith("_right")
|
||||||
start_date=train_start,
|
])
|
||||||
end_date=test_end
|
all_features = all_features.select([
|
||||||
|
c for c in all_features.columns if not c.endswith("_right")
|
||||||
|
])
|
||||||
|
all_features = all_features.join(
|
||||||
|
result, on=["trade_date", "ts_code"], how="outer"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if all_features is None:
|
||||||
|
raise ValueError("没有计算任何因子")
|
||||||
|
|
||||||
|
# 5. 计算标签(未来5日收益率)
|
||||||
|
all_data = _compute_label(all_features, start_date=train_start, end_date=test_end)
|
||||||
|
|
||||||
|
# 6. 过滤不符合条件的股票
|
||||||
|
all_data = _filter_invalid_stocks(all_data)
|
||||||
|
print(f"[PrepareData] After filtering: total={len(all_data)}")
|
||||||
|
|
||||||
# 转换日期格式用于比较
|
# 转换日期格式用于比较
|
||||||
train_start_fmt = f"{train_start[:4]}-{train_start[4:6]}-{train_start[6:8]}"
|
train_start_fmt = f"{train_start[:4]}-{train_start[4:6]}-{train_start[6:8]}"
|
||||||
train_end_fmt = f"{train_end[:4]}-{train_end[4:6]}-{train_end[6:8]}"
|
train_end_fmt = f"{train_end[:4]}-{train_end[4:6]}-{train_end[6:8]}"
|
||||||
@@ -98,18 +280,22 @@ def prepare_data(
|
|||||||
|
|
||||||
# 拆分数据
|
# 拆分数据
|
||||||
train_data = all_data.filter(
|
train_data = all_data.filter(
|
||||||
(pl.col("trade_date") >= train_start_fmt) & (pl.col("trade_date") <= train_end_fmt)
|
(pl.col("trade_date") >= train_start_fmt)
|
||||||
|
& (pl.col("trade_date") <= train_end_fmt)
|
||||||
)
|
)
|
||||||
val_data = all_data.filter(
|
val_data = all_data.filter(
|
||||||
(pl.col("trade_date") >= val_start_fmt) & (pl.col("trade_date") <= val_end_fmt)
|
(pl.col("trade_date") >= val_start_fmt) & (pl.col("trade_date") <= val_end_fmt)
|
||||||
)
|
)
|
||||||
test_data = all_data.filter(
|
test_data = all_data.filter(
|
||||||
(pl.col("trade_date") >= test_start_fmt) & (pl.col("trade_date") <= test_end_fmt)
|
(pl.col("trade_date") >= test_start_fmt)
|
||||||
|
& (pl.col("trade_date") <= test_end_fmt)
|
||||||
)
|
)
|
||||||
|
|
||||||
print(f"[PrepareData] Split result: train={len(train_data)}, val={len(val_data)}, test={len(test_data)}")
|
print(
|
||||||
|
f"[PrepareData] Split result: train={len(train_data)}, val={len(val_data)}, test={len(test_data)}"
|
||||||
|
)
|
||||||
|
|
||||||
return train_data, val_data, test_data
|
return train_data, val_data, test_data, factor_config
|
||||||
|
|
||||||
|
|
||||||
def _filter_invalid_stocks(df: pl.DataFrame) -> pl.DataFrame:
|
def _filter_invalid_stocks(df: pl.DataFrame) -> pl.DataFrame:
|
||||||
@@ -137,59 +323,54 @@ def _filter_invalid_stocks(df: pl.DataFrame) -> pl.DataFrame:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _compute_features_and_label(
|
def _compute_label(
|
||||||
raw_data: pl.DataFrame,
|
features_df: pl.DataFrame,
|
||||||
start_date: str,
|
start_date: str,
|
||||||
end_date: str,
|
end_date: str,
|
||||||
) -> pl.DataFrame:
|
) -> pl.DataFrame:
|
||||||
"""计算因子和标签
|
"""计算标签(未来5日收益率)
|
||||||
|
|
||||||
因子:
|
标签定义:未来5日收益率大于0为1,否则为0
|
||||||
1. return_5_rank: 5日收益率截面排名
|
|
||||||
2. ma_5: 5日移动平均
|
|
||||||
3. ma_10: 10日移动平均
|
|
||||||
|
|
||||||
标签:未来5日收益率大于0为1,否则为0
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
raw_data: 原始日线数据
|
features_df: 包含因子的DataFrame
|
||||||
start_date: 开始日期
|
start_date: 开始日期
|
||||||
end_date: 结束日期
|
end_date: 结束日期
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
包含因子和标签的DataFrame
|
包含因子和标签的DataFrame
|
||||||
"""
|
"""
|
||||||
# 确保按日期排序
|
from src.data.storage import Storage
|
||||||
raw_data = raw_data.sort(["ts_code", "trade_date"])
|
|
||||||
|
|
||||||
# 计算收益率(未来5日)
|
storage = Storage()
|
||||||
raw_data = raw_data.with_columns(
|
|
||||||
[
|
# 从数据库获取收盘价数据用于计算标签
|
||||||
# 当日收益率
|
start_dt = f"{start_date[:4]}-{start_date[4:6]}-{start_date[6:8]}"
|
||||||
((pl.col("close") - pl.col("pre_close")) / pl.col("pre_close")).alias(
|
end_dt = f"{end_date[:4]}-{end_date[4:6]}-{end_date[6:8]}"
|
||||||
"daily_return"
|
|
||||||
),
|
# 需要多取5天数据来计算未来收益率
|
||||||
]
|
end_dt_extended = f"{end_date[:4]}-{end_date[4:6]}-{int(end_date[6:8]) + 5}"
|
||||||
|
|
||||||
|
price_query = f"""
|
||||||
|
SELECT ts_code, trade_date, close
|
||||||
|
FROM daily
|
||||||
|
WHERE trade_date >= '{start_dt}' AND trade_date <= '{end_dt_extended}'
|
||||||
|
ORDER BY ts_code, trade_date
|
||||||
|
"""
|
||||||
|
price_data = storage._connection.sql(price_query).pl()
|
||||||
|
price_data = price_data.with_columns(
|
||||||
|
pl.col("trade_date").dt.strftime("%Y-%m-%d").alias("trade_date")
|
||||||
)
|
)
|
||||||
|
|
||||||
# 按股票分组计算
|
# 按股票计算未来5日收益率
|
||||||
result_list = []
|
result_list = []
|
||||||
|
for ts_code in price_data["ts_code"].unique():
|
||||||
|
stock_data = price_data.filter(pl.col("ts_code") == ts_code).sort("trade_date")
|
||||||
|
|
||||||
for ts_code in raw_data["ts_code"].unique():
|
if len(stock_data) < 6:
|
||||||
stock_data = raw_data.filter(pl.col("ts_code") == ts_code).sort("trade_date")
|
|
||||||
|
|
||||||
if len(stock_data) < 20:
|
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# 计算MA5和MA10
|
# 计算未来5日收益率
|
||||||
stock_data = stock_data.with_columns(
|
|
||||||
[
|
|
||||||
pl.col("close").rolling_mean(5).alias("ma_5"),
|
|
||||||
pl.col("close").rolling_mean(10).alias("ma_10"),
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
# 计算未来5日收益率(用于标签)
|
|
||||||
future_return = stock_data["close"].shift(-5) - stock_data["close"]
|
future_return = stock_data["close"].shift(-5) - stock_data["close"]
|
||||||
future_return_pct = future_return / stock_data["close"]
|
future_return_pct = future_return / stock_data["close"]
|
||||||
stock_data = stock_data.with_columns(
|
stock_data = stock_data.with_columns(
|
||||||
@@ -205,46 +386,21 @@ def _compute_features_and_label(
|
|||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
result_list.append(stock_data)
|
result_list.append(stock_data.select(["trade_date", "ts_code", "label"]))
|
||||||
|
|
||||||
if not result_list:
|
if not result_list:
|
||||||
return pl.DataFrame()
|
return pl.DataFrame()
|
||||||
|
|
||||||
result = pl.concat(result_list)
|
label_df = pl.concat(result_list)
|
||||||
|
|
||||||
# 转换日期格式:YYYYMMDD -> YYYY-MM-DD
|
# 将标签合并到因子数据
|
||||||
start_date_formatted = f"{start_date[:4]}-{start_date[4:6]}-{start_date[6:8]}"
|
result = features_df.join(label_df, on=["trade_date", "ts_code"], how="inner")
|
||||||
end_date_formatted = f"{end_date[:4]}-{end_date[4:6]}-{end_date[6:8]}"
|
|
||||||
|
|
||||||
# 过滤有效日期范围
|
# 过滤有效日期范围
|
||||||
result = result.filter(
|
result = result.filter(
|
||||||
(pl.col("trade_date") >= start_date_formatted) & (pl.col("trade_date") <= end_date_formatted)
|
(pl.col("trade_date") >= start_dt) & (pl.col("trade_date") <= end_dt)
|
||||||
)
|
)
|
||||||
|
|
||||||
# 计算5日收益率排名(截面)
|
|
||||||
result = result.with_columns(
|
|
||||||
[
|
|
||||||
pl.col("daily_return")
|
|
||||||
.rank(method="average")
|
|
||||||
.over("trade_date")
|
|
||||||
.alias("return_5_rank")
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
# 归一化排名到0-1
|
|
||||||
result = result.with_columns(
|
|
||||||
[
|
|
||||||
(
|
|
||||||
pl.col("return_5_rank")
|
|
||||||
/ pl.col("return_5_rank").max().over("trade_date")
|
|
||||||
).alias("return_5_rank")
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
# 选择需要的列
|
|
||||||
feature_cols = ["trade_date", "ts_code", "return_5_rank", "ma_5", "ma_10", "label"]
|
|
||||||
result = result.select(feature_cols)
|
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
@@ -271,7 +427,7 @@ def train_model(
|
|||||||
feature_cols: List[str],
|
feature_cols: List[str],
|
||||||
label_col: str = "label",
|
label_col: str = "label",
|
||||||
model_params: Optional[dict] = None,
|
model_params: Optional[dict] = None,
|
||||||
) -> Tuple[LightGBMModel, ProcessingPipeline]:
|
) -> tuple[LightGBMModel, ProcessingPipeline]:
|
||||||
"""训练LightGBM分类模型
|
"""训练LightGBM分类模型
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -301,8 +457,12 @@ def train_model(
|
|||||||
valid_mask = y_train.is_in([0, 1])
|
valid_mask = y_train.is_in([0, 1])
|
||||||
X_train_processed = X_train_processed.filter(valid_mask)
|
X_train_processed = X_train_processed.filter(valid_mask)
|
||||||
y_train = y_train.filter(valid_mask)
|
y_train = y_train.filter(valid_mask)
|
||||||
print(f"[TrainModel] After filtering valid labels: {len(X_train_processed)} samples")
|
print(
|
||||||
print(f"[TrainModel] Train label distribution: {dict(y_train.value_counts().sort('label').iter_rows())}")
|
f"[TrainModel] After filtering valid labels: {len(X_train_processed)} samples"
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
f"[TrainModel] Train label distribution: {dict(y_train.value_counts().sort('label').iter_rows())}"
|
||||||
|
)
|
||||||
|
|
||||||
# 准备验证集
|
# 准备验证集
|
||||||
X_val_processed = None
|
X_val_processed = None
|
||||||
@@ -320,7 +480,9 @@ def train_model(
|
|||||||
X_val_processed = X_val_processed.filter(val_valid_mask)
|
X_val_processed = X_val_processed.filter(val_valid_mask)
|
||||||
y_val = y_val.filter(val_valid_mask)
|
y_val = y_val.filter(val_valid_mask)
|
||||||
print(f"[TrainModel] Val after filtering: {len(X_val_processed)} samples")
|
print(f"[TrainModel] Val after filtering: {len(X_val_processed)} samples")
|
||||||
print(f"[TrainModel] Val label distribution: {dict(y_val.value_counts().sort('label').iter_rows())}")
|
print(
|
||||||
|
f"[TrainModel] Val label distribution: {dict(y_val.value_counts().sort('label').iter_rows())}"
|
||||||
|
)
|
||||||
|
|
||||||
# 创建模型
|
# 创建模型
|
||||||
params = model_params or {
|
params = model_params or {
|
||||||
@@ -384,7 +546,10 @@ def predict_top_stocks(
|
|||||||
# 使用 key_data 添加预测结果,保持行数一致
|
# 使用 key_data 添加预测结果,保持行数一致
|
||||||
result = key_data.with_columns(
|
result = key_data.with_columns(
|
||||||
pl.Series(
|
pl.Series(
|
||||||
name="pred_prob", values=probs[:, 1] if len(probs.shape) > 1 and probs.shape[1] > 1 else probs.flatten()
|
name="pred_prob",
|
||||||
|
values=probs[:, 1]
|
||||||
|
if len(probs.shape) > 1 and probs.shape[1] > 1
|
||||||
|
else probs.flatten(),
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -396,7 +561,11 @@ def predict_top_stocks(
|
|||||||
# 按概率降序排序,选出top N
|
# 按概率降序排序,选出top N
|
||||||
day_top = day_data.sort("pred_prob", descending=True).head(top_n)
|
day_top = day_data.sort("pred_prob", descending=True).head(top_n)
|
||||||
|
|
||||||
top_stocks.append(day_top.select(["trade_date", "pred_prob", "ts_code"]).rename({"pred_prob": "score"}))
|
top_stocks.append(
|
||||||
|
day_top.select(["trade_date", "pred_prob", "ts_code"]).rename(
|
||||||
|
{"pred_prob": "score"}
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
return pl.concat(top_stocks)
|
return pl.concat(top_stocks)
|
||||||
|
|
||||||
@@ -415,6 +584,7 @@ def save_top_stocks(top_stocks: pl.DataFrame, output_path: str) -> None:
|
|||||||
|
|
||||||
|
|
||||||
def run_training(
|
def run_training(
|
||||||
|
factors: Optional[List[BaseFactor]] = None,
|
||||||
data_dir: str = "data",
|
data_dir: str = "data",
|
||||||
output_path: str = "output/top_stocks.tsv",
|
output_path: str = "output/top_stocks.tsv",
|
||||||
train_start: str = "20180101",
|
train_start: str = "20180101",
|
||||||
@@ -428,6 +598,7 @@ def run_training(
|
|||||||
"""运行完整训练流程
|
"""运行完整训练流程
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
factors: 因子实例列表,默认为 None(使用 MA5, MA10, ReturnRank5)
|
||||||
data_dir: 数据目录
|
data_dir: 数据目录
|
||||||
output_path: 输出文件路径
|
output_path: 输出文件路径
|
||||||
train_start: 训练集开始日期
|
train_start: 训练集开始日期
|
||||||
@@ -448,7 +619,8 @@ def run_training(
|
|||||||
|
|
||||||
# 1. 准备数据
|
# 1. 准备数据
|
||||||
print("[Training] Preparing data...")
|
print("[Training] Preparing data...")
|
||||||
train_data, val_data, test_data = prepare_data(
|
train_data, val_data, test_data, factor_config = prepare_data(
|
||||||
|
factors=factors,
|
||||||
data_dir=data_dir,
|
data_dir=data_dir,
|
||||||
train_start=train_start,
|
train_start=train_start,
|
||||||
train_end=train_end,
|
train_end=train_end,
|
||||||
@@ -461,9 +633,10 @@ def run_training(
|
|||||||
print(f"[Training] Val samples: {len(val_data)}")
|
print(f"[Training] Val samples: {len(val_data)}")
|
||||||
print(f"[Training] Test samples: {len(test_data)}")
|
print(f"[Training] Test samples: {len(test_data)}")
|
||||||
|
|
||||||
# 2. 定义特征列
|
# 2. 获取特征列名
|
||||||
feature_cols = ["return_5_rank", "ma_5", "ma_10"]
|
feature_cols = factor_config.get_feature_names()
|
||||||
label_col = "label"
|
label_col = "label"
|
||||||
|
print(f"[Training] Feature columns: {feature_cols}")
|
||||||
|
|
||||||
# 3. 训练模型
|
# 3. 训练模型
|
||||||
print("[Training] Training model...")
|
print("[Training] Training model...")
|
||||||
|
|||||||
Reference in New Issue
Block a user