Files
ProStock/tests/models/test_core.py
liaozhaorun 9f95be56a0 feat(models): 实现机器学习模型训练框架
- 添加核心抽象:Processor、Model、Splitter、Metric 基类
- 实现阶段感知机制(TRAIN/TEST/ALL),防止数据泄露
- 内置 8 个数据处理器和 3 种时序划分策略
- 支持 LightGBM、CatBoost 模型
- PluginRegistry 装饰器注册,插件式架构
- 22 个单元测试
2026-02-23 01:37:34 +08:00

479 lines
14 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""模型框架核心测试
测试核心抽象类、插件注册中心、处理器、模型和划分策略。
"""
import pytest
import polars as pl
import numpy as np
from typing import List, Optional
# 确保导入时注册所有组件
from src.models import (
PluginRegistry,
PipelineStage,
BaseProcessor,
BaseModel,
BaseSplitter,
ProcessingPipeline,
)
from src.models.core import TaskType
# ========== 测试核心抽象类 ==========
class TestPipelineStage:
"""测试阶段枚举"""
def test_stage_values(self):
assert PipelineStage.ALL.name == "ALL"
assert PipelineStage.TRAIN.name == "TRAIN"
assert PipelineStage.TEST.name == "TEST"
assert PipelineStage.VALIDATION.name == "VALIDATION"
class TestBaseProcessor:
"""测试处理器基类"""
def test_processor_initialization(self):
"""测试处理器初始化"""
class DummyProcessor(BaseProcessor):
stage = PipelineStage.ALL
def fit(self, data: pl.DataFrame) -> "DummyProcessor":
self._is_fitted = True
return self
def transform(self, data: pl.DataFrame) -> pl.DataFrame:
return data
processor = DummyProcessor(columns=["col1", "col2"])
assert processor.columns == ["col1", "col2"]
assert processor.stage == PipelineStage.ALL
assert not processor._is_fitted
def test_processor_fit_transform(self):
"""测试 fit_transform 方法"""
class AddOneProcessor(BaseProcessor):
stage = PipelineStage.ALL
def fit(self, data: pl.DataFrame) -> "AddOneProcessor":
self._is_fitted = True
return self
def transform(self, data: pl.DataFrame) -> pl.DataFrame:
result = data.clone()
for col in self.columns or []:
result = result.with_columns((pl.col(col) + 1).alias(col))
return result
processor = AddOneProcessor(columns=["value"])
df = pl.DataFrame({"value": [1, 2, 3]})
result = processor.fit_transform(df)
assert processor._is_fitted
assert result["value"].to_list() == [2, 3, 4]
class TestBaseModel:
"""测试模型基类"""
def test_model_initialization(self):
"""测试模型初始化"""
class DummyModel(BaseModel):
def fit(self, X, y, X_val=None, y_val=None, **kwargs):
self._is_fitted = True
return self
def predict(self, X):
return np.zeros(len(X))
model = DummyModel(
task_type="regression", params={"lr": 0.01}, name="test_model"
)
assert model.task_type == "regression"
assert model.params == {"lr": 0.01}
assert model.name == "test_model"
assert not model._is_fitted
def test_predict_proba_not_implemented(self):
"""测试未实现 predict_proba 时抛出异常"""
class DummyModel(BaseModel):
def fit(self, X, y, X_val=None, y_val=None, **kwargs):
return self
def predict(self, X):
return np.zeros(len(X))
model = DummyModel(task_type="regression")
df = pl.DataFrame({"feature": [1, 2, 3]})
with pytest.raises(NotImplementedError):
model.predict_proba(df)
class TestBaseSplitter:
"""测试划分策略基类"""
def test_splitter_interface(self):
"""测试划分策略接口"""
class DummySplitter(BaseSplitter):
def split(self, data, date_col="trade_date"):
yield [0, 1], [2, 3]
def get_split_dates(self, data, date_col="trade_date"):
return [("20200101", "20201231", "20210101", "20211231")]
splitter = DummySplitter()
df = pl.DataFrame(
{"trade_date": ["20200101", "20200601", "20210101", "20210601"]}
)
splits = list(splitter.split(df))
assert len(splits) == 1
assert splits[0] == ([0, 1], [2, 3])
dates = splitter.get_split_dates(df)
assert dates == [("20200101", "20201231", "20210101", "20211231")]
# ========== 测试插件注册中心 ==========
class TestPluginRegistry:
"""测试插件注册中心"""
def setup_method(self):
"""每个测试前清除注册"""
PluginRegistry.clear_all()
def test_register_and_get_processor(self):
"""测试注册和获取处理器"""
@PluginRegistry.register_processor("test_processor")
class TestProcessor(BaseProcessor):
stage = PipelineStage.ALL
def fit(self, data):
return self
def transform(self, data):
return data
processor_class = PluginRegistry.get_processor("test_processor")
assert processor_class == TestProcessor
assert "test_processor" in PluginRegistry.list_processors()
def test_register_and_get_model(self):
"""测试注册和获取模型"""
@PluginRegistry.register_model("test_model")
class TestModel(BaseModel):
def fit(self, X, y, X_val=None, y_val=None, **kwargs):
return self
def predict(self, X):
return np.zeros(len(X))
model_class = PluginRegistry.get_model("test_model")
assert model_class == TestModel
assert "test_model" in PluginRegistry.list_models()
def test_register_and_get_splitter(self):
"""测试注册和获取划分策略"""
@PluginRegistry.register_splitter("test_splitter")
class TestSplitter(BaseSplitter):
def split(self, data, date_col="trade_date"):
yield [], []
def get_split_dates(self, data, date_col="trade_date"):
return []
splitter_class = PluginRegistry.get_splitter("test_splitter")
assert splitter_class == TestSplitter
assert "test_splitter" in PluginRegistry.list_splitters()
def test_get_nonexistent_processor(self):
"""测试获取不存在的处理器时抛出异常"""
with pytest.raises(KeyError) as exc_info:
PluginRegistry.get_processor("nonexistent")
assert "nonexistent" in str(exc_info.value)
def test_register_with_default_name(self):
"""测试使用默认名称注册"""
@PluginRegistry.register_processor()
class MyCustomProcessor(BaseProcessor):
stage = PipelineStage.ALL
def fit(self, data):
return self
def transform(self, data):
return data
assert "MyCustomProcessor" in PluginRegistry.list_processors()
# ========== 测试内置处理器 ==========
class TestBuiltInProcessors:
"""测试内置处理器"""
def test_dropna_processor(self):
"""测试缺失值删除处理器"""
from src.models.processors import DropNAProcessor
processor = DropNAProcessor(columns=["a", "b"])
df = pl.DataFrame({"a": [1, None, 3], "b": [4, 5, None], "c": [7, 8, 9]})
result = processor.fit_transform(df)
# 只有第一行没有缺失值
assert len(result) == 1
assert result["a"].to_list() == [1]
assert result["b"].to_list() == [4]
def test_fillna_processor(self):
"""测试缺失值填充处理器"""
from src.models.processors import FillNAProcessor
processor = FillNAProcessor(columns=["a"], method="mean")
df = pl.DataFrame({"a": [1.0, 2.0, None, 4.0]})
result = processor.fit_transform(df)
# 均值 = (1+2+4)/3 = 2.333...
assert result["a"][2] == pytest.approx(2.333, rel=0.01)
def test_standard_scaler(self):
"""测试标准化处理器"""
from src.models.processors import StandardScaler
processor = StandardScaler(columns=["value"])
df = pl.DataFrame({"value": [1.0, 2.0, 3.0, 4.0, 5.0]})
result = processor.fit_transform(df)
# Z-score 标准化后均值为0标准差为1
assert result["value"].mean() == pytest.approx(0.0, abs=1e-10)
assert result["value"].std() == pytest.approx(1.0, rel=0.01)
def test_winsorizer(self):
"""测试缩尾处理器"""
from src.models.processors import Winsorizer
processor = Winsorizer(columns=["value"], lower=0.1, upper=0.9)
df = pl.DataFrame(
{
"value": list(range(100)) # 0-99
}
)
result = processor.fit_transform(df)
# 10%和90%分位数应该是10和89Polars的quantile行为
assert result["value"].min() == 10
assert result["value"].max() == 89
def test_rank_transformer(self):
"""测试排名转换处理器"""
from src.models.processors import RankTransformer
processor = RankTransformer(columns=["value"])
df = pl.DataFrame(
{"trade_date": ["20200101"] * 5, "value": [10, 30, 20, 50, 40]}
)
result = processor.fit_transform(df)
# 排名应该是 1, 3, 2, 5, 4
assert result["value"].to_list() == [1, 3, 2, 5, 4]
def test_neutralizer(self):
"""测试中性化处理器"""
from src.models.processors import Neutralizer
processor = Neutralizer(columns=["value"], group_col="industry")
df = pl.DataFrame(
{
"trade_date": ["20200101", "20200101", "20200101", "20200101"],
"industry": ["A", "A", "B", "B"],
"value": [10, 20, 30, 50],
}
)
result = processor.fit_transform(df)
# 分组去均值后每组的均值为0
group_a = result.filter(pl.col("industry") == "A")
group_b = result.filter(pl.col("industry") == "B")
assert group_a["value"].mean() == pytest.approx(0.0, abs=1e-10)
assert group_b["value"].mean() == pytest.approx(0.0, abs=1e-10)
# ========== 测试处理流水线 ==========
class TestProcessingPipeline:
"""测试处理流水线"""
def test_pipeline_fit_transform(self):
"""测试流水线的 fit_transform"""
from src.models.processors import StandardScaler
scaler1 = StandardScaler(columns=["a"])
scaler2 = StandardScaler(columns=["b"])
pipeline = ProcessingPipeline([scaler1, scaler2])
df = pl.DataFrame({"a": [1.0, 2.0, 3.0], "b": [10.0, 20.0, 30.0]})
result = pipeline.fit_transform(df)
# 两个列都应该被标准化
assert result["a"].mean() == pytest.approx(0.0, abs=1e-10)
assert result["b"].mean() == pytest.approx(0.0, abs=1e-10)
def test_pipeline_transform_uses_fitted_params(self):
"""测试 transform 使用已 fit 的参数"""
from src.models.processors import StandardScaler
scaler = StandardScaler(columns=["value"])
pipeline = ProcessingPipeline([scaler])
# 训练数据
train_df = pl.DataFrame(
{
"value": [1.0, 2.0, 3.0] # 均值=2标准差=1
}
)
# 测试数据(不同的分布)
test_df = pl.DataFrame(
{
"value": [4.0, 5.0, 6.0] # 如果重新计算应该是均值=5
}
)
pipeline.fit_transform(train_df)
result = pipeline.transform(test_df)
# 使用训练数据的均值=2和标准差=1进行标准化
# 4 -> (4-2)/1 = 2
assert result["value"].to_list()[0] == pytest.approx(2.0, abs=1e-10)
# ========== 测试划分策略 ==========
class TestSplitters:
"""测试划分策略"""
def test_time_series_split(self):
"""测试时间序列划分"""
from src.models.core import TimeSeriesSplit
splitter = TimeSeriesSplit(n_splits=2, gap=1, min_train_size=3)
# 10天的数据
df = pl.DataFrame(
{
"trade_date": [f"202001{i:02d}" for i in range(1, 11)],
"value": list(range(10)),
}
)
splits = list(splitter.split(df))
# 应该有两折
assert len(splits) == 2
# 检查每折训练集在测试集之前
for train_idx, test_idx in splits:
assert max(train_idx) < min(test_idx)
def test_walk_forward_split(self):
"""测试滚动前向划分"""
from src.models.core import WalkForwardSplit
splitter = WalkForwardSplit(train_window=5, test_window=2, gap=1)
df = pl.DataFrame(
{
"trade_date": [f"202001{i:02d}" for i in range(1, 13)],
"value": list(range(12)),
}
)
splits = list(splitter.split(df))
# 检查训练集大小固定
for train_idx, test_idx in splits:
assert len(train_idx) == 5
assert len(test_idx) == 2
def test_expanding_window_split(self):
"""测试扩展窗口划分"""
from src.models.core import ExpandingWindowSplit
splitter = ExpandingWindowSplit(initial_train_size=3, test_window=2, gap=1)
df = pl.DataFrame(
{
"trade_date": [f"202001{i:02d}" for i in range(1, 15)],
"value": list(range(14)),
}
)
splits = list(splitter.split(df))
# 训练集应该逐渐增大
train_sizes = [len(train_idx) for train_idx, _ in splits]
assert train_sizes[0] == 3
assert train_sizes[1] == 5 # 3 + 2
assert train_sizes[2] == 7 # 5 + 2
# ========== 测试内置模型(可选,需要安装依赖) ==========
class TestModels:
"""测试内置模型(标记为跳过如果依赖未安装)"""
@pytest.mark.skip(reason="需要安装 lightgbm")
def test_lightgbm_model(self):
"""测试 LightGBM 模型"""
from src.models.models import LightGBMModel
model = LightGBMModel(task_type="regression", params={"n_estimators": 10})
X = pl.DataFrame(
{
"feature1": [1.0, 2.0, 3.0, 4.0, 5.0] * 10,
"feature2": [5.0, 4.0, 3.0, 2.0, 1.0] * 10,
}
)
y = pl.Series("target", [1.0, 2.0, 3.0, 4.0, 5.0] * 10)
model.fit(X, y)
predictions = model.predict(X)
assert len(predictions) == len(X)
assert model._is_fitted
if __name__ == "__main__":
pytest.main([__file__, "-v"])