Files
ProStock/tests/factors/test_base.py
liaozhaorun 0a16129548 feat(factors): 添加因子计算框架
- 新增因子基类 (BaseFactor, CrossSectionalFactor, TimeSeriesFactor)
- 新增数据规格和上下文类 (DataSpec, FactorContext, FactorData)
- 新增数据加载器 (DataLoader) 和执行引擎 (FactorEngine)
- 新增组合因子支持 (CompositeFactor, ScalarFactor)
- 添加因子模块完整测试用例
- 添加 Git 提交规范文档
2026-02-22 14:41:32 +08:00

407 lines
13 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""测试因子基类 - BaseFactor、CrossSectionalFactor、TimeSeriesFactor
测试需求(来自 factor_implementation_plan.md
- BaseFactor:
- 测试有效子类创建通过验证
- 测试缺少 `name` 时抛出 ValueError
- 测试 `name` 为空字符串时抛出 ValueError
- 测试缺少 `factor_type` 时抛出 ValueError
- 测试无效的 `factor_type`(非 cs/ts时抛出 ValueError
- 测试缺少 `data_specs` 时抛出 ValueError
- 测试 `data_specs` 为空列表时抛出 ValueError
- 测试 `compute()` 抽象方法强制子类实现
- 测试参数化初始化 `params` 正确存储
- 测试 `_validate_params()` 被调用
- CrossSectionalFactor:
- 测试 `factor_type` 自动设置为 "cross_sectional"
- 测试子类必须实现 `compute()`
- 测试 `compute()` 返回类型为 pl.Series
- TimeSeriesFactor:
- 测试 `factor_type` 自动设置为 "time_series"
- 测试子类必须实现 `compute()`
- 测试 `compute()` 返回类型为 pl.Series
"""
import pytest
import polars as pl
from src.factors import DataSpec, FactorContext, FactorData
from src.factors.base import BaseFactor, CrossSectionalFactor, TimeSeriesFactor
# ========== 测试数据准备 ==========
@pytest.fixture
def sample_dataspec():
"""创建一个示例 DataSpec"""
return DataSpec(
source="daily", columns=["ts_code", "trade_date", "close"], lookback_days=5
)
@pytest.fixture
def sample_factor_data():
"""创建一个示例 FactorData"""
df = pl.DataFrame(
{
"ts_code": ["000001.SZ", "000002.SZ", "000001.SZ", "000002.SZ"],
"trade_date": ["20240101", "20240101", "20240102", "20240102"],
"close": [10.0, 20.0, 11.0, 21.0],
}
)
context = FactorContext(current_date="20240102")
return FactorData(df, context)
# ========== BaseFactor 测试 ==========
class TestBaseFactorValidation:
"""测试 BaseFactor 子类验证"""
def test_valid_cross_sectional_subclass(self, sample_dataspec):
"""测试有效的截面因子子类创建通过验证"""
class ValidFactor(CrossSectionalFactor):
name = "valid_cs"
data_specs = [sample_dataspec]
def compute(self, data):
return pl.Series([1.0])
# 应该能成功创建实例
factor = ValidFactor()
assert factor.name == "valid_cs"
assert factor.factor_type == "cross_sectional"
def test_valid_time_series_subclass(self, sample_dataspec):
"""测试有效的时序因子子类创建通过验证"""
class ValidFactor(TimeSeriesFactor):
name = "valid_ts"
data_specs = [sample_dataspec]
def compute(self, data):
return pl.Series([1.0])
factor = ValidFactor()
assert factor.name == "valid_ts"
assert factor.factor_type == "time_series"
def test_missing_name(self, sample_dataspec):
"""测试缺少 name 时抛出 ValueError"""
with pytest.raises(ValueError, match="must define 'name'"):
class BadFactor(CrossSectionalFactor):
# name = "" # 故意不定义
data_specs = [sample_dataspec]
def compute(self, data):
return pl.Series([1.0])
def test_empty_name(self, sample_dataspec):
"""测试 name 为空字符串时抛出 ValueError"""
with pytest.raises(ValueError, match="must define 'name'"):
class BadFactor(CrossSectionalFactor):
name = "" # 空字符串
data_specs = [sample_dataspec]
def compute(self, data):
return pl.Series([1.0])
def test_missing_factor_type(self, sample_dataspec):
"""测试缺少 factor_type 时抛出 ValueError"""
with pytest.raises(ValueError, match="must define 'factor_type'"):
class BadFactor(BaseFactor):
name = "bad_factor"
# factor_type = "" # 故意不定义
data_specs = [sample_dataspec]
def compute(self, data):
return pl.Series([1.0])
def test_invalid_factor_type(self, sample_dataspec):
"""测试无效的 factor_type非 cs/ts时抛出 ValueError"""
with pytest.raises(
ValueError, match="must be 'cross_sectional' or 'time_series'"
):
class BadFactor(BaseFactor):
name = "bad_factor"
factor_type = "invalid_type"
data_specs = [sample_dataspec]
def compute(self, data):
return pl.Series([1.0])
def test_missing_data_specs(self):
"""测试缺少 data_specs 时抛出 ValueError"""
with pytest.raises(ValueError, match="must define 'data_specs'"):
class BadFactor(BaseFactor):
name = "bad_factor"
factor_type = "cross_sectional"
# data_specs = [] # 故意不定义
def compute(self, data):
return pl.Series([1.0])
def test_empty_data_specs(self):
"""测试 data_specs 为空列表时抛出 ValueError"""
with pytest.raises(ValueError, match="cannot be empty"):
class BadFactor(CrossSectionalFactor):
name = "bad_factor"
data_specs = [] # 空列表
def compute(self, data):
return pl.Series([1.0])
class TestBaseFactorCompute:
"""测试 compute() 抽象方法"""
def test_compute_must_be_implemented_cs(self, sample_dataspec):
"""测试截面因子子类必须实现 compute()"""
with pytest.raises(TypeError, match="Can't instantiate abstract class"):
class BadFactor(CrossSectionalFactor):
name = "bad_cs"
data_specs = [sample_dataspec]
# 不实现 compute()
BadFactor()
def test_compute_must_be_implemented_ts(self, sample_dataspec):
"""测试时序因子子类必须实现 compute()"""
with pytest.raises(TypeError, match="Can't instantiate abstract class"):
class BadFactor(TimeSeriesFactor):
name = "bad_ts"
data_specs = [sample_dataspec]
# 不实现 compute()
BadFactor()
class TestBaseFactorParams:
"""测试参数化初始化"""
def test_params_stored_correctly(self, sample_dataspec):
"""测试参数化初始化 params 正确存储"""
class ParamFactor(CrossSectionalFactor):
name = "param_factor"
data_specs = [sample_dataspec]
def __init__(self, period: int = 20, weight: float = 1.0):
super().__init__(period=period, weight=weight)
def compute(self, data):
return pl.Series([1.0])
factor = ParamFactor(period=10, weight=0.5)
assert factor.params["period"] == 10
assert factor.params["weight"] == 0.5
def test_validate_params_called(self, sample_dataspec):
"""测试 _validate_params() 被调用"""
validated = []
class ValidatedFactor(CrossSectionalFactor):
name = "validated_factor"
data_specs = [sample_dataspec]
def __init__(self, period: int = 20):
super().__init__(period=period)
def _validate_params(self):
validated.append(True)
if self.params.get("period", 0) <= 0:
raise ValueError("period must be positive")
def compute(self, data):
return pl.Series([1.0])
# 创建实例时应该调用 _validate_params
factor = ValidatedFactor(period=10)
assert len(validated) == 1
assert factor.params["period"] == 10
def test_validate_params_raises(self, sample_dataspec):
"""测试 _validate_params() 可以抛出异常"""
class BadParamFactor(CrossSectionalFactor):
name = "bad_param_factor"
data_specs = [sample_dataspec]
def __init__(self, period: int = 20):
super().__init__(period=period)
def _validate_params(self):
if self.params.get("period", 0) <= 0:
raise ValueError("period must be positive")
def compute(self, data):
return pl.Series([1.0])
with pytest.raises(ValueError, match="period must be positive"):
BadParamFactor(period=-5)
class TestBaseFactorRepr:
"""测试 __repr__"""
def test_repr(self, sample_dataspec):
"""测试 __repr__ 返回正确格式"""
class TestFactor(CrossSectionalFactor):
name = "test_factor"
data_specs = [sample_dataspec]
def compute(self, data):
return pl.Series([1.0])
factor = TestFactor()
repr_str = repr(factor)
assert "TestFactor" in repr_str
assert "test_factor" in repr_str
assert "cross_sectional" in repr_str
# ========== CrossSectionalFactor 测试 ==========
class TestCrossSectionalFactor:
"""测试 CrossSectionalFactor"""
def test_factor_type_auto_set(self, sample_dataspec):
"""测试 factor_type 自动设置为 'cross_sectional'"""
class CSFactor(CrossSectionalFactor):
name = "cs_factor"
data_specs = [sample_dataspec]
def compute(self, data):
return pl.Series([1.0])
factor = CSFactor()
assert factor.factor_type == "cross_sectional"
def test_compute_returns_series(self, sample_factor_data, sample_dataspec):
"""测试 compute() 返回类型为 pl.Series"""
class CSFactor(CrossSectionalFactor):
name = "cs_factor"
data_specs = [sample_dataspec]
def compute(self, data):
# 返回一个简单的 Series
return pl.Series([1.0, 2.0])
factor = CSFactor()
result = factor.compute(sample_factor_data)
assert isinstance(result, pl.Series)
def test_compute_with_cross_section(self, sample_dataspec):
"""测试 compute() 使用 get_cross_section()"""
df = pl.DataFrame(
{
"ts_code": ["000001.SZ", "000002.SZ"],
"trade_date": ["20240101", "20240101"],
"close": [10.0, 20.0],
}
)
context = FactorContext(current_date="20240101")
data = FactorData(df, context)
class RankFactor(CrossSectionalFactor):
name = "rank_factor"
data_specs = [sample_dataspec]
def compute(self, data):
cs = data.get_cross_section()
return cs["close"].rank()
factor = RankFactor()
result = factor.compute(data)
assert isinstance(result, pl.Series)
assert len(result) == 2
# ========== TimeSeriesFactor 测试 ==========
class TestTimeSeriesFactor:
"""测试 TimeSeriesFactor"""
def test_factor_type_auto_set(self, sample_dataspec):
"""测试 factor_type 自动设置为 'time_series'"""
class TSFactor(TimeSeriesFactor):
name = "ts_factor"
data_specs = [sample_dataspec]
def compute(self, data):
return pl.Series([1.0])
factor = TSFactor()
assert factor.factor_type == "time_series"
def test_compute_returns_series(self, sample_factor_data, sample_dataspec):
"""测试 compute() 返回类型为 pl.Series"""
class TSFactor(TimeSeriesFactor):
name = "ts_factor"
data_specs = [sample_dataspec]
def compute(self, data):
return data.get_column("close") * 2
factor = TSFactor()
result = factor.compute(sample_factor_data)
assert isinstance(result, pl.Series)
assert len(result) == 4 # 4行数据
def test_compute_with_rolling(self, sample_dataspec):
"""测试 compute() 使用 rolling 操作"""
df = pl.DataFrame(
{
"ts_code": ["000001.SZ"] * 5,
"trade_date": [
"20240101",
"20240102",
"20240103",
"20240104",
"20240105",
],
"close": [10.0, 11.0, 12.0, 13.0, 14.0],
}
)
context = FactorContext(current_stock="000001.SZ")
data = FactorData(df, context)
class MAFactor(TimeSeriesFactor):
name = "ma_factor"
data_specs = [sample_dataspec]
def __init__(self, period: int = 3):
super().__init__(period=period)
def compute(self, data):
return data.get_column("close").rolling_mean(self.params["period"])
factor = MAFactor(period=3)
result = factor.compute(data)
assert isinstance(result, pl.Series)
assert len(result) == 5
# 前2个应该是 null因为 period=3
assert result[0] is None
assert result[1] is None
assert result[2] is not None