449 lines
13 KiB
Python
449 lines
13 KiB
Python
|
|
"""训练管道 - 包含数据处理、模型训练和预测功能
|
|||
|
|
|
|||
|
|
本模块提供:
|
|||
|
|
1. 数据准备:从因子计算结果中准备训练/测试数据
|
|||
|
|
2. 数据处理:Fillna(0) -> Dropna
|
|||
|
|
3. 模型训练:使用LightGBM训练分类模型
|
|||
|
|
4. 预测和选股:输出每日top5股票池
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
from datetime import datetime
|
|||
|
|
from pathlib import Path
|
|||
|
|
from typing import List, Optional, Tuple
|
|||
|
|
|
|||
|
|
import numpy as np
|
|||
|
|
import polars as pl
|
|||
|
|
|
|||
|
|
from src.factors import DataLoader, FactorEngine
|
|||
|
|
from src.factors.data_spec import DataSpec
|
|||
|
|
from src.pipeline import (
|
|||
|
|
DropNAProcessor,
|
|||
|
|
FillNAProcessor,
|
|||
|
|
LightGBMModel,
|
|||
|
|
PipelineStage,
|
|||
|
|
ProcessingPipeline,
|
|||
|
|
TaskType,
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
|
|||
|
|
def prepare_data(
|
|||
|
|
data_dir: str = "data",
|
|||
|
|
train_start: str = "20180101",
|
|||
|
|
train_end: str = "20230101",
|
|||
|
|
test_start: str = "20230101",
|
|||
|
|
test_end: str = "20240101",
|
|||
|
|
) -> Tuple[pl.DataFrame, pl.DataFrame]:
|
|||
|
|
"""准备训练和测试数据
|
|||
|
|
|
|||
|
|
从DuckDB加载原始日线数据,计算所需因子并生成标签。
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
data_dir: 数据目录
|
|||
|
|
train_start: 训练集开始日期
|
|||
|
|
train_end: 训练集结束日期
|
|||
|
|
test_start: 测试集开始日期
|
|||
|
|
test_end: 测试集结束日期
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
(train_data, test_data): 训练集和测试集的DataFrame
|
|||
|
|
"""
|
|||
|
|
from src.data.storage import Storage
|
|||
|
|
|
|||
|
|
storage = Storage()
|
|||
|
|
|
|||
|
|
# 加载日线数据(需要更多历史数据用于计算因子)
|
|||
|
|
# 训练集需要更多历史数据(用于计算因子lookback)
|
|||
|
|
lookback_days = 20 # 足够计算MA10和5日收益率
|
|||
|
|
start_with_lookback = str(int(train_start) - 10000) # 往前取一年
|
|||
|
|
|
|||
|
|
# 查询训练集数据
|
|||
|
|
# 注意:DuckDB 中 trade_date 是 DATE 类型,需要转换
|
|||
|
|
start_dt = f"{start_with_lookback[:4]}-{start_with_lookback[4:6]}-{start_with_lookback[6:8]}"
|
|||
|
|
end_dt = f"{train_end[:4]}-{train_end[4:6]}-{train_end[6:8]}"
|
|||
|
|
train_query = f"""
|
|||
|
|
SELECT ts_code, trade_date, close, pre_close
|
|||
|
|
FROM daily
|
|||
|
|
WHERE trade_date >= '{start_dt}' AND trade_date <= '{end_dt}'
|
|||
|
|
ORDER BY ts_code, trade_date
|
|||
|
|
"""
|
|||
|
|
train_raw = storage._connection.sql(train_query).pl()
|
|||
|
|
# 转换 trade_date 为字符串格式
|
|||
|
|
train_raw = train_raw.with_columns(
|
|||
|
|
pl.col("trade_date").dt.strftime("%Y-%m-%d").alias("trade_date")
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# 查询测试集数据(也需要历史数据计算因子)
|
|||
|
|
test_start_dt = f"{test_start[:4]}-{test_start[4:6]}-{test_start[6:8]}"
|
|||
|
|
test_end_dt = f"{test_end[:4]}-{test_end[4:6]}-{test_end[6:8]}"
|
|||
|
|
test_query = f"""
|
|||
|
|
SELECT ts_code, trade_date, close, pre_close
|
|||
|
|
FROM daily
|
|||
|
|
WHERE trade_date >= '{test_start_dt}' AND trade_date <= '{test_end_dt}'
|
|||
|
|
ORDER BY ts_code, trade_date
|
|||
|
|
"""
|
|||
|
|
test_raw = storage._connection.sql(test_query).pl()
|
|||
|
|
# 转换 trade_date 为字符串格式
|
|||
|
|
test_raw = test_raw.with_columns(
|
|||
|
|
pl.col("trade_date").dt.strftime("%Y-%m-%d").alias("trade_date")
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# 过滤不符合条件的股票
|
|||
|
|
train_raw = _filter_invalid_stocks(train_raw)
|
|||
|
|
test_raw = _filter_invalid_stocks(test_raw)
|
|||
|
|
print(f"[PrepareData] After filtering: train={len(train_raw)}, test={len(test_raw)}")
|
|||
|
|
|
|||
|
|
# 计算因子和标签
|
|||
|
|
train_data = _compute_features_and_label(train_raw, train_start, train_end)
|
|||
|
|
test_data = _compute_features_and_label(test_raw, test_start, test_end)
|
|||
|
|
|
|||
|
|
return train_data, test_data
|
|||
|
|
|
|||
|
|
|
|||
|
|
def _filter_invalid_stocks(df: pl.DataFrame) -> pl.DataFrame:
|
|||
|
|
"""过滤不符合条件的股票
|
|||
|
|
|
|||
|
|
过滤规则:
|
|||
|
|
1. 过滤北交所股票(ts_code 以 BJ 结尾)
|
|||
|
|
2. 过滤创业板股票(ts_code 以 30 开头)
|
|||
|
|
3. 过滤科创板股票(ts_code 以 68 开头)
|
|||
|
|
4. 过滤退市/风险股票(ts_code 以 8 开头)
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
df: 原始数据
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
过滤后的数据
|
|||
|
|
"""
|
|||
|
|
ts_code_col = pl.col("ts_code")
|
|||
|
|
|
|||
|
|
return df.filter(
|
|||
|
|
~ts_code_col.str.ends_with("BJ")
|
|||
|
|
& ~ts_code_col.str.starts_with("30")
|
|||
|
|
& ~ts_code_col.str.starts_with("68")
|
|||
|
|
& ~ts_code_col.str.starts_with("8")
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
|
|||
|
|
def _compute_features_and_label(
|
|||
|
|
raw_data: pl.DataFrame,
|
|||
|
|
start_date: str,
|
|||
|
|
end_date: str,
|
|||
|
|
) -> pl.DataFrame:
|
|||
|
|
"""计算因子和标签
|
|||
|
|
|
|||
|
|
因子:
|
|||
|
|
1. return_5_rank: 5日收益率截面排名
|
|||
|
|
2. ma_5: 5日移动平均
|
|||
|
|
3. ma_10: 10日移动平均
|
|||
|
|
|
|||
|
|
标签:未来5日收益率大于0为1,否则为0
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
raw_data: 原始日线数据
|
|||
|
|
start_date: 开始日期
|
|||
|
|
end_date: 结束日期
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
包含因子和标签的DataFrame
|
|||
|
|
"""
|
|||
|
|
# 确保按日期排序
|
|||
|
|
raw_data = raw_data.sort(["ts_code", "trade_date"])
|
|||
|
|
|
|||
|
|
# 计算收益率(未来5日)
|
|||
|
|
raw_data = raw_data.with_columns(
|
|||
|
|
[
|
|||
|
|
# 当日收益率
|
|||
|
|
((pl.col("close") - pl.col("pre_close")) / pl.col("pre_close")).alias(
|
|||
|
|
"daily_return"
|
|||
|
|
),
|
|||
|
|
]
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# 按股票分组计算
|
|||
|
|
result_list = []
|
|||
|
|
|
|||
|
|
for ts_code in raw_data["ts_code"].unique():
|
|||
|
|
stock_data = raw_data.filter(pl.col("ts_code") == ts_code).sort("trade_date")
|
|||
|
|
|
|||
|
|
if len(stock_data) < 20:
|
|||
|
|
continue
|
|||
|
|
|
|||
|
|
# 计算MA5和MA10
|
|||
|
|
stock_data = stock_data.with_columns(
|
|||
|
|
[
|
|||
|
|
pl.col("close").rolling_mean(5).alias("ma_5"),
|
|||
|
|
pl.col("close").rolling_mean(10).alias("ma_10"),
|
|||
|
|
]
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# 计算未来5日收益率(用于标签)
|
|||
|
|
future_return = stock_data["close"].shift(-5) - stock_data["close"]
|
|||
|
|
future_return_pct = future_return / stock_data["close"]
|
|||
|
|
stock_data = stock_data.with_columns(
|
|||
|
|
[
|
|||
|
|
future_return_pct.alias("future_return_5"),
|
|||
|
|
]
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# 生成标签:收益率>0为1,否则为0
|
|||
|
|
stock_data = stock_data.with_columns(
|
|||
|
|
[
|
|||
|
|
(pl.col("future_return_5") > 0).cast(pl.Int8).alias("label"),
|
|||
|
|
]
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
result_list.append(stock_data)
|
|||
|
|
|
|||
|
|
if not result_list:
|
|||
|
|
return pl.DataFrame()
|
|||
|
|
|
|||
|
|
result = pl.concat(result_list)
|
|||
|
|
|
|||
|
|
# 转换日期格式:YYYYMMDD -> YYYY-MM-DD
|
|||
|
|
start_date_formatted = f"{start_date[:4]}-{start_date[4:6]}-{start_date[6:8]}"
|
|||
|
|
end_date_formatted = f"{end_date[:4]}-{end_date[4:6]}-{end_date[6:8]}"
|
|||
|
|
|
|||
|
|
# 过滤有效日期范围
|
|||
|
|
result = result.filter(
|
|||
|
|
(pl.col("trade_date") >= start_date_formatted) & (pl.col("trade_date") <= end_date_formatted)
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# 计算5日收益率排名(截面)
|
|||
|
|
result = result.with_columns(
|
|||
|
|
[
|
|||
|
|
pl.col("daily_return")
|
|||
|
|
.rank(method="average")
|
|||
|
|
.over("trade_date")
|
|||
|
|
.alias("return_5_rank")
|
|||
|
|
]
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# 归一化排名到0-1
|
|||
|
|
result = result.with_columns(
|
|||
|
|
[
|
|||
|
|
(
|
|||
|
|
pl.col("return_5_rank")
|
|||
|
|
/ pl.col("return_5_rank").max().over("trade_date")
|
|||
|
|
).alias("return_5_rank")
|
|||
|
|
]
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# 选择需要的列
|
|||
|
|
feature_cols = ["trade_date", "ts_code", "return_5_rank", "ma_5", "ma_10", "label"]
|
|||
|
|
result = result.select(feature_cols)
|
|||
|
|
|
|||
|
|
return result
|
|||
|
|
|
|||
|
|
|
|||
|
|
def create_pipeline() -> ProcessingPipeline:
|
|||
|
|
"""创建数据处理流水线
|
|||
|
|
|
|||
|
|
处理流程:
|
|||
|
|
1. FillNA(0): 将缺失值填充为0
|
|||
|
|
|
|||
|
|
注意:不使用 Dropna,因为会导致训练和预测时的行数不匹配
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
配置好的ProcessingPipeline
|
|||
|
|
"""
|
|||
|
|
processors = [
|
|||
|
|
FillNAProcessor(method="zero"), # 缺失值填充为0
|
|||
|
|
]
|
|||
|
|
return ProcessingPipeline(processors)
|
|||
|
|
|
|||
|
|
|
|||
|
|
def train_model(
|
|||
|
|
train_data: pl.DataFrame,
|
|||
|
|
feature_cols: List[str],
|
|||
|
|
label_col: str = "label",
|
|||
|
|
model_params: Optional[dict] = None,
|
|||
|
|
) -> Tuple[LightGBMModel, ProcessingPipeline]:
|
|||
|
|
"""训练LightGBM分类模型
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
train_data: 训练数据
|
|||
|
|
feature_cols: 特征列名列表
|
|||
|
|
label_col: 标签列名
|
|||
|
|
model_params: 模型参数字典
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
(训练好的模型, 处理流水线)
|
|||
|
|
"""
|
|||
|
|
# 创建处理流水线
|
|||
|
|
pipeline = create_pipeline()
|
|||
|
|
print("[TrainModel] Pipeline created: FillNA(0)")
|
|||
|
|
|
|||
|
|
# 准备特征和标签
|
|||
|
|
X_train = train_data.select(feature_cols)
|
|||
|
|
y_train = train_data[label_col]
|
|||
|
|
print(f"[TrainModel] Raw samples: {len(X_train)}, features: {feature_cols}")
|
|||
|
|
|
|||
|
|
# 处理数据
|
|||
|
|
X_train_processed = pipeline.fit_transform(X_train, stage=PipelineStage.TRAIN)
|
|||
|
|
print(f"[TrainModel] After processing: {len(X_train_processed)} samples")
|
|||
|
|
|
|||
|
|
# 过滤有效标签(排除-1等无效值)
|
|||
|
|
valid_mask = y_train.is_in([0, 1])
|
|||
|
|
X_train_processed = X_train_processed.filter(valid_mask)
|
|||
|
|
y_train = y_train.filter(valid_mask)
|
|||
|
|
print(f"[TrainModel] After filtering valid labels: {len(X_train_processed)} samples")
|
|||
|
|
print(f"[TrainModel] Label distribution: {dict(y_train.value_counts().sort('label').iter_rows())}")
|
|||
|
|
|
|||
|
|
# 创建模型
|
|||
|
|
params = model_params or {
|
|||
|
|
"n_estimators": 100,
|
|||
|
|
"learning_rate": 0.05,
|
|||
|
|
"max_depth": 5,
|
|||
|
|
"num_leaves": 31,
|
|||
|
|
}
|
|||
|
|
print(f"[TrainModel] Model params: {params}")
|
|||
|
|
model = LightGBMModel(
|
|||
|
|
task_type="classification",
|
|||
|
|
params=params,
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# 训练模型
|
|||
|
|
print("[TrainModel] Training LightGBM...")
|
|||
|
|
model.fit(X_train_processed, y_train)
|
|||
|
|
print("[TrainModel] Training completed!")
|
|||
|
|
|
|||
|
|
return model, pipeline
|
|||
|
|
|
|||
|
|
|
|||
|
|
def predict_top_stocks(
|
|||
|
|
model: LightGBMModel,
|
|||
|
|
pipeline: ProcessingPipeline,
|
|||
|
|
test_data: pl.DataFrame,
|
|||
|
|
feature_cols: List[str],
|
|||
|
|
top_n: int = 5,
|
|||
|
|
) -> pl.DataFrame:
|
|||
|
|
"""预测并选出每日top N股票
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
model: 训练好的模型
|
|||
|
|
pipeline: 数据处理流水线
|
|||
|
|
test_data: 测试数据
|
|||
|
|
feature_cols: 特征列名
|
|||
|
|
top_n: 每日选出的股票数量
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
包含日期和股票代码的DataFrame
|
|||
|
|
"""
|
|||
|
|
# 准备特征和必要列
|
|||
|
|
X_test = test_data.select(feature_cols)
|
|||
|
|
key_cols = ["trade_date", "ts_code"]
|
|||
|
|
key_data = test_data.select(key_cols)
|
|||
|
|
|
|||
|
|
print(f"[Predict] Test samples: {len(X_test)}, top_n: {top_n}")
|
|||
|
|
|
|||
|
|
# 处理数据(使用训练阶段的参数)
|
|||
|
|
X_test_processed = pipeline.transform(X_test, stage=PipelineStage.TEST)
|
|||
|
|
print(f"[Predict] Data processed, shape: {X_test_processed.shape}")
|
|||
|
|
|
|||
|
|
# 预测概率
|
|||
|
|
probs = model.predict_proba(X_test_processed)
|
|||
|
|
print(f"[Predict] Predictions generated, probability shape: {probs.shape}")
|
|||
|
|
|
|||
|
|
# 使用 key_data 添加预测结果,保持行数一致
|
|||
|
|
result = key_data.with_columns(
|
|||
|
|
pl.Series(
|
|||
|
|
name="pred_prob", values=probs[:, 1] if len(probs.shape) > 1 and probs.shape[1] > 1 else probs.flatten()
|
|||
|
|
),
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# 每日选出top N
|
|||
|
|
top_stocks = []
|
|||
|
|
for date in result["trade_date"].unique().sort():
|
|||
|
|
day_data = result.filter(pl.col("trade_date") == date)
|
|||
|
|
|
|||
|
|
# 按概率降序排序,选出top N
|
|||
|
|
day_top = day_data.sort("pred_prob", descending=True).head(top_n)
|
|||
|
|
|
|||
|
|
top_stocks.append(day_top.select(["trade_date", "pred_prob", "ts_code"]).rename({"pred_prob": "score"}))
|
|||
|
|
|
|||
|
|
return pl.concat(top_stocks)
|
|||
|
|
|
|||
|
|
|
|||
|
|
def save_top_stocks(top_stocks: pl.DataFrame, output_path: str) -> None:
|
|||
|
|
"""保存选股结果到TSV文件
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
top_stocks: 选股结果
|
|||
|
|
output_path: 输出文件路径
|
|||
|
|
"""
|
|||
|
|
# 转换为pandas并保存为TSV
|
|||
|
|
df = top_stocks.to_pandas()
|
|||
|
|
df.to_csv(output_path, sep="\t", index=False)
|
|||
|
|
print(f"[Training] Top stocks saved to: {output_path}")
|
|||
|
|
|
|||
|
|
|
|||
|
|
def run_training(
|
|||
|
|
data_dir: str = "data",
|
|||
|
|
output_path: str = "output/top_stocks.tsv",
|
|||
|
|
train_start: str = "20180101",
|
|||
|
|
train_end: str = "20230101",
|
|||
|
|
test_start: str = "20230101",
|
|||
|
|
test_end: str = "20240101",
|
|||
|
|
top_n: int = 5,
|
|||
|
|
) -> pl.DataFrame:
|
|||
|
|
"""运行完整训练流程
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
data_dir: 数据目录
|
|||
|
|
output_path: 输出文件路径
|
|||
|
|
train_start: 训练集开始日期
|
|||
|
|
train_end: 训练集结束日期
|
|||
|
|
test_start: 测试集开始日期
|
|||
|
|
test_end: 测试集结束日期
|
|||
|
|
top_n: 每日选股数量
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
选股结果DataFrame
|
|||
|
|
"""
|
|||
|
|
print(f"[Training] Starting training pipeline...")
|
|||
|
|
print(f"[Training] Train period: {train_start} -> {train_end}")
|
|||
|
|
print(f"[Training] Test period: {test_start} -> {test_end}")
|
|||
|
|
|
|||
|
|
# 1. 准备数据
|
|||
|
|
print("[Training] Preparing data...")
|
|||
|
|
train_data, test_data = prepare_data(
|
|||
|
|
data_dir=data_dir,
|
|||
|
|
train_start=train_start,
|
|||
|
|
train_end=train_end,
|
|||
|
|
test_start=test_start,
|
|||
|
|
test_end=test_end,
|
|||
|
|
)
|
|||
|
|
print(f"[Training] Train samples: {len(train_data)}")
|
|||
|
|
print(f"[Training] Test samples: {len(test_data)}")
|
|||
|
|
|
|||
|
|
# 2. 定义特征列
|
|||
|
|
feature_cols = ["return_5_rank", "ma_5", "ma_10"]
|
|||
|
|
label_col = "label"
|
|||
|
|
|
|||
|
|
# 3. 训练模型
|
|||
|
|
print("[Training] Training model...")
|
|||
|
|
model, pipeline = train_model(
|
|||
|
|
train_data=train_data,
|
|||
|
|
feature_cols=feature_cols,
|
|||
|
|
label_col=label_col,
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# 4. 测试集预测
|
|||
|
|
print("[Training] Predicting on test set...")
|
|||
|
|
top_stocks = predict_top_stocks(
|
|||
|
|
model=model,
|
|||
|
|
pipeline=pipeline,
|
|||
|
|
test_data=test_data,
|
|||
|
|
feature_cols=feature_cols,
|
|||
|
|
top_n=top_n,
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# 5. 保存结果
|
|||
|
|
print(f"[Training] Saving results to {output_path}...")
|
|||
|
|
Path(output_path).parent.mkdir(parents=True, exist_ok=True)
|
|||
|
|
save_top_stocks(top_stocks, output_path)
|
|||
|
|
|
|||
|
|
print("[Training] Training completed!")
|
|||
|
|
|
|||
|
|
return top_stocks
|