feat(training): 实现训练模块核心组件(commits 6-9)

- StockPoolManager:每日独立筛选股票池,支持代码过滤和市值选择
- Trainer:整合训练完整流程,支持 processor 分阶段行为和模型持久化
- TrainingConfig:pydantic 配置管理,含必填字段和日期验证
- experiment 模块:预留结构
- 从计划中移除 metrics 组件
- 调整 commit 序号(7-10 → 6-9)
- 更新 training/__init__.py 导出所有公开 API
This commit is contained in:
2026-03-03 22:57:01 +08:00
parent f35a6a76a6
commit 192718095f
9 changed files with 584 additions and 73 deletions

View File

@@ -40,9 +40,6 @@ src/
│ │ ├── processors/ # 数据处理器
│ │ │ ├── __init__.py
│ │ │ └── transforms.py # 标准化(截面/时序)、缩尾
│ │ └── metrics/ # 评估指标
│ │ ├── __init__.py
│ │ └── metrics.py # IC, RankIC, MSE, MAE
│ ├── config/ # 配置管理
│ │ ├── __init__.py
│ │ └── config.py # TrainingConfig (pydantic)
@@ -568,30 +565,6 @@ class Winsorizer(BaseProcessor):
pass
```
### 3.7 评估指标 (components/metrics/)
```python
def ic_score(y_true: np.ndarray, y_pred: np.ndarray) -> float:
"""信息系数 (Pearson 相关系数)"""
from scipy.stats import pearsonr
corr, _ = pearsonr(y_true, y_pred)
return corr
def rank_ic_score(y_true: np.ndarray, y_pred: np.ndarray) -> float:
"""秩相关系数 (Spearman 相关系数)"""
from scipy.stats import spearmanr
corr, _ = spearmanr(y_true, y_pred)
return corr
def mse_score(y_true: np.ndarray, y_pred: np.ndarray) -> float:
"""均方误差"""
return np.mean((y_true - y_pred) ** 2)
def mae_score(y_true: np.ndarray, y_pred: np.ndarray) -> float:
"""平均绝对误差"""
return np.mean(np.abs(y_true - y_pred))
```
## 4. 训练流程设计
### 4.1 Trainer 主类 (core/trainer.py)
@@ -615,7 +588,6 @@ class Trainer:
pool_manager: Optional[StockPoolManager] = None,
processors: List[BaseProcessor] = None,
splitter: DateSplitter = None,
metrics: List[str] = None,
target_col: str = "target",
feature_cols: List[str] = None,
persist_model: bool = False,
@@ -625,7 +597,6 @@ class Trainer:
self.pool_manager = pool_manager
self.processors = processors or []
self.splitter = splitter
self.metrics = metrics or ["ic", "rank_ic", "mse", "mae"]
self.target_col = target_col
self.feature_cols = feature_cols
self.persist_model = persist_model
@@ -634,7 +605,6 @@ class Trainer:
# 存储训练后的处理器
self.fitted_processors: List[BaseProcessor] = []
self.results: pl.DataFrame = None
self.metrics_results: Dict[str, float] = {}
def train(self, data: pl.DataFrame) -> "Trainer":
"""执行训练流程
@@ -686,38 +656,18 @@ class Trainer:
X_test = test_data.select(self.feature_cols)
predictions = self.model.predict(X_test)
# 7. 评估
print("[评估] 计算指标...")
y_test = test_data.select(self.target_col).to_series().to_numpy()
self._evaluate(y_test, predictions)
# 8. 保存结果
# 7. 保存结果
self.results = test_data.with_columns([
pl.Series("prediction", predictions)
])
# 9. 持久化模型
# 8. 持久化模型
if self.persist_model and self.model_save_path:
print(f"[保存] 保存模型到 {self.model_save_path}...")
self.save_model(self.model_save_path)
return self
def _evaluate(self, y_true: np.ndarray, y_pred: np.ndarray):
"""计算评估指标"""
from .components.metrics import ic_score, rank_ic_score, mse_score, mae_score
metric_funcs = {
"ic": ic_score,
"rank_ic": rank_ic_score,
"mse": mse_score,
"mae": mae_score,
}
for metric in self.metrics:
if metric in metric_funcs:
self.metrics_results[metric] = metric_funcs[metric](y_true, y_pred)
def predict(self, data: pl.DataFrame) -> pl.DataFrame:
"""对新数据进行预测
@@ -738,10 +688,6 @@ class Trainer:
"""获取所有预测结果"""
return self.results
def get_metrics(self) -> Dict[str, float]:
"""获取评估指标"""
return self.metrics_results
def save_results(self, path: str):
"""保存预测结果到文件"""
if self.results is not None:
@@ -827,9 +773,6 @@ class TrainingConfig(BaseSettings):
]
)
# === 评估指标 ===
metrics: List[str] = ["ic", "rank_ic", "mse", "mae"]
# === 持久化配置 ===
persist_model: bool = False # 默认不持久化
model_save_path: Optional[str] = None # 持久化路径
@@ -914,8 +857,6 @@ trainer.train(all_data)
# 8. 获取结果
results = trainer.get_results() # 包含预测值
metrics = trainer.get_metrics() # IC, RankIC, etc.
print("评估指标:", metrics)
# 9. 保存结果
trainer.save_results("output/predictions.csv")
@@ -989,7 +930,6 @@ trainer = Trainer(
trainer.train(all_data)
results = trainer.get_results()
metrics = trainer.get_metrics()
```
## 6. 实现顺序
@@ -1019,22 +959,18 @@ metrics = trainer.get_metrics()
- `training/components/models/__init__.py`
- `training/components/models/lightgbm.py`(含 save_model/load_model
### Commit 6: 评估指标
- `training/components/metrics/__init__.py`
- `training/components/metrics/metrics.py`IC, RankIC, MSE, MAE
### Commit 7: 股票池管理器
### Commit 6: 股票池管理器
- `training/core/__init__.py`
- `training/core/stock_pool_manager.py`(每日独立筛选)
### Commit 8: Trainer 训练器
### Commit 7: Trainer 训练器
- `training/core/trainer.py`
### Commit 9: 配置管理
### Commit 8: 配置管理
- `training/config/__init__.py`
- `training/config/config.py`TrainingConfig含必填校验
### Commit 10: 预留实验模块
### Commit 9: 预留实验模块
- `experiment/__init__.py`
## 7. 注意事项
@@ -1080,7 +1016,7 @@ metrics = trainer.get_metrics()
1. **特征选择**processors/selectors.py
2. **滚动训练**WalkForward, ExpandingWindow
3. **结果分析工具**(复杂分析功能)
4. **validator.py, evaluator.py**简化为 metrics.py
4. **validator.py, evaluator.py**已删除,不实现 metrics
### 7.5 新增功能