feat(factorminer): 添加 LocalFactorEvaluator 封装本地因子矩阵计算
- 新增 LocalFactorEvaluator 类,将 FactorEngine 的 Polars 长表输出转换为 (M, T) numpy 矩阵 - 支持批量因子计算、单因子计算及收益率矩阵计算 - 补充完整单元测试,覆盖 pivot、缺失值填充、股票代码过滤及异常处理场景
This commit is contained in:
159
src/factorminer/evaluation/local_engine.py
Normal file
159
src/factorminer/evaluation/local_engine.py
Normal file
@@ -0,0 +1,159 @@
|
||||
"""本地 FactorEngine 封装,用于将 Polars 长表输出转换为 (M, T) numpy 矩阵。
|
||||
|
||||
提供与 FactorMiner 信号计算接口兼容的评估器。
|
||||
"""
|
||||
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import polars as pl
|
||||
|
||||
from src.factors import FactorEngine
|
||||
|
||||
|
||||
class LocalFactorEvaluator:
|
||||
"""封装本地 FactorEngine,输出 (M, T) numpy 矩阵格式的因子信号。
|
||||
|
||||
Attributes:
|
||||
_start_date: 计算开始日期
|
||||
_end_date: 计算结束日期
|
||||
_stock_codes: 可选的股票代码列表
|
||||
_engine: 内部 FactorEngine 实例
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
start_date: str,
|
||||
end_date: str,
|
||||
stock_codes: Optional[List[str]] = None,
|
||||
) -> None:
|
||||
"""初始化评估器。
|
||||
|
||||
Args:
|
||||
start_date: 计算开始日期,YYYYMMDD 格式
|
||||
end_date: 计算结束日期,YYYYMMDD 格式
|
||||
stock_codes: 可选的股票代码列表,None 表示全量
|
||||
"""
|
||||
self._start_date = start_date
|
||||
self._end_date = end_date
|
||||
self._stock_codes = stock_codes
|
||||
self._engine = FactorEngine()
|
||||
|
||||
def evaluate(
|
||||
self,
|
||||
specs: List[Tuple[str, str]],
|
||||
) -> Dict[str, np.ndarray]:
|
||||
"""批量计算并返回 {name: (M, T) 矩阵}。
|
||||
|
||||
Args:
|
||||
specs: (因子名, 本地 DSL 公式) 列表
|
||||
|
||||
Returns:
|
||||
每个因子对应的 (asset, time) numpy 矩阵,缺失值填充 np.nan
|
||||
"""
|
||||
if not specs:
|
||||
return {}
|
||||
|
||||
print(f"[local_engine] 开始批量计算 {len(specs)} 个因子...")
|
||||
for name, formula in specs:
|
||||
self._engine.add_factor(name, formula)
|
||||
|
||||
factor_names = [name for name, _ in specs]
|
||||
df = self._engine.compute(
|
||||
factor_names,
|
||||
self._start_date,
|
||||
self._end_date,
|
||||
self._stock_codes,
|
||||
)
|
||||
|
||||
result: Dict[str, np.ndarray] = {}
|
||||
for name in factor_names:
|
||||
result[name] = self._to_matrix(df, name)
|
||||
|
||||
self._engine.clear()
|
||||
print("[local_engine] 批量计算完成")
|
||||
return result
|
||||
|
||||
def evaluate_single(
|
||||
self,
|
||||
name: str,
|
||||
formula: str,
|
||||
) -> np.ndarray:
|
||||
"""计算单个因子。
|
||||
|
||||
Args:
|
||||
name: 因子名称
|
||||
formula: 本地 DSL 公式字符串
|
||||
|
||||
Returns:
|
||||
(M, T) numpy 矩阵
|
||||
"""
|
||||
self._engine.add_factor(name, formula)
|
||||
df = self._engine.compute(
|
||||
[name],
|
||||
self._start_date,
|
||||
self._end_date,
|
||||
self._stock_codes,
|
||||
)
|
||||
matrix = self._to_matrix(df, name)
|
||||
self._engine.clear()
|
||||
return matrix
|
||||
|
||||
def evaluate_returns(
|
||||
self,
|
||||
periods: int = 1,
|
||||
) -> np.ndarray:
|
||||
"""计算收益率矩阵,用于后续 IC / quintile 分析。
|
||||
|
||||
Args:
|
||||
periods: 收益周期,默认为 1
|
||||
|
||||
Returns:
|
||||
(M, T) 的收益率矩阵
|
||||
"""
|
||||
name = "__returns"
|
||||
formula = f"ts_pct_change(close, {periods})"
|
||||
self._engine.add_factor(name, formula)
|
||||
df = self._engine.compute(
|
||||
[name],
|
||||
self._start_date,
|
||||
self._end_date,
|
||||
self._stock_codes,
|
||||
)
|
||||
matrix = self._to_matrix(df, name)
|
||||
self._engine.clear()
|
||||
return matrix
|
||||
|
||||
def _to_matrix(
|
||||
self,
|
||||
df: pl.DataFrame,
|
||||
factor_name: str,
|
||||
) -> np.ndarray:
|
||||
"""将 Polars 长表 pivot 为 (M, T) numpy 矩阵。
|
||||
|
||||
Args:
|
||||
df: FactorEngine.compute 返回的长表
|
||||
factor_name: 要提取的因子列名
|
||||
|
||||
Returns:
|
||||
(asset, time) 的 numpy 矩阵,缺失值填充 np.nan
|
||||
"""
|
||||
if "ts_code" not in df.columns or "trade_date" not in df.columns:
|
||||
raise ValueError(
|
||||
f"DataFrame 缺少必需的 ts_code 或 trade_date 列,当前列: {df.columns}"
|
||||
)
|
||||
|
||||
if factor_name not in df.columns:
|
||||
raise ValueError(
|
||||
f"DataFrame 中未找到因子列 '{factor_name}',当前列: {df.columns}"
|
||||
)
|
||||
|
||||
pivot_df = df.pivot(
|
||||
index="ts_code",
|
||||
on="trade_date",
|
||||
values=factor_name,
|
||||
).sort("ts_code")
|
||||
|
||||
# 移除 index 列,保留数据并转为 float64 numpy 矩阵
|
||||
data = pivot_df.drop("ts_code").to_numpy()
|
||||
return data.astype(np.float64)
|
||||
155
tests/test_factorminer_local_engine.py
Normal file
155
tests/test_factorminer_local_engine.py
Normal file
@@ -0,0 +1,155 @@
|
||||
"""LocalFactorEvaluator 单元测试。
|
||||
|
||||
使用 MagicMock 模拟 FactorEngine,避免依赖真实数据库。
|
||||
"""
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import numpy as np
|
||||
import polars as pl
|
||||
import pytest
|
||||
|
||||
from src.factorminer.evaluation.local_engine import LocalFactorEvaluator
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_engine():
|
||||
"""提供 mock 的 FactorEngine 实例。"""
|
||||
with patch("src.factorminer.evaluation.local_engine.FactorEngine") as mock_cls:
|
||||
instance = MagicMock()
|
||||
mock_cls.return_value = instance
|
||||
yield instance
|
||||
|
||||
|
||||
def test_evaluate_single(mock_engine):
|
||||
"""测试单个因子计算并正确 pivot 为矩阵。"""
|
||||
mock_engine.compute.return_value = pl.DataFrame(
|
||||
{
|
||||
"ts_code": ["000001.SZ", "000001.SZ", "000002.SZ", "000002.SZ"],
|
||||
"trade_date": ["20240101", "20240102", "20240101", "20240102"],
|
||||
"alpha": [1.0, 2.0, 3.0, 4.0],
|
||||
}
|
||||
)
|
||||
evaluator = LocalFactorEvaluator("20240101", "20240102")
|
||||
result = evaluator.evaluate_single("alpha", "cs_rank(close)")
|
||||
|
||||
mock_engine.add_factor.assert_called_once_with("alpha", "cs_rank(close)")
|
||||
mock_engine.compute.assert_called_once_with(["alpha"], "20240101", "20240102", None)
|
||||
mock_engine.clear.assert_called_once()
|
||||
|
||||
assert result.shape == (2, 2)
|
||||
np.testing.assert_array_equal(
|
||||
result,
|
||||
np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float64),
|
||||
)
|
||||
|
||||
|
||||
def test_evaluate_empty_specs(mock_engine):
|
||||
"""测试空 specs 直接返回空字典。"""
|
||||
evaluator = LocalFactorEvaluator("20240101", "20240102")
|
||||
result = evaluator.evaluate([])
|
||||
assert result == {}
|
||||
mock_engine.compute.assert_not_called()
|
||||
|
||||
|
||||
def test_evaluate_batch(mock_engine):
|
||||
"""测试批量计算多个因子。"""
|
||||
mock_engine.compute.return_value = pl.DataFrame(
|
||||
{
|
||||
"ts_code": ["000001.SZ", "000001.SZ", "000002.SZ", "000002.SZ"],
|
||||
"trade_date": ["20240101", "20240102", "20240101", "20240102"],
|
||||
"alpha1": [1.0, 2.0, 3.0, 4.0],
|
||||
"alpha2": [5.0, 6.0, 7.0, 8.0],
|
||||
}
|
||||
)
|
||||
evaluator = LocalFactorEvaluator("20240101", "20240102")
|
||||
result = evaluator.evaluate(
|
||||
[
|
||||
("alpha1", "cs_rank(close)"),
|
||||
("alpha2", "cs_rank(vol)"),
|
||||
]
|
||||
)
|
||||
|
||||
assert set(result.keys()) == {"alpha1", "alpha2"}
|
||||
np.testing.assert_array_equal(
|
||||
result["alpha1"],
|
||||
np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float64),
|
||||
)
|
||||
np.testing.assert_array_equal(
|
||||
result["alpha2"],
|
||||
np.array([[5.0, 6.0], [7.0, 8.0]], dtype=np.float64),
|
||||
)
|
||||
mock_engine.clear.assert_called_once()
|
||||
|
||||
|
||||
def test_evaluate_returns(mock_engine):
|
||||
"""测试收益率矩阵计算。"""
|
||||
mock_engine.compute.return_value = pl.DataFrame(
|
||||
{
|
||||
"ts_code": ["000001.SZ", "000001.SZ", "000002.SZ", "000002.SZ"],
|
||||
"trade_date": ["20240101", "20240102", "20240101", "20240102"],
|
||||
"__returns": [0.01, 0.02, -0.01, 0.03],
|
||||
}
|
||||
)
|
||||
evaluator = LocalFactorEvaluator("20240101", "20240102")
|
||||
result = evaluator.evaluate_returns(periods=5)
|
||||
|
||||
mock_engine.add_factor.assert_called_once_with(
|
||||
"__returns", "ts_pct_change(close, 5)"
|
||||
)
|
||||
assert result.shape == (2, 2)
|
||||
np.testing.assert_array_equal(
|
||||
result,
|
||||
np.array([[0.01, 0.02], [-0.01, 0.03]], dtype=np.float64),
|
||||
)
|
||||
mock_engine.clear.assert_called_once()
|
||||
|
||||
|
||||
def test_evaluate_with_nan(mock_engine):
|
||||
"""测试缺失值正确填充为 np.nan。"""
|
||||
mock_engine.compute.return_value = pl.DataFrame(
|
||||
{
|
||||
"ts_code": ["000001.SZ", "000001.SZ", "000002.SZ"],
|
||||
"trade_date": ["20240101", "20240102", "20240101"],
|
||||
"alpha": [1.0, 2.0, 3.0],
|
||||
}
|
||||
)
|
||||
evaluator = LocalFactorEvaluator("20240101", "20240102")
|
||||
result = evaluator.evaluate_single("alpha", "cs_rank(close)")
|
||||
|
||||
assert result.shape == (2, 2)
|
||||
assert np.isnan(result[1, 1])
|
||||
assert result[0, 0] == 1.0
|
||||
assert result[0, 1] == 2.0
|
||||
assert result[1, 0] == 3.0
|
||||
|
||||
|
||||
def test_stock_codes_filter(mock_engine):
|
||||
"""测试传入股票代码列表时正确透传给 FactorEngine。"""
|
||||
mock_engine.compute.return_value = pl.DataFrame(
|
||||
{
|
||||
"ts_code": ["000001.SZ", "000001.SZ"],
|
||||
"trade_date": ["20240101", "20240102"],
|
||||
"alpha": [1.0, 2.0],
|
||||
}
|
||||
)
|
||||
evaluator = LocalFactorEvaluator("20240101", "20240102", stock_codes=["000001.SZ"])
|
||||
result = evaluator.evaluate_single("alpha", "close")
|
||||
|
||||
mock_engine.compute.assert_called_once_with(
|
||||
["alpha"], "20240101", "20240102", ["000001.SZ"]
|
||||
)
|
||||
assert result.shape == (1, 2)
|
||||
|
||||
|
||||
def test_to_matrix_missing_factor_column(mock_engine):
|
||||
"""测试 DataFrame 缺少目标因子列时抛出 ValueError。"""
|
||||
mock_engine.compute.return_value = pl.DataFrame(
|
||||
{
|
||||
"ts_code": ["000001.SZ"],
|
||||
"trade_date": ["20240101"],
|
||||
}
|
||||
)
|
||||
evaluator = LocalFactorEvaluator("20240101", "20240102")
|
||||
with pytest.raises(ValueError, match="未找到因子列 'alpha'"):
|
||||
evaluator.evaluate_single("alpha", "close")
|
||||
Reference in New Issue
Block a user