Compare commits
5 Commits
317ecd87e7
...
192718095f
| Author | SHA1 | Date | |
|---|---|---|---|
| 192718095f | |||
| f35a6a76a6 | |||
| 9ca1deae56 | |||
| 6b63c428d9 | |||
| f48b307ad2 |
@@ -40,9 +40,6 @@ src/
|
|||||||
│ │ ├── processors/ # 数据处理器
|
│ │ ├── processors/ # 数据处理器
|
||||||
│ │ │ ├── __init__.py
|
│ │ │ ├── __init__.py
|
||||||
│ │ │ └── transforms.py # 标准化(截面/时序)、缩尾
|
│ │ │ └── transforms.py # 标准化(截面/时序)、缩尾
|
||||||
│ │ └── metrics/ # 评估指标
|
|
||||||
│ │ ├── __init__.py
|
|
||||||
│ │ └── metrics.py # IC, RankIC, MSE, MAE
|
|
||||||
│ ├── config/ # 配置管理
|
│ ├── config/ # 配置管理
|
||||||
│ │ ├── __init__.py
|
│ │ ├── __init__.py
|
||||||
│ │ └── config.py # TrainingConfig (pydantic)
|
│ │ └── config.py # TrainingConfig (pydantic)
|
||||||
@@ -568,30 +565,6 @@ class Winsorizer(BaseProcessor):
|
|||||||
pass
|
pass
|
||||||
```
|
```
|
||||||
|
|
||||||
### 3.7 评估指标 (components/metrics/)
|
|
||||||
|
|
||||||
```python
|
|
||||||
def ic_score(y_true: np.ndarray, y_pred: np.ndarray) -> float:
|
|
||||||
"""信息系数 (Pearson 相关系数)"""
|
|
||||||
from scipy.stats import pearsonr
|
|
||||||
corr, _ = pearsonr(y_true, y_pred)
|
|
||||||
return corr
|
|
||||||
|
|
||||||
def rank_ic_score(y_true: np.ndarray, y_pred: np.ndarray) -> float:
|
|
||||||
"""秩相关系数 (Spearman 相关系数)"""
|
|
||||||
from scipy.stats import spearmanr
|
|
||||||
corr, _ = spearmanr(y_true, y_pred)
|
|
||||||
return corr
|
|
||||||
|
|
||||||
def mse_score(y_true: np.ndarray, y_pred: np.ndarray) -> float:
|
|
||||||
"""均方误差"""
|
|
||||||
return np.mean((y_true - y_pred) ** 2)
|
|
||||||
|
|
||||||
def mae_score(y_true: np.ndarray, y_pred: np.ndarray) -> float:
|
|
||||||
"""平均绝对误差"""
|
|
||||||
return np.mean(np.abs(y_true - y_pred))
|
|
||||||
```
|
|
||||||
|
|
||||||
## 4. 训练流程设计
|
## 4. 训练流程设计
|
||||||
|
|
||||||
### 4.1 Trainer 主类 (core/trainer.py)
|
### 4.1 Trainer 主类 (core/trainer.py)
|
||||||
@@ -615,7 +588,6 @@ class Trainer:
|
|||||||
pool_manager: Optional[StockPoolManager] = None,
|
pool_manager: Optional[StockPoolManager] = None,
|
||||||
processors: List[BaseProcessor] = None,
|
processors: List[BaseProcessor] = None,
|
||||||
splitter: DateSplitter = None,
|
splitter: DateSplitter = None,
|
||||||
metrics: List[str] = None,
|
|
||||||
target_col: str = "target",
|
target_col: str = "target",
|
||||||
feature_cols: List[str] = None,
|
feature_cols: List[str] = None,
|
||||||
persist_model: bool = False,
|
persist_model: bool = False,
|
||||||
@@ -625,7 +597,6 @@ class Trainer:
|
|||||||
self.pool_manager = pool_manager
|
self.pool_manager = pool_manager
|
||||||
self.processors = processors or []
|
self.processors = processors or []
|
||||||
self.splitter = splitter
|
self.splitter = splitter
|
||||||
self.metrics = metrics or ["ic", "rank_ic", "mse", "mae"]
|
|
||||||
self.target_col = target_col
|
self.target_col = target_col
|
||||||
self.feature_cols = feature_cols
|
self.feature_cols = feature_cols
|
||||||
self.persist_model = persist_model
|
self.persist_model = persist_model
|
||||||
@@ -634,7 +605,6 @@ class Trainer:
|
|||||||
# 存储训练后的处理器
|
# 存储训练后的处理器
|
||||||
self.fitted_processors: List[BaseProcessor] = []
|
self.fitted_processors: List[BaseProcessor] = []
|
||||||
self.results: pl.DataFrame = None
|
self.results: pl.DataFrame = None
|
||||||
self.metrics_results: Dict[str, float] = {}
|
|
||||||
|
|
||||||
def train(self, data: pl.DataFrame) -> "Trainer":
|
def train(self, data: pl.DataFrame) -> "Trainer":
|
||||||
"""执行训练流程
|
"""执行训练流程
|
||||||
@@ -686,38 +656,18 @@ class Trainer:
|
|||||||
X_test = test_data.select(self.feature_cols)
|
X_test = test_data.select(self.feature_cols)
|
||||||
predictions = self.model.predict(X_test)
|
predictions = self.model.predict(X_test)
|
||||||
|
|
||||||
# 7. 评估
|
# 7. 保存结果
|
||||||
print("[评估] 计算指标...")
|
|
||||||
y_test = test_data.select(self.target_col).to_series().to_numpy()
|
|
||||||
self._evaluate(y_test, predictions)
|
|
||||||
|
|
||||||
# 8. 保存结果
|
|
||||||
self.results = test_data.with_columns([
|
self.results = test_data.with_columns([
|
||||||
pl.Series("prediction", predictions)
|
pl.Series("prediction", predictions)
|
||||||
])
|
])
|
||||||
|
|
||||||
# 9. 持久化模型
|
# 8. 持久化模型
|
||||||
if self.persist_model and self.model_save_path:
|
if self.persist_model and self.model_save_path:
|
||||||
print(f"[保存] 保存模型到 {self.model_save_path}...")
|
print(f"[保存] 保存模型到 {self.model_save_path}...")
|
||||||
self.save_model(self.model_save_path)
|
self.save_model(self.model_save_path)
|
||||||
|
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def _evaluate(self, y_true: np.ndarray, y_pred: np.ndarray):
|
|
||||||
"""计算评估指标"""
|
|
||||||
from .components.metrics import ic_score, rank_ic_score, mse_score, mae_score
|
|
||||||
|
|
||||||
metric_funcs = {
|
|
||||||
"ic": ic_score,
|
|
||||||
"rank_ic": rank_ic_score,
|
|
||||||
"mse": mse_score,
|
|
||||||
"mae": mae_score,
|
|
||||||
}
|
|
||||||
|
|
||||||
for metric in self.metrics:
|
|
||||||
if metric in metric_funcs:
|
|
||||||
self.metrics_results[metric] = metric_funcs[metric](y_true, y_pred)
|
|
||||||
|
|
||||||
def predict(self, data: pl.DataFrame) -> pl.DataFrame:
|
def predict(self, data: pl.DataFrame) -> pl.DataFrame:
|
||||||
"""对新数据进行预测
|
"""对新数据进行预测
|
||||||
|
|
||||||
@@ -738,10 +688,6 @@ class Trainer:
|
|||||||
"""获取所有预测结果"""
|
"""获取所有预测结果"""
|
||||||
return self.results
|
return self.results
|
||||||
|
|
||||||
def get_metrics(self) -> Dict[str, float]:
|
|
||||||
"""获取评估指标"""
|
|
||||||
return self.metrics_results
|
|
||||||
|
|
||||||
def save_results(self, path: str):
|
def save_results(self, path: str):
|
||||||
"""保存预测结果到文件"""
|
"""保存预测结果到文件"""
|
||||||
if self.results is not None:
|
if self.results is not None:
|
||||||
@@ -827,9 +773,6 @@ class TrainingConfig(BaseSettings):
|
|||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
# === 评估指标 ===
|
|
||||||
metrics: List[str] = ["ic", "rank_ic", "mse", "mae"]
|
|
||||||
|
|
||||||
# === 持久化配置 ===
|
# === 持久化配置 ===
|
||||||
persist_model: bool = False # 默认不持久化
|
persist_model: bool = False # 默认不持久化
|
||||||
model_save_path: Optional[str] = None # 持久化路径
|
model_save_path: Optional[str] = None # 持久化路径
|
||||||
@@ -914,8 +857,6 @@ trainer.train(all_data)
|
|||||||
|
|
||||||
# 8. 获取结果
|
# 8. 获取结果
|
||||||
results = trainer.get_results() # 包含预测值
|
results = trainer.get_results() # 包含预测值
|
||||||
metrics = trainer.get_metrics() # IC, RankIC, etc.
|
|
||||||
print("评估指标:", metrics)
|
|
||||||
|
|
||||||
# 9. 保存结果
|
# 9. 保存结果
|
||||||
trainer.save_results("output/predictions.csv")
|
trainer.save_results("output/predictions.csv")
|
||||||
@@ -989,7 +930,6 @@ trainer = Trainer(
|
|||||||
|
|
||||||
trainer.train(all_data)
|
trainer.train(all_data)
|
||||||
results = trainer.get_results()
|
results = trainer.get_results()
|
||||||
metrics = trainer.get_metrics()
|
|
||||||
```
|
```
|
||||||
|
|
||||||
## 6. 实现顺序
|
## 6. 实现顺序
|
||||||
@@ -1019,22 +959,18 @@ metrics = trainer.get_metrics()
|
|||||||
- `training/components/models/__init__.py`
|
- `training/components/models/__init__.py`
|
||||||
- `training/components/models/lightgbm.py`(含 save_model/load_model)
|
- `training/components/models/lightgbm.py`(含 save_model/load_model)
|
||||||
|
|
||||||
### Commit 6: 评估指标
|
### Commit 6: 股票池管理器
|
||||||
- `training/components/metrics/__init__.py`
|
|
||||||
- `training/components/metrics/metrics.py`(IC, RankIC, MSE, MAE)
|
|
||||||
|
|
||||||
### Commit 7: 股票池管理器
|
|
||||||
- `training/core/__init__.py`
|
- `training/core/__init__.py`
|
||||||
- `training/core/stock_pool_manager.py`(每日独立筛选)
|
- `training/core/stock_pool_manager.py`(每日独立筛选)
|
||||||
|
|
||||||
### Commit 8: Trainer 训练器
|
### Commit 7: Trainer 训练器
|
||||||
- `training/core/trainer.py`
|
- `training/core/trainer.py`
|
||||||
|
|
||||||
### Commit 9: 配置管理
|
### Commit 8: 配置管理
|
||||||
- `training/config/__init__.py`
|
- `training/config/__init__.py`
|
||||||
- `training/config/config.py`(TrainingConfig,含必填校验)
|
- `training/config/config.py`(TrainingConfig,含必填校验)
|
||||||
|
|
||||||
### Commit 10: 预留实验模块
|
### Commit 9: 预留实验模块
|
||||||
- `experiment/__init__.py`
|
- `experiment/__init__.py`
|
||||||
|
|
||||||
## 7. 注意事项
|
## 7. 注意事项
|
||||||
@@ -1080,7 +1016,7 @@ metrics = trainer.get_metrics()
|
|||||||
1. **特征选择**(processors/selectors.py)
|
1. **特征选择**(processors/selectors.py)
|
||||||
2. **滚动训练**(WalkForward, ExpandingWindow)
|
2. **滚动训练**(WalkForward, ExpandingWindow)
|
||||||
3. **结果分析工具**(复杂分析功能)
|
3. **结果分析工具**(复杂分析功能)
|
||||||
4. **validator.py, evaluator.py**(简化为 metrics.py)
|
4. **validator.py, evaluator.py**(已删除,不实现 metrics)
|
||||||
|
|
||||||
### 7.5 新增功能
|
### 7.5 新增功能
|
||||||
|
|
||||||
|
|||||||
7
src/experiment/__init__.py
Normal file
7
src/experiment/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
"""实验管理模块(预留结构)
|
||||||
|
|
||||||
|
此模块为预留结构,用于未来的实验管理功能。
|
||||||
|
暂不提供具体实现。
|
||||||
|
"""
|
||||||
|
|
||||||
|
__all__ = []
|
||||||
@@ -1,6 +1,6 @@
|
|||||||
"""训练模块 - ProStock 量化投资框架
|
"""训练模块 - ProStock 量化投资框架
|
||||||
|
|
||||||
提供模型训练、数据处理和评估的完整流程。
|
提供模型训练、数据处理和预测的完整流程。
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# 基础抽象类
|
# 基础抽象类
|
||||||
@@ -14,6 +14,31 @@ from src.training.registry import (
|
|||||||
register_processor,
|
register_processor,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# 数据划分器
|
||||||
|
from src.training.components.splitters import DateSplitter
|
||||||
|
|
||||||
|
# 股票池选择器配置
|
||||||
|
from src.training.components.selectors import (
|
||||||
|
MarketCapSelectorConfig,
|
||||||
|
StockFilterConfig,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 数据处理器
|
||||||
|
from src.training.components.processors import (
|
||||||
|
CrossSectionalStandardScaler,
|
||||||
|
StandardScaler,
|
||||||
|
Winsorizer,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 模型
|
||||||
|
from src.training.components.models import LightGBMModel
|
||||||
|
|
||||||
|
# 训练核心
|
||||||
|
from src.training.core import StockPoolManager, Trainer
|
||||||
|
|
||||||
|
# 配置
|
||||||
|
from src.training.config import TrainingConfig
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
# 基础抽象类
|
# 基础抽象类
|
||||||
"BaseModel",
|
"BaseModel",
|
||||||
@@ -23,4 +48,20 @@ __all__ = [
|
|||||||
"ProcessorRegistry",
|
"ProcessorRegistry",
|
||||||
"register_model",
|
"register_model",
|
||||||
"register_processor",
|
"register_processor",
|
||||||
|
# 数据划分器
|
||||||
|
"DateSplitter",
|
||||||
|
# 股票池选择器配置
|
||||||
|
"StockFilterConfig",
|
||||||
|
"MarketCapSelectorConfig",
|
||||||
|
# 数据处理器
|
||||||
|
"StandardScaler",
|
||||||
|
"CrossSectionalStandardScaler",
|
||||||
|
"Winsorizer",
|
||||||
|
# 模型
|
||||||
|
"LightGBMModel",
|
||||||
|
# 训练核心
|
||||||
|
"StockPoolManager",
|
||||||
|
"Trainer",
|
||||||
|
# 配置
|
||||||
|
"TrainingConfig",
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -6,7 +6,33 @@
|
|||||||
# 基础抽象类
|
# 基础抽象类
|
||||||
from src.training.components.base import BaseModel, BaseProcessor
|
from src.training.components.base import BaseModel, BaseProcessor
|
||||||
|
|
||||||
|
# 数据划分器
|
||||||
|
from src.training.components.splitters import DateSplitter
|
||||||
|
|
||||||
|
# 股票池选择器配置
|
||||||
|
from src.training.components.selectors import (
|
||||||
|
MarketCapSelectorConfig,
|
||||||
|
StockFilterConfig,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 数据处理器
|
||||||
|
from src.training.components.processors import (
|
||||||
|
CrossSectionalStandardScaler,
|
||||||
|
StandardScaler,
|
||||||
|
Winsorizer,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 模型
|
||||||
|
from src.training.components.models import LightGBMModel
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"BaseModel",
|
"BaseModel",
|
||||||
"BaseProcessor",
|
"BaseProcessor",
|
||||||
|
"DateSplitter",
|
||||||
|
"StockFilterConfig",
|
||||||
|
"MarketCapSelectorConfig",
|
||||||
|
"StandardScaler",
|
||||||
|
"CrossSectionalStandardScaler",
|
||||||
|
"Winsorizer",
|
||||||
|
"LightGBMModel",
|
||||||
]
|
]
|
||||||
|
|||||||
8
src/training/components/models/__init__.py
Normal file
8
src/training/components/models/__init__.py
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
"""模型子模块
|
||||||
|
|
||||||
|
包含各种机器学习模型的实现。
|
||||||
|
"""
|
||||||
|
|
||||||
|
from src.training.components.models.lightgbm import LightGBMModel
|
||||||
|
|
||||||
|
__all__ = ["LightGBMModel"]
|
||||||
194
src/training/components/models/lightgbm.py
Normal file
194
src/training/components/models/lightgbm.py
Normal file
@@ -0,0 +1,194 @@
|
|||||||
|
"""LightGBM 模型实现
|
||||||
|
|
||||||
|
提供 LightGBM 回归模型的实现,支持特征重要性和原生模型保存。
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pandas as pd
|
||||||
|
import polars as pl
|
||||||
|
|
||||||
|
from src.training.components.base import BaseModel
|
||||||
|
from src.training.registry import register_model
|
||||||
|
|
||||||
|
|
||||||
|
@register_model("lightgbm")
|
||||||
|
class LightGBMModel(BaseModel):
|
||||||
|
"""LightGBM 回归模型
|
||||||
|
|
||||||
|
使用 LightGBM 库实现梯度提升回归树。
|
||||||
|
支持自定义参数、特征重要性提取和原生模型格式保存。
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
name: 模型名称 "lightgbm"
|
||||||
|
params: LightGBM 参数字典
|
||||||
|
model: 训练后的 LightGBM Booster 对象
|
||||||
|
feature_names_: 特征名称列表
|
||||||
|
"""
|
||||||
|
|
||||||
|
name = "lightgbm"
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
objective: str = "regression",
|
||||||
|
metric: str = "rmse",
|
||||||
|
num_leaves: int = 31,
|
||||||
|
learning_rate: float = 0.05,
|
||||||
|
n_estimators: int = 100,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
"""初始化 LightGBM 模型
|
||||||
|
|
||||||
|
Args:
|
||||||
|
objective: 目标函数,默认 "regression"
|
||||||
|
metric: 评估指标,默认 "rmse"
|
||||||
|
num_leaves: 叶子节点数,默认 31
|
||||||
|
learning_rate: 学习率,默认 0.05
|
||||||
|
n_estimators: 迭代次数,默认 100
|
||||||
|
**kwargs: 其他 LightGBM 参数
|
||||||
|
"""
|
||||||
|
self.params = {
|
||||||
|
"objective": objective,
|
||||||
|
"metric": metric,
|
||||||
|
"num_leaves": num_leaves,
|
||||||
|
"learning_rate": learning_rate,
|
||||||
|
"verbose": -1, # 抑制训练输出
|
||||||
|
**kwargs,
|
||||||
|
}
|
||||||
|
self.n_estimators = n_estimators
|
||||||
|
self.model = None
|
||||||
|
self.feature_names_: Optional[list] = None
|
||||||
|
|
||||||
|
def fit(self, X: pl.DataFrame, y: pl.Series) -> "LightGBMModel":
|
||||||
|
"""训练模型
|
||||||
|
|
||||||
|
Args:
|
||||||
|
X: 特征矩阵 (Polars DataFrame)
|
||||||
|
y: 目标变量 (Polars Series)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
self (支持链式调用)
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ImportError: 未安装 lightgbm
|
||||||
|
RuntimeError: 训练失败
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
import lightgbm as lgb
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError(
|
||||||
|
"使用 LightGBMModel 需要安装 lightgbm: pip install lightgbm"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 保存特征名称
|
||||||
|
self.feature_names_ = X.columns
|
||||||
|
|
||||||
|
# 转换为 numpy
|
||||||
|
X_np = X.to_numpy()
|
||||||
|
y_np = y.to_numpy()
|
||||||
|
|
||||||
|
# 创建数据集
|
||||||
|
train_data = lgb.Dataset(X_np, label=y_np)
|
||||||
|
|
||||||
|
# 训练
|
||||||
|
self.model = lgb.train(
|
||||||
|
self.params,
|
||||||
|
train_data,
|
||||||
|
num_boost_round=self.n_estimators,
|
||||||
|
)
|
||||||
|
|
||||||
|
return self
|
||||||
|
|
||||||
|
def predict(self, X: pl.DataFrame) -> np.ndarray:
|
||||||
|
"""预测
|
||||||
|
|
||||||
|
Args:
|
||||||
|
X: 特征矩阵 (Polars DataFrame)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
预测结果 (numpy ndarray)
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
RuntimeError: 模型未训练时调用
|
||||||
|
"""
|
||||||
|
if self.model is None:
|
||||||
|
raise RuntimeError("模型尚未训练,请先调用 fit()")
|
||||||
|
|
||||||
|
X_np = X.to_numpy()
|
||||||
|
return self.model.predict(X_np)
|
||||||
|
|
||||||
|
def feature_importance(self) -> Optional[pd.Series]:
|
||||||
|
"""返回特征重要性
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
特征重要性序列,如果模型未训练则返回 None
|
||||||
|
"""
|
||||||
|
if self.model is None or self.feature_names_ is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
importance = self.model.feature_importance(importance_type="gain")
|
||||||
|
return pd.Series(importance, index=self.feature_names_)
|
||||||
|
|
||||||
|
def save(self, path: str) -> None:
|
||||||
|
"""保存模型(使用 LightGBM 原生格式)
|
||||||
|
|
||||||
|
使用 LightGBM 的原生格式保存,不依赖 pickle,
|
||||||
|
可以在不同环境中加载。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
path: 保存路径
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
RuntimeError: 模型未训练时调用
|
||||||
|
"""
|
||||||
|
if self.model is None:
|
||||||
|
raise RuntimeError("模型尚未训练,无法保存")
|
||||||
|
|
||||||
|
self.model.save_model(path)
|
||||||
|
|
||||||
|
# 同时保存特征名称(LightGBM 原生格式不保存这个)
|
||||||
|
import json
|
||||||
|
|
||||||
|
meta_path = path + ".meta.json"
|
||||||
|
with open(meta_path, "w") as f:
|
||||||
|
json.dump(
|
||||||
|
{
|
||||||
|
"feature_names": self.feature_names_,
|
||||||
|
"params": self.params,
|
||||||
|
"n_estimators": self.n_estimators,
|
||||||
|
},
|
||||||
|
f,
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def load(cls, path: str) -> "LightGBMModel":
|
||||||
|
"""加载模型
|
||||||
|
|
||||||
|
从 LightGBM 原生格式加载模型。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
path: 模型文件路径
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
加载的 LightGBMModel 实例
|
||||||
|
"""
|
||||||
|
import lightgbm as lgb
|
||||||
|
import json
|
||||||
|
|
||||||
|
instance = cls()
|
||||||
|
instance.model = lgb.Booster(model_file=path)
|
||||||
|
|
||||||
|
# 加载元数据
|
||||||
|
meta_path = path + ".meta.json"
|
||||||
|
try:
|
||||||
|
with open(meta_path, "r") as f:
|
||||||
|
meta = json.load(f)
|
||||||
|
instance.feature_names_ = meta.get("feature_names")
|
||||||
|
instance.params = meta.get("params", instance.params)
|
||||||
|
instance.n_estimators = meta.get("n_estimators", instance.n_estimators)
|
||||||
|
except FileNotFoundError:
|
||||||
|
# 如果没有元数据文件,继续运行(feature_names_ 为 None)
|
||||||
|
pass
|
||||||
|
|
||||||
|
return instance
|
||||||
16
src/training/components/processors/__init__.py
Normal file
16
src/training/components/processors/__init__.py
Normal file
@@ -0,0 +1,16 @@
|
|||||||
|
"""数据处理器子模块
|
||||||
|
|
||||||
|
包含数据预处理、转换等处理器实现。
|
||||||
|
"""
|
||||||
|
|
||||||
|
from src.training.components.processors.transforms import (
|
||||||
|
CrossSectionalStandardScaler,
|
||||||
|
StandardScaler,
|
||||||
|
Winsorizer,
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"StandardScaler",
|
||||||
|
"CrossSectionalStandardScaler",
|
||||||
|
"Winsorizer",
|
||||||
|
]
|
||||||
275
src/training/components/processors/transforms.py
Normal file
275
src/training/components/processors/transforms.py
Normal file
@@ -0,0 +1,275 @@
|
|||||||
|
"""数据处理器实现
|
||||||
|
|
||||||
|
包含标准化、缩尾等数据处理器。
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
import polars as pl
|
||||||
|
|
||||||
|
from src.training.components.base import BaseProcessor
|
||||||
|
from src.training.registry import register_processor
|
||||||
|
|
||||||
|
|
||||||
|
@register_processor("standard_scaler")
|
||||||
|
class StandardScaler(BaseProcessor):
|
||||||
|
"""标准化处理器(全局标准化)
|
||||||
|
|
||||||
|
在整个训练集上学习均值和标准差,
|
||||||
|
然后应用到训练集和测试集。
|
||||||
|
|
||||||
|
适用于需要全局统计量的场景。
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
exclude_cols: 不参与标准化的列名列表
|
||||||
|
mean_: 学习到的均值字典 {列名: 均值}
|
||||||
|
std_: 学习到的标准差字典 {列名: 标准差}
|
||||||
|
"""
|
||||||
|
|
||||||
|
name = "standard_scaler"
|
||||||
|
|
||||||
|
def __init__(self, exclude_cols: Optional[List[str]] = None):
|
||||||
|
"""初始化标准化处理器
|
||||||
|
|
||||||
|
Args:
|
||||||
|
exclude_cols: 不参与标准化的列名列表,默认为 ["ts_code", "trade_date"]
|
||||||
|
"""
|
||||||
|
self.exclude_cols = exclude_cols or ["ts_code", "trade_date"]
|
||||||
|
self.mean_: dict = {}
|
||||||
|
self.std_: dict = {}
|
||||||
|
|
||||||
|
def fit(self, X: pl.DataFrame) -> "StandardScaler":
|
||||||
|
"""计算均值和标准差(仅在训练集上)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
X: 训练数据
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
self
|
||||||
|
"""
|
||||||
|
numeric_cols = [
|
||||||
|
c
|
||||||
|
for c in X.columns
|
||||||
|
if c not in self.exclude_cols and X[c].dtype.is_numeric()
|
||||||
|
]
|
||||||
|
|
||||||
|
for col in numeric_cols:
|
||||||
|
self.mean_[col] = X[col].mean()
|
||||||
|
self.std_[col] = X[col].std()
|
||||||
|
|
||||||
|
return self
|
||||||
|
|
||||||
|
def transform(self, X: pl.DataFrame) -> pl.DataFrame:
|
||||||
|
"""标准化(使用训练集学到的参数)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
X: 待转换数据
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
标准化后的数据
|
||||||
|
"""
|
||||||
|
expressions = []
|
||||||
|
for col in X.columns:
|
||||||
|
if col in self.mean_ and col in self.std_:
|
||||||
|
# 避免除以0
|
||||||
|
std_val = self.std_[col] if self.std_[col] != 0 else 1.0
|
||||||
|
expr = ((pl.col(col) - self.mean_[col]) / std_val).alias(col)
|
||||||
|
expressions.append(expr)
|
||||||
|
else:
|
||||||
|
expressions.append(pl.col(col))
|
||||||
|
|
||||||
|
return X.select(expressions)
|
||||||
|
|
||||||
|
|
||||||
|
@register_processor("cs_standard_scaler")
|
||||||
|
class CrossSectionalStandardScaler(BaseProcessor):
|
||||||
|
"""截面标准化处理器
|
||||||
|
|
||||||
|
每天独立进行标准化:对当天所有股票的某一因子进行标准化。
|
||||||
|
|
||||||
|
特点:
|
||||||
|
- 不需要 fit,每天独立计算当天的均值和标准差
|
||||||
|
- 适用于截面因子,消除市值等行业差异
|
||||||
|
- 公式:z = (x - mean_today) / std_today
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
exclude_cols: 不参与标准化的列名列表
|
||||||
|
date_col: 日期列名
|
||||||
|
"""
|
||||||
|
|
||||||
|
name = "cs_standard_scaler"
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
exclude_cols: Optional[List[str]] = None,
|
||||||
|
date_col: str = "trade_date",
|
||||||
|
):
|
||||||
|
"""初始化截面标准化处理器
|
||||||
|
|
||||||
|
Args:
|
||||||
|
exclude_cols: 不参与标准化的列名列表,默认为 ["ts_code", "trade_date"]
|
||||||
|
date_col: 日期列名
|
||||||
|
"""
|
||||||
|
self.exclude_cols = exclude_cols or ["ts_code", "trade_date"]
|
||||||
|
self.date_col = date_col
|
||||||
|
|
||||||
|
def transform(self, X: pl.DataFrame) -> pl.DataFrame:
|
||||||
|
"""截面标准化
|
||||||
|
|
||||||
|
按日期分组,每天独立计算均值和标准差并标准化。
|
||||||
|
不需要 fit,因为每天使用当天的统计量。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
X: 待转换数据
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
标准化后的数据
|
||||||
|
"""
|
||||||
|
numeric_cols = [
|
||||||
|
c
|
||||||
|
for c in X.columns
|
||||||
|
if c not in self.exclude_cols and X[c].dtype.is_numeric()
|
||||||
|
]
|
||||||
|
|
||||||
|
# 构建表达式列表
|
||||||
|
expressions = []
|
||||||
|
for col in X.columns:
|
||||||
|
if col in numeric_cols:
|
||||||
|
# 截面标准化:每天独立计算均值和标准差
|
||||||
|
# 避免除以0,当std为0时设为1
|
||||||
|
expr = (
|
||||||
|
(pl.col(col) - pl.col(col).mean().over(self.date_col))
|
||||||
|
/ (pl.col(col).std().over(self.date_col) + 1e-10)
|
||||||
|
).alias(col)
|
||||||
|
expressions.append(expr)
|
||||||
|
else:
|
||||||
|
expressions.append(pl.col(col))
|
||||||
|
|
||||||
|
return X.select(expressions)
|
||||||
|
|
||||||
|
|
||||||
|
@register_processor("winsorizer")
|
||||||
|
class Winsorizer(BaseProcessor):
|
||||||
|
"""缩尾处理器
|
||||||
|
|
||||||
|
对每一列的极端值进行截断处理。
|
||||||
|
可以全局截断(在整个训练集上学习分位数),
|
||||||
|
也可以截面截断(每天独立处理)。
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
lower: 下分位数(如0.01表示1%分位数)
|
||||||
|
upper: 上分位数(如0.99表示99%分位数)
|
||||||
|
by_date: True=每天独立缩尾, False=全局缩尾
|
||||||
|
date_col: 日期列名
|
||||||
|
bounds_: 存储分位数边界(全局模式)
|
||||||
|
"""
|
||||||
|
|
||||||
|
name = "winsorizer"
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
lower: float = 0.01,
|
||||||
|
upper: float = 0.99,
|
||||||
|
by_date: bool = False,
|
||||||
|
date_col: str = "trade_date",
|
||||||
|
):
|
||||||
|
"""初始化缩尾处理器
|
||||||
|
|
||||||
|
Args:
|
||||||
|
lower: 下分位数,默认0.01
|
||||||
|
upper: 上分位数,默认0.99
|
||||||
|
by_date: 每天独立缩尾,默认False(全局缩尾)
|
||||||
|
date_col: 日期列名
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: 分位数参数无效
|
||||||
|
"""
|
||||||
|
if not 0 <= lower < upper <= 1:
|
||||||
|
raise ValueError(
|
||||||
|
f"lower ({lower}) 必须小于 upper ({upper}),且都在 [0, 1] 范围内"
|
||||||
|
)
|
||||||
|
|
||||||
|
self.lower = lower
|
||||||
|
self.upper = upper
|
||||||
|
self.by_date = by_date
|
||||||
|
self.date_col = date_col
|
||||||
|
self.bounds_: dict = {}
|
||||||
|
|
||||||
|
def fit(self, X: pl.DataFrame) -> "Winsorizer":
|
||||||
|
"""学习分位数边界(仅在全局模式下)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
X: 训练数据
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
self
|
||||||
|
"""
|
||||||
|
if not self.by_date:
|
||||||
|
numeric_cols = [c for c in X.columns if X[c].dtype.is_numeric()]
|
||||||
|
for col in numeric_cols:
|
||||||
|
self.bounds_[col] = {
|
||||||
|
"lower": X[col].quantile(self.lower),
|
||||||
|
"upper": X[col].quantile(self.upper),
|
||||||
|
}
|
||||||
|
return self
|
||||||
|
|
||||||
|
def transform(self, X: pl.DataFrame) -> pl.DataFrame:
|
||||||
|
"""缩尾处理
|
||||||
|
|
||||||
|
Args:
|
||||||
|
X: 待转换数据
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
缩尾处理后的数据
|
||||||
|
"""
|
||||||
|
if self.by_date:
|
||||||
|
return self._transform_by_date(X)
|
||||||
|
else:
|
||||||
|
return self._transform_global(X)
|
||||||
|
|
||||||
|
def _transform_global(self, X: pl.DataFrame) -> pl.DataFrame:
|
||||||
|
"""全局缩尾(使用训练集学到的边界)"""
|
||||||
|
expressions = []
|
||||||
|
for col in X.columns:
|
||||||
|
if col in self.bounds_:
|
||||||
|
lower = self.bounds_[col]["lower"]
|
||||||
|
upper = self.bounds_[col]["upper"]
|
||||||
|
expr = pl.col(col).clip(lower, upper).alias(col)
|
||||||
|
expressions.append(expr)
|
||||||
|
else:
|
||||||
|
expressions.append(pl.col(col))
|
||||||
|
return X.select(expressions)
|
||||||
|
|
||||||
|
def _transform_by_date(self, X: pl.DataFrame) -> pl.DataFrame:
|
||||||
|
"""每日独立缩尾"""
|
||||||
|
numeric_cols = [c for c in X.columns if X[c].dtype.is_numeric()]
|
||||||
|
|
||||||
|
# 先计算每天的分位数
|
||||||
|
lower_exprs = [
|
||||||
|
pl.col(col).quantile(self.lower).over(self.date_col).alias(f"{col}_lower")
|
||||||
|
for col in numeric_cols
|
||||||
|
]
|
||||||
|
upper_exprs = [
|
||||||
|
pl.col(col).quantile(self.upper).over(self.date_col).alias(f"{col}_upper")
|
||||||
|
for col in numeric_cols
|
||||||
|
]
|
||||||
|
|
||||||
|
# 添加分位数列
|
||||||
|
result = X.with_columns(lower_exprs + upper_exprs)
|
||||||
|
|
||||||
|
# 执行缩尾
|
||||||
|
clip_exprs = []
|
||||||
|
for col in X.columns:
|
||||||
|
if col in numeric_cols:
|
||||||
|
clipped = (
|
||||||
|
pl.col(col)
|
||||||
|
.clip(pl.col(f"{col}_lower"), pl.col(f"{col}_upper"))
|
||||||
|
.alias(col)
|
||||||
|
)
|
||||||
|
clip_exprs.append(clipped)
|
||||||
|
else:
|
||||||
|
clip_exprs.append(pl.col(col))
|
||||||
|
|
||||||
|
result = result.select(clip_exprs)
|
||||||
|
|
||||||
|
return result
|
||||||
81
src/training/components/selectors.py
Normal file
81
src/training/components/selectors.py
Normal file
@@ -0,0 +1,81 @@
|
|||||||
|
"""股票池选择器配置
|
||||||
|
|
||||||
|
提供股票过滤和市值选择的配置类。
|
||||||
|
"""
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class StockFilterConfig:
|
||||||
|
"""股票过滤器配置
|
||||||
|
|
||||||
|
用于过滤掉不需要的股票(如创业板、科创板等)。
|
||||||
|
基于股票代码进行过滤,不依赖外部数据。
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
exclude_cyb: 是否排除创业板(300xxx)
|
||||||
|
exclude_kcb: 是否排除科创板(688xxx)
|
||||||
|
exclude_bj: 是否排除北交所(8xxxxxx, 4xxxxxx)
|
||||||
|
exclude_st: 是否排除ST股票(需要外部数据支持)
|
||||||
|
"""
|
||||||
|
|
||||||
|
exclude_cyb: bool = True
|
||||||
|
exclude_kcb: bool = True
|
||||||
|
exclude_bj: bool = True
|
||||||
|
exclude_st: bool = True
|
||||||
|
|
||||||
|
def filter_codes(self, codes: List[str]) -> List[str]:
|
||||||
|
"""应用过滤条件,返回过滤后的股票代码列表
|
||||||
|
|
||||||
|
Args:
|
||||||
|
codes: 原始股票代码列表
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
过滤后的股票代码列表
|
||||||
|
|
||||||
|
Note:
|
||||||
|
ST 股票过滤需要额外数据,在 StockPoolManager 中处理。
|
||||||
|
此方法仅基于代码前缀进行过滤。
|
||||||
|
"""
|
||||||
|
result = []
|
||||||
|
for code in codes:
|
||||||
|
# 排除创业板(300xxx)
|
||||||
|
if self.exclude_cyb and code.startswith("300"):
|
||||||
|
continue
|
||||||
|
# 排除科创板(688xxx)
|
||||||
|
if self.exclude_kcb and code.startswith("688"):
|
||||||
|
continue
|
||||||
|
# 排除北交所(8xxxxxx 或 4xxxxxx)
|
||||||
|
if self.exclude_bj and (code.startswith("8") or code.startswith("4")):
|
||||||
|
continue
|
||||||
|
result.append(code)
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class MarketCapSelectorConfig:
|
||||||
|
"""市值选择器配置
|
||||||
|
|
||||||
|
每日独立选择市值最大或最小的 n 只股票。
|
||||||
|
市值数据从 daily_basic 表独立获取,仅用于筛选。
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
enabled: 是否启用选择
|
||||||
|
n: 选择前 n 只
|
||||||
|
ascending: False=最大市值, True=最小市值
|
||||||
|
market_cap_col: 市值列名(来自 daily_basic)
|
||||||
|
"""
|
||||||
|
|
||||||
|
enabled: bool = True
|
||||||
|
n: int = 100
|
||||||
|
ascending: bool = False
|
||||||
|
market_cap_col: str = "total_mv"
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
"""验证配置参数"""
|
||||||
|
if self.n <= 0:
|
||||||
|
raise ValueError(f"n 必须是正整数,得到: {self.n}")
|
||||||
|
if not self.market_cap_col:
|
||||||
|
raise ValueError("market_cap_col 不能为空")
|
||||||
122
src/training/components/splitters.py
Normal file
122
src/training/components/splitters.py
Normal file
@@ -0,0 +1,122 @@
|
|||||||
|
"""数据划分器
|
||||||
|
|
||||||
|
提供基于日期范围的数据划分功能,支持一次性训练/测试划分。
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Tuple
|
||||||
|
|
||||||
|
import polars as pl
|
||||||
|
|
||||||
|
|
||||||
|
class DateSplitter:
|
||||||
|
"""基于日期范围的一次性划分
|
||||||
|
|
||||||
|
将数据按日期划分为训练集和测试集,不滚动。
|
||||||
|
|
||||||
|
示例:
|
||||||
|
train_start: "20200101", train_end: "20221231" (训练集:3年)
|
||||||
|
test_start: "20230101", test_end: "20231231" (测试集:1年)
|
||||||
|
|
||||||
|
特点:
|
||||||
|
- 一次性划分,不滚动
|
||||||
|
- 训练集和测试集互不重叠
|
||||||
|
- 基于实际日期范围,而非行数
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
train_start: 训练期开始日期,格式 "YYYYMMDD"
|
||||||
|
train_end: 训练期结束日期,格式 "YYYYMMDD"
|
||||||
|
test_start: 测试期开始日期,格式 "YYYYMMDD"
|
||||||
|
test_end: 测试期结束日期,格式 "YYYYMMDD"
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
train_start: str,
|
||||||
|
train_end: str,
|
||||||
|
test_start: str,
|
||||||
|
test_end: str,
|
||||||
|
):
|
||||||
|
"""初始化日期划分器
|
||||||
|
|
||||||
|
Args:
|
||||||
|
train_start: 训练期开始日期 "YYYYMMDD"
|
||||||
|
train_end: 训练期结束日期 "YYYYMMDD"
|
||||||
|
test_start: 测试期开始日期 "YYYYMMDD"
|
||||||
|
test_end: 测试期结束日期 "YYYYMMDD"
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: 日期格式错误或日期范围无效
|
||||||
|
"""
|
||||||
|
# 验证日期格式(简单的长度检查)
|
||||||
|
for name, value in [
|
||||||
|
("train_start", train_start),
|
||||||
|
("train_end", train_end),
|
||||||
|
("test_start", test_start),
|
||||||
|
("test_end", test_end),
|
||||||
|
]:
|
||||||
|
if not isinstance(value, str) or len(value) != 8:
|
||||||
|
raise ValueError(
|
||||||
|
f"{name} 必须是格式为 'YYYYMMDD' 的8位字符串,得到: {value}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 验证日期范围逻辑
|
||||||
|
if train_start > train_end:
|
||||||
|
raise ValueError(
|
||||||
|
f"train_start ({train_start}) 必须早于或等于 train_end ({train_end})"
|
||||||
|
)
|
||||||
|
if test_start > test_end:
|
||||||
|
raise ValueError(
|
||||||
|
f"test_start ({test_start}) 必须早于或等于 test_end ({test_end})"
|
||||||
|
)
|
||||||
|
if test_start <= train_end:
|
||||||
|
raise ValueError(
|
||||||
|
f"测试集开始日期 ({test_start}) 必须晚于训练集结束日期 ({train_end}),"
|
||||||
|
"以确保训练集和测试集不重叠"
|
||||||
|
)
|
||||||
|
|
||||||
|
self.train_start = train_start
|
||||||
|
self.train_end = train_end
|
||||||
|
self.test_start = test_start
|
||||||
|
self.test_end = test_end
|
||||||
|
|
||||||
|
def split(
|
||||||
|
self, data: pl.DataFrame, date_col: str = "trade_date"
|
||||||
|
) -> Tuple[pl.DataFrame, pl.DataFrame]:
|
||||||
|
"""划分数据为训练集和测试集
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data: 输入数据,必须包含日期列
|
||||||
|
date_col: 日期列名,默认为 "trade_date"
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(train_data, test_data) 元组
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: 数据中不包含指定的日期列
|
||||||
|
"""
|
||||||
|
if date_col not in data.columns:
|
||||||
|
raise ValueError(f"数据中不包含列 '{date_col}',可用列: {data.columns}")
|
||||||
|
|
||||||
|
# 筛选训练集数据
|
||||||
|
train_data = data.filter(
|
||||||
|
(pl.col(date_col) >= self.train_start)
|
||||||
|
& (pl.col(date_col) <= self.train_end)
|
||||||
|
)
|
||||||
|
|
||||||
|
# 筛选测试集数据
|
||||||
|
test_data = data.filter(
|
||||||
|
(pl.col(date_col) >= self.test_start) & (pl.col(date_col) <= self.test_end)
|
||||||
|
)
|
||||||
|
|
||||||
|
return train_data, test_data
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
"""返回划分器的字符串表示"""
|
||||||
|
return (
|
||||||
|
f"DateSplitter("
|
||||||
|
f"train_start='{self.train_start}', "
|
||||||
|
f"train_end='{self.train_end}', "
|
||||||
|
f"test_start='{self.test_start}', "
|
||||||
|
f"test_end='{self.test_end}'"
|
||||||
|
f")"
|
||||||
|
)
|
||||||
18
src/training/config/__init__.py
Normal file
18
src/training/config/__init__.py
Normal file
@@ -0,0 +1,18 @@
|
|||||||
|
"""训练配置管理
|
||||||
|
|
||||||
|
提供 TrainingConfig 配置类和相关配置数据类。
|
||||||
|
"""
|
||||||
|
|
||||||
|
from src.training.config.config import (
|
||||||
|
MarketCapSelectorConfig,
|
||||||
|
ProcessorConfig,
|
||||||
|
StockFilterConfig,
|
||||||
|
TrainingConfig,
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"TrainingConfig",
|
||||||
|
"StockFilterConfig",
|
||||||
|
"MarketCapSelectorConfig",
|
||||||
|
"ProcessorConfig",
|
||||||
|
]
|
||||||
141
src/training/config/config.py
Normal file
141
src/training/config/config.py
Normal file
@@ -0,0 +1,141 @@
|
|||||||
|
"""训练配置管理
|
||||||
|
|
||||||
|
提供 TrainingConfig 配置类,使用 pydantic 进行参数验证。
|
||||||
|
"""
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
|
from pydantic import Field, validator
|
||||||
|
from pydantic_settings import BaseSettings
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class StockFilterConfig:
|
||||||
|
"""股票过滤器配置"""
|
||||||
|
|
||||||
|
exclude_cyb: bool = True # 排除创业板
|
||||||
|
exclude_kcb: bool = True # 排除科创板
|
||||||
|
exclude_bj: bool = True # 排除北交所
|
||||||
|
exclude_st: bool = True # 排除ST股票
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class MarketCapSelectorConfig:
|
||||||
|
"""市值选择器配置"""
|
||||||
|
|
||||||
|
enabled: bool = True # 是否启用
|
||||||
|
n: int = 100 # 选择前 n 只
|
||||||
|
ascending: bool = False # False=最大市值, True=最小市值
|
||||||
|
market_cap_col: str = "total_mv" # 市值列名(来自 daily_basic)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ProcessorConfig:
|
||||||
|
"""处理器配置"""
|
||||||
|
|
||||||
|
name: str
|
||||||
|
params: Dict[str, Any] = Field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
|
class TrainingConfig(BaseSettings):
|
||||||
|
"""训练配置类
|
||||||
|
|
||||||
|
所有配置参数通过此类管理,支持 pydantic 验证。
|
||||||
|
"""
|
||||||
|
|
||||||
|
# === 数据配置(必填)===
|
||||||
|
feature_cols: List[str] = Field(..., min_items=1) # 特征列名,至少一个
|
||||||
|
target_col: str = "target" # 目标变量列名
|
||||||
|
date_col: str = "trade_date" # 日期列名
|
||||||
|
code_col: str = "ts_code" # 股票代码列名
|
||||||
|
|
||||||
|
# === 日期划分(必填)===
|
||||||
|
train_start: str = Field(..., description="训练期开始 YYYYMMDD")
|
||||||
|
train_end: str = Field(..., description="训练期结束 YYYYMMDD")
|
||||||
|
test_start: str = Field(..., description="测试期开始 YYYYMMDD")
|
||||||
|
test_end: str = Field(..., description="测试期结束 YYYYMMDD")
|
||||||
|
|
||||||
|
# === 股票池配置 ===
|
||||||
|
stock_filter: StockFilterConfig = Field(
|
||||||
|
default_factory=lambda: StockFilterConfig(
|
||||||
|
exclude_cyb=True,
|
||||||
|
exclude_kcb=True,
|
||||||
|
exclude_bj=True,
|
||||||
|
exclude_st=True,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
stock_selector: Optional[MarketCapSelectorConfig] = Field(
|
||||||
|
default_factory=lambda: MarketCapSelectorConfig(
|
||||||
|
enabled=True,
|
||||||
|
n=100,
|
||||||
|
ascending=False,
|
||||||
|
market_cap_col="total_mv",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
# 注意:如果 stock_selector = None,则跳过市值选择
|
||||||
|
|
||||||
|
# === 模型配置 ===
|
||||||
|
model_type: str = "lightgbm"
|
||||||
|
model_params: Dict[str, Any] = Field(default_factory=dict)
|
||||||
|
|
||||||
|
# === 处理器配置 ===
|
||||||
|
processors: List[ProcessorConfig] = Field(
|
||||||
|
default_factory=lambda: [
|
||||||
|
ProcessorConfig(name="winsorizer", params={"lower": 0.01, "upper": 0.99}),
|
||||||
|
ProcessorConfig(name="cs_standard_scaler", params={}),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
# === 持久化配置 ===
|
||||||
|
persist_model: bool = False # 默认不持久化
|
||||||
|
model_save_path: Optional[str] = None # 持久化路径
|
||||||
|
|
||||||
|
# === 输出配置 ===
|
||||||
|
output_dir: str = "output"
|
||||||
|
save_predictions: bool = True
|
||||||
|
|
||||||
|
@validator("train_start", "train_end", "test_start", "test_end")
|
||||||
|
def validate_date_format(cls, v: str) -> str:
|
||||||
|
"""验证日期格式为 YYYYMMDD"""
|
||||||
|
if not isinstance(v, str) or len(v) != 8:
|
||||||
|
raise ValueError(f"日期必须是格式为 'YYYYMMDD' 的8位字符串,得到: {v}")
|
||||||
|
try:
|
||||||
|
int(v)
|
||||||
|
except ValueError:
|
||||||
|
raise ValueError(f"日期必须是数字字符串,得到: {v}")
|
||||||
|
return v
|
||||||
|
|
||||||
|
@validator("train_end")
|
||||||
|
def validate_train_dates(cls, v: str, values: Dict[str, Any]) -> str:
|
||||||
|
"""验证训练日期范围"""
|
||||||
|
if "train_start" in values and values["train_start"] > v:
|
||||||
|
raise ValueError(
|
||||||
|
f"train_start ({values['train_start']}) 必须早于或等于 train_end ({v})"
|
||||||
|
)
|
||||||
|
return v
|
||||||
|
|
||||||
|
@validator("test_end")
|
||||||
|
def validate_test_dates(cls, v: str, values: Dict[str, Any]) -> str:
|
||||||
|
"""验证测试日期范围"""
|
||||||
|
if "test_start" in values and values["test_start"] > v:
|
||||||
|
raise ValueError(
|
||||||
|
f"test_start ({values['test_start']}) 必须早于或等于 test_end ({v})"
|
||||||
|
)
|
||||||
|
return v
|
||||||
|
|
||||||
|
@validator("test_start")
|
||||||
|
def validate_no_overlap(cls, v: str, values: Dict[str, Any]) -> str:
|
||||||
|
"""验证训练集和测试集不重叠"""
|
||||||
|
if "train_end" in values and v <= values["train_end"]:
|
||||||
|
raise ValueError(
|
||||||
|
f"测试集开始日期 ({v}) 必须晚于训练集结束日期 ({values['train_end']}),"
|
||||||
|
"以确保训练集和测试集不重叠"
|
||||||
|
)
|
||||||
|
return v
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
"""Pydantic 配置"""
|
||||||
|
|
||||||
|
env_prefix = "TRAINING_" # 环境变量前缀
|
||||||
|
env_nested_delimiter = "__"
|
||||||
9
src/training/core/__init__.py
Normal file
9
src/training/core/__init__.py
Normal file
@@ -0,0 +1,9 @@
|
|||||||
|
"""训练核心模块
|
||||||
|
|
||||||
|
包含 Trainer 主类和股票池管理器。
|
||||||
|
"""
|
||||||
|
|
||||||
|
from src.training.core.stock_pool_manager import StockPoolManager
|
||||||
|
from src.training.core.trainer import Trainer
|
||||||
|
|
||||||
|
__all__ = ["StockPoolManager", "Trainer"]
|
||||||
171
src/training/core/stock_pool_manager.py
Normal file
171
src/training/core/stock_pool_manager.py
Normal file
@@ -0,0 +1,171 @@
|
|||||||
|
"""股票池管理器
|
||||||
|
|
||||||
|
每日独立筛选股票池,市值数据从 daily_basic 表独立获取。
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import TYPE_CHECKING, Dict, List, Optional
|
||||||
|
|
||||||
|
import polars as pl
|
||||||
|
|
||||||
|
from src.training.components.selectors import MarketCapSelectorConfig, StockFilterConfig
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from src.factors.engine.data_router import DataRouter
|
||||||
|
|
||||||
|
|
||||||
|
class StockPoolManager:
|
||||||
|
"""股票池管理器 - 每日独立筛选
|
||||||
|
|
||||||
|
重要约束:
|
||||||
|
1. 市值数据仅从 daily_basic 表获取,仅用于筛选
|
||||||
|
2. 市值数据绝不混入特征矩阵
|
||||||
|
3. 每日独立筛选(市值是动态变化的)
|
||||||
|
|
||||||
|
处理流程(每日):
|
||||||
|
当日所有股票
|
||||||
|
↓
|
||||||
|
代码过滤(创业板、ST等)
|
||||||
|
↓
|
||||||
|
查询 daily_basic 获取当日市值
|
||||||
|
↓
|
||||||
|
市值选择(前N只)
|
||||||
|
↓
|
||||||
|
返回当日选中股票列表
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
filter_config: StockFilterConfig,
|
||||||
|
selector_config: Optional[MarketCapSelectorConfig],
|
||||||
|
data_router: "DataRouter",
|
||||||
|
code_col: str = "ts_code",
|
||||||
|
date_col: str = "trade_date",
|
||||||
|
):
|
||||||
|
"""初始化股票池管理器
|
||||||
|
|
||||||
|
Args:
|
||||||
|
filter_config: 股票过滤器配置
|
||||||
|
selector_config: 市值选择器配置,None 表示跳过市值选择
|
||||||
|
data_router: 数据路由器,用于获取 daily_basic 数据
|
||||||
|
code_col: 股票代码列名
|
||||||
|
date_col: 日期列名
|
||||||
|
"""
|
||||||
|
self.filter_config = filter_config
|
||||||
|
self.selector_config = selector_config
|
||||||
|
self.data_router = data_router
|
||||||
|
self.code_col = code_col
|
||||||
|
self.date_col = date_col
|
||||||
|
|
||||||
|
def filter_and_select_daily(self, data: pl.DataFrame) -> pl.DataFrame:
|
||||||
|
"""每日独立筛选股票池
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data: 因子计算后的全市场数据,必须包含 trade_date 和 ts_code 列
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
筛选后的数据,仅包含每日选中的股票
|
||||||
|
|
||||||
|
Note:
|
||||||
|
- 按日期分组处理
|
||||||
|
- 市值数据从 daily_basic 独立获取
|
||||||
|
- 保持市值数据与特征数据隔离
|
||||||
|
"""
|
||||||
|
dates = data.select(self.date_col).unique().sort(self.date_col)
|
||||||
|
|
||||||
|
result_frames = []
|
||||||
|
for date in dates.to_series():
|
||||||
|
# 获取当日数据
|
||||||
|
daily_data = data.filter(pl.col(self.date_col) == date)
|
||||||
|
daily_codes = daily_data.select(self.code_col).to_series().to_list()
|
||||||
|
|
||||||
|
# 1. 代码过滤
|
||||||
|
filtered_codes = self.filter_config.filter_codes(daily_codes)
|
||||||
|
|
||||||
|
# 2. 市值选择(如果启用)
|
||||||
|
if self.selector_config and self.selector_config.enabled:
|
||||||
|
# 从 daily_basic 获取当日市值
|
||||||
|
market_caps = self._get_market_caps_for_date(filtered_codes, date)
|
||||||
|
selected_codes = self._select_by_market_cap(filtered_codes, market_caps)
|
||||||
|
else:
|
||||||
|
selected_codes = filtered_codes
|
||||||
|
|
||||||
|
# 3. 保留当日选中的股票数据
|
||||||
|
daily_selected = daily_data.filter(
|
||||||
|
pl.col(self.code_col).is_in(selected_codes)
|
||||||
|
)
|
||||||
|
result_frames.append(daily_selected)
|
||||||
|
|
||||||
|
return pl.concat(result_frames)
|
||||||
|
|
||||||
|
def _get_market_caps_for_date(
|
||||||
|
self, codes: List[str], date: str
|
||||||
|
) -> Dict[str, float]:
|
||||||
|
"""从 daily_basic 表获取指定日期的市值数据
|
||||||
|
|
||||||
|
Args:
|
||||||
|
codes: 股票代码列表
|
||||||
|
date: 日期 "YYYYMMDD"
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
{股票代码: 市值} 的字典
|
||||||
|
"""
|
||||||
|
if not codes:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
assert self.selector_config is not None, (
|
||||||
|
"selector_config should not be None when calling _get_market_caps_for_date"
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 通过 data_router 查询 daily_basic 表
|
||||||
|
from src.factors.engine.data_spec import DataSpec
|
||||||
|
|
||||||
|
data_specs = [
|
||||||
|
DataSpec("daily_basic", [self.selector_config.market_cap_col])
|
||||||
|
]
|
||||||
|
df = self.data_router.fetch_data(
|
||||||
|
data_specs=data_specs,
|
||||||
|
start_date=date,
|
||||||
|
end_date=date,
|
||||||
|
stock_codes=codes,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 转换为字典
|
||||||
|
market_caps = {}
|
||||||
|
for row in df.iter_rows(named=True):
|
||||||
|
code = row[self.code_col]
|
||||||
|
cap = row.get(self.selector_config.market_cap_col)
|
||||||
|
if cap is not None and code in codes:
|
||||||
|
market_caps[code] = float(cap)
|
||||||
|
|
||||||
|
return market_caps
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"[警告] 获取 {date} 市值数据失败: {e}")
|
||||||
|
return {}
|
||||||
|
|
||||||
|
def _select_by_market_cap(
|
||||||
|
self, codes: List[str], market_caps: Dict[str, float]
|
||||||
|
) -> List[str]:
|
||||||
|
"""根据市值选择股票
|
||||||
|
|
||||||
|
Args:
|
||||||
|
codes: 股票代码列表
|
||||||
|
market_caps: 市值数据字典
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
选中的股票代码列表
|
||||||
|
"""
|
||||||
|
if self.selector_config is None:
|
||||||
|
return codes
|
||||||
|
|
||||||
|
if not market_caps:
|
||||||
|
return codes[: self.selector_config.n]
|
||||||
|
|
||||||
|
# 按市值排序并选择前N只
|
||||||
|
sorted_codes = sorted(
|
||||||
|
codes,
|
||||||
|
key=lambda c: market_caps.get(c, 0),
|
||||||
|
reverse=not self.selector_config.ascending,
|
||||||
|
)
|
||||||
|
return sorted_codes[: self.selector_config.n]
|
||||||
179
src/training/core/trainer.py
Normal file
179
src/training/core/trainer.py
Normal file
@@ -0,0 +1,179 @@
|
|||||||
|
"""训练器主类
|
||||||
|
|
||||||
|
整合数据处理、模型训练、预测的完整流程。
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
import polars as pl
|
||||||
|
|
||||||
|
from src.training.components.base import BaseModel, BaseProcessor
|
||||||
|
from src.training.components.splitters import DateSplitter
|
||||||
|
from src.training.core.stock_pool_manager import StockPoolManager
|
||||||
|
|
||||||
|
|
||||||
|
class Trainer:
|
||||||
|
"""训练器主类
|
||||||
|
|
||||||
|
整合数据处理、模型训练、预测的完整流程。
|
||||||
|
|
||||||
|
关键设计:
|
||||||
|
1. 因子先计算(全市场),再筛选股票池(每日独立)
|
||||||
|
2. Processor 分阶段行为:训练集 fit_transform,测试集 transform
|
||||||
|
3. 一次性训练,不滚动
|
||||||
|
4. 支持模型持久化
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model: BaseModel,
|
||||||
|
pool_manager: Optional[StockPoolManager] = None,
|
||||||
|
processors: Optional[List[BaseProcessor]] = None,
|
||||||
|
splitter: Optional[DateSplitter] = None,
|
||||||
|
target_col: str = "target",
|
||||||
|
feature_cols: Optional[List[str]] = None,
|
||||||
|
persist_model: bool = False,
|
||||||
|
model_save_path: Optional[str] = None,
|
||||||
|
):
|
||||||
|
"""初始化训练器
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: 模型实例
|
||||||
|
pool_manager: 股票池管理器,None 表示不筛选
|
||||||
|
processors: 数据处理器列表
|
||||||
|
splitter: 数据划分器
|
||||||
|
target_col: 目标变量列名
|
||||||
|
feature_cols: 特征列名列表
|
||||||
|
persist_model: 是否保存模型
|
||||||
|
model_save_path: 模型保存路径
|
||||||
|
"""
|
||||||
|
self.model = model
|
||||||
|
self.pool_manager = pool_manager
|
||||||
|
self.processors = processors or []
|
||||||
|
self.splitter = splitter
|
||||||
|
self.target_col = target_col
|
||||||
|
self.feature_cols = feature_cols or []
|
||||||
|
self.persist_model = persist_model
|
||||||
|
self.model_save_path = model_save_path
|
||||||
|
|
||||||
|
# 存储训练后的处理器
|
||||||
|
self.fitted_processors: List[BaseProcessor] = []
|
||||||
|
self.results: Optional[pl.DataFrame] = None
|
||||||
|
|
||||||
|
def train(self, data: pl.DataFrame) -> "Trainer":
|
||||||
|
"""执行训练流程
|
||||||
|
|
||||||
|
流程:
|
||||||
|
1. 股票池每日筛选(如果配置了 pool_manager)
|
||||||
|
2. 按日期划分训练集/测试集
|
||||||
|
3. 训练集:processors fit_transform
|
||||||
|
4. 训练模型
|
||||||
|
5. 测试集:processors transform(使用训练集学到的参数)
|
||||||
|
6. 预测
|
||||||
|
7. 保存结果
|
||||||
|
8. 持久化模型(如果启用)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data: 因子计算后的全市场数据
|
||||||
|
必须包含 ts_code 和 trade_date 列
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
self (支持链式调用)
|
||||||
|
"""
|
||||||
|
# 1. 股票池筛选(每日独立)
|
||||||
|
if self.pool_manager:
|
||||||
|
print("[筛选] 每日独立筛选股票池...")
|
||||||
|
data = self.pool_manager.filter_and_select_daily(data)
|
||||||
|
|
||||||
|
# 2. 划分训练/测试集
|
||||||
|
if self.splitter:
|
||||||
|
print("[划分] 划分训练集和测试集...")
|
||||||
|
train_data, test_data = self.splitter.split(data)
|
||||||
|
else:
|
||||||
|
# 没有划分器,全部作为训练集
|
||||||
|
train_data = data
|
||||||
|
test_data = data
|
||||||
|
|
||||||
|
# 3. 训练集:processors fit_transform
|
||||||
|
if self.processors:
|
||||||
|
print("[处理] 处理训练集...")
|
||||||
|
for processor in self.processors:
|
||||||
|
train_data = processor.fit_transform(train_data)
|
||||||
|
self.fitted_processors.append(processor)
|
||||||
|
|
||||||
|
# 4. 训练模型
|
||||||
|
print("[训练] 训练模型...")
|
||||||
|
if not self.feature_cols:
|
||||||
|
raise ValueError("feature_cols 不能为空")
|
||||||
|
|
||||||
|
X_train = train_data.select(self.feature_cols)
|
||||||
|
y_train = train_data.select(self.target_col).to_series()
|
||||||
|
self.model.fit(X_train, y_train)
|
||||||
|
|
||||||
|
# 5. 测试集:processors transform
|
||||||
|
if self.processors and test_data is not train_data:
|
||||||
|
print("[处理] 处理测试集...")
|
||||||
|
for processor in self.fitted_processors:
|
||||||
|
test_data = processor.transform(test_data)
|
||||||
|
|
||||||
|
# 6. 预测
|
||||||
|
print("[预测] 生成预测...")
|
||||||
|
X_test = test_data.select(self.feature_cols)
|
||||||
|
predictions = self.model.predict(X_test)
|
||||||
|
|
||||||
|
# 7. 保存结果
|
||||||
|
self.results = test_data.with_columns([pl.Series("prediction", predictions)])
|
||||||
|
|
||||||
|
# 8. 持久化模型
|
||||||
|
if self.persist_model and self.model_save_path:
|
||||||
|
print(f"[保存] 保存模型到 {self.model_save_path}...")
|
||||||
|
self.save_model(self.model_save_path)
|
||||||
|
|
||||||
|
return self
|
||||||
|
|
||||||
|
def predict(self, data: pl.DataFrame) -> pl.DataFrame:
|
||||||
|
"""对新数据进行预测
|
||||||
|
|
||||||
|
注意:新数据需要先经过股票池筛选,
|
||||||
|
然后使用训练好的 processors 和 model 进行预测。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data: 输入数据
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
包含预测列的数据
|
||||||
|
"""
|
||||||
|
# 应用 processors
|
||||||
|
for processor in self.fitted_processors:
|
||||||
|
data = processor.transform(data)
|
||||||
|
|
||||||
|
# 预测
|
||||||
|
X = data.select(self.feature_cols)
|
||||||
|
predictions = self.model.predict(X)
|
||||||
|
|
||||||
|
return data.with_columns([pl.Series("prediction", predictions)])
|
||||||
|
|
||||||
|
def get_results(self) -> Optional[pl.DataFrame]:
|
||||||
|
"""获取所有预测结果
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
预测结果 DataFrame,包含原始列和 prediction 列
|
||||||
|
"""
|
||||||
|
return self.results
|
||||||
|
|
||||||
|
def save_results(self, path: str) -> None:
|
||||||
|
"""保存预测结果到文件
|
||||||
|
|
||||||
|
Args:
|
||||||
|
path: 保存路径(CSV 格式)
|
||||||
|
"""
|
||||||
|
if self.results is not None:
|
||||||
|
self.results.write_csv(path)
|
||||||
|
|
||||||
|
def save_model(self, path: str) -> None:
|
||||||
|
"""保存模型
|
||||||
|
|
||||||
|
Args:
|
||||||
|
path: 模型保存路径
|
||||||
|
"""
|
||||||
|
self.model.save(path)
|
||||||
235
tests/training/test_lightgbm_model.py
Normal file
235
tests/training/test_lightgbm_model.py
Normal file
@@ -0,0 +1,235 @@
|
|||||||
|
"""测试 LightGBM 模型
|
||||||
|
|
||||||
|
验证 LightGBMModel 的训练、预测、保存和加载功能。
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import tempfile
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import polars as pl
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from src.training.components.models.lightgbm import LightGBMModel
|
||||||
|
|
||||||
|
|
||||||
|
class TestLightGBMModel:
|
||||||
|
"""LightGBMModel 测试类"""
|
||||||
|
|
||||||
|
def test_init_default(self):
|
||||||
|
"""测试默认初始化"""
|
||||||
|
model = LightGBMModel()
|
||||||
|
assert model.name == "lightgbm"
|
||||||
|
assert model.params["objective"] == "regression"
|
||||||
|
assert model.params["metric"] == "rmse"
|
||||||
|
assert model.params["num_leaves"] == 31
|
||||||
|
assert model.params["learning_rate"] == 0.05
|
||||||
|
assert model.n_estimators == 100
|
||||||
|
assert model.model is None
|
||||||
|
|
||||||
|
def test_init_custom(self):
|
||||||
|
"""测试自定义参数"""
|
||||||
|
model = LightGBMModel(
|
||||||
|
objective="huber",
|
||||||
|
metric="mae",
|
||||||
|
num_leaves=50,
|
||||||
|
learning_rate=0.1,
|
||||||
|
n_estimators=200,
|
||||||
|
)
|
||||||
|
assert model.params["objective"] == "huber"
|
||||||
|
assert model.params["metric"] == "mae"
|
||||||
|
assert model.params["num_leaves"] == 50
|
||||||
|
assert model.params["learning_rate"] == 0.1
|
||||||
|
assert model.n_estimators == 200
|
||||||
|
|
||||||
|
def test_fit_success(self):
|
||||||
|
"""测试正常训练"""
|
||||||
|
# 创建简单回归数据
|
||||||
|
X = pl.DataFrame(
|
||||||
|
{
|
||||||
|
"feature1": [1.0, 2.0, 3.0, 4.0, 5.0],
|
||||||
|
"feature2": [2.0, 4.0, 6.0, 8.0, 10.0],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
y = pl.Series("target", [1.5, 3.0, 4.5, 6.0, 7.5])
|
||||||
|
|
||||||
|
model = LightGBMModel(n_estimators=10)
|
||||||
|
result = model.fit(X, y)
|
||||||
|
|
||||||
|
# 验证返回 self(支持链式调用)
|
||||||
|
assert result is model
|
||||||
|
# 验证模型已训练
|
||||||
|
assert model.model is not None
|
||||||
|
# 验证特征名称已保存
|
||||||
|
assert model.feature_names_ == ["feature1", "feature2"]
|
||||||
|
|
||||||
|
def test_predict_before_fit(self):
|
||||||
|
"""测试未训练就预测"""
|
||||||
|
X = pl.DataFrame(
|
||||||
|
{
|
||||||
|
"feature1": [1.0, 2.0],
|
||||||
|
"feature2": [2.0, 4.0],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
model = LightGBMModel()
|
||||||
|
|
||||||
|
with pytest.raises(RuntimeError, match="模型尚未训练"):
|
||||||
|
model.predict(X)
|
||||||
|
|
||||||
|
def test_predict_success(self):
|
||||||
|
"""测试正常预测"""
|
||||||
|
# 创建回归数据
|
||||||
|
np.random.seed(42)
|
||||||
|
n_samples = 100
|
||||||
|
X_train = pl.DataFrame(
|
||||||
|
{
|
||||||
|
"feature1": np.random.randn(n_samples),
|
||||||
|
"feature2": np.random.randn(n_samples),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
# y = 2*feature1 + 3*feature2 + noise
|
||||||
|
y_train = pl.Series(
|
||||||
|
"target",
|
||||||
|
2 * X_train["feature1"]
|
||||||
|
+ 3 * X_train["feature2"]
|
||||||
|
+ np.random.randn(n_samples) * 0.1,
|
||||||
|
)
|
||||||
|
|
||||||
|
model = LightGBMModel(n_estimators=20, learning_rate=0.1)
|
||||||
|
model.fit(X_train, y_train)
|
||||||
|
|
||||||
|
# 预测新数据(使用明显不同的值)
|
||||||
|
X_test = pl.DataFrame(
|
||||||
|
{
|
||||||
|
"feature1": [-2.0, 3.0],
|
||||||
|
"feature2": [-1.0, 4.0],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
predictions = model.predict(X_test)
|
||||||
|
|
||||||
|
# 验证预测结果格式
|
||||||
|
assert isinstance(predictions, np.ndarray)
|
||||||
|
assert len(predictions) == 2
|
||||||
|
# 验证预测值是数值
|
||||||
|
assert all(np.isfinite(predictions))
|
||||||
|
# 验证单调性(第二个样本的 feature 值更大,预测值也应更大)
|
||||||
|
assert predictions[1] > predictions[0]
|
||||||
|
|
||||||
|
def test_feature_importance_before_fit(self):
|
||||||
|
"""测试未训练就获取特征重要性"""
|
||||||
|
model = LightGBMModel()
|
||||||
|
assert model.feature_importance() is None
|
||||||
|
|
||||||
|
def test_feature_importance_after_fit(self):
|
||||||
|
"""测试训练后获取特征重要性"""
|
||||||
|
X = pl.DataFrame(
|
||||||
|
{
|
||||||
|
"feature1": np.random.randn(100),
|
||||||
|
"feature2": np.random.randn(100),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
y = pl.Series("target", X["feature1"] * 2 + X["feature2"] * 0.1)
|
||||||
|
|
||||||
|
model = LightGBMModel(n_estimators=10)
|
||||||
|
model.fit(X, y)
|
||||||
|
|
||||||
|
importance = model.feature_importance()
|
||||||
|
|
||||||
|
# 验证特征重要性格式
|
||||||
|
assert importance is not None
|
||||||
|
assert len(importance) == 2
|
||||||
|
assert "feature1" in importance.index
|
||||||
|
assert "feature2" in importance.index
|
||||||
|
# feature1 的系数更大,重要性应该更高
|
||||||
|
assert importance["feature1"] >= importance["feature2"]
|
||||||
|
|
||||||
|
def test_save_before_fit(self):
|
||||||
|
"""测试未训练就保存"""
|
||||||
|
model = LightGBMModel()
|
||||||
|
|
||||||
|
with pytest.raises(RuntimeError, match="模型尚未训练"):
|
||||||
|
model.save("dummy.txt")
|
||||||
|
|
||||||
|
def test_save_and_load(self):
|
||||||
|
"""测试保存和加载"""
|
||||||
|
# 训练模型
|
||||||
|
X = pl.DataFrame(
|
||||||
|
{
|
||||||
|
"feature1": [1.0, 2.0, 3.0, 4.0, 5.0],
|
||||||
|
"feature2": [2.0, 4.0, 6.0, 8.0, 10.0],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
y = pl.Series("target", [2.0, 4.0, 6.0, 8.0, 10.0])
|
||||||
|
|
||||||
|
model = LightGBMModel(n_estimators=10, learning_rate=0.1)
|
||||||
|
model.fit(X, y)
|
||||||
|
|
||||||
|
# 保存前预测
|
||||||
|
X_test = pl.DataFrame(
|
||||||
|
{
|
||||||
|
"feature1": [6.0],
|
||||||
|
"feature2": [12.0],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
pred_before = model.predict(X_test)
|
||||||
|
|
||||||
|
# 保存到临时文件
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
|
save_path = os.path.join(tmpdir, "model.txt")
|
||||||
|
model.save(save_path)
|
||||||
|
|
||||||
|
# 加载模型
|
||||||
|
loaded_model = LightGBMModel.load(save_path)
|
||||||
|
|
||||||
|
# 验证加载后预测结果相同
|
||||||
|
pred_after = loaded_model.predict(X_test)
|
||||||
|
assert pred_after[0] == pytest.approx(pred_before[0], rel=1e-5)
|
||||||
|
|
||||||
|
# 验证元数据已恢复
|
||||||
|
assert loaded_model.feature_names_ == ["feature1", "feature2"]
|
||||||
|
assert loaded_model.n_estimators == 10
|
||||||
|
|
||||||
|
def test_registration(self):
|
||||||
|
"""测试模型已注册到 registry"""
|
||||||
|
from src.training.registry import ModelRegistry
|
||||||
|
|
||||||
|
# 重新导入模型模块以确保注册(处理其他测试 clear 注册表的情况)
|
||||||
|
import importlib
|
||||||
|
import src.training.components.models.lightgbm as lightgbm_module
|
||||||
|
|
||||||
|
importlib.reload(lightgbm_module)
|
||||||
|
from src.training.components.models.lightgbm import (
|
||||||
|
LightGBMModel as ReloadedModel,
|
||||||
|
)
|
||||||
|
|
||||||
|
model_class = ModelRegistry.get_model("lightgbm")
|
||||||
|
assert model_class is ReloadedModel
|
||||||
|
|
||||||
|
def test_fit_predict_consistency(self):
|
||||||
|
"""测试多次预测结果一致"""
|
||||||
|
X = pl.DataFrame(
|
||||||
|
{
|
||||||
|
"feature1": np.random.randn(50),
|
||||||
|
"feature2": np.random.randn(50),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
y = pl.Series("target", X["feature1"] + X["feature2"])
|
||||||
|
|
||||||
|
model = LightGBMModel(n_estimators=10)
|
||||||
|
model.fit(X, y)
|
||||||
|
|
||||||
|
X_test = pl.DataFrame(
|
||||||
|
{
|
||||||
|
"feature1": [1.0, 2.0, 3.0],
|
||||||
|
"feature2": [1.0, 2.0, 3.0],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# 多次预测应该返回相同结果
|
||||||
|
pred1 = model.predict(X_test)
|
||||||
|
pred2 = model.predict(X_test)
|
||||||
|
np.testing.assert_array_almost_equal(pred1, pred2)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
pytest.main([__file__, "-v"])
|
||||||
300
tests/training/test_processors.py
Normal file
300
tests/training/test_processors.py
Normal file
@@ -0,0 +1,300 @@
|
|||||||
|
"""测试数据处理器
|
||||||
|
|
||||||
|
验证 StandardScaler、CrossSectionalStandardScaler 和 Winsorizer 功能。
|
||||||
|
"""
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import polars as pl
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from src.training.components.processors import (
|
||||||
|
CrossSectionalStandardScaler,
|
||||||
|
StandardScaler,
|
||||||
|
Winsorizer,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestStandardScaler:
|
||||||
|
"""StandardScaler 测试类"""
|
||||||
|
|
||||||
|
def test_init_default(self):
|
||||||
|
"""测试默认初始化"""
|
||||||
|
scaler = StandardScaler()
|
||||||
|
assert scaler.exclude_cols == ["ts_code", "trade_date"]
|
||||||
|
assert scaler.mean_ == {}
|
||||||
|
assert scaler.std_ == {}
|
||||||
|
|
||||||
|
def test_init_custom_exclude(self):
|
||||||
|
"""测试自定义排除列"""
|
||||||
|
scaler = StandardScaler(exclude_cols=["id", "date"])
|
||||||
|
assert scaler.exclude_cols == ["id", "date"]
|
||||||
|
|
||||||
|
def test_fit_transform(self):
|
||||||
|
"""测试拟合和转换"""
|
||||||
|
data = pl.DataFrame(
|
||||||
|
{
|
||||||
|
"ts_code": ["A", "B", "C", "D"],
|
||||||
|
"trade_date": ["20240101"] * 4,
|
||||||
|
"value": [1.0, 2.0, 3.0, 4.0],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
scaler = StandardScaler()
|
||||||
|
result = scaler.fit_transform(data)
|
||||||
|
|
||||||
|
# 验证学习到的统计量
|
||||||
|
assert scaler.mean_["value"] == 2.5
|
||||||
|
assert scaler.std_["value"] == pytest.approx(1.290, rel=1e-2)
|
||||||
|
|
||||||
|
# 验证转换结果
|
||||||
|
expected_std = (np.array([1.0, 2.0, 3.0, 4.0]) - 2.5) / 1.290
|
||||||
|
assert result["value"].to_list() == pytest.approx(
|
||||||
|
expected_std.tolist(), rel=1e-2
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_transform_use_fitted_params(self):
|
||||||
|
"""测试转换使用拟合时的参数"""
|
||||||
|
train_data = pl.DataFrame(
|
||||||
|
{
|
||||||
|
"ts_code": ["A", "B", "C"],
|
||||||
|
"trade_date": ["20240101"] * 3,
|
||||||
|
"value": [1.0, 2.0, 3.0],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
test_data = pl.DataFrame(
|
||||||
|
{
|
||||||
|
"ts_code": ["D"],
|
||||||
|
"trade_date": ["20240102"],
|
||||||
|
"value": [100.0], # 远离训练分布
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
scaler = StandardScaler()
|
||||||
|
scaler.fit(train_data)
|
||||||
|
|
||||||
|
# 使用训练集的均值(2.0)和标准差进行转换
|
||||||
|
result = scaler.transform(test_data)
|
||||||
|
expected = (100.0 - 2.0) / 1.0 # 均值2.0, 标准差1.0
|
||||||
|
assert result["value"][0] == pytest.approx(expected, rel=1e-2)
|
||||||
|
|
||||||
|
def test_exclude_non_numeric(self):
|
||||||
|
"""测试自动排除非数值列"""
|
||||||
|
data = pl.DataFrame(
|
||||||
|
{
|
||||||
|
"ts_code": ["A", "B"],
|
||||||
|
"trade_date": ["20240101", "20240102"],
|
||||||
|
"category": ["X", "Y"], # 字符串列
|
||||||
|
"value": [1.0, 2.0],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
scaler = StandardScaler()
|
||||||
|
result = scaler.fit_transform(data)
|
||||||
|
|
||||||
|
# category 列应该原样保留
|
||||||
|
assert result["category"].to_list() == ["X", "Y"]
|
||||||
|
# value 列应该被标准化
|
||||||
|
assert "value" in scaler.mean_
|
||||||
|
|
||||||
|
def test_zero_std_handling(self):
|
||||||
|
"""测试处理标准差为0的情况"""
|
||||||
|
data = pl.DataFrame(
|
||||||
|
{
|
||||||
|
"ts_code": ["A", "B"],
|
||||||
|
"trade_date": ["20240101", "20240102"],
|
||||||
|
"constant": [5.0, 5.0], # 常数列
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
scaler = StandardScaler()
|
||||||
|
result = scaler.fit_transform(data)
|
||||||
|
|
||||||
|
# 标准差为0时,结果应该为0(避免除以0)
|
||||||
|
assert result["constant"].to_list() == [0.0, 0.0]
|
||||||
|
|
||||||
|
|
||||||
|
class TestCrossSectionalStandardScaler:
|
||||||
|
"""CrossSectionalStandardScaler 测试类"""
|
||||||
|
|
||||||
|
def test_init_default(self):
|
||||||
|
"""测试默认初始化"""
|
||||||
|
scaler = CrossSectionalStandardScaler()
|
||||||
|
assert scaler.exclude_cols == ["ts_code", "trade_date"]
|
||||||
|
assert scaler.date_col == "trade_date"
|
||||||
|
|
||||||
|
def test_init_custom(self):
|
||||||
|
"""测试自定义参数"""
|
||||||
|
scaler = CrossSectionalStandardScaler(
|
||||||
|
exclude_cols=["id"],
|
||||||
|
date_col="date",
|
||||||
|
)
|
||||||
|
assert scaler.exclude_cols == ["id"]
|
||||||
|
assert scaler.date_col == "date"
|
||||||
|
|
||||||
|
def test_transform_no_fit_needed(self):
|
||||||
|
"""测试不需要 fit"""
|
||||||
|
data = pl.DataFrame(
|
||||||
|
{
|
||||||
|
"ts_code": ["A", "B"],
|
||||||
|
"trade_date": ["20240101", "20240101"],
|
||||||
|
"value": [1.0, 3.0],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
scaler = CrossSectionalStandardScaler()
|
||||||
|
# 截面标准化不需要 fit
|
||||||
|
result = scaler.transform(data)
|
||||||
|
|
||||||
|
# 当天均值=2.0, 样本标准差=sqrt(2)≈1.414, z-score=[-0.707, 0.707]
|
||||||
|
assert result["value"].to_list() == pytest.approx([-0.707, 0.707], rel=1e-2)
|
||||||
|
|
||||||
|
def test_transform_by_date(self):
|
||||||
|
"""测试按日期分组标准化"""
|
||||||
|
data = pl.DataFrame(
|
||||||
|
{
|
||||||
|
"ts_code": ["A", "B", "C", "D"],
|
||||||
|
"trade_date": ["20240101", "20240101", "20240102", "20240102"],
|
||||||
|
"value": [1.0, 3.0, 10.0, 30.0],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
scaler = CrossSectionalStandardScaler()
|
||||||
|
result = scaler.transform(data)
|
||||||
|
|
||||||
|
# 2024-01-01: 均值=2.0, 样本std≈1.414 -> [-0.707, 0.707]
|
||||||
|
# 2024-01-02: 均值=20.0, 样本std≈14.14 -> [-0.707, 0.707]
|
||||||
|
values = result["value"].to_list()
|
||||||
|
assert values[0] == pytest.approx(-0.707, abs=1e-2)
|
||||||
|
assert values[1] == pytest.approx(0.707, abs=1e-2)
|
||||||
|
assert values[2] == pytest.approx(-0.707, abs=1e-2)
|
||||||
|
assert values[3] == pytest.approx(0.707, abs=1e-2)
|
||||||
|
|
||||||
|
def test_exclude_columns_preserved(self):
|
||||||
|
"""测试排除列保持原样"""
|
||||||
|
data = pl.DataFrame(
|
||||||
|
{
|
||||||
|
"ts_code": ["A", "B"],
|
||||||
|
"trade_date": ["20240101", "20240101"],
|
||||||
|
"value": [1.0, 3.0],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
scaler = CrossSectionalStandardScaler()
|
||||||
|
result = scaler.transform(data)
|
||||||
|
|
||||||
|
assert result["ts_code"].to_list() == ["A", "B"]
|
||||||
|
assert result["trade_date"].to_list() == ["20240101", "20240101"]
|
||||||
|
|
||||||
|
|
||||||
|
class TestWinsorizer:
|
||||||
|
"""Winsorizer 测试类"""
|
||||||
|
|
||||||
|
def test_init_default(self):
|
||||||
|
"""测试默认初始化"""
|
||||||
|
winsorizer = Winsorizer()
|
||||||
|
assert winsorizer.lower == 0.01
|
||||||
|
assert winsorizer.upper == 0.99
|
||||||
|
assert winsorizer.by_date is False
|
||||||
|
assert winsorizer.date_col == "trade_date"
|
||||||
|
|
||||||
|
def test_init_custom(self):
|
||||||
|
"""测试自定义参数"""
|
||||||
|
winsorizer = Winsorizer(lower=0.05, upper=0.95, by_date=True, date_col="date")
|
||||||
|
assert winsorizer.lower == 0.05
|
||||||
|
assert winsorizer.upper == 0.95
|
||||||
|
assert winsorizer.by_date is True
|
||||||
|
assert winsorizer.date_col == "date"
|
||||||
|
|
||||||
|
def test_invalid_quantiles(self):
|
||||||
|
"""测试无效的分位数参数"""
|
||||||
|
with pytest.raises(ValueError, match="lower .* 必须小于 upper"):
|
||||||
|
Winsorizer(lower=0.5, upper=0.3)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="lower .* 必须小于 upper"):
|
||||||
|
Winsorizer(lower=-0.1, upper=0.5)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="lower .* 必须小于 upper"):
|
||||||
|
Winsorizer(lower=0.5, upper=1.5)
|
||||||
|
|
||||||
|
def test_global_winsorize(self):
|
||||||
|
"""测试全局缩尾"""
|
||||||
|
# 创建包含极端值的数据
|
||||||
|
values = list(range(1, 101)) # 1-100
|
||||||
|
values[0] = -1000 # 极端小值
|
||||||
|
values[-1] = 1000 # 极端大值
|
||||||
|
|
||||||
|
data = pl.DataFrame(
|
||||||
|
{
|
||||||
|
"ts_code": [f"A{i}" for i in range(100)],
|
||||||
|
"trade_date": ["20240101"] * 100,
|
||||||
|
"value": values,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
winsorizer = Winsorizer(lower=0.01, upper=0.99)
|
||||||
|
result = winsorizer.fit_transform(data)
|
||||||
|
|
||||||
|
# 1%分位数=2, 99%分位数=99
|
||||||
|
# -1000 应该被截断为 2
|
||||||
|
# 1000 应该被截断为 99
|
||||||
|
result_values = result["value"].to_list()
|
||||||
|
assert result_values[0] == 2 # 原-1000被截断
|
||||||
|
assert result_values[-1] == 99 # 原1000被截断
|
||||||
|
assert result_values[1] == 2 # 原2保持不变
|
||||||
|
assert result_values[98] == 99 # 原99保持不变
|
||||||
|
|
||||||
|
def test_by_date_winsorize(self):
|
||||||
|
"""测试每日独立缩尾"""
|
||||||
|
data = pl.DataFrame(
|
||||||
|
{
|
||||||
|
"ts_code": ["A", "B", "C", "D", "E", "F"],
|
||||||
|
"trade_date": ["20240101"] * 3 + ["20240102"] * 3,
|
||||||
|
"value": [1.0, 50.0, 100.0, 200.0, 250.0, 300.0],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
winsorizer = Winsorizer(lower=0.0, upper=0.5, by_date=True)
|
||||||
|
result = winsorizer.transform(data)
|
||||||
|
|
||||||
|
# 每天独立处理:
|
||||||
|
# 2024-01-01: [1, 50, 100], 50%分位数=50
|
||||||
|
# -> 截断为 [1, 50, 50]
|
||||||
|
# 2024-01-02: [200, 250, 300], 50%分位数=250
|
||||||
|
# -> 截断为 [200, 250, 250]
|
||||||
|
result_values = result["value"].to_list()
|
||||||
|
assert result_values[0] == 1.0
|
||||||
|
assert result_values[1] == 50.0
|
||||||
|
assert result_values[2] == 50.0 # 被截断
|
||||||
|
assert result_values[3] == 200.0
|
||||||
|
assert result_values[4] == 250.0
|
||||||
|
assert result_values[5] == 250.0 # 被截断
|
||||||
|
|
||||||
|
def test_global_transform_after_fit(self):
|
||||||
|
"""测试全局模式下,转换使用拟合时的边界"""
|
||||||
|
train_data = pl.DataFrame(
|
||||||
|
{
|
||||||
|
"ts_code": ["A", "B", "C"],
|
||||||
|
"trade_date": ["20240101"] * 3,
|
||||||
|
"value": [1.0, 50.0, 100.0],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
test_data = pl.DataFrame(
|
||||||
|
{
|
||||||
|
"ts_code": ["D"],
|
||||||
|
"trade_date": ["20240102"],
|
||||||
|
"value": [200.0],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
winsorizer = Winsorizer(lower=0.0, upper=1.0) # 0%和100%分位数
|
||||||
|
winsorizer.fit(train_data)
|
||||||
|
|
||||||
|
# 使用训练集的分位数边界 [1, 100]
|
||||||
|
result = winsorizer.transform(test_data)
|
||||||
|
assert result["value"][0] == 100.0 # 被截断为100
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
pytest.main([__file__, "-v"])
|
||||||
183
tests/training/test_selectors.py
Normal file
183
tests/training/test_selectors.py
Normal file
@@ -0,0 +1,183 @@
|
|||||||
|
"""测试股票池选择器配置
|
||||||
|
|
||||||
|
验证 StockFilterConfig 和 MarketCapSelectorConfig 功能。
|
||||||
|
"""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from src.training.components.selectors import (
|
||||||
|
MarketCapSelectorConfig,
|
||||||
|
StockFilterConfig,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestStockFilterConfig:
|
||||||
|
"""StockFilterConfig 测试类"""
|
||||||
|
|
||||||
|
def test_default_values(self):
|
||||||
|
"""测试默认值"""
|
||||||
|
config = StockFilterConfig()
|
||||||
|
assert config.exclude_cyb is True
|
||||||
|
assert config.exclude_kcb is True
|
||||||
|
assert config.exclude_bj is True
|
||||||
|
assert config.exclude_st is True
|
||||||
|
|
||||||
|
def test_custom_values(self):
|
||||||
|
"""测试自定义值"""
|
||||||
|
config = StockFilterConfig(
|
||||||
|
exclude_cyb=False,
|
||||||
|
exclude_kcb=False,
|
||||||
|
exclude_bj=False,
|
||||||
|
exclude_st=False,
|
||||||
|
)
|
||||||
|
assert config.exclude_cyb is False
|
||||||
|
assert config.exclude_kcb is False
|
||||||
|
assert config.exclude_bj is False
|
||||||
|
assert config.exclude_st is False
|
||||||
|
|
||||||
|
def test_filter_codes_exclude_all(self):
|
||||||
|
"""测试排除所有类型"""
|
||||||
|
config = StockFilterConfig(
|
||||||
|
exclude_cyb=True,
|
||||||
|
exclude_kcb=True,
|
||||||
|
exclude_bj=True,
|
||||||
|
exclude_st=True,
|
||||||
|
)
|
||||||
|
codes = [
|
||||||
|
"000001.SZ", # 主板 - 保留
|
||||||
|
"300001.SZ", # 创业板 - 排除
|
||||||
|
"688001.SH", # 科创板 - 排除
|
||||||
|
"830001.BJ", # 北交所(8开头)- 排除
|
||||||
|
"430001.BJ", # 北交所(4开头)- 排除
|
||||||
|
]
|
||||||
|
result = config.filter_codes(codes)
|
||||||
|
assert result == ["000001.SZ"]
|
||||||
|
|
||||||
|
def test_filter_codes_allow_cyb(self):
|
||||||
|
"""测试允许创业板"""
|
||||||
|
config = StockFilterConfig(
|
||||||
|
exclude_cyb=False,
|
||||||
|
exclude_kcb=True,
|
||||||
|
exclude_bj=True,
|
||||||
|
exclude_st=True,
|
||||||
|
)
|
||||||
|
codes = [
|
||||||
|
"000001.SZ",
|
||||||
|
"300001.SZ",
|
||||||
|
"688001.SH",
|
||||||
|
]
|
||||||
|
result = config.filter_codes(codes)
|
||||||
|
assert result == ["000001.SZ", "300001.SZ"]
|
||||||
|
|
||||||
|
def test_filter_codes_allow_kcb(self):
|
||||||
|
"""测试允许科创板"""
|
||||||
|
config = StockFilterConfig(
|
||||||
|
exclude_cyb=True,
|
||||||
|
exclude_kcb=False,
|
||||||
|
exclude_bj=True,
|
||||||
|
exclude_st=True,
|
||||||
|
)
|
||||||
|
codes = [
|
||||||
|
"000001.SZ",
|
||||||
|
"300001.SZ",
|
||||||
|
"688001.SH",
|
||||||
|
]
|
||||||
|
result = config.filter_codes(codes)
|
||||||
|
assert result == ["000001.SZ", "688001.SH"]
|
||||||
|
|
||||||
|
def test_filter_codes_allow_bj(self):
|
||||||
|
"""测试允许北交所"""
|
||||||
|
config = StockFilterConfig(
|
||||||
|
exclude_cyb=True,
|
||||||
|
exclude_kcb=True,
|
||||||
|
exclude_bj=False,
|
||||||
|
exclude_st=True,
|
||||||
|
)
|
||||||
|
codes = [
|
||||||
|
"000001.SZ",
|
||||||
|
"300001.SZ",
|
||||||
|
"830001.BJ",
|
||||||
|
"430001.BJ",
|
||||||
|
]
|
||||||
|
result = config.filter_codes(codes)
|
||||||
|
assert result == ["000001.SZ", "830001.BJ", "430001.BJ"]
|
||||||
|
|
||||||
|
def test_filter_codes_allow_all(self):
|
||||||
|
"""测试允许所有类型"""
|
||||||
|
config = StockFilterConfig(
|
||||||
|
exclude_cyb=False,
|
||||||
|
exclude_kcb=False,
|
||||||
|
exclude_bj=False,
|
||||||
|
exclude_st=False,
|
||||||
|
)
|
||||||
|
codes = [
|
||||||
|
"000001.SZ",
|
||||||
|
"300001.SZ",
|
||||||
|
"688001.SH",
|
||||||
|
"830001.BJ",
|
||||||
|
"430001.BJ",
|
||||||
|
]
|
||||||
|
result = config.filter_codes(codes)
|
||||||
|
assert result == codes
|
||||||
|
|
||||||
|
def test_filter_codes_empty_list(self):
|
||||||
|
"""测试空列表"""
|
||||||
|
config = StockFilterConfig()
|
||||||
|
result = config.filter_codes([])
|
||||||
|
assert result == []
|
||||||
|
|
||||||
|
def test_filter_codes_no_matching(self):
|
||||||
|
"""测试全部排除"""
|
||||||
|
config = StockFilterConfig()
|
||||||
|
codes = ["300001.SZ", "688001.SH", "830001.BJ"]
|
||||||
|
result = config.filter_codes(codes)
|
||||||
|
assert result == []
|
||||||
|
|
||||||
|
|
||||||
|
class TestMarketCapSelectorConfig:
|
||||||
|
"""MarketCapSelectorConfig 测试类"""
|
||||||
|
|
||||||
|
def test_default_values(self):
|
||||||
|
"""测试默认值"""
|
||||||
|
config = MarketCapSelectorConfig()
|
||||||
|
assert config.enabled is True
|
||||||
|
assert config.n == 100
|
||||||
|
assert config.ascending is False
|
||||||
|
assert config.market_cap_col == "total_mv"
|
||||||
|
|
||||||
|
def test_custom_values(self):
|
||||||
|
"""测试自定义值"""
|
||||||
|
config = MarketCapSelectorConfig(
|
||||||
|
enabled=False,
|
||||||
|
n=50,
|
||||||
|
ascending=True,
|
||||||
|
market_cap_col="circ_mv",
|
||||||
|
)
|
||||||
|
assert config.enabled is False
|
||||||
|
assert config.n == 50
|
||||||
|
assert config.ascending is True
|
||||||
|
assert config.market_cap_col == "circ_mv"
|
||||||
|
|
||||||
|
def test_invalid_n_zero(self):
|
||||||
|
"""测试无效的 n=0"""
|
||||||
|
with pytest.raises(ValueError, match="n 必须是正整数"):
|
||||||
|
MarketCapSelectorConfig(n=0)
|
||||||
|
|
||||||
|
def test_invalid_n_negative(self):
|
||||||
|
"""测试无效的负数 n"""
|
||||||
|
with pytest.raises(ValueError, match="n 必须是正整数"):
|
||||||
|
MarketCapSelectorConfig(n=-1)
|
||||||
|
|
||||||
|
def test_invalid_empty_market_cap_col(self):
|
||||||
|
"""测试空的市值列名"""
|
||||||
|
with pytest.raises(ValueError, match="market_cap_col 不能为空"):
|
||||||
|
MarketCapSelectorConfig(market_cap_col="")
|
||||||
|
|
||||||
|
def test_large_n(self):
|
||||||
|
"""测试大的 n 值"""
|
||||||
|
config = MarketCapSelectorConfig(n=5000)
|
||||||
|
assert config.n == 5000
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
pytest.main([__file__, "-v"])
|
||||||
244
tests/training/test_splitters.py
Normal file
244
tests/training/test_splitters.py
Normal file
@@ -0,0 +1,244 @@
|
|||||||
|
"""测试 DateSplitter 数据划分器
|
||||||
|
|
||||||
|
验证一次性日期划分功能。
|
||||||
|
"""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import polars as pl
|
||||||
|
|
||||||
|
from src.training.components.splitters import DateSplitter
|
||||||
|
|
||||||
|
|
||||||
|
class TestDateSplitter:
|
||||||
|
"""DateSplitter 测试类"""
|
||||||
|
|
||||||
|
def test_initialization_success(self):
|
||||||
|
"""测试正常初始化"""
|
||||||
|
splitter = DateSplitter(
|
||||||
|
train_start="20200101",
|
||||||
|
train_end="20221231",
|
||||||
|
test_start="20230101",
|
||||||
|
test_end="20231231",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert splitter.train_start == "20200101"
|
||||||
|
assert splitter.train_end == "20221231"
|
||||||
|
assert splitter.test_start == "20230101"
|
||||||
|
assert splitter.test_end == "20231231"
|
||||||
|
|
||||||
|
def test_invalid_date_format(self):
|
||||||
|
"""测试无效的日期格式"""
|
||||||
|
with pytest.raises(ValueError, match="必须是格式为 'YYYYMMDD' 的8位字符串"):
|
||||||
|
DateSplitter(
|
||||||
|
train_start="2020-01-01", # 错误格式
|
||||||
|
train_end="20221231",
|
||||||
|
test_start="20230101",
|
||||||
|
test_end="20231231",
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_train_start_after_train_end(self):
|
||||||
|
"""测试训练集开始日期晚于结束日期"""
|
||||||
|
with pytest.raises(ValueError, match="train_start.*必须早于或等于 train_end"):
|
||||||
|
DateSplitter(
|
||||||
|
train_start="20231231",
|
||||||
|
train_end="20200101",
|
||||||
|
test_start="20230101",
|
||||||
|
test_end="20231231",
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_test_start_after_test_end(self):
|
||||||
|
"""测试测试集开始日期晚于结束日期"""
|
||||||
|
with pytest.raises(ValueError, match="test_start.*必须早于或等于 test_end"):
|
||||||
|
DateSplitter(
|
||||||
|
train_start="20200101",
|
||||||
|
train_end="20221231",
|
||||||
|
test_start="20231231",
|
||||||
|
test_end="20230101",
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_overlapping_dates(self):
|
||||||
|
"""测试训练集和测试集日期重叠"""
|
||||||
|
with pytest.raises(ValueError, match="必须晚于训练集结束日期"):
|
||||||
|
DateSplitter(
|
||||||
|
train_start="20200101",
|
||||||
|
train_end="20221231",
|
||||||
|
test_start="20220601", # 在训练集范围内
|
||||||
|
test_end="20231231",
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_split_success(self):
|
||||||
|
"""测试正常划分数据"""
|
||||||
|
# 创建测试数据
|
||||||
|
data = pl.DataFrame(
|
||||||
|
{
|
||||||
|
"ts_code": [
|
||||||
|
"000001.SZ",
|
||||||
|
"000002.SZ",
|
||||||
|
"000003.SZ",
|
||||||
|
"000004.SZ",
|
||||||
|
"000005.SZ",
|
||||||
|
"000006.SZ",
|
||||||
|
],
|
||||||
|
"trade_date": [
|
||||||
|
"20200101",
|
||||||
|
"20211231",
|
||||||
|
"20221231",
|
||||||
|
"20230101",
|
||||||
|
"20230601",
|
||||||
|
"20231231",
|
||||||
|
],
|
||||||
|
"value": [1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
splitter = DateSplitter(
|
||||||
|
train_start="20200101",
|
||||||
|
train_end="20221231",
|
||||||
|
test_start="20230101",
|
||||||
|
test_end="20231231",
|
||||||
|
)
|
||||||
|
|
||||||
|
train_data, test_data = splitter.split(data)
|
||||||
|
|
||||||
|
# 验证训练集
|
||||||
|
assert len(train_data) == 3
|
||||||
|
assert train_data["trade_date"].to_list() == [
|
||||||
|
"20200101",
|
||||||
|
"20211231",
|
||||||
|
"20221231",
|
||||||
|
]
|
||||||
|
|
||||||
|
# 验证测试集
|
||||||
|
assert len(test_data) == 3
|
||||||
|
assert test_data["trade_date"].to_list() == ["20230101", "20230601", "20231231"]
|
||||||
|
|
||||||
|
def test_split_no_matching_train_data(self):
|
||||||
|
"""测试训练集无匹配数据"""
|
||||||
|
data = pl.DataFrame(
|
||||||
|
{
|
||||||
|
"ts_code": ["000001.SZ", "000002.SZ"],
|
||||||
|
"trade_date": ["20230101", "20231231"],
|
||||||
|
"value": [1.0, 2.0],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
splitter = DateSplitter(
|
||||||
|
train_start="20200101",
|
||||||
|
train_end="20221231",
|
||||||
|
test_start="20230101",
|
||||||
|
test_end="20231231",
|
||||||
|
)
|
||||||
|
|
||||||
|
train_data, test_data = splitter.split(data)
|
||||||
|
|
||||||
|
# 训练集应该为空
|
||||||
|
assert len(train_data) == 0
|
||||||
|
# 测试集应该有数据
|
||||||
|
assert len(test_data) == 2
|
||||||
|
|
||||||
|
def test_split_no_matching_test_data(self):
|
||||||
|
"""测试测试集无匹配数据"""
|
||||||
|
data = pl.DataFrame(
|
||||||
|
{
|
||||||
|
"ts_code": ["000001.SZ", "000002.SZ"],
|
||||||
|
"trade_date": ["20200101", "20211231"],
|
||||||
|
"value": [1.0, 2.0],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
splitter = DateSplitter(
|
||||||
|
train_start="20200101",
|
||||||
|
train_end="20221231",
|
||||||
|
test_start="20230101",
|
||||||
|
test_end="20231231",
|
||||||
|
)
|
||||||
|
|
||||||
|
train_data, test_data = splitter.split(data)
|
||||||
|
|
||||||
|
# 训练集应该有数据
|
||||||
|
assert len(train_data) == 2
|
||||||
|
# 测试集应该为空
|
||||||
|
assert len(test_data) == 0
|
||||||
|
|
||||||
|
def test_split_with_custom_date_col(self):
|
||||||
|
"""测试使用自定义日期列名"""
|
||||||
|
data = pl.DataFrame(
|
||||||
|
{
|
||||||
|
"ts_code": ["000001.SZ", "000002.SZ", "000003.SZ"],
|
||||||
|
"date": ["20200101", "20211231", "20230101"],
|
||||||
|
"value": [1.0, 2.0, 3.0],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
splitter = DateSplitter(
|
||||||
|
train_start="20200101",
|
||||||
|
train_end="20221231",
|
||||||
|
test_start="20230101",
|
||||||
|
test_end="20231231",
|
||||||
|
)
|
||||||
|
|
||||||
|
train_data, test_data = splitter.split(data, date_col="date")
|
||||||
|
|
||||||
|
assert len(train_data) == 2
|
||||||
|
assert len(test_data) == 1
|
||||||
|
|
||||||
|
def test_split_missing_date_column(self):
|
||||||
|
"""测试数据缺少日期列"""
|
||||||
|
data = pl.DataFrame(
|
||||||
|
{
|
||||||
|
"ts_code": ["000001.SZ"],
|
||||||
|
"value": [1.0],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
splitter = DateSplitter(
|
||||||
|
train_start="20200101",
|
||||||
|
train_end="20221231",
|
||||||
|
test_start="20230101",
|
||||||
|
test_end="20231231",
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="数据中不包含列 'trade_date'"):
|
||||||
|
splitter.split(data)
|
||||||
|
|
||||||
|
def test_repr(self):
|
||||||
|
"""测试 __repr__ 方法"""
|
||||||
|
splitter = DateSplitter(
|
||||||
|
train_start="20200101",
|
||||||
|
train_end="20221231",
|
||||||
|
test_start="20230101",
|
||||||
|
test_end="20231231",
|
||||||
|
)
|
||||||
|
|
||||||
|
repr_str = repr(splitter)
|
||||||
|
assert "DateSplitter" in repr_str
|
||||||
|
assert "20200101" in repr_str
|
||||||
|
assert "20221231" in repr_str
|
||||||
|
assert "20230101" in repr_str
|
||||||
|
assert "20231231" in repr_str
|
||||||
|
|
||||||
|
def test_edge_case_same_day_train(self):
|
||||||
|
"""测试训练集为单日"""
|
||||||
|
data = pl.DataFrame(
|
||||||
|
{
|
||||||
|
"ts_code": ["000001.SZ"],
|
||||||
|
"trade_date": ["20200101"],
|
||||||
|
"value": [1.0],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
splitter = DateSplitter(
|
||||||
|
train_start="20200101",
|
||||||
|
train_end="20200101",
|
||||||
|
test_start="20200102",
|
||||||
|
test_end="20200102",
|
||||||
|
)
|
||||||
|
|
||||||
|
train_data, test_data = splitter.split(data)
|
||||||
|
|
||||||
|
assert len(train_data) == 1
|
||||||
|
assert len(test_data) == 0
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
pytest.main([__file__, "-v"])
|
||||||
Reference in New Issue
Block a user