"""模型框架核心测试 测试核心抽象类、插件注册中心、处理器、模型和划分策略。 """ import pytest import polars as pl import numpy as np from typing import List, Optional # 确保导入时注册所有组件 from src.models import ( PluginRegistry, PipelineStage, BaseProcessor, BaseModel, BaseSplitter, ProcessingPipeline, ) from src.models.core import TaskType # ========== 测试核心抽象类 ========== class TestPipelineStage: """测试阶段枚举""" def test_stage_values(self): assert PipelineStage.ALL.name == "ALL" assert PipelineStage.TRAIN.name == "TRAIN" assert PipelineStage.TEST.name == "TEST" assert PipelineStage.VALIDATION.name == "VALIDATION" class TestBaseProcessor: """测试处理器基类""" def test_processor_initialization(self): """测试处理器初始化""" class DummyProcessor(BaseProcessor): stage = PipelineStage.ALL def fit(self, data: pl.DataFrame) -> "DummyProcessor": self._is_fitted = True return self def transform(self, data: pl.DataFrame) -> pl.DataFrame: return data processor = DummyProcessor(columns=["col1", "col2"]) assert processor.columns == ["col1", "col2"] assert processor.stage == PipelineStage.ALL assert not processor._is_fitted def test_processor_fit_transform(self): """测试 fit_transform 方法""" class AddOneProcessor(BaseProcessor): stage = PipelineStage.ALL def fit(self, data: pl.DataFrame) -> "AddOneProcessor": self._is_fitted = True return self def transform(self, data: pl.DataFrame) -> pl.DataFrame: result = data.clone() for col in self.columns or []: result = result.with_columns((pl.col(col) + 1).alias(col)) return result processor = AddOneProcessor(columns=["value"]) df = pl.DataFrame({"value": [1, 2, 3]}) result = processor.fit_transform(df) assert processor._is_fitted assert result["value"].to_list() == [2, 3, 4] class TestBaseModel: """测试模型基类""" def test_model_initialization(self): """测试模型初始化""" class DummyModel(BaseModel): def fit(self, X, y, X_val=None, y_val=None, **kwargs): self._is_fitted = True return self def predict(self, X): return np.zeros(len(X)) model = DummyModel( task_type="regression", params={"lr": 0.01}, name="test_model" ) assert model.task_type == "regression" assert model.params == {"lr": 0.01} assert model.name == "test_model" assert not model._is_fitted def test_predict_proba_not_implemented(self): """测试未实现 predict_proba 时抛出异常""" class DummyModel(BaseModel): def fit(self, X, y, X_val=None, y_val=None, **kwargs): return self def predict(self, X): return np.zeros(len(X)) model = DummyModel(task_type="regression") df = pl.DataFrame({"feature": [1, 2, 3]}) with pytest.raises(NotImplementedError): model.predict_proba(df) class TestBaseSplitter: """测试划分策略基类""" def test_splitter_interface(self): """测试划分策略接口""" class DummySplitter(BaseSplitter): def split(self, data, date_col="trade_date"): yield [0, 1], [2, 3] def get_split_dates(self, data, date_col="trade_date"): return [("20200101", "20201231", "20210101", "20211231")] splitter = DummySplitter() df = pl.DataFrame( {"trade_date": ["20200101", "20200601", "20210101", "20210601"]} ) splits = list(splitter.split(df)) assert len(splits) == 1 assert splits[0] == ([0, 1], [2, 3]) dates = splitter.get_split_dates(df) assert dates == [("20200101", "20201231", "20210101", "20211231")] # ========== 测试插件注册中心 ========== class TestPluginRegistry: """测试插件注册中心""" def setup_method(self): """每个测试前清除注册""" PluginRegistry.clear_all() def test_register_and_get_processor(self): """测试注册和获取处理器""" @PluginRegistry.register_processor("test_processor") class TestProcessor(BaseProcessor): stage = PipelineStage.ALL def fit(self, data): return self def transform(self, data): return data processor_class = PluginRegistry.get_processor("test_processor") assert processor_class == TestProcessor assert "test_processor" in PluginRegistry.list_processors() def test_register_and_get_model(self): """测试注册和获取模型""" @PluginRegistry.register_model("test_model") class TestModel(BaseModel): def fit(self, X, y, X_val=None, y_val=None, **kwargs): return self def predict(self, X): return np.zeros(len(X)) model_class = PluginRegistry.get_model("test_model") assert model_class == TestModel assert "test_model" in PluginRegistry.list_models() def test_register_and_get_splitter(self): """测试注册和获取划分策略""" @PluginRegistry.register_splitter("test_splitter") class TestSplitter(BaseSplitter): def split(self, data, date_col="trade_date"): yield [], [] def get_split_dates(self, data, date_col="trade_date"): return [] splitter_class = PluginRegistry.get_splitter("test_splitter") assert splitter_class == TestSplitter assert "test_splitter" in PluginRegistry.list_splitters() def test_get_nonexistent_processor(self): """测试获取不存在的处理器时抛出异常""" with pytest.raises(KeyError) as exc_info: PluginRegistry.get_processor("nonexistent") assert "nonexistent" in str(exc_info.value) def test_register_with_default_name(self): """测试使用默认名称注册""" @PluginRegistry.register_processor() class MyCustomProcessor(BaseProcessor): stage = PipelineStage.ALL def fit(self, data): return self def transform(self, data): return data assert "MyCustomProcessor" in PluginRegistry.list_processors() # ========== 测试内置处理器 ========== class TestBuiltInProcessors: """测试内置处理器""" def test_dropna_processor(self): """测试缺失值删除处理器""" from src.models.processors import DropNAProcessor processor = DropNAProcessor(columns=["a", "b"]) df = pl.DataFrame({"a": [1, None, 3], "b": [4, 5, None], "c": [7, 8, 9]}) result = processor.fit_transform(df) # 只有第一行没有缺失值 assert len(result) == 1 assert result["a"].to_list() == [1] assert result["b"].to_list() == [4] def test_fillna_processor(self): """测试缺失值填充处理器""" from src.models.processors import FillNAProcessor processor = FillNAProcessor(columns=["a"], method="mean") df = pl.DataFrame({"a": [1.0, 2.0, None, 4.0]}) result = processor.fit_transform(df) # 均值 = (1+2+4)/3 = 2.333... assert result["a"][2] == pytest.approx(2.333, rel=0.01) def test_standard_scaler(self): """测试标准化处理器""" from src.models.processors import StandardScaler processor = StandardScaler(columns=["value"]) df = pl.DataFrame({"value": [1.0, 2.0, 3.0, 4.0, 5.0]}) result = processor.fit_transform(df) # Z-score 标准化后均值为0,标准差为1 assert result["value"].mean() == pytest.approx(0.0, abs=1e-10) assert result["value"].std() == pytest.approx(1.0, rel=0.01) def test_winsorizer(self): """测试缩尾处理器""" from src.models.processors import Winsorizer processor = Winsorizer(columns=["value"], lower=0.1, upper=0.9) df = pl.DataFrame( { "value": list(range(100)) # 0-99 } ) result = processor.fit_transform(df) # 10%和90%分位数应该是10和89(Polars的quantile行为) assert result["value"].min() == 10 assert result["value"].max() == 89 def test_rank_transformer(self): """测试排名转换处理器""" from src.models.processors import RankTransformer processor = RankTransformer(columns=["value"]) df = pl.DataFrame( {"trade_date": ["20200101"] * 5, "value": [10, 30, 20, 50, 40]} ) result = processor.fit_transform(df) # 排名应该是 1, 3, 2, 5, 4 assert result["value"].to_list() == [1, 3, 2, 5, 4] def test_neutralizer(self): """测试中性化处理器""" from src.models.processors import Neutralizer processor = Neutralizer(columns=["value"], group_col="industry") df = pl.DataFrame( { "trade_date": ["20200101", "20200101", "20200101", "20200101"], "industry": ["A", "A", "B", "B"], "value": [10, 20, 30, 50], } ) result = processor.fit_transform(df) # 分组去均值后,每组的均值为0 group_a = result.filter(pl.col("industry") == "A") group_b = result.filter(pl.col("industry") == "B") assert group_a["value"].mean() == pytest.approx(0.0, abs=1e-10) assert group_b["value"].mean() == pytest.approx(0.0, abs=1e-10) # ========== 测试处理流水线 ========== class TestProcessingPipeline: """测试处理流水线""" def test_pipeline_fit_transform(self): """测试流水线的 fit_transform""" from src.models.processors import StandardScaler scaler1 = StandardScaler(columns=["a"]) scaler2 = StandardScaler(columns=["b"]) pipeline = ProcessingPipeline([scaler1, scaler2]) df = pl.DataFrame({"a": [1.0, 2.0, 3.0], "b": [10.0, 20.0, 30.0]}) result = pipeline.fit_transform(df) # 两个列都应该被标准化 assert result["a"].mean() == pytest.approx(0.0, abs=1e-10) assert result["b"].mean() == pytest.approx(0.0, abs=1e-10) def test_pipeline_transform_uses_fitted_params(self): """测试 transform 使用已 fit 的参数""" from src.models.processors import StandardScaler scaler = StandardScaler(columns=["value"]) pipeline = ProcessingPipeline([scaler]) # 训练数据 train_df = pl.DataFrame( { "value": [1.0, 2.0, 3.0] # 均值=2,标准差=1 } ) # 测试数据(不同的分布) test_df = pl.DataFrame( { "value": [4.0, 5.0, 6.0] # 如果重新计算应该是均值=5 } ) pipeline.fit_transform(train_df) result = pipeline.transform(test_df) # 使用训练数据的均值=2和标准差=1进行标准化 # 4 -> (4-2)/1 = 2 assert result["value"].to_list()[0] == pytest.approx(2.0, abs=1e-10) # ========== 测试划分策略 ========== class TestSplitters: """测试划分策略""" def test_time_series_split(self): """测试时间序列划分""" from src.models.core import TimeSeriesSplit splitter = TimeSeriesSplit(n_splits=2, gap=1, min_train_size=3) # 10天的数据 df = pl.DataFrame( { "trade_date": [f"202001{i:02d}" for i in range(1, 11)], "value": list(range(10)), } ) splits = list(splitter.split(df)) # 应该有两折 assert len(splits) == 2 # 检查每折训练集在测试集之前 for train_idx, test_idx in splits: assert max(train_idx) < min(test_idx) def test_walk_forward_split(self): """测试滚动前向划分""" from src.models.core import WalkForwardSplit splitter = WalkForwardSplit(train_window=5, test_window=2, gap=1) df = pl.DataFrame( { "trade_date": [f"202001{i:02d}" for i in range(1, 13)], "value": list(range(12)), } ) splits = list(splitter.split(df)) # 检查训练集大小固定 for train_idx, test_idx in splits: assert len(train_idx) == 5 assert len(test_idx) == 2 def test_expanding_window_split(self): """测试扩展窗口划分""" from src.models.core import ExpandingWindowSplit splitter = ExpandingWindowSplit(initial_train_size=3, test_window=2, gap=1) df = pl.DataFrame( { "trade_date": [f"202001{i:02d}" for i in range(1, 15)], "value": list(range(14)), } ) splits = list(splitter.split(df)) # 训练集应该逐渐增大 train_sizes = [len(train_idx) for train_idx, _ in splits] assert train_sizes[0] == 3 assert train_sizes[1] == 5 # 3 + 2 assert train_sizes[2] == 7 # 5 + 2 # ========== 测试内置模型(可选,需要安装依赖) ========== class TestModels: """测试内置模型(标记为跳过如果依赖未安装)""" @pytest.mark.skip(reason="需要安装 lightgbm") def test_lightgbm_model(self): """测试 LightGBM 模型""" from src.models.models import LightGBMModel model = LightGBMModel(task_type="regression", params={"n_estimators": 10}) X = pl.DataFrame( { "feature1": [1.0, 2.0, 3.0, 4.0, 5.0] * 10, "feature2": [5.0, 4.0, 3.0, 2.0, 1.0] * 10, } ) y = pl.Series("target", [1.0, 2.0, 3.0, 4.0, 5.0] * 10) model.fit(X, y) predictions = model.predict(X) assert len(predictions) == len(X) assert model._is_fitted if __name__ == "__main__": pytest.main([__file__, "-v"])