Files
ProStock/tests/factors/test_engine.py
liaozhaorun 0a16129548 feat(factors): 添加因子计算框架
- 新增因子基类 (BaseFactor, CrossSectionalFactor, TimeSeriesFactor)
- 新增数据规格和上下文类 (DataSpec, FactorContext, FactorData)
- 新增数据加载器 (DataLoader) 和执行引擎 (FactorEngine)
- 新增组合因子支持 (CompositeFactor, ScalarFactor)
- 添加因子模块完整测试用例
- 添加 Git 提交规范文档
2026-02-22 14:41:32 +08:00

267 lines
7.9 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""测试执行引擎 - FactorEngine
测试需求(来自 factor_implementation_plan.md
- 测试 `compute()` 正确分发给截面计算
- 测试 `compute()` 正确分发给时序计算
- 测试无效 factor_type 时抛出 ValueError
截面计算测试(防泄露验证):
- 测试数据裁剪正确(传入 [T-lookback+1, T]
- 测试不包含未来日期 T+1 的数据
- 测试每个日期独立计算
- 测试结果包含所有日期和所有股票
- 测试结果 DataFrame 格式正确
- 测试多个 DataSpec 时 lookback 取最大值
时序计算测试(防泄露验证):
- 测试每只股票只看到自己的数据
- 测试不包含其他股票的数据
- 测试传入的是完整时间序列(向量化计算)
- 测试结果包含所有股票和所有日期
- 测试结果 DataFrame 格式正确
- 测试股票不在数据中时跳过(或填充 null
"""
import pytest
import polars as pl
from src.factors import (
DataSpec,
FactorContext,
FactorData,
DataLoader,
FactorEngine,
CrossSectionalFactor,
TimeSeriesFactor,
)
class SimpleCrossSectionalFactor(CrossSectionalFactor):
"""简单的截面因子 - 返回收盘价排名"""
name = "close_rank"
data_specs = [
DataSpec("daily", ["ts_code", "trade_date", "close"], lookback_days=1)
]
def compute(self, data: FactorData) -> pl.Series:
cs = data.get_cross_section()
return cs["close"].rank()
class SimpleTimeSeriesFactor(TimeSeriesFactor):
"""简单的时序因子 - 返回3日移动平均"""
name = "ma3"
data_specs = [
DataSpec(
"daily",
["ts_code", "trade_date", "close"],
lookback_days=5,
)
]
def __init__(self, period: int = 3):
super().__init__(period=period)
def compute(self, data: FactorData) -> pl.Series:
return data.get_column("close").rolling_mean(self.params["period"])
class ReturnFactor(CrossSectionalFactor):
"""收益率因子 - 需要2天lookback计算收益率"""
name = "return"
data_specs = [
DataSpec("daily", ["ts_code", "trade_date", "close"], lookback_days=2)
]
def compute(self, data: FactorData) -> pl.Series:
# 获取当前日期
current_date = data.context.current_date
# 获取当前日期的数据
cs = data.get_cross_section()
# 简单返回收盘价作为因子值
# 实际应该计算收益率,但这里简化处理
return cs["close"]
@pytest.fixture
def loader():
"""创建 DataLoader 实例"""
return DataLoader(data_dir="data")
@pytest.fixture
def engine(loader):
"""创建 FactorEngine 实例"""
return FactorEngine(loader)
class TestFactorEngineDispatch:
"""测试引擎分发逻辑"""
def test_dispatch_cross_sectional(self, engine):
"""测试 compute() 正确分发给截面计算"""
factor = SimpleCrossSectionalFactor()
result = engine.compute(factor, start_date="20240101", end_date="20240105")
assert isinstance(result, pl.DataFrame)
assert "trade_date" in result.columns
assert "ts_code" in result.columns
assert "close_rank" in result.columns
def test_dispatch_time_series(self, engine, loader):
"""测试 compute() 正确分发给时序计算"""
factor = SimpleTimeSeriesFactor(period=3)
# 获取一些股票代码
sample_data = loader.load(
[DataSpec("daily", ["ts_code", "trade_date"], lookback_days=1)]
)
stock_codes = sample_data["ts_code"].unique().head(3).to_list()
result = engine.compute(
factor,
stock_codes=stock_codes,
start_date="20240101",
end_date="20240110",
)
assert isinstance(result, pl.DataFrame)
assert "trade_date" in result.columns
assert "ts_code" in result.columns
assert "ma3" in result.columns
def test_unknown_factor_type(self, engine):
"""测试无效 factor_type 时抛出 ValueError"""
class UnknownFactor:
name = "unknown"
factor_type = "unknown_type"
data_specs = []
factor = UnknownFactor()
with pytest.raises(ValueError, match="Unknown factor type"):
engine.compute(factor)
class TestCrossSectionalComputation:
"""测试截面计算(防泄露验证)"""
def test_result_format(self, engine):
"""测试结果 DataFrame 格式正确"""
factor = SimpleCrossSectionalFactor()
result = engine.compute(factor, start_date="20240101", end_date="20240105")
# 检查列
assert "trade_date" in result.columns
assert "ts_code" in result.columns
assert factor.name in result.columns
# 检查类型
assert result["trade_date"].dtype == pl.Utf8
assert result["ts_code"].dtype == pl.Utf8
def test_all_dates_present(self, engine):
"""测试结果包含所有日期"""
factor = SimpleCrossSectionalFactor()
start_date = "20240101"
end_date = "20240105"
result = engine.compute(factor, start_date=start_date, end_date=end_date)
if len(result) > 0:
dates = result["trade_date"].unique().to_list()
# 应该包含 start_date 和 end_date 之间的日期
assert len(dates) > 0
def test_lookback_window(self, engine):
"""测试多个 DataSpec 时 lookback 取最大值"""
factor = ReturnFactor()
# lookback_days = 2
result = engine.compute(factor, start_date="20240103", end_date="20240105")
# 应该能计算出结果
assert isinstance(result, pl.DataFrame)
class TestTimeSeriesComputation:
"""测试时序计算(防泄露验证)"""
def test_result_format(self, engine):
"""测试结果 DataFrame 格式正确"""
factor = SimpleTimeSeriesFactor(period=3)
result = engine.compute(
factor,
stock_codes=["000001.SZ"],
start_date="20240101",
end_date="20240110",
)
# 检查列
assert "trade_date" in result.columns
assert "ts_code" in result.columns
assert factor.name in result.columns
def test_single_stock_data(self, engine):
"""测试每只股票只看到自己的数据"""
factor = SimpleTimeSeriesFactor(period=3)
stock_codes = ["000001.SZ"]
result = engine.compute(
factor,
stock_codes=stock_codes,
start_date="20240101",
end_date="20240110",
)
if len(result) > 0:
# 结果中只应该有指定的股票
stocks = result["ts_code"].unique().to_list()
assert set(stocks) == set(stock_codes)
def test_ma_calculation(self, engine):
"""测试移动平均计算"""
factor = SimpleTimeSeriesFactor(period=3)
result = engine.compute(
factor,
stock_codes=["000001.SZ"],
start_date="20240101",
end_date="20240110",
)
if len(result) > 2:
# 前2个应该是 null因为 period=3
ma_values = result[factor.name].to_list()
assert ma_values[0] is None or str(ma_values[0]) == "nan"
assert ma_values[1] is None or str(ma_values[1]) == "nan"
# 第3个应该有值
assert ma_values[2] is not None
def test_missing_stock_skipped(self, engine):
"""测试股票不在数据中时返回空结果"""
factor = SimpleTimeSeriesFactor(period=3)
result = engine.compute(
factor,
stock_codes=["NONEXISTENT.STOCK"],
start_date="20240101",
end_date="20240110",
)
# 应该返回空 DataFrame 或包含该股票但值为 null 的结果
assert isinstance(result, pl.DataFrame)
# 对于不存在的股票,结果可能是空的
# 或者包含该股票但值为 null