- 新增因子基类 (BaseFactor, CrossSectionalFactor, TimeSeriesFactor) - 新增数据规格和上下文类 (DataSpec, FactorContext, FactorData) - 新增数据加载器 (DataLoader) 和执行引擎 (FactorEngine) - 新增组合因子支持 (CompositeFactor, ScalarFactor) - 添加因子模块完整测试用例 - 添加 Git 提交规范文档
407 lines
13 KiB
Python
407 lines
13 KiB
Python
"""测试因子基类 - BaseFactor、CrossSectionalFactor、TimeSeriesFactor
|
||
|
||
测试需求(来自 factor_implementation_plan.md):
|
||
- BaseFactor:
|
||
- 测试有效子类创建通过验证
|
||
- 测试缺少 `name` 时抛出 ValueError
|
||
- 测试 `name` 为空字符串时抛出 ValueError
|
||
- 测试缺少 `factor_type` 时抛出 ValueError
|
||
- 测试无效的 `factor_type`(非 cs/ts)时抛出 ValueError
|
||
- 测试缺少 `data_specs` 时抛出 ValueError
|
||
- 测试 `data_specs` 为空列表时抛出 ValueError
|
||
- 测试 `compute()` 抽象方法强制子类实现
|
||
- 测试参数化初始化 `params` 正确存储
|
||
- 测试 `_validate_params()` 被调用
|
||
|
||
- CrossSectionalFactor:
|
||
- 测试 `factor_type` 自动设置为 "cross_sectional"
|
||
- 测试子类必须实现 `compute()`
|
||
- 测试 `compute()` 返回类型为 pl.Series
|
||
|
||
- TimeSeriesFactor:
|
||
- 测试 `factor_type` 自动设置为 "time_series"
|
||
- 测试子类必须实现 `compute()`
|
||
- 测试 `compute()` 返回类型为 pl.Series
|
||
"""
|
||
|
||
import pytest
|
||
import polars as pl
|
||
|
||
from src.factors import DataSpec, FactorContext, FactorData
|
||
from src.factors.base import BaseFactor, CrossSectionalFactor, TimeSeriesFactor
|
||
|
||
|
||
# ========== 测试数据准备 ==========
|
||
|
||
|
||
@pytest.fixture
|
||
def sample_dataspec():
|
||
"""创建一个示例 DataSpec"""
|
||
return DataSpec(
|
||
source="daily", columns=["ts_code", "trade_date", "close"], lookback_days=5
|
||
)
|
||
|
||
|
||
@pytest.fixture
|
||
def sample_factor_data():
|
||
"""创建一个示例 FactorData"""
|
||
df = pl.DataFrame(
|
||
{
|
||
"ts_code": ["000001.SZ", "000002.SZ", "000001.SZ", "000002.SZ"],
|
||
"trade_date": ["20240101", "20240101", "20240102", "20240102"],
|
||
"close": [10.0, 20.0, 11.0, 21.0],
|
||
}
|
||
)
|
||
context = FactorContext(current_date="20240102")
|
||
return FactorData(df, context)
|
||
|
||
|
||
# ========== BaseFactor 测试 ==========
|
||
|
||
|
||
class TestBaseFactorValidation:
|
||
"""测试 BaseFactor 子类验证"""
|
||
|
||
def test_valid_cross_sectional_subclass(self, sample_dataspec):
|
||
"""测试有效的截面因子子类创建通过验证"""
|
||
|
||
class ValidFactor(CrossSectionalFactor):
|
||
name = "valid_cs"
|
||
data_specs = [sample_dataspec]
|
||
|
||
def compute(self, data):
|
||
return pl.Series([1.0])
|
||
|
||
# 应该能成功创建实例
|
||
factor = ValidFactor()
|
||
assert factor.name == "valid_cs"
|
||
assert factor.factor_type == "cross_sectional"
|
||
|
||
def test_valid_time_series_subclass(self, sample_dataspec):
|
||
"""测试有效的时序因子子类创建通过验证"""
|
||
|
||
class ValidFactor(TimeSeriesFactor):
|
||
name = "valid_ts"
|
||
data_specs = [sample_dataspec]
|
||
|
||
def compute(self, data):
|
||
return pl.Series([1.0])
|
||
|
||
factor = ValidFactor()
|
||
assert factor.name == "valid_ts"
|
||
assert factor.factor_type == "time_series"
|
||
|
||
def test_missing_name(self, sample_dataspec):
|
||
"""测试缺少 name 时抛出 ValueError"""
|
||
with pytest.raises(ValueError, match="must define 'name'"):
|
||
|
||
class BadFactor(CrossSectionalFactor):
|
||
# name = "" # 故意不定义
|
||
data_specs = [sample_dataspec]
|
||
|
||
def compute(self, data):
|
||
return pl.Series([1.0])
|
||
|
||
def test_empty_name(self, sample_dataspec):
|
||
"""测试 name 为空字符串时抛出 ValueError"""
|
||
with pytest.raises(ValueError, match="must define 'name'"):
|
||
|
||
class BadFactor(CrossSectionalFactor):
|
||
name = "" # 空字符串
|
||
data_specs = [sample_dataspec]
|
||
|
||
def compute(self, data):
|
||
return pl.Series([1.0])
|
||
|
||
def test_missing_factor_type(self, sample_dataspec):
|
||
"""测试缺少 factor_type 时抛出 ValueError"""
|
||
with pytest.raises(ValueError, match="must define 'factor_type'"):
|
||
|
||
class BadFactor(BaseFactor):
|
||
name = "bad_factor"
|
||
# factor_type = "" # 故意不定义
|
||
data_specs = [sample_dataspec]
|
||
|
||
def compute(self, data):
|
||
return pl.Series([1.0])
|
||
|
||
def test_invalid_factor_type(self, sample_dataspec):
|
||
"""测试无效的 factor_type(非 cs/ts)时抛出 ValueError"""
|
||
with pytest.raises(
|
||
ValueError, match="must be 'cross_sectional' or 'time_series'"
|
||
):
|
||
|
||
class BadFactor(BaseFactor):
|
||
name = "bad_factor"
|
||
factor_type = "invalid_type"
|
||
data_specs = [sample_dataspec]
|
||
|
||
def compute(self, data):
|
||
return pl.Series([1.0])
|
||
|
||
def test_missing_data_specs(self):
|
||
"""测试缺少 data_specs 时抛出 ValueError"""
|
||
with pytest.raises(ValueError, match="must define 'data_specs'"):
|
||
|
||
class BadFactor(BaseFactor):
|
||
name = "bad_factor"
|
||
factor_type = "cross_sectional"
|
||
# data_specs = [] # 故意不定义
|
||
|
||
def compute(self, data):
|
||
return pl.Series([1.0])
|
||
|
||
def test_empty_data_specs(self):
|
||
"""测试 data_specs 为空列表时抛出 ValueError"""
|
||
with pytest.raises(ValueError, match="cannot be empty"):
|
||
|
||
class BadFactor(CrossSectionalFactor):
|
||
name = "bad_factor"
|
||
data_specs = [] # 空列表
|
||
|
||
def compute(self, data):
|
||
return pl.Series([1.0])
|
||
|
||
|
||
class TestBaseFactorCompute:
|
||
"""测试 compute() 抽象方法"""
|
||
|
||
def test_compute_must_be_implemented_cs(self, sample_dataspec):
|
||
"""测试截面因子子类必须实现 compute()"""
|
||
with pytest.raises(TypeError, match="Can't instantiate abstract class"):
|
||
|
||
class BadFactor(CrossSectionalFactor):
|
||
name = "bad_cs"
|
||
data_specs = [sample_dataspec]
|
||
# 不实现 compute()
|
||
|
||
BadFactor()
|
||
|
||
def test_compute_must_be_implemented_ts(self, sample_dataspec):
|
||
"""测试时序因子子类必须实现 compute()"""
|
||
with pytest.raises(TypeError, match="Can't instantiate abstract class"):
|
||
|
||
class BadFactor(TimeSeriesFactor):
|
||
name = "bad_ts"
|
||
data_specs = [sample_dataspec]
|
||
# 不实现 compute()
|
||
|
||
BadFactor()
|
||
|
||
|
||
class TestBaseFactorParams:
|
||
"""测试参数化初始化"""
|
||
|
||
def test_params_stored_correctly(self, sample_dataspec):
|
||
"""测试参数化初始化 params 正确存储"""
|
||
|
||
class ParamFactor(CrossSectionalFactor):
|
||
name = "param_factor"
|
||
data_specs = [sample_dataspec]
|
||
|
||
def __init__(self, period: int = 20, weight: float = 1.0):
|
||
super().__init__(period=period, weight=weight)
|
||
|
||
def compute(self, data):
|
||
return pl.Series([1.0])
|
||
|
||
factor = ParamFactor(period=10, weight=0.5)
|
||
assert factor.params["period"] == 10
|
||
assert factor.params["weight"] == 0.5
|
||
|
||
def test_validate_params_called(self, sample_dataspec):
|
||
"""测试 _validate_params() 被调用"""
|
||
validated = []
|
||
|
||
class ValidatedFactor(CrossSectionalFactor):
|
||
name = "validated_factor"
|
||
data_specs = [sample_dataspec]
|
||
|
||
def __init__(self, period: int = 20):
|
||
super().__init__(period=period)
|
||
|
||
def _validate_params(self):
|
||
validated.append(True)
|
||
if self.params.get("period", 0) <= 0:
|
||
raise ValueError("period must be positive")
|
||
|
||
def compute(self, data):
|
||
return pl.Series([1.0])
|
||
|
||
# 创建实例时应该调用 _validate_params
|
||
factor = ValidatedFactor(period=10)
|
||
assert len(validated) == 1
|
||
assert factor.params["period"] == 10
|
||
|
||
def test_validate_params_raises(self, sample_dataspec):
|
||
"""测试 _validate_params() 可以抛出异常"""
|
||
|
||
class BadParamFactor(CrossSectionalFactor):
|
||
name = "bad_param_factor"
|
||
data_specs = [sample_dataspec]
|
||
|
||
def __init__(self, period: int = 20):
|
||
super().__init__(period=period)
|
||
|
||
def _validate_params(self):
|
||
if self.params.get("period", 0) <= 0:
|
||
raise ValueError("period must be positive")
|
||
|
||
def compute(self, data):
|
||
return pl.Series([1.0])
|
||
|
||
with pytest.raises(ValueError, match="period must be positive"):
|
||
BadParamFactor(period=-5)
|
||
|
||
|
||
class TestBaseFactorRepr:
|
||
"""测试 __repr__"""
|
||
|
||
def test_repr(self, sample_dataspec):
|
||
"""测试 __repr__ 返回正确格式"""
|
||
|
||
class TestFactor(CrossSectionalFactor):
|
||
name = "test_factor"
|
||
data_specs = [sample_dataspec]
|
||
|
||
def compute(self, data):
|
||
return pl.Series([1.0])
|
||
|
||
factor = TestFactor()
|
||
repr_str = repr(factor)
|
||
assert "TestFactor" in repr_str
|
||
assert "test_factor" in repr_str
|
||
assert "cross_sectional" in repr_str
|
||
|
||
|
||
# ========== CrossSectionalFactor 测试 ==========
|
||
|
||
|
||
class TestCrossSectionalFactor:
|
||
"""测试 CrossSectionalFactor"""
|
||
|
||
def test_factor_type_auto_set(self, sample_dataspec):
|
||
"""测试 factor_type 自动设置为 'cross_sectional'"""
|
||
|
||
class CSFactor(CrossSectionalFactor):
|
||
name = "cs_factor"
|
||
data_specs = [sample_dataspec]
|
||
|
||
def compute(self, data):
|
||
return pl.Series([1.0])
|
||
|
||
factor = CSFactor()
|
||
assert factor.factor_type == "cross_sectional"
|
||
|
||
def test_compute_returns_series(self, sample_factor_data, sample_dataspec):
|
||
"""测试 compute() 返回类型为 pl.Series"""
|
||
|
||
class CSFactor(CrossSectionalFactor):
|
||
name = "cs_factor"
|
||
data_specs = [sample_dataspec]
|
||
|
||
def compute(self, data):
|
||
# 返回一个简单的 Series
|
||
return pl.Series([1.0, 2.0])
|
||
|
||
factor = CSFactor()
|
||
result = factor.compute(sample_factor_data)
|
||
assert isinstance(result, pl.Series)
|
||
|
||
def test_compute_with_cross_section(self, sample_dataspec):
|
||
"""测试 compute() 使用 get_cross_section()"""
|
||
df = pl.DataFrame(
|
||
{
|
||
"ts_code": ["000001.SZ", "000002.SZ"],
|
||
"trade_date": ["20240101", "20240101"],
|
||
"close": [10.0, 20.0],
|
||
}
|
||
)
|
||
context = FactorContext(current_date="20240101")
|
||
data = FactorData(df, context)
|
||
|
||
class RankFactor(CrossSectionalFactor):
|
||
name = "rank_factor"
|
||
data_specs = [sample_dataspec]
|
||
|
||
def compute(self, data):
|
||
cs = data.get_cross_section()
|
||
return cs["close"].rank()
|
||
|
||
factor = RankFactor()
|
||
result = factor.compute(data)
|
||
assert isinstance(result, pl.Series)
|
||
assert len(result) == 2
|
||
|
||
|
||
# ========== TimeSeriesFactor 测试 ==========
|
||
|
||
|
||
class TestTimeSeriesFactor:
|
||
"""测试 TimeSeriesFactor"""
|
||
|
||
def test_factor_type_auto_set(self, sample_dataspec):
|
||
"""测试 factor_type 自动设置为 'time_series'"""
|
||
|
||
class TSFactor(TimeSeriesFactor):
|
||
name = "ts_factor"
|
||
data_specs = [sample_dataspec]
|
||
|
||
def compute(self, data):
|
||
return pl.Series([1.0])
|
||
|
||
factor = TSFactor()
|
||
assert factor.factor_type == "time_series"
|
||
|
||
def test_compute_returns_series(self, sample_factor_data, sample_dataspec):
|
||
"""测试 compute() 返回类型为 pl.Series"""
|
||
|
||
class TSFactor(TimeSeriesFactor):
|
||
name = "ts_factor"
|
||
data_specs = [sample_dataspec]
|
||
|
||
def compute(self, data):
|
||
return data.get_column("close") * 2
|
||
|
||
factor = TSFactor()
|
||
result = factor.compute(sample_factor_data)
|
||
assert isinstance(result, pl.Series)
|
||
assert len(result) == 4 # 4行数据
|
||
|
||
def test_compute_with_rolling(self, sample_dataspec):
|
||
"""测试 compute() 使用 rolling 操作"""
|
||
df = pl.DataFrame(
|
||
{
|
||
"ts_code": ["000001.SZ"] * 5,
|
||
"trade_date": [
|
||
"20240101",
|
||
"20240102",
|
||
"20240103",
|
||
"20240104",
|
||
"20240105",
|
||
],
|
||
"close": [10.0, 11.0, 12.0, 13.0, 14.0],
|
||
}
|
||
)
|
||
context = FactorContext(current_stock="000001.SZ")
|
||
data = FactorData(df, context)
|
||
|
||
class MAFactor(TimeSeriesFactor):
|
||
name = "ma_factor"
|
||
data_specs = [sample_dataspec]
|
||
|
||
def __init__(self, period: int = 3):
|
||
super().__init__(period=period)
|
||
|
||
def compute(self, data):
|
||
return data.get_column("close").rolling_mean(self.params["period"])
|
||
|
||
factor = MAFactor(period=3)
|
||
result = factor.compute(data)
|
||
assert isinstance(result, pl.Series)
|
||
assert len(result) == 5
|
||
# 前2个应该是 null(因为 period=3)
|
||
assert result[0] is None
|
||
assert result[1] is None
|
||
assert result[2] is not None
|