- 新增因子基类 (BaseFactor, CrossSectionalFactor, TimeSeriesFactor) - 新增数据规格和上下文类 (DataSpec, FactorContext, FactorData) - 新增数据加载器 (DataLoader) 和执行引擎 (FactorEngine) - 新增组合因子支持 (CompositeFactor, ScalarFactor) - 添加因子模块完整测试用例 - 添加 Git 提交规范文档
329 lines
12 KiB
Python
329 lines
12 KiB
Python
"""Factors 模块测试 - Phase 1: 数据类型定义测试
|
||
|
||
测试范围:
|
||
- DataSpec: 数据需求规格的创建和验证
|
||
- FactorContext: 计算上下文的创建
|
||
- FactorData: 数据容器的基本操作
|
||
- HDF5 数据读取: 验证能正确读取 daily.h5 文件
|
||
"""
|
||
|
||
import pytest
|
||
import polars as pl
|
||
import pandas as pd
|
||
from pathlib import Path
|
||
|
||
from src.factors import DataSpec, FactorContext, FactorData
|
||
|
||
|
||
class TestDataSpec:
|
||
"""测试 DataSpec 数据需求规格"""
|
||
|
||
def test_valid_dataspec_creation(self):
|
||
"""测试有效的 DataSpec 创建"""
|
||
spec = DataSpec(
|
||
source="daily", columns=["ts_code", "trade_date", "close"], lookback_days=5
|
||
)
|
||
assert spec.source == "daily"
|
||
assert spec.columns == ["ts_code", "trade_date", "close"]
|
||
assert spec.lookback_days == 5
|
||
|
||
def test_dataspec_default_lookback(self):
|
||
"""测试 DataSpec 默认值 lookback_days=1"""
|
||
spec = DataSpec(source="daily", columns=["ts_code", "trade_date", "close"])
|
||
assert spec.lookback_days == 1
|
||
|
||
def test_dataspec_frozen_immutable(self):
|
||
"""测试 DataSpec 是 frozen(不可变)"""
|
||
spec = DataSpec(
|
||
source="daily", columns=["ts_code", "trade_date", "close"], lookback_days=5
|
||
)
|
||
with pytest.raises(FrozenInstanceError):
|
||
spec.source = "other"
|
||
|
||
def test_dataspec_lookback_less_than_1_raises(self):
|
||
"""测试 lookback_days < 1 时抛出 ValueError"""
|
||
with pytest.raises(ValueError, match="lookback_days must be >= 1"):
|
||
DataSpec(
|
||
source="daily",
|
||
columns=["ts_code", "trade_date", "close"],
|
||
lookback_days=0,
|
||
)
|
||
|
||
with pytest.raises(ValueError, match="lookback_days must be >= 1"):
|
||
DataSpec(
|
||
source="daily",
|
||
columns=["ts_code", "trade_date", "close"],
|
||
lookback_days=-1,
|
||
)
|
||
|
||
def test_dataspec_missing_required_columns_raises(self):
|
||
"""测试缺少 ts_code 或 trade_date 时抛出 ValueError"""
|
||
# 缺少 ts_code
|
||
with pytest.raises(ValueError, match="columns must contain"):
|
||
DataSpec(source="daily", columns=["trade_date", "close"], lookback_days=5)
|
||
|
||
# 缺少 trade_date
|
||
with pytest.raises(ValueError, match="columns must contain"):
|
||
DataSpec(source="daily", columns=["ts_code", "close"], lookback_days=5)
|
||
|
||
# 两者都缺少
|
||
with pytest.raises(ValueError, match="columns must contain"):
|
||
DataSpec(source="daily", columns=["close", "open", "high"], lookback_days=5)
|
||
|
||
def test_dataspec_empty_source_raises(self):
|
||
"""测试空 source 时抛出 ValueError"""
|
||
with pytest.raises(ValueError, match="source cannot be empty string"):
|
||
DataSpec(
|
||
source="", columns=["ts_code", "trade_date", "close"], lookback_days=5
|
||
)
|
||
|
||
|
||
class TestFactorContext:
|
||
"""测试 FactorContext 计算上下文"""
|
||
|
||
def test_default_creation(self):
|
||
"""测试默认值创建"""
|
||
ctx = FactorContext()
|
||
assert ctx.current_date is None
|
||
assert ctx.current_stock is None
|
||
assert ctx.trade_dates is None
|
||
|
||
def test_full_creation(self):
|
||
"""测试完整参数创建"""
|
||
ctx = FactorContext(
|
||
current_date="20240101",
|
||
current_stock="000001.SZ",
|
||
trade_dates=["20240101", "20240102", "20240103"],
|
||
)
|
||
assert ctx.current_date == "20240101"
|
||
assert ctx.current_stock == "000001.SZ"
|
||
assert ctx.trade_dates == ["20240101", "20240102", "20240103"]
|
||
|
||
def test_partial_creation(self):
|
||
"""测试部分参数创建"""
|
||
ctx = FactorContext(current_date="20240101")
|
||
assert ctx.current_date == "20240101"
|
||
assert ctx.current_stock is None
|
||
assert ctx.trade_dates is None
|
||
|
||
def test_dataclass_methods(self):
|
||
"""测试 dataclass 自动生成的方法"""
|
||
ctx = FactorContext(current_date="20240101")
|
||
# __repr__
|
||
assert "FactorContext" in repr(ctx)
|
||
assert "20240101" in repr(ctx)
|
||
# __eq__
|
||
ctx2 = FactorContext(current_date="20240101")
|
||
assert ctx == ctx2
|
||
ctx3 = FactorContext(current_date="20240102")
|
||
assert ctx != ctx3
|
||
|
||
|
||
class TestFactorData:
|
||
"""测试 FactorData 数据容器"""
|
||
|
||
@pytest.fixture
|
||
def sample_df(self):
|
||
"""创建示例 DataFrame"""
|
||
return pl.DataFrame(
|
||
{
|
||
"ts_code": ["000001.SZ", "000002.SZ", "000001.SZ", "000002.SZ"],
|
||
"trade_date": ["20240101", "20240101", "20240102", "20240102"],
|
||
"close": [10.0, 20.0, 10.5, 20.5],
|
||
"volume": [1000, 2000, 1100, 2100],
|
||
}
|
||
)
|
||
|
||
@pytest.fixture
|
||
def cs_context(self):
|
||
"""截面因子上下文"""
|
||
return FactorContext(current_date="20240101")
|
||
|
||
@pytest.fixture
|
||
def ts_context(self):
|
||
"""时序因子上下文"""
|
||
return FactorContext(current_stock="000001.SZ")
|
||
|
||
def test_get_column(self, sample_df, cs_context):
|
||
"""测试 get_column 返回正确的 Series"""
|
||
data = FactorData(sample_df, cs_context)
|
||
close_series = data.get_column("close")
|
||
assert isinstance(close_series, pl.Series)
|
||
assert close_series.to_list() == [10.0, 20.0, 10.5, 20.5]
|
||
|
||
def test_get_column_keyerror(self, sample_df, cs_context):
|
||
"""测试 get_column 列不存在时抛出 KeyError"""
|
||
data = FactorData(sample_df, cs_context)
|
||
with pytest.raises(KeyError, match="Column 'nonexistent' not found"):
|
||
data.get_column("nonexistent")
|
||
|
||
def test_filter_by_date(self, sample_df, cs_context):
|
||
"""测试 filter_by_date 返回正确的过滤结果"""
|
||
data = FactorData(sample_df, cs_context)
|
||
filtered = data.filter_by_date("20240101")
|
||
assert len(filtered) == 2
|
||
assert filtered.to_polars()["ts_code"].to_list() == ["000001.SZ", "000002.SZ"]
|
||
assert filtered.to_polars()["close"].to_list() == [10.0, 20.0]
|
||
|
||
def test_filter_by_date_empty_result(self, sample_df, cs_context):
|
||
"""测试 filter_by_date 日期不存在时返回空的 FactorData"""
|
||
data = FactorData(sample_df, cs_context)
|
||
filtered = data.filter_by_date("20241231")
|
||
assert len(filtered) == 0
|
||
assert isinstance(filtered, FactorData)
|
||
|
||
def test_get_cross_section(self, sample_df, cs_context):
|
||
"""测试 get_cross_section 返回 current_date 当天的数据"""
|
||
data = FactorData(sample_df, cs_context)
|
||
cs = data.get_cross_section()
|
||
assert len(cs) == 2
|
||
assert cs["ts_code"].to_list() == ["000001.SZ", "000002.SZ"]
|
||
|
||
def test_get_cross_section_no_date_raises(self, sample_df, ts_context):
|
||
"""测试 get_cross_section current_date 为 None 时抛出 ValueError"""
|
||
data = FactorData(sample_df, ts_context)
|
||
with pytest.raises(ValueError, match="current_date is not set"):
|
||
data.get_cross_section()
|
||
|
||
def test_to_polars(self, sample_df, cs_context):
|
||
"""测试 to_polars 返回原始 DataFrame"""
|
||
data = FactorData(sample_df, cs_context)
|
||
df = data.to_polars()
|
||
assert isinstance(df, pl.DataFrame)
|
||
assert df.shape == sample_df.shape
|
||
assert df.columns == sample_df.columns
|
||
|
||
def test_context_property(self, sample_df, cs_context):
|
||
"""测试 context 属性返回正确的上下文"""
|
||
data = FactorData(sample_df, cs_context)
|
||
assert data.context == cs_context
|
||
assert data.context.current_date == "20240101"
|
||
|
||
def test_len(self, sample_df, cs_context):
|
||
"""测试 __len__ 返回正确的行数"""
|
||
data = FactorData(sample_df, cs_context)
|
||
assert len(data) == 4
|
||
|
||
def test_repr(self, sample_df, cs_context):
|
||
"""测试 __repr__ 返回可读字符串"""
|
||
data = FactorData(sample_df, cs_context)
|
||
repr_str = repr(data)
|
||
assert "FactorData" in repr_str
|
||
assert "rows=4" in repr_str
|
||
assert "date=20240101" in repr_str
|
||
|
||
|
||
class TestHDF5DataAccess:
|
||
"""测试 HDF5 数据读取功能"""
|
||
|
||
def test_daily_h5_file_exists(self):
|
||
"""测试 daily.h5 文件存在"""
|
||
data_path = Path("data/daily.h5")
|
||
assert data_path.exists(), f"daily.h5 文件不存在: {data_path.absolute()}"
|
||
|
||
def test_daily_h5_can_read_with_pandas(self):
|
||
"""测试能用 pandas 读取 daily.h5"""
|
||
data_path = Path("data/daily.h5")
|
||
df = pd.read_hdf(data_path, key="/daily")
|
||
|
||
assert df is not None
|
||
assert len(df) > 0
|
||
assert "ts_code" in df.columns
|
||
assert "trade_date" in df.columns
|
||
assert "close" in df.columns
|
||
|
||
def test_daily_h5_columns(self):
|
||
"""测试 daily.h5 包含预期的列"""
|
||
data_path = Path("data/daily.h5")
|
||
df = pd.read_hdf(data_path, key="/daily")
|
||
|
||
expected_columns = [
|
||
"trade_date",
|
||
"ts_code",
|
||
"open",
|
||
"high",
|
||
"low",
|
||
"close",
|
||
"pre_close",
|
||
"change",
|
||
"pct_chg",
|
||
"vol",
|
||
"amount",
|
||
"turnover_rate",
|
||
"volume_ratio",
|
||
]
|
||
|
||
for col in expected_columns:
|
||
assert col in df.columns, f"列 {col} 不存在于 daily.h5"
|
||
|
||
def test_daily_h5_date_format(self):
|
||
"""测试 daily.h5 日期格式正确"""
|
||
data_path = Path("data/daily.h5")
|
||
df = pd.read_hdf(data_path, key="/daily")
|
||
|
||
# 检查日期格式是 YYYYMMDD 字符串
|
||
sample_date = df["trade_date"].iloc[0]
|
||
assert isinstance(sample_date, str), "日期应该是字符串格式"
|
||
assert len(sample_date) == 8, "日期应该是 8 位字符串 (YYYYMMDD)"
|
||
assert sample_date.isdigit(), "日期应该只包含数字"
|
||
|
||
def test_daily_h5_stock_format(self):
|
||
"""测试 daily.h5 股票代码格式正确"""
|
||
data_path = Path("data/daily.h5")
|
||
df = pd.read_hdf(data_path, key="/daily")
|
||
|
||
# 检查股票代码格式如 "000001.SZ"
|
||
sample_code = df["ts_code"].iloc[0]
|
||
assert isinstance(sample_code, str), "股票代码应该是字符串"
|
||
assert "." in sample_code, "股票代码应该包含交易所后缀"
|
||
assert sample_code.endswith((".SZ", ".SH", ".BJ")), (
|
||
"股票代码应该以交易所后缀结尾"
|
||
)
|
||
|
||
def test_daily_h5_to_polars(self):
|
||
"""测试将 daily.h5 数据转换为 Polars"""
|
||
data_path = Path("data/daily.h5")
|
||
pdf = pd.read_hdf(data_path, key="/daily")
|
||
|
||
# 转换为 Polars
|
||
df = pl.from_pandas(pdf)
|
||
|
||
assert isinstance(df, pl.DataFrame)
|
||
assert len(df) > 0
|
||
assert "ts_code" in df.columns
|
||
assert "trade_date" in df.columns
|
||
|
||
def test_daily_h5_sample_data_with_factors(self):
|
||
"""测试用 daily.h5 真实数据创建 FactorData"""
|
||
data_path = Path("data/daily.h5")
|
||
pdf = pd.read_hdf(data_path, key="/daily")
|
||
|
||
# 取前 100 行作为示例
|
||
sample_pdf = pdf.head(100)
|
||
df = pl.from_pandas(sample_pdf)
|
||
|
||
# 创建 FactorData
|
||
ctx = FactorContext(current_date=df["trade_date"][0])
|
||
data = FactorData(df, ctx)
|
||
|
||
# 验证基本操作
|
||
assert len(data) == 100
|
||
assert "close" in data.to_polars().columns
|
||
|
||
# 测试 get_column
|
||
close_prices = data.get_column("close")
|
||
assert len(close_prices) == 100
|
||
|
||
# 测试 filter_by_date
|
||
first_date = df["trade_date"][0]
|
||
filtered = data.filter_by_date(first_date)
|
||
assert len(filtered) > 0
|
||
|
||
|
||
# 导入 FrozenInstanceError
|
||
try:
|
||
from dataclasses import FrozenInstanceError
|
||
except ImportError:
|
||
# Python < 3.10 compatibility
|
||
FrozenInstanceError = AttributeError
|