refactor(training): 重构训练管道,支持灵活因子配置

- 添加 FactorConfig 类,支持链式 API 和动态因子列表
- 重构 prepare_data 使用 FactorEngine 计算因子,确保防泄露机制生效
- 新增 prepare_data_and_train / train_and_predict 分步执行接口
- 修复 factors/__init__.py docstring 语法错误
This commit is contained in:
2026-02-25 23:39:02 +08:00
parent a9e4746239
commit 990a77ec6c
2 changed files with 564 additions and 119 deletions

View File

@@ -5,26 +5,298 @@
或:
uv run python src/training/main.py
本脚本提供两种运行方式:
1. run_full_pipeline(): 完整训练流程(数据准备 -> 训练 -> 预测)
2. prepare_data_and_train() + train_and_predict(): 分步执行,便于调试和调整
因子配置示例:
from src.factors import MovingAverageFactor, ReturnRankFactor
# 直接传入因子实例列表 - 最简单的方式
factors = [
MovingAverageFactor(period=5),
MovingAverageFactor(period=10),
MovingAverageFactor(period=20),
ReturnRankFactor(period=5),
ReturnRankFactor(period=10),
]
# 运行完整流程
result = run_full_pipeline(factors=factors)
"""
from src.training.pipeline import run_training
from pathlib import Path
from typing import Optional, List
import polars as pl
from src.factors import BaseFactor
from src.training.pipeline import (
FactorConfig,
predict_top_stocks,
prepare_data,
save_top_stocks,
train_model,
)
def prepare_data_and_train(
factors: Optional[List[BaseFactor]] = None,
data_dir: str = "data",
train_start: str = "20190101",
train_end: str = "20231231",
val_start: str = "20240102",
val_end: str = "20240531",
test_start: str = "20240602",
test_end: str = "20241231",
) -> tuple[pl.DataFrame, pl.DataFrame, pl.DataFrame, FactorConfig, str]:
"""第一步:数据处理
加载原始数据,计算因子和标签,拆分训练/验证/测试集。
Args:
factors: 因子实例列表,默认为 None使用 MA5, MA10, ReturnRank5
data_dir: 数据目录
train_start: 训练集开始日期
train_end: 训练集结束日期
val_start: 验证集开始日期
val_end: 验证集结束日期
test_start: 测试集开始日期
test_end: 测试集结束日期
Returns:
tuple: (train_data, val_data, test_data, factor_config, label_col)
"""
print("=" * 50)
print("[Step 1] 数据处理")
print("=" * 50)
print(f"训练集: {train_start} -> {train_end}")
print(f"验证集: {val_start} -> {val_end}")
print(f"测试集: {test_start} -> {test_end}")
print()
# 1. 准备数据
train_data, val_data, test_data, factor_config = prepare_data(
factors=factors,
data_dir=data_dir,
train_start=train_start,
train_end=train_end,
val_start=val_start,
val_end=val_end,
test_start=test_start,
test_end=test_end,
)
print(f"训练集样本数: {len(train_data)}")
print(f"验证集样本数: {len(val_data)}")
print(f"测试集样本数: {len(test_data)}")
print()
# 打印少量数据样本展示
print("=" * 50)
print("[数据预览] 训练集前3行:")
print(train_data.head(3))
print()
print("[数据预览] 验证集前3行:")
print(val_data.head(3))
print()
print("[数据预览] 测试集前3行:")
print(test_data.head(3))
print()
# 2. 获取特征列名
feature_cols = factor_config.get_feature_names()
label_col = "label"
print(f"特征列: {feature_cols}")
print(f"标签列: {label_col}")
print()
return train_data, val_data, test_data, factor_config, label_col
def train_and_predict(
train_data: pl.DataFrame,
val_data: pl.DataFrame,
test_data: pl.DataFrame,
factor_config: FactorConfig,
label_col: str = "label",
top_n: int = 5,
output_path: str = "output/top_stocks.tsv",
) -> pl.DataFrame:
"""第二步:训练和预测
使用处理好的数据训练模型,进行测试集预测并保存结果。
Args:
train_data: 训练数据
val_data: 验证数据
test_data: 测试数据
factor_config: 因子配置对象
label_col: 标签列名
top_n: 每日选股数量
output_path: 输出文件路径
Returns:
选股结果DataFrame
"""
print("=" * 50)
print("[Step 2] 模型训练与预测")
print("=" * 50)
print()
# 获取特征列名
feature_cols = factor_config.get_feature_names()
print(f"使用特征: {feature_cols}")
print()
# 3. 训练模型
print("[Training] Training model...")
model, pipeline = train_model(
train_data=train_data,
val_data=val_data,
feature_cols=feature_cols,
label_col=label_col,
)
print()
# 4. 测试集预测
print("[Predict] Predicting on test set...")
top_stocks = predict_top_stocks(
model=model,
pipeline=pipeline,
test_data=test_data,
feature_cols=feature_cols,
top_n=top_n,
)
print()
# 5. 保存结果
print(f"[Saving] Saving results to {output_path}...")
Path(output_path).parent.mkdir(parents=True, exist_ok=True)
save_top_stocks(top_stocks, output_path)
print()
return top_stocks
def run_full_pipeline(
factors: Optional[List[BaseFactor]] = None,
train_start: str = "20190101",
train_end: str = "20231231",
val_start: str = "20240102",
val_end: str = "20240531",
test_start: str = "20240602",
test_end: str = "20241231",
top_n: int = 5,
output_path: str = "output/top_stocks.tsv",
) -> pl.DataFrame:
"""运行完整训练流程
相当于依次调用 prepare_data_and_train 和 train_and_predict。
Args:
factors: 因子实例列表,默认为 None使用 MA5, MA10, ReturnRank5
train_start: 训练集开始日期
train_end: 训练集结束日期
val_start: 验证集开始日期
val_end: 验证集结束日期
test_start: 测试集开始日期
test_end: 测试集结束日期
top_n: 每日选股数量
output_path: 输出文件路径
Returns:
选股结果DataFrame
"""
# 第一步:数据处理
train_data, val_data, test_data, factor_config, label_col = prepare_data_and_train(
factors=factors,
train_start=train_start,
train_end=train_end,
val_start=val_start,
val_end=val_end,
test_start=test_start,
test_end=test_end,
)
# 第二步:训练和预测
result = train_and_predict(
train_data=train_data,
val_data=val_data,
test_data=test_data,
factor_config=factor_config,
label_col=label_col,
top_n=top_n,
output_path=output_path,
)
print("=" * 50)
print("[Done] 训练流程完成!")
print("=" * 50)
return result
if __name__ == "__main__":
# 运行完整训练流程
# 训练集20190101 - 20231231
# 验证集20240102 - 20240531 (与训练集间隔1天避免数据泄露)
# 测试集20240602 - 20241231 (与验证集间隔1天避免数据泄露)
result = run_training(
from src.factors import MovingAverageFactor, ReturnRankFactor
# ========== 因子配置 ==========
# 直接传入因子实例列表 - 简单直观
factors = [
MovingAverageFactor(period=5), # 5日移动平均线
MovingAverageFactor(period=10), # 10日移动平均线
MovingAverageFactor(period=20), # 20日移动平均线
ReturnRankFactor(period=5), # 5日收益率排名
ReturnRankFactor(period=10), # 10日收益率排名
]
# ========== 运行方式 ==========
# 方式一:完整流程(一次性执行)
# result = run_full_pipeline(
# factors=factors,
# train_start="20190101",
# train_end="20231231",
# val_start="20240102",
# val_end="20240531",
# test_start="20240602",
# test_end="20241231",
# top_n=5,
# output_path="output/top_stocks.tsv",
# )
# 方式二:分步执行(便于调试)
# 第一步:数据处理
train_data, val_data, test_data, factor_config, label_col = prepare_data_and_train(
factors=factors,
train_start="20190101",
train_end="20231231",
val_start="20240102",
val_end="20240531",
test_start="20240602",
test_end="20241231",
top_n=5,
output_path="output/top_stocks.tsv",
)
print("\n[Result] Top stocks selection:")
print(result)
# 可在此处添加自定义逻辑,例如:
# - 查看数据分布
# - 调整特征
# - 保存中间结果
print("\n[Info] 因子配置详情:")
print(f" 因子列表: {factor_config.get_feature_names()}")
print(f" 最大回溯天数: {factor_config.get_max_lookback()}")
# 第二步:训练和预测
# result = train_and_predict(
# train_data=train_data,
# val_data=val_data,
# test_data=test_data,
# factor_config=factor_config,
# label_col=label_col,
# top_n=5,
# output_path="output/top_stocks.tsv",
# )
#
# print("\n[Result] Top stocks selection:")
# print(result)

View File

@@ -1,21 +1,48 @@
"""训练管道 - 包含数据处理、模型训练和预测功能
本模块提供:
1. 数据准备:从因子计算结果中准备训练/测试数据
1. 数据准备:使用 FactorEngine 从因子计算结果中准备训练/测试数据
2. 数据处理Fillna(0) -> Dropna
3. 模型训练使用LightGBM训练分类模型
4. 预测和选股输出每日top5股票池
注意:本模块使用 src.factors 框架进行因子计算,确保防泄露机制生效。
因子配置示例:
from src.factors import MovingAverageFactor, ReturnRankFactor
from src.training.pipeline import prepare_data
# 直接传入因子实例列表 - 简单直观
factors = [
MovingAverageFactor(period=5),
MovingAverageFactor(period=10),
ReturnRankFactor(period=5),
]
train_data, val_data, test_data, factor_config = prepare_data(
factors=factors,
...
)
# 或者使用 FactorConfig 包装(支持链式添加)
from src.training.pipeline import FactorConfig
config = FactorConfig()
.add(MovingAverageFactor(period=5))
.add(MovingAverageFactor(period=10))
.add(ReturnRankFactor(period=5))
"""
from datetime import datetime
from pathlib import Path
from typing import List, Optional, Tuple
from typing import List, Optional
import numpy as np
import polars as pl
from src.factors import DataLoader, FactorEngine
from src.factors import DataLoader, FactorEngine, BaseFactor
from src.factors.data_spec import DataSpec
from src.factors.momentum import MovingAverageFactor, ReturnRankFactor
from src.pipeline import (
DropNAProcessor,
FillNAProcessor,
@@ -26,7 +53,98 @@ from src.pipeline import (
)
# ========== 因子配置类 ==========
class FactorConfig:
"""因子配置类 - 管理因子列表
用于包装因子实例列表,支持链式添加。
示例:
# 方式1初始化时传入列表
config = FactorConfig([
MovingAverageFactor(period=5),
ReturnRankFactor(period=5),
])
# 方式2链式添加
config = FactorConfig()
.add(MovingAverageFactor(period=5))
.add(ReturnRankFactor(period=5))
# 获取因子实例列表
factors = config.get_factors()
# 获取特征列名
feature_cols = config.get_feature_names()
"""
def __init__(self, factors: Optional[List[BaseFactor]] = None):
"""初始化因子配置
Args:
factors: 因子实例列表
"""
self._factors: List[BaseFactor] = factors or []
def add(self, factor: BaseFactor) -> "FactorConfig":
"""添加因子到配置
支持链式调用:
config = FactorConfig()
.add(MovingAverageFactor(period=5))
.add(ReturnRankFactor(period=5))
Args:
factor: 因子实例
Returns:
self支持链式调用
"""
if not isinstance(factor, BaseFactor):
raise ValueError(f"必须是 BaseFactor 实例, got {type(factor)}")
self._factors.append(factor)
return self
def get_factors(self) -> List[BaseFactor]:
"""获取因子实例列表
Returns:
因子实例列表
"""
return self._factors
def get_feature_names(self) -> List[str]:
"""获取所有因子的特征列名
Returns:
特征列名列表
"""
return [f.name for f in self._factors]
def get_max_lookback(self) -> int:
"""获取所有因子中最大的 lookback 天数
Returns:
最大 lookback 天数
"""
max_lookback = 0
for factor in self._factors:
for spec in factor.data_specs:
max_lookback = max(max_lookback, spec.lookback_days)
return max_lookback
def __len__(self) -> int:
return len(self._factors)
def __repr__(self) -> str:
names = [f.name for f in self._factors]
return f"FactorConfig({names})"
def prepare_data(
factors: Optional[List[BaseFactor]] = None,
data_dir: str = "data",
train_start: str = "20180101",
train_end: str = "20230101",
@@ -34,12 +152,13 @@ def prepare_data(
val_end: str = "20230601",
test_start: str = "20230601",
test_end: str = "20240101",
) -> tuple[pl.DataFrame, pl.DataFrame, pl.DataFrame]:
) -> tuple[pl.DataFrame, pl.DataFrame, pl.DataFrame, FactorConfig]:
"""准备训练、验证和测试数据
从DuckDB加载原始日线数据计算所需因子并生成标签
使用 FactorEngine 计算因子,确保防泄露机制生效
Args:
factors: 因子实例列表,默认为 None使用 MA5, MA10, ReturnRank5
data_dir: 数据目录
train_start: 训练集开始日期
train_end: 训练集结束日期
@@ -49,44 +168,107 @@ def prepare_data(
test_end: 测试集结束日期
Returns:
(train_data, val_data, test_data): 训练集、验证集和测试集的DataFrame
(train_data, val_data, test_data, factor_config):
训练集、验证集、测试集的DataFrame以及使用的因子配置
"""
from src.data.storage import Storage
storage = Storage()
# 加载日线数据(需要更多历史数据用于计算因子)
# 训练集需要更多历史数据用于计算因子lookback
lookback_days = 20 # 足够计算MA10和5日收益率
# 1. 处理因子配置
if factors is None:
# 默认因子配置
factor_config = FactorConfig(
[
MovingAverageFactor(period=5),
MovingAverageFactor(period=10),
ReturnRankFactor(period=5),
]
)
elif isinstance(factors, FactorConfig):
factor_config = factors
elif isinstance(factors, list):
# 转换为 FactorConfig
factor_config = FactorConfig(factors)
else:
raise ValueError(
f"factors 必须是 List[BaseFactor] 或 FactorConfig got {type(factors)}"
)
factors_list = factor_config.get_factors()
feature_cols = factor_config.get_feature_names()
if not factors_list:
raise ValueError("至少需要提供一个因子")
print(f"[PrepareData] 使用因子: {feature_cols}")
# 2. 初始化 FactorEngine
loader = DataLoader(data_dir=data_dir)
engine = FactorEngine(loader)
# 3. 计算需要的回溯天数
max_lookback = factor_config.get_max_lookback()
start_with_lookback = str(int(train_start) - 10000) # 往前取一年
# 查询全部数据包含train、val、test然后再拆分
# 注意DuckDB 中 trade_date 是 DATE 类型,需要转换
# 获取所有股票列表
start_dt = f"{start_with_lookback[:4]}-{start_with_lookback[4:6]}-{start_with_lookback[6:8]}"
end_dt = f"{test_end[:4]}-{test_end[4:6]}-{test_end[6:8]}"
all_query = f"""
SELECT ts_code, trade_date, close, pre_close
FROM daily
all_stocks_query = f"""
SELECT DISTINCT ts_code FROM daily
WHERE trade_date >= '{start_dt}' AND trade_date <= '{end_dt}'
ORDER BY ts_code, trade_date
"""
all_raw = storage._connection.sql(all_query).pl()
# 转换 trade_date 为字符串格式
all_raw = all_raw.with_columns(
pl.col("trade_date").dt.strftime("%Y-%m-%d").alias("trade_date")
)
all_stocks_df = storage._connection.sql(all_stocks_query).pl()
all_stocks = all_stocks_df["ts_code"].to_list()
# 过滤不符合条件的股票
all_raw = _filter_invalid_stocks(all_raw)
print(f"[PrepareData] After filtering: total={len(all_raw)}")
print(f"[PrepareData] 股票数量: {len(all_stocks)}")
# 计算因子和标签(需要全局数据一次性计算)
all_data = _compute_features_and_label(
all_raw,
start_date=train_start,
end_date=test_end
)
# 4. 计算所有因子并合并
all_features = None
for factor in factors_list:
print(f"[PrepareData] 计算因子: {factor.name} ({factor.factor_type})")
if factor.factor_type == "time_series":
# 时序因子需要传入股票列表
result = engine.compute(
factor,
stock_codes=all_stocks,
start_date=start_with_lookback,
end_date=test_end,
)
else:
# 截面因子不需要股票列表
result = engine.compute(
factor,
start_date=start_with_lookback,
end_date=test_end,
)
# 合并结果
if all_features is None:
all_features = result
else:
# 确保没有重复的 _right 列
result = result.select([
c for c in result.columns if not c.endswith("_right")
])
all_features = all_features.select([
c for c in all_features.columns if not c.endswith("_right")
])
all_features = all_features.join(
result, on=["trade_date", "ts_code"], how="outer"
)
if all_features is None:
raise ValueError("没有计算任何因子")
# 5. 计算标签未来5日收益率
all_data = _compute_label(all_features, start_date=train_start, end_date=test_end)
# 6. 过滤不符合条件的股票
all_data = _filter_invalid_stocks(all_data)
print(f"[PrepareData] After filtering: total={len(all_data)}")
# 转换日期格式用于比较
train_start_fmt = f"{train_start[:4]}-{train_start[4:6]}-{train_start[6:8]}"
@@ -98,18 +280,22 @@ def prepare_data(
# 拆分数据
train_data = all_data.filter(
(pl.col("trade_date") >= train_start_fmt) & (pl.col("trade_date") <= train_end_fmt)
(pl.col("trade_date") >= train_start_fmt)
& (pl.col("trade_date") <= train_end_fmt)
)
val_data = all_data.filter(
(pl.col("trade_date") >= val_start_fmt) & (pl.col("trade_date") <= val_end_fmt)
)
test_data = all_data.filter(
(pl.col("trade_date") >= test_start_fmt) & (pl.col("trade_date") <= test_end_fmt)
(pl.col("trade_date") >= test_start_fmt)
& (pl.col("trade_date") <= test_end_fmt)
)
print(f"[PrepareData] Split result: train={len(train_data)}, val={len(val_data)}, test={len(test_data)}")
print(
f"[PrepareData] Split result: train={len(train_data)}, val={len(val_data)}, test={len(test_data)}"
)
return train_data, val_data, test_data
return train_data, val_data, test_data, factor_config
def _filter_invalid_stocks(df: pl.DataFrame) -> pl.DataFrame:
@@ -137,59 +323,54 @@ def _filter_invalid_stocks(df: pl.DataFrame) -> pl.DataFrame:
)
def _compute_features_and_label(
raw_data: pl.DataFrame,
def _compute_label(
features_df: pl.DataFrame,
start_date: str,
end_date: str,
) -> pl.DataFrame:
"""计算因子和标签
"""计算标签未来5日收益率
因子:
1. return_5_rank: 5日收益率截面排名
2. ma_5: 5日移动平均
3. ma_10: 10日移动平均
标签未来5日收益率大于0为1否则为0
标签定义未来5日收益率大于0为1否则为0
Args:
raw_data: 原始日线数据
features_df: 包含因子的DataFrame
start_date: 开始日期
end_date: 结束日期
Returns:
包含因子和标签的DataFrame
"""
# 确保按日期排序
raw_data = raw_data.sort(["ts_code", "trade_date"])
from src.data.storage import Storage
# 计算收益率未来5日
raw_data = raw_data.with_columns(
[
# 当日收益率
((pl.col("close") - pl.col("pre_close")) / pl.col("pre_close")).alias(
"daily_return"
),
]
storage = Storage()
# 从数据库获取收盘价数据用于计算标签
start_dt = f"{start_date[:4]}-{start_date[4:6]}-{start_date[6:8]}"
end_dt = f"{end_date[:4]}-{end_date[4:6]}-{end_date[6:8]}"
# 需要多取5天数据来计算未来收益率
end_dt_extended = f"{end_date[:4]}-{end_date[4:6]}-{int(end_date[6:8]) + 5}"
price_query = f"""
SELECT ts_code, trade_date, close
FROM daily
WHERE trade_date >= '{start_dt}' AND trade_date <= '{end_dt_extended}'
ORDER BY ts_code, trade_date
"""
price_data = storage._connection.sql(price_query).pl()
price_data = price_data.with_columns(
pl.col("trade_date").dt.strftime("%Y-%m-%d").alias("trade_date")
)
# 按股票分组计算
# 按股票计算未来5日收益率
result_list = []
for ts_code in price_data["ts_code"].unique():
stock_data = price_data.filter(pl.col("ts_code") == ts_code).sort("trade_date")
for ts_code in raw_data["ts_code"].unique():
stock_data = raw_data.filter(pl.col("ts_code") == ts_code).sort("trade_date")
if len(stock_data) < 20:
if len(stock_data) < 6:
continue
# 计算MA5和MA10
stock_data = stock_data.with_columns(
[
pl.col("close").rolling_mean(5).alias("ma_5"),
pl.col("close").rolling_mean(10).alias("ma_10"),
]
)
# 计算未来5日收益率用于标签
# 计算未来5日收益率
future_return = stock_data["close"].shift(-5) - stock_data["close"]
future_return_pct = future_return / stock_data["close"]
stock_data = stock_data.with_columns(
@@ -205,46 +386,21 @@ def _compute_features_and_label(
]
)
result_list.append(stock_data)
result_list.append(stock_data.select(["trade_date", "ts_code", "label"]))
if not result_list:
return pl.DataFrame()
result = pl.concat(result_list)
label_df = pl.concat(result_list)
# 转换日期格式YYYYMMDD -> YYYY-MM-DD
start_date_formatted = f"{start_date[:4]}-{start_date[4:6]}-{start_date[6:8]}"
end_date_formatted = f"{end_date[:4]}-{end_date[4:6]}-{end_date[6:8]}"
# 将标签合并到因子数据
result = features_df.join(label_df, on=["trade_date", "ts_code"], how="inner")
# 过滤有效日期范围
result = result.filter(
(pl.col("trade_date") >= start_date_formatted) & (pl.col("trade_date") <= end_date_formatted)
(pl.col("trade_date") >= start_dt) & (pl.col("trade_date") <= end_dt)
)
# 计算5日收益率排名截面
result = result.with_columns(
[
pl.col("daily_return")
.rank(method="average")
.over("trade_date")
.alias("return_5_rank")
]
)
# 归一化排名到0-1
result = result.with_columns(
[
(
pl.col("return_5_rank")
/ pl.col("return_5_rank").max().over("trade_date")
).alias("return_5_rank")
]
)
# 选择需要的列
feature_cols = ["trade_date", "ts_code", "return_5_rank", "ma_5", "ma_10", "label"]
result = result.select(feature_cols)
return result
@@ -271,7 +427,7 @@ def train_model(
feature_cols: List[str],
label_col: str = "label",
model_params: Optional[dict] = None,
) -> Tuple[LightGBMModel, ProcessingPipeline]:
) -> tuple[LightGBMModel, ProcessingPipeline]:
"""训练LightGBM分类模型
Args:
@@ -301,8 +457,12 @@ def train_model(
valid_mask = y_train.is_in([0, 1])
X_train_processed = X_train_processed.filter(valid_mask)
y_train = y_train.filter(valid_mask)
print(f"[TrainModel] After filtering valid labels: {len(X_train_processed)} samples")
print(f"[TrainModel] Train label distribution: {dict(y_train.value_counts().sort('label').iter_rows())}")
print(
f"[TrainModel] After filtering valid labels: {len(X_train_processed)} samples"
)
print(
f"[TrainModel] Train label distribution: {dict(y_train.value_counts().sort('label').iter_rows())}"
)
# 准备验证集
X_val_processed = None
@@ -311,16 +471,18 @@ def train_model(
X_val = val_data.select(feature_cols)
y_val = val_data[label_col]
print(f"[TrainModel] Val samples: {len(X_val)}")
# 处理验证集数据(使用训练集的参数)
X_val_processed = pipeline.transform(X_val, stage=PipelineStage.TEST)
# 过滤验证集有效标签
val_valid_mask = y_val.is_in([0, 1])
X_val_processed = X_val_processed.filter(val_valid_mask)
y_val = y_val.filter(val_valid_mask)
print(f"[TrainModel] Val after filtering: {len(X_val_processed)} samples")
print(f"[TrainModel] Val label distribution: {dict(y_val.value_counts().sort('label').iter_rows())}")
print(
f"[TrainModel] Val label distribution: {dict(y_val.value_counts().sort('label').iter_rows())}"
)
# 创建模型
params = model_params or {
@@ -384,7 +546,10 @@ def predict_top_stocks(
# 使用 key_data 添加预测结果,保持行数一致
result = key_data.with_columns(
pl.Series(
name="pred_prob", values=probs[:, 1] if len(probs.shape) > 1 and probs.shape[1] > 1 else probs.flatten()
name="pred_prob",
values=probs[:, 1]
if len(probs.shape) > 1 and probs.shape[1] > 1
else probs.flatten(),
),
)
@@ -396,7 +561,11 @@ def predict_top_stocks(
# 按概率降序排序选出top N
day_top = day_data.sort("pred_prob", descending=True).head(top_n)
top_stocks.append(day_top.select(["trade_date", "pred_prob", "ts_code"]).rename({"pred_prob": "score"}))
top_stocks.append(
day_top.select(["trade_date", "pred_prob", "ts_code"]).rename(
{"pred_prob": "score"}
)
)
return pl.concat(top_stocks)
@@ -415,6 +584,7 @@ def save_top_stocks(top_stocks: pl.DataFrame, output_path: str) -> None:
def run_training(
factors: Optional[List[BaseFactor]] = None,
data_dir: str = "data",
output_path: str = "output/top_stocks.tsv",
train_start: str = "20180101",
@@ -428,6 +598,7 @@ def run_training(
"""运行完整训练流程
Args:
factors: 因子实例列表,默认为 None使用 MA5, MA10, ReturnRank5
data_dir: 数据目录
output_path: 输出文件路径
train_start: 训练集开始日期
@@ -448,7 +619,8 @@ def run_training(
# 1. 准备数据
print("[Training] Preparing data...")
train_data, val_data, test_data = prepare_data(
train_data, val_data, test_data, factor_config = prepare_data(
factors=factors,
data_dir=data_dir,
train_start=train_start,
train_end=train_end,
@@ -461,9 +633,10 @@ def run_training(
print(f"[Training] Val samples: {len(val_data)}")
print(f"[Training] Test samples: {len(test_data)}")
# 2. 定义特征列
feature_cols = ["return_5_rank", "ma_5", "ma_10"]
# 2. 获取特征列
feature_cols = factor_config.get_feature_names()
label_col = "label"
print(f"[Training] Feature columns: {feature_cols}")
# 3. 训练模型
print("[Training] Training model...")