feat(factors): 添加因子计算框架

- 新增因子基类 (BaseFactor, CrossSectionalFactor, TimeSeriesFactor)
- 新增数据规格和上下文类 (DataSpec, FactorContext, FactorData)
- 新增数据加载器 (DataLoader) 和执行引擎 (FactorEngine)
- 新增组合因子支持 (CompositeFactor, ScalarFactor)
- 添加因子模块完整测试用例
- 添加 Git 提交规范文档
This commit is contained in:
2026-02-22 14:41:32 +08:00
parent 9965ce5706
commit 0a16129548
21 changed files with 7064 additions and 748 deletions

View File

@@ -0,0 +1,328 @@
"""Factors 模块测试 - Phase 1: 数据类型定义测试
测试范围:
- DataSpec: 数据需求规格的创建和验证
- FactorContext: 计算上下文的创建
- FactorData: 数据容器的基本操作
- HDF5 数据读取: 验证能正确读取 daily.h5 文件
"""
import pytest
import polars as pl
import pandas as pd
from pathlib import Path
from src.factors import DataSpec, FactorContext, FactorData
class TestDataSpec:
"""测试 DataSpec 数据需求规格"""
def test_valid_dataspec_creation(self):
"""测试有效的 DataSpec 创建"""
spec = DataSpec(
source="daily", columns=["ts_code", "trade_date", "close"], lookback_days=5
)
assert spec.source == "daily"
assert spec.columns == ["ts_code", "trade_date", "close"]
assert spec.lookback_days == 5
def test_dataspec_default_lookback(self):
"""测试 DataSpec 默认值 lookback_days=1"""
spec = DataSpec(source="daily", columns=["ts_code", "trade_date", "close"])
assert spec.lookback_days == 1
def test_dataspec_frozen_immutable(self):
"""测试 DataSpec 是 frozen不可变"""
spec = DataSpec(
source="daily", columns=["ts_code", "trade_date", "close"], lookback_days=5
)
with pytest.raises(FrozenInstanceError):
spec.source = "other"
def test_dataspec_lookback_less_than_1_raises(self):
"""测试 lookback_days < 1 时抛出 ValueError"""
with pytest.raises(ValueError, match="lookback_days must be >= 1"):
DataSpec(
source="daily",
columns=["ts_code", "trade_date", "close"],
lookback_days=0,
)
with pytest.raises(ValueError, match="lookback_days must be >= 1"):
DataSpec(
source="daily",
columns=["ts_code", "trade_date", "close"],
lookback_days=-1,
)
def test_dataspec_missing_required_columns_raises(self):
"""测试缺少 ts_code 或 trade_date 时抛出 ValueError"""
# 缺少 ts_code
with pytest.raises(ValueError, match="columns must contain"):
DataSpec(source="daily", columns=["trade_date", "close"], lookback_days=5)
# 缺少 trade_date
with pytest.raises(ValueError, match="columns must contain"):
DataSpec(source="daily", columns=["ts_code", "close"], lookback_days=5)
# 两者都缺少
with pytest.raises(ValueError, match="columns must contain"):
DataSpec(source="daily", columns=["close", "open", "high"], lookback_days=5)
def test_dataspec_empty_source_raises(self):
"""测试空 source 时抛出 ValueError"""
with pytest.raises(ValueError, match="source cannot be empty string"):
DataSpec(
source="", columns=["ts_code", "trade_date", "close"], lookback_days=5
)
class TestFactorContext:
"""测试 FactorContext 计算上下文"""
def test_default_creation(self):
"""测试默认值创建"""
ctx = FactorContext()
assert ctx.current_date is None
assert ctx.current_stock is None
assert ctx.trade_dates is None
def test_full_creation(self):
"""测试完整参数创建"""
ctx = FactorContext(
current_date="20240101",
current_stock="000001.SZ",
trade_dates=["20240101", "20240102", "20240103"],
)
assert ctx.current_date == "20240101"
assert ctx.current_stock == "000001.SZ"
assert ctx.trade_dates == ["20240101", "20240102", "20240103"]
def test_partial_creation(self):
"""测试部分参数创建"""
ctx = FactorContext(current_date="20240101")
assert ctx.current_date == "20240101"
assert ctx.current_stock is None
assert ctx.trade_dates is None
def test_dataclass_methods(self):
"""测试 dataclass 自动生成的方法"""
ctx = FactorContext(current_date="20240101")
# __repr__
assert "FactorContext" in repr(ctx)
assert "20240101" in repr(ctx)
# __eq__
ctx2 = FactorContext(current_date="20240101")
assert ctx == ctx2
ctx3 = FactorContext(current_date="20240102")
assert ctx != ctx3
class TestFactorData:
"""测试 FactorData 数据容器"""
@pytest.fixture
def sample_df(self):
"""创建示例 DataFrame"""
return pl.DataFrame(
{
"ts_code": ["000001.SZ", "000002.SZ", "000001.SZ", "000002.SZ"],
"trade_date": ["20240101", "20240101", "20240102", "20240102"],
"close": [10.0, 20.0, 10.5, 20.5],
"volume": [1000, 2000, 1100, 2100],
}
)
@pytest.fixture
def cs_context(self):
"""截面因子上下文"""
return FactorContext(current_date="20240101")
@pytest.fixture
def ts_context(self):
"""时序因子上下文"""
return FactorContext(current_stock="000001.SZ")
def test_get_column(self, sample_df, cs_context):
"""测试 get_column 返回正确的 Series"""
data = FactorData(sample_df, cs_context)
close_series = data.get_column("close")
assert isinstance(close_series, pl.Series)
assert close_series.to_list() == [10.0, 20.0, 10.5, 20.5]
def test_get_column_keyerror(self, sample_df, cs_context):
"""测试 get_column 列不存在时抛出 KeyError"""
data = FactorData(sample_df, cs_context)
with pytest.raises(KeyError, match="Column 'nonexistent' not found"):
data.get_column("nonexistent")
def test_filter_by_date(self, sample_df, cs_context):
"""测试 filter_by_date 返回正确的过滤结果"""
data = FactorData(sample_df, cs_context)
filtered = data.filter_by_date("20240101")
assert len(filtered) == 2
assert filtered.to_polars()["ts_code"].to_list() == ["000001.SZ", "000002.SZ"]
assert filtered.to_polars()["close"].to_list() == [10.0, 20.0]
def test_filter_by_date_empty_result(self, sample_df, cs_context):
"""测试 filter_by_date 日期不存在时返回空的 FactorData"""
data = FactorData(sample_df, cs_context)
filtered = data.filter_by_date("20241231")
assert len(filtered) == 0
assert isinstance(filtered, FactorData)
def test_get_cross_section(self, sample_df, cs_context):
"""测试 get_cross_section 返回 current_date 当天的数据"""
data = FactorData(sample_df, cs_context)
cs = data.get_cross_section()
assert len(cs) == 2
assert cs["ts_code"].to_list() == ["000001.SZ", "000002.SZ"]
def test_get_cross_section_no_date_raises(self, sample_df, ts_context):
"""测试 get_cross_section current_date 为 None 时抛出 ValueError"""
data = FactorData(sample_df, ts_context)
with pytest.raises(ValueError, match="current_date is not set"):
data.get_cross_section()
def test_to_polars(self, sample_df, cs_context):
"""测试 to_polars 返回原始 DataFrame"""
data = FactorData(sample_df, cs_context)
df = data.to_polars()
assert isinstance(df, pl.DataFrame)
assert df.shape == sample_df.shape
assert df.columns == sample_df.columns
def test_context_property(self, sample_df, cs_context):
"""测试 context 属性返回正确的上下文"""
data = FactorData(sample_df, cs_context)
assert data.context == cs_context
assert data.context.current_date == "20240101"
def test_len(self, sample_df, cs_context):
"""测试 __len__ 返回正确的行数"""
data = FactorData(sample_df, cs_context)
assert len(data) == 4
def test_repr(self, sample_df, cs_context):
"""测试 __repr__ 返回可读字符串"""
data = FactorData(sample_df, cs_context)
repr_str = repr(data)
assert "FactorData" in repr_str
assert "rows=4" in repr_str
assert "date=20240101" in repr_str
class TestHDF5DataAccess:
"""测试 HDF5 数据读取功能"""
def test_daily_h5_file_exists(self):
"""测试 daily.h5 文件存在"""
data_path = Path("data/daily.h5")
assert data_path.exists(), f"daily.h5 文件不存在: {data_path.absolute()}"
def test_daily_h5_can_read_with_pandas(self):
"""测试能用 pandas 读取 daily.h5"""
data_path = Path("data/daily.h5")
df = pd.read_hdf(data_path, key="/daily")
assert df is not None
assert len(df) > 0
assert "ts_code" in df.columns
assert "trade_date" in df.columns
assert "close" in df.columns
def test_daily_h5_columns(self):
"""测试 daily.h5 包含预期的列"""
data_path = Path("data/daily.h5")
df = pd.read_hdf(data_path, key="/daily")
expected_columns = [
"trade_date",
"ts_code",
"open",
"high",
"low",
"close",
"pre_close",
"change",
"pct_chg",
"vol",
"amount",
"turnover_rate",
"volume_ratio",
]
for col in expected_columns:
assert col in df.columns, f"{col} 不存在于 daily.h5"
def test_daily_h5_date_format(self):
"""测试 daily.h5 日期格式正确"""
data_path = Path("data/daily.h5")
df = pd.read_hdf(data_path, key="/daily")
# 检查日期格式是 YYYYMMDD 字符串
sample_date = df["trade_date"].iloc[0]
assert isinstance(sample_date, str), "日期应该是字符串格式"
assert len(sample_date) == 8, "日期应该是 8 位字符串 (YYYYMMDD)"
assert sample_date.isdigit(), "日期应该只包含数字"
def test_daily_h5_stock_format(self):
"""测试 daily.h5 股票代码格式正确"""
data_path = Path("data/daily.h5")
df = pd.read_hdf(data_path, key="/daily")
# 检查股票代码格式如 "000001.SZ"
sample_code = df["ts_code"].iloc[0]
assert isinstance(sample_code, str), "股票代码应该是字符串"
assert "." in sample_code, "股票代码应该包含交易所后缀"
assert sample_code.endswith((".SZ", ".SH", ".BJ")), (
"股票代码应该以交易所后缀结尾"
)
def test_daily_h5_to_polars(self):
"""测试将 daily.h5 数据转换为 Polars"""
data_path = Path("data/daily.h5")
pdf = pd.read_hdf(data_path, key="/daily")
# 转换为 Polars
df = pl.from_pandas(pdf)
assert isinstance(df, pl.DataFrame)
assert len(df) > 0
assert "ts_code" in df.columns
assert "trade_date" in df.columns
def test_daily_h5_sample_data_with_factors(self):
"""测试用 daily.h5 真实数据创建 FactorData"""
data_path = Path("data/daily.h5")
pdf = pd.read_hdf(data_path, key="/daily")
# 取前 100 行作为示例
sample_pdf = pdf.head(100)
df = pl.from_pandas(sample_pdf)
# 创建 FactorData
ctx = FactorContext(current_date=df["trade_date"][0])
data = FactorData(df, ctx)
# 验证基本操作
assert len(data) == 100
assert "close" in data.to_polars().columns
# 测试 get_column
close_prices = data.get_column("close")
assert len(close_prices) == 100
# 测试 filter_by_date
first_date = df["trade_date"][0]
filtered = data.filter_by_date(first_date)
assert len(filtered) > 0
# 导入 FrozenInstanceError
try:
from dataclasses import FrozenInstanceError
except ImportError:
# Python < 3.10 compatibility
FrozenInstanceError = AttributeError