Files
ProStock/tests/test_factorminer_local_engine.py
liaozhaorun 4e676093d3 feat(factorminer): 添加 LocalFactorEvaluator 封装本地因子矩阵计算
- 新增 LocalFactorEvaluator 类,将 FactorEngine 的 Polars 长表输出转换为 (M, T) numpy 矩阵
- 支持批量因子计算、单因子计算及收益率矩阵计算
- 补充完整单元测试,覆盖 pivot、缺失值填充、股票代码过滤及异常处理场景
2026-04-08 22:36:44 +08:00

156 lines
5.0 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""LocalFactorEvaluator 单元测试。
使用 MagicMock 模拟 FactorEngine避免依赖真实数据库。
"""
from unittest.mock import MagicMock, patch
import numpy as np
import polars as pl
import pytest
from src.factorminer.evaluation.local_engine import LocalFactorEvaluator
@pytest.fixture
def mock_engine():
"""提供 mock 的 FactorEngine 实例。"""
with patch("src.factorminer.evaluation.local_engine.FactorEngine") as mock_cls:
instance = MagicMock()
mock_cls.return_value = instance
yield instance
def test_evaluate_single(mock_engine):
"""测试单个因子计算并正确 pivot 为矩阵。"""
mock_engine.compute.return_value = pl.DataFrame(
{
"ts_code": ["000001.SZ", "000001.SZ", "000002.SZ", "000002.SZ"],
"trade_date": ["20240101", "20240102", "20240101", "20240102"],
"alpha": [1.0, 2.0, 3.0, 4.0],
}
)
evaluator = LocalFactorEvaluator("20240101", "20240102")
result = evaluator.evaluate_single("alpha", "cs_rank(close)")
mock_engine.add_factor.assert_called_once_with("alpha", "cs_rank(close)")
mock_engine.compute.assert_called_once_with(["alpha"], "20240101", "20240102", None)
mock_engine.clear.assert_called_once()
assert result.shape == (2, 2)
np.testing.assert_array_equal(
result,
np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float64),
)
def test_evaluate_empty_specs(mock_engine):
"""测试空 specs 直接返回空字典。"""
evaluator = LocalFactorEvaluator("20240101", "20240102")
result = evaluator.evaluate([])
assert result == {}
mock_engine.compute.assert_not_called()
def test_evaluate_batch(mock_engine):
"""测试批量计算多个因子。"""
mock_engine.compute.return_value = pl.DataFrame(
{
"ts_code": ["000001.SZ", "000001.SZ", "000002.SZ", "000002.SZ"],
"trade_date": ["20240101", "20240102", "20240101", "20240102"],
"alpha1": [1.0, 2.0, 3.0, 4.0],
"alpha2": [5.0, 6.0, 7.0, 8.0],
}
)
evaluator = LocalFactorEvaluator("20240101", "20240102")
result = evaluator.evaluate(
[
("alpha1", "cs_rank(close)"),
("alpha2", "cs_rank(vol)"),
]
)
assert set(result.keys()) == {"alpha1", "alpha2"}
np.testing.assert_array_equal(
result["alpha1"],
np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float64),
)
np.testing.assert_array_equal(
result["alpha2"],
np.array([[5.0, 6.0], [7.0, 8.0]], dtype=np.float64),
)
mock_engine.clear.assert_called_once()
def test_evaluate_returns(mock_engine):
"""测试收益率矩阵计算。"""
mock_engine.compute.return_value = pl.DataFrame(
{
"ts_code": ["000001.SZ", "000001.SZ", "000002.SZ", "000002.SZ"],
"trade_date": ["20240101", "20240102", "20240101", "20240102"],
"__returns": [0.01, 0.02, -0.01, 0.03],
}
)
evaluator = LocalFactorEvaluator("20240101", "20240102")
result = evaluator.evaluate_returns(periods=5)
mock_engine.add_factor.assert_called_once_with(
"__returns", "ts_pct_change(close, 5)"
)
assert result.shape == (2, 2)
np.testing.assert_array_equal(
result,
np.array([[0.01, 0.02], [-0.01, 0.03]], dtype=np.float64),
)
mock_engine.clear.assert_called_once()
def test_evaluate_with_nan(mock_engine):
"""测试缺失值正确填充为 np.nan。"""
mock_engine.compute.return_value = pl.DataFrame(
{
"ts_code": ["000001.SZ", "000001.SZ", "000002.SZ"],
"trade_date": ["20240101", "20240102", "20240101"],
"alpha": [1.0, 2.0, 3.0],
}
)
evaluator = LocalFactorEvaluator("20240101", "20240102")
result = evaluator.evaluate_single("alpha", "cs_rank(close)")
assert result.shape == (2, 2)
assert np.isnan(result[1, 1])
assert result[0, 0] == 1.0
assert result[0, 1] == 2.0
assert result[1, 0] == 3.0
def test_stock_codes_filter(mock_engine):
"""测试传入股票代码列表时正确透传给 FactorEngine。"""
mock_engine.compute.return_value = pl.DataFrame(
{
"ts_code": ["000001.SZ", "000001.SZ"],
"trade_date": ["20240101", "20240102"],
"alpha": [1.0, 2.0],
}
)
evaluator = LocalFactorEvaluator("20240101", "20240102", stock_codes=["000001.SZ"])
result = evaluator.evaluate_single("alpha", "close")
mock_engine.compute.assert_called_once_with(
["alpha"], "20240101", "20240102", ["000001.SZ"]
)
assert result.shape == (1, 2)
def test_to_matrix_missing_factor_column(mock_engine):
"""测试 DataFrame 缺少目标因子列时抛出 ValueError。"""
mock_engine.compute.return_value = pl.DataFrame(
{
"ts_code": ["000001.SZ"],
"trade_date": ["20240101"],
}
)
evaluator = LocalFactorEvaluator("20240101", "20240102")
with pytest.raises(ValueError, match="未找到因子列 'alpha'"):
evaluator.evaluate_single("alpha", "close")