feat(factors): 添加因子计算框架
- 新增因子基类 (BaseFactor, CrossSectionalFactor, TimeSeriesFactor) - 新增数据规格和上下文类 (DataSpec, FactorContext, FactorData) - 新增数据加载器 (DataLoader) 和执行引擎 (FactorEngine) - 新增组合因子支持 (CompositeFactor, ScalarFactor) - 添加因子模块完整测试用例 - 添加 Git 提交规范文档
This commit is contained in:
2
tests/factors/__init__.py
Normal file
2
tests/factors/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
# tests/factors/__init__.py
|
||||
"""Factors 模块测试包"""
|
||||
143
tests/factors/factor_test_report.md
Normal file
143
tests/factors/factor_test_report.md
Normal file
@@ -0,0 +1,143 @@
|
||||
# 因子真实数据测试报告
|
||||
|
||||
## 1. 测试概述
|
||||
|
||||
本测试使用 `daily.h5` 文件中的真实A股市场数据,对 ProStock 因子框架进行验证。测试对比了因子框架计算结果与 Polars 原生计算结果,验证因子计算的正确性。
|
||||
|
||||
### 测试数据
|
||||
- **数据源**: `data/daily.h5`
|
||||
- **时间范围**: 2024-01-01 至 2024-04-30
|
||||
- **股票数量**: 20只
|
||||
- **数据量**: 1,560条记录
|
||||
|
||||
### 测试因子类型
|
||||
1. **时序因子 (TimeSeriesFactor)**: 移动平均线 (MA)
|
||||
2. **截面因子 (CrossSectionalFactor)**: PE排名 (PE_Rank)
|
||||
3. **结合因子 (CompositeFactor)**: 标量组合 (0.5 * MA) 和因子加法 (MA5 + MA10)
|
||||
|
||||
---
|
||||
|
||||
## 2. 测试结果
|
||||
|
||||
### 2.1 时序因子测试 - MA(5)
|
||||
|
||||
```
|
||||
[时序因子 MA(5) 对比]
|
||||
样本股票: 000001.SZ
|
||||
有效数据点: 77
|
||||
最大差异: 0.000000000000000
|
||||
样本数据 (前5个):
|
||||
Polars: 10.022000, Factor: 10.022000, Diff: 0.000000000000000
|
||||
Polars: 10.046000, Factor: 10.046000, Diff: 0.000000000000000
|
||||
Polars: 10.056000, Factor: 10.056000, Diff: 0.000000000000000
|
||||
Polars: 10.072000, Factor: 10.072000, Diff: 0.000000000000000
|
||||
Polars: 10.078000, Factor: 10.078000, Diff: 0.000000000000000
|
||||
```
|
||||
|
||||
**结论**: ✅ **通过** - 因子框架计算的 MA(5) 与 Polars 原生计算完全一致
|
||||
|
||||
---
|
||||
|
||||
### 2.2 截面因子测试 - PE_Rank
|
||||
|
||||
```
|
||||
[截面因子 PE_Rank 对比]
|
||||
样本日期: 20240131
|
||||
股票数量: 20
|
||||
最大差异: 0.000000000000000
|
||||
样本数据 (前5个):
|
||||
000001.SZ: Polars: 0.050000, Factor: 0.050000
|
||||
000002.SZ: Polars: 0.550000, Factor: 0.550000
|
||||
000004.SZ: Polars: 0.300000, Factor: 0.300000
|
||||
000005.SZ: Polars: 0.100000, Factor: 0.100000
|
||||
000006.SZ: Polars: 0.400000, Factor: 0.400000
|
||||
```
|
||||
|
||||
**结论**: ✅ **通过** - 因子框架计算的 PE_Rank 与 Polars 原生计算完全一致
|
||||
|
||||
---
|
||||
|
||||
### 2.3 结合因子测试 - 0.5 * MA(5)
|
||||
|
||||
```
|
||||
[结合因子 0.5*MA(5) 对比]
|
||||
公式: 0.5 * MA(5)
|
||||
有效数据点: 77
|
||||
最大差异: 0.000000000000000
|
||||
样本数据 (前5个):
|
||||
Polars: 5.011000, Factor: 5.011000, Diff: 0.000000000000000
|
||||
Polars: 5.023000, Factor: 5.023000, Diff: 0.000000000000000
|
||||
Polars: 5.028000, Factor: 5.028000, Diff: 0.000000000000000
|
||||
Polars: 5.036000, Factor: 5.036000, Diff: 0.000000000000000
|
||||
Polars: 5.039000, Factor: 5.039000, Diff: 0.000000000000000
|
||||
```
|
||||
|
||||
**结论**: ✅ **通过** - 标量组合因子计算正确
|
||||
|
||||
---
|
||||
|
||||
### 2.4 结合因子测试 - MA(5) + MA(10)
|
||||
|
||||
```
|
||||
[结合因子 MA(5) + MA(10) 对比]
|
||||
有效数据点: 72
|
||||
最大差异: 0.000000000000000
|
||||
```
|
||||
|
||||
**结论**: ✅ **通过** - 因子加法组合计算正确
|
||||
|
||||
---
|
||||
|
||||
## 3. 综合测试汇总
|
||||
|
||||
```
|
||||
============================================================
|
||||
因子测试汇总
|
||||
============================================================
|
||||
MA(5): 最大差异 = 0.00e+00 通过
|
||||
MA(10): 最大差异 = 0.00e+00 通过
|
||||
MA(20): 最大差异 = 0.00e+00 通过
|
||||
PE_Rank: 最大差异 = 0.00e+00 通过
|
||||
============================================================
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 4. 测试结论
|
||||
|
||||
### 4.1 全部通过 ✅
|
||||
|
||||
所有5个测试用例均通过验证:
|
||||
|
||||
| 测试项目 | 因子类型 | 最大差异 | 状态 |
|
||||
|---------|---------|---------|------|
|
||||
| MA(5) | 时序因子 | 0.00e+00 | ✅ 通过 |
|
||||
| MA(10) | 时序因子 | 0.00e+00 | ✅ 通过 |
|
||||
| MA(20) | 时序因子 | 0.00e+00 | ✅ 通过 |
|
||||
| PE_Rank | 截面因子 | 0.00e+00 | ✅ 通过 |
|
||||
| 0.5 * MA(5) | 结合因子 | 0.00e+00 | ✅ 通过 |
|
||||
| MA(5) + MA(10) | 结合因子 | 0.00e+00 | ✅ 通过 |
|
||||
|
||||
### 4.2 关键发现
|
||||
|
||||
1. **计算精度**: 因子框架与 Polars 原生计算结果的差异为 0,表明计算精度完全一致
|
||||
2. **时序因子**: `TimeSeriesFactor` 基类正确实现了股票级别的时序计算
|
||||
3. **截面因子**: `CrossSectionalFactor` 基类正确实现了日期级别的截面计算
|
||||
4. **组合因子**: `ScalarFactor` 和 `CompositeFactor` 正确实现了标量运算和因子组合
|
||||
|
||||
### 4.3 验证结论
|
||||
|
||||
ProStock 因子框架的计算逻辑与 Polars 原生计算完全一致,框架设计正确,可以用于实际量化投资研究。
|
||||
|
||||
---
|
||||
|
||||
## 5. 测试环境
|
||||
|
||||
- Python: 3.13.2
|
||||
- Polars: 最新版本
|
||||
- Pytest: 9.0.2
|
||||
- 数据: daily.h5 (8,856,081 条记录)
|
||||
|
||||
---
|
||||
|
||||
*报告生成时间: 2026-02-22*
|
||||
406
tests/factors/test_base.py
Normal file
406
tests/factors/test_base.py
Normal file
@@ -0,0 +1,406 @@
|
||||
"""测试因子基类 - BaseFactor、CrossSectionalFactor、TimeSeriesFactor
|
||||
|
||||
测试需求(来自 factor_implementation_plan.md):
|
||||
- BaseFactor:
|
||||
- 测试有效子类创建通过验证
|
||||
- 测试缺少 `name` 时抛出 ValueError
|
||||
- 测试 `name` 为空字符串时抛出 ValueError
|
||||
- 测试缺少 `factor_type` 时抛出 ValueError
|
||||
- 测试无效的 `factor_type`(非 cs/ts)时抛出 ValueError
|
||||
- 测试缺少 `data_specs` 时抛出 ValueError
|
||||
- 测试 `data_specs` 为空列表时抛出 ValueError
|
||||
- 测试 `compute()` 抽象方法强制子类实现
|
||||
- 测试参数化初始化 `params` 正确存储
|
||||
- 测试 `_validate_params()` 被调用
|
||||
|
||||
- CrossSectionalFactor:
|
||||
- 测试 `factor_type` 自动设置为 "cross_sectional"
|
||||
- 测试子类必须实现 `compute()`
|
||||
- 测试 `compute()` 返回类型为 pl.Series
|
||||
|
||||
- TimeSeriesFactor:
|
||||
- 测试 `factor_type` 自动设置为 "time_series"
|
||||
- 测试子类必须实现 `compute()`
|
||||
- 测试 `compute()` 返回类型为 pl.Series
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import polars as pl
|
||||
|
||||
from src.factors import DataSpec, FactorContext, FactorData
|
||||
from src.factors.base import BaseFactor, CrossSectionalFactor, TimeSeriesFactor
|
||||
|
||||
|
||||
# ========== 测试数据准备 ==========
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_dataspec():
|
||||
"""创建一个示例 DataSpec"""
|
||||
return DataSpec(
|
||||
source="daily", columns=["ts_code", "trade_date", "close"], lookback_days=5
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_factor_data():
|
||||
"""创建一个示例 FactorData"""
|
||||
df = pl.DataFrame(
|
||||
{
|
||||
"ts_code": ["000001.SZ", "000002.SZ", "000001.SZ", "000002.SZ"],
|
||||
"trade_date": ["20240101", "20240101", "20240102", "20240102"],
|
||||
"close": [10.0, 20.0, 11.0, 21.0],
|
||||
}
|
||||
)
|
||||
context = FactorContext(current_date="20240102")
|
||||
return FactorData(df, context)
|
||||
|
||||
|
||||
# ========== BaseFactor 测试 ==========
|
||||
|
||||
|
||||
class TestBaseFactorValidation:
|
||||
"""测试 BaseFactor 子类验证"""
|
||||
|
||||
def test_valid_cross_sectional_subclass(self, sample_dataspec):
|
||||
"""测试有效的截面因子子类创建通过验证"""
|
||||
|
||||
class ValidFactor(CrossSectionalFactor):
|
||||
name = "valid_cs"
|
||||
data_specs = [sample_dataspec]
|
||||
|
||||
def compute(self, data):
|
||||
return pl.Series([1.0])
|
||||
|
||||
# 应该能成功创建实例
|
||||
factor = ValidFactor()
|
||||
assert factor.name == "valid_cs"
|
||||
assert factor.factor_type == "cross_sectional"
|
||||
|
||||
def test_valid_time_series_subclass(self, sample_dataspec):
|
||||
"""测试有效的时序因子子类创建通过验证"""
|
||||
|
||||
class ValidFactor(TimeSeriesFactor):
|
||||
name = "valid_ts"
|
||||
data_specs = [sample_dataspec]
|
||||
|
||||
def compute(self, data):
|
||||
return pl.Series([1.0])
|
||||
|
||||
factor = ValidFactor()
|
||||
assert factor.name == "valid_ts"
|
||||
assert factor.factor_type == "time_series"
|
||||
|
||||
def test_missing_name(self, sample_dataspec):
|
||||
"""测试缺少 name 时抛出 ValueError"""
|
||||
with pytest.raises(ValueError, match="must define 'name'"):
|
||||
|
||||
class BadFactor(CrossSectionalFactor):
|
||||
# name = "" # 故意不定义
|
||||
data_specs = [sample_dataspec]
|
||||
|
||||
def compute(self, data):
|
||||
return pl.Series([1.0])
|
||||
|
||||
def test_empty_name(self, sample_dataspec):
|
||||
"""测试 name 为空字符串时抛出 ValueError"""
|
||||
with pytest.raises(ValueError, match="must define 'name'"):
|
||||
|
||||
class BadFactor(CrossSectionalFactor):
|
||||
name = "" # 空字符串
|
||||
data_specs = [sample_dataspec]
|
||||
|
||||
def compute(self, data):
|
||||
return pl.Series([1.0])
|
||||
|
||||
def test_missing_factor_type(self, sample_dataspec):
|
||||
"""测试缺少 factor_type 时抛出 ValueError"""
|
||||
with pytest.raises(ValueError, match="must define 'factor_type'"):
|
||||
|
||||
class BadFactor(BaseFactor):
|
||||
name = "bad_factor"
|
||||
# factor_type = "" # 故意不定义
|
||||
data_specs = [sample_dataspec]
|
||||
|
||||
def compute(self, data):
|
||||
return pl.Series([1.0])
|
||||
|
||||
def test_invalid_factor_type(self, sample_dataspec):
|
||||
"""测试无效的 factor_type(非 cs/ts)时抛出 ValueError"""
|
||||
with pytest.raises(
|
||||
ValueError, match="must be 'cross_sectional' or 'time_series'"
|
||||
):
|
||||
|
||||
class BadFactor(BaseFactor):
|
||||
name = "bad_factor"
|
||||
factor_type = "invalid_type"
|
||||
data_specs = [sample_dataspec]
|
||||
|
||||
def compute(self, data):
|
||||
return pl.Series([1.0])
|
||||
|
||||
def test_missing_data_specs(self):
|
||||
"""测试缺少 data_specs 时抛出 ValueError"""
|
||||
with pytest.raises(ValueError, match="must define 'data_specs'"):
|
||||
|
||||
class BadFactor(BaseFactor):
|
||||
name = "bad_factor"
|
||||
factor_type = "cross_sectional"
|
||||
# data_specs = [] # 故意不定义
|
||||
|
||||
def compute(self, data):
|
||||
return pl.Series([1.0])
|
||||
|
||||
def test_empty_data_specs(self):
|
||||
"""测试 data_specs 为空列表时抛出 ValueError"""
|
||||
with pytest.raises(ValueError, match="cannot be empty"):
|
||||
|
||||
class BadFactor(CrossSectionalFactor):
|
||||
name = "bad_factor"
|
||||
data_specs = [] # 空列表
|
||||
|
||||
def compute(self, data):
|
||||
return pl.Series([1.0])
|
||||
|
||||
|
||||
class TestBaseFactorCompute:
|
||||
"""测试 compute() 抽象方法"""
|
||||
|
||||
def test_compute_must_be_implemented_cs(self, sample_dataspec):
|
||||
"""测试截面因子子类必须实现 compute()"""
|
||||
with pytest.raises(TypeError, match="Can't instantiate abstract class"):
|
||||
|
||||
class BadFactor(CrossSectionalFactor):
|
||||
name = "bad_cs"
|
||||
data_specs = [sample_dataspec]
|
||||
# 不实现 compute()
|
||||
|
||||
BadFactor()
|
||||
|
||||
def test_compute_must_be_implemented_ts(self, sample_dataspec):
|
||||
"""测试时序因子子类必须实现 compute()"""
|
||||
with pytest.raises(TypeError, match="Can't instantiate abstract class"):
|
||||
|
||||
class BadFactor(TimeSeriesFactor):
|
||||
name = "bad_ts"
|
||||
data_specs = [sample_dataspec]
|
||||
# 不实现 compute()
|
||||
|
||||
BadFactor()
|
||||
|
||||
|
||||
class TestBaseFactorParams:
|
||||
"""测试参数化初始化"""
|
||||
|
||||
def test_params_stored_correctly(self, sample_dataspec):
|
||||
"""测试参数化初始化 params 正确存储"""
|
||||
|
||||
class ParamFactor(CrossSectionalFactor):
|
||||
name = "param_factor"
|
||||
data_specs = [sample_dataspec]
|
||||
|
||||
def __init__(self, period: int = 20, weight: float = 1.0):
|
||||
super().__init__(period=period, weight=weight)
|
||||
|
||||
def compute(self, data):
|
||||
return pl.Series([1.0])
|
||||
|
||||
factor = ParamFactor(period=10, weight=0.5)
|
||||
assert factor.params["period"] == 10
|
||||
assert factor.params["weight"] == 0.5
|
||||
|
||||
def test_validate_params_called(self, sample_dataspec):
|
||||
"""测试 _validate_params() 被调用"""
|
||||
validated = []
|
||||
|
||||
class ValidatedFactor(CrossSectionalFactor):
|
||||
name = "validated_factor"
|
||||
data_specs = [sample_dataspec]
|
||||
|
||||
def __init__(self, period: int = 20):
|
||||
super().__init__(period=period)
|
||||
|
||||
def _validate_params(self):
|
||||
validated.append(True)
|
||||
if self.params.get("period", 0) <= 0:
|
||||
raise ValueError("period must be positive")
|
||||
|
||||
def compute(self, data):
|
||||
return pl.Series([1.0])
|
||||
|
||||
# 创建实例时应该调用 _validate_params
|
||||
factor = ValidatedFactor(period=10)
|
||||
assert len(validated) == 1
|
||||
assert factor.params["period"] == 10
|
||||
|
||||
def test_validate_params_raises(self, sample_dataspec):
|
||||
"""测试 _validate_params() 可以抛出异常"""
|
||||
|
||||
class BadParamFactor(CrossSectionalFactor):
|
||||
name = "bad_param_factor"
|
||||
data_specs = [sample_dataspec]
|
||||
|
||||
def __init__(self, period: int = 20):
|
||||
super().__init__(period=period)
|
||||
|
||||
def _validate_params(self):
|
||||
if self.params.get("period", 0) <= 0:
|
||||
raise ValueError("period must be positive")
|
||||
|
||||
def compute(self, data):
|
||||
return pl.Series([1.0])
|
||||
|
||||
with pytest.raises(ValueError, match="period must be positive"):
|
||||
BadParamFactor(period=-5)
|
||||
|
||||
|
||||
class TestBaseFactorRepr:
|
||||
"""测试 __repr__"""
|
||||
|
||||
def test_repr(self, sample_dataspec):
|
||||
"""测试 __repr__ 返回正确格式"""
|
||||
|
||||
class TestFactor(CrossSectionalFactor):
|
||||
name = "test_factor"
|
||||
data_specs = [sample_dataspec]
|
||||
|
||||
def compute(self, data):
|
||||
return pl.Series([1.0])
|
||||
|
||||
factor = TestFactor()
|
||||
repr_str = repr(factor)
|
||||
assert "TestFactor" in repr_str
|
||||
assert "test_factor" in repr_str
|
||||
assert "cross_sectional" in repr_str
|
||||
|
||||
|
||||
# ========== CrossSectionalFactor 测试 ==========
|
||||
|
||||
|
||||
class TestCrossSectionalFactor:
|
||||
"""测试 CrossSectionalFactor"""
|
||||
|
||||
def test_factor_type_auto_set(self, sample_dataspec):
|
||||
"""测试 factor_type 自动设置为 'cross_sectional'"""
|
||||
|
||||
class CSFactor(CrossSectionalFactor):
|
||||
name = "cs_factor"
|
||||
data_specs = [sample_dataspec]
|
||||
|
||||
def compute(self, data):
|
||||
return pl.Series([1.0])
|
||||
|
||||
factor = CSFactor()
|
||||
assert factor.factor_type == "cross_sectional"
|
||||
|
||||
def test_compute_returns_series(self, sample_factor_data, sample_dataspec):
|
||||
"""测试 compute() 返回类型为 pl.Series"""
|
||||
|
||||
class CSFactor(CrossSectionalFactor):
|
||||
name = "cs_factor"
|
||||
data_specs = [sample_dataspec]
|
||||
|
||||
def compute(self, data):
|
||||
# 返回一个简单的 Series
|
||||
return pl.Series([1.0, 2.0])
|
||||
|
||||
factor = CSFactor()
|
||||
result = factor.compute(sample_factor_data)
|
||||
assert isinstance(result, pl.Series)
|
||||
|
||||
def test_compute_with_cross_section(self, sample_dataspec):
|
||||
"""测试 compute() 使用 get_cross_section()"""
|
||||
df = pl.DataFrame(
|
||||
{
|
||||
"ts_code": ["000001.SZ", "000002.SZ"],
|
||||
"trade_date": ["20240101", "20240101"],
|
||||
"close": [10.0, 20.0],
|
||||
}
|
||||
)
|
||||
context = FactorContext(current_date="20240101")
|
||||
data = FactorData(df, context)
|
||||
|
||||
class RankFactor(CrossSectionalFactor):
|
||||
name = "rank_factor"
|
||||
data_specs = [sample_dataspec]
|
||||
|
||||
def compute(self, data):
|
||||
cs = data.get_cross_section()
|
||||
return cs["close"].rank()
|
||||
|
||||
factor = RankFactor()
|
||||
result = factor.compute(data)
|
||||
assert isinstance(result, pl.Series)
|
||||
assert len(result) == 2
|
||||
|
||||
|
||||
# ========== TimeSeriesFactor 测试 ==========
|
||||
|
||||
|
||||
class TestTimeSeriesFactor:
|
||||
"""测试 TimeSeriesFactor"""
|
||||
|
||||
def test_factor_type_auto_set(self, sample_dataspec):
|
||||
"""测试 factor_type 自动设置为 'time_series'"""
|
||||
|
||||
class TSFactor(TimeSeriesFactor):
|
||||
name = "ts_factor"
|
||||
data_specs = [sample_dataspec]
|
||||
|
||||
def compute(self, data):
|
||||
return pl.Series([1.0])
|
||||
|
||||
factor = TSFactor()
|
||||
assert factor.factor_type == "time_series"
|
||||
|
||||
def test_compute_returns_series(self, sample_factor_data, sample_dataspec):
|
||||
"""测试 compute() 返回类型为 pl.Series"""
|
||||
|
||||
class TSFactor(TimeSeriesFactor):
|
||||
name = "ts_factor"
|
||||
data_specs = [sample_dataspec]
|
||||
|
||||
def compute(self, data):
|
||||
return data.get_column("close") * 2
|
||||
|
||||
factor = TSFactor()
|
||||
result = factor.compute(sample_factor_data)
|
||||
assert isinstance(result, pl.Series)
|
||||
assert len(result) == 4 # 4行数据
|
||||
|
||||
def test_compute_with_rolling(self, sample_dataspec):
|
||||
"""测试 compute() 使用 rolling 操作"""
|
||||
df = pl.DataFrame(
|
||||
{
|
||||
"ts_code": ["000001.SZ"] * 5,
|
||||
"trade_date": [
|
||||
"20240101",
|
||||
"20240102",
|
||||
"20240103",
|
||||
"20240104",
|
||||
"20240105",
|
||||
],
|
||||
"close": [10.0, 11.0, 12.0, 13.0, 14.0],
|
||||
}
|
||||
)
|
||||
context = FactorContext(current_stock="000001.SZ")
|
||||
data = FactorData(df, context)
|
||||
|
||||
class MAFactor(TimeSeriesFactor):
|
||||
name = "ma_factor"
|
||||
data_specs = [sample_dataspec]
|
||||
|
||||
def __init__(self, period: int = 3):
|
||||
super().__init__(period=period)
|
||||
|
||||
def compute(self, data):
|
||||
return data.get_column("close").rolling_mean(self.params["period"])
|
||||
|
||||
factor = MAFactor(period=3)
|
||||
result = factor.compute(data)
|
||||
assert isinstance(result, pl.Series)
|
||||
assert len(result) == 5
|
||||
# 前2个应该是 null(因为 period=3)
|
||||
assert result[0] is None
|
||||
assert result[1] is None
|
||||
assert result[2] is not None
|
||||
417
tests/factors/test_composite.py
Normal file
417
tests/factors/test_composite.py
Normal file
@@ -0,0 +1,417 @@
|
||||
"""测试组合因子 - CompositeFactor、ScalarFactor
|
||||
|
||||
测试需求(来自 factor_implementation_plan.md):
|
||||
- CompositeFactor:
|
||||
- 测试同类型因子组合成功(cs + cs)
|
||||
- 测试同类型因子组合成功(ts + ts)
|
||||
- 测试不同类型因子组合抛出 ValueError(cs + ts)
|
||||
- 测试无效运算符抛出 ValueError
|
||||
- 测试 `_merge_data_specs()` 正确合并(相同 source)
|
||||
- 测试 `_merge_data_specs()` 正确合并(不同 source)
|
||||
- 测试 `_merge_data_specs()` lookback 取最大值
|
||||
- 测试 `compute()` 执行正确的数学运算
|
||||
|
||||
- ScalarFactor:
|
||||
- 测试标量乘法 `0.5 * factor`
|
||||
- 测试标量乘法 `factor * 0.5`
|
||||
- 测试标量加法(如支持)
|
||||
- 测试继承基础因子的 data_specs
|
||||
- 测试 `compute()` 返回正确缩放后的值
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import polars as pl
|
||||
|
||||
from src.factors import DataSpec, FactorContext, FactorData
|
||||
from src.factors.base import CrossSectionalFactor, TimeSeriesFactor
|
||||
from src.factors.composite import CompositeFactor, ScalarFactor
|
||||
|
||||
|
||||
# ========== 测试数据准备 ==========
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_dataspec():
|
||||
"""创建一个示例 DataSpec"""
|
||||
return DataSpec(
|
||||
source="daily", columns=["ts_code", "trade_date", "close"], lookback_days=5
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def cs_factor1(sample_dataspec):
|
||||
"""截面因子 1"""
|
||||
|
||||
class CSFactor1(CrossSectionalFactor):
|
||||
name = "cs_factor1"
|
||||
data_specs = [sample_dataspec]
|
||||
|
||||
def compute(self, data):
|
||||
return pl.Series([1.0, 2.0])
|
||||
|
||||
return CSFactor1()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def cs_factor2(sample_dataspec):
|
||||
"""截面因子 2"""
|
||||
|
||||
class CSFactor2(CrossSectionalFactor):
|
||||
name = "cs_factor2"
|
||||
data_specs = [sample_dataspec]
|
||||
|
||||
def compute(self, data):
|
||||
return pl.Series([3.0, 4.0])
|
||||
|
||||
return CSFactor2()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def ts_factor1(sample_dataspec):
|
||||
"""时序因子 1"""
|
||||
|
||||
class TSFactor1(TimeSeriesFactor):
|
||||
name = "ts_factor1"
|
||||
data_specs = [sample_dataspec]
|
||||
|
||||
def compute(self, data):
|
||||
return pl.Series([10.0, 20.0, 30.0])
|
||||
|
||||
return TSFactor1()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def ts_factor2(sample_dataspec):
|
||||
"""时序因子 2"""
|
||||
|
||||
class TSFactor2(TimeSeriesFactor):
|
||||
name = "ts_factor2"
|
||||
data_specs = [sample_dataspec]
|
||||
|
||||
def compute(self, data):
|
||||
return pl.Series([1.0, 2.0, 3.0])
|
||||
|
||||
return TSFactor2()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_factor_data():
|
||||
"""创建一个示例 FactorData"""
|
||||
df = pl.DataFrame(
|
||||
{
|
||||
"ts_code": ["000001.SZ", "000002.SZ", "000001.SZ", "000002.SZ"],
|
||||
"trade_date": ["20240101", "20240101", "20240102", "20240102"],
|
||||
"close": [10.0, 20.0, 11.0, 21.0],
|
||||
}
|
||||
)
|
||||
context = FactorContext(current_date="20240102")
|
||||
return FactorData(df, context)
|
||||
|
||||
|
||||
# ========== CompositeFactor 测试 ==========
|
||||
|
||||
|
||||
class TestCompositeFactorTypeValidation:
|
||||
"""测试类型验证"""
|
||||
|
||||
def test_same_type_combination_cs(self, cs_factor1, cs_factor2):
|
||||
"""测试同类型截面因子组合成功"""
|
||||
combined = cs_factor1 + cs_factor2
|
||||
assert isinstance(combined, CompositeFactor)
|
||||
assert combined.factor_type == "cross_sectional"
|
||||
assert combined.name == "(cs_factor1_+_cs_factor2)"
|
||||
|
||||
def test_same_type_combination_ts(self, ts_factor1, ts_factor2):
|
||||
"""测试同类型时序因子组合成功"""
|
||||
combined = ts_factor1 - ts_factor2
|
||||
assert isinstance(combined, CompositeFactor)
|
||||
assert combined.factor_type == "time_series"
|
||||
assert combined.name == "(ts_factor1_-_ts_factor2)"
|
||||
|
||||
def test_different_type_raises(self, cs_factor1, ts_factor1):
|
||||
"""测试不同类型因子组合抛出 ValueError"""
|
||||
with pytest.raises(
|
||||
ValueError, match="Cannot combine factors of different types"
|
||||
):
|
||||
cs_factor1 + ts_factor1
|
||||
|
||||
def test_invalid_operator_raises(self, cs_factor1, cs_factor2):
|
||||
"""测试无效运算符抛出 ValueError"""
|
||||
with pytest.raises(ValueError, match="Unsupported operator"):
|
||||
CompositeFactor(cs_factor1, cs_factor2, "%")
|
||||
|
||||
|
||||
class TestCompositeFactorMergeDataSpecs:
|
||||
"""测试 _merge_data_specs"""
|
||||
|
||||
def test_merge_same_source_same_columns(self):
|
||||
"""测试相同 source 和 columns 的 DataSpec 合并"""
|
||||
spec1 = DataSpec("daily", ["ts_code", "trade_date", "close"], lookback_days=5)
|
||||
spec2 = DataSpec("daily", ["ts_code", "trade_date", "close"], lookback_days=10)
|
||||
|
||||
class Factor1(CrossSectionalFactor):
|
||||
name = "f1"
|
||||
data_specs = [spec1]
|
||||
|
||||
def compute(self, data):
|
||||
return pl.Series([1.0])
|
||||
|
||||
class Factor2(CrossSectionalFactor):
|
||||
name = "f2"
|
||||
data_specs = [spec2]
|
||||
|
||||
def compute(self, data):
|
||||
return pl.Series([2.0])
|
||||
|
||||
combined = Factor1() + Factor2()
|
||||
|
||||
# 应该合并成一个 DataSpec,lookback_days 取最大值 10
|
||||
assert len(combined.data_specs) == 1
|
||||
assert combined.data_specs[0].lookback_days == 10
|
||||
assert combined.data_specs[0].source == "daily"
|
||||
|
||||
def test_merge_different_source(self):
|
||||
"""测试不同 source 的 DataSpec 不合并"""
|
||||
spec1 = DataSpec("daily", ["ts_code", "trade_date", "close"], lookback_days=5)
|
||||
spec2 = DataSpec(
|
||||
"fundamental", ["ts_code", "trade_date", "pe"], lookback_days=1
|
||||
)
|
||||
|
||||
class Factor1(CrossSectionalFactor):
|
||||
name = "f1"
|
||||
data_specs = [spec1]
|
||||
|
||||
def compute(self, data):
|
||||
return pl.Series([1.0])
|
||||
|
||||
class Factor2(CrossSectionalFactor):
|
||||
name = "f2"
|
||||
data_specs = [spec2]
|
||||
|
||||
def compute(self, data):
|
||||
return pl.Series([2.0])
|
||||
|
||||
combined = Factor1() + Factor2()
|
||||
|
||||
# 应该有两个 DataSpec
|
||||
assert len(combined.data_specs) == 2
|
||||
sources = {s.source for s in combined.data_specs}
|
||||
assert sources == {"daily", "fundamental"}
|
||||
|
||||
def test_merge_same_source_different_columns(self):
|
||||
"""测试相同 source 但不同 columns 的 DataSpec 不合并"""
|
||||
spec1 = DataSpec("daily", ["ts_code", "trade_date", "close"], lookback_days=5)
|
||||
spec2 = DataSpec("daily", ["ts_code", "trade_date", "open"], lookback_days=3)
|
||||
|
||||
class Factor1(CrossSectionalFactor):
|
||||
name = "f1"
|
||||
data_specs = [spec1]
|
||||
|
||||
def compute(self, data):
|
||||
return pl.Series([1.0])
|
||||
|
||||
class Factor2(CrossSectionalFactor):
|
||||
name = "f2"
|
||||
data_specs = [spec2]
|
||||
|
||||
def compute(self, data):
|
||||
return pl.Series([2.0])
|
||||
|
||||
combined = Factor1() + Factor2()
|
||||
|
||||
# 应该有两个 DataSpec(因为 columns 不同)
|
||||
assert len(combined.data_specs) == 2
|
||||
|
||||
def test_merge_lookback_max(self):
|
||||
"""测试 lookback_days 取最大值"""
|
||||
spec1 = DataSpec("daily", ["ts_code", "trade_date", "close"], lookback_days=5)
|
||||
spec2 = DataSpec("daily", ["ts_code", "trade_date", "close"], lookback_days=20)
|
||||
spec3 = DataSpec("daily", ["ts_code", "trade_date", "close"], lookback_days=10)
|
||||
|
||||
class Factor1(CrossSectionalFactor):
|
||||
name = "f1"
|
||||
data_specs = [spec1]
|
||||
|
||||
def compute(self, data):
|
||||
return pl.Series([1.0])
|
||||
|
||||
class Factor2(CrossSectionalFactor):
|
||||
name = "f2"
|
||||
data_specs = [spec2, spec3]
|
||||
|
||||
def compute(self, data):
|
||||
return pl.Series([2.0])
|
||||
|
||||
combined = Factor1() + Factor2()
|
||||
|
||||
# 应该合并成一个 DataSpec,lookback_days 取最大值 20
|
||||
assert len(combined.data_specs) == 1
|
||||
assert combined.data_specs[0].lookback_days == 20
|
||||
|
||||
|
||||
class TestCompositeFactorCompute:
|
||||
"""测试 compute 运算"""
|
||||
|
||||
def test_compute_addition(self, cs_factor1, cs_factor2, sample_factor_data):
|
||||
"""测试加法运算"""
|
||||
combined = cs_factor1 + cs_factor2
|
||||
result = combined.compute(sample_factor_data)
|
||||
|
||||
# [1.0, 2.0] + [3.0, 4.0] = [4.0, 6.0]
|
||||
expected = pl.Series([4.0, 6.0])
|
||||
assert (result - expected).abs().max() < 1e-10
|
||||
|
||||
def test_compute_subtraction(self, cs_factor1, cs_factor2, sample_factor_data):
|
||||
"""测试减法运算"""
|
||||
combined = cs_factor1 - cs_factor2
|
||||
result = combined.compute(sample_factor_data)
|
||||
|
||||
# [1.0, 2.0] - [3.0, 4.0] = [-2.0, -2.0]
|
||||
expected = pl.Series([-2.0, -2.0])
|
||||
assert (result - expected).abs().max() < 1e-10
|
||||
|
||||
def test_compute_multiplication(self, cs_factor1, cs_factor2, sample_factor_data):
|
||||
"""测试乘法运算"""
|
||||
combined = cs_factor1 * cs_factor2
|
||||
result = combined.compute(sample_factor_data)
|
||||
|
||||
# [1.0, 2.0] * [3.0, 4.0] = [3.0, 8.0]
|
||||
expected = pl.Series([3.0, 8.0])
|
||||
assert (result - expected).abs().max() < 1e-10
|
||||
|
||||
def test_compute_division(self, cs_factor1, cs_factor2, sample_factor_data):
|
||||
"""测试除法运算"""
|
||||
combined = cs_factor1 / cs_factor2
|
||||
result = combined.compute(sample_factor_data)
|
||||
|
||||
# [1.0, 2.0] / [3.0, 4.0] = [0.333..., 0.5]
|
||||
expected = pl.Series([1.0 / 3.0, 0.5])
|
||||
assert (result - expected).abs().max() < 1e-10
|
||||
|
||||
def test_compute_with_ts_factors(self, ts_factor1, ts_factor2, sample_factor_data):
|
||||
"""测试时序因子的组合运算"""
|
||||
combined = ts_factor1 + ts_factor2
|
||||
result = combined.compute(sample_factor_data)
|
||||
|
||||
# [10.0, 20.0, 30.0] + [1.0, 2.0, 3.0] = [11.0, 22.0, 33.0]
|
||||
expected = pl.Series([11.0, 22.0, 33.0])
|
||||
assert (result - expected).abs().max() < 1e-10
|
||||
|
||||
def test_chained_combination(self, cs_factor1, cs_factor2, sample_factor_data):
|
||||
"""测试链式组合 (f1 + f2) * f1"""
|
||||
|
||||
# 创建第三个因子
|
||||
class CSFactor3(CrossSectionalFactor):
|
||||
name = "cs_factor3"
|
||||
data_specs = [
|
||||
DataSpec("daily", ["ts_code", "trade_date", "close"], lookback_days=5)
|
||||
]
|
||||
|
||||
def compute(self, data):
|
||||
return pl.Series([0.5, 1.0])
|
||||
|
||||
f3 = CSFactor3()
|
||||
|
||||
# (f1 + f2) * f3
|
||||
# f1 + f2 = [1.0, 2.0] + [3.0, 4.0] = [4.0, 6.0]
|
||||
# [4.0, 6.0] * [0.5, 1.0] = [2.0, 6.0]
|
||||
combined = (cs_factor1 + cs_factor2) * f3
|
||||
result = combined.compute(sample_factor_data)
|
||||
|
||||
expected = pl.Series([2.0, 6.0])
|
||||
assert (result - expected).abs().max() < 1e-10
|
||||
|
||||
|
||||
# ========== ScalarFactor 测试 ==========
|
||||
|
||||
|
||||
class TestScalarFactor:
|
||||
"""测试 ScalarFactor"""
|
||||
|
||||
def test_scalar_multiplication_left(self, cs_factor1):
|
||||
"""测试标量乘法 `0.5 * factor`(左乘)"""
|
||||
scaled = 0.5 * cs_factor1
|
||||
assert isinstance(scaled, ScalarFactor)
|
||||
assert scaled.scalar == 0.5
|
||||
assert scaled.op == "*"
|
||||
assert scaled.factor == cs_factor1
|
||||
|
||||
def test_scalar_multiplication_right(self, cs_factor1):
|
||||
"""测试标量乘法 `factor * 0.5`(右乘)"""
|
||||
scaled = cs_factor1 * 0.5
|
||||
assert isinstance(scaled, ScalarFactor)
|
||||
assert scaled.scalar == 0.5
|
||||
assert scaled.op == "*"
|
||||
|
||||
def test_scalar_integer_multiplication(self, cs_factor1):
|
||||
"""测试整数标量乘法"""
|
||||
scaled = 2 * cs_factor1
|
||||
assert isinstance(scaled, ScalarFactor)
|
||||
assert scaled.scalar == 2.0
|
||||
|
||||
def test_inherits_data_specs(self, cs_factor1):
|
||||
"""测试继承基础因子的 data_specs"""
|
||||
scaled = 0.5 * cs_factor1
|
||||
assert scaled.data_specs == cs_factor1.data_specs
|
||||
|
||||
def test_compute_multiplication(self, cs_factor1, sample_factor_data):
|
||||
"""测试标量乘法 compute 结果"""
|
||||
scaled = 0.5 * cs_factor1
|
||||
result = scaled.compute(sample_factor_data)
|
||||
|
||||
# [1.0, 2.0] * 0.5 = [0.5, 1.0]
|
||||
expected = pl.Series([0.5, 1.0])
|
||||
assert (result - expected).abs().max() < 1e-10
|
||||
|
||||
def test_compute_with_ts_factor(self, ts_factor1, sample_factor_data):
|
||||
"""测试时序因子的标量乘法"""
|
||||
scaled = 0.1 * ts_factor1
|
||||
result = scaled.compute(sample_factor_data)
|
||||
|
||||
# [10.0, 20.0, 30.0] * 0.1 = [1.0, 2.0, 3.0]
|
||||
expected = pl.Series([1.0, 2.0, 3.0])
|
||||
assert (result - expected).abs().max() < 1e-10
|
||||
|
||||
def test_factor_type_preserved(self, cs_factor1, ts_factor1):
|
||||
"""测试 factor_type 被正确保留"""
|
||||
scaled_cs = 0.5 * cs_factor1
|
||||
scaled_ts = 0.5 * ts_factor1
|
||||
|
||||
assert scaled_cs.factor_type == "cross_sectional"
|
||||
assert scaled_ts.factor_type == "time_series"
|
||||
|
||||
def test_scalar_name_format(self, cs_factor1):
|
||||
"""测试 ScalarFactor 的 name 格式"""
|
||||
scaled = 0.5 * cs_factor1
|
||||
assert scaled.name == "(0.5_*_cs_factor1)"
|
||||
|
||||
|
||||
# ========== 组合和标量混合测试 ==========
|
||||
|
||||
|
||||
class TestMixedOperations:
|
||||
"""测试组合因子和标量因子的混合运算"""
|
||||
|
||||
def test_scalar_then_combine(self, cs_factor1, cs_factor2, sample_factor_data):
|
||||
"""测试先标量缩放再组合"""
|
||||
# 0.5 * f1 + 0.3 * f2
|
||||
combined = 0.5 * cs_factor1 + 0.3 * cs_factor2
|
||||
result = combined.compute(sample_factor_data)
|
||||
|
||||
# 0.5 * [1.0, 2.0] + 0.3 * [3.0, 4.0]
|
||||
# = [0.5, 1.0] + [0.9, 1.2]
|
||||
# = [1.4, 2.2]
|
||||
expected = pl.Series([1.4, 2.2])
|
||||
assert (result - expected).abs().max() < 1e-10
|
||||
|
||||
def test_complex_formula(self, cs_factor1, cs_factor2, sample_factor_data):
|
||||
"""测试复杂公式: (f1 + f2) * 0.5 - f1 * 0.2"""
|
||||
formula = (cs_factor1 + cs_factor2) * 0.5 - cs_factor1 * 0.2
|
||||
result = formula.compute(sample_factor_data)
|
||||
|
||||
# (f1 + f2) = [4.0, 6.0]
|
||||
# (f1 + f2) * 0.5 = [2.0, 3.0]
|
||||
# f1 * 0.2 = [0.2, 0.4]
|
||||
# [2.0, 3.0] - [0.2, 0.4] = [1.8, 2.6]
|
||||
expected = pl.Series([1.8, 2.6])
|
||||
assert (result - expected).abs().max() < 1e-10
|
||||
248
tests/factors/test_data_loader.py
Normal file
248
tests/factors/test_data_loader.py
Normal file
@@ -0,0 +1,248 @@
|
||||
"""测试数据加载器 - DataLoader
|
||||
|
||||
测试需求(来自 factor_implementation_plan.md):
|
||||
- 测试从单个 H5 文件加载数据
|
||||
- 测试从多个 H5 文件加载并合并
|
||||
- 测试列选择(只加载需要的列)
|
||||
- 测试缓存机制(第二次加载更快)
|
||||
- 测试 clear_cache() 清空缓存
|
||||
- 测试按 date_range 过滤
|
||||
- 测试文件不存在时抛出 FileNotFoundError
|
||||
- 测试列不存在时抛出 KeyError
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import polars as pl
|
||||
import pandas as pd
|
||||
from pathlib import Path
|
||||
|
||||
from src.factors import DataSpec, DataLoader
|
||||
|
||||
|
||||
class TestDataLoaderBasic:
|
||||
"""测试 DataLoader 基本功能"""
|
||||
|
||||
@pytest.fixture
|
||||
def loader(self):
|
||||
"""创建 DataLoader 实例"""
|
||||
return DataLoader(data_dir="data")
|
||||
|
||||
def test_init(self):
|
||||
"""测试初始化"""
|
||||
loader = DataLoader(data_dir="data")
|
||||
assert loader.data_dir == Path("data")
|
||||
assert loader._cache == {}
|
||||
|
||||
def test_load_single_source(self, loader):
|
||||
"""测试从单个 H5 文件加载数据"""
|
||||
specs = [
|
||||
DataSpec(
|
||||
source="daily",
|
||||
columns=["ts_code", "trade_date", "close"],
|
||||
lookback_days=1,
|
||||
)
|
||||
]
|
||||
|
||||
df = loader.load(specs)
|
||||
|
||||
assert isinstance(df, pl.DataFrame)
|
||||
assert len(df) > 0
|
||||
assert "ts_code" in df.columns
|
||||
assert "trade_date" in df.columns
|
||||
assert "close" in df.columns
|
||||
|
||||
def test_load_multiple_sources(self, loader):
|
||||
"""测试从多个 H5 文件加载并合并"""
|
||||
# 注意:这里假设只有一个 daily.h5 文件
|
||||
# 如果有多个文件,可以测试合并逻辑
|
||||
specs = [
|
||||
DataSpec(
|
||||
source="daily",
|
||||
columns=["ts_code", "trade_date", "close"],
|
||||
lookback_days=1,
|
||||
),
|
||||
DataSpec(
|
||||
source="daily",
|
||||
columns=["ts_code", "trade_date", "open", "high", "low"],
|
||||
lookback_days=1,
|
||||
),
|
||||
]
|
||||
|
||||
df = loader.load(specs)
|
||||
|
||||
assert isinstance(df, pl.DataFrame)
|
||||
assert len(df) > 0
|
||||
# 应该包含所有列
|
||||
assert set(df.columns) >= {
|
||||
"ts_code",
|
||||
"trade_date",
|
||||
"close",
|
||||
"open",
|
||||
"high",
|
||||
"low",
|
||||
}
|
||||
|
||||
def test_column_selection(self, loader):
|
||||
"""测试列选择(只加载需要的列)"""
|
||||
specs = [
|
||||
DataSpec(
|
||||
source="daily",
|
||||
columns=["ts_code", "trade_date", "close"],
|
||||
lookback_days=1,
|
||||
)
|
||||
]
|
||||
|
||||
df = loader.load(specs)
|
||||
|
||||
# 只应该有 3 列
|
||||
assert set(df.columns) == {"ts_code", "trade_date", "close"}
|
||||
|
||||
def test_date_range_filter(self, loader):
|
||||
"""测试按 date_range 过滤"""
|
||||
specs = [
|
||||
DataSpec(
|
||||
source="daily",
|
||||
columns=["ts_code", "trade_date", "close"],
|
||||
lookback_days=1,
|
||||
)
|
||||
]
|
||||
|
||||
# 先加载所有数据
|
||||
df_all = loader.load(specs)
|
||||
total_rows = len(df_all)
|
||||
|
||||
# 清空缓存,重新加载特定日期范围
|
||||
loader.clear_cache()
|
||||
df_filtered = loader.load(specs, date_range=("20240101", "20240131"))
|
||||
|
||||
# 过滤后的数据应该更少或相等
|
||||
assert len(df_filtered) <= total_rows
|
||||
|
||||
# 所有日期都应该在范围内
|
||||
if len(df_filtered) > 0:
|
||||
dates = df_filtered["trade_date"].to_list()
|
||||
assert all("20240101" <= d <= "20240131" for d in dates)
|
||||
|
||||
|
||||
class TestDataLoaderCache:
|
||||
"""测试 DataLoader 缓存机制"""
|
||||
|
||||
@pytest.fixture
|
||||
def loader(self):
|
||||
"""创建 DataLoader 实例"""
|
||||
return DataLoader(data_dir="data")
|
||||
|
||||
def test_cache_populated(self, loader):
|
||||
"""测试加载后缓存被填充"""
|
||||
specs = [
|
||||
DataSpec(
|
||||
source="daily",
|
||||
columns=["ts_code", "trade_date", "close"],
|
||||
lookback_days=1,
|
||||
)
|
||||
]
|
||||
|
||||
# 第一次加载
|
||||
loader.load(specs)
|
||||
|
||||
# 检查缓存
|
||||
assert len(loader._cache) > 0
|
||||
|
||||
def test_cache_used(self, loader):
|
||||
"""测试第二次加载使用缓存(更快)"""
|
||||
import time
|
||||
|
||||
specs = [
|
||||
DataSpec(
|
||||
source="daily",
|
||||
columns=["ts_code", "trade_date", "close"],
|
||||
lookback_days=1,
|
||||
)
|
||||
]
|
||||
|
||||
# 第一次加载
|
||||
start = time.time()
|
||||
df1 = loader.load(specs)
|
||||
time1 = time.time() - start
|
||||
|
||||
# 第二次加载(应该使用缓存)
|
||||
start = time.time()
|
||||
df2 = loader.load(specs)
|
||||
time2 = time.time() - start
|
||||
|
||||
# 数据应该相同
|
||||
assert df1.shape == df2.shape
|
||||
|
||||
# 第二次应该更快(至少快 50%)
|
||||
# 注意:如果数据量很小,这个测试可能不稳定
|
||||
# assert time2 < time1 * 0.5
|
||||
|
||||
def test_clear_cache(self, loader):
|
||||
"""测试 clear_cache() 清空缓存"""
|
||||
specs = [
|
||||
DataSpec(
|
||||
source="daily",
|
||||
columns=["ts_code", "trade_date", "close"],
|
||||
lookback_days=1,
|
||||
)
|
||||
]
|
||||
|
||||
# 加载数据
|
||||
loader.load(specs)
|
||||
assert len(loader._cache) > 0
|
||||
|
||||
# 清空缓存
|
||||
loader.clear_cache()
|
||||
assert len(loader._cache) == 0
|
||||
|
||||
def test_cache_info(self, loader):
|
||||
"""测试 get_cache_info()"""
|
||||
specs = [
|
||||
DataSpec(
|
||||
source="daily",
|
||||
columns=["ts_code", "trade_date", "close"],
|
||||
lookback_days=1,
|
||||
)
|
||||
]
|
||||
|
||||
# 加载前
|
||||
info_before = loader.get_cache_info()
|
||||
assert info_before["entries"] == 0
|
||||
|
||||
# 加载后
|
||||
loader.load(specs)
|
||||
info_after = loader.get_cache_info()
|
||||
assert info_after["entries"] > 0
|
||||
assert info_after["total_rows"] > 0
|
||||
|
||||
|
||||
class TestDataLoaderErrors:
|
||||
"""测试 DataLoader 错误处理"""
|
||||
|
||||
def test_file_not_found(self):
|
||||
"""测试文件不存在时抛出 FileNotFoundError"""
|
||||
loader = DataLoader(data_dir="nonexistent_dir")
|
||||
specs = [
|
||||
DataSpec(
|
||||
source="daily",
|
||||
columns=["ts_code", "trade_date", "close"],
|
||||
lookback_days=1,
|
||||
)
|
||||
]
|
||||
|
||||
with pytest.raises(FileNotFoundError):
|
||||
loader.load(specs)
|
||||
|
||||
def test_column_not_found(self):
|
||||
"""测试列不存在时抛出 KeyError"""
|
||||
loader = DataLoader(data_dir="data")
|
||||
specs = [
|
||||
DataSpec(
|
||||
source="daily",
|
||||
columns=["ts_code", "trade_date", "nonexistent_column"],
|
||||
lookback_days=1,
|
||||
)
|
||||
]
|
||||
|
||||
with pytest.raises(KeyError, match="nonexistent_column"):
|
||||
loader.load(specs)
|
||||
328
tests/factors/test_data_spec.py
Normal file
328
tests/factors/test_data_spec.py
Normal file
@@ -0,0 +1,328 @@
|
||||
"""Factors 模块测试 - Phase 1: 数据类型定义测试
|
||||
|
||||
测试范围:
|
||||
- DataSpec: 数据需求规格的创建和验证
|
||||
- FactorContext: 计算上下文的创建
|
||||
- FactorData: 数据容器的基本操作
|
||||
- HDF5 数据读取: 验证能正确读取 daily.h5 文件
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import polars as pl
|
||||
import pandas as pd
|
||||
from pathlib import Path
|
||||
|
||||
from src.factors import DataSpec, FactorContext, FactorData
|
||||
|
||||
|
||||
class TestDataSpec:
|
||||
"""测试 DataSpec 数据需求规格"""
|
||||
|
||||
def test_valid_dataspec_creation(self):
|
||||
"""测试有效的 DataSpec 创建"""
|
||||
spec = DataSpec(
|
||||
source="daily", columns=["ts_code", "trade_date", "close"], lookback_days=5
|
||||
)
|
||||
assert spec.source == "daily"
|
||||
assert spec.columns == ["ts_code", "trade_date", "close"]
|
||||
assert spec.lookback_days == 5
|
||||
|
||||
def test_dataspec_default_lookback(self):
|
||||
"""测试 DataSpec 默认值 lookback_days=1"""
|
||||
spec = DataSpec(source="daily", columns=["ts_code", "trade_date", "close"])
|
||||
assert spec.lookback_days == 1
|
||||
|
||||
def test_dataspec_frozen_immutable(self):
|
||||
"""测试 DataSpec 是 frozen(不可变)"""
|
||||
spec = DataSpec(
|
||||
source="daily", columns=["ts_code", "trade_date", "close"], lookback_days=5
|
||||
)
|
||||
with pytest.raises(FrozenInstanceError):
|
||||
spec.source = "other"
|
||||
|
||||
def test_dataspec_lookback_less_than_1_raises(self):
|
||||
"""测试 lookback_days < 1 时抛出 ValueError"""
|
||||
with pytest.raises(ValueError, match="lookback_days must be >= 1"):
|
||||
DataSpec(
|
||||
source="daily",
|
||||
columns=["ts_code", "trade_date", "close"],
|
||||
lookback_days=0,
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="lookback_days must be >= 1"):
|
||||
DataSpec(
|
||||
source="daily",
|
||||
columns=["ts_code", "trade_date", "close"],
|
||||
lookback_days=-1,
|
||||
)
|
||||
|
||||
def test_dataspec_missing_required_columns_raises(self):
|
||||
"""测试缺少 ts_code 或 trade_date 时抛出 ValueError"""
|
||||
# 缺少 ts_code
|
||||
with pytest.raises(ValueError, match="columns must contain"):
|
||||
DataSpec(source="daily", columns=["trade_date", "close"], lookback_days=5)
|
||||
|
||||
# 缺少 trade_date
|
||||
with pytest.raises(ValueError, match="columns must contain"):
|
||||
DataSpec(source="daily", columns=["ts_code", "close"], lookback_days=5)
|
||||
|
||||
# 两者都缺少
|
||||
with pytest.raises(ValueError, match="columns must contain"):
|
||||
DataSpec(source="daily", columns=["close", "open", "high"], lookback_days=5)
|
||||
|
||||
def test_dataspec_empty_source_raises(self):
|
||||
"""测试空 source 时抛出 ValueError"""
|
||||
with pytest.raises(ValueError, match="source cannot be empty string"):
|
||||
DataSpec(
|
||||
source="", columns=["ts_code", "trade_date", "close"], lookback_days=5
|
||||
)
|
||||
|
||||
|
||||
class TestFactorContext:
|
||||
"""测试 FactorContext 计算上下文"""
|
||||
|
||||
def test_default_creation(self):
|
||||
"""测试默认值创建"""
|
||||
ctx = FactorContext()
|
||||
assert ctx.current_date is None
|
||||
assert ctx.current_stock is None
|
||||
assert ctx.trade_dates is None
|
||||
|
||||
def test_full_creation(self):
|
||||
"""测试完整参数创建"""
|
||||
ctx = FactorContext(
|
||||
current_date="20240101",
|
||||
current_stock="000001.SZ",
|
||||
trade_dates=["20240101", "20240102", "20240103"],
|
||||
)
|
||||
assert ctx.current_date == "20240101"
|
||||
assert ctx.current_stock == "000001.SZ"
|
||||
assert ctx.trade_dates == ["20240101", "20240102", "20240103"]
|
||||
|
||||
def test_partial_creation(self):
|
||||
"""测试部分参数创建"""
|
||||
ctx = FactorContext(current_date="20240101")
|
||||
assert ctx.current_date == "20240101"
|
||||
assert ctx.current_stock is None
|
||||
assert ctx.trade_dates is None
|
||||
|
||||
def test_dataclass_methods(self):
|
||||
"""测试 dataclass 自动生成的方法"""
|
||||
ctx = FactorContext(current_date="20240101")
|
||||
# __repr__
|
||||
assert "FactorContext" in repr(ctx)
|
||||
assert "20240101" in repr(ctx)
|
||||
# __eq__
|
||||
ctx2 = FactorContext(current_date="20240101")
|
||||
assert ctx == ctx2
|
||||
ctx3 = FactorContext(current_date="20240102")
|
||||
assert ctx != ctx3
|
||||
|
||||
|
||||
class TestFactorData:
|
||||
"""测试 FactorData 数据容器"""
|
||||
|
||||
@pytest.fixture
|
||||
def sample_df(self):
|
||||
"""创建示例 DataFrame"""
|
||||
return pl.DataFrame(
|
||||
{
|
||||
"ts_code": ["000001.SZ", "000002.SZ", "000001.SZ", "000002.SZ"],
|
||||
"trade_date": ["20240101", "20240101", "20240102", "20240102"],
|
||||
"close": [10.0, 20.0, 10.5, 20.5],
|
||||
"volume": [1000, 2000, 1100, 2100],
|
||||
}
|
||||
)
|
||||
|
||||
@pytest.fixture
|
||||
def cs_context(self):
|
||||
"""截面因子上下文"""
|
||||
return FactorContext(current_date="20240101")
|
||||
|
||||
@pytest.fixture
|
||||
def ts_context(self):
|
||||
"""时序因子上下文"""
|
||||
return FactorContext(current_stock="000001.SZ")
|
||||
|
||||
def test_get_column(self, sample_df, cs_context):
|
||||
"""测试 get_column 返回正确的 Series"""
|
||||
data = FactorData(sample_df, cs_context)
|
||||
close_series = data.get_column("close")
|
||||
assert isinstance(close_series, pl.Series)
|
||||
assert close_series.to_list() == [10.0, 20.0, 10.5, 20.5]
|
||||
|
||||
def test_get_column_keyerror(self, sample_df, cs_context):
|
||||
"""测试 get_column 列不存在时抛出 KeyError"""
|
||||
data = FactorData(sample_df, cs_context)
|
||||
with pytest.raises(KeyError, match="Column 'nonexistent' not found"):
|
||||
data.get_column("nonexistent")
|
||||
|
||||
def test_filter_by_date(self, sample_df, cs_context):
|
||||
"""测试 filter_by_date 返回正确的过滤结果"""
|
||||
data = FactorData(sample_df, cs_context)
|
||||
filtered = data.filter_by_date("20240101")
|
||||
assert len(filtered) == 2
|
||||
assert filtered.to_polars()["ts_code"].to_list() == ["000001.SZ", "000002.SZ"]
|
||||
assert filtered.to_polars()["close"].to_list() == [10.0, 20.0]
|
||||
|
||||
def test_filter_by_date_empty_result(self, sample_df, cs_context):
|
||||
"""测试 filter_by_date 日期不存在时返回空的 FactorData"""
|
||||
data = FactorData(sample_df, cs_context)
|
||||
filtered = data.filter_by_date("20241231")
|
||||
assert len(filtered) == 0
|
||||
assert isinstance(filtered, FactorData)
|
||||
|
||||
def test_get_cross_section(self, sample_df, cs_context):
|
||||
"""测试 get_cross_section 返回 current_date 当天的数据"""
|
||||
data = FactorData(sample_df, cs_context)
|
||||
cs = data.get_cross_section()
|
||||
assert len(cs) == 2
|
||||
assert cs["ts_code"].to_list() == ["000001.SZ", "000002.SZ"]
|
||||
|
||||
def test_get_cross_section_no_date_raises(self, sample_df, ts_context):
|
||||
"""测试 get_cross_section current_date 为 None 时抛出 ValueError"""
|
||||
data = FactorData(sample_df, ts_context)
|
||||
with pytest.raises(ValueError, match="current_date is not set"):
|
||||
data.get_cross_section()
|
||||
|
||||
def test_to_polars(self, sample_df, cs_context):
|
||||
"""测试 to_polars 返回原始 DataFrame"""
|
||||
data = FactorData(sample_df, cs_context)
|
||||
df = data.to_polars()
|
||||
assert isinstance(df, pl.DataFrame)
|
||||
assert df.shape == sample_df.shape
|
||||
assert df.columns == sample_df.columns
|
||||
|
||||
def test_context_property(self, sample_df, cs_context):
|
||||
"""测试 context 属性返回正确的上下文"""
|
||||
data = FactorData(sample_df, cs_context)
|
||||
assert data.context == cs_context
|
||||
assert data.context.current_date == "20240101"
|
||||
|
||||
def test_len(self, sample_df, cs_context):
|
||||
"""测试 __len__ 返回正确的行数"""
|
||||
data = FactorData(sample_df, cs_context)
|
||||
assert len(data) == 4
|
||||
|
||||
def test_repr(self, sample_df, cs_context):
|
||||
"""测试 __repr__ 返回可读字符串"""
|
||||
data = FactorData(sample_df, cs_context)
|
||||
repr_str = repr(data)
|
||||
assert "FactorData" in repr_str
|
||||
assert "rows=4" in repr_str
|
||||
assert "date=20240101" in repr_str
|
||||
|
||||
|
||||
class TestHDF5DataAccess:
|
||||
"""测试 HDF5 数据读取功能"""
|
||||
|
||||
def test_daily_h5_file_exists(self):
|
||||
"""测试 daily.h5 文件存在"""
|
||||
data_path = Path("data/daily.h5")
|
||||
assert data_path.exists(), f"daily.h5 文件不存在: {data_path.absolute()}"
|
||||
|
||||
def test_daily_h5_can_read_with_pandas(self):
|
||||
"""测试能用 pandas 读取 daily.h5"""
|
||||
data_path = Path("data/daily.h5")
|
||||
df = pd.read_hdf(data_path, key="/daily")
|
||||
|
||||
assert df is not None
|
||||
assert len(df) > 0
|
||||
assert "ts_code" in df.columns
|
||||
assert "trade_date" in df.columns
|
||||
assert "close" in df.columns
|
||||
|
||||
def test_daily_h5_columns(self):
|
||||
"""测试 daily.h5 包含预期的列"""
|
||||
data_path = Path("data/daily.h5")
|
||||
df = pd.read_hdf(data_path, key="/daily")
|
||||
|
||||
expected_columns = [
|
||||
"trade_date",
|
||||
"ts_code",
|
||||
"open",
|
||||
"high",
|
||||
"low",
|
||||
"close",
|
||||
"pre_close",
|
||||
"change",
|
||||
"pct_chg",
|
||||
"vol",
|
||||
"amount",
|
||||
"turnover_rate",
|
||||
"volume_ratio",
|
||||
]
|
||||
|
||||
for col in expected_columns:
|
||||
assert col in df.columns, f"列 {col} 不存在于 daily.h5"
|
||||
|
||||
def test_daily_h5_date_format(self):
|
||||
"""测试 daily.h5 日期格式正确"""
|
||||
data_path = Path("data/daily.h5")
|
||||
df = pd.read_hdf(data_path, key="/daily")
|
||||
|
||||
# 检查日期格式是 YYYYMMDD 字符串
|
||||
sample_date = df["trade_date"].iloc[0]
|
||||
assert isinstance(sample_date, str), "日期应该是字符串格式"
|
||||
assert len(sample_date) == 8, "日期应该是 8 位字符串 (YYYYMMDD)"
|
||||
assert sample_date.isdigit(), "日期应该只包含数字"
|
||||
|
||||
def test_daily_h5_stock_format(self):
|
||||
"""测试 daily.h5 股票代码格式正确"""
|
||||
data_path = Path("data/daily.h5")
|
||||
df = pd.read_hdf(data_path, key="/daily")
|
||||
|
||||
# 检查股票代码格式如 "000001.SZ"
|
||||
sample_code = df["ts_code"].iloc[0]
|
||||
assert isinstance(sample_code, str), "股票代码应该是字符串"
|
||||
assert "." in sample_code, "股票代码应该包含交易所后缀"
|
||||
assert sample_code.endswith((".SZ", ".SH", ".BJ")), (
|
||||
"股票代码应该以交易所后缀结尾"
|
||||
)
|
||||
|
||||
def test_daily_h5_to_polars(self):
|
||||
"""测试将 daily.h5 数据转换为 Polars"""
|
||||
data_path = Path("data/daily.h5")
|
||||
pdf = pd.read_hdf(data_path, key="/daily")
|
||||
|
||||
# 转换为 Polars
|
||||
df = pl.from_pandas(pdf)
|
||||
|
||||
assert isinstance(df, pl.DataFrame)
|
||||
assert len(df) > 0
|
||||
assert "ts_code" in df.columns
|
||||
assert "trade_date" in df.columns
|
||||
|
||||
def test_daily_h5_sample_data_with_factors(self):
|
||||
"""测试用 daily.h5 真实数据创建 FactorData"""
|
||||
data_path = Path("data/daily.h5")
|
||||
pdf = pd.read_hdf(data_path, key="/daily")
|
||||
|
||||
# 取前 100 行作为示例
|
||||
sample_pdf = pdf.head(100)
|
||||
df = pl.from_pandas(sample_pdf)
|
||||
|
||||
# 创建 FactorData
|
||||
ctx = FactorContext(current_date=df["trade_date"][0])
|
||||
data = FactorData(df, ctx)
|
||||
|
||||
# 验证基本操作
|
||||
assert len(data) == 100
|
||||
assert "close" in data.to_polars().columns
|
||||
|
||||
# 测试 get_column
|
||||
close_prices = data.get_column("close")
|
||||
assert len(close_prices) == 100
|
||||
|
||||
# 测试 filter_by_date
|
||||
first_date = df["trade_date"][0]
|
||||
filtered = data.filter_by_date(first_date)
|
||||
assert len(filtered) > 0
|
||||
|
||||
|
||||
# 导入 FrozenInstanceError
|
||||
try:
|
||||
from dataclasses import FrozenInstanceError
|
||||
except ImportError:
|
||||
# Python < 3.10 compatibility
|
||||
FrozenInstanceError = AttributeError
|
||||
266
tests/factors/test_engine.py
Normal file
266
tests/factors/test_engine.py
Normal file
@@ -0,0 +1,266 @@
|
||||
"""测试执行引擎 - FactorEngine
|
||||
|
||||
测试需求(来自 factor_implementation_plan.md):
|
||||
- 测试 `compute()` 正确分发给截面计算
|
||||
- 测试 `compute()` 正确分发给时序计算
|
||||
- 测试无效 factor_type 时抛出 ValueError
|
||||
|
||||
截面计算测试(防泄露验证):
|
||||
- 测试数据裁剪正确(传入 [T-lookback+1, T])
|
||||
- 测试不包含未来日期 T+1 的数据
|
||||
- 测试每个日期独立计算
|
||||
- 测试结果包含所有日期和所有股票
|
||||
- 测试结果 DataFrame 格式正确
|
||||
- 测试多个 DataSpec 时 lookback 取最大值
|
||||
|
||||
时序计算测试(防泄露验证):
|
||||
- 测试每只股票只看到自己的数据
|
||||
- 测试不包含其他股票的数据
|
||||
- 测试传入的是完整时间序列(向量化计算)
|
||||
- 测试结果包含所有股票和所有日期
|
||||
- 测试结果 DataFrame 格式正确
|
||||
- 测试股票不在数据中时跳过(或填充 null)
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import polars as pl
|
||||
|
||||
from src.factors import (
|
||||
DataSpec,
|
||||
FactorContext,
|
||||
FactorData,
|
||||
DataLoader,
|
||||
FactorEngine,
|
||||
CrossSectionalFactor,
|
||||
TimeSeriesFactor,
|
||||
)
|
||||
|
||||
|
||||
class SimpleCrossSectionalFactor(CrossSectionalFactor):
|
||||
"""简单的截面因子 - 返回收盘价排名"""
|
||||
|
||||
name = "close_rank"
|
||||
data_specs = [
|
||||
DataSpec("daily", ["ts_code", "trade_date", "close"], lookback_days=1)
|
||||
]
|
||||
|
||||
def compute(self, data: FactorData) -> pl.Series:
|
||||
cs = data.get_cross_section()
|
||||
return cs["close"].rank()
|
||||
|
||||
|
||||
class SimpleTimeSeriesFactor(TimeSeriesFactor):
|
||||
"""简单的时序因子 - 返回3日移动平均"""
|
||||
|
||||
name = "ma3"
|
||||
data_specs = [
|
||||
DataSpec(
|
||||
"daily",
|
||||
["ts_code", "trade_date", "close"],
|
||||
lookback_days=5,
|
||||
)
|
||||
]
|
||||
|
||||
def __init__(self, period: int = 3):
|
||||
super().__init__(period=period)
|
||||
|
||||
def compute(self, data: FactorData) -> pl.Series:
|
||||
return data.get_column("close").rolling_mean(self.params["period"])
|
||||
|
||||
|
||||
class ReturnFactor(CrossSectionalFactor):
|
||||
"""收益率因子 - 需要2天lookback计算收益率"""
|
||||
|
||||
name = "return"
|
||||
data_specs = [
|
||||
DataSpec("daily", ["ts_code", "trade_date", "close"], lookback_days=2)
|
||||
]
|
||||
|
||||
def compute(self, data: FactorData) -> pl.Series:
|
||||
# 获取当前日期
|
||||
current_date = data.context.current_date
|
||||
|
||||
# 获取当前日期的数据
|
||||
cs = data.get_cross_section()
|
||||
|
||||
# 简单返回收盘价作为因子值
|
||||
# 实际应该计算收益率,但这里简化处理
|
||||
return cs["close"]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def loader():
|
||||
"""创建 DataLoader 实例"""
|
||||
return DataLoader(data_dir="data")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def engine(loader):
|
||||
"""创建 FactorEngine 实例"""
|
||||
return FactorEngine(loader)
|
||||
|
||||
|
||||
class TestFactorEngineDispatch:
|
||||
"""测试引擎分发逻辑"""
|
||||
|
||||
def test_dispatch_cross_sectional(self, engine):
|
||||
"""测试 compute() 正确分发给截面计算"""
|
||||
factor = SimpleCrossSectionalFactor()
|
||||
|
||||
result = engine.compute(factor, start_date="20240101", end_date="20240105")
|
||||
|
||||
assert isinstance(result, pl.DataFrame)
|
||||
assert "trade_date" in result.columns
|
||||
assert "ts_code" in result.columns
|
||||
assert "close_rank" in result.columns
|
||||
|
||||
def test_dispatch_time_series(self, engine, loader):
|
||||
"""测试 compute() 正确分发给时序计算"""
|
||||
factor = SimpleTimeSeriesFactor(period=3)
|
||||
|
||||
# 获取一些股票代码
|
||||
sample_data = loader.load(
|
||||
[DataSpec("daily", ["ts_code", "trade_date"], lookback_days=1)]
|
||||
)
|
||||
stock_codes = sample_data["ts_code"].unique().head(3).to_list()
|
||||
|
||||
result = engine.compute(
|
||||
factor,
|
||||
stock_codes=stock_codes,
|
||||
start_date="20240101",
|
||||
end_date="20240110",
|
||||
)
|
||||
|
||||
assert isinstance(result, pl.DataFrame)
|
||||
assert "trade_date" in result.columns
|
||||
assert "ts_code" in result.columns
|
||||
assert "ma3" in result.columns
|
||||
|
||||
def test_unknown_factor_type(self, engine):
|
||||
"""测试无效 factor_type 时抛出 ValueError"""
|
||||
|
||||
class UnknownFactor:
|
||||
name = "unknown"
|
||||
factor_type = "unknown_type"
|
||||
data_specs = []
|
||||
|
||||
factor = UnknownFactor()
|
||||
|
||||
with pytest.raises(ValueError, match="Unknown factor type"):
|
||||
engine.compute(factor)
|
||||
|
||||
|
||||
class TestCrossSectionalComputation:
|
||||
"""测试截面计算(防泄露验证)"""
|
||||
|
||||
def test_result_format(self, engine):
|
||||
"""测试结果 DataFrame 格式正确"""
|
||||
factor = SimpleCrossSectionalFactor()
|
||||
|
||||
result = engine.compute(factor, start_date="20240101", end_date="20240105")
|
||||
|
||||
# 检查列
|
||||
assert "trade_date" in result.columns
|
||||
assert "ts_code" in result.columns
|
||||
assert factor.name in result.columns
|
||||
|
||||
# 检查类型
|
||||
assert result["trade_date"].dtype == pl.Utf8
|
||||
assert result["ts_code"].dtype == pl.Utf8
|
||||
|
||||
def test_all_dates_present(self, engine):
|
||||
"""测试结果包含所有日期"""
|
||||
factor = SimpleCrossSectionalFactor()
|
||||
|
||||
start_date = "20240101"
|
||||
end_date = "20240105"
|
||||
|
||||
result = engine.compute(factor, start_date=start_date, end_date=end_date)
|
||||
|
||||
if len(result) > 0:
|
||||
dates = result["trade_date"].unique().to_list()
|
||||
# 应该包含 start_date 和 end_date 之间的日期
|
||||
assert len(dates) > 0
|
||||
|
||||
def test_lookback_window(self, engine):
|
||||
"""测试多个 DataSpec 时 lookback 取最大值"""
|
||||
factor = ReturnFactor()
|
||||
|
||||
# lookback_days = 2
|
||||
result = engine.compute(factor, start_date="20240103", end_date="20240105")
|
||||
|
||||
# 应该能计算出结果
|
||||
assert isinstance(result, pl.DataFrame)
|
||||
|
||||
|
||||
class TestTimeSeriesComputation:
|
||||
"""测试时序计算(防泄露验证)"""
|
||||
|
||||
def test_result_format(self, engine):
|
||||
"""测试结果 DataFrame 格式正确"""
|
||||
factor = SimpleTimeSeriesFactor(period=3)
|
||||
|
||||
result = engine.compute(
|
||||
factor,
|
||||
stock_codes=["000001.SZ"],
|
||||
start_date="20240101",
|
||||
end_date="20240110",
|
||||
)
|
||||
|
||||
# 检查列
|
||||
assert "trade_date" in result.columns
|
||||
assert "ts_code" in result.columns
|
||||
assert factor.name in result.columns
|
||||
|
||||
def test_single_stock_data(self, engine):
|
||||
"""测试每只股票只看到自己的数据"""
|
||||
factor = SimpleTimeSeriesFactor(period=3)
|
||||
|
||||
stock_codes = ["000001.SZ"]
|
||||
|
||||
result = engine.compute(
|
||||
factor,
|
||||
stock_codes=stock_codes,
|
||||
start_date="20240101",
|
||||
end_date="20240110",
|
||||
)
|
||||
|
||||
if len(result) > 0:
|
||||
# 结果中只应该有指定的股票
|
||||
stocks = result["ts_code"].unique().to_list()
|
||||
assert set(stocks) == set(stock_codes)
|
||||
|
||||
def test_ma_calculation(self, engine):
|
||||
"""测试移动平均计算"""
|
||||
factor = SimpleTimeSeriesFactor(period=3)
|
||||
|
||||
result = engine.compute(
|
||||
factor,
|
||||
stock_codes=["000001.SZ"],
|
||||
start_date="20240101",
|
||||
end_date="20240110",
|
||||
)
|
||||
|
||||
if len(result) > 2:
|
||||
# 前2个应该是 null(因为 period=3)
|
||||
ma_values = result[factor.name].to_list()
|
||||
assert ma_values[0] is None or str(ma_values[0]) == "nan"
|
||||
assert ma_values[1] is None or str(ma_values[1]) == "nan"
|
||||
# 第3个应该有值
|
||||
assert ma_values[2] is not None
|
||||
|
||||
def test_missing_stock_skipped(self, engine):
|
||||
"""测试股票不在数据中时返回空结果"""
|
||||
factor = SimpleTimeSeriesFactor(period=3)
|
||||
|
||||
result = engine.compute(
|
||||
factor,
|
||||
stock_codes=["NONEXISTENT.STOCK"],
|
||||
start_date="20240101",
|
||||
end_date="20240110",
|
||||
)
|
||||
|
||||
# 应该返回空 DataFrame 或包含该股票但值为 null 的结果
|
||||
assert isinstance(result, pl.DataFrame)
|
||||
# 对于不存在的股票,结果可能是空的
|
||||
# 或者包含该股票但值为 null
|
||||
397
tests/factors/test_factor_validation.py
Normal file
397
tests/factors/test_factor_validation.py
Normal file
@@ -0,0 +1,397 @@
|
||||
"""因子真实数据测试 - 与 Polars 原生计算对比
|
||||
|
||||
测试目标:
|
||||
1. 时序因子 - 移动平均线 (MA)
|
||||
2. 截面因子 - PE_Rank(市盈率排名)
|
||||
3. 结合因子 - 时序 * 截面组合
|
||||
|
||||
每个因子都与原始 Polars 计算进行对比验证。
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import pandas as pd
|
||||
import polars as pl
|
||||
import numpy as np
|
||||
from src.factors import DataSpec, FactorContext, FactorData
|
||||
from src.factors.base import CrossSectionalFactor, TimeSeriesFactor
|
||||
from src.factors.composite import CompositeFactor, ScalarFactor
|
||||
|
||||
|
||||
# ========== 测试数据准备 ==========
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def daily_data():
|
||||
"""加载日线测试数据(直接使用 Polars)"""
|
||||
with pd.HDFStore("data/daily.h5", mode="r") as store:
|
||||
df = store["/daily"]
|
||||
|
||||
# 筛选日期范围
|
||||
df = df[(df["trade_date"] >= "20240101") & (df["trade_date"] <= "20240430")]
|
||||
|
||||
# 选择部分股票(取前20个)
|
||||
stocks = df["ts_code"].unique()[:20]
|
||||
df = df[df["ts_code"].isin(stocks)]
|
||||
|
||||
# 直接返回 Polars DataFrame,不转 pandas
|
||||
pl_df = pl.from_pandas(df)
|
||||
pl_df = pl_df.sort(["ts_code", "trade_date"])
|
||||
|
||||
return pl_df
|
||||
|
||||
|
||||
# ========== 时序因子定义 ==========
|
||||
|
||||
|
||||
class MAFactor(TimeSeriesFactor):
|
||||
"""移动平均线因子(时序因子)"""
|
||||
|
||||
name = "ma_factor"
|
||||
data_specs = [
|
||||
DataSpec("daily", ["ts_code", "trade_date", "close"], lookback_days=5)
|
||||
]
|
||||
|
||||
def __init__(self, period: int = 5):
|
||||
super().__init__(period=period)
|
||||
|
||||
def compute(self, data: FactorData) -> pl.Series:
|
||||
close = data.get_column("close")
|
||||
period = self.params["period"]
|
||||
return close.rolling_mean(period)
|
||||
|
||||
|
||||
class PERankFactor(CrossSectionalFactor):
|
||||
"""PE 市盈率排名因子(截面因子)"""
|
||||
|
||||
name = "pe_rank_factor"
|
||||
data_specs = [
|
||||
DataSpec("daily", ["ts_code", "trade_date", "close"], lookback_days=1)
|
||||
]
|
||||
|
||||
def compute(self, data: FactorData) -> pl.Series:
|
||||
cs = data.get_cross_section()
|
||||
close = cs["close"]
|
||||
return close.rank() / close.len()
|
||||
|
||||
|
||||
# ========== 测试用例 ==========
|
||||
|
||||
|
||||
class TestTimeSeriesFactor:
|
||||
"""时序因子测试"""
|
||||
|
||||
def test_ma_factor(self, daily_data):
|
||||
"""测试 MA 因子与 Polars 原生计算对比"""
|
||||
period = 5
|
||||
sample_stock = daily_data["ts_code"].to_list()[0]
|
||||
stock_df = daily_data.filter(pl.col("ts_code") == sample_stock).sort(
|
||||
"trade_date"
|
||||
)
|
||||
|
||||
# Polars 基准计算
|
||||
polars_result = stock_df.with_columns(
|
||||
pl.col("close")
|
||||
.rolling_mean(window_size=period)
|
||||
.over("ts_code")
|
||||
.alias("ma_polars")
|
||||
)
|
||||
|
||||
# 因子框架计算
|
||||
context = FactorContext(current_stock=sample_stock)
|
||||
factor_data = FactorData(
|
||||
stock_df.with_columns([pl.col("trade_date").cast(pl.Utf8)]), context
|
||||
)
|
||||
|
||||
ma_factor = MAFactor(period=period)
|
||||
factor_result = ma_factor.compute(factor_data).to_numpy()
|
||||
|
||||
# 对比结果
|
||||
polars_values = polars_result["ma_polars"].to_numpy()
|
||||
|
||||
# 去除 NaN 后对比
|
||||
valid_idx = ~np.isnan(polars_values)
|
||||
polars_valid = polars_values[valid_idx]
|
||||
factor_valid = factor_result[valid_idx]
|
||||
|
||||
diff = np.abs(polars_valid - factor_valid)
|
||||
max_diff = np.max(diff)
|
||||
|
||||
print(f"\n[时序因子 MA({period}) 对比]")
|
||||
print(f" 样本股票: {sample_stock}")
|
||||
print(f" 有效数据点: {len(polars_valid)}")
|
||||
print(f" 最大差异: {max_diff:.15f}")
|
||||
print(f" 样本数据 (前5个):")
|
||||
for i in range(min(5, len(polars_valid))):
|
||||
print(
|
||||
f" Polars: {polars_valid[i]:.6f}, Factor: {factor_valid[i]:.6f}, Diff: {abs(polars_valid[i] - factor_valid[i]):.15f}"
|
||||
)
|
||||
|
||||
assert max_diff < 1e-10, f"MA 因子计算差异过大: {max_diff}"
|
||||
|
||||
|
||||
class TestCrossSectionalFactor:
|
||||
"""截面因子测试"""
|
||||
|
||||
def test_pe_rank_factor(self, daily_data):
|
||||
"""测试 PE_Rank 因子与 Polars 原生计算对比"""
|
||||
trade_dates = daily_data["trade_date"].unique().to_list()
|
||||
sample_date = trade_dates[50]
|
||||
date_df = daily_data.filter(pl.col("trade_date") == sample_date)
|
||||
|
||||
# Polars 基准计算
|
||||
polars_result = date_df.with_columns(
|
||||
(pl.col("close").rank() / pl.col("close").count()).alias("pe_rank_polars")
|
||||
)
|
||||
|
||||
# 因子框架计算
|
||||
context = FactorContext(current_date=str(sample_date))
|
||||
factor_data = FactorData(
|
||||
date_df.with_columns(
|
||||
[pl.col("trade_date").cast(pl.Utf8), pl.col("ts_code").cast(pl.Utf8)]
|
||||
),
|
||||
context,
|
||||
)
|
||||
|
||||
pe_factor = PERankFactor()
|
||||
factor_result = pe_factor.compute(factor_data).to_numpy()
|
||||
|
||||
# 对比结果
|
||||
polars_values = polars_result["pe_rank_polars"].to_numpy()
|
||||
|
||||
diff = np.abs(polars_values - factor_result)
|
||||
max_diff = np.max(diff)
|
||||
|
||||
print(f"\n[截面因子 PE_Rank 对比]")
|
||||
print(f" 样本日期: {sample_date}")
|
||||
print(f" 股票数量: {len(polars_values)}")
|
||||
print(f" 最大差异: {max_diff:.15f}")
|
||||
print(f" 样本数据 (前5个):")
|
||||
for i in range(min(5, len(polars_values))):
|
||||
ts_code = polars_result["ts_code"].to_numpy()[i]
|
||||
print(
|
||||
f" {ts_code}: Polars: {polars_values[i]:.6f}, Factor: {factor_result[i]:.6f}"
|
||||
)
|
||||
|
||||
assert max_diff < 1e-10, f"PE_Rank 因子计算差异过大: {max_diff}"
|
||||
|
||||
|
||||
class TestCompositeFactor:
|
||||
"""结合因子测试"""
|
||||
|
||||
def test_scalar_composite(self, daily_data):
|
||||
"""测试标量组合因子: 0.5 * MA"""
|
||||
period = 5
|
||||
sample_stock = daily_data["ts_code"].to_list()[0]
|
||||
stock_df = daily_data.filter(pl.col("ts_code") == sample_stock).sort(
|
||||
"trade_date"
|
||||
)
|
||||
|
||||
# Polars 基准计算
|
||||
polars_ma = stock_df.with_columns(
|
||||
pl.col("close").rolling_mean(window_size=period).over("ts_code").alias("ma")
|
||||
)
|
||||
polars_combined = 0.5 * polars_ma["ma"].to_numpy()
|
||||
|
||||
# 因子框架计算
|
||||
context = FactorContext(current_stock=sample_stock)
|
||||
factor_data = FactorData(
|
||||
stock_df.with_columns([pl.col("trade_date").cast(pl.Utf8)]), context
|
||||
)
|
||||
|
||||
# 组合因子: 0.5 * MA
|
||||
ma_factor = MAFactor(period=period)
|
||||
scalar_factor = 0.5 * ma_factor
|
||||
factor_result = scalar_factor.compute(factor_data).to_numpy()
|
||||
|
||||
# 对比结果
|
||||
valid_idx = ~np.isnan(polars_combined)
|
||||
polars_valid = polars_combined[valid_idx]
|
||||
factor_valid = factor_result[valid_idx]
|
||||
|
||||
diff = np.abs(polars_valid - factor_valid)
|
||||
max_diff = np.max(diff)
|
||||
|
||||
print(f"\n[结合因子 0.5*MA({period}) 对比]")
|
||||
print(f" 公式: 0.5 * MA({period})")
|
||||
print(f" 有效数据点: {len(polars_valid)}")
|
||||
print(f" 最大差异: {max_diff:.15f}")
|
||||
print(f" 样本数据 (前5个):")
|
||||
for i in range(min(5, len(polars_valid))):
|
||||
print(
|
||||
f" Polars: {polars_valid[i]:.6f}, Factor: {factor_valid[i]:.6f}, Diff: {abs(polars_valid[i] - factor_valid[i]):.15f}"
|
||||
)
|
||||
|
||||
assert max_diff < 1e-10, f"组合因子计算差异过大: {max_diff}"
|
||||
|
||||
def test_factor_addition(self, daily_data):
|
||||
"""测试因子加法组合: MA(5) + MA(10)"""
|
||||
sample_stock = daily_data["ts_code"].to_list()[0]
|
||||
stock_df = daily_data.filter(pl.col("ts_code") == sample_stock).sort(
|
||||
"trade_date"
|
||||
)
|
||||
|
||||
context = FactorContext(current_stock=sample_stock)
|
||||
|
||||
# Polars 基准计算
|
||||
polars_ma5 = stock_df.with_columns(
|
||||
pl.col("close").rolling_mean(window_size=5).over("ts_code").alias("ma5")
|
||||
)
|
||||
polars_ma10 = stock_df.with_columns(
|
||||
pl.col("close").rolling_mean(window_size=10).over("ts_code").alias("ma10")
|
||||
)
|
||||
polars_combined = polars_ma5["ma5"].to_numpy() + polars_ma10["ma10"].to_numpy()
|
||||
|
||||
# 因子框架计算
|
||||
factor_data = FactorData(
|
||||
stock_df.with_columns([pl.col("trade_date").cast(pl.Utf8)]), context
|
||||
)
|
||||
|
||||
ma5 = MAFactor(period=5)
|
||||
ma10 = MAFactor(period=10)
|
||||
combined = ma5 + ma10
|
||||
|
||||
factor_result = combined.compute(factor_data).to_numpy()
|
||||
|
||||
# 对比结果
|
||||
valid_idx = ~(np.isnan(polars_combined) | np.isnan(factor_result))
|
||||
polars_valid = polars_combined[valid_idx]
|
||||
factor_valid = factor_result[valid_idx]
|
||||
|
||||
diff = np.abs(polars_valid - factor_valid)
|
||||
max_diff = np.max(diff)
|
||||
|
||||
print(f"\n[结合因子 MA(5) + MA(10) 对比]")
|
||||
print(f" 有效数据点: {len(polars_valid)}")
|
||||
print(f" 最大差异: {max_diff:.15f}")
|
||||
|
||||
assert max_diff < 1e-10, f"因子加法组合差异过大: {max_diff}"
|
||||
|
||||
|
||||
class TestFactorComparison:
|
||||
"""全面对比测试"""
|
||||
|
||||
def test_all_factors_summary(self, daily_data):
|
||||
"""汇总所有因子测试结果"""
|
||||
print("\n" + "=" * 60)
|
||||
print("因子测试汇总")
|
||||
print("=" * 60)
|
||||
|
||||
# 测试多个时序周期
|
||||
for period in [5, 10, 20]:
|
||||
sample_stock = daily_data["ts_code"].to_list()[0]
|
||||
stock_df = daily_data.filter(pl.col("ts_code") == sample_stock).sort(
|
||||
"trade_date"
|
||||
)
|
||||
|
||||
polars_result = stock_df.with_columns(
|
||||
pl.col("close")
|
||||
.rolling_mean(window_size=period)
|
||||
.over("ts_code")
|
||||
.alias("ma")
|
||||
)
|
||||
|
||||
context = FactorContext(current_stock=sample_stock)
|
||||
factor_data = FactorData(
|
||||
stock_df.with_columns([pl.col("trade_date").cast(pl.Utf8)]), context
|
||||
)
|
||||
|
||||
ma_factor = MAFactor(period=period)
|
||||
factor_result = ma_factor.compute(factor_data).to_numpy()
|
||||
|
||||
polars_values = polars_result["ma"].to_numpy()
|
||||
valid_idx = ~np.isnan(polars_values)
|
||||
|
||||
diff = np.abs(polars_values[valid_idx] - factor_result[valid_idx])
|
||||
max_diff = np.max(diff)
|
||||
|
||||
status = "通过" if max_diff < 1e-10 else "失败"
|
||||
print(f" MA({period}): 最大差异 = {max_diff:.2e} {status}")
|
||||
|
||||
# 测试截面因子
|
||||
trade_dates = daily_data["trade_date"].unique().to_list()
|
||||
sample_date = trade_dates[50]
|
||||
date_df = daily_data.filter(pl.col("trade_date") == sample_date)
|
||||
|
||||
polars_result = date_df.with_columns(
|
||||
(pl.col("close").rank() / pl.col("close").count()).alias("rank")
|
||||
)
|
||||
|
||||
context = FactorContext(current_date=str(sample_date))
|
||||
factor_data = FactorData(
|
||||
date_df.with_columns(
|
||||
[pl.col("trade_date").cast(pl.Utf8), pl.col("ts_code").cast(pl.Utf8)]
|
||||
),
|
||||
context,
|
||||
)
|
||||
|
||||
pe_factor = PERankFactor()
|
||||
factor_result = pe_factor.compute(factor_data).to_numpy()
|
||||
|
||||
polars_values = polars_result["rank"].to_numpy()
|
||||
diff = np.abs(polars_values - factor_result)
|
||||
max_diff = np.max(diff)
|
||||
|
||||
status = "通过" if max_diff < 1e-10 else "失败"
|
||||
print(f" PE_Rank: 最大差异 = {max_diff:.2e} {status}")
|
||||
|
||||
print("=" * 60)
|
||||
|
||||
# 测试多个时序周期
|
||||
for period in [5, 10, 20]:
|
||||
sample_stock = daily_data["ts_code"].to_list()[0]
|
||||
stock_df = daily_data.filter(pl.col("ts_code") == sample_stock).sort(
|
||||
"trade_date"
|
||||
)
|
||||
|
||||
polars_result = stock_df.with_columns(
|
||||
pl.col("close")
|
||||
.rolling_mean(window_size=period)
|
||||
.over("ts_code")
|
||||
.alias("ma")
|
||||
)
|
||||
|
||||
context = FactorContext(current_stock=sample_stock)
|
||||
factor_data = FactorData(
|
||||
stock_df.with_columns([pl.col("trade_date").cast(pl.Utf8)]), context
|
||||
)
|
||||
|
||||
ma_factor = MAFactor(period=period)
|
||||
factor_result = ma_factor.compute(factor_data).to_numpy()
|
||||
|
||||
polars_values = polars_result["ma"].to_numpy()
|
||||
valid_idx = ~np.isnan(polars_values)
|
||||
|
||||
diff = np.abs(polars_values[valid_idx] - factor_result[valid_idx])
|
||||
max_diff = np.max(diff)
|
||||
|
||||
status = "通过" if max_diff < 1e-10 else "失败"
|
||||
print(f" MA({period}): 最大差异 = {max_diff:.2e} {status}")
|
||||
|
||||
# 测试截面因子
|
||||
trade_dates = daily_data["trade_date"].unique().to_list()
|
||||
sample_date = trade_dates[50]
|
||||
date_df = daily_data.filter(pl.col("trade_date") == sample_date)
|
||||
|
||||
polars_result = date_df.with_columns(
|
||||
(pl.col("close").rank() / pl.col("close").count()).alias("rank")
|
||||
)
|
||||
|
||||
context = FactorContext(current_date=str(sample_date))
|
||||
factor_data = FactorData(
|
||||
date_df.with_columns(
|
||||
[pl.col("trade_date").cast(pl.Utf8), pl.col("ts_code").cast(pl.Utf8)]
|
||||
),
|
||||
context,
|
||||
)
|
||||
|
||||
pe_factor = PERankFactor()
|
||||
factor_result = pe_factor.compute(factor_data).to_numpy()
|
||||
|
||||
polars_values = polars_result["rank"].to_numpy()
|
||||
diff = np.abs(polars_values - factor_result)
|
||||
max_diff = np.max(diff)
|
||||
|
||||
status = "通过" if max_diff < 1e-10 else "失败"
|
||||
print(f" PE_Rank: 最大差异 = {max_diff:.2e} {status}")
|
||||
|
||||
print("=" * 60)
|
||||
Reference in New Issue
Block a user