Files
ProStock/tests/factors/test_base.py

407 lines
13 KiB
Python
Raw Normal View History

"""测试因子基类 - BaseFactor、CrossSectionalFactor、TimeSeriesFactor
测试需求来自 factor_implementation_plan.md
- BaseFactor:
- 测试有效子类创建通过验证
- 测试缺少 `name` 时抛出 ValueError
- 测试 `name` 为空字符串时抛出 ValueError
- 测试缺少 `factor_type` 时抛出 ValueError
- 测试无效的 `factor_type` cs/ts时抛出 ValueError
- 测试缺少 `data_specs` 时抛出 ValueError
- 测试 `data_specs` 为空列表时抛出 ValueError
- 测试 `compute()` 抽象方法强制子类实现
- 测试参数化初始化 `params` 正确存储
- 测试 `_validate_params()` 被调用
- CrossSectionalFactor:
- 测试 `factor_type` 自动设置为 "cross_sectional"
- 测试子类必须实现 `compute()`
- 测试 `compute()` 返回类型为 pl.Series
- TimeSeriesFactor:
- 测试 `factor_type` 自动设置为 "time_series"
- 测试子类必须实现 `compute()`
- 测试 `compute()` 返回类型为 pl.Series
"""
import pytest
import polars as pl
from src.factors import DataSpec, FactorContext, FactorData
from src.factors.base import BaseFactor, CrossSectionalFactor, TimeSeriesFactor
# ========== 测试数据准备 ==========
@pytest.fixture
def sample_dataspec():
"""创建一个示例 DataSpec"""
return DataSpec(
source="daily", columns=["ts_code", "trade_date", "close"], lookback_days=5
)
@pytest.fixture
def sample_factor_data():
"""创建一个示例 FactorData"""
df = pl.DataFrame(
{
"ts_code": ["000001.SZ", "000002.SZ", "000001.SZ", "000002.SZ"],
"trade_date": ["20240101", "20240101", "20240102", "20240102"],
"close": [10.0, 20.0, 11.0, 21.0],
}
)
context = FactorContext(current_date="20240102")
return FactorData(df, context)
# ========== BaseFactor 测试 ==========
class TestBaseFactorValidation:
"""测试 BaseFactor 子类验证"""
def test_valid_cross_sectional_subclass(self, sample_dataspec):
"""测试有效的截面因子子类创建通过验证"""
class ValidFactor(CrossSectionalFactor):
name = "valid_cs"
data_specs = [sample_dataspec]
def compute(self, data):
return pl.Series([1.0])
# 应该能成功创建实例
factor = ValidFactor()
assert factor.name == "valid_cs"
assert factor.factor_type == "cross_sectional"
def test_valid_time_series_subclass(self, sample_dataspec):
"""测试有效的时序因子子类创建通过验证"""
class ValidFactor(TimeSeriesFactor):
name = "valid_ts"
data_specs = [sample_dataspec]
def compute(self, data):
return pl.Series([1.0])
factor = ValidFactor()
assert factor.name == "valid_ts"
assert factor.factor_type == "time_series"
def test_missing_name(self, sample_dataspec):
"""测试缺少 name 时抛出 ValueError"""
with pytest.raises(ValueError, match="must define 'name'"):
class BadFactor(CrossSectionalFactor):
# name = "" # 故意不定义
data_specs = [sample_dataspec]
def compute(self, data):
return pl.Series([1.0])
def test_empty_name(self, sample_dataspec):
"""测试 name 为空字符串时抛出 ValueError"""
with pytest.raises(ValueError, match="must define 'name'"):
class BadFactor(CrossSectionalFactor):
name = "" # 空字符串
data_specs = [sample_dataspec]
def compute(self, data):
return pl.Series([1.0])
def test_missing_factor_type(self, sample_dataspec):
"""测试缺少 factor_type 时抛出 ValueError"""
with pytest.raises(ValueError, match="must define 'factor_type'"):
class BadFactor(BaseFactor):
name = "bad_factor"
# factor_type = "" # 故意不定义
data_specs = [sample_dataspec]
def compute(self, data):
return pl.Series([1.0])
def test_invalid_factor_type(self, sample_dataspec):
"""测试无效的 factor_type非 cs/ts时抛出 ValueError"""
with pytest.raises(
ValueError, match="must be 'cross_sectional' or 'time_series'"
):
class BadFactor(BaseFactor):
name = "bad_factor"
factor_type = "invalid_type"
data_specs = [sample_dataspec]
def compute(self, data):
return pl.Series([1.0])
def test_missing_data_specs(self):
"""测试缺少 data_specs 时抛出 ValueError"""
with pytest.raises(ValueError, match="must define 'data_specs'"):
class BadFactor(BaseFactor):
name = "bad_factor"
factor_type = "cross_sectional"
# data_specs = [] # 故意不定义
def compute(self, data):
return pl.Series([1.0])
def test_empty_data_specs(self):
"""测试 data_specs 为空列表时抛出 ValueError"""
with pytest.raises(ValueError, match="cannot be empty"):
class BadFactor(CrossSectionalFactor):
name = "bad_factor"
data_specs = [] # 空列表
def compute(self, data):
return pl.Series([1.0])
class TestBaseFactorCompute:
"""测试 compute() 抽象方法"""
def test_compute_must_be_implemented_cs(self, sample_dataspec):
"""测试截面因子子类必须实现 compute()"""
with pytest.raises(TypeError, match="Can't instantiate abstract class"):
class BadFactor(CrossSectionalFactor):
name = "bad_cs"
data_specs = [sample_dataspec]
# 不实现 compute()
BadFactor()
def test_compute_must_be_implemented_ts(self, sample_dataspec):
"""测试时序因子子类必须实现 compute()"""
with pytest.raises(TypeError, match="Can't instantiate abstract class"):
class BadFactor(TimeSeriesFactor):
name = "bad_ts"
data_specs = [sample_dataspec]
# 不实现 compute()
BadFactor()
class TestBaseFactorParams:
"""测试参数化初始化"""
def test_params_stored_correctly(self, sample_dataspec):
"""测试参数化初始化 params 正确存储"""
class ParamFactor(CrossSectionalFactor):
name = "param_factor"
data_specs = [sample_dataspec]
def __init__(self, period: int = 20, weight: float = 1.0):
super().__init__(period=period, weight=weight)
def compute(self, data):
return pl.Series([1.0])
factor = ParamFactor(period=10, weight=0.5)
assert factor.params["period"] == 10
assert factor.params["weight"] == 0.5
def test_validate_params_called(self, sample_dataspec):
"""测试 _validate_params() 被调用"""
validated = []
class ValidatedFactor(CrossSectionalFactor):
name = "validated_factor"
data_specs = [sample_dataspec]
def __init__(self, period: int = 20):
super().__init__(period=period)
def _validate_params(self):
validated.append(True)
if self.params.get("period", 0) <= 0:
raise ValueError("period must be positive")
def compute(self, data):
return pl.Series([1.0])
# 创建实例时应该调用 _validate_params
factor = ValidatedFactor(period=10)
assert len(validated) == 1
assert factor.params["period"] == 10
def test_validate_params_raises(self, sample_dataspec):
"""测试 _validate_params() 可以抛出异常"""
class BadParamFactor(CrossSectionalFactor):
name = "bad_param_factor"
data_specs = [sample_dataspec]
def __init__(self, period: int = 20):
super().__init__(period=period)
def _validate_params(self):
if self.params.get("period", 0) <= 0:
raise ValueError("period must be positive")
def compute(self, data):
return pl.Series([1.0])
with pytest.raises(ValueError, match="period must be positive"):
BadParamFactor(period=-5)
class TestBaseFactorRepr:
"""测试 __repr__"""
def test_repr(self, sample_dataspec):
"""测试 __repr__ 返回正确格式"""
class TestFactor(CrossSectionalFactor):
name = "test_factor"
data_specs = [sample_dataspec]
def compute(self, data):
return pl.Series([1.0])
factor = TestFactor()
repr_str = repr(factor)
assert "TestFactor" in repr_str
assert "test_factor" in repr_str
assert "cross_sectional" in repr_str
# ========== CrossSectionalFactor 测试 ==========
class TestCrossSectionalFactor:
"""测试 CrossSectionalFactor"""
def test_factor_type_auto_set(self, sample_dataspec):
"""测试 factor_type 自动设置为 'cross_sectional'"""
class CSFactor(CrossSectionalFactor):
name = "cs_factor"
data_specs = [sample_dataspec]
def compute(self, data):
return pl.Series([1.0])
factor = CSFactor()
assert factor.factor_type == "cross_sectional"
def test_compute_returns_series(self, sample_factor_data, sample_dataspec):
"""测试 compute() 返回类型为 pl.Series"""
class CSFactor(CrossSectionalFactor):
name = "cs_factor"
data_specs = [sample_dataspec]
def compute(self, data):
# 返回一个简单的 Series
return pl.Series([1.0, 2.0])
factor = CSFactor()
result = factor.compute(sample_factor_data)
assert isinstance(result, pl.Series)
def test_compute_with_cross_section(self, sample_dataspec):
"""测试 compute() 使用 get_cross_section()"""
df = pl.DataFrame(
{
"ts_code": ["000001.SZ", "000002.SZ"],
"trade_date": ["20240101", "20240101"],
"close": [10.0, 20.0],
}
)
context = FactorContext(current_date="20240101")
data = FactorData(df, context)
class RankFactor(CrossSectionalFactor):
name = "rank_factor"
data_specs = [sample_dataspec]
def compute(self, data):
cs = data.get_cross_section()
return cs["close"].rank()
factor = RankFactor()
result = factor.compute(data)
assert isinstance(result, pl.Series)
assert len(result) == 2
# ========== TimeSeriesFactor 测试 ==========
class TestTimeSeriesFactor:
"""测试 TimeSeriesFactor"""
def test_factor_type_auto_set(self, sample_dataspec):
"""测试 factor_type 自动设置为 'time_series'"""
class TSFactor(TimeSeriesFactor):
name = "ts_factor"
data_specs = [sample_dataspec]
def compute(self, data):
return pl.Series([1.0])
factor = TSFactor()
assert factor.factor_type == "time_series"
def test_compute_returns_series(self, sample_factor_data, sample_dataspec):
"""测试 compute() 返回类型为 pl.Series"""
class TSFactor(TimeSeriesFactor):
name = "ts_factor"
data_specs = [sample_dataspec]
def compute(self, data):
return data.get_column("close") * 2
factor = TSFactor()
result = factor.compute(sample_factor_data)
assert isinstance(result, pl.Series)
assert len(result) == 4 # 4行数据
def test_compute_with_rolling(self, sample_dataspec):
"""测试 compute() 使用 rolling 操作"""
df = pl.DataFrame(
{
"ts_code": ["000001.SZ"] * 5,
"trade_date": [
"20240101",
"20240102",
"20240103",
"20240104",
"20240105",
],
"close": [10.0, 11.0, 12.0, 13.0, 14.0],
}
)
context = FactorContext(current_stock="000001.SZ")
data = FactorData(df, context)
class MAFactor(TimeSeriesFactor):
name = "ma_factor"
data_specs = [sample_dataspec]
def __init__(self, period: int = 3):
super().__init__(period=period)
def compute(self, data):
return data.get_column("close").rolling_mean(self.params["period"])
factor = MAFactor(period=3)
result = factor.compute(data)
assert isinstance(result, pl.Series)
assert len(result) == 5
# 前2个应该是 null因为 period=3
assert result[0] is None
assert result[1] is None
assert result[2] is not None