- 存储层重构: HDF5 → DuckDB(UPSERT模式、线程安全存储) - Sync类迁移: DataSync从sync.py迁移到api_daily.py(职责分离) - 模型模块重构: src/models → src/pipeline(更清晰的命名) - 新增因子模块: factors/momentum (MA、收益率排名)、factors/financial - 新增API接口: api_namechange、api_bak_basic - 新增训练入口: training模块(main.py、pipeline配置) - 工具函数统一: get_today_date等移至utils.py - 文档更新: AGENTS.md添加架构变更历史
479 lines
14 KiB
Python
479 lines
14 KiB
Python
"""Pipeline 组件库核心测试
|
||
|
||
测试核心抽象类、插件注册中心、处理器、模型和划分策略。
|
||
"""
|
||
|
||
import pytest
|
||
import polars as pl
|
||
import numpy as np
|
||
from typing import List, Optional
|
||
|
||
# 确保导入时注册所有组件
|
||
from src.pipeline import (
|
||
PluginRegistry,
|
||
PipelineStage,
|
||
BaseProcessor,
|
||
BaseModel,
|
||
BaseSplitter,
|
||
ProcessingPipeline,
|
||
)
|
||
from src.pipeline.core import TaskType
|
||
|
||
|
||
# ========== 测试核心抽象类 ==========
|
||
|
||
|
||
class TestPipelineStage:
|
||
"""测试阶段枚举"""
|
||
|
||
def test_stage_values(self):
|
||
assert PipelineStage.ALL.name == "ALL"
|
||
assert PipelineStage.TRAIN.name == "TRAIN"
|
||
assert PipelineStage.TEST.name == "TEST"
|
||
assert PipelineStage.VALIDATION.name == "VALIDATION"
|
||
|
||
|
||
class TestBaseProcessor:
|
||
"""测试处理器基类"""
|
||
|
||
def test_processor_initialization(self):
|
||
"""测试处理器初始化"""
|
||
|
||
class DummyProcessor(BaseProcessor):
|
||
stage = PipelineStage.ALL
|
||
|
||
def fit(self, data: pl.DataFrame) -> "DummyProcessor":
|
||
self._is_fitted = True
|
||
return self
|
||
|
||
def transform(self, data: pl.DataFrame) -> pl.DataFrame:
|
||
return data
|
||
|
||
processor = DummyProcessor(columns=["col1", "col2"])
|
||
assert processor.columns == ["col1", "col2"]
|
||
assert processor.stage == PipelineStage.ALL
|
||
assert not processor._is_fitted
|
||
|
||
def test_processor_fit_transform(self):
|
||
"""测试 fit_transform 方法"""
|
||
|
||
class AddOneProcessor(BaseProcessor):
|
||
stage = PipelineStage.ALL
|
||
|
||
def fit(self, data: pl.DataFrame) -> "AddOneProcessor":
|
||
self._is_fitted = True
|
||
return self
|
||
|
||
def transform(self, data: pl.DataFrame) -> pl.DataFrame:
|
||
result = data.clone()
|
||
for col in self.columns or []:
|
||
result = result.with_columns((pl.col(col) + 1).alias(col))
|
||
return result
|
||
|
||
processor = AddOneProcessor(columns=["value"])
|
||
df = pl.DataFrame({"value": [1, 2, 3]})
|
||
|
||
result = processor.fit_transform(df)
|
||
|
||
assert processor._is_fitted
|
||
assert result["value"].to_list() == [2, 3, 4]
|
||
|
||
|
||
class TestBaseModel:
|
||
"""测试模型基类"""
|
||
|
||
def test_model_initialization(self):
|
||
"""测试模型初始化"""
|
||
|
||
class DummyModel(BaseModel):
|
||
def fit(self, X, y, X_val=None, y_val=None, **kwargs):
|
||
self._is_fitted = True
|
||
return self
|
||
|
||
def predict(self, X):
|
||
return np.zeros(len(X))
|
||
|
||
model = DummyModel(
|
||
task_type="regression", params={"lr": 0.01}, name="test_model"
|
||
)
|
||
|
||
assert model.task_type == "regression"
|
||
assert model.params == {"lr": 0.01}
|
||
assert model.name == "test_model"
|
||
assert not model._is_fitted
|
||
|
||
def test_predict_proba_not_implemented(self):
|
||
"""测试未实现 predict_proba 时抛出异常"""
|
||
|
||
class DummyModel(BaseModel):
|
||
def fit(self, X, y, X_val=None, y_val=None, **kwargs):
|
||
return self
|
||
|
||
def predict(self, X):
|
||
return np.zeros(len(X))
|
||
|
||
model = DummyModel(task_type="regression")
|
||
df = pl.DataFrame({"feature": [1, 2, 3]})
|
||
|
||
with pytest.raises(NotImplementedError):
|
||
model.predict_proba(df)
|
||
|
||
|
||
class TestBaseSplitter:
|
||
"""测试划分策略基类"""
|
||
|
||
def test_splitter_interface(self):
|
||
"""测试划分策略接口"""
|
||
|
||
class DummySplitter(BaseSplitter):
|
||
def split(self, data, date_col="trade_date"):
|
||
yield [0, 1], [2, 3]
|
||
|
||
def get_split_dates(self, data, date_col="trade_date"):
|
||
return [("20200101", "20201231", "20210101", "20211231")]
|
||
|
||
splitter = DummySplitter()
|
||
df = pl.DataFrame(
|
||
{"trade_date": ["20200101", "20200601", "20210101", "20210601"]}
|
||
)
|
||
|
||
splits = list(splitter.split(df))
|
||
assert len(splits) == 1
|
||
assert splits[0] == ([0, 1], [2, 3])
|
||
|
||
dates = splitter.get_split_dates(df)
|
||
assert dates == [("20200101", "20201231", "20210101", "20211231")]
|
||
|
||
|
||
# ========== 测试插件注册中心 ==========
|
||
|
||
|
||
class TestPluginRegistry:
|
||
"""测试插件注册中心"""
|
||
|
||
def setup_method(self):
|
||
"""每个测试前清除注册"""
|
||
PluginRegistry.clear_all()
|
||
|
||
def test_register_and_get_processor(self):
|
||
"""测试注册和获取处理器"""
|
||
|
||
@PluginRegistry.register_processor("test_processor")
|
||
class TestProcessor(BaseProcessor):
|
||
stage = PipelineStage.ALL
|
||
|
||
def fit(self, data):
|
||
return self
|
||
|
||
def transform(self, data):
|
||
return data
|
||
|
||
processor_class = PluginRegistry.get_processor("test_processor")
|
||
assert processor_class == TestProcessor
|
||
assert "test_processor" in PluginRegistry.list_processors()
|
||
|
||
def test_register_and_get_model(self):
|
||
"""测试注册和获取模型"""
|
||
|
||
@PluginRegistry.register_model("test_model")
|
||
class TestModel(BaseModel):
|
||
def fit(self, X, y, X_val=None, y_val=None, **kwargs):
|
||
return self
|
||
|
||
def predict(self, X):
|
||
return np.zeros(len(X))
|
||
|
||
model_class = PluginRegistry.get_model("test_model")
|
||
assert model_class == TestModel
|
||
assert "test_model" in PluginRegistry.list_models()
|
||
|
||
def test_register_and_get_splitter(self):
|
||
"""测试注册和获取划分策略"""
|
||
|
||
@PluginRegistry.register_splitter("test_splitter")
|
||
class TestSplitter(BaseSplitter):
|
||
def split(self, data, date_col="trade_date"):
|
||
yield [], []
|
||
|
||
def get_split_dates(self, data, date_col="trade_date"):
|
||
return []
|
||
|
||
splitter_class = PluginRegistry.get_splitter("test_splitter")
|
||
assert splitter_class == TestSplitter
|
||
assert "test_splitter" in PluginRegistry.list_splitters()
|
||
|
||
def test_get_nonexistent_processor(self):
|
||
"""测试获取不存在的处理器时抛出异常"""
|
||
with pytest.raises(KeyError) as exc_info:
|
||
PluginRegistry.get_processor("nonexistent")
|
||
assert "nonexistent" in str(exc_info.value)
|
||
|
||
def test_register_with_default_name(self):
|
||
"""测试使用默认名称注册"""
|
||
|
||
@PluginRegistry.register_processor()
|
||
class MyCustomProcessor(BaseProcessor):
|
||
stage = PipelineStage.ALL
|
||
|
||
def fit(self, data):
|
||
return self
|
||
|
||
def transform(self, data):
|
||
return data
|
||
|
||
assert "MyCustomProcessor" in PluginRegistry.list_processors()
|
||
|
||
|
||
# ========== 测试内置处理器 ==========
|
||
|
||
|
||
class TestBuiltInProcessors:
|
||
"""测试内置处理器"""
|
||
|
||
def test_dropna_processor(self):
|
||
"""测试缺失值删除处理器"""
|
||
from src.pipeline.processors import DropNAProcessor
|
||
|
||
processor = DropNAProcessor(columns=["a", "b"])
|
||
df = pl.DataFrame({"a": [1, None, 3], "b": [4, 5, None], "c": [7, 8, 9]})
|
||
|
||
result = processor.fit_transform(df)
|
||
|
||
# 只有第一行没有缺失值
|
||
assert len(result) == 1
|
||
assert result["a"].to_list() == [1]
|
||
assert result["b"].to_list() == [4]
|
||
|
||
def test_fillna_processor(self):
|
||
"""测试缺失值填充处理器"""
|
||
from src.pipeline.processors import FillNAProcessor
|
||
|
||
processor = FillNAProcessor(columns=["a"], method="mean")
|
||
df = pl.DataFrame({"a": [1.0, 2.0, None, 4.0]})
|
||
|
||
result = processor.fit_transform(df)
|
||
|
||
# 均值 = (1+2+4)/3 = 2.333...
|
||
assert result["a"][2] == pytest.approx(2.333, rel=0.01)
|
||
|
||
def test_standard_scaler(self):
|
||
"""测试标准化处理器"""
|
||
from src.pipeline.processors import StandardScaler
|
||
|
||
processor = StandardScaler(columns=["value"])
|
||
df = pl.DataFrame({"value": [1.0, 2.0, 3.0, 4.0, 5.0]})
|
||
|
||
result = processor.fit_transform(df)
|
||
|
||
# Z-score 标准化后均值为0,标准差为1
|
||
assert result["value"].mean() == pytest.approx(0.0, abs=1e-10)
|
||
assert result["value"].std() == pytest.approx(1.0, rel=0.01)
|
||
|
||
def test_winsorizer(self):
|
||
"""测试缩尾处理器"""
|
||
from src.pipeline.processors import Winsorizer
|
||
|
||
processor = Winsorizer(columns=["value"], lower=0.1, upper=0.9)
|
||
df = pl.DataFrame(
|
||
{
|
||
"value": list(range(100)) # 0-99
|
||
}
|
||
)
|
||
|
||
result = processor.fit_transform(df)
|
||
|
||
# 10%和90%分位数应该是10和89(Polars的quantile行为)
|
||
assert result["value"].min() == 10
|
||
assert result["value"].max() == 89
|
||
|
||
def test_rank_transformer(self):
|
||
"""测试排名转换处理器"""
|
||
from src.pipeline.processors import RankTransformer
|
||
|
||
processor = RankTransformer(columns=["value"])
|
||
df = pl.DataFrame(
|
||
{"trade_date": ["20200101"] * 5, "value": [10, 30, 20, 50, 40]}
|
||
)
|
||
|
||
result = processor.fit_transform(df)
|
||
|
||
# 排名应该是 1, 3, 2, 5, 4
|
||
assert result["value"].to_list() == [1, 3, 2, 5, 4]
|
||
|
||
def test_neutralizer(self):
|
||
"""测试中性化处理器"""
|
||
from src.pipeline.processors import Neutralizer
|
||
|
||
processor = Neutralizer(columns=["value"], group_col="industry")
|
||
df = pl.DataFrame(
|
||
{
|
||
"trade_date": ["20200101", "20200101", "20200101", "20200101"],
|
||
"industry": ["A", "A", "B", "B"],
|
||
"value": [10, 20, 30, 50],
|
||
}
|
||
)
|
||
|
||
result = processor.fit_transform(df)
|
||
|
||
# 分组去均值后,每组的均值为0
|
||
group_a = result.filter(pl.col("industry") == "A")
|
||
group_b = result.filter(pl.col("industry") == "B")
|
||
|
||
assert group_a["value"].mean() == pytest.approx(0.0, abs=1e-10)
|
||
assert group_b["value"].mean() == pytest.approx(0.0, abs=1e-10)
|
||
|
||
|
||
# ========== 测试处理流水线 ==========
|
||
|
||
|
||
class TestProcessingPipeline:
|
||
"""测试处理流水线"""
|
||
|
||
def test_pipeline_fit_transform(self):
|
||
"""测试流水线的 fit_transform"""
|
||
from src.pipeline.processors import StandardScaler
|
||
|
||
scaler1 = StandardScaler(columns=["a"])
|
||
scaler2 = StandardScaler(columns=["b"])
|
||
|
||
pipeline = ProcessingPipeline([scaler1, scaler2])
|
||
|
||
df = pl.DataFrame({"a": [1.0, 2.0, 3.0], "b": [10.0, 20.0, 30.0]})
|
||
|
||
result = pipeline.fit_transform(df)
|
||
|
||
# 两个列都应该被标准化
|
||
assert result["a"].mean() == pytest.approx(0.0, abs=1e-10)
|
||
assert result["b"].mean() == pytest.approx(0.0, abs=1e-10)
|
||
|
||
def test_pipeline_transform_uses_fitted_params(self):
|
||
"""测试 transform 使用已 fit 的参数"""
|
||
from src.pipeline.processors import StandardScaler
|
||
|
||
scaler = StandardScaler(columns=["value"])
|
||
pipeline = ProcessingPipeline([scaler])
|
||
|
||
# 训练数据
|
||
train_df = pl.DataFrame(
|
||
{
|
||
"value": [1.0, 2.0, 3.0] # 均值=2,标准差=1
|
||
}
|
||
)
|
||
|
||
# 测试数据(不同的分布)
|
||
test_df = pl.DataFrame(
|
||
{
|
||
"value": [4.0, 5.0, 6.0] # 如果重新计算应该是均值=5
|
||
}
|
||
)
|
||
|
||
pipeline.fit_transform(train_df)
|
||
result = pipeline.transform(test_df)
|
||
|
||
# 使用训练数据的均值=2和标准差=1进行标准化
|
||
# 4 -> (4-2)/1 = 2
|
||
assert result["value"].to_list()[0] == pytest.approx(2.0, abs=1e-10)
|
||
|
||
|
||
# ========== 测试划分策略 ==========
|
||
|
||
|
||
class TestSplitters:
|
||
"""测试划分策略"""
|
||
|
||
def test_time_series_split(self):
|
||
"""测试时间序列划分"""
|
||
from src.pipeline.core import TimeSeriesSplit
|
||
|
||
splitter = TimeSeriesSplit(n_splits=2, gap=1, min_train_size=3)
|
||
|
||
# 10天的数据
|
||
df = pl.DataFrame(
|
||
{
|
||
"trade_date": [f"202001{i:02d}" for i in range(1, 11)],
|
||
"value": list(range(10)),
|
||
}
|
||
)
|
||
|
||
splits = list(splitter.split(df))
|
||
|
||
# 应该有两折
|
||
assert len(splits) == 2
|
||
|
||
# 检查每折训练集在测试集之前
|
||
for train_idx, test_idx in splits:
|
||
assert max(train_idx) < min(test_idx)
|
||
|
||
def test_walk_forward_split(self):
|
||
"""测试滚动前向划分"""
|
||
from src.pipeline.core import WalkForwardSplit
|
||
|
||
splitter = WalkForwardSplit(train_window=5, test_window=2, gap=1)
|
||
|
||
df = pl.DataFrame(
|
||
{
|
||
"trade_date": [f"202001{i:02d}" for i in range(1, 13)],
|
||
"value": list(range(12)),
|
||
}
|
||
)
|
||
|
||
splits = list(splitter.split(df))
|
||
|
||
# 检查训练集大小固定
|
||
for train_idx, test_idx in splits:
|
||
assert len(train_idx) == 5
|
||
assert len(test_idx) == 2
|
||
|
||
def test_expanding_window_split(self):
|
||
"""测试扩展窗口划分"""
|
||
from src.pipeline.core import ExpandingWindowSplit
|
||
|
||
splitter = ExpandingWindowSplit(initial_train_size=3, test_window=2, gap=1)
|
||
|
||
df = pl.DataFrame(
|
||
{
|
||
"trade_date": [f"202001{i:02d}" for i in range(1, 15)],
|
||
"value": list(range(14)),
|
||
}
|
||
)
|
||
|
||
splits = list(splitter.split(df))
|
||
|
||
# 训练集应该逐渐增大
|
||
train_sizes = [len(train_idx) for train_idx, _ in splits]
|
||
assert train_sizes[0] == 3
|
||
assert train_sizes[1] == 5 # 3 + 2
|
||
assert train_sizes[2] == 7 # 5 + 2
|
||
|
||
|
||
# ========== 测试内置模型(可选,需要安装依赖) ==========
|
||
|
||
|
||
class TestModels:
|
||
"""测试内置模型(标记为跳过如果依赖未安装)"""
|
||
|
||
@pytest.mark.skip(reason="需要安装 lightgbm")
|
||
def test_lightgbm_model(self):
|
||
"""测试 LightGBM 模型"""
|
||
from src.pipeline.models import LightGBMModel
|
||
|
||
model = LightGBMModel(task_type="regression", params={"n_estimators": 10})
|
||
|
||
X = pl.DataFrame(
|
||
{
|
||
"feature1": [1.0, 2.0, 3.0, 4.0, 5.0] * 10,
|
||
"feature2": [5.0, 4.0, 3.0, 2.0, 1.0] * 10,
|
||
}
|
||
)
|
||
y = pl.Series("target", [1.0, 2.0, 3.0, 4.0, 5.0] * 10)
|
||
|
||
model.fit(X, y)
|
||
predictions = model.predict(X)
|
||
|
||
assert len(predictions) == len(X)
|
||
assert model._is_fitted
|
||
|
||
|
||
if __name__ == "__main__":
|
||
pytest.main([__file__, "-v"])
|