2026-02-23 16:23:53 +08:00
|
|
|
|
"""Pipeline 组件库核心测试
|
2026-02-23 01:37:34 +08:00
|
|
|
|
|
|
|
|
|
|
测试核心抽象类、插件注册中心、处理器、模型和划分策略。
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
import pytest
|
|
|
|
|
|
import polars as pl
|
|
|
|
|
|
import numpy as np
|
|
|
|
|
|
from typing import List, Optional
|
|
|
|
|
|
|
|
|
|
|
|
# 确保导入时注册所有组件
|
2026-02-23 16:23:53 +08:00
|
|
|
|
from src.pipeline import (
|
2026-02-23 01:37:34 +08:00
|
|
|
|
PluginRegistry,
|
|
|
|
|
|
PipelineStage,
|
|
|
|
|
|
BaseProcessor,
|
|
|
|
|
|
BaseModel,
|
|
|
|
|
|
BaseSplitter,
|
|
|
|
|
|
ProcessingPipeline,
|
|
|
|
|
|
)
|
2026-02-23 16:23:53 +08:00
|
|
|
|
from src.pipeline.core import TaskType
|
2026-02-23 01:37:34 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# ========== 测试核心抽象类 ==========
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TestPipelineStage:
|
|
|
|
|
|
"""测试阶段枚举"""
|
|
|
|
|
|
|
|
|
|
|
|
def test_stage_values(self):
|
|
|
|
|
|
assert PipelineStage.ALL.name == "ALL"
|
|
|
|
|
|
assert PipelineStage.TRAIN.name == "TRAIN"
|
|
|
|
|
|
assert PipelineStage.TEST.name == "TEST"
|
|
|
|
|
|
assert PipelineStage.VALIDATION.name == "VALIDATION"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TestBaseProcessor:
|
|
|
|
|
|
"""测试处理器基类"""
|
|
|
|
|
|
|
|
|
|
|
|
def test_processor_initialization(self):
|
|
|
|
|
|
"""测试处理器初始化"""
|
|
|
|
|
|
|
|
|
|
|
|
class DummyProcessor(BaseProcessor):
|
|
|
|
|
|
stage = PipelineStage.ALL
|
|
|
|
|
|
|
|
|
|
|
|
def fit(self, data: pl.DataFrame) -> "DummyProcessor":
|
|
|
|
|
|
self._is_fitted = True
|
|
|
|
|
|
return self
|
|
|
|
|
|
|
|
|
|
|
|
def transform(self, data: pl.DataFrame) -> pl.DataFrame:
|
|
|
|
|
|
return data
|
|
|
|
|
|
|
|
|
|
|
|
processor = DummyProcessor(columns=["col1", "col2"])
|
|
|
|
|
|
assert processor.columns == ["col1", "col2"]
|
|
|
|
|
|
assert processor.stage == PipelineStage.ALL
|
|
|
|
|
|
assert not processor._is_fitted
|
|
|
|
|
|
|
|
|
|
|
|
def test_processor_fit_transform(self):
|
|
|
|
|
|
"""测试 fit_transform 方法"""
|
|
|
|
|
|
|
|
|
|
|
|
class AddOneProcessor(BaseProcessor):
|
|
|
|
|
|
stage = PipelineStage.ALL
|
|
|
|
|
|
|
|
|
|
|
|
def fit(self, data: pl.DataFrame) -> "AddOneProcessor":
|
|
|
|
|
|
self._is_fitted = True
|
|
|
|
|
|
return self
|
|
|
|
|
|
|
|
|
|
|
|
def transform(self, data: pl.DataFrame) -> pl.DataFrame:
|
|
|
|
|
|
result = data.clone()
|
|
|
|
|
|
for col in self.columns or []:
|
|
|
|
|
|
result = result.with_columns((pl.col(col) + 1).alias(col))
|
|
|
|
|
|
return result
|
|
|
|
|
|
|
|
|
|
|
|
processor = AddOneProcessor(columns=["value"])
|
|
|
|
|
|
df = pl.DataFrame({"value": [1, 2, 3]})
|
|
|
|
|
|
|
|
|
|
|
|
result = processor.fit_transform(df)
|
|
|
|
|
|
|
|
|
|
|
|
assert processor._is_fitted
|
|
|
|
|
|
assert result["value"].to_list() == [2, 3, 4]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TestBaseModel:
|
|
|
|
|
|
"""测试模型基类"""
|
|
|
|
|
|
|
|
|
|
|
|
def test_model_initialization(self):
|
|
|
|
|
|
"""测试模型初始化"""
|
|
|
|
|
|
|
|
|
|
|
|
class DummyModel(BaseModel):
|
|
|
|
|
|
def fit(self, X, y, X_val=None, y_val=None, **kwargs):
|
|
|
|
|
|
self._is_fitted = True
|
|
|
|
|
|
return self
|
|
|
|
|
|
|
|
|
|
|
|
def predict(self, X):
|
|
|
|
|
|
return np.zeros(len(X))
|
|
|
|
|
|
|
|
|
|
|
|
model = DummyModel(
|
|
|
|
|
|
task_type="regression", params={"lr": 0.01}, name="test_model"
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
assert model.task_type == "regression"
|
|
|
|
|
|
assert model.params == {"lr": 0.01}
|
|
|
|
|
|
assert model.name == "test_model"
|
|
|
|
|
|
assert not model._is_fitted
|
|
|
|
|
|
|
|
|
|
|
|
def test_predict_proba_not_implemented(self):
|
|
|
|
|
|
"""测试未实现 predict_proba 时抛出异常"""
|
|
|
|
|
|
|
|
|
|
|
|
class DummyModel(BaseModel):
|
|
|
|
|
|
def fit(self, X, y, X_val=None, y_val=None, **kwargs):
|
|
|
|
|
|
return self
|
|
|
|
|
|
|
|
|
|
|
|
def predict(self, X):
|
|
|
|
|
|
return np.zeros(len(X))
|
|
|
|
|
|
|
|
|
|
|
|
model = DummyModel(task_type="regression")
|
|
|
|
|
|
df = pl.DataFrame({"feature": [1, 2, 3]})
|
|
|
|
|
|
|
|
|
|
|
|
with pytest.raises(NotImplementedError):
|
|
|
|
|
|
model.predict_proba(df)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TestBaseSplitter:
|
|
|
|
|
|
"""测试划分策略基类"""
|
|
|
|
|
|
|
|
|
|
|
|
def test_splitter_interface(self):
|
|
|
|
|
|
"""测试划分策略接口"""
|
|
|
|
|
|
|
|
|
|
|
|
class DummySplitter(BaseSplitter):
|
|
|
|
|
|
def split(self, data, date_col="trade_date"):
|
|
|
|
|
|
yield [0, 1], [2, 3]
|
|
|
|
|
|
|
|
|
|
|
|
def get_split_dates(self, data, date_col="trade_date"):
|
|
|
|
|
|
return [("20200101", "20201231", "20210101", "20211231")]
|
|
|
|
|
|
|
|
|
|
|
|
splitter = DummySplitter()
|
|
|
|
|
|
df = pl.DataFrame(
|
|
|
|
|
|
{"trade_date": ["20200101", "20200601", "20210101", "20210601"]}
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
splits = list(splitter.split(df))
|
|
|
|
|
|
assert len(splits) == 1
|
|
|
|
|
|
assert splits[0] == ([0, 1], [2, 3])
|
|
|
|
|
|
|
|
|
|
|
|
dates = splitter.get_split_dates(df)
|
|
|
|
|
|
assert dates == [("20200101", "20201231", "20210101", "20211231")]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# ========== 测试插件注册中心 ==========
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TestPluginRegistry:
|
|
|
|
|
|
"""测试插件注册中心"""
|
|
|
|
|
|
|
|
|
|
|
|
def setup_method(self):
|
|
|
|
|
|
"""每个测试前清除注册"""
|
|
|
|
|
|
PluginRegistry.clear_all()
|
|
|
|
|
|
|
|
|
|
|
|
def test_register_and_get_processor(self):
|
|
|
|
|
|
"""测试注册和获取处理器"""
|
|
|
|
|
|
|
|
|
|
|
|
@PluginRegistry.register_processor("test_processor")
|
|
|
|
|
|
class TestProcessor(BaseProcessor):
|
|
|
|
|
|
stage = PipelineStage.ALL
|
|
|
|
|
|
|
|
|
|
|
|
def fit(self, data):
|
|
|
|
|
|
return self
|
|
|
|
|
|
|
|
|
|
|
|
def transform(self, data):
|
|
|
|
|
|
return data
|
|
|
|
|
|
|
|
|
|
|
|
processor_class = PluginRegistry.get_processor("test_processor")
|
|
|
|
|
|
assert processor_class == TestProcessor
|
|
|
|
|
|
assert "test_processor" in PluginRegistry.list_processors()
|
|
|
|
|
|
|
|
|
|
|
|
def test_register_and_get_model(self):
|
|
|
|
|
|
"""测试注册和获取模型"""
|
|
|
|
|
|
|
|
|
|
|
|
@PluginRegistry.register_model("test_model")
|
|
|
|
|
|
class TestModel(BaseModel):
|
|
|
|
|
|
def fit(self, X, y, X_val=None, y_val=None, **kwargs):
|
|
|
|
|
|
return self
|
|
|
|
|
|
|
|
|
|
|
|
def predict(self, X):
|
|
|
|
|
|
return np.zeros(len(X))
|
|
|
|
|
|
|
|
|
|
|
|
model_class = PluginRegistry.get_model("test_model")
|
|
|
|
|
|
assert model_class == TestModel
|
|
|
|
|
|
assert "test_model" in PluginRegistry.list_models()
|
|
|
|
|
|
|
|
|
|
|
|
def test_register_and_get_splitter(self):
|
|
|
|
|
|
"""测试注册和获取划分策略"""
|
|
|
|
|
|
|
|
|
|
|
|
@PluginRegistry.register_splitter("test_splitter")
|
|
|
|
|
|
class TestSplitter(BaseSplitter):
|
|
|
|
|
|
def split(self, data, date_col="trade_date"):
|
|
|
|
|
|
yield [], []
|
|
|
|
|
|
|
|
|
|
|
|
def get_split_dates(self, data, date_col="trade_date"):
|
|
|
|
|
|
return []
|
|
|
|
|
|
|
|
|
|
|
|
splitter_class = PluginRegistry.get_splitter("test_splitter")
|
|
|
|
|
|
assert splitter_class == TestSplitter
|
|
|
|
|
|
assert "test_splitter" in PluginRegistry.list_splitters()
|
|
|
|
|
|
|
|
|
|
|
|
def test_get_nonexistent_processor(self):
|
|
|
|
|
|
"""测试获取不存在的处理器时抛出异常"""
|
|
|
|
|
|
with pytest.raises(KeyError) as exc_info:
|
|
|
|
|
|
PluginRegistry.get_processor("nonexistent")
|
|
|
|
|
|
assert "nonexistent" in str(exc_info.value)
|
|
|
|
|
|
|
|
|
|
|
|
def test_register_with_default_name(self):
|
|
|
|
|
|
"""测试使用默认名称注册"""
|
|
|
|
|
|
|
|
|
|
|
|
@PluginRegistry.register_processor()
|
|
|
|
|
|
class MyCustomProcessor(BaseProcessor):
|
|
|
|
|
|
stage = PipelineStage.ALL
|
|
|
|
|
|
|
|
|
|
|
|
def fit(self, data):
|
|
|
|
|
|
return self
|
|
|
|
|
|
|
|
|
|
|
|
def transform(self, data):
|
|
|
|
|
|
return data
|
|
|
|
|
|
|
|
|
|
|
|
assert "MyCustomProcessor" in PluginRegistry.list_processors()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# ========== 测试内置处理器 ==========
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TestBuiltInProcessors:
|
|
|
|
|
|
"""测试内置处理器"""
|
|
|
|
|
|
|
|
|
|
|
|
def test_dropna_processor(self):
|
|
|
|
|
|
"""测试缺失值删除处理器"""
|
2026-02-23 16:23:53 +08:00
|
|
|
|
from src.pipeline.processors import DropNAProcessor
|
2026-02-23 01:37:34 +08:00
|
|
|
|
|
|
|
|
|
|
processor = DropNAProcessor(columns=["a", "b"])
|
|
|
|
|
|
df = pl.DataFrame({"a": [1, None, 3], "b": [4, 5, None], "c": [7, 8, 9]})
|
|
|
|
|
|
|
|
|
|
|
|
result = processor.fit_transform(df)
|
|
|
|
|
|
|
|
|
|
|
|
# 只有第一行没有缺失值
|
|
|
|
|
|
assert len(result) == 1
|
|
|
|
|
|
assert result["a"].to_list() == [1]
|
|
|
|
|
|
assert result["b"].to_list() == [4]
|
|
|
|
|
|
|
|
|
|
|
|
def test_fillna_processor(self):
|
|
|
|
|
|
"""测试缺失值填充处理器"""
|
2026-02-23 16:23:53 +08:00
|
|
|
|
from src.pipeline.processors import FillNAProcessor
|
2026-02-23 01:37:34 +08:00
|
|
|
|
|
|
|
|
|
|
processor = FillNAProcessor(columns=["a"], method="mean")
|
|
|
|
|
|
df = pl.DataFrame({"a": [1.0, 2.0, None, 4.0]})
|
|
|
|
|
|
|
|
|
|
|
|
result = processor.fit_transform(df)
|
|
|
|
|
|
|
|
|
|
|
|
# 均值 = (1+2+4)/3 = 2.333...
|
|
|
|
|
|
assert result["a"][2] == pytest.approx(2.333, rel=0.01)
|
|
|
|
|
|
|
|
|
|
|
|
def test_standard_scaler(self):
|
|
|
|
|
|
"""测试标准化处理器"""
|
2026-02-23 16:23:53 +08:00
|
|
|
|
from src.pipeline.processors import StandardScaler
|
2026-02-23 01:37:34 +08:00
|
|
|
|
|
|
|
|
|
|
processor = StandardScaler(columns=["value"])
|
|
|
|
|
|
df = pl.DataFrame({"value": [1.0, 2.0, 3.0, 4.0, 5.0]})
|
|
|
|
|
|
|
|
|
|
|
|
result = processor.fit_transform(df)
|
|
|
|
|
|
|
|
|
|
|
|
# Z-score 标准化后均值为0,标准差为1
|
|
|
|
|
|
assert result["value"].mean() == pytest.approx(0.0, abs=1e-10)
|
|
|
|
|
|
assert result["value"].std() == pytest.approx(1.0, rel=0.01)
|
|
|
|
|
|
|
|
|
|
|
|
def test_winsorizer(self):
|
|
|
|
|
|
"""测试缩尾处理器"""
|
2026-02-23 16:23:53 +08:00
|
|
|
|
from src.pipeline.processors import Winsorizer
|
2026-02-23 01:37:34 +08:00
|
|
|
|
|
|
|
|
|
|
processor = Winsorizer(columns=["value"], lower=0.1, upper=0.9)
|
|
|
|
|
|
df = pl.DataFrame(
|
|
|
|
|
|
{
|
|
|
|
|
|
"value": list(range(100)) # 0-99
|
|
|
|
|
|
}
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
result = processor.fit_transform(df)
|
|
|
|
|
|
|
|
|
|
|
|
# 10%和90%分位数应该是10和89(Polars的quantile行为)
|
|
|
|
|
|
assert result["value"].min() == 10
|
|
|
|
|
|
assert result["value"].max() == 89
|
|
|
|
|
|
|
|
|
|
|
|
def test_rank_transformer(self):
|
|
|
|
|
|
"""测试排名转换处理器"""
|
2026-02-23 16:23:53 +08:00
|
|
|
|
from src.pipeline.processors import RankTransformer
|
2026-02-23 01:37:34 +08:00
|
|
|
|
|
|
|
|
|
|
processor = RankTransformer(columns=["value"])
|
|
|
|
|
|
df = pl.DataFrame(
|
|
|
|
|
|
{"trade_date": ["20200101"] * 5, "value": [10, 30, 20, 50, 40]}
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
result = processor.fit_transform(df)
|
|
|
|
|
|
|
|
|
|
|
|
# 排名应该是 1, 3, 2, 5, 4
|
|
|
|
|
|
assert result["value"].to_list() == [1, 3, 2, 5, 4]
|
|
|
|
|
|
|
|
|
|
|
|
def test_neutralizer(self):
|
|
|
|
|
|
"""测试中性化处理器"""
|
2026-02-23 16:23:53 +08:00
|
|
|
|
from src.pipeline.processors import Neutralizer
|
2026-02-23 01:37:34 +08:00
|
|
|
|
|
|
|
|
|
|
processor = Neutralizer(columns=["value"], group_col="industry")
|
|
|
|
|
|
df = pl.DataFrame(
|
|
|
|
|
|
{
|
|
|
|
|
|
"trade_date": ["20200101", "20200101", "20200101", "20200101"],
|
|
|
|
|
|
"industry": ["A", "A", "B", "B"],
|
|
|
|
|
|
"value": [10, 20, 30, 50],
|
|
|
|
|
|
}
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
result = processor.fit_transform(df)
|
|
|
|
|
|
|
|
|
|
|
|
# 分组去均值后,每组的均值为0
|
|
|
|
|
|
group_a = result.filter(pl.col("industry") == "A")
|
|
|
|
|
|
group_b = result.filter(pl.col("industry") == "B")
|
|
|
|
|
|
|
|
|
|
|
|
assert group_a["value"].mean() == pytest.approx(0.0, abs=1e-10)
|
|
|
|
|
|
assert group_b["value"].mean() == pytest.approx(0.0, abs=1e-10)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# ========== 测试处理流水线 ==========
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TestProcessingPipeline:
|
|
|
|
|
|
"""测试处理流水线"""
|
|
|
|
|
|
|
|
|
|
|
|
def test_pipeline_fit_transform(self):
|
|
|
|
|
|
"""测试流水线的 fit_transform"""
|
2026-02-23 16:23:53 +08:00
|
|
|
|
from src.pipeline.processors import StandardScaler
|
2026-02-23 01:37:34 +08:00
|
|
|
|
|
|
|
|
|
|
scaler1 = StandardScaler(columns=["a"])
|
|
|
|
|
|
scaler2 = StandardScaler(columns=["b"])
|
|
|
|
|
|
|
|
|
|
|
|
pipeline = ProcessingPipeline([scaler1, scaler2])
|
|
|
|
|
|
|
|
|
|
|
|
df = pl.DataFrame({"a": [1.0, 2.0, 3.0], "b": [10.0, 20.0, 30.0]})
|
|
|
|
|
|
|
|
|
|
|
|
result = pipeline.fit_transform(df)
|
|
|
|
|
|
|
|
|
|
|
|
# 两个列都应该被标准化
|
|
|
|
|
|
assert result["a"].mean() == pytest.approx(0.0, abs=1e-10)
|
|
|
|
|
|
assert result["b"].mean() == pytest.approx(0.0, abs=1e-10)
|
|
|
|
|
|
|
|
|
|
|
|
def test_pipeline_transform_uses_fitted_params(self):
|
|
|
|
|
|
"""测试 transform 使用已 fit 的参数"""
|
2026-02-23 16:23:53 +08:00
|
|
|
|
from src.pipeline.processors import StandardScaler
|
2026-02-23 01:37:34 +08:00
|
|
|
|
|
|
|
|
|
|
scaler = StandardScaler(columns=["value"])
|
|
|
|
|
|
pipeline = ProcessingPipeline([scaler])
|
|
|
|
|
|
|
|
|
|
|
|
# 训练数据
|
|
|
|
|
|
train_df = pl.DataFrame(
|
|
|
|
|
|
{
|
|
|
|
|
|
"value": [1.0, 2.0, 3.0] # 均值=2,标准差=1
|
|
|
|
|
|
}
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
# 测试数据(不同的分布)
|
|
|
|
|
|
test_df = pl.DataFrame(
|
|
|
|
|
|
{
|
|
|
|
|
|
"value": [4.0, 5.0, 6.0] # 如果重新计算应该是均值=5
|
|
|
|
|
|
}
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
pipeline.fit_transform(train_df)
|
|
|
|
|
|
result = pipeline.transform(test_df)
|
|
|
|
|
|
|
|
|
|
|
|
# 使用训练数据的均值=2和标准差=1进行标准化
|
|
|
|
|
|
# 4 -> (4-2)/1 = 2
|
|
|
|
|
|
assert result["value"].to_list()[0] == pytest.approx(2.0, abs=1e-10)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# ========== 测试划分策略 ==========
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TestSplitters:
|
|
|
|
|
|
"""测试划分策略"""
|
|
|
|
|
|
|
|
|
|
|
|
def test_time_series_split(self):
|
|
|
|
|
|
"""测试时间序列划分"""
|
2026-02-23 16:23:53 +08:00
|
|
|
|
from src.pipeline.core import TimeSeriesSplit
|
2026-02-23 01:37:34 +08:00
|
|
|
|
|
|
|
|
|
|
splitter = TimeSeriesSplit(n_splits=2, gap=1, min_train_size=3)
|
|
|
|
|
|
|
|
|
|
|
|
# 10天的数据
|
|
|
|
|
|
df = pl.DataFrame(
|
|
|
|
|
|
{
|
|
|
|
|
|
"trade_date": [f"202001{i:02d}" for i in range(1, 11)],
|
|
|
|
|
|
"value": list(range(10)),
|
|
|
|
|
|
}
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
splits = list(splitter.split(df))
|
|
|
|
|
|
|
|
|
|
|
|
# 应该有两折
|
|
|
|
|
|
assert len(splits) == 2
|
|
|
|
|
|
|
|
|
|
|
|
# 检查每折训练集在测试集之前
|
|
|
|
|
|
for train_idx, test_idx in splits:
|
|
|
|
|
|
assert max(train_idx) < min(test_idx)
|
|
|
|
|
|
|
|
|
|
|
|
def test_walk_forward_split(self):
|
|
|
|
|
|
"""测试滚动前向划分"""
|
2026-02-23 16:23:53 +08:00
|
|
|
|
from src.pipeline.core import WalkForwardSplit
|
2026-02-23 01:37:34 +08:00
|
|
|
|
|
|
|
|
|
|
splitter = WalkForwardSplit(train_window=5, test_window=2, gap=1)
|
|
|
|
|
|
|
|
|
|
|
|
df = pl.DataFrame(
|
|
|
|
|
|
{
|
|
|
|
|
|
"trade_date": [f"202001{i:02d}" for i in range(1, 13)],
|
|
|
|
|
|
"value": list(range(12)),
|
|
|
|
|
|
}
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
splits = list(splitter.split(df))
|
|
|
|
|
|
|
|
|
|
|
|
# 检查训练集大小固定
|
|
|
|
|
|
for train_idx, test_idx in splits:
|
|
|
|
|
|
assert len(train_idx) == 5
|
|
|
|
|
|
assert len(test_idx) == 2
|
|
|
|
|
|
|
|
|
|
|
|
def test_expanding_window_split(self):
|
|
|
|
|
|
"""测试扩展窗口划分"""
|
2026-02-23 16:23:53 +08:00
|
|
|
|
from src.pipeline.core import ExpandingWindowSplit
|
2026-02-23 01:37:34 +08:00
|
|
|
|
|
|
|
|
|
|
splitter = ExpandingWindowSplit(initial_train_size=3, test_window=2, gap=1)
|
|
|
|
|
|
|
|
|
|
|
|
df = pl.DataFrame(
|
|
|
|
|
|
{
|
|
|
|
|
|
"trade_date": [f"202001{i:02d}" for i in range(1, 15)],
|
|
|
|
|
|
"value": list(range(14)),
|
|
|
|
|
|
}
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
splits = list(splitter.split(df))
|
|
|
|
|
|
|
|
|
|
|
|
# 训练集应该逐渐增大
|
|
|
|
|
|
train_sizes = [len(train_idx) for train_idx, _ in splits]
|
|
|
|
|
|
assert train_sizes[0] == 3
|
|
|
|
|
|
assert train_sizes[1] == 5 # 3 + 2
|
|
|
|
|
|
assert train_sizes[2] == 7 # 5 + 2
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# ========== 测试内置模型(可选,需要安装依赖) ==========
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TestModels:
|
|
|
|
|
|
"""测试内置模型(标记为跳过如果依赖未安装)"""
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.skip(reason="需要安装 lightgbm")
|
|
|
|
|
|
def test_lightgbm_model(self):
|
|
|
|
|
|
"""测试 LightGBM 模型"""
|
2026-02-23 16:23:53 +08:00
|
|
|
|
from src.pipeline.models import LightGBMModel
|
2026-02-23 01:37:34 +08:00
|
|
|
|
|
|
|
|
|
|
model = LightGBMModel(task_type="regression", params={"n_estimators": 10})
|
|
|
|
|
|
|
|
|
|
|
|
X = pl.DataFrame(
|
|
|
|
|
|
{
|
|
|
|
|
|
"feature1": [1.0, 2.0, 3.0, 4.0, 5.0] * 10,
|
|
|
|
|
|
"feature2": [5.0, 4.0, 3.0, 2.0, 1.0] * 10,
|
|
|
|
|
|
}
|
|
|
|
|
|
)
|
|
|
|
|
|
y = pl.Series("target", [1.0, 2.0, 3.0, 4.0, 5.0] * 10)
|
|
|
|
|
|
|
|
|
|
|
|
model.fit(X, y)
|
|
|
|
|
|
predictions = model.predict(X)
|
|
|
|
|
|
|
|
|
|
|
|
assert len(predictions) == len(X)
|
|
|
|
|
|
assert model._is_fitted
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
|
|
pytest.main([__file__, "-v"])
|