Files
ProStock/tests/factors/test_composite.py
liaozhaorun 0a16129548 feat(factors): 添加因子计算框架
- 新增因子基类 (BaseFactor, CrossSectionalFactor, TimeSeriesFactor)
- 新增数据规格和上下文类 (DataSpec, FactorContext, FactorData)
- 新增数据加载器 (DataLoader) 和执行引擎 (FactorEngine)
- 新增组合因子支持 (CompositeFactor, ScalarFactor)
- 添加因子模块完整测试用例
- 添加 Git 提交规范文档
2026-02-22 14:41:32 +08:00

418 lines
14 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""测试组合因子 - CompositeFactor、ScalarFactor
测试需求(来自 factor_implementation_plan.md
- CompositeFactor:
- 测试同类型因子组合成功cs + cs
- 测试同类型因子组合成功ts + ts
- 测试不同类型因子组合抛出 ValueErrorcs + ts
- 测试无效运算符抛出 ValueError
- 测试 `_merge_data_specs()` 正确合并(相同 source
- 测试 `_merge_data_specs()` 正确合并(不同 source
- 测试 `_merge_data_specs()` lookback 取最大值
- 测试 `compute()` 执行正确的数学运算
- ScalarFactor:
- 测试标量乘法 `0.5 * factor`
- 测试标量乘法 `factor * 0.5`
- 测试标量加法(如支持)
- 测试继承基础因子的 data_specs
- 测试 `compute()` 返回正确缩放后的值
"""
import pytest
import polars as pl
from src.factors import DataSpec, FactorContext, FactorData
from src.factors.base import CrossSectionalFactor, TimeSeriesFactor
from src.factors.composite import CompositeFactor, ScalarFactor
# ========== 测试数据准备 ==========
@pytest.fixture
def sample_dataspec():
"""创建一个示例 DataSpec"""
return DataSpec(
source="daily", columns=["ts_code", "trade_date", "close"], lookback_days=5
)
@pytest.fixture
def cs_factor1(sample_dataspec):
"""截面因子 1"""
class CSFactor1(CrossSectionalFactor):
name = "cs_factor1"
data_specs = [sample_dataspec]
def compute(self, data):
return pl.Series([1.0, 2.0])
return CSFactor1()
@pytest.fixture
def cs_factor2(sample_dataspec):
"""截面因子 2"""
class CSFactor2(CrossSectionalFactor):
name = "cs_factor2"
data_specs = [sample_dataspec]
def compute(self, data):
return pl.Series([3.0, 4.0])
return CSFactor2()
@pytest.fixture
def ts_factor1(sample_dataspec):
"""时序因子 1"""
class TSFactor1(TimeSeriesFactor):
name = "ts_factor1"
data_specs = [sample_dataspec]
def compute(self, data):
return pl.Series([10.0, 20.0, 30.0])
return TSFactor1()
@pytest.fixture
def ts_factor2(sample_dataspec):
"""时序因子 2"""
class TSFactor2(TimeSeriesFactor):
name = "ts_factor2"
data_specs = [sample_dataspec]
def compute(self, data):
return pl.Series([1.0, 2.0, 3.0])
return TSFactor2()
@pytest.fixture
def sample_factor_data():
"""创建一个示例 FactorData"""
df = pl.DataFrame(
{
"ts_code": ["000001.SZ", "000002.SZ", "000001.SZ", "000002.SZ"],
"trade_date": ["20240101", "20240101", "20240102", "20240102"],
"close": [10.0, 20.0, 11.0, 21.0],
}
)
context = FactorContext(current_date="20240102")
return FactorData(df, context)
# ========== CompositeFactor 测试 ==========
class TestCompositeFactorTypeValidation:
"""测试类型验证"""
def test_same_type_combination_cs(self, cs_factor1, cs_factor2):
"""测试同类型截面因子组合成功"""
combined = cs_factor1 + cs_factor2
assert isinstance(combined, CompositeFactor)
assert combined.factor_type == "cross_sectional"
assert combined.name == "(cs_factor1_+_cs_factor2)"
def test_same_type_combination_ts(self, ts_factor1, ts_factor2):
"""测试同类型时序因子组合成功"""
combined = ts_factor1 - ts_factor2
assert isinstance(combined, CompositeFactor)
assert combined.factor_type == "time_series"
assert combined.name == "(ts_factor1_-_ts_factor2)"
def test_different_type_raises(self, cs_factor1, ts_factor1):
"""测试不同类型因子组合抛出 ValueError"""
with pytest.raises(
ValueError, match="Cannot combine factors of different types"
):
cs_factor1 + ts_factor1
def test_invalid_operator_raises(self, cs_factor1, cs_factor2):
"""测试无效运算符抛出 ValueError"""
with pytest.raises(ValueError, match="Unsupported operator"):
CompositeFactor(cs_factor1, cs_factor2, "%")
class TestCompositeFactorMergeDataSpecs:
"""测试 _merge_data_specs"""
def test_merge_same_source_same_columns(self):
"""测试相同 source 和 columns 的 DataSpec 合并"""
spec1 = DataSpec("daily", ["ts_code", "trade_date", "close"], lookback_days=5)
spec2 = DataSpec("daily", ["ts_code", "trade_date", "close"], lookback_days=10)
class Factor1(CrossSectionalFactor):
name = "f1"
data_specs = [spec1]
def compute(self, data):
return pl.Series([1.0])
class Factor2(CrossSectionalFactor):
name = "f2"
data_specs = [spec2]
def compute(self, data):
return pl.Series([2.0])
combined = Factor1() + Factor2()
# 应该合并成一个 DataSpeclookback_days 取最大值 10
assert len(combined.data_specs) == 1
assert combined.data_specs[0].lookback_days == 10
assert combined.data_specs[0].source == "daily"
def test_merge_different_source(self):
"""测试不同 source 的 DataSpec 不合并"""
spec1 = DataSpec("daily", ["ts_code", "trade_date", "close"], lookback_days=5)
spec2 = DataSpec(
"fundamental", ["ts_code", "trade_date", "pe"], lookback_days=1
)
class Factor1(CrossSectionalFactor):
name = "f1"
data_specs = [spec1]
def compute(self, data):
return pl.Series([1.0])
class Factor2(CrossSectionalFactor):
name = "f2"
data_specs = [spec2]
def compute(self, data):
return pl.Series([2.0])
combined = Factor1() + Factor2()
# 应该有两个 DataSpec
assert len(combined.data_specs) == 2
sources = {s.source for s in combined.data_specs}
assert sources == {"daily", "fundamental"}
def test_merge_same_source_different_columns(self):
"""测试相同 source 但不同 columns 的 DataSpec 不合并"""
spec1 = DataSpec("daily", ["ts_code", "trade_date", "close"], lookback_days=5)
spec2 = DataSpec("daily", ["ts_code", "trade_date", "open"], lookback_days=3)
class Factor1(CrossSectionalFactor):
name = "f1"
data_specs = [spec1]
def compute(self, data):
return pl.Series([1.0])
class Factor2(CrossSectionalFactor):
name = "f2"
data_specs = [spec2]
def compute(self, data):
return pl.Series([2.0])
combined = Factor1() + Factor2()
# 应该有两个 DataSpec因为 columns 不同)
assert len(combined.data_specs) == 2
def test_merge_lookback_max(self):
"""测试 lookback_days 取最大值"""
spec1 = DataSpec("daily", ["ts_code", "trade_date", "close"], lookback_days=5)
spec2 = DataSpec("daily", ["ts_code", "trade_date", "close"], lookback_days=20)
spec3 = DataSpec("daily", ["ts_code", "trade_date", "close"], lookback_days=10)
class Factor1(CrossSectionalFactor):
name = "f1"
data_specs = [spec1]
def compute(self, data):
return pl.Series([1.0])
class Factor2(CrossSectionalFactor):
name = "f2"
data_specs = [spec2, spec3]
def compute(self, data):
return pl.Series([2.0])
combined = Factor1() + Factor2()
# 应该合并成一个 DataSpeclookback_days 取最大值 20
assert len(combined.data_specs) == 1
assert combined.data_specs[0].lookback_days == 20
class TestCompositeFactorCompute:
"""测试 compute 运算"""
def test_compute_addition(self, cs_factor1, cs_factor2, sample_factor_data):
"""测试加法运算"""
combined = cs_factor1 + cs_factor2
result = combined.compute(sample_factor_data)
# [1.0, 2.0] + [3.0, 4.0] = [4.0, 6.0]
expected = pl.Series([4.0, 6.0])
assert (result - expected).abs().max() < 1e-10
def test_compute_subtraction(self, cs_factor1, cs_factor2, sample_factor_data):
"""测试减法运算"""
combined = cs_factor1 - cs_factor2
result = combined.compute(sample_factor_data)
# [1.0, 2.0] - [3.0, 4.0] = [-2.0, -2.0]
expected = pl.Series([-2.0, -2.0])
assert (result - expected).abs().max() < 1e-10
def test_compute_multiplication(self, cs_factor1, cs_factor2, sample_factor_data):
"""测试乘法运算"""
combined = cs_factor1 * cs_factor2
result = combined.compute(sample_factor_data)
# [1.0, 2.0] * [3.0, 4.0] = [3.0, 8.0]
expected = pl.Series([3.0, 8.0])
assert (result - expected).abs().max() < 1e-10
def test_compute_division(self, cs_factor1, cs_factor2, sample_factor_data):
"""测试除法运算"""
combined = cs_factor1 / cs_factor2
result = combined.compute(sample_factor_data)
# [1.0, 2.0] / [3.0, 4.0] = [0.333..., 0.5]
expected = pl.Series([1.0 / 3.0, 0.5])
assert (result - expected).abs().max() < 1e-10
def test_compute_with_ts_factors(self, ts_factor1, ts_factor2, sample_factor_data):
"""测试时序因子的组合运算"""
combined = ts_factor1 + ts_factor2
result = combined.compute(sample_factor_data)
# [10.0, 20.0, 30.0] + [1.0, 2.0, 3.0] = [11.0, 22.0, 33.0]
expected = pl.Series([11.0, 22.0, 33.0])
assert (result - expected).abs().max() < 1e-10
def test_chained_combination(self, cs_factor1, cs_factor2, sample_factor_data):
"""测试链式组合 (f1 + f2) * f1"""
# 创建第三个因子
class CSFactor3(CrossSectionalFactor):
name = "cs_factor3"
data_specs = [
DataSpec("daily", ["ts_code", "trade_date", "close"], lookback_days=5)
]
def compute(self, data):
return pl.Series([0.5, 1.0])
f3 = CSFactor3()
# (f1 + f2) * f3
# f1 + f2 = [1.0, 2.0] + [3.0, 4.0] = [4.0, 6.0]
# [4.0, 6.0] * [0.5, 1.0] = [2.0, 6.0]
combined = (cs_factor1 + cs_factor2) * f3
result = combined.compute(sample_factor_data)
expected = pl.Series([2.0, 6.0])
assert (result - expected).abs().max() < 1e-10
# ========== ScalarFactor 测试 ==========
class TestScalarFactor:
"""测试 ScalarFactor"""
def test_scalar_multiplication_left(self, cs_factor1):
"""测试标量乘法 `0.5 * factor`(左乘)"""
scaled = 0.5 * cs_factor1
assert isinstance(scaled, ScalarFactor)
assert scaled.scalar == 0.5
assert scaled.op == "*"
assert scaled.factor == cs_factor1
def test_scalar_multiplication_right(self, cs_factor1):
"""测试标量乘法 `factor * 0.5`(右乘)"""
scaled = cs_factor1 * 0.5
assert isinstance(scaled, ScalarFactor)
assert scaled.scalar == 0.5
assert scaled.op == "*"
def test_scalar_integer_multiplication(self, cs_factor1):
"""测试整数标量乘法"""
scaled = 2 * cs_factor1
assert isinstance(scaled, ScalarFactor)
assert scaled.scalar == 2.0
def test_inherits_data_specs(self, cs_factor1):
"""测试继承基础因子的 data_specs"""
scaled = 0.5 * cs_factor1
assert scaled.data_specs == cs_factor1.data_specs
def test_compute_multiplication(self, cs_factor1, sample_factor_data):
"""测试标量乘法 compute 结果"""
scaled = 0.5 * cs_factor1
result = scaled.compute(sample_factor_data)
# [1.0, 2.0] * 0.5 = [0.5, 1.0]
expected = pl.Series([0.5, 1.0])
assert (result - expected).abs().max() < 1e-10
def test_compute_with_ts_factor(self, ts_factor1, sample_factor_data):
"""测试时序因子的标量乘法"""
scaled = 0.1 * ts_factor1
result = scaled.compute(sample_factor_data)
# [10.0, 20.0, 30.0] * 0.1 = [1.0, 2.0, 3.0]
expected = pl.Series([1.0, 2.0, 3.0])
assert (result - expected).abs().max() < 1e-10
def test_factor_type_preserved(self, cs_factor1, ts_factor1):
"""测试 factor_type 被正确保留"""
scaled_cs = 0.5 * cs_factor1
scaled_ts = 0.5 * ts_factor1
assert scaled_cs.factor_type == "cross_sectional"
assert scaled_ts.factor_type == "time_series"
def test_scalar_name_format(self, cs_factor1):
"""测试 ScalarFactor 的 name 格式"""
scaled = 0.5 * cs_factor1
assert scaled.name == "(0.5_*_cs_factor1)"
# ========== 组合和标量混合测试 ==========
class TestMixedOperations:
"""测试组合因子和标量因子的混合运算"""
def test_scalar_then_combine(self, cs_factor1, cs_factor2, sample_factor_data):
"""测试先标量缩放再组合"""
# 0.5 * f1 + 0.3 * f2
combined = 0.5 * cs_factor1 + 0.3 * cs_factor2
result = combined.compute(sample_factor_data)
# 0.5 * [1.0, 2.0] + 0.3 * [3.0, 4.0]
# = [0.5, 1.0] + [0.9, 1.2]
# = [1.4, 2.2]
expected = pl.Series([1.4, 2.2])
assert (result - expected).abs().max() < 1e-10
def test_complex_formula(self, cs_factor1, cs_factor2, sample_factor_data):
"""测试复杂公式: (f1 + f2) * 0.5 - f1 * 0.2"""
formula = (cs_factor1 + cs_factor2) * 0.5 - cs_factor1 * 0.2
result = formula.compute(sample_factor_data)
# (f1 + f2) = [4.0, 6.0]
# (f1 + f2) * 0.5 = [2.0, 3.0]
# f1 * 0.2 = [0.2, 0.4]
# [2.0, 3.0] - [0.2, 0.4] = [1.8, 2.6]
expected = pl.Series([1.8, 2.6])
assert (result - expected).abs().max() < 1e-10