Files
ProStock/tests/test_factorminer_local_engine.py

156 lines
5.0 KiB
Python
Raw Normal View History

"""LocalFactorEvaluator 单元测试。
使用 MagicMock 模拟 FactorEngine避免依赖真实数据库
"""
from unittest.mock import MagicMock, patch
import numpy as np
import polars as pl
import pytest
from src.factorminer.evaluation.local_engine import LocalFactorEvaluator
@pytest.fixture
def mock_engine():
"""提供 mock 的 FactorEngine 实例。"""
with patch("src.factorminer.evaluation.local_engine.FactorEngine") as mock_cls:
instance = MagicMock()
mock_cls.return_value = instance
yield instance
def test_evaluate_single(mock_engine):
"""测试单个因子计算并正确 pivot 为矩阵。"""
mock_engine.compute.return_value = pl.DataFrame(
{
"ts_code": ["000001.SZ", "000001.SZ", "000002.SZ", "000002.SZ"],
"trade_date": ["20240101", "20240102", "20240101", "20240102"],
"alpha": [1.0, 2.0, 3.0, 4.0],
}
)
evaluator = LocalFactorEvaluator("20240101", "20240102")
result = evaluator.evaluate_single("alpha", "cs_rank(close)")
mock_engine.add_factor.assert_called_once_with("alpha", "cs_rank(close)")
mock_engine.compute.assert_called_once_with(["alpha"], "20240101", "20240102", None)
mock_engine.clear.assert_called_once()
assert result.shape == (2, 2)
np.testing.assert_array_equal(
result,
np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float64),
)
def test_evaluate_empty_specs(mock_engine):
"""测试空 specs 直接返回空字典。"""
evaluator = LocalFactorEvaluator("20240101", "20240102")
result = evaluator.evaluate([])
assert result == {}
mock_engine.compute.assert_not_called()
def test_evaluate_batch(mock_engine):
"""测试批量计算多个因子。"""
mock_engine.compute.return_value = pl.DataFrame(
{
"ts_code": ["000001.SZ", "000001.SZ", "000002.SZ", "000002.SZ"],
"trade_date": ["20240101", "20240102", "20240101", "20240102"],
"alpha1": [1.0, 2.0, 3.0, 4.0],
"alpha2": [5.0, 6.0, 7.0, 8.0],
}
)
evaluator = LocalFactorEvaluator("20240101", "20240102")
result = evaluator.evaluate(
[
("alpha1", "cs_rank(close)"),
("alpha2", "cs_rank(vol)"),
]
)
assert set(result.keys()) == {"alpha1", "alpha2"}
np.testing.assert_array_equal(
result["alpha1"],
np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float64),
)
np.testing.assert_array_equal(
result["alpha2"],
np.array([[5.0, 6.0], [7.0, 8.0]], dtype=np.float64),
)
mock_engine.clear.assert_called_once()
def test_evaluate_returns(mock_engine):
"""测试收益率矩阵计算。"""
mock_engine.compute.return_value = pl.DataFrame(
{
"ts_code": ["000001.SZ", "000001.SZ", "000002.SZ", "000002.SZ"],
"trade_date": ["20240101", "20240102", "20240101", "20240102"],
"__returns": [0.01, 0.02, -0.01, 0.03],
}
)
evaluator = LocalFactorEvaluator("20240101", "20240102")
result = evaluator.evaluate_returns(periods=5)
mock_engine.add_factor.assert_called_once_with(
"__returns", "ts_pct_change(close, 5)"
)
assert result.shape == (2, 2)
np.testing.assert_array_equal(
result,
np.array([[0.01, 0.02], [-0.01, 0.03]], dtype=np.float64),
)
mock_engine.clear.assert_called_once()
def test_evaluate_with_nan(mock_engine):
"""测试缺失值正确填充为 np.nan。"""
mock_engine.compute.return_value = pl.DataFrame(
{
"ts_code": ["000001.SZ", "000001.SZ", "000002.SZ"],
"trade_date": ["20240101", "20240102", "20240101"],
"alpha": [1.0, 2.0, 3.0],
}
)
evaluator = LocalFactorEvaluator("20240101", "20240102")
result = evaluator.evaluate_single("alpha", "cs_rank(close)")
assert result.shape == (2, 2)
assert np.isnan(result[1, 1])
assert result[0, 0] == 1.0
assert result[0, 1] == 2.0
assert result[1, 0] == 3.0
def test_stock_codes_filter(mock_engine):
"""测试传入股票代码列表时正确透传给 FactorEngine。"""
mock_engine.compute.return_value = pl.DataFrame(
{
"ts_code": ["000001.SZ", "000001.SZ"],
"trade_date": ["20240101", "20240102"],
"alpha": [1.0, 2.0],
}
)
evaluator = LocalFactorEvaluator("20240101", "20240102", stock_codes=["000001.SZ"])
result = evaluator.evaluate_single("alpha", "close")
mock_engine.compute.assert_called_once_with(
["alpha"], "20240101", "20240102", ["000001.SZ"]
)
assert result.shape == (1, 2)
def test_to_matrix_missing_factor_column(mock_engine):
"""测试 DataFrame 缺少目标因子列时抛出 ValueError。"""
mock_engine.compute.return_value = pl.DataFrame(
{
"ts_code": ["000001.SZ"],
"trade_date": ["20240101"],
}
)
evaluator = LocalFactorEvaluator("20240101", "20240102")
with pytest.raises(ValueError, match="未找到因子列 'alpha'"):
evaluator.evaluate_single("alpha", "close")