feat(training): 实现训练模块核心组件(commits 6-9)
- StockPoolManager:每日独立筛选股票池,支持代码过滤和市值选择 - Trainer:整合训练完整流程,支持 processor 分阶段行为和模型持久化 - TrainingConfig:pydantic 配置管理,含必填字段和日期验证 - experiment 模块:预留结构 - 从计划中移除 metrics 组件 - 调整 commit 序号(7-10 → 6-9) - 更新 training/__init__.py 导出所有公开 API
This commit is contained in:
@@ -40,9 +40,6 @@ src/
|
||||
│ │ ├── processors/ # 数据处理器
|
||||
│ │ │ ├── __init__.py
|
||||
│ │ │ └── transforms.py # 标准化(截面/时序)、缩尾
|
||||
│ │ └── metrics/ # 评估指标
|
||||
│ │ ├── __init__.py
|
||||
│ │ └── metrics.py # IC, RankIC, MSE, MAE
|
||||
│ ├── config/ # 配置管理
|
||||
│ │ ├── __init__.py
|
||||
│ │ └── config.py # TrainingConfig (pydantic)
|
||||
@@ -568,30 +565,6 @@ class Winsorizer(BaseProcessor):
|
||||
pass
|
||||
```
|
||||
|
||||
### 3.7 评估指标 (components/metrics/)
|
||||
|
||||
```python
|
||||
def ic_score(y_true: np.ndarray, y_pred: np.ndarray) -> float:
|
||||
"""信息系数 (Pearson 相关系数)"""
|
||||
from scipy.stats import pearsonr
|
||||
corr, _ = pearsonr(y_true, y_pred)
|
||||
return corr
|
||||
|
||||
def rank_ic_score(y_true: np.ndarray, y_pred: np.ndarray) -> float:
|
||||
"""秩相关系数 (Spearman 相关系数)"""
|
||||
from scipy.stats import spearmanr
|
||||
corr, _ = spearmanr(y_true, y_pred)
|
||||
return corr
|
||||
|
||||
def mse_score(y_true: np.ndarray, y_pred: np.ndarray) -> float:
|
||||
"""均方误差"""
|
||||
return np.mean((y_true - y_pred) ** 2)
|
||||
|
||||
def mae_score(y_true: np.ndarray, y_pred: np.ndarray) -> float:
|
||||
"""平均绝对误差"""
|
||||
return np.mean(np.abs(y_true - y_pred))
|
||||
```
|
||||
|
||||
## 4. 训练流程设计
|
||||
|
||||
### 4.1 Trainer 主类 (core/trainer.py)
|
||||
@@ -615,7 +588,6 @@ class Trainer:
|
||||
pool_manager: Optional[StockPoolManager] = None,
|
||||
processors: List[BaseProcessor] = None,
|
||||
splitter: DateSplitter = None,
|
||||
metrics: List[str] = None,
|
||||
target_col: str = "target",
|
||||
feature_cols: List[str] = None,
|
||||
persist_model: bool = False,
|
||||
@@ -625,7 +597,6 @@ class Trainer:
|
||||
self.pool_manager = pool_manager
|
||||
self.processors = processors or []
|
||||
self.splitter = splitter
|
||||
self.metrics = metrics or ["ic", "rank_ic", "mse", "mae"]
|
||||
self.target_col = target_col
|
||||
self.feature_cols = feature_cols
|
||||
self.persist_model = persist_model
|
||||
@@ -634,7 +605,6 @@ class Trainer:
|
||||
# 存储训练后的处理器
|
||||
self.fitted_processors: List[BaseProcessor] = []
|
||||
self.results: pl.DataFrame = None
|
||||
self.metrics_results: Dict[str, float] = {}
|
||||
|
||||
def train(self, data: pl.DataFrame) -> "Trainer":
|
||||
"""执行训练流程
|
||||
@@ -686,38 +656,18 @@ class Trainer:
|
||||
X_test = test_data.select(self.feature_cols)
|
||||
predictions = self.model.predict(X_test)
|
||||
|
||||
# 7. 评估
|
||||
print("[评估] 计算指标...")
|
||||
y_test = test_data.select(self.target_col).to_series().to_numpy()
|
||||
self._evaluate(y_test, predictions)
|
||||
|
||||
# 8. 保存结果
|
||||
# 7. 保存结果
|
||||
self.results = test_data.with_columns([
|
||||
pl.Series("prediction", predictions)
|
||||
])
|
||||
|
||||
# 9. 持久化模型
|
||||
# 8. 持久化模型
|
||||
if self.persist_model and self.model_save_path:
|
||||
print(f"[保存] 保存模型到 {self.model_save_path}...")
|
||||
self.save_model(self.model_save_path)
|
||||
|
||||
return self
|
||||
|
||||
def _evaluate(self, y_true: np.ndarray, y_pred: np.ndarray):
|
||||
"""计算评估指标"""
|
||||
from .components.metrics import ic_score, rank_ic_score, mse_score, mae_score
|
||||
|
||||
metric_funcs = {
|
||||
"ic": ic_score,
|
||||
"rank_ic": rank_ic_score,
|
||||
"mse": mse_score,
|
||||
"mae": mae_score,
|
||||
}
|
||||
|
||||
for metric in self.metrics:
|
||||
if metric in metric_funcs:
|
||||
self.metrics_results[metric] = metric_funcs[metric](y_true, y_pred)
|
||||
|
||||
def predict(self, data: pl.DataFrame) -> pl.DataFrame:
|
||||
"""对新数据进行预测
|
||||
|
||||
@@ -738,10 +688,6 @@ class Trainer:
|
||||
"""获取所有预测结果"""
|
||||
return self.results
|
||||
|
||||
def get_metrics(self) -> Dict[str, float]:
|
||||
"""获取评估指标"""
|
||||
return self.metrics_results
|
||||
|
||||
def save_results(self, path: str):
|
||||
"""保存预测结果到文件"""
|
||||
if self.results is not None:
|
||||
@@ -827,9 +773,6 @@ class TrainingConfig(BaseSettings):
|
||||
]
|
||||
)
|
||||
|
||||
# === 评估指标 ===
|
||||
metrics: List[str] = ["ic", "rank_ic", "mse", "mae"]
|
||||
|
||||
# === 持久化配置 ===
|
||||
persist_model: bool = False # 默认不持久化
|
||||
model_save_path: Optional[str] = None # 持久化路径
|
||||
@@ -914,8 +857,6 @@ trainer.train(all_data)
|
||||
|
||||
# 8. 获取结果
|
||||
results = trainer.get_results() # 包含预测值
|
||||
metrics = trainer.get_metrics() # IC, RankIC, etc.
|
||||
print("评估指标:", metrics)
|
||||
|
||||
# 9. 保存结果
|
||||
trainer.save_results("output/predictions.csv")
|
||||
@@ -989,7 +930,6 @@ trainer = Trainer(
|
||||
|
||||
trainer.train(all_data)
|
||||
results = trainer.get_results()
|
||||
metrics = trainer.get_metrics()
|
||||
```
|
||||
|
||||
## 6. 实现顺序
|
||||
@@ -1019,22 +959,18 @@ metrics = trainer.get_metrics()
|
||||
- `training/components/models/__init__.py`
|
||||
- `training/components/models/lightgbm.py`(含 save_model/load_model)
|
||||
|
||||
### Commit 6: 评估指标
|
||||
- `training/components/metrics/__init__.py`
|
||||
- `training/components/metrics/metrics.py`(IC, RankIC, MSE, MAE)
|
||||
|
||||
### Commit 7: 股票池管理器
|
||||
### Commit 6: 股票池管理器
|
||||
- `training/core/__init__.py`
|
||||
- `training/core/stock_pool_manager.py`(每日独立筛选)
|
||||
|
||||
### Commit 8: Trainer 训练器
|
||||
### Commit 7: Trainer 训练器
|
||||
- `training/core/trainer.py`
|
||||
|
||||
### Commit 9: 配置管理
|
||||
### Commit 8: 配置管理
|
||||
- `training/config/__init__.py`
|
||||
- `training/config/config.py`(TrainingConfig,含必填校验)
|
||||
|
||||
### Commit 10: 预留实验模块
|
||||
### Commit 9: 预留实验模块
|
||||
- `experiment/__init__.py`
|
||||
|
||||
## 7. 注意事项
|
||||
@@ -1080,7 +1016,7 @@ metrics = trainer.get_metrics()
|
||||
1. **特征选择**(processors/selectors.py)
|
||||
2. **滚动训练**(WalkForward, ExpandingWindow)
|
||||
3. **结果分析工具**(复杂分析功能)
|
||||
4. **validator.py, evaluator.py**(简化为 metrics.py)
|
||||
4. **validator.py, evaluator.py**(已删除,不实现 metrics)
|
||||
|
||||
### 7.5 新增功能
|
||||
|
||||
|
||||
Reference in New Issue
Block a user