feat(training): 添加训练模块基础架构
实现 Commit 1:训练模块基础架构 新增文件: - src/training/__init__.py - 主模块导出 - src/training/components/__init__.py - components 子模块导出 - src/training/components/base.py - BaseModel/BaseProcessor 抽象基类 - src/training/registry.py - 模型和处理器注册中心 - tests/training/test_base.py - 基础架构单元测试 功能特性: - BaseModel: 提供 fit, predict, feature_importance, save/load 接口 - BaseProcessor: 提供 fit, transform, fit_transform 接口 - ModelRegistry/ProcessorRegistry: 支持装饰器风格组件注册 - 支持即插即用的组件扩展机制 Ultraworked with [Sisyphus](https://github.com/code-yeongyu/oh-my-opencode) Co-authored-by: Sisyphus <clio-agent@sisyphuslabs.ai>
This commit is contained in:
@@ -1,46 +1,26 @@
|
||||
"""ProStock 训练流程模块
|
||||
"""训练模块 - ProStock 量化投资框架
|
||||
|
||||
本模块提供完整的模型训练流程:
|
||||
1. 数据处理:Fillna(0) -> Dropna
|
||||
2. 模型训练:LightGBM分类模型
|
||||
3. 预测选股:每日top5股票池
|
||||
|
||||
使用示例:
|
||||
from src.training import run_training
|
||||
|
||||
# 运行完整训练流程
|
||||
result = run_training(
|
||||
train_start="20180101",
|
||||
train_end="20230101",
|
||||
test_start="20230101",
|
||||
test_end="20240101",
|
||||
top_n=5,
|
||||
output_path="output/top_stocks.tsv"
|
||||
)
|
||||
|
||||
因子使用:
|
||||
from src.factors import MovingAverageFactor, ReturnRankFactor
|
||||
|
||||
ma5 = MovingAverageFactor(period=5) # 5日移动平均
|
||||
ma10 = MovingAverageFactor(period=10) # 10日移动平均
|
||||
ret5 = ReturnRankFactor(period=5) # 5日收益率排名
|
||||
提供模型训练、数据处理和评估的完整流程。
|
||||
"""
|
||||
|
||||
from src.training.pipeline import (
|
||||
create_pipeline,
|
||||
predict_top_stocks,
|
||||
prepare_data,
|
||||
run_training,
|
||||
save_top_stocks,
|
||||
train_model,
|
||||
# 基础抽象类
|
||||
from src.training.components.base import BaseModel, BaseProcessor
|
||||
|
||||
# 注册中心
|
||||
from src.training.registry import (
|
||||
ModelRegistry,
|
||||
ProcessorRegistry,
|
||||
register_model,
|
||||
register_processor,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
# 管道函数
|
||||
"prepare_data",
|
||||
"create_pipeline",
|
||||
"train_model",
|
||||
"predict_top_stocks",
|
||||
"save_top_stocks",
|
||||
"run_training",
|
||||
# 基础抽象类
|
||||
"BaseModel",
|
||||
"BaseProcessor",
|
||||
# 注册中心
|
||||
"ModelRegistry",
|
||||
"ProcessorRegistry",
|
||||
"register_model",
|
||||
"register_processor",
|
||||
]
|
||||
|
||||
12
src/training/components/__init__.py
Normal file
12
src/training/components/__init__.py
Normal file
@@ -0,0 +1,12 @@
|
||||
"""训练组件子模块
|
||||
|
||||
包含模型、处理器、划分器、选择器等组件。
|
||||
"""
|
||||
|
||||
# 基础抽象类
|
||||
from src.training.components.base import BaseModel, BaseProcessor
|
||||
|
||||
__all__ = [
|
||||
"BaseModel",
|
||||
"BaseProcessor",
|
||||
]
|
||||
141
src/training/components/base.py
Normal file
141
src/training/components/base.py
Normal file
@@ -0,0 +1,141 @@
|
||||
"""基础抽象类定义
|
||||
|
||||
定义 BaseModel 和 BaseProcessor 抽象基类,
|
||||
为所有训练组件提供统一的接口。
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional
|
||||
import pickle
|
||||
|
||||
import polars as pl
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
|
||||
class BaseModel(ABC):
|
||||
"""模型基类
|
||||
|
||||
所有机器学习模型必须继承此类并实现抽象方法。
|
||||
提供统一的训练、预测、特征重要性和持久化接口。
|
||||
|
||||
Attributes:
|
||||
name: 模型名称,子类必须定义
|
||||
"""
|
||||
|
||||
name: str = "" # 模型名称
|
||||
|
||||
@abstractmethod
|
||||
def fit(self, X: pl.DataFrame, y: pl.Series) -> "BaseModel":
|
||||
"""训练模型
|
||||
|
||||
Args:
|
||||
X: 特征矩阵 (Polars DataFrame)
|
||||
y: 目标变量 (Polars Series)
|
||||
|
||||
Returns:
|
||||
self (支持链式调用)
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def predict(self, X: pl.DataFrame) -> np.ndarray:
|
||||
"""预测
|
||||
|
||||
Args:
|
||||
X: 特征矩阵 (Polars DataFrame)
|
||||
|
||||
Returns:
|
||||
预测结果 (numpy ndarray)
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def feature_importance(self) -> Optional[pd.Series]:
|
||||
"""特征重要性
|
||||
|
||||
Returns:
|
||||
特征重要性序列,如果不支持则返回 None
|
||||
"""
|
||||
return None
|
||||
|
||||
def save(self, path: str) -> None:
|
||||
"""保存模型到文件
|
||||
|
||||
默认实现使用 pickle 序列化,子类可覆盖以使用更高效的格式。
|
||||
|
||||
Args:
|
||||
path: 保存路径
|
||||
|
||||
Raises:
|
||||
RuntimeError: 模型未训练时调用
|
||||
"""
|
||||
with open(path, "wb") as f:
|
||||
pickle.dump(self, f)
|
||||
|
||||
@classmethod
|
||||
def load(cls, path: str) -> "BaseModel":
|
||||
"""从文件加载模型
|
||||
|
||||
Args:
|
||||
path: 模型文件路径
|
||||
|
||||
Returns:
|
||||
加载的模型实例
|
||||
"""
|
||||
with open(path, "rb") as f:
|
||||
return pickle.load(f)
|
||||
|
||||
|
||||
class BaseProcessor(ABC):
|
||||
"""数据处理器基类
|
||||
|
||||
重要:Processor 在不同阶段行为不同:
|
||||
- 训练阶段:fit_transform(学习参数并应用)
|
||||
- 验证/测试阶段:transform(使用训练阶段学到的参数)
|
||||
|
||||
这意味着 Processor 实例会在训练后被保存,
|
||||
用于后续的验证和测试数据转换。
|
||||
|
||||
Attributes:
|
||||
name: 处理器名称,子类必须定义
|
||||
"""
|
||||
|
||||
name: str = ""
|
||||
|
||||
def fit(self, X: pl.DataFrame) -> "BaseProcessor":
|
||||
"""学习参数(仅在训练阶段调用)
|
||||
|
||||
子类应覆盖此方法以学习统计参数(如均值、标准差等)。
|
||||
|
||||
Args:
|
||||
X: 训练数据 (Polars DataFrame)
|
||||
|
||||
Returns:
|
||||
self (支持链式调用)
|
||||
"""
|
||||
return self
|
||||
|
||||
@abstractmethod
|
||||
def transform(self, X: pl.DataFrame) -> pl.DataFrame:
|
||||
"""转换数据
|
||||
|
||||
Args:
|
||||
X: 输入数据 (Polars DataFrame)
|
||||
|
||||
Returns:
|
||||
转换后的数据 (Polars DataFrame)
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def fit_transform(self, X: pl.DataFrame) -> pl.DataFrame:
|
||||
"""拟合并转换(训练阶段使用)
|
||||
|
||||
先调用 fit 学习参数,然后调用 transform 应用转换。
|
||||
|
||||
Args:
|
||||
X: 训练数据 (Polars DataFrame)
|
||||
|
||||
Returns:
|
||||
转换后的数据 (Polars DataFrame)
|
||||
"""
|
||||
return self.fit(X).transform(X)
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,667 +0,0 @@
|
||||
"""训练管道 - 包含数据处理、模型训练和预测功能
|
||||
|
||||
本模块提供:
|
||||
1. 数据准备:使用 FactorEngine 从因子计算结果中准备训练/测试数据
|
||||
2. 数据处理:Fillna(0) -> Dropna
|
||||
3. 模型训练:使用LightGBM训练分类模型
|
||||
4. 预测和选股:输出每日top5股票池
|
||||
|
||||
注意:本模块使用 src.factors 框架进行因子计算,确保防泄露机制生效。
|
||||
|
||||
因子配置示例:
|
||||
from src.factors import MovingAverageFactor, ReturnRankFactor
|
||||
from src.training.pipeline import prepare_data
|
||||
|
||||
# 直接传入因子实例列表 - 简单直观
|
||||
factors = [
|
||||
MovingAverageFactor(period=5),
|
||||
MovingAverageFactor(period=10),
|
||||
ReturnRankFactor(period=5),
|
||||
]
|
||||
|
||||
train_data, val_data, test_data, factor_config = prepare_data(
|
||||
factors=factors,
|
||||
...
|
||||
)
|
||||
|
||||
# 或者使用 FactorConfig 包装(支持链式添加)
|
||||
from src.training.pipeline import FactorConfig
|
||||
|
||||
config = FactorConfig()
|
||||
.add(MovingAverageFactor(period=5))
|
||||
.add(MovingAverageFactor(period=10))
|
||||
.add(ReturnRankFactor(period=5))
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import List, Optional
|
||||
|
||||
import numpy as np
|
||||
import polars as pl
|
||||
|
||||
from src.factors import DataLoader, FactorEngine, BaseFactor
|
||||
from src.factors.data_spec import DataSpec
|
||||
from src.factors.momentum import MovingAverageFactor, ReturnRankFactor
|
||||
from src.pipeline import (
|
||||
DropNAProcessor,
|
||||
FillNAProcessor,
|
||||
LightGBMModel,
|
||||
PipelineStage,
|
||||
ProcessingPipeline,
|
||||
TaskType,
|
||||
)
|
||||
|
||||
|
||||
# ========== 因子配置类 ==========
|
||||
|
||||
|
||||
class FactorConfig:
|
||||
"""因子配置类 - 管理因子列表
|
||||
|
||||
用于包装因子实例列表,支持链式添加。
|
||||
|
||||
示例:
|
||||
# 方式1:初始化时传入列表
|
||||
config = FactorConfig([
|
||||
MovingAverageFactor(period=5),
|
||||
ReturnRankFactor(period=5),
|
||||
])
|
||||
|
||||
# 方式2:链式添加
|
||||
config = FactorConfig()
|
||||
.add(MovingAverageFactor(period=5))
|
||||
.add(ReturnRankFactor(period=5))
|
||||
|
||||
# 获取因子实例列表
|
||||
factors = config.get_factors()
|
||||
|
||||
# 获取特征列名
|
||||
feature_cols = config.get_feature_names()
|
||||
"""
|
||||
|
||||
def __init__(self, factors: Optional[List[BaseFactor]] = None):
|
||||
"""初始化因子配置
|
||||
|
||||
Args:
|
||||
factors: 因子实例列表
|
||||
"""
|
||||
self._factors: List[BaseFactor] = factors or []
|
||||
|
||||
def add(self, factor: BaseFactor) -> "FactorConfig":
|
||||
"""添加因子到配置
|
||||
|
||||
支持链式调用:
|
||||
config = FactorConfig()
|
||||
.add(MovingAverageFactor(period=5))
|
||||
.add(ReturnRankFactor(period=5))
|
||||
|
||||
Args:
|
||||
factor: 因子实例
|
||||
|
||||
Returns:
|
||||
self,支持链式调用
|
||||
"""
|
||||
if not isinstance(factor, BaseFactor):
|
||||
raise ValueError(f"必须是 BaseFactor 实例, got {type(factor)}")
|
||||
self._factors.append(factor)
|
||||
return self
|
||||
|
||||
def get_factors(self) -> List[BaseFactor]:
|
||||
"""获取因子实例列表
|
||||
|
||||
Returns:
|
||||
因子实例列表
|
||||
"""
|
||||
return self._factors
|
||||
|
||||
def get_feature_names(self) -> List[str]:
|
||||
"""获取所有因子的特征列名
|
||||
|
||||
Returns:
|
||||
特征列名列表
|
||||
"""
|
||||
return [f.name for f in self._factors]
|
||||
|
||||
def get_max_lookback(self) -> int:
|
||||
"""获取所有因子中最大的 lookback 天数
|
||||
|
||||
Returns:
|
||||
最大 lookback 天数
|
||||
"""
|
||||
max_lookback = 0
|
||||
for factor in self._factors:
|
||||
for spec in factor.data_specs:
|
||||
max_lookback = max(max_lookback, spec.lookback_days)
|
||||
return max_lookback
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self._factors)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
names = [f.name for f in self._factors]
|
||||
return f"FactorConfig({names})"
|
||||
|
||||
|
||||
def prepare_data(
|
||||
factors: Optional[List[BaseFactor]] = None,
|
||||
data_dir: str = "data",
|
||||
train_start: str = "20180101",
|
||||
train_end: str = "20230101",
|
||||
val_start: str = "20230101",
|
||||
val_end: str = "20230601",
|
||||
test_start: str = "20230601",
|
||||
test_end: str = "20240101",
|
||||
) -> tuple[pl.DataFrame, pl.DataFrame, pl.DataFrame, FactorConfig]:
|
||||
"""准备训练、验证和测试数据
|
||||
|
||||
使用 FactorEngine 计算因子,确保防泄露机制生效。
|
||||
|
||||
Args:
|
||||
factors: 因子实例列表,默认为 None(使用 MA5, MA10, ReturnRank5)
|
||||
data_dir: 数据目录
|
||||
train_start: 训练集开始日期
|
||||
train_end: 训练集结束日期
|
||||
val_start: 验证集开始日期
|
||||
val_end: 验证集结束日期
|
||||
test_start: 测试集开始日期
|
||||
test_end: 测试集结束日期
|
||||
|
||||
Returns:
|
||||
(train_data, val_data, test_data, factor_config):
|
||||
训练集、验证集、测试集的DataFrame,以及使用的因子配置
|
||||
"""
|
||||
from src.data.storage import Storage
|
||||
|
||||
storage = Storage()
|
||||
|
||||
# 1. 处理因子配置
|
||||
if factors is None:
|
||||
# 默认因子配置
|
||||
factor_config = FactorConfig(
|
||||
[
|
||||
MovingAverageFactor(period=5),
|
||||
MovingAverageFactor(period=10),
|
||||
ReturnRankFactor(period=5),
|
||||
]
|
||||
)
|
||||
elif isinstance(factors, FactorConfig):
|
||||
factor_config = factors
|
||||
elif isinstance(factors, list):
|
||||
# 转换为 FactorConfig
|
||||
factor_config = FactorConfig(factors)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"factors 必须是 List[BaseFactor] 或 FactorConfig, got {type(factors)}"
|
||||
)
|
||||
|
||||
factors_list = factor_config.get_factors()
|
||||
feature_cols = factor_config.get_feature_names()
|
||||
|
||||
if not factors_list:
|
||||
raise ValueError("至少需要提供一个因子")
|
||||
|
||||
print(f"[PrepareData] 使用因子: {feature_cols}")
|
||||
|
||||
# 2. 初始化 FactorEngine
|
||||
loader = DataLoader(data_dir=data_dir)
|
||||
engine = FactorEngine(loader)
|
||||
|
||||
# 3. 计算需要的回溯天数
|
||||
max_lookback = factor_config.get_max_lookback()
|
||||
start_with_lookback = str(int(train_start) - 10000) # 往前取一年
|
||||
|
||||
# 获取所有股票列表
|
||||
start_dt = f"{start_with_lookback[:4]}-{start_with_lookback[4:6]}-{start_with_lookback[6:8]}"
|
||||
end_dt = f"{test_end[:4]}-{test_end[4:6]}-{test_end[6:8]}"
|
||||
all_stocks_query = f"""
|
||||
SELECT DISTINCT ts_code FROM daily
|
||||
WHERE trade_date >= '{start_dt}' AND trade_date <= '{end_dt}'
|
||||
"""
|
||||
all_stocks_df = storage._connection.sql(all_stocks_query).pl()
|
||||
all_stocks = all_stocks_df["ts_code"].to_list()
|
||||
|
||||
print(f"[PrepareData] 股票数量: {len(all_stocks)}")
|
||||
|
||||
# 4. 计算所有因子并合并
|
||||
all_features = None
|
||||
|
||||
for factor in factors_list:
|
||||
print(f"[PrepareData] 计算因子: {factor.name} ({factor.factor_type})")
|
||||
|
||||
if factor.factor_type == "time_series":
|
||||
# 时序因子需要传入股票列表
|
||||
result = engine.compute(
|
||||
factor,
|
||||
stock_codes=all_stocks,
|
||||
start_date=start_with_lookback,
|
||||
end_date=test_end,
|
||||
)
|
||||
else:
|
||||
# 截面因子不需要股票列表
|
||||
result = engine.compute(
|
||||
factor,
|
||||
start_date=start_with_lookback,
|
||||
end_date=test_end,
|
||||
)
|
||||
|
||||
# 合并结果
|
||||
if all_features is None:
|
||||
all_features = result
|
||||
else:
|
||||
# 确保没有重复的 _right 列
|
||||
result = result.select([
|
||||
c for c in result.columns if not c.endswith("_right")
|
||||
])
|
||||
all_features = all_features.select([
|
||||
c for c in all_features.columns if not c.endswith("_right")
|
||||
])
|
||||
all_features = all_features.join(
|
||||
result, on=["trade_date", "ts_code"], how="outer"
|
||||
)
|
||||
|
||||
if all_features is None:
|
||||
raise ValueError("没有计算任何因子")
|
||||
|
||||
# 5. 计算标签(未来5日收益率)
|
||||
all_data = _compute_label(all_features, start_date=train_start, end_date=test_end)
|
||||
|
||||
# 6. 过滤不符合条件的股票
|
||||
all_data = _filter_invalid_stocks(all_data)
|
||||
print(f"[PrepareData] After filtering: total={len(all_data)}")
|
||||
|
||||
# 转换日期格式用于比较
|
||||
train_start_fmt = f"{train_start[:4]}-{train_start[4:6]}-{train_start[6:8]}"
|
||||
train_end_fmt = f"{train_end[:4]}-{train_end[4:6]}-{train_end[6:8]}"
|
||||
val_start_fmt = f"{val_start[:4]}-{val_start[4:6]}-{val_start[6:8]}"
|
||||
val_end_fmt = f"{val_end[:4]}-{val_end[4:6]}-{val_end[6:8]}"
|
||||
test_start_fmt = f"{test_start[:4]}-{test_start[4:6]}-{test_start[6:8]}"
|
||||
test_end_fmt = f"{test_end[:4]}-{test_end[4:6]}-{test_end[6:8]}"
|
||||
|
||||
# 拆分数据
|
||||
train_data = all_data.filter(
|
||||
(pl.col("trade_date") >= train_start_fmt)
|
||||
& (pl.col("trade_date") <= train_end_fmt)
|
||||
)
|
||||
val_data = all_data.filter(
|
||||
(pl.col("trade_date") >= val_start_fmt) & (pl.col("trade_date") <= val_end_fmt)
|
||||
)
|
||||
test_data = all_data.filter(
|
||||
(pl.col("trade_date") >= test_start_fmt)
|
||||
& (pl.col("trade_date") <= test_end_fmt)
|
||||
)
|
||||
|
||||
print(
|
||||
f"[PrepareData] Split result: train={len(train_data)}, val={len(val_data)}, test={len(test_data)}"
|
||||
)
|
||||
|
||||
return train_data, val_data, test_data, factor_config
|
||||
|
||||
|
||||
def _filter_invalid_stocks(df: pl.DataFrame) -> pl.DataFrame:
|
||||
"""过滤不符合条件的股票
|
||||
|
||||
过滤规则:
|
||||
1. 过滤北交所股票(ts_code 以 BJ 结尾)
|
||||
2. 过滤创业板股票(ts_code 以 30 开头)
|
||||
3. 过滤科创板股票(ts_code 以 68 开头)
|
||||
4. 过滤退市/风险股票(ts_code 以 8 开头)
|
||||
|
||||
Args:
|
||||
df: 原始数据
|
||||
|
||||
Returns:
|
||||
过滤后的数据
|
||||
"""
|
||||
ts_code_col = pl.col("ts_code")
|
||||
|
||||
return df.filter(
|
||||
~ts_code_col.str.ends_with("BJ")
|
||||
& ~ts_code_col.str.starts_with("30")
|
||||
& ~ts_code_col.str.starts_with("68")
|
||||
& ~ts_code_col.str.starts_with("8")
|
||||
)
|
||||
|
||||
|
||||
def _compute_label(
|
||||
features_df: pl.DataFrame,
|
||||
start_date: str,
|
||||
end_date: str,
|
||||
) -> pl.DataFrame:
|
||||
"""计算标签(未来5日收益率)
|
||||
|
||||
标签定义:未来5日收益率大于0为1,否则为0
|
||||
|
||||
Args:
|
||||
features_df: 包含因子的DataFrame
|
||||
start_date: 开始日期
|
||||
end_date: 结束日期
|
||||
|
||||
Returns:
|
||||
包含因子和标签的DataFrame
|
||||
"""
|
||||
from src.data.storage import Storage
|
||||
|
||||
storage = Storage()
|
||||
|
||||
# 从数据库获取收盘价数据用于计算标签
|
||||
start_dt = f"{start_date[:4]}-{start_date[4:6]}-{start_date[6:8]}"
|
||||
end_dt = f"{end_date[:4]}-{end_date[4:6]}-{end_date[6:8]}"
|
||||
|
||||
# 需要多取5天数据来计算未来收益率
|
||||
end_dt_extended = f"{end_date[:4]}-{end_date[4:6]}-{int(end_date[6:8]) + 5}"
|
||||
|
||||
price_query = f"""
|
||||
SELECT ts_code, trade_date, close
|
||||
FROM daily
|
||||
WHERE trade_date >= '{start_dt}' AND trade_date <= '{end_dt_extended}'
|
||||
ORDER BY ts_code, trade_date
|
||||
"""
|
||||
price_data = storage._connection.sql(price_query).pl()
|
||||
price_data = price_data.with_columns(
|
||||
pl.col("trade_date").dt.strftime("%Y-%m-%d").alias("trade_date")
|
||||
)
|
||||
|
||||
# 按股票计算未来5日收益率
|
||||
result_list = []
|
||||
for ts_code in price_data["ts_code"].unique():
|
||||
stock_data = price_data.filter(pl.col("ts_code") == ts_code).sort("trade_date")
|
||||
|
||||
if len(stock_data) < 6:
|
||||
continue
|
||||
|
||||
# 计算未来5日收益率
|
||||
future_return = stock_data["close"].shift(-5) - stock_data["close"]
|
||||
future_return_pct = future_return / stock_data["close"]
|
||||
stock_data = stock_data.with_columns(
|
||||
[
|
||||
future_return_pct.alias("future_return_5"),
|
||||
]
|
||||
)
|
||||
|
||||
# 生成标签:收益率>0为1,否则为0
|
||||
stock_data = stock_data.with_columns(
|
||||
[
|
||||
(pl.col("future_return_5") > 0).cast(pl.Int8).alias("label"),
|
||||
]
|
||||
)
|
||||
|
||||
result_list.append(stock_data.select(["trade_date", "ts_code", "label"]))
|
||||
|
||||
if not result_list:
|
||||
return pl.DataFrame()
|
||||
|
||||
label_df = pl.concat(result_list)
|
||||
|
||||
# 将标签合并到因子数据
|
||||
result = features_df.join(label_df, on=["trade_date", "ts_code"], how="inner")
|
||||
|
||||
# 过滤有效日期范围
|
||||
result = result.filter(
|
||||
(pl.col("trade_date") >= start_dt) & (pl.col("trade_date") <= end_dt)
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def create_pipeline() -> ProcessingPipeline:
|
||||
"""创建数据处理流水线
|
||||
|
||||
处理流程:
|
||||
1. FillNA(0): 将缺失值填充为0
|
||||
|
||||
注意:不使用 Dropna,因为会导致训练和预测时的行数不匹配
|
||||
|
||||
Returns:
|
||||
配置好的ProcessingPipeline
|
||||
"""
|
||||
processors = [
|
||||
FillNAProcessor(method="zero"), # 缺失值填充为0
|
||||
]
|
||||
return ProcessingPipeline(processors)
|
||||
|
||||
|
||||
def train_model(
|
||||
train_data: pl.DataFrame,
|
||||
val_data: Optional[pl.DataFrame],
|
||||
feature_cols: List[str],
|
||||
label_col: str = "label",
|
||||
model_params: Optional[dict] = None,
|
||||
) -> tuple[LightGBMModel, ProcessingPipeline]:
|
||||
"""训练LightGBM分类模型
|
||||
|
||||
Args:
|
||||
train_data: 训练数据
|
||||
val_data: 验证数据(用于早停)
|
||||
feature_cols: 特征列名列表
|
||||
label_col: 标签列名
|
||||
model_params: 模型参数字典
|
||||
|
||||
Returns:
|
||||
(训练好的模型, 处理流水线)
|
||||
"""
|
||||
# 创建处理流水线
|
||||
pipeline = create_pipeline()
|
||||
print("[TrainModel] Pipeline created: FillNA(0)")
|
||||
|
||||
# 准备训练特征和标签
|
||||
X_train = train_data.select(feature_cols)
|
||||
y_train = train_data[label_col]
|
||||
print(f"[TrainModel] Train samples: {len(X_train)}, features: {feature_cols}")
|
||||
|
||||
# 处理训练数据
|
||||
X_train_processed = pipeline.fit_transform(X_train, stage=PipelineStage.TRAIN)
|
||||
print(f"[TrainModel] After processing: {len(X_train_processed)} samples")
|
||||
|
||||
# 过滤训练集有效标签(排除-1等无效值)
|
||||
valid_mask = y_train.is_in([0, 1])
|
||||
X_train_processed = X_train_processed.filter(valid_mask)
|
||||
y_train = y_train.filter(valid_mask)
|
||||
print(
|
||||
f"[TrainModel] After filtering valid labels: {len(X_train_processed)} samples"
|
||||
)
|
||||
print(
|
||||
f"[TrainModel] Train label distribution: {dict(y_train.value_counts().sort('label').iter_rows())}"
|
||||
)
|
||||
|
||||
# 准备验证集
|
||||
X_val_processed = None
|
||||
y_val = None
|
||||
if val_data is not None and len(val_data) > 0:
|
||||
X_val = val_data.select(feature_cols)
|
||||
y_val = val_data[label_col]
|
||||
print(f"[TrainModel] Val samples: {len(X_val)}")
|
||||
|
||||
# 处理验证集数据(使用训练集的参数)
|
||||
X_val_processed = pipeline.transform(X_val, stage=PipelineStage.TEST)
|
||||
|
||||
# 过滤验证集有效标签
|
||||
val_valid_mask = y_val.is_in([0, 1])
|
||||
X_val_processed = X_val_processed.filter(val_valid_mask)
|
||||
y_val = y_val.filter(val_valid_mask)
|
||||
print(f"[TrainModel] Val after filtering: {len(X_val_processed)} samples")
|
||||
print(
|
||||
f"[TrainModel] Val label distribution: {dict(y_val.value_counts().sort('label').iter_rows())}"
|
||||
)
|
||||
|
||||
# 创建模型
|
||||
params = model_params or {
|
||||
"n_estimators": 100,
|
||||
"learning_rate": 0.05,
|
||||
"max_depth": 5,
|
||||
"num_leaves": 31,
|
||||
}
|
||||
print(f"[TrainModel] Model params: {params}")
|
||||
model = LightGBMModel(
|
||||
task_type="classification",
|
||||
params=params,
|
||||
)
|
||||
|
||||
# 训练模型(使用验证集早停)
|
||||
print("[TrainModel] Training LightGBM...")
|
||||
if X_val_processed is not None and y_val is not None:
|
||||
print("[TrainModel] Using validation set for early stopping")
|
||||
model.fit(X_train_processed, y_train, X_val_processed, y_val)
|
||||
else:
|
||||
model.fit(X_train_processed, y_train)
|
||||
print("[TrainModel] Training completed!")
|
||||
|
||||
return model, pipeline
|
||||
|
||||
|
||||
def predict_top_stocks(
|
||||
model: LightGBMModel,
|
||||
pipeline: ProcessingPipeline,
|
||||
test_data: pl.DataFrame,
|
||||
feature_cols: List[str],
|
||||
top_n: int = 5,
|
||||
) -> pl.DataFrame:
|
||||
"""预测并选出每日top N股票
|
||||
|
||||
Args:
|
||||
model: 训练好的模型
|
||||
pipeline: 数据处理流水线
|
||||
test_data: 测试数据
|
||||
feature_cols: 特征列名
|
||||
top_n: 每日选出的股票数量
|
||||
|
||||
Returns:
|
||||
包含日期和股票代码的DataFrame
|
||||
"""
|
||||
# 准备特征和必要列
|
||||
X_test = test_data.select(feature_cols)
|
||||
key_cols = ["trade_date", "ts_code"]
|
||||
key_data = test_data.select(key_cols)
|
||||
|
||||
print(f"[Predict] Test samples: {len(X_test)}, top_n: {top_n}")
|
||||
|
||||
# 处理数据(使用训练阶段的参数)
|
||||
X_test_processed = pipeline.transform(X_test, stage=PipelineStage.TEST)
|
||||
print(f"[Predict] Data processed, shape: {X_test_processed.shape}")
|
||||
|
||||
# 预测概率
|
||||
probs = model.predict_proba(X_test_processed)
|
||||
print(f"[Predict] Predictions generated, probability shape: {probs.shape}")
|
||||
|
||||
# 使用 key_data 添加预测结果,保持行数一致
|
||||
result = key_data.with_columns(
|
||||
pl.Series(
|
||||
name="pred_prob",
|
||||
values=probs[:, 1]
|
||||
if len(probs.shape) > 1 and probs.shape[1] > 1
|
||||
else probs.flatten(),
|
||||
),
|
||||
)
|
||||
|
||||
# 每日选出top N
|
||||
top_stocks = []
|
||||
for date in result["trade_date"].unique().sort():
|
||||
day_data = result.filter(pl.col("trade_date") == date)
|
||||
|
||||
# 按概率降序排序,选出top N
|
||||
day_top = day_data.sort("pred_prob", descending=True).head(top_n)
|
||||
|
||||
top_stocks.append(
|
||||
day_top.select(["trade_date", "pred_prob", "ts_code"]).rename(
|
||||
{"pred_prob": "score"}
|
||||
)
|
||||
)
|
||||
|
||||
return pl.concat(top_stocks)
|
||||
|
||||
|
||||
def save_top_stocks(top_stocks: pl.DataFrame, output_path: str) -> None:
|
||||
"""保存选股结果到TSV文件
|
||||
|
||||
Args:
|
||||
top_stocks: 选股结果
|
||||
output_path: 输出文件路径
|
||||
"""
|
||||
# 转换为pandas并保存为TSV
|
||||
df = top_stocks.to_pandas()
|
||||
df.to_csv(output_path, sep="\t", index=False)
|
||||
print(f"[Training] Top stocks saved to: {output_path}")
|
||||
|
||||
|
||||
def run_training(
|
||||
factors: Optional[List[BaseFactor]] = None,
|
||||
data_dir: str = "data",
|
||||
output_path: str = "output/top_stocks.tsv",
|
||||
train_start: str = "20180101",
|
||||
train_end: str = "20230101",
|
||||
val_start: str = "20230101",
|
||||
val_end: str = "20230601",
|
||||
test_start: str = "20230601",
|
||||
test_end: str = "20240101",
|
||||
top_n: int = 5,
|
||||
) -> pl.DataFrame:
|
||||
"""运行完整训练流程
|
||||
|
||||
Args:
|
||||
factors: 因子实例列表,默认为 None(使用 MA5, MA10, ReturnRank5)
|
||||
data_dir: 数据目录
|
||||
output_path: 输出文件路径
|
||||
train_start: 训练集开始日期
|
||||
train_end: 训练集结束日期
|
||||
val_start: 验证集开始日期
|
||||
val_end: 验证集结束日期
|
||||
test_start: 测试集开始日期
|
||||
test_end: 测试集结束日期
|
||||
top_n: 每日选股数量
|
||||
|
||||
Returns:
|
||||
选股结果DataFrame
|
||||
"""
|
||||
print(f"[Training] Starting training pipeline...")
|
||||
print(f"[Training] Train period: {train_start} -> {train_end}")
|
||||
print(f"[Training] Val period: {val_start} -> {val_end}")
|
||||
print(f"[Training] Test period: {test_start} -> {test_end}")
|
||||
|
||||
# 1. 准备数据
|
||||
print("[Training] Preparing data...")
|
||||
train_data, val_data, test_data, factor_config = prepare_data(
|
||||
factors=factors,
|
||||
data_dir=data_dir,
|
||||
train_start=train_start,
|
||||
train_end=train_end,
|
||||
val_start=val_start,
|
||||
val_end=val_end,
|
||||
test_start=test_start,
|
||||
test_end=test_end,
|
||||
)
|
||||
print(f"[Training] Train samples: {len(train_data)}")
|
||||
print(f"[Training] Val samples: {len(val_data)}")
|
||||
print(f"[Training] Test samples: {len(test_data)}")
|
||||
|
||||
# 2. 获取特征列名
|
||||
feature_cols = factor_config.get_feature_names()
|
||||
label_col = "label"
|
||||
print(f"[Training] Feature columns: {feature_cols}")
|
||||
|
||||
# 3. 训练模型
|
||||
print("[Training] Training model...")
|
||||
model, pipeline = train_model(
|
||||
train_data=train_data,
|
||||
val_data=val_data,
|
||||
feature_cols=feature_cols,
|
||||
label_col=label_col,
|
||||
)
|
||||
|
||||
# 4. 测试集预测
|
||||
print("[Training] Predicting on test set...")
|
||||
top_stocks = predict_top_stocks(
|
||||
model=model,
|
||||
pipeline=pipeline,
|
||||
test_data=test_data,
|
||||
feature_cols=feature_cols,
|
||||
top_n=top_n,
|
||||
)
|
||||
|
||||
# 5. 保存结果
|
||||
print(f"[Training] Saving results to {output_path}...")
|
||||
Path(output_path).parent.mkdir(parents=True, exist_ok=True)
|
||||
save_top_stocks(top_stocks, output_path)
|
||||
|
||||
print("[Training] Training completed!")
|
||||
|
||||
return top_stocks
|
||||
184
src/training/registry.py
Normal file
184
src/training/registry.py
Normal file
@@ -0,0 +1,184 @@
|
||||
"""组件注册中心
|
||||
|
||||
提供装饰器风格的组件注册机制,支持即插即用。
|
||||
"""
|
||||
|
||||
from typing import Dict, Type, Callable, Any
|
||||
from src.training.components.base import BaseModel, BaseProcessor
|
||||
|
||||
|
||||
class ModelRegistry:
|
||||
"""模型注册中心
|
||||
|
||||
管理所有可用的模型类,支持通过名称获取模型类。
|
||||
|
||||
Example:
|
||||
>>> @register_model("lightgbm")
|
||||
... class LightGBMModel(BaseModel):
|
||||
... pass
|
||||
>>>
|
||||
>>> model_class = ModelRegistry.get_model("lightgbm")
|
||||
>>> model = model_class(**params)
|
||||
"""
|
||||
|
||||
_registry: Dict[str, Type[BaseModel]] = {}
|
||||
|
||||
@classmethod
|
||||
def register(cls, name: str, model_class: Type[BaseModel]) -> None:
|
||||
"""注册模型类
|
||||
|
||||
Args:
|
||||
name: 模型名称
|
||||
model_class: 模型类(必须继承 BaseModel)
|
||||
|
||||
Raises:
|
||||
ValueError: 名称已被注册或类不继承 BaseModel
|
||||
"""
|
||||
if name in cls._registry:
|
||||
raise ValueError(f"模型 '{name}' 已被注册")
|
||||
if not issubclass(model_class, BaseModel):
|
||||
raise ValueError(f"模型类必须继承 BaseModel")
|
||||
cls._registry[name] = model_class
|
||||
|
||||
@classmethod
|
||||
def get_model(cls, name: str) -> Type[BaseModel]:
|
||||
"""获取模型类
|
||||
|
||||
Args:
|
||||
name: 模型名称
|
||||
|
||||
Returns:
|
||||
模型类
|
||||
|
||||
Raises:
|
||||
KeyError: 未找到该名称的模型
|
||||
"""
|
||||
if name not in cls._registry:
|
||||
available = ", ".join(cls._registry.keys())
|
||||
raise KeyError(f"未知模型 '{name}',可用模型: {available}")
|
||||
return cls._registry[name]
|
||||
|
||||
@classmethod
|
||||
def list_models(cls) -> list[str]:
|
||||
"""列出所有已注册的模型名称"""
|
||||
return list(cls._registry.keys())
|
||||
|
||||
@classmethod
|
||||
def clear(cls) -> None:
|
||||
"""清空注册表(主要用于测试)"""
|
||||
cls._registry.clear()
|
||||
|
||||
|
||||
class ProcessorRegistry:
|
||||
"""处理器注册中心
|
||||
|
||||
管理所有可用的数据处理器类,支持通过名称获取处理器类。
|
||||
|
||||
Example:
|
||||
>>> @register_processor("standard_scaler")
|
||||
... class StandardScaler(BaseProcessor):
|
||||
... pass
|
||||
>>>
|
||||
>>> processor_class = ProcessorRegistry.get_processor("standard_scaler")
|
||||
>>> processor = processor_class(**params)
|
||||
"""
|
||||
|
||||
_registry: Dict[str, Type[BaseProcessor]] = {}
|
||||
|
||||
@classmethod
|
||||
def register(cls, name: str, processor_class: Type[BaseProcessor]) -> None:
|
||||
"""注册处理器类
|
||||
|
||||
Args:
|
||||
name: 处理器名称
|
||||
processor_class: 处理器类(必须继承 BaseProcessor)
|
||||
|
||||
Raises:
|
||||
ValueError: 名称已被注册或类不继承 BaseProcessor
|
||||
"""
|
||||
if name in cls._registry:
|
||||
raise ValueError(f"处理器 '{name}' 已被注册")
|
||||
if not issubclass(processor_class, BaseProcessor):
|
||||
raise ValueError(f"处理器类必须继承 BaseProcessor")
|
||||
cls._registry[name] = processor_class
|
||||
|
||||
@classmethod
|
||||
def get_processor(cls, name: str) -> Type[BaseProcessor]:
|
||||
"""获取处理器类
|
||||
|
||||
Args:
|
||||
name: 处理器名称
|
||||
|
||||
Returns:
|
||||
处理器类
|
||||
|
||||
Raises:
|
||||
KeyError: 未找到该名称的处理器
|
||||
"""
|
||||
if name not in cls._registry:
|
||||
available = ", ".join(cls._registry.keys())
|
||||
raise KeyError(f"未知处理器 '{name}',可用处理器: {available}")
|
||||
return cls._registry[name]
|
||||
|
||||
@classmethod
|
||||
def list_processors(cls) -> list[str]:
|
||||
"""列出所有已注册的处理器名称"""
|
||||
return list(cls._registry.keys())
|
||||
|
||||
@classmethod
|
||||
def clear(cls) -> None:
|
||||
"""清空注册表(主要用于测试)"""
|
||||
cls._registry.clear()
|
||||
|
||||
|
||||
def register_model(name: str) -> Callable[[Type[BaseModel]], Type[BaseModel]]:
|
||||
"""模型注册装饰器
|
||||
|
||||
用于装饰继承 BaseModel 的类,将其注册到 ModelRegistry。
|
||||
|
||||
Args:
|
||||
name: 模型名称
|
||||
|
||||
Returns:
|
||||
装饰器函数
|
||||
|
||||
Example:
|
||||
>>> @register_model("lightgbm")
|
||||
... class LightGBMModel(BaseModel):
|
||||
... name = "lightgbm"
|
||||
... def fit(self, X, y): ...
|
||||
... def predict(self, X): ...
|
||||
"""
|
||||
|
||||
def decorator(cls: Type[BaseModel]) -> Type[BaseModel]:
|
||||
ModelRegistry.register(name, cls)
|
||||
return cls
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def register_processor(
|
||||
name: str,
|
||||
) -> Callable[[Type[BaseProcessor]], Type[BaseProcessor]]:
|
||||
"""处理器注册装饰器
|
||||
|
||||
用于装饰继承 BaseProcessor 的类,将其注册到 ProcessorRegistry。
|
||||
|
||||
Args:
|
||||
name: 处理器名称
|
||||
|
||||
Returns:
|
||||
装饰器函数
|
||||
|
||||
Example:
|
||||
>>> @register_processor("standard_scaler")
|
||||
... class StandardScaler(BaseProcessor):
|
||||
... name = "standard_scaler"
|
||||
... def transform(self, X): ...
|
||||
"""
|
||||
|
||||
def decorator(cls: Type[BaseProcessor]) -> Type[BaseProcessor]:
|
||||
ProcessorRegistry.register(name, cls)
|
||||
return cls
|
||||
|
||||
return decorator
|
||||
Reference in New Issue
Block a user