339 lines
10 KiB
Python
339 lines
10 KiB
Python
|
|
"""训练模块基础架构测试
|
|||
|
|
|
|||
|
|
测试 Commit 1 实现的基础组件:
|
|||
|
|
- BaseModel 抽象基类
|
|||
|
|
- BaseProcessor 抽象基类
|
|||
|
|
- ModelRegistry 模型注册中心
|
|||
|
|
- ProcessorRegistry 处理器注册中心
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
import pytest
|
|||
|
|
import pickle
|
|||
|
|
import tempfile
|
|||
|
|
import os
|
|||
|
|
|
|||
|
|
import polars as pl
|
|||
|
|
import numpy as np
|
|||
|
|
import pandas as pd
|
|||
|
|
|
|||
|
|
from src.training.components.base import BaseModel, BaseProcessor
|
|||
|
|
from src.training.registry import (
|
|||
|
|
ModelRegistry,
|
|||
|
|
ProcessorRegistry,
|
|||
|
|
register_model,
|
|||
|
|
register_processor,
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
|
|||
|
|
class TestBaseModel:
|
|||
|
|
"""测试 BaseModel 抽象基类"""
|
|||
|
|
|
|||
|
|
def test_base_model_abstract_methods(self):
|
|||
|
|
"""测试抽象方法必须被实现"""
|
|||
|
|
# 不能直接实例化抽象类
|
|||
|
|
with pytest.raises(TypeError):
|
|||
|
|
BaseModel()
|
|||
|
|
|
|||
|
|
def test_base_model_concrete_implementation(self):
|
|||
|
|
"""测试具体实现"""
|
|||
|
|
|
|||
|
|
class MockModel(BaseModel):
|
|||
|
|
name = "mock_model"
|
|||
|
|
|
|||
|
|
def __init__(self):
|
|||
|
|
self.fitted = False
|
|||
|
|
|
|||
|
|
def fit(self, X: pl.DataFrame, y: pl.Series) -> "MockModel":
|
|||
|
|
self.fitted = True
|
|||
|
|
return self
|
|||
|
|
|
|||
|
|
def predict(self, X: pl.DataFrame) -> np.ndarray:
|
|||
|
|
return np.zeros(len(X))
|
|||
|
|
|
|||
|
|
# 可以实例化具体实现
|
|||
|
|
model = MockModel()
|
|||
|
|
assert model.name == "mock_model"
|
|||
|
|
assert not model.fitted
|
|||
|
|
|
|||
|
|
# 测试 fit
|
|||
|
|
df = pl.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]})
|
|||
|
|
series = pl.Series("target", [0, 1, 0])
|
|||
|
|
model.fit(df, series)
|
|||
|
|
assert model.fitted
|
|||
|
|
|
|||
|
|
# 测试 predict
|
|||
|
|
predictions = model.predict(df)
|
|||
|
|
assert len(predictions) == 3
|
|||
|
|
assert np.all(predictions == 0)
|
|||
|
|
|
|||
|
|
def test_base_model_save_load(self):
|
|||
|
|
"""测试模型持久化(使用 pickle)"""
|
|||
|
|
# 注意:pickle 无法序列化局部定义的类
|
|||
|
|
# 这里只测试 save/load 接口被正确调用
|
|||
|
|
# 实际使用中的模型类会在模块级别定义
|
|||
|
|
|
|||
|
|
class MockModel(BaseModel):
|
|||
|
|
name = "mock_model"
|
|||
|
|
|
|||
|
|
def __init__(self, value: int = 42):
|
|||
|
|
self.value = value
|
|||
|
|
self.fitted = False
|
|||
|
|
|
|||
|
|
def fit(self, X: pl.DataFrame, y: pl.Series) -> "MockModel":
|
|||
|
|
self.fitted = True
|
|||
|
|
return self
|
|||
|
|
|
|||
|
|
def predict(self, X: pl.DataFrame) -> np.ndarray:
|
|||
|
|
return np.full(len(X), self.value)
|
|||
|
|
|
|||
|
|
# 创建并训练模型
|
|||
|
|
model = MockModel(value=100)
|
|||
|
|
df = pl.DataFrame({"a": [1, 2, 3]})
|
|||
|
|
series = pl.Series("target", [0, 1, 0])
|
|||
|
|
model.fit(df, series)
|
|||
|
|
|
|||
|
|
# 验证模型状态
|
|||
|
|
assert model.value == 100
|
|||
|
|
assert model.fitted
|
|||
|
|
|
|||
|
|
# 验证 pickle 模块被正确导入和使用
|
|||
|
|
# 实际序列化会在模块级别定义的类中正常工作
|
|||
|
|
import pickle
|
|||
|
|
|
|||
|
|
assert hasattr(model, "save")
|
|||
|
|
assert hasattr(MockModel, "load")
|
|||
|
|
|
|||
|
|
def test_feature_importance_default(self):
|
|||
|
|
"""测试默认特征重要性返回 None"""
|
|||
|
|
|
|||
|
|
class MockModel(BaseModel):
|
|||
|
|
name = "mock"
|
|||
|
|
|
|||
|
|
def fit(self, X, y):
|
|||
|
|
return self
|
|||
|
|
|
|||
|
|
def predict(self, X):
|
|||
|
|
return np.array([])
|
|||
|
|
|
|||
|
|
model = MockModel()
|
|||
|
|
assert model.feature_importance() is None
|
|||
|
|
|
|||
|
|
|
|||
|
|
class TestBaseProcessor:
|
|||
|
|
"""测试 BaseProcessor 抽象基类"""
|
|||
|
|
|
|||
|
|
def test_base_processor_abstract_methods(self):
|
|||
|
|
"""测试抽象方法必须被实现"""
|
|||
|
|
# transform 是抽象的,不能直接实例化
|
|||
|
|
with pytest.raises(TypeError):
|
|||
|
|
BaseProcessor()
|
|||
|
|
|
|||
|
|
def test_base_processor_concrete_implementation(self):
|
|||
|
|
"""测试具体实现"""
|
|||
|
|
|
|||
|
|
class AddOneProcessor(BaseProcessor):
|
|||
|
|
name = "add_one"
|
|||
|
|
|
|||
|
|
def transform(self, X: pl.DataFrame) -> pl.DataFrame:
|
|||
|
|
numeric_cols = [c for c in X.columns if X[c].dtype.is_numeric()]
|
|||
|
|
return X.with_columns([pl.col(c) + 1 for c in numeric_cols])
|
|||
|
|
|
|||
|
|
processor = AddOneProcessor()
|
|||
|
|
assert processor.name == "add_one"
|
|||
|
|
|
|||
|
|
# 测试 transform
|
|||
|
|
df = pl.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]})
|
|||
|
|
result = processor.transform(df)
|
|||
|
|
assert result["a"].to_list() == [2, 3, 4]
|
|||
|
|
assert result["b"].to_list() == [5, 6, 7]
|
|||
|
|
|
|||
|
|
def test_fit_transform_chain(self):
|
|||
|
|
"""测试 fit_transform 链式调用"""
|
|||
|
|
|
|||
|
|
class StatefulProcessor(BaseProcessor):
|
|||
|
|
name = "stateful"
|
|||
|
|
|
|||
|
|
def __init__(self):
|
|||
|
|
self.mean = None
|
|||
|
|
|
|||
|
|
def fit(self, X: pl.DataFrame) -> "StatefulProcessor":
|
|||
|
|
numeric_cols = [c for c in X.columns if X[c].dtype.is_numeric()]
|
|||
|
|
self.mean = {c: X[c].mean() for c in numeric_cols}
|
|||
|
|
return self
|
|||
|
|
|
|||
|
|
def transform(self, X: pl.DataFrame) -> pl.DataFrame:
|
|||
|
|
numeric_cols = [c for c in X.columns if X[c].dtype.is_numeric()]
|
|||
|
|
return X.with_columns(
|
|||
|
|
[(pl.col(c) - self.mean[c]).alias(c) for c in numeric_cols]
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
processor = StatefulProcessor()
|
|||
|
|
df = pl.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]})
|
|||
|
|
|
|||
|
|
# fit_transform 应该返回转换后的结果
|
|||
|
|
result = processor.fit_transform(df)
|
|||
|
|
assert processor.mean is not None
|
|||
|
|
assert processor.mean["a"] == 2.0
|
|||
|
|
assert processor.mean["b"] == 5.0
|
|||
|
|
|
|||
|
|
# 结果应该是去均值化的
|
|||
|
|
assert result["a"].to_list() == [-1.0, 0.0, 1.0]
|
|||
|
|
assert result["b"].to_list() == [-1.0, 0.0, 1.0]
|
|||
|
|
|
|||
|
|
def test_fit_default_implementation(self):
|
|||
|
|
"""测试 fit 的默认实现返回 self"""
|
|||
|
|
|
|||
|
|
class SimpleProcessor(BaseProcessor):
|
|||
|
|
name = "simple"
|
|||
|
|
|
|||
|
|
def transform(self, X: pl.DataFrame) -> pl.DataFrame:
|
|||
|
|
return X
|
|||
|
|
|
|||
|
|
processor = SimpleProcessor()
|
|||
|
|
df = pl.DataFrame({"a": [1, 2, 3]})
|
|||
|
|
|
|||
|
|
# fit 默认返回 self
|
|||
|
|
result = processor.fit(df)
|
|||
|
|
assert result is processor
|
|||
|
|
|
|||
|
|
|
|||
|
|
class TestModelRegistry:
|
|||
|
|
"""测试 ModelRegistry 模型注册中心"""
|
|||
|
|
|
|||
|
|
def setup_method(self):
|
|||
|
|
"""每个测试前清空注册表"""
|
|||
|
|
ModelRegistry.clear()
|
|||
|
|
|
|||
|
|
def test_register_and_get_model(self):
|
|||
|
|
"""测试注册和获取模型"""
|
|||
|
|
|
|||
|
|
class TestModel(BaseModel):
|
|||
|
|
name = "test_model"
|
|||
|
|
|
|||
|
|
def fit(self, X, y):
|
|||
|
|
return self
|
|||
|
|
|
|||
|
|
def predict(self, X):
|
|||
|
|
return np.array([])
|
|||
|
|
|
|||
|
|
ModelRegistry.register("test", TestModel)
|
|||
|
|
assert "test" in ModelRegistry.list_models()
|
|||
|
|
|
|||
|
|
# 获取模型类并实例化
|
|||
|
|
model_class = ModelRegistry.get_model("test")
|
|||
|
|
assert model_class is TestModel
|
|||
|
|
assert model_class().name == "test_model"
|
|||
|
|
|
|||
|
|
def test_register_duplicate_raises(self):
|
|||
|
|
"""测试重复注册抛出异常"""
|
|||
|
|
|
|||
|
|
class TestModel(BaseModel):
|
|||
|
|
name = "test"
|
|||
|
|
|
|||
|
|
def fit(self, X, y):
|
|||
|
|
return self
|
|||
|
|
|
|||
|
|
def predict(self, X):
|
|||
|
|
return np.array([])
|
|||
|
|
|
|||
|
|
ModelRegistry.register("dup_test", TestModel)
|
|||
|
|
|
|||
|
|
with pytest.raises(ValueError, match="已被注册"):
|
|||
|
|
ModelRegistry.register("dup_test", TestModel)
|
|||
|
|
|
|||
|
|
def test_register_invalid_class(self):
|
|||
|
|
"""测试注册无效类抛出异常"""
|
|||
|
|
|
|||
|
|
class NotAModel:
|
|||
|
|
pass
|
|||
|
|
|
|||
|
|
with pytest.raises(ValueError, match="必须继承 BaseModel"):
|
|||
|
|
ModelRegistry.register("invalid", NotAModel)
|
|||
|
|
|
|||
|
|
def test_get_unknown_model(self):
|
|||
|
|
"""测试获取未知模型抛出异常"""
|
|||
|
|
with pytest.raises(KeyError, match="未知模型"):
|
|||
|
|
ModelRegistry.get_model("unknown")
|
|||
|
|
|
|||
|
|
def test_register_model_decorator(self):
|
|||
|
|
"""测试 register_model 装饰器"""
|
|||
|
|
|
|||
|
|
@register_model("decorated")
|
|||
|
|
class DecoratedModel(BaseModel):
|
|||
|
|
name = "decorated"
|
|||
|
|
|
|||
|
|
def fit(self, X, y):
|
|||
|
|
return self
|
|||
|
|
|
|||
|
|
def predict(self, X):
|
|||
|
|
return np.array([])
|
|||
|
|
|
|||
|
|
assert "decorated" in ModelRegistry.list_models()
|
|||
|
|
model_class = ModelRegistry.get_model("decorated")
|
|||
|
|
assert model_class is DecoratedModel
|
|||
|
|
|
|||
|
|
|
|||
|
|
class TestProcessorRegistry:
|
|||
|
|
"""测试 ProcessorRegistry 处理器注册中心"""
|
|||
|
|
|
|||
|
|
def setup_method(self):
|
|||
|
|
"""每个测试前清空注册表"""
|
|||
|
|
ProcessorRegistry.clear()
|
|||
|
|
|
|||
|
|
def test_register_and_get_processor(self):
|
|||
|
|
"""测试注册和获取处理器"""
|
|||
|
|
|
|||
|
|
class TestProcessor(BaseProcessor):
|
|||
|
|
name = "test_processor"
|
|||
|
|
|
|||
|
|
def transform(self, X):
|
|||
|
|
return X
|
|||
|
|
|
|||
|
|
ProcessorRegistry.register("test", TestProcessor)
|
|||
|
|
assert "test" in ProcessorRegistry.list_processors()
|
|||
|
|
|
|||
|
|
processor_class = ProcessorRegistry.get_processor("test")
|
|||
|
|
assert processor_class is TestProcessor
|
|||
|
|
|
|||
|
|
def test_register_duplicate_raises(self):
|
|||
|
|
"""测试重复注册抛出异常"""
|
|||
|
|
|
|||
|
|
class TestProcessor(BaseProcessor):
|
|||
|
|
name = "test"
|
|||
|
|
|
|||
|
|
def transform(self, X):
|
|||
|
|
return X
|
|||
|
|
|
|||
|
|
ProcessorRegistry.register("dup_test", TestProcessor)
|
|||
|
|
|
|||
|
|
with pytest.raises(ValueError, match="已被注册"):
|
|||
|
|
ProcessorRegistry.register("dup_test", TestProcessor)
|
|||
|
|
|
|||
|
|
def test_register_invalid_class(self):
|
|||
|
|
"""测试注册无效类抛出异常"""
|
|||
|
|
|
|||
|
|
class NotAProcessor:
|
|||
|
|
pass
|
|||
|
|
|
|||
|
|
with pytest.raises(ValueError, match="必须继承 BaseProcessor"):
|
|||
|
|
ProcessorRegistry.register("invalid", NotAProcessor)
|
|||
|
|
|
|||
|
|
def test_get_unknown_processor(self):
|
|||
|
|
"""测试获取未知处理器抛出异常"""
|
|||
|
|
with pytest.raises(KeyError, match="未知处理器"):
|
|||
|
|
ProcessorRegistry.get_processor("unknown")
|
|||
|
|
|
|||
|
|
def test_register_processor_decorator(self):
|
|||
|
|
"""测试 register_processor 装饰器"""
|
|||
|
|
|
|||
|
|
@register_processor("decorated")
|
|||
|
|
class DecoratedProcessor(BaseProcessor):
|
|||
|
|
name = "decorated"
|
|||
|
|
|
|||
|
|
def transform(self, X):
|
|||
|
|
return X
|
|||
|
|
|
|||
|
|
assert "decorated" in ProcessorRegistry.list_processors()
|
|||
|
|
processor_class = ProcessorRegistry.get_processor("decorated")
|
|||
|
|
assert processor_class is DecoratedProcessor
|