156 lines
5.0 KiB
Python
156 lines
5.0 KiB
Python
|
|
"""LocalFactorEvaluator 单元测试。
|
|||
|
|
|
|||
|
|
使用 MagicMock 模拟 FactorEngine,避免依赖真实数据库。
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
from unittest.mock import MagicMock, patch
|
|||
|
|
|
|||
|
|
import numpy as np
|
|||
|
|
import polars as pl
|
|||
|
|
import pytest
|
|||
|
|
|
|||
|
|
from src.factorminer.evaluation.local_engine import LocalFactorEvaluator
|
|||
|
|
|
|||
|
|
|
|||
|
|
@pytest.fixture
|
|||
|
|
def mock_engine():
|
|||
|
|
"""提供 mock 的 FactorEngine 实例。"""
|
|||
|
|
with patch("src.factorminer.evaluation.local_engine.FactorEngine") as mock_cls:
|
|||
|
|
instance = MagicMock()
|
|||
|
|
mock_cls.return_value = instance
|
|||
|
|
yield instance
|
|||
|
|
|
|||
|
|
|
|||
|
|
def test_evaluate_single(mock_engine):
|
|||
|
|
"""测试单个因子计算并正确 pivot 为矩阵。"""
|
|||
|
|
mock_engine.compute.return_value = pl.DataFrame(
|
|||
|
|
{
|
|||
|
|
"ts_code": ["000001.SZ", "000001.SZ", "000002.SZ", "000002.SZ"],
|
|||
|
|
"trade_date": ["20240101", "20240102", "20240101", "20240102"],
|
|||
|
|
"alpha": [1.0, 2.0, 3.0, 4.0],
|
|||
|
|
}
|
|||
|
|
)
|
|||
|
|
evaluator = LocalFactorEvaluator("20240101", "20240102")
|
|||
|
|
result = evaluator.evaluate_single("alpha", "cs_rank(close)")
|
|||
|
|
|
|||
|
|
mock_engine.add_factor.assert_called_once_with("alpha", "cs_rank(close)")
|
|||
|
|
mock_engine.compute.assert_called_once_with(["alpha"], "20240101", "20240102", None)
|
|||
|
|
mock_engine.clear.assert_called_once()
|
|||
|
|
|
|||
|
|
assert result.shape == (2, 2)
|
|||
|
|
np.testing.assert_array_equal(
|
|||
|
|
result,
|
|||
|
|
np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float64),
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
|
|||
|
|
def test_evaluate_empty_specs(mock_engine):
|
|||
|
|
"""测试空 specs 直接返回空字典。"""
|
|||
|
|
evaluator = LocalFactorEvaluator("20240101", "20240102")
|
|||
|
|
result = evaluator.evaluate([])
|
|||
|
|
assert result == {}
|
|||
|
|
mock_engine.compute.assert_not_called()
|
|||
|
|
|
|||
|
|
|
|||
|
|
def test_evaluate_batch(mock_engine):
|
|||
|
|
"""测试批量计算多个因子。"""
|
|||
|
|
mock_engine.compute.return_value = pl.DataFrame(
|
|||
|
|
{
|
|||
|
|
"ts_code": ["000001.SZ", "000001.SZ", "000002.SZ", "000002.SZ"],
|
|||
|
|
"trade_date": ["20240101", "20240102", "20240101", "20240102"],
|
|||
|
|
"alpha1": [1.0, 2.0, 3.0, 4.0],
|
|||
|
|
"alpha2": [5.0, 6.0, 7.0, 8.0],
|
|||
|
|
}
|
|||
|
|
)
|
|||
|
|
evaluator = LocalFactorEvaluator("20240101", "20240102")
|
|||
|
|
result = evaluator.evaluate(
|
|||
|
|
[
|
|||
|
|
("alpha1", "cs_rank(close)"),
|
|||
|
|
("alpha2", "cs_rank(vol)"),
|
|||
|
|
]
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
assert set(result.keys()) == {"alpha1", "alpha2"}
|
|||
|
|
np.testing.assert_array_equal(
|
|||
|
|
result["alpha1"],
|
|||
|
|
np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float64),
|
|||
|
|
)
|
|||
|
|
np.testing.assert_array_equal(
|
|||
|
|
result["alpha2"],
|
|||
|
|
np.array([[5.0, 6.0], [7.0, 8.0]], dtype=np.float64),
|
|||
|
|
)
|
|||
|
|
mock_engine.clear.assert_called_once()
|
|||
|
|
|
|||
|
|
|
|||
|
|
def test_evaluate_returns(mock_engine):
|
|||
|
|
"""测试收益率矩阵计算。"""
|
|||
|
|
mock_engine.compute.return_value = pl.DataFrame(
|
|||
|
|
{
|
|||
|
|
"ts_code": ["000001.SZ", "000001.SZ", "000002.SZ", "000002.SZ"],
|
|||
|
|
"trade_date": ["20240101", "20240102", "20240101", "20240102"],
|
|||
|
|
"__returns": [0.01, 0.02, -0.01, 0.03],
|
|||
|
|
}
|
|||
|
|
)
|
|||
|
|
evaluator = LocalFactorEvaluator("20240101", "20240102")
|
|||
|
|
result = evaluator.evaluate_returns(periods=5)
|
|||
|
|
|
|||
|
|
mock_engine.add_factor.assert_called_once_with(
|
|||
|
|
"__returns", "ts_pct_change(close, 5)"
|
|||
|
|
)
|
|||
|
|
assert result.shape == (2, 2)
|
|||
|
|
np.testing.assert_array_equal(
|
|||
|
|
result,
|
|||
|
|
np.array([[0.01, 0.02], [-0.01, 0.03]], dtype=np.float64),
|
|||
|
|
)
|
|||
|
|
mock_engine.clear.assert_called_once()
|
|||
|
|
|
|||
|
|
|
|||
|
|
def test_evaluate_with_nan(mock_engine):
|
|||
|
|
"""测试缺失值正确填充为 np.nan。"""
|
|||
|
|
mock_engine.compute.return_value = pl.DataFrame(
|
|||
|
|
{
|
|||
|
|
"ts_code": ["000001.SZ", "000001.SZ", "000002.SZ"],
|
|||
|
|
"trade_date": ["20240101", "20240102", "20240101"],
|
|||
|
|
"alpha": [1.0, 2.0, 3.0],
|
|||
|
|
}
|
|||
|
|
)
|
|||
|
|
evaluator = LocalFactorEvaluator("20240101", "20240102")
|
|||
|
|
result = evaluator.evaluate_single("alpha", "cs_rank(close)")
|
|||
|
|
|
|||
|
|
assert result.shape == (2, 2)
|
|||
|
|
assert np.isnan(result[1, 1])
|
|||
|
|
assert result[0, 0] == 1.0
|
|||
|
|
assert result[0, 1] == 2.0
|
|||
|
|
assert result[1, 0] == 3.0
|
|||
|
|
|
|||
|
|
|
|||
|
|
def test_stock_codes_filter(mock_engine):
|
|||
|
|
"""测试传入股票代码列表时正确透传给 FactorEngine。"""
|
|||
|
|
mock_engine.compute.return_value = pl.DataFrame(
|
|||
|
|
{
|
|||
|
|
"ts_code": ["000001.SZ", "000001.SZ"],
|
|||
|
|
"trade_date": ["20240101", "20240102"],
|
|||
|
|
"alpha": [1.0, 2.0],
|
|||
|
|
}
|
|||
|
|
)
|
|||
|
|
evaluator = LocalFactorEvaluator("20240101", "20240102", stock_codes=["000001.SZ"])
|
|||
|
|
result = evaluator.evaluate_single("alpha", "close")
|
|||
|
|
|
|||
|
|
mock_engine.compute.assert_called_once_with(
|
|||
|
|
["alpha"], "20240101", "20240102", ["000001.SZ"]
|
|||
|
|
)
|
|||
|
|
assert result.shape == (1, 2)
|
|||
|
|
|
|||
|
|
|
|||
|
|
def test_to_matrix_missing_factor_column(mock_engine):
|
|||
|
|
"""测试 DataFrame 缺少目标因子列时抛出 ValueError。"""
|
|||
|
|
mock_engine.compute.return_value = pl.DataFrame(
|
|||
|
|
{
|
|||
|
|
"ts_code": ["000001.SZ"],
|
|||
|
|
"trade_date": ["20240101"],
|
|||
|
|
}
|
|||
|
|
)
|
|||
|
|
evaluator = LocalFactorEvaluator("20240101", "20240102")
|
|||
|
|
with pytest.raises(ValueError, match="未找到因子列 'alpha'"):
|
|||
|
|
evaluator.evaluate_single("alpha", "close")
|