2026-03-03 22:04:22 +08:00
|
|
|
|
# 训练模块实现计划
|
|
|
|
|
|
|
|
|
|
|
|
## 1. 概述
|
|
|
|
|
|
|
|
|
|
|
|
本计划描述 ProStock 训练模块的完整实现方案,支持标准回归模型训练并输出预测结果。
|
|
|
|
|
|
|
|
|
|
|
|
### 1.1 目标
|
|
|
|
|
|
- 提供简洁的训练流程
|
|
|
|
|
|
- 支持标准回归模型(LightGBM、CatBoost)
|
|
|
|
|
|
- 输出可解释的预测结果
|
|
|
|
|
|
- 与现有因子引擎无缝集成
|
|
|
|
|
|
|
|
|
|
|
|
### 1.2 设计原则
|
|
|
|
|
|
- **职责分离**:训练流程与建模组件分离
|
|
|
|
|
|
- **注册机制**:使用装饰器实现即插即用
|
|
|
|
|
|
- **配置驱动**:所有参数通过配置类管理
|
|
|
|
|
|
- **阶段感知**:Processor 在训练和测试阶段行为不同
|
|
|
|
|
|
- **每日筛选**:股票池每日独立筛选(市值动态变化)
|
|
|
|
|
|
- **模型持久化**:支持保存和加载训练好的模型
|
|
|
|
|
|
|
|
|
|
|
|
## 2. 模块结构
|
|
|
|
|
|
|
|
|
|
|
|
```
|
|
|
|
|
|
src/
|
|
|
|
|
|
├── training/ # 训练模块(流程 + 组件)
|
|
|
|
|
|
│ ├── __init__.py # 导出核心类
|
|
|
|
|
|
│ ├── core/ # 训练流程核心
|
|
|
|
|
|
│ │ ├── __init__.py
|
|
|
|
|
|
│ │ ├── trainer.py # Trainer 主类
|
|
|
|
|
|
│ │ └── stock_pool_manager.py # 股票池管理器(每日独立筛选)
|
|
|
|
|
|
│ ├── components/ # 建模组件
|
|
|
|
|
|
│ │ ├── __init__.py
|
|
|
|
|
|
│ │ ├── base.py # BaseModel, BaseProcessor 抽象基类
|
|
|
|
|
|
│ │ ├── splitters.py # 时间序列划分策略(一次性划分)
|
|
|
|
|
|
│ │ ├── selectors.py # 股票池选择器配置
|
|
|
|
|
|
│ │ ├── models/ # 模型实现
|
|
|
|
|
|
│ │ │ ├── __init__.py
|
|
|
|
|
|
│ │ │ ├── lightgbm.py # LightGBM 回归模型
|
|
|
|
|
|
│ │ │ └── catboost.py # CatBoost 回归模型
|
|
|
|
|
|
│ │ ├── processors/ # 数据处理器
|
|
|
|
|
|
│ │ │ ├── __init__.py
|
|
|
|
|
|
│ │ │ └── transforms.py # 标准化(截面/时序)、缩尾
|
|
|
|
|
|
│ ├── config/ # 配置管理
|
|
|
|
|
|
│ │ ├── __init__.py
|
|
|
|
|
|
│ │ └── config.py # TrainingConfig (pydantic)
|
|
|
|
|
|
│ └── registry.py # 组件注册中心
|
|
|
|
|
|
│
|
|
|
|
|
|
└── experiment/ # 实验管理(预留结构,暂不实现)
|
|
|
|
|
|
└── __init__.py
|
|
|
|
|
|
```
|
|
|
|
|
|
|
|
|
|
|
|
## 3. 核心组件设计
|
|
|
|
|
|
|
|
|
|
|
|
### 3.1 抽象基类 (components/base.py)
|
|
|
|
|
|
|
|
|
|
|
|
#### BaseModel
|
|
|
|
|
|
```python
|
|
|
|
|
|
class BaseModel(ABC):
|
|
|
|
|
|
"""模型基类"""
|
|
|
|
|
|
|
|
|
|
|
|
name: str = "" # 模型名称
|
|
|
|
|
|
|
|
|
|
|
|
def fit(self, X: pl.DataFrame, y: pl.Series) -> "BaseModel":
|
|
|
|
|
|
"""训练模型"""
|
|
|
|
|
|
raise NotImplementedError
|
|
|
|
|
|
|
|
|
|
|
|
def predict(self, X: pl.DataFrame) -> np.ndarray:
|
|
|
|
|
|
"""预测"""
|
|
|
|
|
|
raise NotImplementedError
|
|
|
|
|
|
|
|
|
|
|
|
def feature_importance(self) -> Optional[pd.Series]:
|
|
|
|
|
|
"""特征重要性"""
|
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
|
|
def save(self, path: str) -> None:
|
|
|
|
|
|
"""保存模型到文件
|
|
|
|
|
|
|
|
|
|
|
|
默认实现使用 pickle,子类可覆盖
|
|
|
|
|
|
"""
|
|
|
|
|
|
import pickle
|
|
|
|
|
|
with open(path, 'wb') as f:
|
|
|
|
|
|
pickle.dump(self, f)
|
|
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
|
def load(cls, path: str) -> "BaseModel":
|
|
|
|
|
|
"""从文件加载模型"""
|
|
|
|
|
|
import pickle
|
|
|
|
|
|
with open(path, 'rb') as f:
|
|
|
|
|
|
return pickle.load(f)
|
|
|
|
|
|
```
|
|
|
|
|
|
|
|
|
|
|
|
#### BaseProcessor
|
|
|
|
|
|
```python
|
|
|
|
|
|
class BaseProcessor(ABC):
|
|
|
|
|
|
"""数据处理器基类
|
|
|
|
|
|
|
|
|
|
|
|
重要:Processor 在不同阶段行为不同:
|
|
|
|
|
|
- 训练阶段:fit_transform(学习参数并应用)
|
|
|
|
|
|
- 验证/测试阶段:transform(使用训练阶段学到的参数)
|
|
|
|
|
|
|
|
|
|
|
|
这意味着 Processor 实例会在训练后被保存,
|
|
|
|
|
|
用于后续的验证和测试数据转换。
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
name: str = ""
|
|
|
|
|
|
|
|
|
|
|
|
def fit(self, X: pl.DataFrame) -> "BaseProcessor":
|
|
|
|
|
|
"""学习参数(仅在训练阶段调用)"""
|
|
|
|
|
|
return self
|
|
|
|
|
|
|
|
|
|
|
|
def transform(self, X: pl.DataFrame) -> pl.DataFrame:
|
|
|
|
|
|
"""转换数据"""
|
|
|
|
|
|
raise NotImplementedError
|
|
|
|
|
|
|
|
|
|
|
|
def fit_transform(self, X: pl.DataFrame) -> pl.DataFrame:
|
|
|
|
|
|
"""拟合并转换(训练阶段使用)"""
|
|
|
|
|
|
return self.fit(X).transform(X)
|
|
|
|
|
|
```
|
|
|
|
|
|
|
|
|
|
|
|
### 3.2 时间序列划分 (components/splitters.py)
|
|
|
|
|
|
|
|
|
|
|
|
**设计说明**:暂不实现滚动训练,采用一次性训练/测试划分。
|
|
|
|
|
|
|
|
|
|
|
|
#### DateSplitter
|
|
|
|
|
|
```python
|
|
|
|
|
|
class DateSplitter:
|
|
|
|
|
|
"""基于日期范围的一次性划分
|
|
|
|
|
|
|
|
|
|
|
|
将数据按日期划分为训练集和测试集,不滚动。
|
|
|
|
|
|
|
|
|
|
|
|
示例:
|
|
|
|
|
|
train_start: "20200101", train_end: "20221231" (训练集:3年)
|
|
|
|
|
|
test_start: "20230101", test_end: "20231231" (测试集:1年)
|
|
|
|
|
|
|
|
|
|
|
|
特点:
|
|
|
|
|
|
- 一次性划分,不滚动
|
|
|
|
|
|
- 训练集和测试集互不重叠
|
|
|
|
|
|
- 基于实际日期范围,而非行数
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
def __init__(
|
|
|
|
|
|
self,
|
|
|
|
|
|
train_start: str, # 训练期开始日期 "YYYYMMDD"
|
|
|
|
|
|
train_end: str, # 训练期结束日期 "YYYYMMDD"
|
|
|
|
|
|
test_start: str, # 测试期开始日期 "YYYYMMDD"
|
|
|
|
|
|
test_end: str, # 测试期结束日期 "YYYYMMDD"
|
|
|
|
|
|
):
|
|
|
|
|
|
self.train_start = train_start
|
|
|
|
|
|
self.train_end = train_end
|
|
|
|
|
|
self.test_start = test_start
|
|
|
|
|
|
self.test_end = test_end
|
|
|
|
|
|
|
|
|
|
|
|
def split(self, data: pl.DataFrame) -> Tuple[pl.DataFrame, pl.DataFrame]:
|
|
|
|
|
|
"""划分数据为训练集和测试集"""
|
|
|
|
|
|
train_data = data.filter(
|
|
|
|
|
|
(pl.col("trade_date") >= self.train_start) &
|
|
|
|
|
|
(pl.col("trade_date") <= self.train_end)
|
|
|
|
|
|
)
|
|
|
|
|
|
test_data = data.filter(
|
|
|
|
|
|
(pl.col("trade_date") >= self.test_start) &
|
|
|
|
|
|
(pl.col("trade_date") <= self.test_end)
|
|
|
|
|
|
)
|
|
|
|
|
|
return train_data, test_data
|
|
|
|
|
|
```
|
|
|
|
|
|
|
|
|
|
|
|
### 3.3 股票池选择器配置 (components/selectors.py)
|
|
|
|
|
|
|
|
|
|
|
|
**设计说明**:股票池每日独立筛选,市值选择需要配合 StockPoolManager 使用。
|
|
|
|
|
|
|
|
|
|
|
|
#### StockFilterConfig
|
|
|
|
|
|
```python
|
|
|
|
|
|
@dataclass
|
|
|
|
|
|
class StockFilterConfig:
|
|
|
|
|
|
"""股票过滤器配置
|
|
|
|
|
|
|
|
|
|
|
|
用于过滤掉不需要的股票(如创业板、科创板等)。
|
|
|
|
|
|
基于股票代码进行过滤,不依赖外部数据。
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
exclude_cyb: bool = True # 是否排除创业板(300xxx)
|
|
|
|
|
|
exclude_kcb: bool = True # 是否排除科创板(688xxx)
|
|
|
|
|
|
exclude_bj: bool = True # 是否排除北交所(8xxxxxx, 4xxxxxx)
|
|
|
|
|
|
exclude_st: bool = True # 是否排除ST股票
|
|
|
|
|
|
|
|
|
|
|
|
def filter_codes(self, codes: List[str]) -> List[str]:
|
|
|
|
|
|
"""应用过滤条件,返回过滤后的股票代码列表"""
|
|
|
|
|
|
result = []
|
|
|
|
|
|
for code in codes:
|
|
|
|
|
|
if self.exclude_cyb and code.startswith("300"):
|
|
|
|
|
|
continue
|
|
|
|
|
|
if self.exclude_kcb and code.startswith("688"):
|
|
|
|
|
|
continue
|
|
|
|
|
|
if self.exclude_bj and (code.startswith("8") or code.startswith("4")):
|
|
|
|
|
|
continue
|
|
|
|
|
|
# ST 股票过滤需要额外数据,在 StockPoolManager 中处理
|
|
|
|
|
|
result.append(code)
|
|
|
|
|
|
return result
|
|
|
|
|
|
```
|
|
|
|
|
|
|
|
|
|
|
|
#### MarketCapSelectorConfig
|
|
|
|
|
|
```python
|
|
|
|
|
|
@dataclass
|
|
|
|
|
|
class MarketCapSelectorConfig:
|
|
|
|
|
|
"""市值选择器配置
|
|
|
|
|
|
|
|
|
|
|
|
每日独立选择市值最大或最小的 n 只股票。
|
|
|
|
|
|
市值数据从 daily_basic 表独立获取,仅用于筛选。
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
enabled: bool = True # 是否启用选择
|
|
|
|
|
|
n: int = 100 # 选择前 n 只
|
|
|
|
|
|
ascending: bool = False # False=最大市值, True=最小市值
|
|
|
|
|
|
market_cap_col: str = "total_mv" # 市值列名(来自 daily_basic)
|
|
|
|
|
|
```
|
|
|
|
|
|
|
|
|
|
|
|
### 3.4 股票池管理器 (core/stock_pool_manager.py)
|
|
|
|
|
|
|
|
|
|
|
|
**设计说明**:每日独立筛选股票池,市值数据从 daily_basic 表独立获取。
|
|
|
|
|
|
|
|
|
|
|
|
```python
|
|
|
|
|
|
class StockPoolManager:
|
|
|
|
|
|
"""股票池管理器 - 每日独立筛选
|
|
|
|
|
|
|
|
|
|
|
|
重要约束:
|
|
|
|
|
|
1. 市值数据仅从 daily_basic 表获取,仅用于筛选
|
|
|
|
|
|
2. 市值数据绝不混入特征矩阵
|
|
|
|
|
|
3. 每日独立筛选(市值是动态变化的)
|
|
|
|
|
|
|
|
|
|
|
|
处理流程(每日):
|
|
|
|
|
|
当日所有股票
|
|
|
|
|
|
↓
|
|
|
|
|
|
代码过滤(创业板、ST等)
|
|
|
|
|
|
↓
|
|
|
|
|
|
查询 daily_basic 获取当日市值
|
|
|
|
|
|
↓
|
|
|
|
|
|
市值选择(前N只)
|
|
|
|
|
|
↓
|
|
|
|
|
|
返回当日选中股票列表
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
def __init__(
|
|
|
|
|
|
self,
|
|
|
|
|
|
filter_config: StockFilterConfig,
|
|
|
|
|
|
selector_config: Optional[MarketCapSelectorConfig],
|
|
|
|
|
|
data_router: DataRouter, # 用于获取 daily_basic 数据
|
|
|
|
|
|
code_col: str = "ts_code",
|
|
|
|
|
|
date_col: str = "trade_date",
|
|
|
|
|
|
):
|
|
|
|
|
|
self.filter_config = filter_config
|
|
|
|
|
|
self.selector_config = selector_config
|
|
|
|
|
|
self.data_router = data_router
|
|
|
|
|
|
self.code_col = code_col
|
|
|
|
|
|
self.date_col = date_col
|
|
|
|
|
|
|
|
|
|
|
|
def filter_and_select_daily(self, data: pl.DataFrame) -> pl.DataFrame:
|
|
|
|
|
|
"""每日独立筛选股票池
|
|
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
|
data: 因子计算后的全市场数据,必须包含 trade_date 和 ts_code 列
|
|
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
|
筛选后的数据,仅包含每日选中的股票
|
|
|
|
|
|
|
|
|
|
|
|
注意:
|
|
|
|
|
|
- 按日期分组处理
|
|
|
|
|
|
- 市值数据从 daily_basic 独立获取
|
|
|
|
|
|
- 保持市值数据与特征数据隔离
|
|
|
|
|
|
"""
|
|
|
|
|
|
dates = data.select(self.date_col).unique().sort(self.date_col)
|
|
|
|
|
|
|
|
|
|
|
|
result_frames = []
|
|
|
|
|
|
for date in dates.to_series():
|
|
|
|
|
|
# 获取当日数据
|
|
|
|
|
|
daily_data = data.filter(pl.col(self.date_col) == date)
|
|
|
|
|
|
daily_codes = daily_data.select(self.code_col).to_series().to_list()
|
|
|
|
|
|
|
|
|
|
|
|
# 1. 代码过滤
|
|
|
|
|
|
filtered_codes = self.filter_config.filter_codes(daily_codes)
|
|
|
|
|
|
|
|
|
|
|
|
# 2. 市值选择(如果启用)
|
|
|
|
|
|
if self.selector_config and self.selector_config.enabled:
|
|
|
|
|
|
# 从 daily_basic 获取当日市值
|
|
|
|
|
|
market_caps = self._get_market_caps_for_date(filtered_codes, date)
|
|
|
|
|
|
selected_codes = self._select_by_market_cap(filtered_codes, market_caps)
|
|
|
|
|
|
else:
|
|
|
|
|
|
selected_codes = filtered_codes
|
|
|
|
|
|
|
|
|
|
|
|
# 3. 保留当日选中的股票数据
|
|
|
|
|
|
daily_selected = daily_data.filter(
|
|
|
|
|
|
pl.col(self.code_col).is_in(selected_codes)
|
|
|
|
|
|
)
|
|
|
|
|
|
result_frames.append(daily_selected)
|
|
|
|
|
|
|
|
|
|
|
|
return pl.concat(result_frames)
|
|
|
|
|
|
|
|
|
|
|
|
def _get_market_caps_for_date(
|
|
|
|
|
|
self,
|
|
|
|
|
|
codes: List[str],
|
|
|
|
|
|
date: str
|
|
|
|
|
|
) -> Dict[str, float]:
|
|
|
|
|
|
"""从 daily_basic 表获取指定日期的市值数据
|
|
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
|
codes: 股票代码列表
|
|
|
|
|
|
date: 日期 "YYYYMMDD"
|
|
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
|
{股票代码: 市值} 的字典
|
|
|
|
|
|
"""
|
|
|
|
|
|
# 通过 data_router 查询 daily_basic 表
|
|
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
|
|
def _select_by_market_cap(
|
|
|
|
|
|
self,
|
|
|
|
|
|
codes: List[str],
|
|
|
|
|
|
market_caps: Dict[str, float]
|
|
|
|
|
|
) -> List[str]:
|
|
|
|
|
|
"""根据市值选择股票"""
|
|
|
|
|
|
if not market_caps:
|
|
|
|
|
|
return codes
|
|
|
|
|
|
|
|
|
|
|
|
# 按市值排序并选择前N只
|
|
|
|
|
|
sorted_codes = sorted(
|
|
|
|
|
|
codes,
|
|
|
|
|
|
key=lambda c: market_caps.get(c, 0),
|
|
|
|
|
|
reverse=not self.selector_config.ascending
|
|
|
|
|
|
)
|
|
|
|
|
|
return sorted_codes[:self.selector_config.n]
|
|
|
|
|
|
```
|
|
|
|
|
|
|
|
|
|
|
|
### 3.5 模型实现 (components/models/)
|
|
|
|
|
|
|
|
|
|
|
|
#### LightGBMModel
|
|
|
|
|
|
```python
|
|
|
|
|
|
@register_model("lightgbm")
|
|
|
|
|
|
class LightGBMModel(BaseModel):
|
|
|
|
|
|
"""LightGBM 回归模型"""
|
|
|
|
|
|
|
|
|
|
|
|
name = "lightgbm"
|
|
|
|
|
|
|
|
|
|
|
|
def __init__(self,
|
|
|
|
|
|
objective: str = "regression",
|
|
|
|
|
|
metric: str = "rmse",
|
|
|
|
|
|
num_leaves: int = 31,
|
|
|
|
|
|
learning_rate: float = 0.05,
|
|
|
|
|
|
n_estimators: int = 100,
|
|
|
|
|
|
**kwargs):
|
|
|
|
|
|
self.params = {
|
|
|
|
|
|
"objective": objective,
|
|
|
|
|
|
"metric": metric,
|
|
|
|
|
|
"num_leaves": num_leaves,
|
|
|
|
|
|
"learning_rate": learning_rate,
|
|
|
|
|
|
"n_estimators": n_estimators,
|
|
|
|
|
|
**kwargs
|
|
|
|
|
|
}
|
|
|
|
|
|
self.model = None
|
|
|
|
|
|
|
|
|
|
|
|
def fit(self, X: pl.DataFrame, y: pl.Series) -> "LightGBMModel":
|
|
|
|
|
|
"""训练模型"""
|
|
|
|
|
|
import lightgbm as lgb
|
|
|
|
|
|
|
|
|
|
|
|
# 转换为 numpy
|
|
|
|
|
|
X_np = X.to_numpy()
|
|
|
|
|
|
y_np = y.to_numpy()
|
|
|
|
|
|
|
|
|
|
|
|
# 创建数据集
|
|
|
|
|
|
train_data = lgb.Dataset(X_np, label=y_np)
|
|
|
|
|
|
|
|
|
|
|
|
# 训练
|
|
|
|
|
|
self.model = lgb.train(
|
|
|
|
|
|
self.params,
|
|
|
|
|
|
train_data,
|
|
|
|
|
|
num_boost_round=self.params.get("n_estimators", 100)
|
|
|
|
|
|
)
|
|
|
|
|
|
return self
|
|
|
|
|
|
|
|
|
|
|
|
def predict(self, X: pl.DataFrame) -> np.ndarray:
|
|
|
|
|
|
"""预测"""
|
|
|
|
|
|
if self.model is None:
|
|
|
|
|
|
raise RuntimeError("Model not fitted yet")
|
|
|
|
|
|
X_np = X.to_numpy()
|
|
|
|
|
|
return self.model.predict(X_np)
|
|
|
|
|
|
|
|
|
|
|
|
def feature_importance(self) -> pd.Series:
|
|
|
|
|
|
"""返回特征重要性"""
|
|
|
|
|
|
if self.model is None:
|
|
|
|
|
|
return None
|
|
|
|
|
|
importance = self.model.feature_importance(importance_type="gain")
|
|
|
|
|
|
return pd.Series(importance, index=self.feature_names_)
|
|
|
|
|
|
|
|
|
|
|
|
def save(self, path: str) -> None:
|
|
|
|
|
|
"""保存模型(使用 LightGBM 原生格式)"""
|
|
|
|
|
|
if self.model is None:
|
|
|
|
|
|
raise RuntimeError("Model not fitted yet")
|
|
|
|
|
|
self.model.save_model(path)
|
|
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
|
def load(cls, path: str) -> "LightGBMModel":
|
|
|
|
|
|
"""加载模型"""
|
|
|
|
|
|
import lightgbm as lgb
|
|
|
|
|
|
instance = cls()
|
|
|
|
|
|
instance.model = lgb.Booster(model_file=path)
|
|
|
|
|
|
return instance
|
|
|
|
|
|
```
|
|
|
|
|
|
|
|
|
|
|
|
### 3.6 数据处理器 (components/processors/)
|
|
|
|
|
|
|
|
|
|
|
|
#### StandardScaler
|
|
|
|
|
|
```python
|
|
|
|
|
|
@register_processor("standard_scaler")
|
|
|
|
|
|
class StandardScaler(BaseProcessor):
|
|
|
|
|
|
"""标准化处理器(时序标准化)
|
|
|
|
|
|
|
|
|
|
|
|
在整个训练集上学习均值和标准差,
|
|
|
|
|
|
然后应用到训练集和测试集。
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
name = "standard_scaler"
|
|
|
|
|
|
|
|
|
|
|
|
def __init__(self, exclude_cols: List[str] = None):
|
|
|
|
|
|
self.exclude_cols = exclude_cols or ["ts_code", "trade_date"]
|
|
|
|
|
|
self.mean_ = {}
|
|
|
|
|
|
self.std_ = {}
|
|
|
|
|
|
|
|
|
|
|
|
def fit(self, X: pl.DataFrame) -> "StandardScaler":
|
|
|
|
|
|
"""计算均值和标准差(仅在训练集上)"""
|
|
|
|
|
|
numeric_cols = [
|
|
|
|
|
|
c for c in X.columns
|
|
|
|
|
|
if c not in self.exclude_cols and X[c].dtype.is_numeric()
|
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
|
|
for col in numeric_cols:
|
|
|
|
|
|
self.mean_[col] = X[col].mean()
|
|
|
|
|
|
self.std_[col] = X[col].std()
|
|
|
|
|
|
|
|
|
|
|
|
return self
|
|
|
|
|
|
|
|
|
|
|
|
def transform(self, X: pl.DataFrame) -> pl.DataFrame:
|
|
|
|
|
|
"""标准化(使用训练集学到的参数)"""
|
|
|
|
|
|
expressions = []
|
|
|
|
|
|
for col in X.columns:
|
|
|
|
|
|
if col in self.mean_:
|
|
|
|
|
|
expr = ((pl.col(col) - self.mean_[col]) / self.std_[col]).alias(col)
|
|
|
|
|
|
expressions.append(expr)
|
|
|
|
|
|
else:
|
|
|
|
|
|
expressions.append(pl.col(col))
|
|
|
|
|
|
|
|
|
|
|
|
return X.select(expressions)
|
|
|
|
|
|
```
|
|
|
|
|
|
|
|
|
|
|
|
#### CrossSectionalStandardScaler
|
|
|
|
|
|
```python
|
|
|
|
|
|
@register_processor("cs_standard_scaler")
|
|
|
|
|
|
class CrossSectionalStandardScaler(BaseProcessor):
|
|
|
|
|
|
"""截面标准化处理器
|
|
|
|
|
|
|
|
|
|
|
|
每天独立进行标准化:对当天所有股票的某一因子进行标准化。
|
|
|
|
|
|
|
|
|
|
|
|
特点:
|
|
|
|
|
|
- 不需要 fit,每天独立计算当天的均值和标准差
|
|
|
|
|
|
- 适用于截面因子,消除市值等行业差异
|
|
|
|
|
|
- 公式:z = (x - mean_today) / std_today
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
name = "cs_standard_scaler"
|
|
|
|
|
|
|
|
|
|
|
|
def __init__(self, exclude_cols: List[str] = None, date_col: str = "trade_date"):
|
|
|
|
|
|
self.exclude_cols = exclude_cols or ["ts_code", "trade_date"]
|
|
|
|
|
|
self.date_col = date_col
|
|
|
|
|
|
|
|
|
|
|
|
def transform(self, X: pl.DataFrame) -> pl.DataFrame:
|
|
|
|
|
|
"""截面标准化
|
|
|
|
|
|
|
|
|
|
|
|
按日期分组,每天独立计算均值和标准差并标准化。
|
|
|
|
|
|
不需要 fit,因为每天使用当天的统计量。
|
|
|
|
|
|
"""
|
|
|
|
|
|
numeric_cols = [
|
|
|
|
|
|
c for c in X.columns
|
|
|
|
|
|
if c not in self.exclude_cols and X[c].dtype.is_numeric()
|
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
|
|
# 按日期分组标准化
|
|
|
|
|
|
result = X.with_columns([
|
|
|
|
|
|
pl.col(col).mean().over(self.date_col).alias(f"{col}_mean")
|
|
|
|
|
|
for col in numeric_cols
|
|
|
|
|
|
] + [
|
|
|
|
|
|
pl.col(col).std().over(self.date_col).alias(f"{col}_std")
|
|
|
|
|
|
for col in numeric_cols
|
|
|
|
|
|
])
|
|
|
|
|
|
|
|
|
|
|
|
# 计算标准化值
|
|
|
|
|
|
for col in numeric_cols:
|
|
|
|
|
|
result = result.with_columns([
|
|
|
|
|
|
((pl.col(col) - pl.col(f"{col}_mean")) / pl.col(f"{col}_std")).alias(col)
|
|
|
|
|
|
])
|
|
|
|
|
|
# 删除中间列
|
|
|
|
|
|
result = result.drop([f"{col}_mean", f"{col}_std"])
|
|
|
|
|
|
|
|
|
|
|
|
return result
|
|
|
|
|
|
```
|
|
|
|
|
|
|
|
|
|
|
|
#### Winsorizer
|
|
|
|
|
|
```python
|
|
|
|
|
|
@register_processor("winsorizer")
|
|
|
|
|
|
class Winsorizer(BaseProcessor):
|
|
|
|
|
|
"""缩尾处理器
|
|
|
|
|
|
|
|
|
|
|
|
对每一列的极端值进行截断处理。
|
|
|
|
|
|
可以全局截断(在整个训练集上学习分位数),
|
|
|
|
|
|
也可以截面截断(每天独立处理)。
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
name = "winsorizer"
|
|
|
|
|
|
|
|
|
|
|
|
def __init__(
|
|
|
|
|
|
self,
|
|
|
|
|
|
lower: float = 0.01,
|
|
|
|
|
|
upper: float = 0.99,
|
|
|
|
|
|
by_date: bool = False, # True=每天独立缩尾, False=全局缩尾
|
|
|
|
|
|
date_col: str = "trade_date"
|
|
|
|
|
|
):
|
|
|
|
|
|
self.lower = lower
|
|
|
|
|
|
self.upper = upper
|
|
|
|
|
|
self.by_date = by_date
|
|
|
|
|
|
self.date_col = date_col
|
|
|
|
|
|
self.bounds_ = {} # 存储分位数边界(全局模式)
|
|
|
|
|
|
|
|
|
|
|
|
def fit(self, X: pl.DataFrame) -> "Winsorizer":
|
|
|
|
|
|
"""学习分位数边界(仅在全局模式下)"""
|
|
|
|
|
|
if not self.by_date:
|
|
|
|
|
|
numeric_cols = [c for c in X.columns if X[c].dtype.is_numeric()]
|
|
|
|
|
|
for col in numeric_cols:
|
|
|
|
|
|
self.bounds_[col] = {
|
|
|
|
|
|
"lower": X[col].quantile(self.lower),
|
|
|
|
|
|
"upper": X[col].quantile(self.upper)
|
|
|
|
|
|
}
|
|
|
|
|
|
return self
|
|
|
|
|
|
|
|
|
|
|
|
def transform(self, X: pl.DataFrame) -> pl.DataFrame:
|
|
|
|
|
|
"""缩尾处理"""
|
|
|
|
|
|
if self.by_date:
|
|
|
|
|
|
# 每天独立缩尾
|
|
|
|
|
|
return self._transform_by_date(X)
|
|
|
|
|
|
else:
|
|
|
|
|
|
# 全局缩尾
|
|
|
|
|
|
return self._transform_global(X)
|
|
|
|
|
|
|
|
|
|
|
|
def _transform_global(self, X: pl.DataFrame) -> pl.DataFrame:
|
|
|
|
|
|
"""全局缩尾(使用训练集学到的边界)"""
|
|
|
|
|
|
expressions = []
|
|
|
|
|
|
for col in X.columns:
|
|
|
|
|
|
if col in self.bounds_:
|
|
|
|
|
|
lower = self.bounds_[col]["lower"]
|
|
|
|
|
|
upper = self.bounds_[col]["upper"]
|
|
|
|
|
|
expr = pl.col(col).clip(lower, upper).alias(col)
|
|
|
|
|
|
expressions.append(expr)
|
|
|
|
|
|
else:
|
|
|
|
|
|
expressions.append(pl.col(col))
|
|
|
|
|
|
return X.select(expressions)
|
|
|
|
|
|
|
|
|
|
|
|
def _transform_by_date(self, X: pl.DataFrame) -> pl.DataFrame:
|
|
|
|
|
|
"""每日独立缩尾"""
|
|
|
|
|
|
# 按日期分组计算分位数并截断
|
|
|
|
|
|
# Polars 实现...
|
|
|
|
|
|
pass
|
|
|
|
|
|
```
|
|
|
|
|
|
|
|
|
|
|
|
## 4. 训练流程设计
|
|
|
|
|
|
|
|
|
|
|
|
### 4.1 Trainer 主类 (core/trainer.py)
|
|
|
|
|
|
|
|
|
|
|
|
```python
|
|
|
|
|
|
class Trainer:
|
|
|
|
|
|
"""训练器主类
|
|
|
|
|
|
|
|
|
|
|
|
整合数据处理、模型训练、评估的完整流程。
|
|
|
|
|
|
|
|
|
|
|
|
关键设计:
|
|
|
|
|
|
1. 因子先计算(全市场),再筛选股票池(每日独立)
|
|
|
|
|
|
2. Processor 分阶段行为:训练集 fit_transform,测试集 transform
|
|
|
|
|
|
3. 一次性训练,不滚动
|
|
|
|
|
|
4. 支持模型持久化
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
def __init__(
|
|
|
|
|
|
self,
|
|
|
|
|
|
model: BaseModel,
|
|
|
|
|
|
pool_manager: Optional[StockPoolManager] = None,
|
|
|
|
|
|
processors: List[BaseProcessor] = None,
|
|
|
|
|
|
splitter: DateSplitter = None,
|
|
|
|
|
|
target_col: str = "target",
|
|
|
|
|
|
feature_cols: List[str] = None,
|
|
|
|
|
|
persist_model: bool = False,
|
|
|
|
|
|
model_save_path: Optional[str] = None,
|
|
|
|
|
|
):
|
|
|
|
|
|
self.model = model
|
|
|
|
|
|
self.pool_manager = pool_manager
|
|
|
|
|
|
self.processors = processors or []
|
|
|
|
|
|
self.splitter = splitter
|
|
|
|
|
|
self.target_col = target_col
|
|
|
|
|
|
self.feature_cols = feature_cols
|
|
|
|
|
|
self.persist_model = persist_model
|
|
|
|
|
|
self.model_save_path = model_save_path
|
|
|
|
|
|
|
|
|
|
|
|
# 存储训练后的处理器
|
|
|
|
|
|
self.fitted_processors: List[BaseProcessor] = []
|
|
|
|
|
|
self.results: pl.DataFrame = None
|
|
|
|
|
|
|
|
|
|
|
|
def train(self, data: pl.DataFrame) -> "Trainer":
|
|
|
|
|
|
"""执行训练流程
|
|
|
|
|
|
|
|
|
|
|
|
流程:
|
|
|
|
|
|
1. 股票池每日筛选(如果配置了 pool_manager)
|
|
|
|
|
|
2. 按日期划分训练集/测试集
|
|
|
|
|
|
3. 训练集:processors fit_transform
|
|
|
|
|
|
4. 训练模型
|
|
|
|
|
|
5. 测试集:processors transform(使用训练集学到的参数)
|
|
|
|
|
|
6. 预测并评估
|
|
|
|
|
|
7. 持久化模型(如果启用)
|
|
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
|
data: 因子计算后的全市场数据
|
|
|
|
|
|
必须包含 ts_code 和 trade_date 列
|
|
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
|
self (支持链式调用)
|
|
|
|
|
|
"""
|
|
|
|
|
|
# 1. 股票池筛选(每日独立)
|
|
|
|
|
|
if self.pool_manager:
|
|
|
|
|
|
print("[筛选] 每日独立筛选股票池...")
|
|
|
|
|
|
data = self.pool_manager.filter_and_select_daily(data)
|
|
|
|
|
|
|
|
|
|
|
|
# 2. 划分训练/测试集
|
|
|
|
|
|
print("[划分] 划分训练集和测试集...")
|
|
|
|
|
|
train_data, test_data = self.splitter.split(data)
|
|
|
|
|
|
|
|
|
|
|
|
# 3. 训练集:processors fit_transform
|
|
|
|
|
|
print("[处理] 处理训练集...")
|
|
|
|
|
|
for processor in self.processors:
|
|
|
|
|
|
train_data = processor.fit_transform(train_data)
|
|
|
|
|
|
self.fitted_processors.append(processor)
|
|
|
|
|
|
|
|
|
|
|
|
# 4. 训练模型
|
|
|
|
|
|
print("[训练] 训练模型...")
|
|
|
|
|
|
X_train = train_data.select(self.feature_cols)
|
|
|
|
|
|
y_train = train_data.select(self.target_col).to_series()
|
|
|
|
|
|
self.model.fit(X_train, y_train)
|
|
|
|
|
|
|
|
|
|
|
|
# 5. 测试集:processors transform
|
|
|
|
|
|
print("[处理] 处理测试集...")
|
|
|
|
|
|
for processor in self.fitted_processors:
|
|
|
|
|
|
test_data = processor.transform(test_data)
|
|
|
|
|
|
|
|
|
|
|
|
# 6. 预测
|
|
|
|
|
|
print("[预测] 生成预测...")
|
|
|
|
|
|
X_test = test_data.select(self.feature_cols)
|
|
|
|
|
|
predictions = self.model.predict(X_test)
|
|
|
|
|
|
|
2026-03-03 22:57:01 +08:00
|
|
|
|
# 7. 保存结果
|
2026-03-03 22:04:22 +08:00
|
|
|
|
self.results = test_data.with_columns([
|
|
|
|
|
|
pl.Series("prediction", predictions)
|
|
|
|
|
|
])
|
|
|
|
|
|
|
2026-03-03 22:57:01 +08:00
|
|
|
|
# 8. 持久化模型
|
2026-03-03 22:04:22 +08:00
|
|
|
|
if self.persist_model and self.model_save_path:
|
|
|
|
|
|
print(f"[保存] 保存模型到 {self.model_save_path}...")
|
|
|
|
|
|
self.save_model(self.model_save_path)
|
|
|
|
|
|
|
|
|
|
|
|
return self
|
|
|
|
|
|
|
|
|
|
|
|
def predict(self, data: pl.DataFrame) -> pl.DataFrame:
|
|
|
|
|
|
"""对新数据进行预测
|
|
|
|
|
|
|
|
|
|
|
|
注意:新数据需要先经过股票池筛选,
|
|
|
|
|
|
然后使用训练好的 processors 和 model 进行预测。
|
|
|
|
|
|
"""
|
|
|
|
|
|
# 应用 processors
|
|
|
|
|
|
for processor in self.fitted_processors:
|
|
|
|
|
|
data = processor.transform(data)
|
|
|
|
|
|
|
|
|
|
|
|
# 预测
|
|
|
|
|
|
X = data.select(self.feature_cols)
|
|
|
|
|
|
predictions = self.model.predict(X)
|
|
|
|
|
|
|
|
|
|
|
|
return data.with_columns([pl.Series("prediction", predictions)])
|
|
|
|
|
|
|
|
|
|
|
|
def get_results(self) -> pl.DataFrame:
|
|
|
|
|
|
"""获取所有预测结果"""
|
|
|
|
|
|
return self.results
|
|
|
|
|
|
|
|
|
|
|
|
def save_results(self, path: str):
|
|
|
|
|
|
"""保存预测结果到文件"""
|
|
|
|
|
|
if self.results is not None:
|
|
|
|
|
|
self.results.write_csv(path)
|
|
|
|
|
|
|
|
|
|
|
|
def save_model(self, path: str):
|
|
|
|
|
|
"""保存模型"""
|
|
|
|
|
|
self.model.save(path)
|
|
|
|
|
|
```
|
|
|
|
|
|
|
|
|
|
|
|
### 4.2 配置类 (config/config.py)
|
|
|
|
|
|
|
|
|
|
|
|
```python
|
|
|
|
|
|
from pydantic import BaseSettings, Field
|
|
|
|
|
|
from typing import List, Dict, Optional
|
|
|
|
|
|
from dataclasses import dataclass
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
|
|
|
class StockFilterConfig:
|
|
|
|
|
|
"""股票过滤器配置"""
|
|
|
|
|
|
exclude_cyb: bool = True # 排除创业板
|
|
|
|
|
|
exclude_kcb: bool = True # 排除科创板
|
|
|
|
|
|
exclude_bj: bool = True # 排除北交所
|
|
|
|
|
|
exclude_st: bool = True # 排除ST股票
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
|
|
|
class MarketCapSelectorConfig:
|
|
|
|
|
|
"""市值选择器配置"""
|
|
|
|
|
|
enabled: bool = True # 是否启用
|
|
|
|
|
|
n: int = 100 # 选择前 n 只
|
|
|
|
|
|
ascending: bool = False # False=最大市值, True=最小市值
|
|
|
|
|
|
market_cap_col: str = "total_mv" # 市值列名(来自 daily_basic)
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
|
|
|
class ProcessorConfig:
|
|
|
|
|
|
"""处理器配置"""
|
|
|
|
|
|
name: str
|
|
|
|
|
|
params: Dict = Field(default_factory=dict)
|
|
|
|
|
|
|
|
|
|
|
|
class TrainingConfig(BaseSettings):
|
|
|
|
|
|
"""训练配置类"""
|
|
|
|
|
|
|
|
|
|
|
|
# === 数据配置(必填)===
|
|
|
|
|
|
feature_cols: List[str] = Field(..., min_items=1) # 特征列名,至少一个
|
|
|
|
|
|
target_col: str = "target" # 目标变量列名
|
|
|
|
|
|
date_col: str = "trade_date" # 日期列名
|
|
|
|
|
|
code_col: str = "ts_code" # 股票代码列名
|
|
|
|
|
|
|
|
|
|
|
|
# === 日期划分(必填)===
|
|
|
|
|
|
train_start: str = Field(..., description="训练期开始 YYYYMMDD")
|
|
|
|
|
|
train_end: str = Field(..., description="训练期结束 YYYYMMDD")
|
|
|
|
|
|
test_start: str = Field(..., description="测试期开始 YYYYMMDD")
|
|
|
|
|
|
test_end: str = Field(..., description="测试期结束 YYYYMMDD")
|
|
|
|
|
|
|
|
|
|
|
|
# === 股票池配置 ===
|
|
|
|
|
|
stock_filter: StockFilterConfig = Field(
|
|
|
|
|
|
default_factory=lambda: StockFilterConfig(
|
|
|
|
|
|
exclude_cyb=True,
|
|
|
|
|
|
exclude_kcb=True,
|
|
|
|
|
|
exclude_bj=True,
|
|
|
|
|
|
exclude_st=True,
|
|
|
|
|
|
)
|
|
|
|
|
|
)
|
|
|
|
|
|
stock_selector: Optional[MarketCapSelectorConfig] = Field(
|
|
|
|
|
|
default_factory=lambda: MarketCapSelectorConfig(
|
|
|
|
|
|
enabled=True,
|
|
|
|
|
|
n=100,
|
|
|
|
|
|
ascending=False,
|
|
|
|
|
|
market_cap_col="total_mv",
|
|
|
|
|
|
)
|
|
|
|
|
|
)
|
|
|
|
|
|
# 注意:如果 stock_selector = None,则跳过市值选择
|
|
|
|
|
|
|
|
|
|
|
|
# === 模型配置 ===
|
|
|
|
|
|
model_type: str = "lightgbm"
|
|
|
|
|
|
model_params: Dict = Field(default_factory=dict)
|
|
|
|
|
|
|
|
|
|
|
|
# === 处理器配置 ===
|
|
|
|
|
|
processors: List[ProcessorConfig] = Field(
|
|
|
|
|
|
default_factory=lambda: [
|
|
|
|
|
|
ProcessorConfig(name="winsorizer", params={"lower": 0.01, "upper": 0.99}),
|
|
|
|
|
|
ProcessorConfig(name="cs_standard_scaler", params={}),
|
|
|
|
|
|
]
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
# === 持久化配置 ===
|
|
|
|
|
|
persist_model: bool = False # 默认不持久化
|
|
|
|
|
|
model_save_path: Optional[str] = None # 持久化路径
|
|
|
|
|
|
|
|
|
|
|
|
# === 输出配置 ===
|
|
|
|
|
|
output_dir: str = "output"
|
|
|
|
|
|
save_predictions: bool = True
|
|
|
|
|
|
```
|
|
|
|
|
|
|
|
|
|
|
|
## 5. 使用示例
|
|
|
|
|
|
|
|
|
|
|
|
### 5.1 基本用法
|
|
|
|
|
|
|
|
|
|
|
|
```python
|
|
|
|
|
|
from src.training import Trainer, LightGBMModel, DateSplitter
|
|
|
|
|
|
from src.training.config import StockFilterConfig, MarketCapSelectorConfig
|
|
|
|
|
|
from src.training.core.stock_pool_manager import StockPoolManager
|
|
|
|
|
|
from src.training.components.processors import Winsorizer, CrossSectionalStandardScaler
|
|
|
|
|
|
from src.factors import FactorEngine
|
|
|
|
|
|
import polars as pl
|
|
|
|
|
|
|
|
|
|
|
|
# 1. 因子计算(全市场数据)
|
|
|
|
|
|
engine = FactorEngine()
|
|
|
|
|
|
all_data = engine.compute(
|
|
|
|
|
|
factor_names=["factor1", "factor2", "factor3", "future_return_5d"],
|
|
|
|
|
|
start_date="20200101",
|
|
|
|
|
|
end_date="20231231",
|
|
|
|
|
|
# 不指定 stock_codes,获取全市场数据
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
# 2. 创建股票池管理器(每日独立筛选)
|
|
|
|
|
|
pool_manager = StockPoolManager(
|
|
|
|
|
|
filter_config=StockFilterConfig(
|
|
|
|
|
|
exclude_cyb=True,
|
|
|
|
|
|
exclude_kcb=True,
|
|
|
|
|
|
exclude_bj=True,
|
|
|
|
|
|
exclude_st=True,
|
|
|
|
|
|
),
|
|
|
|
|
|
selector_config=MarketCapSelectorConfig(
|
|
|
|
|
|
enabled=True,
|
|
|
|
|
|
n=100,
|
|
|
|
|
|
ascending=False, # 最大市值
|
|
|
|
|
|
),
|
|
|
|
|
|
data_router=data_router, # 用于获取 daily_basic 数据
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
# 3. 创建模型
|
|
|
|
|
|
model = LightGBMModel(
|
|
|
|
|
|
objective="regression",
|
|
|
|
|
|
n_estimators=100,
|
|
|
|
|
|
learning_rate=0.05
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
# 4. 创建处理器
|
|
|
|
|
|
processors = [
|
|
|
|
|
|
Winsorizer(lower=0.01, upper=0.99, by_date=False), # 全局缩尾
|
|
|
|
|
|
CrossSectionalStandardScaler(), # 截面标准化(每天独立)
|
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
|
|
# 5. 创建划分器
|
|
|
|
|
|
splitter = DateSplitter(
|
|
|
|
|
|
train_start="20200101",
|
|
|
|
|
|
train_end="20221231",
|
|
|
|
|
|
test_start="20230101",
|
|
|
|
|
|
test_end="20231231"
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
# 6. 创建训练器
|
|
|
|
|
|
trainer = Trainer(
|
|
|
|
|
|
model=model,
|
|
|
|
|
|
pool_manager=pool_manager, # 每日筛选股票池
|
|
|
|
|
|
processors=processors,
|
|
|
|
|
|
splitter=splitter,
|
|
|
|
|
|
target_col="future_return_5d",
|
|
|
|
|
|
feature_cols=["factor1", "factor2", "factor3"],
|
|
|
|
|
|
persist_model=True, # 启用持久化
|
|
|
|
|
|
model_save_path="output/model.pkl",
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
# 7. 执行训练(传入全市场数据)
|
|
|
|
|
|
trainer.train(all_data)
|
|
|
|
|
|
|
|
|
|
|
|
# 8. 获取结果
|
|
|
|
|
|
results = trainer.get_results() # 包含预测值
|
|
|
|
|
|
|
|
|
|
|
|
# 9. 保存结果
|
|
|
|
|
|
trainer.save_results("output/predictions.csv")
|
|
|
|
|
|
|
|
|
|
|
|
# 10. 加载模型并预测新数据
|
|
|
|
|
|
loaded_model = LightGBMModel.load("output/model.pkl")
|
|
|
|
|
|
new_predictions = loaded_model.predict(new_data)
|
|
|
|
|
|
```
|
|
|
|
|
|
|
|
|
|
|
|
### 5.2 使用配置驱动
|
|
|
|
|
|
|
|
|
|
|
|
```python
|
|
|
|
|
|
from src.training.config import TrainingConfig, StockFilterConfig
|
|
|
|
|
|
from src.training import Trainer
|
|
|
|
|
|
from src.training.registry import ModelRegistry, ProcessorRegistry
|
|
|
|
|
|
from src.training.core.stock_pool_manager import StockPoolManager
|
|
|
|
|
|
from src.factors import FactorEngine
|
|
|
|
|
|
|
|
|
|
|
|
# 1. 配置(必填字段校验)
|
|
|
|
|
|
config = TrainingConfig(
|
|
|
|
|
|
feature_cols=["factor1", "factor2", "factor3"],
|
|
|
|
|
|
target_col="future_return_5d",
|
|
|
|
|
|
train_start="20200101",
|
|
|
|
|
|
train_end="20221231",
|
|
|
|
|
|
test_start="20230101",
|
|
|
|
|
|
test_end="20231231",
|
|
|
|
|
|
stock_filter=StockFilterConfig(
|
|
|
|
|
|
exclude_cyb=True,
|
|
|
|
|
|
exclude_kcb=True,
|
|
|
|
|
|
),
|
|
|
|
|
|
stock_selector=None, # 跳过市值选择
|
|
|
|
|
|
persist_model=False, # 不持久化
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
# 2. 因子计算
|
|
|
|
|
|
engine = FactorEngine()
|
|
|
|
|
|
all_data = engine.compute(
|
|
|
|
|
|
factor_names=config.feature_cols + [config.target_col],
|
|
|
|
|
|
start_date=config.train_start,
|
|
|
|
|
|
end_date=config.test_end,
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
# 3. 从配置创建组件
|
|
|
|
|
|
model = ModelRegistry.get_model(config.model_type)(**config.model_params)
|
|
|
|
|
|
processors = [
|
|
|
|
|
|
ProcessorRegistry.get_processor(p.name)(**p.params)
|
|
|
|
|
|
for p in config.processors
|
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
|
|
# 4. 创建股票池管理器
|
|
|
|
|
|
pool_manager = StockPoolManager(
|
|
|
|
|
|
filter_config=config.stock_filter,
|
|
|
|
|
|
selector_config=config.stock_selector,
|
|
|
|
|
|
data_router=data_router,
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
# 5. 创建并运行训练器
|
|
|
|
|
|
trainer = Trainer(
|
|
|
|
|
|
model=model,
|
|
|
|
|
|
pool_manager=pool_manager,
|
|
|
|
|
|
processors=processors,
|
|
|
|
|
|
splitter=DateSplitter(
|
|
|
|
|
|
train_start=config.train_start,
|
|
|
|
|
|
train_end=config.train_end,
|
|
|
|
|
|
test_start=config.test_start,
|
|
|
|
|
|
test_end=config.test_end,
|
|
|
|
|
|
),
|
|
|
|
|
|
target_col=config.target_col,
|
|
|
|
|
|
feature_cols=config.feature_cols,
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
trainer.train(all_data)
|
|
|
|
|
|
results = trainer.get_results()
|
|
|
|
|
|
```
|
|
|
|
|
|
|
|
|
|
|
|
## 6. 实现顺序
|
|
|
|
|
|
|
|
|
|
|
|
按以下顺序实现和提交:
|
|
|
|
|
|
|
|
|
|
|
|
### Commit 1: 基础架构
|
|
|
|
|
|
- `training/__init__.py`
|
|
|
|
|
|
- `training/components/__init__.py`
|
|
|
|
|
|
- `training/components/base.py`(BaseModel, BaseProcessor,含 save/load)
|
|
|
|
|
|
- `training/registry.py`(组件注册中心)
|
|
|
|
|
|
|
|
|
|
|
|
### Commit 2: 数据划分
|
|
|
|
|
|
- `training/components/splitters.py`(DateSplitter,仅一次性划分)
|
|
|
|
|
|
|
|
|
|
|
|
### Commit 3: 股票池选择器配置
|
|
|
|
|
|
- `training/components/selectors.py`(StockFilterConfig, MarketCapSelectorConfig)
|
|
|
|
|
|
|
|
|
|
|
|
### Commit 4: 数据处理器
|
|
|
|
|
|
- `training/components/processors/__init__.py`
|
|
|
|
|
|
- `training/components/processors/transforms.py`
|
|
|
|
|
|
- Winsorizer
|
|
|
|
|
|
- StandardScaler
|
|
|
|
|
|
- CrossSectionalStandardScaler
|
|
|
|
|
|
|
|
|
|
|
|
### Commit 5: LightGBM 模型
|
|
|
|
|
|
- `training/components/models/__init__.py`
|
|
|
|
|
|
- `training/components/models/lightgbm.py`(含 save_model/load_model)
|
|
|
|
|
|
|
2026-03-03 22:57:01 +08:00
|
|
|
|
### Commit 6: 股票池管理器
|
2026-03-03 22:04:22 +08:00
|
|
|
|
- `training/core/__init__.py`
|
|
|
|
|
|
- `training/core/stock_pool_manager.py`(每日独立筛选)
|
|
|
|
|
|
|
2026-03-03 22:57:01 +08:00
|
|
|
|
### Commit 7: Trainer 训练器
|
2026-03-03 22:04:22 +08:00
|
|
|
|
- `training/core/trainer.py`
|
|
|
|
|
|
|
2026-03-03 22:57:01 +08:00
|
|
|
|
### Commit 8: 配置管理
|
2026-03-03 22:04:22 +08:00
|
|
|
|
- `training/config/__init__.py`
|
|
|
|
|
|
- `training/config/config.py`(TrainingConfig,含必填校验)
|
|
|
|
|
|
|
2026-03-03 22:57:01 +08:00
|
|
|
|
### Commit 9: 预留实验模块
|
2026-03-03 22:04:22 +08:00
|
|
|
|
- `experiment/__init__.py`
|
|
|
|
|
|
|
|
|
|
|
|
## 7. 注意事项
|
|
|
|
|
|
|
|
|
|
|
|
### 7.1 股票池处理顺序(每日)
|
|
|
|
|
|
|
|
|
|
|
|
```
|
|
|
|
|
|
当日所有股票数据
|
|
|
|
|
|
↓
|
|
|
|
|
|
代码过滤(创业板、ST等)
|
|
|
|
|
|
↓
|
|
|
|
|
|
查询 daily_basic 获取当日市值
|
|
|
|
|
|
↓
|
|
|
|
|
|
市值选择(前N只)
|
|
|
|
|
|
↓
|
|
|
|
|
|
返回当日选中股票
|
|
|
|
|
|
```
|
|
|
|
|
|
|
|
|
|
|
|
- 每日独立筛选,市值动态变化
|
|
|
|
|
|
- 市值数据仅从 daily_basic 获取
|
|
|
|
|
|
- 市值数据绝不混入特征矩阵
|
|
|
|
|
|
|
|
|
|
|
|
### 7.2 Processor 阶段行为
|
|
|
|
|
|
|
|
|
|
|
|
| Processor | 训练集 | 测试集 |
|
|
|
|
|
|
|-----------|--------|--------|
|
|
|
|
|
|
| StandardScaler | fit_transform | transform(使用训练集参数) |
|
|
|
|
|
|
| CrossSectionalStandardScaler | transform | transform(每天独立) |
|
|
|
|
|
|
| Winsorizer (global) | fit_transform | transform(使用训练集参数) |
|
|
|
|
|
|
| Winsorizer (by_date) | transform | transform(每天独立) |
|
|
|
|
|
|
|
|
|
|
|
|
### 7.3 依赖关系
|
|
|
|
|
|
|
|
|
|
|
|
- 使用 Polars 进行数据处理
|
|
|
|
|
|
- LightGBM 用于模型训练
|
|
|
|
|
|
- Pydantic 用于配置管理
|
|
|
|
|
|
- 市值数据来自 daily_basic 表(独立数据源)
|
|
|
|
|
|
|
|
|
|
|
|
### 7.4 删除的功能
|
|
|
|
|
|
|
|
|
|
|
|
以下原计划在本次实现中删除:
|
|
|
|
|
|
|
|
|
|
|
|
1. **特征选择**(processors/selectors.py)
|
|
|
|
|
|
2. **滚动训练**(WalkForward, ExpandingWindow)
|
|
|
|
|
|
3. **结果分析工具**(复杂分析功能)
|
2026-03-03 22:57:01 +08:00
|
|
|
|
4. **validator.py, evaluator.py**(已删除,不实现 metrics)
|
2026-03-03 22:04:22 +08:00
|
|
|
|
|
|
|
|
|
|
### 7.5 新增功能
|
|
|
|
|
|
|
|
|
|
|
|
1. **StockPoolManager**(每日独立筛选)
|
|
|
|
|
|
2. **模型持久化**(save/load,默认关闭)
|
|
|
|
|
|
3. **配置必填校验**(feature_cols, 日期范围)
|