feat(factors): 添加因子计算框架
- 新增因子基类 (BaseFactor, CrossSectionalFactor, TimeSeriesFactor) - 新增数据规格和上下文类 (DataSpec, FactorContext, FactorData) - 新增数据加载器 (DataLoader) 和执行引擎 (FactorEngine) - 新增组合因子支持 (CompositeFactor, ScalarFactor) - 添加因子模块完整测试用例 - 添加 Git 提交规范文档
This commit is contained in:
417
tests/factors/test_composite.py
Normal file
417
tests/factors/test_composite.py
Normal file
@@ -0,0 +1,417 @@
|
||||
"""测试组合因子 - CompositeFactor、ScalarFactor
|
||||
|
||||
测试需求(来自 factor_implementation_plan.md):
|
||||
- CompositeFactor:
|
||||
- 测试同类型因子组合成功(cs + cs)
|
||||
- 测试同类型因子组合成功(ts + ts)
|
||||
- 测试不同类型因子组合抛出 ValueError(cs + ts)
|
||||
- 测试无效运算符抛出 ValueError
|
||||
- 测试 `_merge_data_specs()` 正确合并(相同 source)
|
||||
- 测试 `_merge_data_specs()` 正确合并(不同 source)
|
||||
- 测试 `_merge_data_specs()` lookback 取最大值
|
||||
- 测试 `compute()` 执行正确的数学运算
|
||||
|
||||
- ScalarFactor:
|
||||
- 测试标量乘法 `0.5 * factor`
|
||||
- 测试标量乘法 `factor * 0.5`
|
||||
- 测试标量加法(如支持)
|
||||
- 测试继承基础因子的 data_specs
|
||||
- 测试 `compute()` 返回正确缩放后的值
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import polars as pl
|
||||
|
||||
from src.factors import DataSpec, FactorContext, FactorData
|
||||
from src.factors.base import CrossSectionalFactor, TimeSeriesFactor
|
||||
from src.factors.composite import CompositeFactor, ScalarFactor
|
||||
|
||||
|
||||
# ========== 测试数据准备 ==========
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_dataspec():
|
||||
"""创建一个示例 DataSpec"""
|
||||
return DataSpec(
|
||||
source="daily", columns=["ts_code", "trade_date", "close"], lookback_days=5
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def cs_factor1(sample_dataspec):
|
||||
"""截面因子 1"""
|
||||
|
||||
class CSFactor1(CrossSectionalFactor):
|
||||
name = "cs_factor1"
|
||||
data_specs = [sample_dataspec]
|
||||
|
||||
def compute(self, data):
|
||||
return pl.Series([1.0, 2.0])
|
||||
|
||||
return CSFactor1()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def cs_factor2(sample_dataspec):
|
||||
"""截面因子 2"""
|
||||
|
||||
class CSFactor2(CrossSectionalFactor):
|
||||
name = "cs_factor2"
|
||||
data_specs = [sample_dataspec]
|
||||
|
||||
def compute(self, data):
|
||||
return pl.Series([3.0, 4.0])
|
||||
|
||||
return CSFactor2()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def ts_factor1(sample_dataspec):
|
||||
"""时序因子 1"""
|
||||
|
||||
class TSFactor1(TimeSeriesFactor):
|
||||
name = "ts_factor1"
|
||||
data_specs = [sample_dataspec]
|
||||
|
||||
def compute(self, data):
|
||||
return pl.Series([10.0, 20.0, 30.0])
|
||||
|
||||
return TSFactor1()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def ts_factor2(sample_dataspec):
|
||||
"""时序因子 2"""
|
||||
|
||||
class TSFactor2(TimeSeriesFactor):
|
||||
name = "ts_factor2"
|
||||
data_specs = [sample_dataspec]
|
||||
|
||||
def compute(self, data):
|
||||
return pl.Series([1.0, 2.0, 3.0])
|
||||
|
||||
return TSFactor2()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_factor_data():
|
||||
"""创建一个示例 FactorData"""
|
||||
df = pl.DataFrame(
|
||||
{
|
||||
"ts_code": ["000001.SZ", "000002.SZ", "000001.SZ", "000002.SZ"],
|
||||
"trade_date": ["20240101", "20240101", "20240102", "20240102"],
|
||||
"close": [10.0, 20.0, 11.0, 21.0],
|
||||
}
|
||||
)
|
||||
context = FactorContext(current_date="20240102")
|
||||
return FactorData(df, context)
|
||||
|
||||
|
||||
# ========== CompositeFactor 测试 ==========
|
||||
|
||||
|
||||
class TestCompositeFactorTypeValidation:
|
||||
"""测试类型验证"""
|
||||
|
||||
def test_same_type_combination_cs(self, cs_factor1, cs_factor2):
|
||||
"""测试同类型截面因子组合成功"""
|
||||
combined = cs_factor1 + cs_factor2
|
||||
assert isinstance(combined, CompositeFactor)
|
||||
assert combined.factor_type == "cross_sectional"
|
||||
assert combined.name == "(cs_factor1_+_cs_factor2)"
|
||||
|
||||
def test_same_type_combination_ts(self, ts_factor1, ts_factor2):
|
||||
"""测试同类型时序因子组合成功"""
|
||||
combined = ts_factor1 - ts_factor2
|
||||
assert isinstance(combined, CompositeFactor)
|
||||
assert combined.factor_type == "time_series"
|
||||
assert combined.name == "(ts_factor1_-_ts_factor2)"
|
||||
|
||||
def test_different_type_raises(self, cs_factor1, ts_factor1):
|
||||
"""测试不同类型因子组合抛出 ValueError"""
|
||||
with pytest.raises(
|
||||
ValueError, match="Cannot combine factors of different types"
|
||||
):
|
||||
cs_factor1 + ts_factor1
|
||||
|
||||
def test_invalid_operator_raises(self, cs_factor1, cs_factor2):
|
||||
"""测试无效运算符抛出 ValueError"""
|
||||
with pytest.raises(ValueError, match="Unsupported operator"):
|
||||
CompositeFactor(cs_factor1, cs_factor2, "%")
|
||||
|
||||
|
||||
class TestCompositeFactorMergeDataSpecs:
|
||||
"""测试 _merge_data_specs"""
|
||||
|
||||
def test_merge_same_source_same_columns(self):
|
||||
"""测试相同 source 和 columns 的 DataSpec 合并"""
|
||||
spec1 = DataSpec("daily", ["ts_code", "trade_date", "close"], lookback_days=5)
|
||||
spec2 = DataSpec("daily", ["ts_code", "trade_date", "close"], lookback_days=10)
|
||||
|
||||
class Factor1(CrossSectionalFactor):
|
||||
name = "f1"
|
||||
data_specs = [spec1]
|
||||
|
||||
def compute(self, data):
|
||||
return pl.Series([1.0])
|
||||
|
||||
class Factor2(CrossSectionalFactor):
|
||||
name = "f2"
|
||||
data_specs = [spec2]
|
||||
|
||||
def compute(self, data):
|
||||
return pl.Series([2.0])
|
||||
|
||||
combined = Factor1() + Factor2()
|
||||
|
||||
# 应该合并成一个 DataSpec,lookback_days 取最大值 10
|
||||
assert len(combined.data_specs) == 1
|
||||
assert combined.data_specs[0].lookback_days == 10
|
||||
assert combined.data_specs[0].source == "daily"
|
||||
|
||||
def test_merge_different_source(self):
|
||||
"""测试不同 source 的 DataSpec 不合并"""
|
||||
spec1 = DataSpec("daily", ["ts_code", "trade_date", "close"], lookback_days=5)
|
||||
spec2 = DataSpec(
|
||||
"fundamental", ["ts_code", "trade_date", "pe"], lookback_days=1
|
||||
)
|
||||
|
||||
class Factor1(CrossSectionalFactor):
|
||||
name = "f1"
|
||||
data_specs = [spec1]
|
||||
|
||||
def compute(self, data):
|
||||
return pl.Series([1.0])
|
||||
|
||||
class Factor2(CrossSectionalFactor):
|
||||
name = "f2"
|
||||
data_specs = [spec2]
|
||||
|
||||
def compute(self, data):
|
||||
return pl.Series([2.0])
|
||||
|
||||
combined = Factor1() + Factor2()
|
||||
|
||||
# 应该有两个 DataSpec
|
||||
assert len(combined.data_specs) == 2
|
||||
sources = {s.source for s in combined.data_specs}
|
||||
assert sources == {"daily", "fundamental"}
|
||||
|
||||
def test_merge_same_source_different_columns(self):
|
||||
"""测试相同 source 但不同 columns 的 DataSpec 不合并"""
|
||||
spec1 = DataSpec("daily", ["ts_code", "trade_date", "close"], lookback_days=5)
|
||||
spec2 = DataSpec("daily", ["ts_code", "trade_date", "open"], lookback_days=3)
|
||||
|
||||
class Factor1(CrossSectionalFactor):
|
||||
name = "f1"
|
||||
data_specs = [spec1]
|
||||
|
||||
def compute(self, data):
|
||||
return pl.Series([1.0])
|
||||
|
||||
class Factor2(CrossSectionalFactor):
|
||||
name = "f2"
|
||||
data_specs = [spec2]
|
||||
|
||||
def compute(self, data):
|
||||
return pl.Series([2.0])
|
||||
|
||||
combined = Factor1() + Factor2()
|
||||
|
||||
# 应该有两个 DataSpec(因为 columns 不同)
|
||||
assert len(combined.data_specs) == 2
|
||||
|
||||
def test_merge_lookback_max(self):
|
||||
"""测试 lookback_days 取最大值"""
|
||||
spec1 = DataSpec("daily", ["ts_code", "trade_date", "close"], lookback_days=5)
|
||||
spec2 = DataSpec("daily", ["ts_code", "trade_date", "close"], lookback_days=20)
|
||||
spec3 = DataSpec("daily", ["ts_code", "trade_date", "close"], lookback_days=10)
|
||||
|
||||
class Factor1(CrossSectionalFactor):
|
||||
name = "f1"
|
||||
data_specs = [spec1]
|
||||
|
||||
def compute(self, data):
|
||||
return pl.Series([1.0])
|
||||
|
||||
class Factor2(CrossSectionalFactor):
|
||||
name = "f2"
|
||||
data_specs = [spec2, spec3]
|
||||
|
||||
def compute(self, data):
|
||||
return pl.Series([2.0])
|
||||
|
||||
combined = Factor1() + Factor2()
|
||||
|
||||
# 应该合并成一个 DataSpec,lookback_days 取最大值 20
|
||||
assert len(combined.data_specs) == 1
|
||||
assert combined.data_specs[0].lookback_days == 20
|
||||
|
||||
|
||||
class TestCompositeFactorCompute:
|
||||
"""测试 compute 运算"""
|
||||
|
||||
def test_compute_addition(self, cs_factor1, cs_factor2, sample_factor_data):
|
||||
"""测试加法运算"""
|
||||
combined = cs_factor1 + cs_factor2
|
||||
result = combined.compute(sample_factor_data)
|
||||
|
||||
# [1.0, 2.0] + [3.0, 4.0] = [4.0, 6.0]
|
||||
expected = pl.Series([4.0, 6.0])
|
||||
assert (result - expected).abs().max() < 1e-10
|
||||
|
||||
def test_compute_subtraction(self, cs_factor1, cs_factor2, sample_factor_data):
|
||||
"""测试减法运算"""
|
||||
combined = cs_factor1 - cs_factor2
|
||||
result = combined.compute(sample_factor_data)
|
||||
|
||||
# [1.0, 2.0] - [3.0, 4.0] = [-2.0, -2.0]
|
||||
expected = pl.Series([-2.0, -2.0])
|
||||
assert (result - expected).abs().max() < 1e-10
|
||||
|
||||
def test_compute_multiplication(self, cs_factor1, cs_factor2, sample_factor_data):
|
||||
"""测试乘法运算"""
|
||||
combined = cs_factor1 * cs_factor2
|
||||
result = combined.compute(sample_factor_data)
|
||||
|
||||
# [1.0, 2.0] * [3.0, 4.0] = [3.0, 8.0]
|
||||
expected = pl.Series([3.0, 8.0])
|
||||
assert (result - expected).abs().max() < 1e-10
|
||||
|
||||
def test_compute_division(self, cs_factor1, cs_factor2, sample_factor_data):
|
||||
"""测试除法运算"""
|
||||
combined = cs_factor1 / cs_factor2
|
||||
result = combined.compute(sample_factor_data)
|
||||
|
||||
# [1.0, 2.0] / [3.0, 4.0] = [0.333..., 0.5]
|
||||
expected = pl.Series([1.0 / 3.0, 0.5])
|
||||
assert (result - expected).abs().max() < 1e-10
|
||||
|
||||
def test_compute_with_ts_factors(self, ts_factor1, ts_factor2, sample_factor_data):
|
||||
"""测试时序因子的组合运算"""
|
||||
combined = ts_factor1 + ts_factor2
|
||||
result = combined.compute(sample_factor_data)
|
||||
|
||||
# [10.0, 20.0, 30.0] + [1.0, 2.0, 3.0] = [11.0, 22.0, 33.0]
|
||||
expected = pl.Series([11.0, 22.0, 33.0])
|
||||
assert (result - expected).abs().max() < 1e-10
|
||||
|
||||
def test_chained_combination(self, cs_factor1, cs_factor2, sample_factor_data):
|
||||
"""测试链式组合 (f1 + f2) * f1"""
|
||||
|
||||
# 创建第三个因子
|
||||
class CSFactor3(CrossSectionalFactor):
|
||||
name = "cs_factor3"
|
||||
data_specs = [
|
||||
DataSpec("daily", ["ts_code", "trade_date", "close"], lookback_days=5)
|
||||
]
|
||||
|
||||
def compute(self, data):
|
||||
return pl.Series([0.5, 1.0])
|
||||
|
||||
f3 = CSFactor3()
|
||||
|
||||
# (f1 + f2) * f3
|
||||
# f1 + f2 = [1.0, 2.0] + [3.0, 4.0] = [4.0, 6.0]
|
||||
# [4.0, 6.0] * [0.5, 1.0] = [2.0, 6.0]
|
||||
combined = (cs_factor1 + cs_factor2) * f3
|
||||
result = combined.compute(sample_factor_data)
|
||||
|
||||
expected = pl.Series([2.0, 6.0])
|
||||
assert (result - expected).abs().max() < 1e-10
|
||||
|
||||
|
||||
# ========== ScalarFactor 测试 ==========
|
||||
|
||||
|
||||
class TestScalarFactor:
|
||||
"""测试 ScalarFactor"""
|
||||
|
||||
def test_scalar_multiplication_left(self, cs_factor1):
|
||||
"""测试标量乘法 `0.5 * factor`(左乘)"""
|
||||
scaled = 0.5 * cs_factor1
|
||||
assert isinstance(scaled, ScalarFactor)
|
||||
assert scaled.scalar == 0.5
|
||||
assert scaled.op == "*"
|
||||
assert scaled.factor == cs_factor1
|
||||
|
||||
def test_scalar_multiplication_right(self, cs_factor1):
|
||||
"""测试标量乘法 `factor * 0.5`(右乘)"""
|
||||
scaled = cs_factor1 * 0.5
|
||||
assert isinstance(scaled, ScalarFactor)
|
||||
assert scaled.scalar == 0.5
|
||||
assert scaled.op == "*"
|
||||
|
||||
def test_scalar_integer_multiplication(self, cs_factor1):
|
||||
"""测试整数标量乘法"""
|
||||
scaled = 2 * cs_factor1
|
||||
assert isinstance(scaled, ScalarFactor)
|
||||
assert scaled.scalar == 2.0
|
||||
|
||||
def test_inherits_data_specs(self, cs_factor1):
|
||||
"""测试继承基础因子的 data_specs"""
|
||||
scaled = 0.5 * cs_factor1
|
||||
assert scaled.data_specs == cs_factor1.data_specs
|
||||
|
||||
def test_compute_multiplication(self, cs_factor1, sample_factor_data):
|
||||
"""测试标量乘法 compute 结果"""
|
||||
scaled = 0.5 * cs_factor1
|
||||
result = scaled.compute(sample_factor_data)
|
||||
|
||||
# [1.0, 2.0] * 0.5 = [0.5, 1.0]
|
||||
expected = pl.Series([0.5, 1.0])
|
||||
assert (result - expected).abs().max() < 1e-10
|
||||
|
||||
def test_compute_with_ts_factor(self, ts_factor1, sample_factor_data):
|
||||
"""测试时序因子的标量乘法"""
|
||||
scaled = 0.1 * ts_factor1
|
||||
result = scaled.compute(sample_factor_data)
|
||||
|
||||
# [10.0, 20.0, 30.0] * 0.1 = [1.0, 2.0, 3.0]
|
||||
expected = pl.Series([1.0, 2.0, 3.0])
|
||||
assert (result - expected).abs().max() < 1e-10
|
||||
|
||||
def test_factor_type_preserved(self, cs_factor1, ts_factor1):
|
||||
"""测试 factor_type 被正确保留"""
|
||||
scaled_cs = 0.5 * cs_factor1
|
||||
scaled_ts = 0.5 * ts_factor1
|
||||
|
||||
assert scaled_cs.factor_type == "cross_sectional"
|
||||
assert scaled_ts.factor_type == "time_series"
|
||||
|
||||
def test_scalar_name_format(self, cs_factor1):
|
||||
"""测试 ScalarFactor 的 name 格式"""
|
||||
scaled = 0.5 * cs_factor1
|
||||
assert scaled.name == "(0.5_*_cs_factor1)"
|
||||
|
||||
|
||||
# ========== 组合和标量混合测试 ==========
|
||||
|
||||
|
||||
class TestMixedOperations:
|
||||
"""测试组合因子和标量因子的混合运算"""
|
||||
|
||||
def test_scalar_then_combine(self, cs_factor1, cs_factor2, sample_factor_data):
|
||||
"""测试先标量缩放再组合"""
|
||||
# 0.5 * f1 + 0.3 * f2
|
||||
combined = 0.5 * cs_factor1 + 0.3 * cs_factor2
|
||||
result = combined.compute(sample_factor_data)
|
||||
|
||||
# 0.5 * [1.0, 2.0] + 0.3 * [3.0, 4.0]
|
||||
# = [0.5, 1.0] + [0.9, 1.2]
|
||||
# = [1.4, 2.2]
|
||||
expected = pl.Series([1.4, 2.2])
|
||||
assert (result - expected).abs().max() < 1e-10
|
||||
|
||||
def test_complex_formula(self, cs_factor1, cs_factor2, sample_factor_data):
|
||||
"""测试复杂公式: (f1 + f2) * 0.5 - f1 * 0.2"""
|
||||
formula = (cs_factor1 + cs_factor2) * 0.5 - cs_factor1 * 0.2
|
||||
result = formula.compute(sample_factor_data)
|
||||
|
||||
# (f1 + f2) = [4.0, 6.0]
|
||||
# (f1 + f2) * 0.5 = [2.0, 3.0]
|
||||
# f1 * 0.2 = [0.2, 0.4]
|
||||
# [2.0, 3.0] - [0.2, 0.4] = [1.8, 2.6]
|
||||
expected = pl.Series([1.8, 2.6])
|
||||
assert (result - expected).abs().max() < 1e-10
|
||||
Reference in New Issue
Block a user