195 lines
6.3 KiB
Python
195 lines
6.3 KiB
Python
"""compute_factors 入口脚本的单元测试(使用 mock 隔离外部依赖)。"""
|
||
|
||
from unittest.mock import MagicMock, patch
|
||
|
||
import polars as pl
|
||
import pytest
|
||
|
||
from src.data.compute_factors import run
|
||
|
||
|
||
@pytest.fixture
|
||
def mock_manager():
|
||
"""提供一个 mock 的 FactorManager。"""
|
||
manager = MagicMock()
|
||
manager.get_all_factors.return_value = pl.DataFrame(
|
||
{
|
||
"factor_id": ["F_001", "F_002"],
|
||
"name": ["ma_5", "ma_20"],
|
||
"dsl": ["ts_mean(close, 5)", "ts_mean(close, 20)"],
|
||
}
|
||
)
|
||
manager.get_factors_by_name.side_effect = lambda name: pl.DataFrame(
|
||
{"name": [name], "dsl": [f"dsl_of_{name}"]}
|
||
)
|
||
return manager
|
||
|
||
|
||
@pytest.fixture
|
||
def mock_engine():
|
||
"""提供一个 mock 的 FactorEngine。"""
|
||
engine = MagicMock()
|
||
engine.compute.side_effect = lambda cols, start, end: pl.DataFrame(
|
||
{
|
||
"trade_date": ["20240101", "20240101"],
|
||
"ts_code": ["000001.SZ", "000002.SZ"],
|
||
cols[0]: [1.0, 2.0],
|
||
}
|
||
)
|
||
return engine
|
||
|
||
|
||
@pytest.fixture
|
||
def mock_storage(tmp_path):
|
||
"""提供一个 mock 的 FactorStorage,全部校验通过。"""
|
||
storage = MagicMock()
|
||
storage.validate.return_value = (
|
||
True,
|
||
{"max_abs_diff": 0.0, "mean_abs_diff": 0.0, "matched_rows": 2},
|
||
)
|
||
storage.save.return_value = None
|
||
return storage
|
||
|
||
|
||
def test_run_auto_discover_factors(mock_manager, mock_engine, mock_storage):
|
||
"""未传入 factor_names 时自动读取 metadata 中全部因子。"""
|
||
with (
|
||
patch("src.data.compute_factors.FactorEngine", return_value=mock_engine),
|
||
patch("src.data.compute_factors.FactorStorage", return_value=mock_storage),
|
||
patch("src.data.compute_factors.FactorManager", return_value=mock_manager),
|
||
):
|
||
result = run(
|
||
factor_names=[],
|
||
metadata="dummy.jsonl",
|
||
start_date="20240101",
|
||
end_date="20240131",
|
||
)
|
||
|
||
assert result["success"] == ["ma_5", "ma_20"]
|
||
assert len(result["failed"]) == 0
|
||
assert mock_manager.get_all_factors.called
|
||
|
||
|
||
def test_run_validate_fail_skip(mock_manager, mock_engine, mock_storage):
|
||
"""校验失败且无 force 时跳过写入。"""
|
||
mock_storage.validate.return_value = (
|
||
False,
|
||
{"max_abs_diff": 1.0, "mean_abs_diff": 0.5, "matched_rows": 2},
|
||
)
|
||
|
||
with (
|
||
patch("src.data.compute_factors.FactorEngine", return_value=mock_engine),
|
||
patch("src.data.compute_factors.FactorStorage", return_value=mock_storage),
|
||
patch("src.data.compute_factors.FactorManager", return_value=mock_manager),
|
||
):
|
||
result = run(
|
||
factor_names=["ma_5"],
|
||
metadata="dummy.jsonl",
|
||
start_date="20240101",
|
||
end_date="20240131",
|
||
force=False,
|
||
)
|
||
|
||
assert result["success"] == []
|
||
assert len(result["failed"]) == 1
|
||
assert result["failed"][0]["name"] == "ma_5"
|
||
assert result["failed"][0]["reason"] == "校验失败"
|
||
mock_storage.save.assert_not_called()
|
||
|
||
|
||
def test_run_validate_fail_force(mock_manager, mock_engine, mock_storage):
|
||
"""校验失败但 force=True 时强制写入。"""
|
||
mock_storage.validate.return_value = (
|
||
False,
|
||
{"max_abs_diff": 1.0, "mean_abs_diff": 0.5, "matched_rows": 2},
|
||
)
|
||
|
||
with (
|
||
patch("src.data.compute_factors.FactorEngine", return_value=mock_engine),
|
||
patch("src.data.compute_factors.FactorStorage", return_value=mock_storage),
|
||
patch("src.data.compute_factors.FactorManager", return_value=mock_manager),
|
||
):
|
||
result = run(
|
||
factor_names=["ma_5"],
|
||
metadata="dummy.jsonl",
|
||
start_date="20240101",
|
||
end_date="20240131",
|
||
force=True,
|
||
)
|
||
|
||
assert result["success"] == ["ma_5"]
|
||
assert len(result["failed"]) == 0
|
||
mock_storage.save.assert_called_once()
|
||
|
||
|
||
def test_run_missing_metadata_entry(mock_manager, mock_engine, mock_storage):
|
||
"""metadata 中找不到因子时标记失败。"""
|
||
mock_manager.get_factors_by_name.side_effect = None
|
||
mock_manager.get_factors_by_name.return_value = pl.DataFrame(
|
||
{"name": [], "dsl": []}
|
||
)
|
||
|
||
with (
|
||
patch("src.data.compute_factors.FactorEngine", return_value=mock_engine),
|
||
patch("src.data.compute_factors.FactorStorage", return_value=mock_storage),
|
||
patch("src.data.compute_factors.FactorManager", return_value=mock_manager),
|
||
):
|
||
result = run(
|
||
factor_names=["unknown"],
|
||
metadata="dummy.jsonl",
|
||
start_date="20240101",
|
||
end_date="20240131",
|
||
)
|
||
|
||
assert result["success"] == []
|
||
assert len(result["failed"]) == 1
|
||
assert "未找到" in result["failed"][0]["reason"]
|
||
|
||
|
||
def test_run_engine_exception(mock_manager, mock_engine, mock_storage):
|
||
"""engine.compute 抛出异常时标记失败。"""
|
||
mock_engine.compute.side_effect = ValueError("compute error")
|
||
|
||
with (
|
||
patch("src.data.compute_factors.FactorEngine", return_value=mock_engine),
|
||
patch("src.data.compute_factors.FactorStorage", return_value=mock_storage),
|
||
patch("src.data.compute_factors.FactorManager", return_value=mock_manager),
|
||
):
|
||
result = run(
|
||
factor_names=["ma_5"],
|
||
metadata="dummy.jsonl",
|
||
start_date="20240101",
|
||
end_date="20240131",
|
||
)
|
||
|
||
assert result["success"] == []
|
||
assert len(result["failed"]) == 1
|
||
assert "compute error" in result["failed"][0]["reason"]
|
||
|
||
|
||
def test_run_missing_result_column(mock_engine, mock_storage, mock_manager):
|
||
"""计算结果缺少对应因子列时标记失败。"""
|
||
mock_engine.compute.side_effect = lambda cols, start, end: pl.DataFrame(
|
||
{
|
||
"trade_date": ["20240101"],
|
||
"ts_code": ["000001.SZ"],
|
||
"wrong_col": [1.0],
|
||
}
|
||
)
|
||
|
||
with (
|
||
patch("src.data.compute_factors.FactorEngine", return_value=mock_engine),
|
||
patch("src.data.compute_factors.FactorStorage", return_value=mock_storage),
|
||
patch("src.data.compute_factors.FactorManager", return_value=mock_manager),
|
||
):
|
||
result = run(
|
||
factor_names=["ma_5"],
|
||
metadata="dummy.jsonl",
|
||
start_date="20240101",
|
||
end_date="20240131",
|
||
)
|
||
|
||
assert result["success"] == []
|
||
assert len(result["failed"]) == 1
|
||
assert "缺少列" in result["failed"][0]["reason"]
|