feat: 引入 FactorMiner 开源量化因子挖掘项目
This commit is contained in:
194
tests/test_compute_factors.py
Normal file
194
tests/test_compute_factors.py
Normal file
@@ -0,0 +1,194 @@
|
||||
"""compute_factors 入口脚本的单元测试(使用 mock 隔离外部依赖)。"""
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import polars as pl
|
||||
import pytest
|
||||
|
||||
from src.data.compute_factors import run
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_manager():
|
||||
"""提供一个 mock 的 FactorManager。"""
|
||||
manager = MagicMock()
|
||||
manager.get_all_factors.return_value = pl.DataFrame(
|
||||
{
|
||||
"factor_id": ["F_001", "F_002"],
|
||||
"name": ["ma_5", "ma_20"],
|
||||
"dsl": ["ts_mean(close, 5)", "ts_mean(close, 20)"],
|
||||
}
|
||||
)
|
||||
manager.get_factors_by_name.side_effect = lambda name: pl.DataFrame(
|
||||
{"name": [name], "dsl": [f"dsl_of_{name}"]}
|
||||
)
|
||||
return manager
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_engine():
|
||||
"""提供一个 mock 的 FactorEngine。"""
|
||||
engine = MagicMock()
|
||||
engine.compute.side_effect = lambda cols, start, end: pl.DataFrame(
|
||||
{
|
||||
"trade_date": ["20240101", "20240101"],
|
||||
"ts_code": ["000001.SZ", "000002.SZ"],
|
||||
cols[0]: [1.0, 2.0],
|
||||
}
|
||||
)
|
||||
return engine
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_storage(tmp_path):
|
||||
"""提供一个 mock 的 FactorStorage,全部校验通过。"""
|
||||
storage = MagicMock()
|
||||
storage.validate.return_value = (
|
||||
True,
|
||||
{"max_abs_diff": 0.0, "mean_abs_diff": 0.0, "matched_rows": 2},
|
||||
)
|
||||
storage.save.return_value = None
|
||||
return storage
|
||||
|
||||
|
||||
def test_run_auto_discover_factors(mock_manager, mock_engine, mock_storage):
|
||||
"""未传入 factor_names 时自动读取 metadata 中全部因子。"""
|
||||
with (
|
||||
patch("src.data.compute_factors.FactorEngine", return_value=mock_engine),
|
||||
patch("src.data.compute_factors.FactorStorage", return_value=mock_storage),
|
||||
patch("src.data.compute_factors.FactorManager", return_value=mock_manager),
|
||||
):
|
||||
result = run(
|
||||
factor_names=[],
|
||||
metadata="dummy.jsonl",
|
||||
start_date="20240101",
|
||||
end_date="20240131",
|
||||
)
|
||||
|
||||
assert result["success"] == ["ma_5", "ma_20"]
|
||||
assert len(result["failed"]) == 0
|
||||
assert mock_manager.get_all_factors.called
|
||||
|
||||
|
||||
def test_run_validate_fail_skip(mock_manager, mock_engine, mock_storage):
|
||||
"""校验失败且无 force 时跳过写入。"""
|
||||
mock_storage.validate.return_value = (
|
||||
False,
|
||||
{"max_abs_diff": 1.0, "mean_abs_diff": 0.5, "matched_rows": 2},
|
||||
)
|
||||
|
||||
with (
|
||||
patch("src.data.compute_factors.FactorEngine", return_value=mock_engine),
|
||||
patch("src.data.compute_factors.FactorStorage", return_value=mock_storage),
|
||||
patch("src.data.compute_factors.FactorManager", return_value=mock_manager),
|
||||
):
|
||||
result = run(
|
||||
factor_names=["ma_5"],
|
||||
metadata="dummy.jsonl",
|
||||
start_date="20240101",
|
||||
end_date="20240131",
|
||||
force=False,
|
||||
)
|
||||
|
||||
assert result["success"] == []
|
||||
assert len(result["failed"]) == 1
|
||||
assert result["failed"][0]["name"] == "ma_5"
|
||||
assert result["failed"][0]["reason"] == "校验失败"
|
||||
mock_storage.save.assert_not_called()
|
||||
|
||||
|
||||
def test_run_validate_fail_force(mock_manager, mock_engine, mock_storage):
|
||||
"""校验失败但 force=True 时强制写入。"""
|
||||
mock_storage.validate.return_value = (
|
||||
False,
|
||||
{"max_abs_diff": 1.0, "mean_abs_diff": 0.5, "matched_rows": 2},
|
||||
)
|
||||
|
||||
with (
|
||||
patch("src.data.compute_factors.FactorEngine", return_value=mock_engine),
|
||||
patch("src.data.compute_factors.FactorStorage", return_value=mock_storage),
|
||||
patch("src.data.compute_factors.FactorManager", return_value=mock_manager),
|
||||
):
|
||||
result = run(
|
||||
factor_names=["ma_5"],
|
||||
metadata="dummy.jsonl",
|
||||
start_date="20240101",
|
||||
end_date="20240131",
|
||||
force=True,
|
||||
)
|
||||
|
||||
assert result["success"] == ["ma_5"]
|
||||
assert len(result["failed"]) == 0
|
||||
mock_storage.save.assert_called_once()
|
||||
|
||||
|
||||
def test_run_missing_metadata_entry(mock_manager, mock_engine, mock_storage):
|
||||
"""metadata 中找不到因子时标记失败。"""
|
||||
mock_manager.get_factors_by_name.side_effect = None
|
||||
mock_manager.get_factors_by_name.return_value = pl.DataFrame(
|
||||
{"name": [], "dsl": []}
|
||||
)
|
||||
|
||||
with (
|
||||
patch("src.data.compute_factors.FactorEngine", return_value=mock_engine),
|
||||
patch("src.data.compute_factors.FactorStorage", return_value=mock_storage),
|
||||
patch("src.data.compute_factors.FactorManager", return_value=mock_manager),
|
||||
):
|
||||
result = run(
|
||||
factor_names=["unknown"],
|
||||
metadata="dummy.jsonl",
|
||||
start_date="20240101",
|
||||
end_date="20240131",
|
||||
)
|
||||
|
||||
assert result["success"] == []
|
||||
assert len(result["failed"]) == 1
|
||||
assert "未找到" in result["failed"][0]["reason"]
|
||||
|
||||
|
||||
def test_run_engine_exception(mock_manager, mock_engine, mock_storage):
|
||||
"""engine.compute 抛出异常时标记失败。"""
|
||||
mock_engine.compute.side_effect = ValueError("compute error")
|
||||
|
||||
with (
|
||||
patch("src.data.compute_factors.FactorEngine", return_value=mock_engine),
|
||||
patch("src.data.compute_factors.FactorStorage", return_value=mock_storage),
|
||||
patch("src.data.compute_factors.FactorManager", return_value=mock_manager),
|
||||
):
|
||||
result = run(
|
||||
factor_names=["ma_5"],
|
||||
metadata="dummy.jsonl",
|
||||
start_date="20240101",
|
||||
end_date="20240131",
|
||||
)
|
||||
|
||||
assert result["success"] == []
|
||||
assert len(result["failed"]) == 1
|
||||
assert "compute error" in result["failed"][0]["reason"]
|
||||
|
||||
|
||||
def test_run_missing_result_column(mock_engine, mock_storage, mock_manager):
|
||||
"""计算结果缺少对应因子列时标记失败。"""
|
||||
mock_engine.compute.side_effect = lambda cols, start, end: pl.DataFrame(
|
||||
{
|
||||
"trade_date": ["20240101"],
|
||||
"ts_code": ["000001.SZ"],
|
||||
"wrong_col": [1.0],
|
||||
}
|
||||
)
|
||||
|
||||
with (
|
||||
patch("src.data.compute_factors.FactorEngine", return_value=mock_engine),
|
||||
patch("src.data.compute_factors.FactorStorage", return_value=mock_storage),
|
||||
patch("src.data.compute_factors.FactorManager", return_value=mock_manager),
|
||||
):
|
||||
result = run(
|
||||
factor_names=["ma_5"],
|
||||
metadata="dummy.jsonl",
|
||||
start_date="20240101",
|
||||
end_date="20240131",
|
||||
)
|
||||
|
||||
assert result["success"] == []
|
||||
assert len(result["failed"]) == 1
|
||||
assert "缺少列" in result["failed"][0]["reason"]
|
||||
127
tests/test_factor_storage.py
Normal file
127
tests/test_factor_storage.py
Normal file
@@ -0,0 +1,127 @@
|
||||
"""FactorStorage 单元测试。"""
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import polars as pl
|
||||
import pytest
|
||||
|
||||
from src.data.factor_storage import FactorStorage
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def storage(tmp_path):
|
||||
return FactorStorage(base_dir=tmp_path / "factor")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_df():
|
||||
return pl.DataFrame(
|
||||
{
|
||||
"trade_date": ["20240101", "20240102", "20240103"],
|
||||
"ts_code": ["000001.SZ", "000002.SZ", "000003.SZ"],
|
||||
"test_factor": [1.0, 2.0, 3.0],
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def test_exists_and_save_load(storage, sample_df):
|
||||
assert not storage.exists("test_factor")
|
||||
storage.save("test_factor", sample_df)
|
||||
assert storage.exists("test_factor")
|
||||
|
||||
loaded = storage.load("test_factor")
|
||||
assert loaded.shape == (3, 3)
|
||||
assert set(loaded.columns) == {"trade_date", "ts_code", "test_factor"}
|
||||
assert loaded["test_factor"].to_list() == [1.0, 2.0, 3.0]
|
||||
|
||||
|
||||
def test_get_date_range(storage, sample_df):
|
||||
storage.save("test_factor", sample_df)
|
||||
dr = storage.get_date_range("test_factor")
|
||||
assert dr == ("20240101", "20240103")
|
||||
|
||||
assert storage.get_date_range("missing") is None
|
||||
|
||||
|
||||
def test_load_with_date_filter(storage, sample_df):
|
||||
storage.save("test_factor", sample_df)
|
||||
loaded = storage.load("test_factor", start_date="20240102")
|
||||
assert loaded.shape == (2, 3)
|
||||
assert loaded["trade_date"].to_list() == ["20240102", "20240103"]
|
||||
|
||||
|
||||
def test_incremental_update(storage, sample_df):
|
||||
# 第一次写入
|
||||
storage.save("test_factor", sample_df)
|
||||
|
||||
# 第二次写入:新增日期 + 覆盖已有日期
|
||||
new_df = pl.DataFrame(
|
||||
{
|
||||
"trade_date": ["20240103", "20240104"],
|
||||
"ts_code": ["000003.SZ", "000004.SZ"],
|
||||
"test_factor": [30.0, 4.0],
|
||||
}
|
||||
)
|
||||
storage.save("test_factor", new_df)
|
||||
|
||||
loaded = storage.load("test_factor").sort(["trade_date", "ts_code"])
|
||||
assert loaded.shape == (4, 3)
|
||||
assert loaded["test_factor"].to_list() == [1.0, 2.0, 30.0, 4.0]
|
||||
|
||||
|
||||
def test_validate_pass_when_no_local_file(storage, sample_df):
|
||||
passed, stats = storage.validate("test_factor", sample_df)
|
||||
assert passed is True
|
||||
assert stats == {}
|
||||
|
||||
|
||||
def test_validate_pass_with_identical_data(storage, sample_df):
|
||||
storage.save("test_factor", sample_df)
|
||||
passed, stats = storage.validate("test_factor", sample_df)
|
||||
assert passed is True
|
||||
assert stats["matched_rows"] == 3
|
||||
assert stats["max_abs_diff"] == pytest.approx(0.0)
|
||||
assert stats["mean_abs_diff"] == pytest.approx(0.0)
|
||||
|
||||
|
||||
def test_validate_fail_on_data_mismatch(storage, sample_df):
|
||||
storage.save("test_factor", sample_df)
|
||||
|
||||
modified = sample_df.with_columns(
|
||||
pl.when(pl.col("trade_date") == "20240101")
|
||||
.then(pl.col("test_factor") + 1.0)
|
||||
.otherwise(pl.col("test_factor"))
|
||||
.alias("test_factor")
|
||||
)
|
||||
passed, stats = storage.validate("test_factor", modified, tolerance=1e-6)
|
||||
assert passed is False
|
||||
assert stats["matched_rows"] == 3
|
||||
assert stats["max_abs_diff"] == pytest.approx(1.0)
|
||||
|
||||
|
||||
def test_validate_pass_with_non_overlapping_data(storage, sample_df):
|
||||
storage.save("test_factor", sample_df)
|
||||
|
||||
non_overlap = pl.DataFrame(
|
||||
{
|
||||
"trade_date": ["20240105"],
|
||||
"ts_code": ["000001.SZ"],
|
||||
"test_factor": [99.0],
|
||||
}
|
||||
)
|
||||
passed, stats = storage.validate("test_factor", non_overlap)
|
||||
assert passed is True
|
||||
assert stats == {}
|
||||
|
||||
|
||||
def test_save_preserves_column_order(storage):
|
||||
df = pl.DataFrame(
|
||||
{
|
||||
"trade_date": ["20240101"],
|
||||
"ts_code": ["000001.SZ"],
|
||||
"my_factor": [1.5],
|
||||
}
|
||||
)
|
||||
storage.save("my_factor", df)
|
||||
pdf = pd.read_hdf(storage._file_path("my_factor"), key=storage._HDF_KEY)
|
||||
assert list(pdf.columns) == ["trade_date", "ts_code", "my_factor"] # type: ignore[attr-defined]
|
||||
Reference in New Issue
Block a user