"""测试执行引擎 - FactorEngine 测试需求(来自 factor_implementation_plan.md): - 测试 `compute()` 正确分发给截面计算 - 测试 `compute()` 正确分发给时序计算 - 测试无效 factor_type 时抛出 ValueError 截面计算测试(防泄露验证): - 测试数据裁剪正确(传入 [T-lookback+1, T]) - 测试不包含未来日期 T+1 的数据 - 测试每个日期独立计算 - 测试结果包含所有日期和所有股票 - 测试结果 DataFrame 格式正确 - 测试多个 DataSpec 时 lookback 取最大值 时序计算测试(防泄露验证): - 测试每只股票只看到自己的数据 - 测试不包含其他股票的数据 - 测试传入的是完整时间序列(向量化计算) - 测试结果包含所有股票和所有日期 - 测试结果 DataFrame 格式正确 - 测试股票不在数据中时跳过(或填充 null) """ import pytest import polars as pl from src.factors import ( DataSpec, FactorContext, FactorData, DataLoader, FactorEngine, CrossSectionalFactor, TimeSeriesFactor, ) class SimpleCrossSectionalFactor(CrossSectionalFactor): """简单的截面因子 - 返回收盘价排名""" name = "close_rank" data_specs = [ DataSpec("daily", ["ts_code", "trade_date", "close"], lookback_days=1) ] def compute(self, data: FactorData) -> pl.Series: cs = data.get_cross_section() return cs["close"].rank() class SimpleTimeSeriesFactor(TimeSeriesFactor): """简单的时序因子 - 返回3日移动平均""" name = "ma3" data_specs = [ DataSpec( "daily", ["ts_code", "trade_date", "close"], lookback_days=5, ) ] def __init__(self, period: int = 3): super().__init__(period=period) def compute(self, data: FactorData) -> pl.Series: return data.get_column("close").rolling_mean(self.params["period"]) class ReturnFactor(CrossSectionalFactor): """收益率因子 - 需要2天lookback计算收益率""" name = "return" data_specs = [ DataSpec("daily", ["ts_code", "trade_date", "close"], lookback_days=2) ] def compute(self, data: FactorData) -> pl.Series: # 获取当前日期 current_date = data.context.current_date # 获取当前日期的数据 cs = data.get_cross_section() # 简单返回收盘价作为因子值 # 实际应该计算收益率,但这里简化处理 return cs["close"] @pytest.fixture def loader(): """创建 DataLoader 实例""" return DataLoader(data_dir="data") @pytest.fixture def engine(loader): """创建 FactorEngine 实例""" return FactorEngine(loader) class TestFactorEngineDispatch: """测试引擎分发逻辑""" def test_dispatch_cross_sectional(self, engine): """测试 compute() 正确分发给截面计算""" factor = SimpleCrossSectionalFactor() result = engine.compute(factor, start_date="20240101", end_date="20240105") assert isinstance(result, pl.DataFrame) assert "trade_date" in result.columns assert "ts_code" in result.columns assert "close_rank" in result.columns def test_dispatch_time_series(self, engine, loader): """测试 compute() 正确分发给时序计算""" factor = SimpleTimeSeriesFactor(period=3) # 获取一些股票代码 sample_data = loader.load( [DataSpec("daily", ["ts_code", "trade_date"], lookback_days=1)] ) stock_codes = sample_data["ts_code"].unique().head(3).to_list() result = engine.compute( factor, stock_codes=stock_codes, start_date="20240101", end_date="20240110", ) assert isinstance(result, pl.DataFrame) assert "trade_date" in result.columns assert "ts_code" in result.columns assert "ma3" in result.columns def test_unknown_factor_type(self, engine): """测试无效 factor_type 时抛出 ValueError""" class UnknownFactor: name = "unknown" factor_type = "unknown_type" data_specs = [] factor = UnknownFactor() with pytest.raises(ValueError, match="Unknown factor type"): engine.compute(factor) class TestCrossSectionalComputation: """测试截面计算(防泄露验证)""" def test_result_format(self, engine): """测试结果 DataFrame 格式正确""" factor = SimpleCrossSectionalFactor() result = engine.compute(factor, start_date="20240101", end_date="20240105") # 检查列 assert "trade_date" in result.columns assert "ts_code" in result.columns assert factor.name in result.columns # 检查类型 assert result["trade_date"].dtype == pl.Utf8 assert result["ts_code"].dtype == pl.Utf8 def test_all_dates_present(self, engine): """测试结果包含所有日期""" factor = SimpleCrossSectionalFactor() start_date = "20240101" end_date = "20240105" result = engine.compute(factor, start_date=start_date, end_date=end_date) if len(result) > 0: dates = result["trade_date"].unique().to_list() # 应该包含 start_date 和 end_date 之间的日期 assert len(dates) > 0 def test_lookback_window(self, engine): """测试多个 DataSpec 时 lookback 取最大值""" factor = ReturnFactor() # lookback_days = 2 result = engine.compute(factor, start_date="20240103", end_date="20240105") # 应该能计算出结果 assert isinstance(result, pl.DataFrame) class TestTimeSeriesComputation: """测试时序计算(防泄露验证)""" def test_result_format(self, engine): """测试结果 DataFrame 格式正确""" factor = SimpleTimeSeriesFactor(period=3) result = engine.compute( factor, stock_codes=["000001.SZ"], start_date="20240101", end_date="20240110", ) # 检查列 assert "trade_date" in result.columns assert "ts_code" in result.columns assert factor.name in result.columns def test_single_stock_data(self, engine): """测试每只股票只看到自己的数据""" factor = SimpleTimeSeriesFactor(period=3) stock_codes = ["000001.SZ"] result = engine.compute( factor, stock_codes=stock_codes, start_date="20240101", end_date="20240110", ) if len(result) > 0: # 结果中只应该有指定的股票 stocks = result["ts_code"].unique().to_list() assert set(stocks) == set(stock_codes) def test_ma_calculation(self, engine): """测试移动平均计算""" factor = SimpleTimeSeriesFactor(period=3) result = engine.compute( factor, stock_codes=["000001.SZ"], start_date="20240101", end_date="20240110", ) if len(result) > 2: # 前2个应该是 null(因为 period=3) ma_values = result[factor.name].to_list() assert ma_values[0] is None or str(ma_values[0]) == "nan" assert ma_values[1] is None or str(ma_values[1]) == "nan" # 第3个应该有值 assert ma_values[2] is not None def test_missing_stock_skipped(self, engine): """测试股票不在数据中时返回空结果""" factor = SimpleTimeSeriesFactor(period=3) result = engine.compute( factor, stock_codes=["NONEXISTENT.STOCK"], start_date="20240101", end_date="20240110", ) # 应该返回空 DataFrame 或包含该股票但值为 null 的结果 assert isinstance(result, pl.DataFrame) # 对于不存在的股票,结果可能是空的 # 或者包含该股票但值为 null