Compare commits

..

5 Commits

Author SHA1 Message Date
192718095f feat(training): 实现训练模块核心组件(commits 6-9)
- StockPoolManager:每日独立筛选股票池,支持代码过滤和市值选择
- Trainer:整合训练完整流程,支持 processor 分阶段行为和模型持久化
- TrainingConfig:pydantic 配置管理,含必填字段和日期验证
- experiment 模块:预留结构
- 从计划中移除 metrics 组件
- 调整 commit 序号(7-10 → 6-9)
- 更新 training/__init__.py 导出所有公开 API
2026-03-03 22:57:01 +08:00
f35a6a76a6 feat(training): 实现 LightGBM 模型
- 新增 LightGBMModel:LightGBM 回归模型实现
- 支持自定义参数(objective, num_leaves, learning_rate, n_estimators 等)
- 使用 LightGBM 原生格式保存/加载模型(不依赖 pickle)
- 支持特征重要性提取
- 已注册到 ModelRegistry(@register_model("lightgbm"))
2026-03-03 22:30:37 +08:00
9ca1deae56 feat(training): 实现数据处理器
- 新增 StandardScaler:全局标准化,训练集学习参数,测试集复用
- 新增 CrossSectionalStandardScaler:截面标准化,每天独立计算
- 新增 Winsorizer:支持全局/截面两种缩尾模式
- 处理器统一遵循 fit/transform 接口,Trainer 可无差别调用
- 添加 17 个单元测试覆盖各种场景
2026-03-03 22:23:43 +08:00
6b63c428d9 feat(training): 实现股票池选择器配置
- 新增 StockFilterConfig:支持按代码前缀过滤创业板/科创板/北交所
- 新增 MarketCapSelectorConfig:配置市值选择参数(数量、排序、列名)
- 添加参数验证(n>0, 列名非空)
- 在 components 模块导出配置类
- 添加 15 个单元测试覆盖各种场景
2026-03-03 22:10:36 +08:00
f48b307ad2 feat(training): 实现 DateSplitter 数据划分器
- 新增 DateSplitter 类,支持基于日期范围的一次性训练/测试划分
- 实现日期格式验证和日期范围逻辑检查
- 支持自定义日期列名参数
- 添加完整的单元测试(12个测试用例)
- 在 components 模块导出 DateSplitter
2026-03-03 22:07:45 +08:00
19 changed files with 2258 additions and 72 deletions

View File

@@ -40,9 +40,6 @@ src/
│ │ ├── processors/ # 数据处理器 │ │ ├── processors/ # 数据处理器
│ │ │ ├── __init__.py │ │ │ ├── __init__.py
│ │ │ └── transforms.py # 标准化(截面/时序)、缩尾 │ │ │ └── transforms.py # 标准化(截面/时序)、缩尾
│ │ └── metrics/ # 评估指标
│ │ ├── __init__.py
│ │ └── metrics.py # IC, RankIC, MSE, MAE
│ ├── config/ # 配置管理 │ ├── config/ # 配置管理
│ │ ├── __init__.py │ │ ├── __init__.py
│ │ └── config.py # TrainingConfig (pydantic) │ │ └── config.py # TrainingConfig (pydantic)
@@ -568,30 +565,6 @@ class Winsorizer(BaseProcessor):
pass pass
``` ```
### 3.7 评估指标 (components/metrics/)
```python
def ic_score(y_true: np.ndarray, y_pred: np.ndarray) -> float:
"""信息系数 (Pearson 相关系数)"""
from scipy.stats import pearsonr
corr, _ = pearsonr(y_true, y_pred)
return corr
def rank_ic_score(y_true: np.ndarray, y_pred: np.ndarray) -> float:
"""秩相关系数 (Spearman 相关系数)"""
from scipy.stats import spearmanr
corr, _ = spearmanr(y_true, y_pred)
return corr
def mse_score(y_true: np.ndarray, y_pred: np.ndarray) -> float:
"""均方误差"""
return np.mean((y_true - y_pred) ** 2)
def mae_score(y_true: np.ndarray, y_pred: np.ndarray) -> float:
"""平均绝对误差"""
return np.mean(np.abs(y_true - y_pred))
```
## 4. 训练流程设计 ## 4. 训练流程设计
### 4.1 Trainer 主类 (core/trainer.py) ### 4.1 Trainer 主类 (core/trainer.py)
@@ -615,7 +588,6 @@ class Trainer:
pool_manager: Optional[StockPoolManager] = None, pool_manager: Optional[StockPoolManager] = None,
processors: List[BaseProcessor] = None, processors: List[BaseProcessor] = None,
splitter: DateSplitter = None, splitter: DateSplitter = None,
metrics: List[str] = None,
target_col: str = "target", target_col: str = "target",
feature_cols: List[str] = None, feature_cols: List[str] = None,
persist_model: bool = False, persist_model: bool = False,
@@ -625,7 +597,6 @@ class Trainer:
self.pool_manager = pool_manager self.pool_manager = pool_manager
self.processors = processors or [] self.processors = processors or []
self.splitter = splitter self.splitter = splitter
self.metrics = metrics or ["ic", "rank_ic", "mse", "mae"]
self.target_col = target_col self.target_col = target_col
self.feature_cols = feature_cols self.feature_cols = feature_cols
self.persist_model = persist_model self.persist_model = persist_model
@@ -634,7 +605,6 @@ class Trainer:
# 存储训练后的处理器 # 存储训练后的处理器
self.fitted_processors: List[BaseProcessor] = [] self.fitted_processors: List[BaseProcessor] = []
self.results: pl.DataFrame = None self.results: pl.DataFrame = None
self.metrics_results: Dict[str, float] = {}
def train(self, data: pl.DataFrame) -> "Trainer": def train(self, data: pl.DataFrame) -> "Trainer":
"""执行训练流程 """执行训练流程
@@ -686,38 +656,18 @@ class Trainer:
X_test = test_data.select(self.feature_cols) X_test = test_data.select(self.feature_cols)
predictions = self.model.predict(X_test) predictions = self.model.predict(X_test)
# 7. 评估 # 7. 保存结果
print("[评估] 计算指标...")
y_test = test_data.select(self.target_col).to_series().to_numpy()
self._evaluate(y_test, predictions)
# 8. 保存结果
self.results = test_data.with_columns([ self.results = test_data.with_columns([
pl.Series("prediction", predictions) pl.Series("prediction", predictions)
]) ])
# 9. 持久化模型 # 8. 持久化模型
if self.persist_model and self.model_save_path: if self.persist_model and self.model_save_path:
print(f"[保存] 保存模型到 {self.model_save_path}...") print(f"[保存] 保存模型到 {self.model_save_path}...")
self.save_model(self.model_save_path) self.save_model(self.model_save_path)
return self return self
def _evaluate(self, y_true: np.ndarray, y_pred: np.ndarray):
"""计算评估指标"""
from .components.metrics import ic_score, rank_ic_score, mse_score, mae_score
metric_funcs = {
"ic": ic_score,
"rank_ic": rank_ic_score,
"mse": mse_score,
"mae": mae_score,
}
for metric in self.metrics:
if metric in metric_funcs:
self.metrics_results[metric] = metric_funcs[metric](y_true, y_pred)
def predict(self, data: pl.DataFrame) -> pl.DataFrame: def predict(self, data: pl.DataFrame) -> pl.DataFrame:
"""对新数据进行预测 """对新数据进行预测
@@ -738,10 +688,6 @@ class Trainer:
"""获取所有预测结果""" """获取所有预测结果"""
return self.results return self.results
def get_metrics(self) -> Dict[str, float]:
"""获取评估指标"""
return self.metrics_results
def save_results(self, path: str): def save_results(self, path: str):
"""保存预测结果到文件""" """保存预测结果到文件"""
if self.results is not None: if self.results is not None:
@@ -827,9 +773,6 @@ class TrainingConfig(BaseSettings):
] ]
) )
# === 评估指标 ===
metrics: List[str] = ["ic", "rank_ic", "mse", "mae"]
# === 持久化配置 === # === 持久化配置 ===
persist_model: bool = False # 默认不持久化 persist_model: bool = False # 默认不持久化
model_save_path: Optional[str] = None # 持久化路径 model_save_path: Optional[str] = None # 持久化路径
@@ -914,8 +857,6 @@ trainer.train(all_data)
# 8. 获取结果 # 8. 获取结果
results = trainer.get_results() # 包含预测值 results = trainer.get_results() # 包含预测值
metrics = trainer.get_metrics() # IC, RankIC, etc.
print("评估指标:", metrics)
# 9. 保存结果 # 9. 保存结果
trainer.save_results("output/predictions.csv") trainer.save_results("output/predictions.csv")
@@ -989,7 +930,6 @@ trainer = Trainer(
trainer.train(all_data) trainer.train(all_data)
results = trainer.get_results() results = trainer.get_results()
metrics = trainer.get_metrics()
``` ```
## 6. 实现顺序 ## 6. 实现顺序
@@ -1019,22 +959,18 @@ metrics = trainer.get_metrics()
- `training/components/models/__init__.py` - `training/components/models/__init__.py`
- `training/components/models/lightgbm.py`(含 save_model/load_model - `training/components/models/lightgbm.py`(含 save_model/load_model
### Commit 6: 评估指标 ### Commit 6: 股票池管理器
- `training/components/metrics/__init__.py`
- `training/components/metrics/metrics.py`IC, RankIC, MSE, MAE
### Commit 7: 股票池管理器
- `training/core/__init__.py` - `training/core/__init__.py`
- `training/core/stock_pool_manager.py`(每日独立筛选) - `training/core/stock_pool_manager.py`(每日独立筛选)
### Commit 8: Trainer 训练器 ### Commit 7: Trainer 训练器
- `training/core/trainer.py` - `training/core/trainer.py`
### Commit 9: 配置管理 ### Commit 8: 配置管理
- `training/config/__init__.py` - `training/config/__init__.py`
- `training/config/config.py`TrainingConfig含必填校验 - `training/config/config.py`TrainingConfig含必填校验
### Commit 10: 预留实验模块 ### Commit 9: 预留实验模块
- `experiment/__init__.py` - `experiment/__init__.py`
## 7. 注意事项 ## 7. 注意事项
@@ -1080,7 +1016,7 @@ metrics = trainer.get_metrics()
1. **特征选择**processors/selectors.py 1. **特征选择**processors/selectors.py
2. **滚动训练**WalkForward, ExpandingWindow 2. **滚动训练**WalkForward, ExpandingWindow
3. **结果分析工具**(复杂分析功能) 3. **结果分析工具**(复杂分析功能)
4. **validator.py, evaluator.py**简化为 metrics.py 4. **validator.py, evaluator.py**已删除,不实现 metrics
### 7.5 新增功能 ### 7.5 新增功能

View File

@@ -0,0 +1,7 @@
"""实验管理模块(预留结构)
此模块为预留结构,用于未来的实验管理功能。
暂不提供具体实现。
"""
__all__ = []

View File

@@ -1,6 +1,6 @@
"""训练模块 - ProStock 量化投资框架 """训练模块 - ProStock 量化投资框架
提供模型训练、数据处理和评估的完整流程。 提供模型训练、数据处理和预测的完整流程。
""" """
# 基础抽象类 # 基础抽象类
@@ -14,6 +14,31 @@ from src.training.registry import (
register_processor, register_processor,
) )
# 数据划分器
from src.training.components.splitters import DateSplitter
# 股票池选择器配置
from src.training.components.selectors import (
MarketCapSelectorConfig,
StockFilterConfig,
)
# 数据处理器
from src.training.components.processors import (
CrossSectionalStandardScaler,
StandardScaler,
Winsorizer,
)
# 模型
from src.training.components.models import LightGBMModel
# 训练核心
from src.training.core import StockPoolManager, Trainer
# 配置
from src.training.config import TrainingConfig
__all__ = [ __all__ = [
# 基础抽象类 # 基础抽象类
"BaseModel", "BaseModel",
@@ -23,4 +48,20 @@ __all__ = [
"ProcessorRegistry", "ProcessorRegistry",
"register_model", "register_model",
"register_processor", "register_processor",
# 数据划分器
"DateSplitter",
# 股票池选择器配置
"StockFilterConfig",
"MarketCapSelectorConfig",
# 数据处理器
"StandardScaler",
"CrossSectionalStandardScaler",
"Winsorizer",
# 模型
"LightGBMModel",
# 训练核心
"StockPoolManager",
"Trainer",
# 配置
"TrainingConfig",
] ]

View File

@@ -6,7 +6,33 @@
# 基础抽象类 # 基础抽象类
from src.training.components.base import BaseModel, BaseProcessor from src.training.components.base import BaseModel, BaseProcessor
# 数据划分器
from src.training.components.splitters import DateSplitter
# 股票池选择器配置
from src.training.components.selectors import (
MarketCapSelectorConfig,
StockFilterConfig,
)
# 数据处理器
from src.training.components.processors import (
CrossSectionalStandardScaler,
StandardScaler,
Winsorizer,
)
# 模型
from src.training.components.models import LightGBMModel
__all__ = [ __all__ = [
"BaseModel", "BaseModel",
"BaseProcessor", "BaseProcessor",
"DateSplitter",
"StockFilterConfig",
"MarketCapSelectorConfig",
"StandardScaler",
"CrossSectionalStandardScaler",
"Winsorizer",
"LightGBMModel",
] ]

View File

@@ -0,0 +1,8 @@
"""模型子模块
包含各种机器学习模型的实现。
"""
from src.training.components.models.lightgbm import LightGBMModel
__all__ = ["LightGBMModel"]

View File

@@ -0,0 +1,194 @@
"""LightGBM 模型实现
提供 LightGBM 回归模型的实现,支持特征重要性和原生模型保存。
"""
from typing import Optional
import numpy as np
import pandas as pd
import polars as pl
from src.training.components.base import BaseModel
from src.training.registry import register_model
@register_model("lightgbm")
class LightGBMModel(BaseModel):
"""LightGBM 回归模型
使用 LightGBM 库实现梯度提升回归树。
支持自定义参数、特征重要性提取和原生模型格式保存。
Attributes:
name: 模型名称 "lightgbm"
params: LightGBM 参数字典
model: 训练后的 LightGBM Booster 对象
feature_names_: 特征名称列表
"""
name = "lightgbm"
def __init__(
self,
objective: str = "regression",
metric: str = "rmse",
num_leaves: int = 31,
learning_rate: float = 0.05,
n_estimators: int = 100,
**kwargs,
):
"""初始化 LightGBM 模型
Args:
objective: 目标函数,默认 "regression"
metric: 评估指标,默认 "rmse"
num_leaves: 叶子节点数,默认 31
learning_rate: 学习率,默认 0.05
n_estimators: 迭代次数,默认 100
**kwargs: 其他 LightGBM 参数
"""
self.params = {
"objective": objective,
"metric": metric,
"num_leaves": num_leaves,
"learning_rate": learning_rate,
"verbose": -1, # 抑制训练输出
**kwargs,
}
self.n_estimators = n_estimators
self.model = None
self.feature_names_: Optional[list] = None
def fit(self, X: pl.DataFrame, y: pl.Series) -> "LightGBMModel":
"""训练模型
Args:
X: 特征矩阵 (Polars DataFrame)
y: 目标变量 (Polars Series)
Returns:
self (支持链式调用)
Raises:
ImportError: 未安装 lightgbm
RuntimeError: 训练失败
"""
try:
import lightgbm as lgb
except ImportError:
raise ImportError(
"使用 LightGBMModel 需要安装 lightgbm: pip install lightgbm"
)
# 保存特征名称
self.feature_names_ = X.columns
# 转换为 numpy
X_np = X.to_numpy()
y_np = y.to_numpy()
# 创建数据集
train_data = lgb.Dataset(X_np, label=y_np)
# 训练
self.model = lgb.train(
self.params,
train_data,
num_boost_round=self.n_estimators,
)
return self
def predict(self, X: pl.DataFrame) -> np.ndarray:
"""预测
Args:
X: 特征矩阵 (Polars DataFrame)
Returns:
预测结果 (numpy ndarray)
Raises:
RuntimeError: 模型未训练时调用
"""
if self.model is None:
raise RuntimeError("模型尚未训练,请先调用 fit()")
X_np = X.to_numpy()
return self.model.predict(X_np)
def feature_importance(self) -> Optional[pd.Series]:
"""返回特征重要性
Returns:
特征重要性序列,如果模型未训练则返回 None
"""
if self.model is None or self.feature_names_ is None:
return None
importance = self.model.feature_importance(importance_type="gain")
return pd.Series(importance, index=self.feature_names_)
def save(self, path: str) -> None:
"""保存模型(使用 LightGBM 原生格式)
使用 LightGBM 的原生格式保存,不依赖 pickle
可以在不同环境中加载。
Args:
path: 保存路径
Raises:
RuntimeError: 模型未训练时调用
"""
if self.model is None:
raise RuntimeError("模型尚未训练,无法保存")
self.model.save_model(path)
# 同时保存特征名称LightGBM 原生格式不保存这个)
import json
meta_path = path + ".meta.json"
with open(meta_path, "w") as f:
json.dump(
{
"feature_names": self.feature_names_,
"params": self.params,
"n_estimators": self.n_estimators,
},
f,
)
@classmethod
def load(cls, path: str) -> "LightGBMModel":
"""加载模型
从 LightGBM 原生格式加载模型。
Args:
path: 模型文件路径
Returns:
加载的 LightGBMModel 实例
"""
import lightgbm as lgb
import json
instance = cls()
instance.model = lgb.Booster(model_file=path)
# 加载元数据
meta_path = path + ".meta.json"
try:
with open(meta_path, "r") as f:
meta = json.load(f)
instance.feature_names_ = meta.get("feature_names")
instance.params = meta.get("params", instance.params)
instance.n_estimators = meta.get("n_estimators", instance.n_estimators)
except FileNotFoundError:
# 如果没有元数据文件继续运行feature_names_ 为 None
pass
return instance

View File

@@ -0,0 +1,16 @@
"""数据处理器子模块
包含数据预处理、转换等处理器实现。
"""
from src.training.components.processors.transforms import (
CrossSectionalStandardScaler,
StandardScaler,
Winsorizer,
)
__all__ = [
"StandardScaler",
"CrossSectionalStandardScaler",
"Winsorizer",
]

View File

@@ -0,0 +1,275 @@
"""数据处理器实现
包含标准化、缩尾等数据处理器。
"""
from typing import List, Optional
import polars as pl
from src.training.components.base import BaseProcessor
from src.training.registry import register_processor
@register_processor("standard_scaler")
class StandardScaler(BaseProcessor):
"""标准化处理器(全局标准化)
在整个训练集上学习均值和标准差,
然后应用到训练集和测试集。
适用于需要全局统计量的场景。
Attributes:
exclude_cols: 不参与标准化的列名列表
mean_: 学习到的均值字典 {列名: 均值}
std_: 学习到的标准差字典 {列名: 标准差}
"""
name = "standard_scaler"
def __init__(self, exclude_cols: Optional[List[str]] = None):
"""初始化标准化处理器
Args:
exclude_cols: 不参与标准化的列名列表,默认为 ["ts_code", "trade_date"]
"""
self.exclude_cols = exclude_cols or ["ts_code", "trade_date"]
self.mean_: dict = {}
self.std_: dict = {}
def fit(self, X: pl.DataFrame) -> "StandardScaler":
"""计算均值和标准差(仅在训练集上)
Args:
X: 训练数据
Returns:
self
"""
numeric_cols = [
c
for c in X.columns
if c not in self.exclude_cols and X[c].dtype.is_numeric()
]
for col in numeric_cols:
self.mean_[col] = X[col].mean()
self.std_[col] = X[col].std()
return self
def transform(self, X: pl.DataFrame) -> pl.DataFrame:
"""标准化(使用训练集学到的参数)
Args:
X: 待转换数据
Returns:
标准化后的数据
"""
expressions = []
for col in X.columns:
if col in self.mean_ and col in self.std_:
# 避免除以0
std_val = self.std_[col] if self.std_[col] != 0 else 1.0
expr = ((pl.col(col) - self.mean_[col]) / std_val).alias(col)
expressions.append(expr)
else:
expressions.append(pl.col(col))
return X.select(expressions)
@register_processor("cs_standard_scaler")
class CrossSectionalStandardScaler(BaseProcessor):
"""截面标准化处理器
每天独立进行标准化:对当天所有股票的某一因子进行标准化。
特点:
- 不需要 fit每天独立计算当天的均值和标准差
- 适用于截面因子,消除市值等行业差异
- 公式z = (x - mean_today) / std_today
Attributes:
exclude_cols: 不参与标准化的列名列表
date_col: 日期列名
"""
name = "cs_standard_scaler"
def __init__(
self,
exclude_cols: Optional[List[str]] = None,
date_col: str = "trade_date",
):
"""初始化截面标准化处理器
Args:
exclude_cols: 不参与标准化的列名列表,默认为 ["ts_code", "trade_date"]
date_col: 日期列名
"""
self.exclude_cols = exclude_cols or ["ts_code", "trade_date"]
self.date_col = date_col
def transform(self, X: pl.DataFrame) -> pl.DataFrame:
"""截面标准化
按日期分组,每天独立计算均值和标准差并标准化。
不需要 fit因为每天使用当天的统计量。
Args:
X: 待转换数据
Returns:
标准化后的数据
"""
numeric_cols = [
c
for c in X.columns
if c not in self.exclude_cols and X[c].dtype.is_numeric()
]
# 构建表达式列表
expressions = []
for col in X.columns:
if col in numeric_cols:
# 截面标准化:每天独立计算均值和标准差
# 避免除以0当std为0时设为1
expr = (
(pl.col(col) - pl.col(col).mean().over(self.date_col))
/ (pl.col(col).std().over(self.date_col) + 1e-10)
).alias(col)
expressions.append(expr)
else:
expressions.append(pl.col(col))
return X.select(expressions)
@register_processor("winsorizer")
class Winsorizer(BaseProcessor):
"""缩尾处理器
对每一列的极端值进行截断处理。
可以全局截断(在整个训练集上学习分位数),
也可以截面截断(每天独立处理)。
Attributes:
lower: 下分位数如0.01表示1%分位数)
upper: 上分位数如0.99表示99%分位数)
by_date: True=每天独立缩尾, False=全局缩尾
date_col: 日期列名
bounds_: 存储分位数边界(全局模式)
"""
name = "winsorizer"
def __init__(
self,
lower: float = 0.01,
upper: float = 0.99,
by_date: bool = False,
date_col: str = "trade_date",
):
"""初始化缩尾处理器
Args:
lower: 下分位数默认0.01
upper: 上分位数默认0.99
by_date: 每天独立缩尾默认False全局缩尾
date_col: 日期列名
Raises:
ValueError: 分位数参数无效
"""
if not 0 <= lower < upper <= 1:
raise ValueError(
f"lower ({lower}) 必须小于 upper ({upper}),且都在 [0, 1] 范围内"
)
self.lower = lower
self.upper = upper
self.by_date = by_date
self.date_col = date_col
self.bounds_: dict = {}
def fit(self, X: pl.DataFrame) -> "Winsorizer":
"""学习分位数边界(仅在全局模式下)
Args:
X: 训练数据
Returns:
self
"""
if not self.by_date:
numeric_cols = [c for c in X.columns if X[c].dtype.is_numeric()]
for col in numeric_cols:
self.bounds_[col] = {
"lower": X[col].quantile(self.lower),
"upper": X[col].quantile(self.upper),
}
return self
def transform(self, X: pl.DataFrame) -> pl.DataFrame:
"""缩尾处理
Args:
X: 待转换数据
Returns:
缩尾处理后的数据
"""
if self.by_date:
return self._transform_by_date(X)
else:
return self._transform_global(X)
def _transform_global(self, X: pl.DataFrame) -> pl.DataFrame:
"""全局缩尾(使用训练集学到的边界)"""
expressions = []
for col in X.columns:
if col in self.bounds_:
lower = self.bounds_[col]["lower"]
upper = self.bounds_[col]["upper"]
expr = pl.col(col).clip(lower, upper).alias(col)
expressions.append(expr)
else:
expressions.append(pl.col(col))
return X.select(expressions)
def _transform_by_date(self, X: pl.DataFrame) -> pl.DataFrame:
"""每日独立缩尾"""
numeric_cols = [c for c in X.columns if X[c].dtype.is_numeric()]
# 先计算每天的分位数
lower_exprs = [
pl.col(col).quantile(self.lower).over(self.date_col).alias(f"{col}_lower")
for col in numeric_cols
]
upper_exprs = [
pl.col(col).quantile(self.upper).over(self.date_col).alias(f"{col}_upper")
for col in numeric_cols
]
# 添加分位数列
result = X.with_columns(lower_exprs + upper_exprs)
# 执行缩尾
clip_exprs = []
for col in X.columns:
if col in numeric_cols:
clipped = (
pl.col(col)
.clip(pl.col(f"{col}_lower"), pl.col(f"{col}_upper"))
.alias(col)
)
clip_exprs.append(clipped)
else:
clip_exprs.append(pl.col(col))
result = result.select(clip_exprs)
return result

View File

@@ -0,0 +1,81 @@
"""股票池选择器配置
提供股票过滤和市值选择的配置类。
"""
from dataclasses import dataclass
from typing import List, Optional
@dataclass
class StockFilterConfig:
"""股票过滤器配置
用于过滤掉不需要的股票(如创业板、科创板等)。
基于股票代码进行过滤,不依赖外部数据。
Attributes:
exclude_cyb: 是否排除创业板300xxx
exclude_kcb: 是否排除科创板688xxx
exclude_bj: 是否排除北交所8xxxxxx, 4xxxxxx
exclude_st: 是否排除ST股票需要外部数据支持
"""
exclude_cyb: bool = True
exclude_kcb: bool = True
exclude_bj: bool = True
exclude_st: bool = True
def filter_codes(self, codes: List[str]) -> List[str]:
"""应用过滤条件,返回过滤后的股票代码列表
Args:
codes: 原始股票代码列表
Returns:
过滤后的股票代码列表
Note:
ST 股票过滤需要额外数据,在 StockPoolManager 中处理。
此方法仅基于代码前缀进行过滤。
"""
result = []
for code in codes:
# 排除创业板300xxx
if self.exclude_cyb and code.startswith("300"):
continue
# 排除科创板688xxx
if self.exclude_kcb and code.startswith("688"):
continue
# 排除北交所8xxxxxx 或 4xxxxxx
if self.exclude_bj and (code.startswith("8") or code.startswith("4")):
continue
result.append(code)
return result
@dataclass
class MarketCapSelectorConfig:
"""市值选择器配置
每日独立选择市值最大或最小的 n 只股票。
市值数据从 daily_basic 表独立获取,仅用于筛选。
Attributes:
enabled: 是否启用选择
n: 选择前 n 只
ascending: False=最大市值, True=最小市值
market_cap_col: 市值列名(来自 daily_basic
"""
enabled: bool = True
n: int = 100
ascending: bool = False
market_cap_col: str = "total_mv"
def __post_init__(self):
"""验证配置参数"""
if self.n <= 0:
raise ValueError(f"n 必须是正整数,得到: {self.n}")
if not self.market_cap_col:
raise ValueError("market_cap_col 不能为空")

View File

@@ -0,0 +1,122 @@
"""数据划分器
提供基于日期范围的数据划分功能,支持一次性训练/测试划分。
"""
from typing import Tuple
import polars as pl
class DateSplitter:
"""基于日期范围的一次性划分
将数据按日期划分为训练集和测试集,不滚动。
示例:
train_start: "20200101", train_end: "20221231" (训练集3年)
test_start: "20230101", test_end: "20231231" (测试集1年)
特点:
- 一次性划分,不滚动
- 训练集和测试集互不重叠
- 基于实际日期范围,而非行数
Attributes:
train_start: 训练期开始日期,格式 "YYYYMMDD"
train_end: 训练期结束日期,格式 "YYYYMMDD"
test_start: 测试期开始日期,格式 "YYYYMMDD"
test_end: 测试期结束日期,格式 "YYYYMMDD"
"""
def __init__(
self,
train_start: str,
train_end: str,
test_start: str,
test_end: str,
):
"""初始化日期划分器
Args:
train_start: 训练期开始日期 "YYYYMMDD"
train_end: 训练期结束日期 "YYYYMMDD"
test_start: 测试期开始日期 "YYYYMMDD"
test_end: 测试期结束日期 "YYYYMMDD"
Raises:
ValueError: 日期格式错误或日期范围无效
"""
# 验证日期格式(简单的长度检查)
for name, value in [
("train_start", train_start),
("train_end", train_end),
("test_start", test_start),
("test_end", test_end),
]:
if not isinstance(value, str) or len(value) != 8:
raise ValueError(
f"{name} 必须是格式为 'YYYYMMDD' 的8位字符串得到: {value}"
)
# 验证日期范围逻辑
if train_start > train_end:
raise ValueError(
f"train_start ({train_start}) 必须早于或等于 train_end ({train_end})"
)
if test_start > test_end:
raise ValueError(
f"test_start ({test_start}) 必须早于或等于 test_end ({test_end})"
)
if test_start <= train_end:
raise ValueError(
f"测试集开始日期 ({test_start}) 必须晚于训练集结束日期 ({train_end})"
"以确保训练集和测试集不重叠"
)
self.train_start = train_start
self.train_end = train_end
self.test_start = test_start
self.test_end = test_end
def split(
self, data: pl.DataFrame, date_col: str = "trade_date"
) -> Tuple[pl.DataFrame, pl.DataFrame]:
"""划分数据为训练集和测试集
Args:
data: 输入数据,必须包含日期列
date_col: 日期列名,默认为 "trade_date"
Returns:
(train_data, test_data) 元组
Raises:
ValueError: 数据中不包含指定的日期列
"""
if date_col not in data.columns:
raise ValueError(f"数据中不包含列 '{date_col}',可用列: {data.columns}")
# 筛选训练集数据
train_data = data.filter(
(pl.col(date_col) >= self.train_start)
& (pl.col(date_col) <= self.train_end)
)
# 筛选测试集数据
test_data = data.filter(
(pl.col(date_col) >= self.test_start) & (pl.col(date_col) <= self.test_end)
)
return train_data, test_data
def __repr__(self) -> str:
"""返回划分器的字符串表示"""
return (
f"DateSplitter("
f"train_start='{self.train_start}', "
f"train_end='{self.train_end}', "
f"test_start='{self.test_start}', "
f"test_end='{self.test_end}'"
f")"
)

View File

@@ -0,0 +1,18 @@
"""训练配置管理
提供 TrainingConfig 配置类和相关配置数据类。
"""
from src.training.config.config import (
MarketCapSelectorConfig,
ProcessorConfig,
StockFilterConfig,
TrainingConfig,
)
__all__ = [
"TrainingConfig",
"StockFilterConfig",
"MarketCapSelectorConfig",
"ProcessorConfig",
]

View File

@@ -0,0 +1,141 @@
"""训练配置管理
提供 TrainingConfig 配置类,使用 pydantic 进行参数验证。
"""
from dataclasses import dataclass
from typing import Any, Dict, List, Optional
from pydantic import Field, validator
from pydantic_settings import BaseSettings
@dataclass
class StockFilterConfig:
"""股票过滤器配置"""
exclude_cyb: bool = True # 排除创业板
exclude_kcb: bool = True # 排除科创板
exclude_bj: bool = True # 排除北交所
exclude_st: bool = True # 排除ST股票
@dataclass
class MarketCapSelectorConfig:
"""市值选择器配置"""
enabled: bool = True # 是否启用
n: int = 100 # 选择前 n 只
ascending: bool = False # False=最大市值, True=最小市值
market_cap_col: str = "total_mv" # 市值列名(来自 daily_basic
@dataclass
class ProcessorConfig:
"""处理器配置"""
name: str
params: Dict[str, Any] = Field(default_factory=dict)
class TrainingConfig(BaseSettings):
"""训练配置类
所有配置参数通过此类管理,支持 pydantic 验证。
"""
# === 数据配置(必填)===
feature_cols: List[str] = Field(..., min_items=1) # 特征列名,至少一个
target_col: str = "target" # 目标变量列名
date_col: str = "trade_date" # 日期列名
code_col: str = "ts_code" # 股票代码列名
# === 日期划分(必填)===
train_start: str = Field(..., description="训练期开始 YYYYMMDD")
train_end: str = Field(..., description="训练期结束 YYYYMMDD")
test_start: str = Field(..., description="测试期开始 YYYYMMDD")
test_end: str = Field(..., description="测试期结束 YYYYMMDD")
# === 股票池配置 ===
stock_filter: StockFilterConfig = Field(
default_factory=lambda: StockFilterConfig(
exclude_cyb=True,
exclude_kcb=True,
exclude_bj=True,
exclude_st=True,
)
)
stock_selector: Optional[MarketCapSelectorConfig] = Field(
default_factory=lambda: MarketCapSelectorConfig(
enabled=True,
n=100,
ascending=False,
market_cap_col="total_mv",
)
)
# 注意:如果 stock_selector = None则跳过市值选择
# === 模型配置 ===
model_type: str = "lightgbm"
model_params: Dict[str, Any] = Field(default_factory=dict)
# === 处理器配置 ===
processors: List[ProcessorConfig] = Field(
default_factory=lambda: [
ProcessorConfig(name="winsorizer", params={"lower": 0.01, "upper": 0.99}),
ProcessorConfig(name="cs_standard_scaler", params={}),
]
)
# === 持久化配置 ===
persist_model: bool = False # 默认不持久化
model_save_path: Optional[str] = None # 持久化路径
# === 输出配置 ===
output_dir: str = "output"
save_predictions: bool = True
@validator("train_start", "train_end", "test_start", "test_end")
def validate_date_format(cls, v: str) -> str:
"""验证日期格式为 YYYYMMDD"""
if not isinstance(v, str) or len(v) != 8:
raise ValueError(f"日期必须是格式为 'YYYYMMDD' 的8位字符串得到: {v}")
try:
int(v)
except ValueError:
raise ValueError(f"日期必须是数字字符串,得到: {v}")
return v
@validator("train_end")
def validate_train_dates(cls, v: str, values: Dict[str, Any]) -> str:
"""验证训练日期范围"""
if "train_start" in values and values["train_start"] > v:
raise ValueError(
f"train_start ({values['train_start']}) 必须早于或等于 train_end ({v})"
)
return v
@validator("test_end")
def validate_test_dates(cls, v: str, values: Dict[str, Any]) -> str:
"""验证测试日期范围"""
if "test_start" in values and values["test_start"] > v:
raise ValueError(
f"test_start ({values['test_start']}) 必须早于或等于 test_end ({v})"
)
return v
@validator("test_start")
def validate_no_overlap(cls, v: str, values: Dict[str, Any]) -> str:
"""验证训练集和测试集不重叠"""
if "train_end" in values and v <= values["train_end"]:
raise ValueError(
f"测试集开始日期 ({v}) 必须晚于训练集结束日期 ({values['train_end']})"
"以确保训练集和测试集不重叠"
)
return v
class Config:
"""Pydantic 配置"""
env_prefix = "TRAINING_" # 环境变量前缀
env_nested_delimiter = "__"

View File

@@ -0,0 +1,9 @@
"""训练核心模块
包含 Trainer 主类和股票池管理器。
"""
from src.training.core.stock_pool_manager import StockPoolManager
from src.training.core.trainer import Trainer
__all__ = ["StockPoolManager", "Trainer"]

View File

@@ -0,0 +1,171 @@
"""股票池管理器
每日独立筛选股票池,市值数据从 daily_basic 表独立获取。
"""
from typing import TYPE_CHECKING, Dict, List, Optional
import polars as pl
from src.training.components.selectors import MarketCapSelectorConfig, StockFilterConfig
if TYPE_CHECKING:
from src.factors.engine.data_router import DataRouter
class StockPoolManager:
"""股票池管理器 - 每日独立筛选
重要约束:
1. 市值数据仅从 daily_basic 表获取,仅用于筛选
2. 市值数据绝不混入特征矩阵
3. 每日独立筛选(市值是动态变化的)
处理流程(每日):
当日所有股票
代码过滤创业板、ST等
查询 daily_basic 获取当日市值
市值选择前N只
返回当日选中股票列表
"""
def __init__(
self,
filter_config: StockFilterConfig,
selector_config: Optional[MarketCapSelectorConfig],
data_router: "DataRouter",
code_col: str = "ts_code",
date_col: str = "trade_date",
):
"""初始化股票池管理器
Args:
filter_config: 股票过滤器配置
selector_config: 市值选择器配置None 表示跳过市值选择
data_router: 数据路由器,用于获取 daily_basic 数据
code_col: 股票代码列名
date_col: 日期列名
"""
self.filter_config = filter_config
self.selector_config = selector_config
self.data_router = data_router
self.code_col = code_col
self.date_col = date_col
def filter_and_select_daily(self, data: pl.DataFrame) -> pl.DataFrame:
"""每日独立筛选股票池
Args:
data: 因子计算后的全市场数据,必须包含 trade_date 和 ts_code 列
Returns:
筛选后的数据,仅包含每日选中的股票
Note:
- 按日期分组处理
- 市值数据从 daily_basic 独立获取
- 保持市值数据与特征数据隔离
"""
dates = data.select(self.date_col).unique().sort(self.date_col)
result_frames = []
for date in dates.to_series():
# 获取当日数据
daily_data = data.filter(pl.col(self.date_col) == date)
daily_codes = daily_data.select(self.code_col).to_series().to_list()
# 1. 代码过滤
filtered_codes = self.filter_config.filter_codes(daily_codes)
# 2. 市值选择(如果启用)
if self.selector_config and self.selector_config.enabled:
# 从 daily_basic 获取当日市值
market_caps = self._get_market_caps_for_date(filtered_codes, date)
selected_codes = self._select_by_market_cap(filtered_codes, market_caps)
else:
selected_codes = filtered_codes
# 3. 保留当日选中的股票数据
daily_selected = daily_data.filter(
pl.col(self.code_col).is_in(selected_codes)
)
result_frames.append(daily_selected)
return pl.concat(result_frames)
def _get_market_caps_for_date(
self, codes: List[str], date: str
) -> Dict[str, float]:
"""从 daily_basic 表获取指定日期的市值数据
Args:
codes: 股票代码列表
date: 日期 "YYYYMMDD"
Returns:
{股票代码: 市值} 的字典
"""
if not codes:
return {}
assert self.selector_config is not None, (
"selector_config should not be None when calling _get_market_caps_for_date"
)
try:
# 通过 data_router 查询 daily_basic 表
from src.factors.engine.data_spec import DataSpec
data_specs = [
DataSpec("daily_basic", [self.selector_config.market_cap_col])
]
df = self.data_router.fetch_data(
data_specs=data_specs,
start_date=date,
end_date=date,
stock_codes=codes,
)
# 转换为字典
market_caps = {}
for row in df.iter_rows(named=True):
code = row[self.code_col]
cap = row.get(self.selector_config.market_cap_col)
if cap is not None and code in codes:
market_caps[code] = float(cap)
return market_caps
except Exception as e:
print(f"[警告] 获取 {date} 市值数据失败: {e}")
return {}
def _select_by_market_cap(
self, codes: List[str], market_caps: Dict[str, float]
) -> List[str]:
"""根据市值选择股票
Args:
codes: 股票代码列表
market_caps: 市值数据字典
Returns:
选中的股票代码列表
"""
if self.selector_config is None:
return codes
if not market_caps:
return codes[: self.selector_config.n]
# 按市值排序并选择前N只
sorted_codes = sorted(
codes,
key=lambda c: market_caps.get(c, 0),
reverse=not self.selector_config.ascending,
)
return sorted_codes[: self.selector_config.n]

View File

@@ -0,0 +1,179 @@
"""训练器主类
整合数据处理、模型训练、预测的完整流程。
"""
from typing import List, Optional
import polars as pl
from src.training.components.base import BaseModel, BaseProcessor
from src.training.components.splitters import DateSplitter
from src.training.core.stock_pool_manager import StockPoolManager
class Trainer:
"""训练器主类
整合数据处理、模型训练、预测的完整流程。
关键设计:
1. 因子先计算(全市场),再筛选股票池(每日独立)
2. Processor 分阶段行为:训练集 fit_transform测试集 transform
3. 一次性训练,不滚动
4. 支持模型持久化
"""
def __init__(
self,
model: BaseModel,
pool_manager: Optional[StockPoolManager] = None,
processors: Optional[List[BaseProcessor]] = None,
splitter: Optional[DateSplitter] = None,
target_col: str = "target",
feature_cols: Optional[List[str]] = None,
persist_model: bool = False,
model_save_path: Optional[str] = None,
):
"""初始化训练器
Args:
model: 模型实例
pool_manager: 股票池管理器None 表示不筛选
processors: 数据处理器列表
splitter: 数据划分器
target_col: 目标变量列名
feature_cols: 特征列名列表
persist_model: 是否保存模型
model_save_path: 模型保存路径
"""
self.model = model
self.pool_manager = pool_manager
self.processors = processors or []
self.splitter = splitter
self.target_col = target_col
self.feature_cols = feature_cols or []
self.persist_model = persist_model
self.model_save_path = model_save_path
# 存储训练后的处理器
self.fitted_processors: List[BaseProcessor] = []
self.results: Optional[pl.DataFrame] = None
def train(self, data: pl.DataFrame) -> "Trainer":
"""执行训练流程
流程:
1. 股票池每日筛选(如果配置了 pool_manager
2. 按日期划分训练集/测试集
3. 训练集processors fit_transform
4. 训练模型
5. 测试集processors transform使用训练集学到的参数
6. 预测
7. 保存结果
8. 持久化模型(如果启用)
Args:
data: 因子计算后的全市场数据
必须包含 ts_code 和 trade_date 列
Returns:
self (支持链式调用)
"""
# 1. 股票池筛选(每日独立)
if self.pool_manager:
print("[筛选] 每日独立筛选股票池...")
data = self.pool_manager.filter_and_select_daily(data)
# 2. 划分训练/测试集
if self.splitter:
print("[划分] 划分训练集和测试集...")
train_data, test_data = self.splitter.split(data)
else:
# 没有划分器,全部作为训练集
train_data = data
test_data = data
# 3. 训练集processors fit_transform
if self.processors:
print("[处理] 处理训练集...")
for processor in self.processors:
train_data = processor.fit_transform(train_data)
self.fitted_processors.append(processor)
# 4. 训练模型
print("[训练] 训练模型...")
if not self.feature_cols:
raise ValueError("feature_cols 不能为空")
X_train = train_data.select(self.feature_cols)
y_train = train_data.select(self.target_col).to_series()
self.model.fit(X_train, y_train)
# 5. 测试集processors transform
if self.processors and test_data is not train_data:
print("[处理] 处理测试集...")
for processor in self.fitted_processors:
test_data = processor.transform(test_data)
# 6. 预测
print("[预测] 生成预测...")
X_test = test_data.select(self.feature_cols)
predictions = self.model.predict(X_test)
# 7. 保存结果
self.results = test_data.with_columns([pl.Series("prediction", predictions)])
# 8. 持久化模型
if self.persist_model and self.model_save_path:
print(f"[保存] 保存模型到 {self.model_save_path}...")
self.save_model(self.model_save_path)
return self
def predict(self, data: pl.DataFrame) -> pl.DataFrame:
"""对新数据进行预测
注意:新数据需要先经过股票池筛选,
然后使用训练好的 processors 和 model 进行预测。
Args:
data: 输入数据
Returns:
包含预测列的数据
"""
# 应用 processors
for processor in self.fitted_processors:
data = processor.transform(data)
# 预测
X = data.select(self.feature_cols)
predictions = self.model.predict(X)
return data.with_columns([pl.Series("prediction", predictions)])
def get_results(self) -> Optional[pl.DataFrame]:
"""获取所有预测结果
Returns:
预测结果 DataFrame包含原始列和 prediction 列
"""
return self.results
def save_results(self, path: str) -> None:
"""保存预测结果到文件
Args:
path: 保存路径CSV 格式)
"""
if self.results is not None:
self.results.write_csv(path)
def save_model(self, path: str) -> None:
"""保存模型
Args:
path: 模型保存路径
"""
self.model.save(path)

View File

@@ -0,0 +1,235 @@
"""测试 LightGBM 模型
验证 LightGBMModel 的训练、预测、保存和加载功能。
"""
import os
import tempfile
import numpy as np
import polars as pl
import pytest
from src.training.components.models.lightgbm import LightGBMModel
class TestLightGBMModel:
"""LightGBMModel 测试类"""
def test_init_default(self):
"""测试默认初始化"""
model = LightGBMModel()
assert model.name == "lightgbm"
assert model.params["objective"] == "regression"
assert model.params["metric"] == "rmse"
assert model.params["num_leaves"] == 31
assert model.params["learning_rate"] == 0.05
assert model.n_estimators == 100
assert model.model is None
def test_init_custom(self):
"""测试自定义参数"""
model = LightGBMModel(
objective="huber",
metric="mae",
num_leaves=50,
learning_rate=0.1,
n_estimators=200,
)
assert model.params["objective"] == "huber"
assert model.params["metric"] == "mae"
assert model.params["num_leaves"] == 50
assert model.params["learning_rate"] == 0.1
assert model.n_estimators == 200
def test_fit_success(self):
"""测试正常训练"""
# 创建简单回归数据
X = pl.DataFrame(
{
"feature1": [1.0, 2.0, 3.0, 4.0, 5.0],
"feature2": [2.0, 4.0, 6.0, 8.0, 10.0],
}
)
y = pl.Series("target", [1.5, 3.0, 4.5, 6.0, 7.5])
model = LightGBMModel(n_estimators=10)
result = model.fit(X, y)
# 验证返回 self支持链式调用
assert result is model
# 验证模型已训练
assert model.model is not None
# 验证特征名称已保存
assert model.feature_names_ == ["feature1", "feature2"]
def test_predict_before_fit(self):
"""测试未训练就预测"""
X = pl.DataFrame(
{
"feature1": [1.0, 2.0],
"feature2": [2.0, 4.0],
}
)
model = LightGBMModel()
with pytest.raises(RuntimeError, match="模型尚未训练"):
model.predict(X)
def test_predict_success(self):
"""测试正常预测"""
# 创建回归数据
np.random.seed(42)
n_samples = 100
X_train = pl.DataFrame(
{
"feature1": np.random.randn(n_samples),
"feature2": np.random.randn(n_samples),
}
)
# y = 2*feature1 + 3*feature2 + noise
y_train = pl.Series(
"target",
2 * X_train["feature1"]
+ 3 * X_train["feature2"]
+ np.random.randn(n_samples) * 0.1,
)
model = LightGBMModel(n_estimators=20, learning_rate=0.1)
model.fit(X_train, y_train)
# 预测新数据(使用明显不同的值)
X_test = pl.DataFrame(
{
"feature1": [-2.0, 3.0],
"feature2": [-1.0, 4.0],
}
)
predictions = model.predict(X_test)
# 验证预测结果格式
assert isinstance(predictions, np.ndarray)
assert len(predictions) == 2
# 验证预测值是数值
assert all(np.isfinite(predictions))
# 验证单调性(第二个样本的 feature 值更大,预测值也应更大)
assert predictions[1] > predictions[0]
def test_feature_importance_before_fit(self):
"""测试未训练就获取特征重要性"""
model = LightGBMModel()
assert model.feature_importance() is None
def test_feature_importance_after_fit(self):
"""测试训练后获取特征重要性"""
X = pl.DataFrame(
{
"feature1": np.random.randn(100),
"feature2": np.random.randn(100),
}
)
y = pl.Series("target", X["feature1"] * 2 + X["feature2"] * 0.1)
model = LightGBMModel(n_estimators=10)
model.fit(X, y)
importance = model.feature_importance()
# 验证特征重要性格式
assert importance is not None
assert len(importance) == 2
assert "feature1" in importance.index
assert "feature2" in importance.index
# feature1 的系数更大,重要性应该更高
assert importance["feature1"] >= importance["feature2"]
def test_save_before_fit(self):
"""测试未训练就保存"""
model = LightGBMModel()
with pytest.raises(RuntimeError, match="模型尚未训练"):
model.save("dummy.txt")
def test_save_and_load(self):
"""测试保存和加载"""
# 训练模型
X = pl.DataFrame(
{
"feature1": [1.0, 2.0, 3.0, 4.0, 5.0],
"feature2": [2.0, 4.0, 6.0, 8.0, 10.0],
}
)
y = pl.Series("target", [2.0, 4.0, 6.0, 8.0, 10.0])
model = LightGBMModel(n_estimators=10, learning_rate=0.1)
model.fit(X, y)
# 保存前预测
X_test = pl.DataFrame(
{
"feature1": [6.0],
"feature2": [12.0],
}
)
pred_before = model.predict(X_test)
# 保存到临时文件
with tempfile.TemporaryDirectory() as tmpdir:
save_path = os.path.join(tmpdir, "model.txt")
model.save(save_path)
# 加载模型
loaded_model = LightGBMModel.load(save_path)
# 验证加载后预测结果相同
pred_after = loaded_model.predict(X_test)
assert pred_after[0] == pytest.approx(pred_before[0], rel=1e-5)
# 验证元数据已恢复
assert loaded_model.feature_names_ == ["feature1", "feature2"]
assert loaded_model.n_estimators == 10
def test_registration(self):
"""测试模型已注册到 registry"""
from src.training.registry import ModelRegistry
# 重新导入模型模块以确保注册(处理其他测试 clear 注册表的情况)
import importlib
import src.training.components.models.lightgbm as lightgbm_module
importlib.reload(lightgbm_module)
from src.training.components.models.lightgbm import (
LightGBMModel as ReloadedModel,
)
model_class = ModelRegistry.get_model("lightgbm")
assert model_class is ReloadedModel
def test_fit_predict_consistency(self):
"""测试多次预测结果一致"""
X = pl.DataFrame(
{
"feature1": np.random.randn(50),
"feature2": np.random.randn(50),
}
)
y = pl.Series("target", X["feature1"] + X["feature2"])
model = LightGBMModel(n_estimators=10)
model.fit(X, y)
X_test = pl.DataFrame(
{
"feature1": [1.0, 2.0, 3.0],
"feature2": [1.0, 2.0, 3.0],
}
)
# 多次预测应该返回相同结果
pred1 = model.predict(X_test)
pred2 = model.predict(X_test)
np.testing.assert_array_almost_equal(pred1, pred2)
if __name__ == "__main__":
pytest.main([__file__, "-v"])

View File

@@ -0,0 +1,300 @@
"""测试数据处理器
验证 StandardScaler、CrossSectionalStandardScaler 和 Winsorizer 功能。
"""
import numpy as np
import polars as pl
import pytest
from src.training.components.processors import (
CrossSectionalStandardScaler,
StandardScaler,
Winsorizer,
)
class TestStandardScaler:
"""StandardScaler 测试类"""
def test_init_default(self):
"""测试默认初始化"""
scaler = StandardScaler()
assert scaler.exclude_cols == ["ts_code", "trade_date"]
assert scaler.mean_ == {}
assert scaler.std_ == {}
def test_init_custom_exclude(self):
"""测试自定义排除列"""
scaler = StandardScaler(exclude_cols=["id", "date"])
assert scaler.exclude_cols == ["id", "date"]
def test_fit_transform(self):
"""测试拟合和转换"""
data = pl.DataFrame(
{
"ts_code": ["A", "B", "C", "D"],
"trade_date": ["20240101"] * 4,
"value": [1.0, 2.0, 3.0, 4.0],
}
)
scaler = StandardScaler()
result = scaler.fit_transform(data)
# 验证学习到的统计量
assert scaler.mean_["value"] == 2.5
assert scaler.std_["value"] == pytest.approx(1.290, rel=1e-2)
# 验证转换结果
expected_std = (np.array([1.0, 2.0, 3.0, 4.0]) - 2.5) / 1.290
assert result["value"].to_list() == pytest.approx(
expected_std.tolist(), rel=1e-2
)
def test_transform_use_fitted_params(self):
"""测试转换使用拟合时的参数"""
train_data = pl.DataFrame(
{
"ts_code": ["A", "B", "C"],
"trade_date": ["20240101"] * 3,
"value": [1.0, 2.0, 3.0],
}
)
test_data = pl.DataFrame(
{
"ts_code": ["D"],
"trade_date": ["20240102"],
"value": [100.0], # 远离训练分布
}
)
scaler = StandardScaler()
scaler.fit(train_data)
# 使用训练集的均值(2.0)和标准差进行转换
result = scaler.transform(test_data)
expected = (100.0 - 2.0) / 1.0 # 均值2.0, 标准差1.0
assert result["value"][0] == pytest.approx(expected, rel=1e-2)
def test_exclude_non_numeric(self):
"""测试自动排除非数值列"""
data = pl.DataFrame(
{
"ts_code": ["A", "B"],
"trade_date": ["20240101", "20240102"],
"category": ["X", "Y"], # 字符串列
"value": [1.0, 2.0],
}
)
scaler = StandardScaler()
result = scaler.fit_transform(data)
# category 列应该原样保留
assert result["category"].to_list() == ["X", "Y"]
# value 列应该被标准化
assert "value" in scaler.mean_
def test_zero_std_handling(self):
"""测试处理标准差为0的情况"""
data = pl.DataFrame(
{
"ts_code": ["A", "B"],
"trade_date": ["20240101", "20240102"],
"constant": [5.0, 5.0], # 常数列
}
)
scaler = StandardScaler()
result = scaler.fit_transform(data)
# 标准差为0时结果应该为0避免除以0
assert result["constant"].to_list() == [0.0, 0.0]
class TestCrossSectionalStandardScaler:
"""CrossSectionalStandardScaler 测试类"""
def test_init_default(self):
"""测试默认初始化"""
scaler = CrossSectionalStandardScaler()
assert scaler.exclude_cols == ["ts_code", "trade_date"]
assert scaler.date_col == "trade_date"
def test_init_custom(self):
"""测试自定义参数"""
scaler = CrossSectionalStandardScaler(
exclude_cols=["id"],
date_col="date",
)
assert scaler.exclude_cols == ["id"]
assert scaler.date_col == "date"
def test_transform_no_fit_needed(self):
"""测试不需要 fit"""
data = pl.DataFrame(
{
"ts_code": ["A", "B"],
"trade_date": ["20240101", "20240101"],
"value": [1.0, 3.0],
}
)
scaler = CrossSectionalStandardScaler()
# 截面标准化不需要 fit
result = scaler.transform(data)
# 当天均值=2.0, 样本标准差=sqrt(2)≈1.414, z-score=[-0.707, 0.707]
assert result["value"].to_list() == pytest.approx([-0.707, 0.707], rel=1e-2)
def test_transform_by_date(self):
"""测试按日期分组标准化"""
data = pl.DataFrame(
{
"ts_code": ["A", "B", "C", "D"],
"trade_date": ["20240101", "20240101", "20240102", "20240102"],
"value": [1.0, 3.0, 10.0, 30.0],
}
)
scaler = CrossSectionalStandardScaler()
result = scaler.transform(data)
# 2024-01-01: 均值=2.0, 样本std≈1.414 -> [-0.707, 0.707]
# 2024-01-02: 均值=20.0, 样本std≈14.14 -> [-0.707, 0.707]
values = result["value"].to_list()
assert values[0] == pytest.approx(-0.707, abs=1e-2)
assert values[1] == pytest.approx(0.707, abs=1e-2)
assert values[2] == pytest.approx(-0.707, abs=1e-2)
assert values[3] == pytest.approx(0.707, abs=1e-2)
def test_exclude_columns_preserved(self):
"""测试排除列保持原样"""
data = pl.DataFrame(
{
"ts_code": ["A", "B"],
"trade_date": ["20240101", "20240101"],
"value": [1.0, 3.0],
}
)
scaler = CrossSectionalStandardScaler()
result = scaler.transform(data)
assert result["ts_code"].to_list() == ["A", "B"]
assert result["trade_date"].to_list() == ["20240101", "20240101"]
class TestWinsorizer:
"""Winsorizer 测试类"""
def test_init_default(self):
"""测试默认初始化"""
winsorizer = Winsorizer()
assert winsorizer.lower == 0.01
assert winsorizer.upper == 0.99
assert winsorizer.by_date is False
assert winsorizer.date_col == "trade_date"
def test_init_custom(self):
"""测试自定义参数"""
winsorizer = Winsorizer(lower=0.05, upper=0.95, by_date=True, date_col="date")
assert winsorizer.lower == 0.05
assert winsorizer.upper == 0.95
assert winsorizer.by_date is True
assert winsorizer.date_col == "date"
def test_invalid_quantiles(self):
"""测试无效的分位数参数"""
with pytest.raises(ValueError, match="lower .* 必须小于 upper"):
Winsorizer(lower=0.5, upper=0.3)
with pytest.raises(ValueError, match="lower .* 必须小于 upper"):
Winsorizer(lower=-0.1, upper=0.5)
with pytest.raises(ValueError, match="lower .* 必须小于 upper"):
Winsorizer(lower=0.5, upper=1.5)
def test_global_winsorize(self):
"""测试全局缩尾"""
# 创建包含极端值的数据
values = list(range(1, 101)) # 1-100
values[0] = -1000 # 极端小值
values[-1] = 1000 # 极端大值
data = pl.DataFrame(
{
"ts_code": [f"A{i}" for i in range(100)],
"trade_date": ["20240101"] * 100,
"value": values,
}
)
winsorizer = Winsorizer(lower=0.01, upper=0.99)
result = winsorizer.fit_transform(data)
# 1%分位数=2, 99%分位数=99
# -1000 应该被截断为 2
# 1000 应该被截断为 99
result_values = result["value"].to_list()
assert result_values[0] == 2 # 原-1000被截断
assert result_values[-1] == 99 # 原1000被截断
assert result_values[1] == 2 # 原2保持不变
assert result_values[98] == 99 # 原99保持不变
def test_by_date_winsorize(self):
"""测试每日独立缩尾"""
data = pl.DataFrame(
{
"ts_code": ["A", "B", "C", "D", "E", "F"],
"trade_date": ["20240101"] * 3 + ["20240102"] * 3,
"value": [1.0, 50.0, 100.0, 200.0, 250.0, 300.0],
}
)
winsorizer = Winsorizer(lower=0.0, upper=0.5, by_date=True)
result = winsorizer.transform(data)
# 每天独立处理:
# 2024-01-01: [1, 50, 100], 50%分位数=50
# -> 截断为 [1, 50, 50]
# 2024-01-02: [200, 250, 300], 50%分位数=250
# -> 截断为 [200, 250, 250]
result_values = result["value"].to_list()
assert result_values[0] == 1.0
assert result_values[1] == 50.0
assert result_values[2] == 50.0 # 被截断
assert result_values[3] == 200.0
assert result_values[4] == 250.0
assert result_values[5] == 250.0 # 被截断
def test_global_transform_after_fit(self):
"""测试全局模式下,转换使用拟合时的边界"""
train_data = pl.DataFrame(
{
"ts_code": ["A", "B", "C"],
"trade_date": ["20240101"] * 3,
"value": [1.0, 50.0, 100.0],
}
)
test_data = pl.DataFrame(
{
"ts_code": ["D"],
"trade_date": ["20240102"],
"value": [200.0],
}
)
winsorizer = Winsorizer(lower=0.0, upper=1.0) # 0%和100%分位数
winsorizer.fit(train_data)
# 使用训练集的分位数边界 [1, 100]
result = winsorizer.transform(test_data)
assert result["value"][0] == 100.0 # 被截断为100
if __name__ == "__main__":
pytest.main([__file__, "-v"])

View File

@@ -0,0 +1,183 @@
"""测试股票池选择器配置
验证 StockFilterConfig 和 MarketCapSelectorConfig 功能。
"""
import pytest
from src.training.components.selectors import (
MarketCapSelectorConfig,
StockFilterConfig,
)
class TestStockFilterConfig:
"""StockFilterConfig 测试类"""
def test_default_values(self):
"""测试默认值"""
config = StockFilterConfig()
assert config.exclude_cyb is True
assert config.exclude_kcb is True
assert config.exclude_bj is True
assert config.exclude_st is True
def test_custom_values(self):
"""测试自定义值"""
config = StockFilterConfig(
exclude_cyb=False,
exclude_kcb=False,
exclude_bj=False,
exclude_st=False,
)
assert config.exclude_cyb is False
assert config.exclude_kcb is False
assert config.exclude_bj is False
assert config.exclude_st is False
def test_filter_codes_exclude_all(self):
"""测试排除所有类型"""
config = StockFilterConfig(
exclude_cyb=True,
exclude_kcb=True,
exclude_bj=True,
exclude_st=True,
)
codes = [
"000001.SZ", # 主板 - 保留
"300001.SZ", # 创业板 - 排除
"688001.SH", # 科创板 - 排除
"830001.BJ", # 北交所8开头- 排除
"430001.BJ", # 北交所4开头- 排除
]
result = config.filter_codes(codes)
assert result == ["000001.SZ"]
def test_filter_codes_allow_cyb(self):
"""测试允许创业板"""
config = StockFilterConfig(
exclude_cyb=False,
exclude_kcb=True,
exclude_bj=True,
exclude_st=True,
)
codes = [
"000001.SZ",
"300001.SZ",
"688001.SH",
]
result = config.filter_codes(codes)
assert result == ["000001.SZ", "300001.SZ"]
def test_filter_codes_allow_kcb(self):
"""测试允许科创板"""
config = StockFilterConfig(
exclude_cyb=True,
exclude_kcb=False,
exclude_bj=True,
exclude_st=True,
)
codes = [
"000001.SZ",
"300001.SZ",
"688001.SH",
]
result = config.filter_codes(codes)
assert result == ["000001.SZ", "688001.SH"]
def test_filter_codes_allow_bj(self):
"""测试允许北交所"""
config = StockFilterConfig(
exclude_cyb=True,
exclude_kcb=True,
exclude_bj=False,
exclude_st=True,
)
codes = [
"000001.SZ",
"300001.SZ",
"830001.BJ",
"430001.BJ",
]
result = config.filter_codes(codes)
assert result == ["000001.SZ", "830001.BJ", "430001.BJ"]
def test_filter_codes_allow_all(self):
"""测试允许所有类型"""
config = StockFilterConfig(
exclude_cyb=False,
exclude_kcb=False,
exclude_bj=False,
exclude_st=False,
)
codes = [
"000001.SZ",
"300001.SZ",
"688001.SH",
"830001.BJ",
"430001.BJ",
]
result = config.filter_codes(codes)
assert result == codes
def test_filter_codes_empty_list(self):
"""测试空列表"""
config = StockFilterConfig()
result = config.filter_codes([])
assert result == []
def test_filter_codes_no_matching(self):
"""测试全部排除"""
config = StockFilterConfig()
codes = ["300001.SZ", "688001.SH", "830001.BJ"]
result = config.filter_codes(codes)
assert result == []
class TestMarketCapSelectorConfig:
"""MarketCapSelectorConfig 测试类"""
def test_default_values(self):
"""测试默认值"""
config = MarketCapSelectorConfig()
assert config.enabled is True
assert config.n == 100
assert config.ascending is False
assert config.market_cap_col == "total_mv"
def test_custom_values(self):
"""测试自定义值"""
config = MarketCapSelectorConfig(
enabled=False,
n=50,
ascending=True,
market_cap_col="circ_mv",
)
assert config.enabled is False
assert config.n == 50
assert config.ascending is True
assert config.market_cap_col == "circ_mv"
def test_invalid_n_zero(self):
"""测试无效的 n=0"""
with pytest.raises(ValueError, match="n 必须是正整数"):
MarketCapSelectorConfig(n=0)
def test_invalid_n_negative(self):
"""测试无效的负数 n"""
with pytest.raises(ValueError, match="n 必须是正整数"):
MarketCapSelectorConfig(n=-1)
def test_invalid_empty_market_cap_col(self):
"""测试空的市值列名"""
with pytest.raises(ValueError, match="market_cap_col 不能为空"):
MarketCapSelectorConfig(market_cap_col="")
def test_large_n(self):
"""测试大的 n 值"""
config = MarketCapSelectorConfig(n=5000)
assert config.n == 5000
if __name__ == "__main__":
pytest.main([__file__, "-v"])

View File

@@ -0,0 +1,244 @@
"""测试 DateSplitter 数据划分器
验证一次性日期划分功能。
"""
import pytest
import polars as pl
from src.training.components.splitters import DateSplitter
class TestDateSplitter:
"""DateSplitter 测试类"""
def test_initialization_success(self):
"""测试正常初始化"""
splitter = DateSplitter(
train_start="20200101",
train_end="20221231",
test_start="20230101",
test_end="20231231",
)
assert splitter.train_start == "20200101"
assert splitter.train_end == "20221231"
assert splitter.test_start == "20230101"
assert splitter.test_end == "20231231"
def test_invalid_date_format(self):
"""测试无效的日期格式"""
with pytest.raises(ValueError, match="必须是格式为 'YYYYMMDD' 的8位字符串"):
DateSplitter(
train_start="2020-01-01", # 错误格式
train_end="20221231",
test_start="20230101",
test_end="20231231",
)
def test_train_start_after_train_end(self):
"""测试训练集开始日期晚于结束日期"""
with pytest.raises(ValueError, match="train_start.*必须早于或等于 train_end"):
DateSplitter(
train_start="20231231",
train_end="20200101",
test_start="20230101",
test_end="20231231",
)
def test_test_start_after_test_end(self):
"""测试测试集开始日期晚于结束日期"""
with pytest.raises(ValueError, match="test_start.*必须早于或等于 test_end"):
DateSplitter(
train_start="20200101",
train_end="20221231",
test_start="20231231",
test_end="20230101",
)
def test_overlapping_dates(self):
"""测试训练集和测试集日期重叠"""
with pytest.raises(ValueError, match="必须晚于训练集结束日期"):
DateSplitter(
train_start="20200101",
train_end="20221231",
test_start="20220601", # 在训练集范围内
test_end="20231231",
)
def test_split_success(self):
"""测试正常划分数据"""
# 创建测试数据
data = pl.DataFrame(
{
"ts_code": [
"000001.SZ",
"000002.SZ",
"000003.SZ",
"000004.SZ",
"000005.SZ",
"000006.SZ",
],
"trade_date": [
"20200101",
"20211231",
"20221231",
"20230101",
"20230601",
"20231231",
],
"value": [1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
}
)
splitter = DateSplitter(
train_start="20200101",
train_end="20221231",
test_start="20230101",
test_end="20231231",
)
train_data, test_data = splitter.split(data)
# 验证训练集
assert len(train_data) == 3
assert train_data["trade_date"].to_list() == [
"20200101",
"20211231",
"20221231",
]
# 验证测试集
assert len(test_data) == 3
assert test_data["trade_date"].to_list() == ["20230101", "20230601", "20231231"]
def test_split_no_matching_train_data(self):
"""测试训练集无匹配数据"""
data = pl.DataFrame(
{
"ts_code": ["000001.SZ", "000002.SZ"],
"trade_date": ["20230101", "20231231"],
"value": [1.0, 2.0],
}
)
splitter = DateSplitter(
train_start="20200101",
train_end="20221231",
test_start="20230101",
test_end="20231231",
)
train_data, test_data = splitter.split(data)
# 训练集应该为空
assert len(train_data) == 0
# 测试集应该有数据
assert len(test_data) == 2
def test_split_no_matching_test_data(self):
"""测试测试集无匹配数据"""
data = pl.DataFrame(
{
"ts_code": ["000001.SZ", "000002.SZ"],
"trade_date": ["20200101", "20211231"],
"value": [1.0, 2.0],
}
)
splitter = DateSplitter(
train_start="20200101",
train_end="20221231",
test_start="20230101",
test_end="20231231",
)
train_data, test_data = splitter.split(data)
# 训练集应该有数据
assert len(train_data) == 2
# 测试集应该为空
assert len(test_data) == 0
def test_split_with_custom_date_col(self):
"""测试使用自定义日期列名"""
data = pl.DataFrame(
{
"ts_code": ["000001.SZ", "000002.SZ", "000003.SZ"],
"date": ["20200101", "20211231", "20230101"],
"value": [1.0, 2.0, 3.0],
}
)
splitter = DateSplitter(
train_start="20200101",
train_end="20221231",
test_start="20230101",
test_end="20231231",
)
train_data, test_data = splitter.split(data, date_col="date")
assert len(train_data) == 2
assert len(test_data) == 1
def test_split_missing_date_column(self):
"""测试数据缺少日期列"""
data = pl.DataFrame(
{
"ts_code": ["000001.SZ"],
"value": [1.0],
}
)
splitter = DateSplitter(
train_start="20200101",
train_end="20221231",
test_start="20230101",
test_end="20231231",
)
with pytest.raises(ValueError, match="数据中不包含列 'trade_date'"):
splitter.split(data)
def test_repr(self):
"""测试 __repr__ 方法"""
splitter = DateSplitter(
train_start="20200101",
train_end="20221231",
test_start="20230101",
test_end="20231231",
)
repr_str = repr(splitter)
assert "DateSplitter" in repr_str
assert "20200101" in repr_str
assert "20221231" in repr_str
assert "20230101" in repr_str
assert "20231231" in repr_str
def test_edge_case_same_day_train(self):
"""测试训练集为单日"""
data = pl.DataFrame(
{
"ts_code": ["000001.SZ"],
"trade_date": ["20200101"],
"value": [1.0],
}
)
splitter = DateSplitter(
train_start="20200101",
train_end="20200101",
test_start="20200102",
test_end="20200102",
)
train_data, test_data = splitter.split(data)
assert len(train_data) == 1
assert len(test_data) == 0
if __name__ == "__main__":
pytest.main([__file__, "-v"])