125 lines
3.8 KiB
Python
125 lines
3.8 KiB
Python
|
|
"""Tests for LocalFactorEvaluator."""
|
||
|
|
|
||
|
|
from __future__ import annotations
|
||
|
|
|
||
|
|
from typing import Dict, List, Tuple
|
||
|
|
|
||
|
|
import numpy as np
|
||
|
|
import pytest
|
||
|
|
|
||
|
|
from src.factorminer.evaluation.local_engine import LocalFactorEvaluator
|
||
|
|
|
||
|
|
|
||
|
|
class TestLocalFactorEvaluator:
|
||
|
|
"""测试 LocalFactorEvaluator 的基本功能。"""
|
||
|
|
|
||
|
|
def test_init(self) -> None:
|
||
|
|
"""测试初始化。"""
|
||
|
|
evaluator = LocalFactorEvaluator(
|
||
|
|
start_date="20200101",
|
||
|
|
end_date="20200131",
|
||
|
|
stock_codes=None,
|
||
|
|
)
|
||
|
|
assert evaluator.start_date == "20200101"
|
||
|
|
assert evaluator.end_date == "20200131"
|
||
|
|
assert evaluator.stock_codes is None
|
||
|
|
assert evaluator.engine is not None
|
||
|
|
|
||
|
|
def test_evaluate_empty_specs(self) -> None:
|
||
|
|
"""测试空规格列表。"""
|
||
|
|
evaluator = LocalFactorEvaluator(
|
||
|
|
start_date="20200101",
|
||
|
|
end_date="20200131",
|
||
|
|
)
|
||
|
|
result = evaluator.evaluate([])
|
||
|
|
assert result == {}
|
||
|
|
|
||
|
|
def test_evaluate_returns_shape(self) -> None:
|
||
|
|
"""测试 evaluate_returns 返回矩阵形状。"""
|
||
|
|
evaluator = LocalFactorEvaluator(
|
||
|
|
start_date="20200101",
|
||
|
|
end_date="20200131",
|
||
|
|
)
|
||
|
|
returns = evaluator.evaluate_returns(periods=1)
|
||
|
|
# 验证返回的是 numpy 数组
|
||
|
|
assert isinstance(returns, np.ndarray)
|
||
|
|
|
||
|
|
def test_evaluate_single_basic(self) -> None:
|
||
|
|
"""测试单个因子计算基本功能。"""
|
||
|
|
evaluator = LocalFactorEvaluator(
|
||
|
|
start_date="20200101",
|
||
|
|
end_date="20200131",
|
||
|
|
)
|
||
|
|
# 测试计算 close 因子
|
||
|
|
try:
|
||
|
|
result = evaluator.evaluate_single("close", "close")
|
||
|
|
assert isinstance(result, np.ndarray)
|
||
|
|
# 验证结果是 2D 矩阵
|
||
|
|
assert result.ndim == 2
|
||
|
|
except Exception as e:
|
||
|
|
# 数据可能不存在,跳过
|
||
|
|
pytest.skip(f"数据不存在: {e}")
|
||
|
|
|
||
|
|
def test_evaluate_pct_change(self) -> None:
|
||
|
|
"""测试收益率计算。"""
|
||
|
|
evaluator = LocalFactorEvaluator(
|
||
|
|
start_date="20200101",
|
||
|
|
end_date="20200131",
|
||
|
|
)
|
||
|
|
try:
|
||
|
|
result = evaluator.evaluate_single(
|
||
|
|
"pct_change", "close / ts_delay(close, 1) - 1"
|
||
|
|
)
|
||
|
|
assert isinstance(result, np.ndarray)
|
||
|
|
assert result.ndim == 2
|
||
|
|
except Exception as e:
|
||
|
|
pytest.skip(f"数据不存在: {e}")
|
||
|
|
|
||
|
|
def test_pivot_to_matrix_structure(self) -> None:
|
||
|
|
"""测试 _pivot_to_matrix 的结构。"""
|
||
|
|
import polars as pl
|
||
|
|
|
||
|
|
evaluator = LocalFactorEvaluator(
|
||
|
|
start_date="20200101",
|
||
|
|
end_date="20200131",
|
||
|
|
)
|
||
|
|
|
||
|
|
# 创建测试数据
|
||
|
|
df = pl.DataFrame(
|
||
|
|
{
|
||
|
|
"ts_code": ["000001.SZ", "000001.SZ", "000002.SZ", "000002.SZ"],
|
||
|
|
"trade_date": ["20200101", "20200102", "20200101", "20200102"],
|
||
|
|
"factor1": [1.0, 2.0, 3.0, 4.0],
|
||
|
|
}
|
||
|
|
)
|
||
|
|
|
||
|
|
result = evaluator._pivot_to_matrix(df, ["factor1"])
|
||
|
|
|
||
|
|
assert "factor1" in result
|
||
|
|
assert isinstance(result["factor1"], np.ndarray)
|
||
|
|
assert result["factor1"].ndim == 2
|
||
|
|
|
||
|
|
def test_batch_evaluate(self) -> None:
|
||
|
|
"""测试批量计算。"""
|
||
|
|
evaluator = LocalFactorEvaluator(
|
||
|
|
start_date="20200101",
|
||
|
|
end_date="20200131",
|
||
|
|
)
|
||
|
|
|
||
|
|
specs: List[Tuple[str, str]] = [
|
||
|
|
("close", "close"),
|
||
|
|
("open", "open"),
|
||
|
|
]
|
||
|
|
|
||
|
|
try:
|
||
|
|
result = evaluator.evaluate(specs)
|
||
|
|
assert isinstance(result, dict)
|
||
|
|
assert "close" in result
|
||
|
|
assert "open" in result
|
||
|
|
except Exception as e:
|
||
|
|
pytest.skip(f"数据不存在: {e}")
|
||
|
|
|
||
|
|
|
||
|
|
if __name__ == "__main__":
|
||
|
|
pytest.main([__file__, "-v"])
|