2026-02-23 16:23:53 +08:00
|
|
|
|
"""训练管道 - 包含数据处理、模型训练和预测功能
|
|
|
|
|
|
|
|
|
|
|
|
本模块提供:
|
2026-02-25 23:39:02 +08:00
|
|
|
|
1. 数据准备:使用 FactorEngine 从因子计算结果中准备训练/测试数据
|
2026-02-23 16:23:53 +08:00
|
|
|
|
2. 数据处理:Fillna(0) -> Dropna
|
|
|
|
|
|
3. 模型训练:使用LightGBM训练分类模型
|
|
|
|
|
|
4. 预测和选股:输出每日top5股票池
|
2026-02-25 23:39:02 +08:00
|
|
|
|
|
|
|
|
|
|
注意:本模块使用 src.factors 框架进行因子计算,确保防泄露机制生效。
|
|
|
|
|
|
|
|
|
|
|
|
因子配置示例:
|
|
|
|
|
|
from src.factors import MovingAverageFactor, ReturnRankFactor
|
|
|
|
|
|
from src.training.pipeline import prepare_data
|
|
|
|
|
|
|
|
|
|
|
|
# 直接传入因子实例列表 - 简单直观
|
|
|
|
|
|
factors = [
|
|
|
|
|
|
MovingAverageFactor(period=5),
|
|
|
|
|
|
MovingAverageFactor(period=10),
|
|
|
|
|
|
ReturnRankFactor(period=5),
|
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
|
|
train_data, val_data, test_data, factor_config = prepare_data(
|
|
|
|
|
|
factors=factors,
|
|
|
|
|
|
...
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
# 或者使用 FactorConfig 包装(支持链式添加)
|
|
|
|
|
|
from src.training.pipeline import FactorConfig
|
|
|
|
|
|
|
|
|
|
|
|
config = FactorConfig()
|
|
|
|
|
|
.add(MovingAverageFactor(period=5))
|
|
|
|
|
|
.add(MovingAverageFactor(period=10))
|
|
|
|
|
|
.add(ReturnRankFactor(period=5))
|
2026-02-23 16:23:53 +08:00
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
from datetime import datetime
|
|
|
|
|
|
from pathlib import Path
|
2026-02-25 23:39:02 +08:00
|
|
|
|
from typing import List, Optional
|
2026-02-23 16:23:53 +08:00
|
|
|
|
|
|
|
|
|
|
import numpy as np
|
|
|
|
|
|
import polars as pl
|
|
|
|
|
|
|
2026-02-25 23:39:02 +08:00
|
|
|
|
from src.factors import DataLoader, FactorEngine, BaseFactor
|
2026-02-23 16:23:53 +08:00
|
|
|
|
from src.factors.data_spec import DataSpec
|
2026-02-25 23:39:02 +08:00
|
|
|
|
from src.factors.momentum import MovingAverageFactor, ReturnRankFactor
|
2026-02-23 16:23:53 +08:00
|
|
|
|
from src.pipeline import (
|
|
|
|
|
|
DropNAProcessor,
|
|
|
|
|
|
FillNAProcessor,
|
|
|
|
|
|
LightGBMModel,
|
|
|
|
|
|
PipelineStage,
|
|
|
|
|
|
ProcessingPipeline,
|
|
|
|
|
|
TaskType,
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
2026-02-25 23:39:02 +08:00
|
|
|
|
# ========== 因子配置类 ==========
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class FactorConfig:
|
|
|
|
|
|
"""因子配置类 - 管理因子列表
|
|
|
|
|
|
|
|
|
|
|
|
用于包装因子实例列表,支持链式添加。
|
|
|
|
|
|
|
|
|
|
|
|
示例:
|
|
|
|
|
|
# 方式1:初始化时传入列表
|
|
|
|
|
|
config = FactorConfig([
|
|
|
|
|
|
MovingAverageFactor(period=5),
|
|
|
|
|
|
ReturnRankFactor(period=5),
|
|
|
|
|
|
])
|
|
|
|
|
|
|
|
|
|
|
|
# 方式2:链式添加
|
|
|
|
|
|
config = FactorConfig()
|
|
|
|
|
|
.add(MovingAverageFactor(period=5))
|
|
|
|
|
|
.add(ReturnRankFactor(period=5))
|
|
|
|
|
|
|
|
|
|
|
|
# 获取因子实例列表
|
|
|
|
|
|
factors = config.get_factors()
|
|
|
|
|
|
|
|
|
|
|
|
# 获取特征列名
|
|
|
|
|
|
feature_cols = config.get_feature_names()
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
def __init__(self, factors: Optional[List[BaseFactor]] = None):
|
|
|
|
|
|
"""初始化因子配置
|
|
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
|
factors: 因子实例列表
|
|
|
|
|
|
"""
|
|
|
|
|
|
self._factors: List[BaseFactor] = factors or []
|
|
|
|
|
|
|
|
|
|
|
|
def add(self, factor: BaseFactor) -> "FactorConfig":
|
|
|
|
|
|
"""添加因子到配置
|
|
|
|
|
|
|
|
|
|
|
|
支持链式调用:
|
|
|
|
|
|
config = FactorConfig()
|
|
|
|
|
|
.add(MovingAverageFactor(period=5))
|
|
|
|
|
|
.add(ReturnRankFactor(period=5))
|
|
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
|
factor: 因子实例
|
|
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
|
self,支持链式调用
|
|
|
|
|
|
"""
|
|
|
|
|
|
if not isinstance(factor, BaseFactor):
|
|
|
|
|
|
raise ValueError(f"必须是 BaseFactor 实例, got {type(factor)}")
|
|
|
|
|
|
self._factors.append(factor)
|
|
|
|
|
|
return self
|
|
|
|
|
|
|
|
|
|
|
|
def get_factors(self) -> List[BaseFactor]:
|
|
|
|
|
|
"""获取因子实例列表
|
|
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
|
因子实例列表
|
|
|
|
|
|
"""
|
|
|
|
|
|
return self._factors
|
|
|
|
|
|
|
|
|
|
|
|
def get_feature_names(self) -> List[str]:
|
|
|
|
|
|
"""获取所有因子的特征列名
|
|
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
|
特征列名列表
|
|
|
|
|
|
"""
|
|
|
|
|
|
return [f.name for f in self._factors]
|
|
|
|
|
|
|
|
|
|
|
|
def get_max_lookback(self) -> int:
|
|
|
|
|
|
"""获取所有因子中最大的 lookback 天数
|
|
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
|
最大 lookback 天数
|
|
|
|
|
|
"""
|
|
|
|
|
|
max_lookback = 0
|
|
|
|
|
|
for factor in self._factors:
|
|
|
|
|
|
for spec in factor.data_specs:
|
|
|
|
|
|
max_lookback = max(max_lookback, spec.lookback_days)
|
|
|
|
|
|
return max_lookback
|
|
|
|
|
|
|
|
|
|
|
|
def __len__(self) -> int:
|
|
|
|
|
|
return len(self._factors)
|
|
|
|
|
|
|
|
|
|
|
|
def __repr__(self) -> str:
|
|
|
|
|
|
names = [f.name for f in self._factors]
|
|
|
|
|
|
return f"FactorConfig({names})"
|
|
|
|
|
|
|
|
|
|
|
|
|
2026-02-23 16:23:53 +08:00
|
|
|
|
def prepare_data(
|
2026-02-25 23:39:02 +08:00
|
|
|
|
factors: Optional[List[BaseFactor]] = None,
|
2026-02-23 16:23:53 +08:00
|
|
|
|
data_dir: str = "data",
|
|
|
|
|
|
train_start: str = "20180101",
|
|
|
|
|
|
train_end: str = "20230101",
|
2026-02-25 21:11:19 +08:00
|
|
|
|
val_start: str = "20230101",
|
|
|
|
|
|
val_end: str = "20230601",
|
|
|
|
|
|
test_start: str = "20230601",
|
2026-02-23 16:23:53 +08:00
|
|
|
|
test_end: str = "20240101",
|
2026-02-25 23:39:02 +08:00
|
|
|
|
) -> tuple[pl.DataFrame, pl.DataFrame, pl.DataFrame, FactorConfig]:
|
2026-02-25 21:11:19 +08:00
|
|
|
|
"""准备训练、验证和测试数据
|
2026-02-23 16:23:53 +08:00
|
|
|
|
|
2026-02-25 23:39:02 +08:00
|
|
|
|
使用 FactorEngine 计算因子,确保防泄露机制生效。
|
2026-02-23 16:23:53 +08:00
|
|
|
|
|
|
|
|
|
|
Args:
|
2026-02-25 23:39:02 +08:00
|
|
|
|
factors: 因子实例列表,默认为 None(使用 MA5, MA10, ReturnRank5)
|
2026-02-23 16:23:53 +08:00
|
|
|
|
data_dir: 数据目录
|
|
|
|
|
|
train_start: 训练集开始日期
|
|
|
|
|
|
train_end: 训练集结束日期
|
2026-02-25 21:11:19 +08:00
|
|
|
|
val_start: 验证集开始日期
|
|
|
|
|
|
val_end: 验证集结束日期
|
2026-02-23 16:23:53 +08:00
|
|
|
|
test_start: 测试集开始日期
|
|
|
|
|
|
test_end: 测试集结束日期
|
|
|
|
|
|
|
|
|
|
|
|
Returns:
|
2026-02-25 23:39:02 +08:00
|
|
|
|
(train_data, val_data, test_data, factor_config):
|
|
|
|
|
|
训练集、验证集、测试集的DataFrame,以及使用的因子配置
|
2026-02-23 16:23:53 +08:00
|
|
|
|
"""
|
|
|
|
|
|
from src.data.storage import Storage
|
|
|
|
|
|
|
|
|
|
|
|
storage = Storage()
|
|
|
|
|
|
|
2026-02-25 23:39:02 +08:00
|
|
|
|
# 1. 处理因子配置
|
|
|
|
|
|
if factors is None:
|
|
|
|
|
|
# 默认因子配置
|
|
|
|
|
|
factor_config = FactorConfig(
|
|
|
|
|
|
[
|
|
|
|
|
|
MovingAverageFactor(period=5),
|
|
|
|
|
|
MovingAverageFactor(period=10),
|
|
|
|
|
|
ReturnRankFactor(period=5),
|
|
|
|
|
|
]
|
|
|
|
|
|
)
|
|
|
|
|
|
elif isinstance(factors, FactorConfig):
|
|
|
|
|
|
factor_config = factors
|
|
|
|
|
|
elif isinstance(factors, list):
|
|
|
|
|
|
# 转换为 FactorConfig
|
|
|
|
|
|
factor_config = FactorConfig(factors)
|
|
|
|
|
|
else:
|
|
|
|
|
|
raise ValueError(
|
|
|
|
|
|
f"factors 必须是 List[BaseFactor] 或 FactorConfig, got {type(factors)}"
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
factors_list = factor_config.get_factors()
|
|
|
|
|
|
feature_cols = factor_config.get_feature_names()
|
|
|
|
|
|
|
|
|
|
|
|
if not factors_list:
|
|
|
|
|
|
raise ValueError("至少需要提供一个因子")
|
|
|
|
|
|
|
|
|
|
|
|
print(f"[PrepareData] 使用因子: {feature_cols}")
|
|
|
|
|
|
|
|
|
|
|
|
# 2. 初始化 FactorEngine
|
|
|
|
|
|
loader = DataLoader(data_dir=data_dir)
|
|
|
|
|
|
engine = FactorEngine(loader)
|
|
|
|
|
|
|
|
|
|
|
|
# 3. 计算需要的回溯天数
|
|
|
|
|
|
max_lookback = factor_config.get_max_lookback()
|
2026-02-23 16:23:53 +08:00
|
|
|
|
start_with_lookback = str(int(train_start) - 10000) # 往前取一年
|
|
|
|
|
|
|
2026-02-25 23:39:02 +08:00
|
|
|
|
# 获取所有股票列表
|
2026-02-23 16:23:53 +08:00
|
|
|
|
start_dt = f"{start_with_lookback[:4]}-{start_with_lookback[4:6]}-{start_with_lookback[6:8]}"
|
2026-02-25 21:11:19 +08:00
|
|
|
|
end_dt = f"{test_end[:4]}-{test_end[4:6]}-{test_end[6:8]}"
|
2026-02-25 23:39:02 +08:00
|
|
|
|
all_stocks_query = f"""
|
|
|
|
|
|
SELECT DISTINCT ts_code FROM daily
|
2026-02-23 16:23:53 +08:00
|
|
|
|
WHERE trade_date >= '{start_dt}' AND trade_date <= '{end_dt}'
|
|
|
|
|
|
"""
|
2026-02-25 23:39:02 +08:00
|
|
|
|
all_stocks_df = storage._connection.sql(all_stocks_query).pl()
|
|
|
|
|
|
all_stocks = all_stocks_df["ts_code"].to_list()
|
|
|
|
|
|
|
|
|
|
|
|
print(f"[PrepareData] 股票数量: {len(all_stocks)}")
|
|
|
|
|
|
|
|
|
|
|
|
# 4. 计算所有因子并合并
|
|
|
|
|
|
all_features = None
|
|
|
|
|
|
|
|
|
|
|
|
for factor in factors_list:
|
|
|
|
|
|
print(f"[PrepareData] 计算因子: {factor.name} ({factor.factor_type})")
|
|
|
|
|
|
|
|
|
|
|
|
if factor.factor_type == "time_series":
|
|
|
|
|
|
# 时序因子需要传入股票列表
|
|
|
|
|
|
result = engine.compute(
|
|
|
|
|
|
factor,
|
|
|
|
|
|
stock_codes=all_stocks,
|
|
|
|
|
|
start_date=start_with_lookback,
|
|
|
|
|
|
end_date=test_end,
|
|
|
|
|
|
)
|
|
|
|
|
|
else:
|
|
|
|
|
|
# 截面因子不需要股票列表
|
|
|
|
|
|
result = engine.compute(
|
|
|
|
|
|
factor,
|
|
|
|
|
|
start_date=start_with_lookback,
|
|
|
|
|
|
end_date=test_end,
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
# 合并结果
|
|
|
|
|
|
if all_features is None:
|
|
|
|
|
|
all_features = result
|
|
|
|
|
|
else:
|
|
|
|
|
|
# 确保没有重复的 _right 列
|
|
|
|
|
|
result = result.select([
|
|
|
|
|
|
c for c in result.columns if not c.endswith("_right")
|
|
|
|
|
|
])
|
|
|
|
|
|
all_features = all_features.select([
|
|
|
|
|
|
c for c in all_features.columns if not c.endswith("_right")
|
|
|
|
|
|
])
|
|
|
|
|
|
all_features = all_features.join(
|
|
|
|
|
|
result, on=["trade_date", "ts_code"], how="outer"
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
if all_features is None:
|
|
|
|
|
|
raise ValueError("没有计算任何因子")
|
|
|
|
|
|
|
|
|
|
|
|
# 5. 计算标签(未来5日收益率)
|
|
|
|
|
|
all_data = _compute_label(all_features, start_date=train_start, end_date=test_end)
|
|
|
|
|
|
|
|
|
|
|
|
# 6. 过滤不符合条件的股票
|
|
|
|
|
|
all_data = _filter_invalid_stocks(all_data)
|
|
|
|
|
|
print(f"[PrepareData] After filtering: total={len(all_data)}")
|
2026-02-23 16:23:53 +08:00
|
|
|
|
|
2026-02-25 21:11:19 +08:00
|
|
|
|
# 转换日期格式用于比较
|
|
|
|
|
|
train_start_fmt = f"{train_start[:4]}-{train_start[4:6]}-{train_start[6:8]}"
|
|
|
|
|
|
train_end_fmt = f"{train_end[:4]}-{train_end[4:6]}-{train_end[6:8]}"
|
|
|
|
|
|
val_start_fmt = f"{val_start[:4]}-{val_start[4:6]}-{val_start[6:8]}"
|
|
|
|
|
|
val_end_fmt = f"{val_end[:4]}-{val_end[4:6]}-{val_end[6:8]}"
|
|
|
|
|
|
test_start_fmt = f"{test_start[:4]}-{test_start[4:6]}-{test_start[6:8]}"
|
|
|
|
|
|
test_end_fmt = f"{test_end[:4]}-{test_end[4:6]}-{test_end[6:8]}"
|
|
|
|
|
|
|
|
|
|
|
|
# 拆分数据
|
|
|
|
|
|
train_data = all_data.filter(
|
2026-02-25 23:39:02 +08:00
|
|
|
|
(pl.col("trade_date") >= train_start_fmt)
|
|
|
|
|
|
& (pl.col("trade_date") <= train_end_fmt)
|
2026-02-25 21:11:19 +08:00
|
|
|
|
)
|
|
|
|
|
|
val_data = all_data.filter(
|
|
|
|
|
|
(pl.col("trade_date") >= val_start_fmt) & (pl.col("trade_date") <= val_end_fmt)
|
|
|
|
|
|
)
|
|
|
|
|
|
test_data = all_data.filter(
|
2026-02-25 23:39:02 +08:00
|
|
|
|
(pl.col("trade_date") >= test_start_fmt)
|
|
|
|
|
|
& (pl.col("trade_date") <= test_end_fmt)
|
2026-02-25 21:11:19 +08:00
|
|
|
|
)
|
2026-02-23 16:23:53 +08:00
|
|
|
|
|
2026-02-25 23:39:02 +08:00
|
|
|
|
print(
|
|
|
|
|
|
f"[PrepareData] Split result: train={len(train_data)}, val={len(val_data)}, test={len(test_data)}"
|
|
|
|
|
|
)
|
2026-02-23 16:23:53 +08:00
|
|
|
|
|
2026-02-25 23:39:02 +08:00
|
|
|
|
return train_data, val_data, test_data, factor_config
|
2026-02-23 16:23:53 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _filter_invalid_stocks(df: pl.DataFrame) -> pl.DataFrame:
|
|
|
|
|
|
"""过滤不符合条件的股票
|
|
|
|
|
|
|
|
|
|
|
|
过滤规则:
|
|
|
|
|
|
1. 过滤北交所股票(ts_code 以 BJ 结尾)
|
|
|
|
|
|
2. 过滤创业板股票(ts_code 以 30 开头)
|
|
|
|
|
|
3. 过滤科创板股票(ts_code 以 68 开头)
|
|
|
|
|
|
4. 过滤退市/风险股票(ts_code 以 8 开头)
|
|
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
|
df: 原始数据
|
|
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
|
过滤后的数据
|
|
|
|
|
|
"""
|
|
|
|
|
|
ts_code_col = pl.col("ts_code")
|
|
|
|
|
|
|
|
|
|
|
|
return df.filter(
|
|
|
|
|
|
~ts_code_col.str.ends_with("BJ")
|
|
|
|
|
|
& ~ts_code_col.str.starts_with("30")
|
|
|
|
|
|
& ~ts_code_col.str.starts_with("68")
|
|
|
|
|
|
& ~ts_code_col.str.starts_with("8")
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
2026-02-25 23:39:02 +08:00
|
|
|
|
def _compute_label(
|
|
|
|
|
|
features_df: pl.DataFrame,
|
2026-02-23 16:23:53 +08:00
|
|
|
|
start_date: str,
|
|
|
|
|
|
end_date: str,
|
|
|
|
|
|
) -> pl.DataFrame:
|
2026-02-25 23:39:02 +08:00
|
|
|
|
"""计算标签(未来5日收益率)
|
2026-02-23 16:23:53 +08:00
|
|
|
|
|
2026-02-25 23:39:02 +08:00
|
|
|
|
标签定义:未来5日收益率大于0为1,否则为0
|
2026-02-23 16:23:53 +08:00
|
|
|
|
|
|
|
|
|
|
Args:
|
2026-02-25 23:39:02 +08:00
|
|
|
|
features_df: 包含因子的DataFrame
|
2026-02-23 16:23:53 +08:00
|
|
|
|
start_date: 开始日期
|
|
|
|
|
|
end_date: 结束日期
|
|
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
|
包含因子和标签的DataFrame
|
|
|
|
|
|
"""
|
2026-02-25 23:39:02 +08:00
|
|
|
|
from src.data.storage import Storage
|
|
|
|
|
|
|
|
|
|
|
|
storage = Storage()
|
|
|
|
|
|
|
|
|
|
|
|
# 从数据库获取收盘价数据用于计算标签
|
|
|
|
|
|
start_dt = f"{start_date[:4]}-{start_date[4:6]}-{start_date[6:8]}"
|
|
|
|
|
|
end_dt = f"{end_date[:4]}-{end_date[4:6]}-{end_date[6:8]}"
|
|
|
|
|
|
|
|
|
|
|
|
# 需要多取5天数据来计算未来收益率
|
|
|
|
|
|
end_dt_extended = f"{end_date[:4]}-{end_date[4:6]}-{int(end_date[6:8]) + 5}"
|
|
|
|
|
|
|
|
|
|
|
|
price_query = f"""
|
|
|
|
|
|
SELECT ts_code, trade_date, close
|
|
|
|
|
|
FROM daily
|
|
|
|
|
|
WHERE trade_date >= '{start_dt}' AND trade_date <= '{end_dt_extended}'
|
|
|
|
|
|
ORDER BY ts_code, trade_date
|
|
|
|
|
|
"""
|
|
|
|
|
|
price_data = storage._connection.sql(price_query).pl()
|
|
|
|
|
|
price_data = price_data.with_columns(
|
|
|
|
|
|
pl.col("trade_date").dt.strftime("%Y-%m-%d").alias("trade_date")
|
2026-02-23 16:23:53 +08:00
|
|
|
|
)
|
|
|
|
|
|
|
2026-02-25 23:39:02 +08:00
|
|
|
|
# 按股票计算未来5日收益率
|
2026-02-23 16:23:53 +08:00
|
|
|
|
result_list = []
|
2026-02-25 23:39:02 +08:00
|
|
|
|
for ts_code in price_data["ts_code"].unique():
|
|
|
|
|
|
stock_data = price_data.filter(pl.col("ts_code") == ts_code).sort("trade_date")
|
2026-02-23 16:23:53 +08:00
|
|
|
|
|
2026-02-25 23:39:02 +08:00
|
|
|
|
if len(stock_data) < 6:
|
2026-02-23 16:23:53 +08:00
|
|
|
|
continue
|
|
|
|
|
|
|
2026-02-25 23:39:02 +08:00
|
|
|
|
# 计算未来5日收益率
|
2026-02-23 16:23:53 +08:00
|
|
|
|
future_return = stock_data["close"].shift(-5) - stock_data["close"]
|
|
|
|
|
|
future_return_pct = future_return / stock_data["close"]
|
|
|
|
|
|
stock_data = stock_data.with_columns(
|
|
|
|
|
|
[
|
|
|
|
|
|
future_return_pct.alias("future_return_5"),
|
|
|
|
|
|
]
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
# 生成标签:收益率>0为1,否则为0
|
|
|
|
|
|
stock_data = stock_data.with_columns(
|
|
|
|
|
|
[
|
|
|
|
|
|
(pl.col("future_return_5") > 0).cast(pl.Int8).alias("label"),
|
|
|
|
|
|
]
|
|
|
|
|
|
)
|
|
|
|
|
|
|
2026-02-25 23:39:02 +08:00
|
|
|
|
result_list.append(stock_data.select(["trade_date", "ts_code", "label"]))
|
2026-02-23 16:23:53 +08:00
|
|
|
|
|
|
|
|
|
|
if not result_list:
|
|
|
|
|
|
return pl.DataFrame()
|
|
|
|
|
|
|
2026-02-25 23:39:02 +08:00
|
|
|
|
label_df = pl.concat(result_list)
|
2026-02-23 16:23:53 +08:00
|
|
|
|
|
2026-02-25 23:39:02 +08:00
|
|
|
|
# 将标签合并到因子数据
|
|
|
|
|
|
result = features_df.join(label_df, on=["trade_date", "ts_code"], how="inner")
|
2026-02-23 16:23:53 +08:00
|
|
|
|
|
|
|
|
|
|
# 过滤有效日期范围
|
|
|
|
|
|
result = result.filter(
|
2026-02-25 23:39:02 +08:00
|
|
|
|
(pl.col("trade_date") >= start_dt) & (pl.col("trade_date") <= end_dt)
|
2026-02-23 16:23:53 +08:00
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
return result
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def create_pipeline() -> ProcessingPipeline:
|
|
|
|
|
|
"""创建数据处理流水线
|
|
|
|
|
|
|
|
|
|
|
|
处理流程:
|
|
|
|
|
|
1. FillNA(0): 将缺失值填充为0
|
|
|
|
|
|
|
|
|
|
|
|
注意:不使用 Dropna,因为会导致训练和预测时的行数不匹配
|
|
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
|
配置好的ProcessingPipeline
|
|
|
|
|
|
"""
|
|
|
|
|
|
processors = [
|
|
|
|
|
|
FillNAProcessor(method="zero"), # 缺失值填充为0
|
|
|
|
|
|
]
|
|
|
|
|
|
return ProcessingPipeline(processors)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def train_model(
|
|
|
|
|
|
train_data: pl.DataFrame,
|
2026-02-25 21:11:19 +08:00
|
|
|
|
val_data: Optional[pl.DataFrame],
|
2026-02-23 16:23:53 +08:00
|
|
|
|
feature_cols: List[str],
|
|
|
|
|
|
label_col: str = "label",
|
|
|
|
|
|
model_params: Optional[dict] = None,
|
2026-02-25 23:39:02 +08:00
|
|
|
|
) -> tuple[LightGBMModel, ProcessingPipeline]:
|
2026-02-23 16:23:53 +08:00
|
|
|
|
"""训练LightGBM分类模型
|
|
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
|
train_data: 训练数据
|
2026-02-25 21:11:19 +08:00
|
|
|
|
val_data: 验证数据(用于早停)
|
2026-02-23 16:23:53 +08:00
|
|
|
|
feature_cols: 特征列名列表
|
|
|
|
|
|
label_col: 标签列名
|
|
|
|
|
|
model_params: 模型参数字典
|
|
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
|
(训练好的模型, 处理流水线)
|
|
|
|
|
|
"""
|
|
|
|
|
|
# 创建处理流水线
|
|
|
|
|
|
pipeline = create_pipeline()
|
|
|
|
|
|
print("[TrainModel] Pipeline created: FillNA(0)")
|
|
|
|
|
|
|
2026-02-25 21:11:19 +08:00
|
|
|
|
# 准备训练特征和标签
|
2026-02-23 16:23:53 +08:00
|
|
|
|
X_train = train_data.select(feature_cols)
|
|
|
|
|
|
y_train = train_data[label_col]
|
2026-02-25 21:11:19 +08:00
|
|
|
|
print(f"[TrainModel] Train samples: {len(X_train)}, features: {feature_cols}")
|
2026-02-23 16:23:53 +08:00
|
|
|
|
|
2026-02-25 21:11:19 +08:00
|
|
|
|
# 处理训练数据
|
2026-02-23 16:23:53 +08:00
|
|
|
|
X_train_processed = pipeline.fit_transform(X_train, stage=PipelineStage.TRAIN)
|
|
|
|
|
|
print(f"[TrainModel] After processing: {len(X_train_processed)} samples")
|
|
|
|
|
|
|
2026-02-25 21:11:19 +08:00
|
|
|
|
# 过滤训练集有效标签(排除-1等无效值)
|
2026-02-23 16:23:53 +08:00
|
|
|
|
valid_mask = y_train.is_in([0, 1])
|
|
|
|
|
|
X_train_processed = X_train_processed.filter(valid_mask)
|
|
|
|
|
|
y_train = y_train.filter(valid_mask)
|
2026-02-25 23:39:02 +08:00
|
|
|
|
print(
|
|
|
|
|
|
f"[TrainModel] After filtering valid labels: {len(X_train_processed)} samples"
|
|
|
|
|
|
)
|
|
|
|
|
|
print(
|
|
|
|
|
|
f"[TrainModel] Train label distribution: {dict(y_train.value_counts().sort('label').iter_rows())}"
|
|
|
|
|
|
)
|
2026-02-25 21:11:19 +08:00
|
|
|
|
|
|
|
|
|
|
# 准备验证集
|
|
|
|
|
|
X_val_processed = None
|
|
|
|
|
|
y_val = None
|
|
|
|
|
|
if val_data is not None and len(val_data) > 0:
|
|
|
|
|
|
X_val = val_data.select(feature_cols)
|
|
|
|
|
|
y_val = val_data[label_col]
|
|
|
|
|
|
print(f"[TrainModel] Val samples: {len(X_val)}")
|
2026-02-25 23:39:02 +08:00
|
|
|
|
|
2026-02-25 21:11:19 +08:00
|
|
|
|
# 处理验证集数据(使用训练集的参数)
|
|
|
|
|
|
X_val_processed = pipeline.transform(X_val, stage=PipelineStage.TEST)
|
2026-02-25 23:39:02 +08:00
|
|
|
|
|
2026-02-25 21:11:19 +08:00
|
|
|
|
# 过滤验证集有效标签
|
|
|
|
|
|
val_valid_mask = y_val.is_in([0, 1])
|
|
|
|
|
|
X_val_processed = X_val_processed.filter(val_valid_mask)
|
|
|
|
|
|
y_val = y_val.filter(val_valid_mask)
|
|
|
|
|
|
print(f"[TrainModel] Val after filtering: {len(X_val_processed)} samples")
|
2026-02-25 23:39:02 +08:00
|
|
|
|
print(
|
|
|
|
|
|
f"[TrainModel] Val label distribution: {dict(y_val.value_counts().sort('label').iter_rows())}"
|
|
|
|
|
|
)
|
2026-02-23 16:23:53 +08:00
|
|
|
|
|
|
|
|
|
|
# 创建模型
|
|
|
|
|
|
params = model_params or {
|
|
|
|
|
|
"n_estimators": 100,
|
|
|
|
|
|
"learning_rate": 0.05,
|
|
|
|
|
|
"max_depth": 5,
|
|
|
|
|
|
"num_leaves": 31,
|
|
|
|
|
|
}
|
|
|
|
|
|
print(f"[TrainModel] Model params: {params}")
|
|
|
|
|
|
model = LightGBMModel(
|
|
|
|
|
|
task_type="classification",
|
|
|
|
|
|
params=params,
|
|
|
|
|
|
)
|
|
|
|
|
|
|
2026-02-25 21:11:19 +08:00
|
|
|
|
# 训练模型(使用验证集早停)
|
2026-02-23 16:23:53 +08:00
|
|
|
|
print("[TrainModel] Training LightGBM...")
|
2026-02-25 21:11:19 +08:00
|
|
|
|
if X_val_processed is not None and y_val is not None:
|
|
|
|
|
|
print("[TrainModel] Using validation set for early stopping")
|
|
|
|
|
|
model.fit(X_train_processed, y_train, X_val_processed, y_val)
|
|
|
|
|
|
else:
|
|
|
|
|
|
model.fit(X_train_processed, y_train)
|
2026-02-23 16:23:53 +08:00
|
|
|
|
print("[TrainModel] Training completed!")
|
|
|
|
|
|
|
|
|
|
|
|
return model, pipeline
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def predict_top_stocks(
|
|
|
|
|
|
model: LightGBMModel,
|
|
|
|
|
|
pipeline: ProcessingPipeline,
|
|
|
|
|
|
test_data: pl.DataFrame,
|
|
|
|
|
|
feature_cols: List[str],
|
|
|
|
|
|
top_n: int = 5,
|
|
|
|
|
|
) -> pl.DataFrame:
|
|
|
|
|
|
"""预测并选出每日top N股票
|
|
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
|
model: 训练好的模型
|
|
|
|
|
|
pipeline: 数据处理流水线
|
|
|
|
|
|
test_data: 测试数据
|
|
|
|
|
|
feature_cols: 特征列名
|
|
|
|
|
|
top_n: 每日选出的股票数量
|
|
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
|
包含日期和股票代码的DataFrame
|
|
|
|
|
|
"""
|
|
|
|
|
|
# 准备特征和必要列
|
|
|
|
|
|
X_test = test_data.select(feature_cols)
|
|
|
|
|
|
key_cols = ["trade_date", "ts_code"]
|
|
|
|
|
|
key_data = test_data.select(key_cols)
|
|
|
|
|
|
|
|
|
|
|
|
print(f"[Predict] Test samples: {len(X_test)}, top_n: {top_n}")
|
|
|
|
|
|
|
|
|
|
|
|
# 处理数据(使用训练阶段的参数)
|
|
|
|
|
|
X_test_processed = pipeline.transform(X_test, stage=PipelineStage.TEST)
|
|
|
|
|
|
print(f"[Predict] Data processed, shape: {X_test_processed.shape}")
|
|
|
|
|
|
|
|
|
|
|
|
# 预测概率
|
|
|
|
|
|
probs = model.predict_proba(X_test_processed)
|
|
|
|
|
|
print(f"[Predict] Predictions generated, probability shape: {probs.shape}")
|
|
|
|
|
|
|
|
|
|
|
|
# 使用 key_data 添加预测结果,保持行数一致
|
|
|
|
|
|
result = key_data.with_columns(
|
|
|
|
|
|
pl.Series(
|
2026-02-25 23:39:02 +08:00
|
|
|
|
name="pred_prob",
|
|
|
|
|
|
values=probs[:, 1]
|
|
|
|
|
|
if len(probs.shape) > 1 and probs.shape[1] > 1
|
|
|
|
|
|
else probs.flatten(),
|
2026-02-23 16:23:53 +08:00
|
|
|
|
),
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
# 每日选出top N
|
|
|
|
|
|
top_stocks = []
|
|
|
|
|
|
for date in result["trade_date"].unique().sort():
|
|
|
|
|
|
day_data = result.filter(pl.col("trade_date") == date)
|
|
|
|
|
|
|
|
|
|
|
|
# 按概率降序排序,选出top N
|
|
|
|
|
|
day_top = day_data.sort("pred_prob", descending=True).head(top_n)
|
|
|
|
|
|
|
2026-02-25 23:39:02 +08:00
|
|
|
|
top_stocks.append(
|
|
|
|
|
|
day_top.select(["trade_date", "pred_prob", "ts_code"]).rename(
|
|
|
|
|
|
{"pred_prob": "score"}
|
|
|
|
|
|
)
|
|
|
|
|
|
)
|
2026-02-23 16:23:53 +08:00
|
|
|
|
|
|
|
|
|
|
return pl.concat(top_stocks)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def save_top_stocks(top_stocks: pl.DataFrame, output_path: str) -> None:
|
|
|
|
|
|
"""保存选股结果到TSV文件
|
|
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
|
top_stocks: 选股结果
|
|
|
|
|
|
output_path: 输出文件路径
|
|
|
|
|
|
"""
|
|
|
|
|
|
# 转换为pandas并保存为TSV
|
|
|
|
|
|
df = top_stocks.to_pandas()
|
|
|
|
|
|
df.to_csv(output_path, sep="\t", index=False)
|
|
|
|
|
|
print(f"[Training] Top stocks saved to: {output_path}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def run_training(
|
2026-02-25 23:39:02 +08:00
|
|
|
|
factors: Optional[List[BaseFactor]] = None,
|
2026-02-23 16:23:53 +08:00
|
|
|
|
data_dir: str = "data",
|
|
|
|
|
|
output_path: str = "output/top_stocks.tsv",
|
|
|
|
|
|
train_start: str = "20180101",
|
|
|
|
|
|
train_end: str = "20230101",
|
2026-02-25 21:11:19 +08:00
|
|
|
|
val_start: str = "20230101",
|
|
|
|
|
|
val_end: str = "20230601",
|
|
|
|
|
|
test_start: str = "20230601",
|
2026-02-23 16:23:53 +08:00
|
|
|
|
test_end: str = "20240101",
|
|
|
|
|
|
top_n: int = 5,
|
|
|
|
|
|
) -> pl.DataFrame:
|
|
|
|
|
|
"""运行完整训练流程
|
|
|
|
|
|
|
|
|
|
|
|
Args:
|
2026-02-25 23:39:02 +08:00
|
|
|
|
factors: 因子实例列表,默认为 None(使用 MA5, MA10, ReturnRank5)
|
2026-02-23 16:23:53 +08:00
|
|
|
|
data_dir: 数据目录
|
|
|
|
|
|
output_path: 输出文件路径
|
|
|
|
|
|
train_start: 训练集开始日期
|
|
|
|
|
|
train_end: 训练集结束日期
|
2026-02-25 21:11:19 +08:00
|
|
|
|
val_start: 验证集开始日期
|
|
|
|
|
|
val_end: 验证集结束日期
|
2026-02-23 16:23:53 +08:00
|
|
|
|
test_start: 测试集开始日期
|
|
|
|
|
|
test_end: 测试集结束日期
|
|
|
|
|
|
top_n: 每日选股数量
|
|
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
|
选股结果DataFrame
|
|
|
|
|
|
"""
|
|
|
|
|
|
print(f"[Training] Starting training pipeline...")
|
|
|
|
|
|
print(f"[Training] Train period: {train_start} -> {train_end}")
|
2026-02-25 21:11:19 +08:00
|
|
|
|
print(f"[Training] Val period: {val_start} -> {val_end}")
|
2026-02-23 16:23:53 +08:00
|
|
|
|
print(f"[Training] Test period: {test_start} -> {test_end}")
|
|
|
|
|
|
|
|
|
|
|
|
# 1. 准备数据
|
|
|
|
|
|
print("[Training] Preparing data...")
|
2026-02-25 23:39:02 +08:00
|
|
|
|
train_data, val_data, test_data, factor_config = prepare_data(
|
|
|
|
|
|
factors=factors,
|
2026-02-23 16:23:53 +08:00
|
|
|
|
data_dir=data_dir,
|
|
|
|
|
|
train_start=train_start,
|
|
|
|
|
|
train_end=train_end,
|
2026-02-25 21:11:19 +08:00
|
|
|
|
val_start=val_start,
|
|
|
|
|
|
val_end=val_end,
|
2026-02-23 16:23:53 +08:00
|
|
|
|
test_start=test_start,
|
|
|
|
|
|
test_end=test_end,
|
|
|
|
|
|
)
|
|
|
|
|
|
print(f"[Training] Train samples: {len(train_data)}")
|
2026-02-25 21:11:19 +08:00
|
|
|
|
print(f"[Training] Val samples: {len(val_data)}")
|
2026-02-23 16:23:53 +08:00
|
|
|
|
print(f"[Training] Test samples: {len(test_data)}")
|
|
|
|
|
|
|
2026-02-25 23:39:02 +08:00
|
|
|
|
# 2. 获取特征列名
|
|
|
|
|
|
feature_cols = factor_config.get_feature_names()
|
2026-02-23 16:23:53 +08:00
|
|
|
|
label_col = "label"
|
2026-02-25 23:39:02 +08:00
|
|
|
|
print(f"[Training] Feature columns: {feature_cols}")
|
2026-02-23 16:23:53 +08:00
|
|
|
|
|
|
|
|
|
|
# 3. 训练模型
|
|
|
|
|
|
print("[Training] Training model...")
|
|
|
|
|
|
model, pipeline = train_model(
|
|
|
|
|
|
train_data=train_data,
|
2026-02-25 21:11:19 +08:00
|
|
|
|
val_data=val_data,
|
2026-02-23 16:23:53 +08:00
|
|
|
|
feature_cols=feature_cols,
|
|
|
|
|
|
label_col=label_col,
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
# 4. 测试集预测
|
|
|
|
|
|
print("[Training] Predicting on test set...")
|
|
|
|
|
|
top_stocks = predict_top_stocks(
|
|
|
|
|
|
model=model,
|
|
|
|
|
|
pipeline=pipeline,
|
|
|
|
|
|
test_data=test_data,
|
|
|
|
|
|
feature_cols=feature_cols,
|
|
|
|
|
|
top_n=top_n,
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
# 5. 保存结果
|
|
|
|
|
|
print(f"[Training] Saving results to {output_path}...")
|
|
|
|
|
|
Path(output_path).parent.mkdir(parents=True, exist_ok=True)
|
|
|
|
|
|
save_top_stocks(top_stocks, output_path)
|
|
|
|
|
|
|
|
|
|
|
|
print("[Training] Training completed!")
|
|
|
|
|
|
|
|
|
|
|
|
return top_stocks
|