267 lines
7.9 KiB
Python
267 lines
7.9 KiB
Python
|
|
"""测试执行引擎 - FactorEngine
|
|||
|
|
|
|||
|
|
测试需求(来自 factor_implementation_plan.md):
|
|||
|
|
- 测试 `compute()` 正确分发给截面计算
|
|||
|
|
- 测试 `compute()` 正确分发给时序计算
|
|||
|
|
- 测试无效 factor_type 时抛出 ValueError
|
|||
|
|
|
|||
|
|
截面计算测试(防泄露验证):
|
|||
|
|
- 测试数据裁剪正确(传入 [T-lookback+1, T])
|
|||
|
|
- 测试不包含未来日期 T+1 的数据
|
|||
|
|
- 测试每个日期独立计算
|
|||
|
|
- 测试结果包含所有日期和所有股票
|
|||
|
|
- 测试结果 DataFrame 格式正确
|
|||
|
|
- 测试多个 DataSpec 时 lookback 取最大值
|
|||
|
|
|
|||
|
|
时序计算测试(防泄露验证):
|
|||
|
|
- 测试每只股票只看到自己的数据
|
|||
|
|
- 测试不包含其他股票的数据
|
|||
|
|
- 测试传入的是完整时间序列(向量化计算)
|
|||
|
|
- 测试结果包含所有股票和所有日期
|
|||
|
|
- 测试结果 DataFrame 格式正确
|
|||
|
|
- 测试股票不在数据中时跳过(或填充 null)
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
import pytest
|
|||
|
|
import polars as pl
|
|||
|
|
|
|||
|
|
from src.factors import (
|
|||
|
|
DataSpec,
|
|||
|
|
FactorContext,
|
|||
|
|
FactorData,
|
|||
|
|
DataLoader,
|
|||
|
|
FactorEngine,
|
|||
|
|
CrossSectionalFactor,
|
|||
|
|
TimeSeriesFactor,
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
|
|||
|
|
class SimpleCrossSectionalFactor(CrossSectionalFactor):
|
|||
|
|
"""简单的截面因子 - 返回收盘价排名"""
|
|||
|
|
|
|||
|
|
name = "close_rank"
|
|||
|
|
data_specs = [
|
|||
|
|
DataSpec("daily", ["ts_code", "trade_date", "close"], lookback_days=1)
|
|||
|
|
]
|
|||
|
|
|
|||
|
|
def compute(self, data: FactorData) -> pl.Series:
|
|||
|
|
cs = data.get_cross_section()
|
|||
|
|
return cs["close"].rank()
|
|||
|
|
|
|||
|
|
|
|||
|
|
class SimpleTimeSeriesFactor(TimeSeriesFactor):
|
|||
|
|
"""简单的时序因子 - 返回3日移动平均"""
|
|||
|
|
|
|||
|
|
name = "ma3"
|
|||
|
|
data_specs = [
|
|||
|
|
DataSpec(
|
|||
|
|
"daily",
|
|||
|
|
["ts_code", "trade_date", "close"],
|
|||
|
|
lookback_days=5,
|
|||
|
|
)
|
|||
|
|
]
|
|||
|
|
|
|||
|
|
def __init__(self, period: int = 3):
|
|||
|
|
super().__init__(period=period)
|
|||
|
|
|
|||
|
|
def compute(self, data: FactorData) -> pl.Series:
|
|||
|
|
return data.get_column("close").rolling_mean(self.params["period"])
|
|||
|
|
|
|||
|
|
|
|||
|
|
class ReturnFactor(CrossSectionalFactor):
|
|||
|
|
"""收益率因子 - 需要2天lookback计算收益率"""
|
|||
|
|
|
|||
|
|
name = "return"
|
|||
|
|
data_specs = [
|
|||
|
|
DataSpec("daily", ["ts_code", "trade_date", "close"], lookback_days=2)
|
|||
|
|
]
|
|||
|
|
|
|||
|
|
def compute(self, data: FactorData) -> pl.Series:
|
|||
|
|
# 获取当前日期
|
|||
|
|
current_date = data.context.current_date
|
|||
|
|
|
|||
|
|
# 获取当前日期的数据
|
|||
|
|
cs = data.get_cross_section()
|
|||
|
|
|
|||
|
|
# 简单返回收盘价作为因子值
|
|||
|
|
# 实际应该计算收益率,但这里简化处理
|
|||
|
|
return cs["close"]
|
|||
|
|
|
|||
|
|
|
|||
|
|
@pytest.fixture
|
|||
|
|
def loader():
|
|||
|
|
"""创建 DataLoader 实例"""
|
|||
|
|
return DataLoader(data_dir="data")
|
|||
|
|
|
|||
|
|
|
|||
|
|
@pytest.fixture
|
|||
|
|
def engine(loader):
|
|||
|
|
"""创建 FactorEngine 实例"""
|
|||
|
|
return FactorEngine(loader)
|
|||
|
|
|
|||
|
|
|
|||
|
|
class TestFactorEngineDispatch:
|
|||
|
|
"""测试引擎分发逻辑"""
|
|||
|
|
|
|||
|
|
def test_dispatch_cross_sectional(self, engine):
|
|||
|
|
"""测试 compute() 正确分发给截面计算"""
|
|||
|
|
factor = SimpleCrossSectionalFactor()
|
|||
|
|
|
|||
|
|
result = engine.compute(factor, start_date="20240101", end_date="20240105")
|
|||
|
|
|
|||
|
|
assert isinstance(result, pl.DataFrame)
|
|||
|
|
assert "trade_date" in result.columns
|
|||
|
|
assert "ts_code" in result.columns
|
|||
|
|
assert "close_rank" in result.columns
|
|||
|
|
|
|||
|
|
def test_dispatch_time_series(self, engine, loader):
|
|||
|
|
"""测试 compute() 正确分发给时序计算"""
|
|||
|
|
factor = SimpleTimeSeriesFactor(period=3)
|
|||
|
|
|
|||
|
|
# 获取一些股票代码
|
|||
|
|
sample_data = loader.load(
|
|||
|
|
[DataSpec("daily", ["ts_code", "trade_date"], lookback_days=1)]
|
|||
|
|
)
|
|||
|
|
stock_codes = sample_data["ts_code"].unique().head(3).to_list()
|
|||
|
|
|
|||
|
|
result = engine.compute(
|
|||
|
|
factor,
|
|||
|
|
stock_codes=stock_codes,
|
|||
|
|
start_date="20240101",
|
|||
|
|
end_date="20240110",
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
assert isinstance(result, pl.DataFrame)
|
|||
|
|
assert "trade_date" in result.columns
|
|||
|
|
assert "ts_code" in result.columns
|
|||
|
|
assert "ma3" in result.columns
|
|||
|
|
|
|||
|
|
def test_unknown_factor_type(self, engine):
|
|||
|
|
"""测试无效 factor_type 时抛出 ValueError"""
|
|||
|
|
|
|||
|
|
class UnknownFactor:
|
|||
|
|
name = "unknown"
|
|||
|
|
factor_type = "unknown_type"
|
|||
|
|
data_specs = []
|
|||
|
|
|
|||
|
|
factor = UnknownFactor()
|
|||
|
|
|
|||
|
|
with pytest.raises(ValueError, match="Unknown factor type"):
|
|||
|
|
engine.compute(factor)
|
|||
|
|
|
|||
|
|
|
|||
|
|
class TestCrossSectionalComputation:
|
|||
|
|
"""测试截面计算(防泄露验证)"""
|
|||
|
|
|
|||
|
|
def test_result_format(self, engine):
|
|||
|
|
"""测试结果 DataFrame 格式正确"""
|
|||
|
|
factor = SimpleCrossSectionalFactor()
|
|||
|
|
|
|||
|
|
result = engine.compute(factor, start_date="20240101", end_date="20240105")
|
|||
|
|
|
|||
|
|
# 检查列
|
|||
|
|
assert "trade_date" in result.columns
|
|||
|
|
assert "ts_code" in result.columns
|
|||
|
|
assert factor.name in result.columns
|
|||
|
|
|
|||
|
|
# 检查类型
|
|||
|
|
assert result["trade_date"].dtype == pl.Utf8
|
|||
|
|
assert result["ts_code"].dtype == pl.Utf8
|
|||
|
|
|
|||
|
|
def test_all_dates_present(self, engine):
|
|||
|
|
"""测试结果包含所有日期"""
|
|||
|
|
factor = SimpleCrossSectionalFactor()
|
|||
|
|
|
|||
|
|
start_date = "20240101"
|
|||
|
|
end_date = "20240105"
|
|||
|
|
|
|||
|
|
result = engine.compute(factor, start_date=start_date, end_date=end_date)
|
|||
|
|
|
|||
|
|
if len(result) > 0:
|
|||
|
|
dates = result["trade_date"].unique().to_list()
|
|||
|
|
# 应该包含 start_date 和 end_date 之间的日期
|
|||
|
|
assert len(dates) > 0
|
|||
|
|
|
|||
|
|
def test_lookback_window(self, engine):
|
|||
|
|
"""测试多个 DataSpec 时 lookback 取最大值"""
|
|||
|
|
factor = ReturnFactor()
|
|||
|
|
|
|||
|
|
# lookback_days = 2
|
|||
|
|
result = engine.compute(factor, start_date="20240103", end_date="20240105")
|
|||
|
|
|
|||
|
|
# 应该能计算出结果
|
|||
|
|
assert isinstance(result, pl.DataFrame)
|
|||
|
|
|
|||
|
|
|
|||
|
|
class TestTimeSeriesComputation:
|
|||
|
|
"""测试时序计算(防泄露验证)"""
|
|||
|
|
|
|||
|
|
def test_result_format(self, engine):
|
|||
|
|
"""测试结果 DataFrame 格式正确"""
|
|||
|
|
factor = SimpleTimeSeriesFactor(period=3)
|
|||
|
|
|
|||
|
|
result = engine.compute(
|
|||
|
|
factor,
|
|||
|
|
stock_codes=["000001.SZ"],
|
|||
|
|
start_date="20240101",
|
|||
|
|
end_date="20240110",
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# 检查列
|
|||
|
|
assert "trade_date" in result.columns
|
|||
|
|
assert "ts_code" in result.columns
|
|||
|
|
assert factor.name in result.columns
|
|||
|
|
|
|||
|
|
def test_single_stock_data(self, engine):
|
|||
|
|
"""测试每只股票只看到自己的数据"""
|
|||
|
|
factor = SimpleTimeSeriesFactor(period=3)
|
|||
|
|
|
|||
|
|
stock_codes = ["000001.SZ"]
|
|||
|
|
|
|||
|
|
result = engine.compute(
|
|||
|
|
factor,
|
|||
|
|
stock_codes=stock_codes,
|
|||
|
|
start_date="20240101",
|
|||
|
|
end_date="20240110",
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
if len(result) > 0:
|
|||
|
|
# 结果中只应该有指定的股票
|
|||
|
|
stocks = result["ts_code"].unique().to_list()
|
|||
|
|
assert set(stocks) == set(stock_codes)
|
|||
|
|
|
|||
|
|
def test_ma_calculation(self, engine):
|
|||
|
|
"""测试移动平均计算"""
|
|||
|
|
factor = SimpleTimeSeriesFactor(period=3)
|
|||
|
|
|
|||
|
|
result = engine.compute(
|
|||
|
|
factor,
|
|||
|
|
stock_codes=["000001.SZ"],
|
|||
|
|
start_date="20240101",
|
|||
|
|
end_date="20240110",
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
if len(result) > 2:
|
|||
|
|
# 前2个应该是 null(因为 period=3)
|
|||
|
|
ma_values = result[factor.name].to_list()
|
|||
|
|
assert ma_values[0] is None or str(ma_values[0]) == "nan"
|
|||
|
|
assert ma_values[1] is None or str(ma_values[1]) == "nan"
|
|||
|
|
# 第3个应该有值
|
|||
|
|
assert ma_values[2] is not None
|
|||
|
|
|
|||
|
|
def test_missing_stock_skipped(self, engine):
|
|||
|
|
"""测试股票不在数据中时返回空结果"""
|
|||
|
|
factor = SimpleTimeSeriesFactor(period=3)
|
|||
|
|
|
|||
|
|
result = engine.compute(
|
|||
|
|
factor,
|
|||
|
|
stock_codes=["NONEXISTENT.STOCK"],
|
|||
|
|
start_date="20240101",
|
|||
|
|
end_date="20240110",
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# 应该返回空 DataFrame 或包含该股票但值为 null 的结果
|
|||
|
|
assert isinstance(result, pl.DataFrame)
|
|||
|
|
# 对于不存在的股票,结果可能是空的
|
|||
|
|
# 或者包含该股票但值为 null
|