Files
ProStock/docs/plan/training_module_plan.md
liaozhaorun 192718095f feat(training): 实现训练模块核心组件(commits 6-9)
- StockPoolManager:每日独立筛选股票池,支持代码过滤和市值选择
- Trainer:整合训练完整流程,支持 processor 分阶段行为和模型持久化
- TrainingConfig:pydantic 配置管理,含必填字段和日期验证
- experiment 模块:预留结构
- 从计划中移除 metrics 组件
- 调整 commit 序号(7-10 → 6-9)
- 更新 training/__init__.py 导出所有公开 API
2026-03-03 22:57:01 +08:00

1026 lines
32 KiB
Markdown
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# 训练模块实现计划
## 1. 概述
本计划描述 ProStock 训练模块的完整实现方案,支持标准回归模型训练并输出预测结果。
### 1.1 目标
- 提供简洁的训练流程
- 支持标准回归模型LightGBM、CatBoost
- 输出可解释的预测结果
- 与现有因子引擎无缝集成
### 1.2 设计原则
- **职责分离**:训练流程与建模组件分离
- **注册机制**:使用装饰器实现即插即用
- **配置驱动**:所有参数通过配置类管理
- **阶段感知**Processor 在训练和测试阶段行为不同
- **每日筛选**:股票池每日独立筛选(市值动态变化)
- **模型持久化**:支持保存和加载训练好的模型
## 2. 模块结构
```
src/
├── training/ # 训练模块(流程 + 组件)
│ ├── __init__.py # 导出核心类
│ ├── core/ # 训练流程核心
│ │ ├── __init__.py
│ │ ├── trainer.py # Trainer 主类
│ │ └── stock_pool_manager.py # 股票池管理器(每日独立筛选)
│ ├── components/ # 建模组件
│ │ ├── __init__.py
│ │ ├── base.py # BaseModel, BaseProcessor 抽象基类
│ │ ├── splitters.py # 时间序列划分策略(一次性划分)
│ │ ├── selectors.py # 股票池选择器配置
│ │ ├── models/ # 模型实现
│ │ │ ├── __init__.py
│ │ │ ├── lightgbm.py # LightGBM 回归模型
│ │ │ └── catboost.py # CatBoost 回归模型
│ │ ├── processors/ # 数据处理器
│ │ │ ├── __init__.py
│ │ │ └── transforms.py # 标准化(截面/时序)、缩尾
│ ├── config/ # 配置管理
│ │ ├── __init__.py
│ │ └── config.py # TrainingConfig (pydantic)
│ └── registry.py # 组件注册中心
└── experiment/ # 实验管理(预留结构,暂不实现)
└── __init__.py
```
## 3. 核心组件设计
### 3.1 抽象基类 (components/base.py)
#### BaseModel
```python
class BaseModel(ABC):
"""模型基类"""
name: str = "" # 模型名称
def fit(self, X: pl.DataFrame, y: pl.Series) -> "BaseModel":
"""训练模型"""
raise NotImplementedError
def predict(self, X: pl.DataFrame) -> np.ndarray:
"""预测"""
raise NotImplementedError
def feature_importance(self) -> Optional[pd.Series]:
"""特征重要性"""
return None
def save(self, path: str) -> None:
"""保存模型到文件
默认实现使用 pickle子类可覆盖
"""
import pickle
with open(path, 'wb') as f:
pickle.dump(self, f)
@classmethod
def load(cls, path: str) -> "BaseModel":
"""从文件加载模型"""
import pickle
with open(path, 'rb') as f:
return pickle.load(f)
```
#### BaseProcessor
```python
class BaseProcessor(ABC):
"""数据处理器基类
重要Processor 在不同阶段行为不同:
- 训练阶段fit_transform学习参数并应用
- 验证/测试阶段transform使用训练阶段学到的参数
这意味着 Processor 实例会在训练后被保存,
用于后续的验证和测试数据转换。
"""
name: str = ""
def fit(self, X: pl.DataFrame) -> "BaseProcessor":
"""学习参数(仅在训练阶段调用)"""
return self
def transform(self, X: pl.DataFrame) -> pl.DataFrame:
"""转换数据"""
raise NotImplementedError
def fit_transform(self, X: pl.DataFrame) -> pl.DataFrame:
"""拟合并转换(训练阶段使用)"""
return self.fit(X).transform(X)
```
### 3.2 时间序列划分 (components/splitters.py)
**设计说明**:暂不实现滚动训练,采用一次性训练/测试划分。
#### DateSplitter
```python
class DateSplitter:
"""基于日期范围的一次性划分
将数据按日期划分为训练集和测试集,不滚动。
示例:
train_start: "20200101", train_end: "20221231" (训练集3年)
test_start: "20230101", test_end: "20231231" (测试集1年)
特点:
- 一次性划分,不滚动
- 训练集和测试集互不重叠
- 基于实际日期范围,而非行数
"""
def __init__(
self,
train_start: str, # 训练期开始日期 "YYYYMMDD"
train_end: str, # 训练期结束日期 "YYYYMMDD"
test_start: str, # 测试期开始日期 "YYYYMMDD"
test_end: str, # 测试期结束日期 "YYYYMMDD"
):
self.train_start = train_start
self.train_end = train_end
self.test_start = test_start
self.test_end = test_end
def split(self, data: pl.DataFrame) -> Tuple[pl.DataFrame, pl.DataFrame]:
"""划分数据为训练集和测试集"""
train_data = data.filter(
(pl.col("trade_date") >= self.train_start) &
(pl.col("trade_date") <= self.train_end)
)
test_data = data.filter(
(pl.col("trade_date") >= self.test_start) &
(pl.col("trade_date") <= self.test_end)
)
return train_data, test_data
```
### 3.3 股票池选择器配置 (components/selectors.py)
**设计说明**:股票池每日独立筛选,市值选择需要配合 StockPoolManager 使用。
#### StockFilterConfig
```python
@dataclass
class StockFilterConfig:
"""股票过滤器配置
用于过滤掉不需要的股票(如创业板、科创板等)。
基于股票代码进行过滤,不依赖外部数据。
"""
exclude_cyb: bool = True # 是否排除创业板300xxx
exclude_kcb: bool = True # 是否排除科创板688xxx
exclude_bj: bool = True # 是否排除北交所8xxxxxx, 4xxxxxx
exclude_st: bool = True # 是否排除ST股票
def filter_codes(self, codes: List[str]) -> List[str]:
"""应用过滤条件,返回过滤后的股票代码列表"""
result = []
for code in codes:
if self.exclude_cyb and code.startswith("300"):
continue
if self.exclude_kcb and code.startswith("688"):
continue
if self.exclude_bj and (code.startswith("8") or code.startswith("4")):
continue
# ST 股票过滤需要额外数据,在 StockPoolManager 中处理
result.append(code)
return result
```
#### MarketCapSelectorConfig
```python
@dataclass
class MarketCapSelectorConfig:
"""市值选择器配置
每日独立选择市值最大或最小的 n 只股票。
市值数据从 daily_basic 表独立获取,仅用于筛选。
"""
enabled: bool = True # 是否启用选择
n: int = 100 # 选择前 n 只
ascending: bool = False # False=最大市值, True=最小市值
market_cap_col: str = "total_mv" # 市值列名(来自 daily_basic
```
### 3.4 股票池管理器 (core/stock_pool_manager.py)
**设计说明**:每日独立筛选股票池,市值数据从 daily_basic 表独立获取。
```python
class StockPoolManager:
"""股票池管理器 - 每日独立筛选
重要约束:
1. 市值数据仅从 daily_basic 表获取,仅用于筛选
2. 市值数据绝不混入特征矩阵
3. 每日独立筛选(市值是动态变化的)
处理流程(每日):
当日所有股票
代码过滤创业板、ST等
查询 daily_basic 获取当日市值
市值选择前N只
返回当日选中股票列表
"""
def __init__(
self,
filter_config: StockFilterConfig,
selector_config: Optional[MarketCapSelectorConfig],
data_router: DataRouter, # 用于获取 daily_basic 数据
code_col: str = "ts_code",
date_col: str = "trade_date",
):
self.filter_config = filter_config
self.selector_config = selector_config
self.data_router = data_router
self.code_col = code_col
self.date_col = date_col
def filter_and_select_daily(self, data: pl.DataFrame) -> pl.DataFrame:
"""每日独立筛选股票池
Args:
data: 因子计算后的全市场数据,必须包含 trade_date 和 ts_code 列
Returns:
筛选后的数据,仅包含每日选中的股票
注意:
- 按日期分组处理
- 市值数据从 daily_basic 独立获取
- 保持市值数据与特征数据隔离
"""
dates = data.select(self.date_col).unique().sort(self.date_col)
result_frames = []
for date in dates.to_series():
# 获取当日数据
daily_data = data.filter(pl.col(self.date_col) == date)
daily_codes = daily_data.select(self.code_col).to_series().to_list()
# 1. 代码过滤
filtered_codes = self.filter_config.filter_codes(daily_codes)
# 2. 市值选择(如果启用)
if self.selector_config and self.selector_config.enabled:
# 从 daily_basic 获取当日市值
market_caps = self._get_market_caps_for_date(filtered_codes, date)
selected_codes = self._select_by_market_cap(filtered_codes, market_caps)
else:
selected_codes = filtered_codes
# 3. 保留当日选中的股票数据
daily_selected = daily_data.filter(
pl.col(self.code_col).is_in(selected_codes)
)
result_frames.append(daily_selected)
return pl.concat(result_frames)
def _get_market_caps_for_date(
self,
codes: List[str],
date: str
) -> Dict[str, float]:
"""从 daily_basic 表获取指定日期的市值数据
Args:
codes: 股票代码列表
date: 日期 "YYYYMMDD"
Returns:
{股票代码: 市值} 的字典
"""
# 通过 data_router 查询 daily_basic 表
pass
def _select_by_market_cap(
self,
codes: List[str],
market_caps: Dict[str, float]
) -> List[str]:
"""根据市值选择股票"""
if not market_caps:
return codes
# 按市值排序并选择前N只
sorted_codes = sorted(
codes,
key=lambda c: market_caps.get(c, 0),
reverse=not self.selector_config.ascending
)
return sorted_codes[:self.selector_config.n]
```
### 3.5 模型实现 (components/models/)
#### LightGBMModel
```python
@register_model("lightgbm")
class LightGBMModel(BaseModel):
"""LightGBM 回归模型"""
name = "lightgbm"
def __init__(self,
objective: str = "regression",
metric: str = "rmse",
num_leaves: int = 31,
learning_rate: float = 0.05,
n_estimators: int = 100,
**kwargs):
self.params = {
"objective": objective,
"metric": metric,
"num_leaves": num_leaves,
"learning_rate": learning_rate,
"n_estimators": n_estimators,
**kwargs
}
self.model = None
def fit(self, X: pl.DataFrame, y: pl.Series) -> "LightGBMModel":
"""训练模型"""
import lightgbm as lgb
# 转换为 numpy
X_np = X.to_numpy()
y_np = y.to_numpy()
# 创建数据集
train_data = lgb.Dataset(X_np, label=y_np)
# 训练
self.model = lgb.train(
self.params,
train_data,
num_boost_round=self.params.get("n_estimators", 100)
)
return self
def predict(self, X: pl.DataFrame) -> np.ndarray:
"""预测"""
if self.model is None:
raise RuntimeError("Model not fitted yet")
X_np = X.to_numpy()
return self.model.predict(X_np)
def feature_importance(self) -> pd.Series:
"""返回特征重要性"""
if self.model is None:
return None
importance = self.model.feature_importance(importance_type="gain")
return pd.Series(importance, index=self.feature_names_)
def save(self, path: str) -> None:
"""保存模型(使用 LightGBM 原生格式)"""
if self.model is None:
raise RuntimeError("Model not fitted yet")
self.model.save_model(path)
@classmethod
def load(cls, path: str) -> "LightGBMModel":
"""加载模型"""
import lightgbm as lgb
instance = cls()
instance.model = lgb.Booster(model_file=path)
return instance
```
### 3.6 数据处理器 (components/processors/)
#### StandardScaler
```python
@register_processor("standard_scaler")
class StandardScaler(BaseProcessor):
"""标准化处理器(时序标准化)
在整个训练集上学习均值和标准差,
然后应用到训练集和测试集。
"""
name = "standard_scaler"
def __init__(self, exclude_cols: List[str] = None):
self.exclude_cols = exclude_cols or ["ts_code", "trade_date"]
self.mean_ = {}
self.std_ = {}
def fit(self, X: pl.DataFrame) -> "StandardScaler":
"""计算均值和标准差(仅在训练集上)"""
numeric_cols = [
c for c in X.columns
if c not in self.exclude_cols and X[c].dtype.is_numeric()
]
for col in numeric_cols:
self.mean_[col] = X[col].mean()
self.std_[col] = X[col].std()
return self
def transform(self, X: pl.DataFrame) -> pl.DataFrame:
"""标准化(使用训练集学到的参数)"""
expressions = []
for col in X.columns:
if col in self.mean_:
expr = ((pl.col(col) - self.mean_[col]) / self.std_[col]).alias(col)
expressions.append(expr)
else:
expressions.append(pl.col(col))
return X.select(expressions)
```
#### CrossSectionalStandardScaler
```python
@register_processor("cs_standard_scaler")
class CrossSectionalStandardScaler(BaseProcessor):
"""截面标准化处理器
每天独立进行标准化:对当天所有股票的某一因子进行标准化。
特点:
- 不需要 fit每天独立计算当天的均值和标准差
- 适用于截面因子,消除市值等行业差异
- 公式z = (x - mean_today) / std_today
"""
name = "cs_standard_scaler"
def __init__(self, exclude_cols: List[str] = None, date_col: str = "trade_date"):
self.exclude_cols = exclude_cols or ["ts_code", "trade_date"]
self.date_col = date_col
def transform(self, X: pl.DataFrame) -> pl.DataFrame:
"""截面标准化
按日期分组,每天独立计算均值和标准差并标准化。
不需要 fit因为每天使用当天的统计量。
"""
numeric_cols = [
c for c in X.columns
if c not in self.exclude_cols and X[c].dtype.is_numeric()
]
# 按日期分组标准化
result = X.with_columns([
pl.col(col).mean().over(self.date_col).alias(f"{col}_mean")
for col in numeric_cols
] + [
pl.col(col).std().over(self.date_col).alias(f"{col}_std")
for col in numeric_cols
])
# 计算标准化值
for col in numeric_cols:
result = result.with_columns([
((pl.col(col) - pl.col(f"{col}_mean")) / pl.col(f"{col}_std")).alias(col)
])
# 删除中间列
result = result.drop([f"{col}_mean", f"{col}_std"])
return result
```
#### Winsorizer
```python
@register_processor("winsorizer")
class Winsorizer(BaseProcessor):
"""缩尾处理器
对每一列的极端值进行截断处理。
可以全局截断(在整个训练集上学习分位数),
也可以截面截断(每天独立处理)。
"""
name = "winsorizer"
def __init__(
self,
lower: float = 0.01,
upper: float = 0.99,
by_date: bool = False, # True=每天独立缩尾, False=全局缩尾
date_col: str = "trade_date"
):
self.lower = lower
self.upper = upper
self.by_date = by_date
self.date_col = date_col
self.bounds_ = {} # 存储分位数边界(全局模式)
def fit(self, X: pl.DataFrame) -> "Winsorizer":
"""学习分位数边界(仅在全局模式下)"""
if not self.by_date:
numeric_cols = [c for c in X.columns if X[c].dtype.is_numeric()]
for col in numeric_cols:
self.bounds_[col] = {
"lower": X[col].quantile(self.lower),
"upper": X[col].quantile(self.upper)
}
return self
def transform(self, X: pl.DataFrame) -> pl.DataFrame:
"""缩尾处理"""
if self.by_date:
# 每天独立缩尾
return self._transform_by_date(X)
else:
# 全局缩尾
return self._transform_global(X)
def _transform_global(self, X: pl.DataFrame) -> pl.DataFrame:
"""全局缩尾(使用训练集学到的边界)"""
expressions = []
for col in X.columns:
if col in self.bounds_:
lower = self.bounds_[col]["lower"]
upper = self.bounds_[col]["upper"]
expr = pl.col(col).clip(lower, upper).alias(col)
expressions.append(expr)
else:
expressions.append(pl.col(col))
return X.select(expressions)
def _transform_by_date(self, X: pl.DataFrame) -> pl.DataFrame:
"""每日独立缩尾"""
# 按日期分组计算分位数并截断
# Polars 实现...
pass
```
## 4. 训练流程设计
### 4.1 Trainer 主类 (core/trainer.py)
```python
class Trainer:
"""训练器主类
整合数据处理、模型训练、评估的完整流程。
关键设计:
1. 因子先计算(全市场),再筛选股票池(每日独立)
2. Processor 分阶段行为:训练集 fit_transform测试集 transform
3. 一次性训练,不滚动
4. 支持模型持久化
"""
def __init__(
self,
model: BaseModel,
pool_manager: Optional[StockPoolManager] = None,
processors: List[BaseProcessor] = None,
splitter: DateSplitter = None,
target_col: str = "target",
feature_cols: List[str] = None,
persist_model: bool = False,
model_save_path: Optional[str] = None,
):
self.model = model
self.pool_manager = pool_manager
self.processors = processors or []
self.splitter = splitter
self.target_col = target_col
self.feature_cols = feature_cols
self.persist_model = persist_model
self.model_save_path = model_save_path
# 存储训练后的处理器
self.fitted_processors: List[BaseProcessor] = []
self.results: pl.DataFrame = None
def train(self, data: pl.DataFrame) -> "Trainer":
"""执行训练流程
流程:
1. 股票池每日筛选(如果配置了 pool_manager
2. 按日期划分训练集/测试集
3. 训练集processors fit_transform
4. 训练模型
5. 测试集processors transform使用训练集学到的参数
6. 预测并评估
7. 持久化模型(如果启用)
Args:
data: 因子计算后的全市场数据
必须包含 ts_code 和 trade_date 列
Returns:
self (支持链式调用)
"""
# 1. 股票池筛选(每日独立)
if self.pool_manager:
print("[筛选] 每日独立筛选股票池...")
data = self.pool_manager.filter_and_select_daily(data)
# 2. 划分训练/测试集
print("[划分] 划分训练集和测试集...")
train_data, test_data = self.splitter.split(data)
# 3. 训练集processors fit_transform
print("[处理] 处理训练集...")
for processor in self.processors:
train_data = processor.fit_transform(train_data)
self.fitted_processors.append(processor)
# 4. 训练模型
print("[训练] 训练模型...")
X_train = train_data.select(self.feature_cols)
y_train = train_data.select(self.target_col).to_series()
self.model.fit(X_train, y_train)
# 5. 测试集processors transform
print("[处理] 处理测试集...")
for processor in self.fitted_processors:
test_data = processor.transform(test_data)
# 6. 预测
print("[预测] 生成预测...")
X_test = test_data.select(self.feature_cols)
predictions = self.model.predict(X_test)
# 7. 保存结果
self.results = test_data.with_columns([
pl.Series("prediction", predictions)
])
# 8. 持久化模型
if self.persist_model and self.model_save_path:
print(f"[保存] 保存模型到 {self.model_save_path}...")
self.save_model(self.model_save_path)
return self
def predict(self, data: pl.DataFrame) -> pl.DataFrame:
"""对新数据进行预测
注意:新数据需要先经过股票池筛选,
然后使用训练好的 processors 和 model 进行预测。
"""
# 应用 processors
for processor in self.fitted_processors:
data = processor.transform(data)
# 预测
X = data.select(self.feature_cols)
predictions = self.model.predict(X)
return data.with_columns([pl.Series("prediction", predictions)])
def get_results(self) -> pl.DataFrame:
"""获取所有预测结果"""
return self.results
def save_results(self, path: str):
"""保存预测结果到文件"""
if self.results is not None:
self.results.write_csv(path)
def save_model(self, path: str):
"""保存模型"""
self.model.save(path)
```
### 4.2 配置类 (config/config.py)
```python
from pydantic import BaseSettings, Field
from typing import List, Dict, Optional
from dataclasses import dataclass
@dataclass
class StockFilterConfig:
"""股票过滤器配置"""
exclude_cyb: bool = True # 排除创业板
exclude_kcb: bool = True # 排除科创板
exclude_bj: bool = True # 排除北交所
exclude_st: bool = True # 排除ST股票
@dataclass
class MarketCapSelectorConfig:
"""市值选择器配置"""
enabled: bool = True # 是否启用
n: int = 100 # 选择前 n 只
ascending: bool = False # False=最大市值, True=最小市值
market_cap_col: str = "total_mv" # 市值列名(来自 daily_basic
@dataclass
class ProcessorConfig:
"""处理器配置"""
name: str
params: Dict = Field(default_factory=dict)
class TrainingConfig(BaseSettings):
"""训练配置类"""
# === 数据配置(必填)===
feature_cols: List[str] = Field(..., min_items=1) # 特征列名,至少一个
target_col: str = "target" # 目标变量列名
date_col: str = "trade_date" # 日期列名
code_col: str = "ts_code" # 股票代码列名
# === 日期划分(必填)===
train_start: str = Field(..., description="训练期开始 YYYYMMDD")
train_end: str = Field(..., description="训练期结束 YYYYMMDD")
test_start: str = Field(..., description="测试期开始 YYYYMMDD")
test_end: str = Field(..., description="测试期结束 YYYYMMDD")
# === 股票池配置 ===
stock_filter: StockFilterConfig = Field(
default_factory=lambda: StockFilterConfig(
exclude_cyb=True,
exclude_kcb=True,
exclude_bj=True,
exclude_st=True,
)
)
stock_selector: Optional[MarketCapSelectorConfig] = Field(
default_factory=lambda: MarketCapSelectorConfig(
enabled=True,
n=100,
ascending=False,
market_cap_col="total_mv",
)
)
# 注意:如果 stock_selector = None则跳过市值选择
# === 模型配置 ===
model_type: str = "lightgbm"
model_params: Dict = Field(default_factory=dict)
# === 处理器配置 ===
processors: List[ProcessorConfig] = Field(
default_factory=lambda: [
ProcessorConfig(name="winsorizer", params={"lower": 0.01, "upper": 0.99}),
ProcessorConfig(name="cs_standard_scaler", params={}),
]
)
# === 持久化配置 ===
persist_model: bool = False # 默认不持久化
model_save_path: Optional[str] = None # 持久化路径
# === 输出配置 ===
output_dir: str = "output"
save_predictions: bool = True
```
## 5. 使用示例
### 5.1 基本用法
```python
from src.training import Trainer, LightGBMModel, DateSplitter
from src.training.config import StockFilterConfig, MarketCapSelectorConfig
from src.training.core.stock_pool_manager import StockPoolManager
from src.training.components.processors import Winsorizer, CrossSectionalStandardScaler
from src.factors import FactorEngine
import polars as pl
# 1. 因子计算(全市场数据)
engine = FactorEngine()
all_data = engine.compute(
factor_names=["factor1", "factor2", "factor3", "future_return_5d"],
start_date="20200101",
end_date="20231231",
# 不指定 stock_codes获取全市场数据
)
# 2. 创建股票池管理器(每日独立筛选)
pool_manager = StockPoolManager(
filter_config=StockFilterConfig(
exclude_cyb=True,
exclude_kcb=True,
exclude_bj=True,
exclude_st=True,
),
selector_config=MarketCapSelectorConfig(
enabled=True,
n=100,
ascending=False, # 最大市值
),
data_router=data_router, # 用于获取 daily_basic 数据
)
# 3. 创建模型
model = LightGBMModel(
objective="regression",
n_estimators=100,
learning_rate=0.05
)
# 4. 创建处理器
processors = [
Winsorizer(lower=0.01, upper=0.99, by_date=False), # 全局缩尾
CrossSectionalStandardScaler(), # 截面标准化(每天独立)
]
# 5. 创建划分器
splitter = DateSplitter(
train_start="20200101",
train_end="20221231",
test_start="20230101",
test_end="20231231"
)
# 6. 创建训练器
trainer = Trainer(
model=model,
pool_manager=pool_manager, # 每日筛选股票池
processors=processors,
splitter=splitter,
target_col="future_return_5d",
feature_cols=["factor1", "factor2", "factor3"],
persist_model=True, # 启用持久化
model_save_path="output/model.pkl",
)
# 7. 执行训练(传入全市场数据)
trainer.train(all_data)
# 8. 获取结果
results = trainer.get_results() # 包含预测值
# 9. 保存结果
trainer.save_results("output/predictions.csv")
# 10. 加载模型并预测新数据
loaded_model = LightGBMModel.load("output/model.pkl")
new_predictions = loaded_model.predict(new_data)
```
### 5.2 使用配置驱动
```python
from src.training.config import TrainingConfig, StockFilterConfig
from src.training import Trainer
from src.training.registry import ModelRegistry, ProcessorRegistry
from src.training.core.stock_pool_manager import StockPoolManager
from src.factors import FactorEngine
# 1. 配置(必填字段校验)
config = TrainingConfig(
feature_cols=["factor1", "factor2", "factor3"],
target_col="future_return_5d",
train_start="20200101",
train_end="20221231",
test_start="20230101",
test_end="20231231",
stock_filter=StockFilterConfig(
exclude_cyb=True,
exclude_kcb=True,
),
stock_selector=None, # 跳过市值选择
persist_model=False, # 不持久化
)
# 2. 因子计算
engine = FactorEngine()
all_data = engine.compute(
factor_names=config.feature_cols + [config.target_col],
start_date=config.train_start,
end_date=config.test_end,
)
# 3. 从配置创建组件
model = ModelRegistry.get_model(config.model_type)(**config.model_params)
processors = [
ProcessorRegistry.get_processor(p.name)(**p.params)
for p in config.processors
]
# 4. 创建股票池管理器
pool_manager = StockPoolManager(
filter_config=config.stock_filter,
selector_config=config.stock_selector,
data_router=data_router,
)
# 5. 创建并运行训练器
trainer = Trainer(
model=model,
pool_manager=pool_manager,
processors=processors,
splitter=DateSplitter(
train_start=config.train_start,
train_end=config.train_end,
test_start=config.test_start,
test_end=config.test_end,
),
target_col=config.target_col,
feature_cols=config.feature_cols,
)
trainer.train(all_data)
results = trainer.get_results()
```
## 6. 实现顺序
按以下顺序实现和提交:
### Commit 1: 基础架构
- `training/__init__.py`
- `training/components/__init__.py`
- `training/components/base.py`BaseModel, BaseProcessor含 save/load
- `training/registry.py`(组件注册中心)
### Commit 2: 数据划分
- `training/components/splitters.py`DateSplitter仅一次性划分
### Commit 3: 股票池选择器配置
- `training/components/selectors.py`StockFilterConfig, MarketCapSelectorConfig
### Commit 4: 数据处理器
- `training/components/processors/__init__.py`
- `training/components/processors/transforms.py`
- Winsorizer
- StandardScaler
- CrossSectionalStandardScaler
### Commit 5: LightGBM 模型
- `training/components/models/__init__.py`
- `training/components/models/lightgbm.py`(含 save_model/load_model
### Commit 6: 股票池管理器
- `training/core/__init__.py`
- `training/core/stock_pool_manager.py`(每日独立筛选)
### Commit 7: Trainer 训练器
- `training/core/trainer.py`
### Commit 8: 配置管理
- `training/config/__init__.py`
- `training/config/config.py`TrainingConfig含必填校验
### Commit 9: 预留实验模块
- `experiment/__init__.py`
## 7. 注意事项
### 7.1 股票池处理顺序(每日)
```
当日所有股票数据
代码过滤创业板、ST等
查询 daily_basic 获取当日市值
市值选择前N只
返回当日选中股票
```
- 每日独立筛选,市值动态变化
- 市值数据仅从 daily_basic 获取
- 市值数据绝不混入特征矩阵
### 7.2 Processor 阶段行为
| Processor | 训练集 | 测试集 |
|-----------|--------|--------|
| StandardScaler | fit_transform | transform使用训练集参数 |
| CrossSectionalStandardScaler | transform | transform每天独立 |
| Winsorizer (global) | fit_transform | transform使用训练集参数 |
| Winsorizer (by_date) | transform | transform每天独立 |
### 7.3 依赖关系
- 使用 Polars 进行数据处理
- LightGBM 用于模型训练
- Pydantic 用于配置管理
- 市值数据来自 daily_basic 表(独立数据源)
### 7.4 删除的功能
以下原计划在本次实现中删除:
1. **特征选择**processors/selectors.py
2. **滚动训练**WalkForward, ExpandingWindow
3. **结果分析工具**(复杂分析功能)
4. **validator.py, evaluator.py**(已删除,不实现 metrics
### 7.5 新增功能
1. **StockPoolManager**(每日独立筛选)
2. **模型持久化**save/load默认关闭
3. **配置必填校验**feature_cols, 日期范围)