Files
ProStock/tests/test_compute_factors.py

195 lines
6.3 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""compute_factors 入口脚本的单元测试(使用 mock 隔离外部依赖)。"""
from unittest.mock import MagicMock, patch
import polars as pl
import pytest
from src.data.compute_factors import run
@pytest.fixture
def mock_manager():
"""提供一个 mock 的 FactorManager。"""
manager = MagicMock()
manager.get_all_factors.return_value = pl.DataFrame(
{
"factor_id": ["F_001", "F_002"],
"name": ["ma_5", "ma_20"],
"dsl": ["ts_mean(close, 5)", "ts_mean(close, 20)"],
}
)
manager.get_factors_by_name.side_effect = lambda name: pl.DataFrame(
{"name": [name], "dsl": [f"dsl_of_{name}"]}
)
return manager
@pytest.fixture
def mock_engine():
"""提供一个 mock 的 FactorEngine。"""
engine = MagicMock()
engine.compute.side_effect = lambda cols, start, end: pl.DataFrame(
{
"trade_date": ["20240101", "20240101"],
"ts_code": ["000001.SZ", "000002.SZ"],
cols[0]: [1.0, 2.0],
}
)
return engine
@pytest.fixture
def mock_storage(tmp_path):
"""提供一个 mock 的 FactorStorage全部校验通过。"""
storage = MagicMock()
storage.validate.return_value = (
True,
{"max_abs_diff": 0.0, "mean_abs_diff": 0.0, "matched_rows": 2},
)
storage.save.return_value = None
return storage
def test_run_auto_discover_factors(mock_manager, mock_engine, mock_storage):
"""未传入 factor_names 时自动读取 metadata 中全部因子。"""
with (
patch("src.data.compute_factors.FactorEngine", return_value=mock_engine),
patch("src.data.compute_factors.FactorStorage", return_value=mock_storage),
patch("src.data.compute_factors.FactorManager", return_value=mock_manager),
):
result = run(
factor_names=[],
metadata="dummy.jsonl",
start_date="20240101",
end_date="20240131",
)
assert result["success"] == ["ma_5", "ma_20"]
assert len(result["failed"]) == 0
assert mock_manager.get_all_factors.called
def test_run_validate_fail_skip(mock_manager, mock_engine, mock_storage):
"""校验失败且无 force 时跳过写入。"""
mock_storage.validate.return_value = (
False,
{"max_abs_diff": 1.0, "mean_abs_diff": 0.5, "matched_rows": 2},
)
with (
patch("src.data.compute_factors.FactorEngine", return_value=mock_engine),
patch("src.data.compute_factors.FactorStorage", return_value=mock_storage),
patch("src.data.compute_factors.FactorManager", return_value=mock_manager),
):
result = run(
factor_names=["ma_5"],
metadata="dummy.jsonl",
start_date="20240101",
end_date="20240131",
force=False,
)
assert result["success"] == []
assert len(result["failed"]) == 1
assert result["failed"][0]["name"] == "ma_5"
assert result["failed"][0]["reason"] == "校验失败"
mock_storage.save.assert_not_called()
def test_run_validate_fail_force(mock_manager, mock_engine, mock_storage):
"""校验失败但 force=True 时强制写入。"""
mock_storage.validate.return_value = (
False,
{"max_abs_diff": 1.0, "mean_abs_diff": 0.5, "matched_rows": 2},
)
with (
patch("src.data.compute_factors.FactorEngine", return_value=mock_engine),
patch("src.data.compute_factors.FactorStorage", return_value=mock_storage),
patch("src.data.compute_factors.FactorManager", return_value=mock_manager),
):
result = run(
factor_names=["ma_5"],
metadata="dummy.jsonl",
start_date="20240101",
end_date="20240131",
force=True,
)
assert result["success"] == ["ma_5"]
assert len(result["failed"]) == 0
mock_storage.save.assert_called_once()
def test_run_missing_metadata_entry(mock_manager, mock_engine, mock_storage):
"""metadata 中找不到因子时标记失败。"""
mock_manager.get_factors_by_name.side_effect = None
mock_manager.get_factors_by_name.return_value = pl.DataFrame(
{"name": [], "dsl": []}
)
with (
patch("src.data.compute_factors.FactorEngine", return_value=mock_engine),
patch("src.data.compute_factors.FactorStorage", return_value=mock_storage),
patch("src.data.compute_factors.FactorManager", return_value=mock_manager),
):
result = run(
factor_names=["unknown"],
metadata="dummy.jsonl",
start_date="20240101",
end_date="20240131",
)
assert result["success"] == []
assert len(result["failed"]) == 1
assert "未找到" in result["failed"][0]["reason"]
def test_run_engine_exception(mock_manager, mock_engine, mock_storage):
"""engine.compute 抛出异常时标记失败。"""
mock_engine.compute.side_effect = ValueError("compute error")
with (
patch("src.data.compute_factors.FactorEngine", return_value=mock_engine),
patch("src.data.compute_factors.FactorStorage", return_value=mock_storage),
patch("src.data.compute_factors.FactorManager", return_value=mock_manager),
):
result = run(
factor_names=["ma_5"],
metadata="dummy.jsonl",
start_date="20240101",
end_date="20240131",
)
assert result["success"] == []
assert len(result["failed"]) == 1
assert "compute error" in result["failed"][0]["reason"]
def test_run_missing_result_column(mock_engine, mock_storage, mock_manager):
"""计算结果缺少对应因子列时标记失败。"""
mock_engine.compute.side_effect = lambda cols, start, end: pl.DataFrame(
{
"trade_date": ["20240101"],
"ts_code": ["000001.SZ"],
"wrong_col": [1.0],
}
)
with (
patch("src.data.compute_factors.FactorEngine", return_value=mock_engine),
patch("src.data.compute_factors.FactorStorage", return_value=mock_storage),
patch("src.data.compute_factors.FactorManager", return_value=mock_manager),
):
result = run(
factor_names=["ma_5"],
metadata="dummy.jsonl",
start_date="20240101",
end_date="20240131",
)
assert result["success"] == []
assert len(result["failed"]) == 1
assert "缺少列" in result["failed"][0]["reason"]