Files
ProStock/tests/factors/test_data_spec.py
liaozhaorun 0a16129548 feat(factors): 添加因子计算框架
- 新增因子基类 (BaseFactor, CrossSectionalFactor, TimeSeriesFactor)
- 新增数据规格和上下文类 (DataSpec, FactorContext, FactorData)
- 新增数据加载器 (DataLoader) 和执行引擎 (FactorEngine)
- 新增组合因子支持 (CompositeFactor, ScalarFactor)
- 添加因子模块完整测试用例
- 添加 Git 提交规范文档
2026-02-22 14:41:32 +08:00

329 lines
12 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""Factors 模块测试 - Phase 1: 数据类型定义测试
测试范围:
- DataSpec: 数据需求规格的创建和验证
- FactorContext: 计算上下文的创建
- FactorData: 数据容器的基本操作
- HDF5 数据读取: 验证能正确读取 daily.h5 文件
"""
import pytest
import polars as pl
import pandas as pd
from pathlib import Path
from src.factors import DataSpec, FactorContext, FactorData
class TestDataSpec:
"""测试 DataSpec 数据需求规格"""
def test_valid_dataspec_creation(self):
"""测试有效的 DataSpec 创建"""
spec = DataSpec(
source="daily", columns=["ts_code", "trade_date", "close"], lookback_days=5
)
assert spec.source == "daily"
assert spec.columns == ["ts_code", "trade_date", "close"]
assert spec.lookback_days == 5
def test_dataspec_default_lookback(self):
"""测试 DataSpec 默认值 lookback_days=1"""
spec = DataSpec(source="daily", columns=["ts_code", "trade_date", "close"])
assert spec.lookback_days == 1
def test_dataspec_frozen_immutable(self):
"""测试 DataSpec 是 frozen不可变"""
spec = DataSpec(
source="daily", columns=["ts_code", "trade_date", "close"], lookback_days=5
)
with pytest.raises(FrozenInstanceError):
spec.source = "other"
def test_dataspec_lookback_less_than_1_raises(self):
"""测试 lookback_days < 1 时抛出 ValueError"""
with pytest.raises(ValueError, match="lookback_days must be >= 1"):
DataSpec(
source="daily",
columns=["ts_code", "trade_date", "close"],
lookback_days=0,
)
with pytest.raises(ValueError, match="lookback_days must be >= 1"):
DataSpec(
source="daily",
columns=["ts_code", "trade_date", "close"],
lookback_days=-1,
)
def test_dataspec_missing_required_columns_raises(self):
"""测试缺少 ts_code 或 trade_date 时抛出 ValueError"""
# 缺少 ts_code
with pytest.raises(ValueError, match="columns must contain"):
DataSpec(source="daily", columns=["trade_date", "close"], lookback_days=5)
# 缺少 trade_date
with pytest.raises(ValueError, match="columns must contain"):
DataSpec(source="daily", columns=["ts_code", "close"], lookback_days=5)
# 两者都缺少
with pytest.raises(ValueError, match="columns must contain"):
DataSpec(source="daily", columns=["close", "open", "high"], lookback_days=5)
def test_dataspec_empty_source_raises(self):
"""测试空 source 时抛出 ValueError"""
with pytest.raises(ValueError, match="source cannot be empty string"):
DataSpec(
source="", columns=["ts_code", "trade_date", "close"], lookback_days=5
)
class TestFactorContext:
"""测试 FactorContext 计算上下文"""
def test_default_creation(self):
"""测试默认值创建"""
ctx = FactorContext()
assert ctx.current_date is None
assert ctx.current_stock is None
assert ctx.trade_dates is None
def test_full_creation(self):
"""测试完整参数创建"""
ctx = FactorContext(
current_date="20240101",
current_stock="000001.SZ",
trade_dates=["20240101", "20240102", "20240103"],
)
assert ctx.current_date == "20240101"
assert ctx.current_stock == "000001.SZ"
assert ctx.trade_dates == ["20240101", "20240102", "20240103"]
def test_partial_creation(self):
"""测试部分参数创建"""
ctx = FactorContext(current_date="20240101")
assert ctx.current_date == "20240101"
assert ctx.current_stock is None
assert ctx.trade_dates is None
def test_dataclass_methods(self):
"""测试 dataclass 自动生成的方法"""
ctx = FactorContext(current_date="20240101")
# __repr__
assert "FactorContext" in repr(ctx)
assert "20240101" in repr(ctx)
# __eq__
ctx2 = FactorContext(current_date="20240101")
assert ctx == ctx2
ctx3 = FactorContext(current_date="20240102")
assert ctx != ctx3
class TestFactorData:
"""测试 FactorData 数据容器"""
@pytest.fixture
def sample_df(self):
"""创建示例 DataFrame"""
return pl.DataFrame(
{
"ts_code": ["000001.SZ", "000002.SZ", "000001.SZ", "000002.SZ"],
"trade_date": ["20240101", "20240101", "20240102", "20240102"],
"close": [10.0, 20.0, 10.5, 20.5],
"volume": [1000, 2000, 1100, 2100],
}
)
@pytest.fixture
def cs_context(self):
"""截面因子上下文"""
return FactorContext(current_date="20240101")
@pytest.fixture
def ts_context(self):
"""时序因子上下文"""
return FactorContext(current_stock="000001.SZ")
def test_get_column(self, sample_df, cs_context):
"""测试 get_column 返回正确的 Series"""
data = FactorData(sample_df, cs_context)
close_series = data.get_column("close")
assert isinstance(close_series, pl.Series)
assert close_series.to_list() == [10.0, 20.0, 10.5, 20.5]
def test_get_column_keyerror(self, sample_df, cs_context):
"""测试 get_column 列不存在时抛出 KeyError"""
data = FactorData(sample_df, cs_context)
with pytest.raises(KeyError, match="Column 'nonexistent' not found"):
data.get_column("nonexistent")
def test_filter_by_date(self, sample_df, cs_context):
"""测试 filter_by_date 返回正确的过滤结果"""
data = FactorData(sample_df, cs_context)
filtered = data.filter_by_date("20240101")
assert len(filtered) == 2
assert filtered.to_polars()["ts_code"].to_list() == ["000001.SZ", "000002.SZ"]
assert filtered.to_polars()["close"].to_list() == [10.0, 20.0]
def test_filter_by_date_empty_result(self, sample_df, cs_context):
"""测试 filter_by_date 日期不存在时返回空的 FactorData"""
data = FactorData(sample_df, cs_context)
filtered = data.filter_by_date("20241231")
assert len(filtered) == 0
assert isinstance(filtered, FactorData)
def test_get_cross_section(self, sample_df, cs_context):
"""测试 get_cross_section 返回 current_date 当天的数据"""
data = FactorData(sample_df, cs_context)
cs = data.get_cross_section()
assert len(cs) == 2
assert cs["ts_code"].to_list() == ["000001.SZ", "000002.SZ"]
def test_get_cross_section_no_date_raises(self, sample_df, ts_context):
"""测试 get_cross_section current_date 为 None 时抛出 ValueError"""
data = FactorData(sample_df, ts_context)
with pytest.raises(ValueError, match="current_date is not set"):
data.get_cross_section()
def test_to_polars(self, sample_df, cs_context):
"""测试 to_polars 返回原始 DataFrame"""
data = FactorData(sample_df, cs_context)
df = data.to_polars()
assert isinstance(df, pl.DataFrame)
assert df.shape == sample_df.shape
assert df.columns == sample_df.columns
def test_context_property(self, sample_df, cs_context):
"""测试 context 属性返回正确的上下文"""
data = FactorData(sample_df, cs_context)
assert data.context == cs_context
assert data.context.current_date == "20240101"
def test_len(self, sample_df, cs_context):
"""测试 __len__ 返回正确的行数"""
data = FactorData(sample_df, cs_context)
assert len(data) == 4
def test_repr(self, sample_df, cs_context):
"""测试 __repr__ 返回可读字符串"""
data = FactorData(sample_df, cs_context)
repr_str = repr(data)
assert "FactorData" in repr_str
assert "rows=4" in repr_str
assert "date=20240101" in repr_str
class TestHDF5DataAccess:
"""测试 HDF5 数据读取功能"""
def test_daily_h5_file_exists(self):
"""测试 daily.h5 文件存在"""
data_path = Path("data/daily.h5")
assert data_path.exists(), f"daily.h5 文件不存在: {data_path.absolute()}"
def test_daily_h5_can_read_with_pandas(self):
"""测试能用 pandas 读取 daily.h5"""
data_path = Path("data/daily.h5")
df = pd.read_hdf(data_path, key="/daily")
assert df is not None
assert len(df) > 0
assert "ts_code" in df.columns
assert "trade_date" in df.columns
assert "close" in df.columns
def test_daily_h5_columns(self):
"""测试 daily.h5 包含预期的列"""
data_path = Path("data/daily.h5")
df = pd.read_hdf(data_path, key="/daily")
expected_columns = [
"trade_date",
"ts_code",
"open",
"high",
"low",
"close",
"pre_close",
"change",
"pct_chg",
"vol",
"amount",
"turnover_rate",
"volume_ratio",
]
for col in expected_columns:
assert col in df.columns, f"{col} 不存在于 daily.h5"
def test_daily_h5_date_format(self):
"""测试 daily.h5 日期格式正确"""
data_path = Path("data/daily.h5")
df = pd.read_hdf(data_path, key="/daily")
# 检查日期格式是 YYYYMMDD 字符串
sample_date = df["trade_date"].iloc[0]
assert isinstance(sample_date, str), "日期应该是字符串格式"
assert len(sample_date) == 8, "日期应该是 8 位字符串 (YYYYMMDD)"
assert sample_date.isdigit(), "日期应该只包含数字"
def test_daily_h5_stock_format(self):
"""测试 daily.h5 股票代码格式正确"""
data_path = Path("data/daily.h5")
df = pd.read_hdf(data_path, key="/daily")
# 检查股票代码格式如 "000001.SZ"
sample_code = df["ts_code"].iloc[0]
assert isinstance(sample_code, str), "股票代码应该是字符串"
assert "." in sample_code, "股票代码应该包含交易所后缀"
assert sample_code.endswith((".SZ", ".SH", ".BJ")), (
"股票代码应该以交易所后缀结尾"
)
def test_daily_h5_to_polars(self):
"""测试将 daily.h5 数据转换为 Polars"""
data_path = Path("data/daily.h5")
pdf = pd.read_hdf(data_path, key="/daily")
# 转换为 Polars
df = pl.from_pandas(pdf)
assert isinstance(df, pl.DataFrame)
assert len(df) > 0
assert "ts_code" in df.columns
assert "trade_date" in df.columns
def test_daily_h5_sample_data_with_factors(self):
"""测试用 daily.h5 真实数据创建 FactorData"""
data_path = Path("data/daily.h5")
pdf = pd.read_hdf(data_path, key="/daily")
# 取前 100 行作为示例
sample_pdf = pdf.head(100)
df = pl.from_pandas(sample_pdf)
# 创建 FactorData
ctx = FactorContext(current_date=df["trade_date"][0])
data = FactorData(df, ctx)
# 验证基本操作
assert len(data) == 100
assert "close" in data.to_polars().columns
# 测试 get_column
close_prices = data.get_column("close")
assert len(close_prices) == 100
# 测试 filter_by_date
first_date = df["trade_date"][0]
filtered = data.filter_by_date(first_date)
assert len(filtered) > 0
# 导入 FrozenInstanceError
try:
from dataclasses import FrozenInstanceError
except ImportError:
# Python < 3.10 compatibility
FrozenInstanceError = AttributeError