refactor(training): 重构训练管道,支持灵活因子配置

- 添加 FactorConfig 类,支持链式 API 和动态因子列表
- 重构 prepare_data 使用 FactorEngine 计算因子,确保防泄露机制生效
- 新增 prepare_data_and_train / train_and_predict 分步执行接口
- 修复 factors/__init__.py docstring 语法错误
This commit is contained in:
2026-02-25 23:39:02 +08:00
parent a9e4746239
commit 990a77ec6c
2 changed files with 564 additions and 119 deletions

View File

@@ -1,21 +1,48 @@
"""训练管道 - 包含数据处理、模型训练和预测功能
本模块提供:
1. 数据准备:从因子计算结果中准备训练/测试数据
1. 数据准备:使用 FactorEngine 从因子计算结果中准备训练/测试数据
2. 数据处理Fillna(0) -> Dropna
3. 模型训练使用LightGBM训练分类模型
4. 预测和选股输出每日top5股票池
注意:本模块使用 src.factors 框架进行因子计算,确保防泄露机制生效。
因子配置示例:
from src.factors import MovingAverageFactor, ReturnRankFactor
from src.training.pipeline import prepare_data
# 直接传入因子实例列表 - 简单直观
factors = [
MovingAverageFactor(period=5),
MovingAverageFactor(period=10),
ReturnRankFactor(period=5),
]
train_data, val_data, test_data, factor_config = prepare_data(
factors=factors,
...
)
# 或者使用 FactorConfig 包装(支持链式添加)
from src.training.pipeline import FactorConfig
config = FactorConfig()
.add(MovingAverageFactor(period=5))
.add(MovingAverageFactor(period=10))
.add(ReturnRankFactor(period=5))
"""
from datetime import datetime
from pathlib import Path
from typing import List, Optional, Tuple
from typing import List, Optional
import numpy as np
import polars as pl
from src.factors import DataLoader, FactorEngine
from src.factors import DataLoader, FactorEngine, BaseFactor
from src.factors.data_spec import DataSpec
from src.factors.momentum import MovingAverageFactor, ReturnRankFactor
from src.pipeline import (
DropNAProcessor,
FillNAProcessor,
@@ -26,7 +53,98 @@ from src.pipeline import (
)
# ========== 因子配置类 ==========
class FactorConfig:
"""因子配置类 - 管理因子列表
用于包装因子实例列表,支持链式添加。
示例:
# 方式1初始化时传入列表
config = FactorConfig([
MovingAverageFactor(period=5),
ReturnRankFactor(period=5),
])
# 方式2链式添加
config = FactorConfig()
.add(MovingAverageFactor(period=5))
.add(ReturnRankFactor(period=5))
# 获取因子实例列表
factors = config.get_factors()
# 获取特征列名
feature_cols = config.get_feature_names()
"""
def __init__(self, factors: Optional[List[BaseFactor]] = None):
"""初始化因子配置
Args:
factors: 因子实例列表
"""
self._factors: List[BaseFactor] = factors or []
def add(self, factor: BaseFactor) -> "FactorConfig":
"""添加因子到配置
支持链式调用:
config = FactorConfig()
.add(MovingAverageFactor(period=5))
.add(ReturnRankFactor(period=5))
Args:
factor: 因子实例
Returns:
self支持链式调用
"""
if not isinstance(factor, BaseFactor):
raise ValueError(f"必须是 BaseFactor 实例, got {type(factor)}")
self._factors.append(factor)
return self
def get_factors(self) -> List[BaseFactor]:
"""获取因子实例列表
Returns:
因子实例列表
"""
return self._factors
def get_feature_names(self) -> List[str]:
"""获取所有因子的特征列名
Returns:
特征列名列表
"""
return [f.name for f in self._factors]
def get_max_lookback(self) -> int:
"""获取所有因子中最大的 lookback 天数
Returns:
最大 lookback 天数
"""
max_lookback = 0
for factor in self._factors:
for spec in factor.data_specs:
max_lookback = max(max_lookback, spec.lookback_days)
return max_lookback
def __len__(self) -> int:
return len(self._factors)
def __repr__(self) -> str:
names = [f.name for f in self._factors]
return f"FactorConfig({names})"
def prepare_data(
factors: Optional[List[BaseFactor]] = None,
data_dir: str = "data",
train_start: str = "20180101",
train_end: str = "20230101",
@@ -34,12 +152,13 @@ def prepare_data(
val_end: str = "20230601",
test_start: str = "20230601",
test_end: str = "20240101",
) -> tuple[pl.DataFrame, pl.DataFrame, pl.DataFrame]:
) -> tuple[pl.DataFrame, pl.DataFrame, pl.DataFrame, FactorConfig]:
"""准备训练、验证和测试数据
从DuckDB加载原始日线数据计算所需因子并生成标签
使用 FactorEngine 计算因子,确保防泄露机制生效
Args:
factors: 因子实例列表,默认为 None使用 MA5, MA10, ReturnRank5
data_dir: 数据目录
train_start: 训练集开始日期
train_end: 训练集结束日期
@@ -49,44 +168,107 @@ def prepare_data(
test_end: 测试集结束日期
Returns:
(train_data, val_data, test_data): 训练集、验证集和测试集的DataFrame
(train_data, val_data, test_data, factor_config):
训练集、验证集、测试集的DataFrame以及使用的因子配置
"""
from src.data.storage import Storage
storage = Storage()
# 加载日线数据(需要更多历史数据用于计算因子)
# 训练集需要更多历史数据用于计算因子lookback
lookback_days = 20 # 足够计算MA10和5日收益率
# 1. 处理因子配置
if factors is None:
# 默认因子配置
factor_config = FactorConfig(
[
MovingAverageFactor(period=5),
MovingAverageFactor(period=10),
ReturnRankFactor(period=5),
]
)
elif isinstance(factors, FactorConfig):
factor_config = factors
elif isinstance(factors, list):
# 转换为 FactorConfig
factor_config = FactorConfig(factors)
else:
raise ValueError(
f"factors 必须是 List[BaseFactor] 或 FactorConfig got {type(factors)}"
)
factors_list = factor_config.get_factors()
feature_cols = factor_config.get_feature_names()
if not factors_list:
raise ValueError("至少需要提供一个因子")
print(f"[PrepareData] 使用因子: {feature_cols}")
# 2. 初始化 FactorEngine
loader = DataLoader(data_dir=data_dir)
engine = FactorEngine(loader)
# 3. 计算需要的回溯天数
max_lookback = factor_config.get_max_lookback()
start_with_lookback = str(int(train_start) - 10000) # 往前取一年
# 查询全部数据包含train、val、test然后再拆分
# 注意DuckDB 中 trade_date 是 DATE 类型,需要转换
# 获取所有股票列表
start_dt = f"{start_with_lookback[:4]}-{start_with_lookback[4:6]}-{start_with_lookback[6:8]}"
end_dt = f"{test_end[:4]}-{test_end[4:6]}-{test_end[6:8]}"
all_query = f"""
SELECT ts_code, trade_date, close, pre_close
FROM daily
all_stocks_query = f"""
SELECT DISTINCT ts_code FROM daily
WHERE trade_date >= '{start_dt}' AND trade_date <= '{end_dt}'
ORDER BY ts_code, trade_date
"""
all_raw = storage._connection.sql(all_query).pl()
# 转换 trade_date 为字符串格式
all_raw = all_raw.with_columns(
pl.col("trade_date").dt.strftime("%Y-%m-%d").alias("trade_date")
)
all_stocks_df = storage._connection.sql(all_stocks_query).pl()
all_stocks = all_stocks_df["ts_code"].to_list()
# 过滤不符合条件的股票
all_raw = _filter_invalid_stocks(all_raw)
print(f"[PrepareData] After filtering: total={len(all_raw)}")
print(f"[PrepareData] 股票数量: {len(all_stocks)}")
# 计算因子和标签(需要全局数据一次性计算)
all_data = _compute_features_and_label(
all_raw,
start_date=train_start,
end_date=test_end
)
# 4. 计算所有因子并合并
all_features = None
for factor in factors_list:
print(f"[PrepareData] 计算因子: {factor.name} ({factor.factor_type})")
if factor.factor_type == "time_series":
# 时序因子需要传入股票列表
result = engine.compute(
factor,
stock_codes=all_stocks,
start_date=start_with_lookback,
end_date=test_end,
)
else:
# 截面因子不需要股票列表
result = engine.compute(
factor,
start_date=start_with_lookback,
end_date=test_end,
)
# 合并结果
if all_features is None:
all_features = result
else:
# 确保没有重复的 _right 列
result = result.select([
c for c in result.columns if not c.endswith("_right")
])
all_features = all_features.select([
c for c in all_features.columns if not c.endswith("_right")
])
all_features = all_features.join(
result, on=["trade_date", "ts_code"], how="outer"
)
if all_features is None:
raise ValueError("没有计算任何因子")
# 5. 计算标签未来5日收益率
all_data = _compute_label(all_features, start_date=train_start, end_date=test_end)
# 6. 过滤不符合条件的股票
all_data = _filter_invalid_stocks(all_data)
print(f"[PrepareData] After filtering: total={len(all_data)}")
# 转换日期格式用于比较
train_start_fmt = f"{train_start[:4]}-{train_start[4:6]}-{train_start[6:8]}"
@@ -98,18 +280,22 @@ def prepare_data(
# 拆分数据
train_data = all_data.filter(
(pl.col("trade_date") >= train_start_fmt) & (pl.col("trade_date") <= train_end_fmt)
(pl.col("trade_date") >= train_start_fmt)
& (pl.col("trade_date") <= train_end_fmt)
)
val_data = all_data.filter(
(pl.col("trade_date") >= val_start_fmt) & (pl.col("trade_date") <= val_end_fmt)
)
test_data = all_data.filter(
(pl.col("trade_date") >= test_start_fmt) & (pl.col("trade_date") <= test_end_fmt)
(pl.col("trade_date") >= test_start_fmt)
& (pl.col("trade_date") <= test_end_fmt)
)
print(f"[PrepareData] Split result: train={len(train_data)}, val={len(val_data)}, test={len(test_data)}")
print(
f"[PrepareData] Split result: train={len(train_data)}, val={len(val_data)}, test={len(test_data)}"
)
return train_data, val_data, test_data
return train_data, val_data, test_data, factor_config
def _filter_invalid_stocks(df: pl.DataFrame) -> pl.DataFrame:
@@ -137,59 +323,54 @@ def _filter_invalid_stocks(df: pl.DataFrame) -> pl.DataFrame:
)
def _compute_features_and_label(
raw_data: pl.DataFrame,
def _compute_label(
features_df: pl.DataFrame,
start_date: str,
end_date: str,
) -> pl.DataFrame:
"""计算因子和标签
"""计算标签未来5日收益率
因子:
1. return_5_rank: 5日收益率截面排名
2. ma_5: 5日移动平均
3. ma_10: 10日移动平均
标签未来5日收益率大于0为1否则为0
标签定义未来5日收益率大于0为1否则为0
Args:
raw_data: 原始日线数据
features_df: 包含因子的DataFrame
start_date: 开始日期
end_date: 结束日期
Returns:
包含因子和标签的DataFrame
"""
# 确保按日期排序
raw_data = raw_data.sort(["ts_code", "trade_date"])
from src.data.storage import Storage
# 计算收益率未来5日
raw_data = raw_data.with_columns(
[
# 当日收益率
((pl.col("close") - pl.col("pre_close")) / pl.col("pre_close")).alias(
"daily_return"
),
]
storage = Storage()
# 从数据库获取收盘价数据用于计算标签
start_dt = f"{start_date[:4]}-{start_date[4:6]}-{start_date[6:8]}"
end_dt = f"{end_date[:4]}-{end_date[4:6]}-{end_date[6:8]}"
# 需要多取5天数据来计算未来收益率
end_dt_extended = f"{end_date[:4]}-{end_date[4:6]}-{int(end_date[6:8]) + 5}"
price_query = f"""
SELECT ts_code, trade_date, close
FROM daily
WHERE trade_date >= '{start_dt}' AND trade_date <= '{end_dt_extended}'
ORDER BY ts_code, trade_date
"""
price_data = storage._connection.sql(price_query).pl()
price_data = price_data.with_columns(
pl.col("trade_date").dt.strftime("%Y-%m-%d").alias("trade_date")
)
# 按股票分组计算
# 按股票计算未来5日收益率
result_list = []
for ts_code in price_data["ts_code"].unique():
stock_data = price_data.filter(pl.col("ts_code") == ts_code).sort("trade_date")
for ts_code in raw_data["ts_code"].unique():
stock_data = raw_data.filter(pl.col("ts_code") == ts_code).sort("trade_date")
if len(stock_data) < 20:
if len(stock_data) < 6:
continue
# 计算MA5和MA10
stock_data = stock_data.with_columns(
[
pl.col("close").rolling_mean(5).alias("ma_5"),
pl.col("close").rolling_mean(10).alias("ma_10"),
]
)
# 计算未来5日收益率用于标签
# 计算未来5日收益率
future_return = stock_data["close"].shift(-5) - stock_data["close"]
future_return_pct = future_return / stock_data["close"]
stock_data = stock_data.with_columns(
@@ -205,46 +386,21 @@ def _compute_features_and_label(
]
)
result_list.append(stock_data)
result_list.append(stock_data.select(["trade_date", "ts_code", "label"]))
if not result_list:
return pl.DataFrame()
result = pl.concat(result_list)
label_df = pl.concat(result_list)
# 转换日期格式YYYYMMDD -> YYYY-MM-DD
start_date_formatted = f"{start_date[:4]}-{start_date[4:6]}-{start_date[6:8]}"
end_date_formatted = f"{end_date[:4]}-{end_date[4:6]}-{end_date[6:8]}"
# 将标签合并到因子数据
result = features_df.join(label_df, on=["trade_date", "ts_code"], how="inner")
# 过滤有效日期范围
result = result.filter(
(pl.col("trade_date") >= start_date_formatted) & (pl.col("trade_date") <= end_date_formatted)
(pl.col("trade_date") >= start_dt) & (pl.col("trade_date") <= end_dt)
)
# 计算5日收益率排名截面
result = result.with_columns(
[
pl.col("daily_return")
.rank(method="average")
.over("trade_date")
.alias("return_5_rank")
]
)
# 归一化排名到0-1
result = result.with_columns(
[
(
pl.col("return_5_rank")
/ pl.col("return_5_rank").max().over("trade_date")
).alias("return_5_rank")
]
)
# 选择需要的列
feature_cols = ["trade_date", "ts_code", "return_5_rank", "ma_5", "ma_10", "label"]
result = result.select(feature_cols)
return result
@@ -271,7 +427,7 @@ def train_model(
feature_cols: List[str],
label_col: str = "label",
model_params: Optional[dict] = None,
) -> Tuple[LightGBMModel, ProcessingPipeline]:
) -> tuple[LightGBMModel, ProcessingPipeline]:
"""训练LightGBM分类模型
Args:
@@ -301,8 +457,12 @@ def train_model(
valid_mask = y_train.is_in([0, 1])
X_train_processed = X_train_processed.filter(valid_mask)
y_train = y_train.filter(valid_mask)
print(f"[TrainModel] After filtering valid labels: {len(X_train_processed)} samples")
print(f"[TrainModel] Train label distribution: {dict(y_train.value_counts().sort('label').iter_rows())}")
print(
f"[TrainModel] After filtering valid labels: {len(X_train_processed)} samples"
)
print(
f"[TrainModel] Train label distribution: {dict(y_train.value_counts().sort('label').iter_rows())}"
)
# 准备验证集
X_val_processed = None
@@ -311,16 +471,18 @@ def train_model(
X_val = val_data.select(feature_cols)
y_val = val_data[label_col]
print(f"[TrainModel] Val samples: {len(X_val)}")
# 处理验证集数据(使用训练集的参数)
X_val_processed = pipeline.transform(X_val, stage=PipelineStage.TEST)
# 过滤验证集有效标签
val_valid_mask = y_val.is_in([0, 1])
X_val_processed = X_val_processed.filter(val_valid_mask)
y_val = y_val.filter(val_valid_mask)
print(f"[TrainModel] Val after filtering: {len(X_val_processed)} samples")
print(f"[TrainModel] Val label distribution: {dict(y_val.value_counts().sort('label').iter_rows())}")
print(
f"[TrainModel] Val label distribution: {dict(y_val.value_counts().sort('label').iter_rows())}"
)
# 创建模型
params = model_params or {
@@ -384,7 +546,10 @@ def predict_top_stocks(
# 使用 key_data 添加预测结果,保持行数一致
result = key_data.with_columns(
pl.Series(
name="pred_prob", values=probs[:, 1] if len(probs.shape) > 1 and probs.shape[1] > 1 else probs.flatten()
name="pred_prob",
values=probs[:, 1]
if len(probs.shape) > 1 and probs.shape[1] > 1
else probs.flatten(),
),
)
@@ -396,7 +561,11 @@ def predict_top_stocks(
# 按概率降序排序选出top N
day_top = day_data.sort("pred_prob", descending=True).head(top_n)
top_stocks.append(day_top.select(["trade_date", "pred_prob", "ts_code"]).rename({"pred_prob": "score"}))
top_stocks.append(
day_top.select(["trade_date", "pred_prob", "ts_code"]).rename(
{"pred_prob": "score"}
)
)
return pl.concat(top_stocks)
@@ -415,6 +584,7 @@ def save_top_stocks(top_stocks: pl.DataFrame, output_path: str) -> None:
def run_training(
factors: Optional[List[BaseFactor]] = None,
data_dir: str = "data",
output_path: str = "output/top_stocks.tsv",
train_start: str = "20180101",
@@ -428,6 +598,7 @@ def run_training(
"""运行完整训练流程
Args:
factors: 因子实例列表,默认为 None使用 MA5, MA10, ReturnRank5
data_dir: 数据目录
output_path: 输出文件路径
train_start: 训练集开始日期
@@ -448,7 +619,8 @@ def run_training(
# 1. 准备数据
print("[Training] Preparing data...")
train_data, val_data, test_data = prepare_data(
train_data, val_data, test_data, factor_config = prepare_data(
factors=factors,
data_dir=data_dir,
train_start=train_start,
train_end=train_end,
@@ -461,9 +633,10 @@ def run_training(
print(f"[Training] Val samples: {len(val_data)}")
print(f"[Training] Test samples: {len(test_data)}")
# 2. 定义特征列
feature_cols = ["return_5_rank", "ma_5", "ma_10"]
# 2. 获取特征列
feature_cols = factor_config.get_feature_names()
label_col = "label"
print(f"[Training] Feature columns: {feature_cols}")
# 3. 训练模型
print("[Training] Training model...")