"""Factors 模块测试 - Phase 1: 数据类型定义测试 测试范围: - DataSpec: 数据需求规格的创建和验证 - FactorContext: 计算上下文的创建 - FactorData: 数据容器的基本操作 - HDF5 数据读取: 验证能正确读取 daily.h5 文件 """ import pytest import polars as pl import pandas as pd from pathlib import Path from src.factors import DataSpec, FactorContext, FactorData class TestDataSpec: """测试 DataSpec 数据需求规格""" def test_valid_dataspec_creation(self): """测试有效的 DataSpec 创建""" spec = DataSpec( source="daily", columns=["ts_code", "trade_date", "close"], lookback_days=5 ) assert spec.source == "daily" assert spec.columns == ["ts_code", "trade_date", "close"] assert spec.lookback_days == 5 def test_dataspec_default_lookback(self): """测试 DataSpec 默认值 lookback_days=1""" spec = DataSpec(source="daily", columns=["ts_code", "trade_date", "close"]) assert spec.lookback_days == 1 def test_dataspec_frozen_immutable(self): """测试 DataSpec 是 frozen(不可变)""" spec = DataSpec( source="daily", columns=["ts_code", "trade_date", "close"], lookback_days=5 ) with pytest.raises(FrozenInstanceError): spec.source = "other" def test_dataspec_lookback_less_than_1_raises(self): """测试 lookback_days < 1 时抛出 ValueError""" with pytest.raises(ValueError, match="lookback_days must be >= 1"): DataSpec( source="daily", columns=["ts_code", "trade_date", "close"], lookback_days=0, ) with pytest.raises(ValueError, match="lookback_days must be >= 1"): DataSpec( source="daily", columns=["ts_code", "trade_date", "close"], lookback_days=-1, ) def test_dataspec_missing_required_columns_raises(self): """测试缺少 ts_code 或 trade_date 时抛出 ValueError""" # 缺少 ts_code with pytest.raises(ValueError, match="columns must contain"): DataSpec(source="daily", columns=["trade_date", "close"], lookback_days=5) # 缺少 trade_date with pytest.raises(ValueError, match="columns must contain"): DataSpec(source="daily", columns=["ts_code", "close"], lookback_days=5) # 两者都缺少 with pytest.raises(ValueError, match="columns must contain"): DataSpec(source="daily", columns=["close", "open", "high"], lookback_days=5) def test_dataspec_empty_source_raises(self): """测试空 source 时抛出 ValueError""" with pytest.raises(ValueError, match="source cannot be empty string"): DataSpec( source="", columns=["ts_code", "trade_date", "close"], lookback_days=5 ) class TestFactorContext: """测试 FactorContext 计算上下文""" def test_default_creation(self): """测试默认值创建""" ctx = FactorContext() assert ctx.current_date is None assert ctx.current_stock is None assert ctx.trade_dates is None def test_full_creation(self): """测试完整参数创建""" ctx = FactorContext( current_date="20240101", current_stock="000001.SZ", trade_dates=["20240101", "20240102", "20240103"], ) assert ctx.current_date == "20240101" assert ctx.current_stock == "000001.SZ" assert ctx.trade_dates == ["20240101", "20240102", "20240103"] def test_partial_creation(self): """测试部分参数创建""" ctx = FactorContext(current_date="20240101") assert ctx.current_date == "20240101" assert ctx.current_stock is None assert ctx.trade_dates is None def test_dataclass_methods(self): """测试 dataclass 自动生成的方法""" ctx = FactorContext(current_date="20240101") # __repr__ assert "FactorContext" in repr(ctx) assert "20240101" in repr(ctx) # __eq__ ctx2 = FactorContext(current_date="20240101") assert ctx == ctx2 ctx3 = FactorContext(current_date="20240102") assert ctx != ctx3 class TestFactorData: """测试 FactorData 数据容器""" @pytest.fixture def sample_df(self): """创建示例 DataFrame""" return pl.DataFrame( { "ts_code": ["000001.SZ", "000002.SZ", "000001.SZ", "000002.SZ"], "trade_date": ["20240101", "20240101", "20240102", "20240102"], "close": [10.0, 20.0, 10.5, 20.5], "volume": [1000, 2000, 1100, 2100], } ) @pytest.fixture def cs_context(self): """截面因子上下文""" return FactorContext(current_date="20240101") @pytest.fixture def ts_context(self): """时序因子上下文""" return FactorContext(current_stock="000001.SZ") def test_get_column(self, sample_df, cs_context): """测试 get_column 返回正确的 Series""" data = FactorData(sample_df, cs_context) close_series = data.get_column("close") assert isinstance(close_series, pl.Series) assert close_series.to_list() == [10.0, 20.0, 10.5, 20.5] def test_get_column_keyerror(self, sample_df, cs_context): """测试 get_column 列不存在时抛出 KeyError""" data = FactorData(sample_df, cs_context) with pytest.raises(KeyError, match="Column 'nonexistent' not found"): data.get_column("nonexistent") def test_filter_by_date(self, sample_df, cs_context): """测试 filter_by_date 返回正确的过滤结果""" data = FactorData(sample_df, cs_context) filtered = data.filter_by_date("20240101") assert len(filtered) == 2 assert filtered.to_polars()["ts_code"].to_list() == ["000001.SZ", "000002.SZ"] assert filtered.to_polars()["close"].to_list() == [10.0, 20.0] def test_filter_by_date_empty_result(self, sample_df, cs_context): """测试 filter_by_date 日期不存在时返回空的 FactorData""" data = FactorData(sample_df, cs_context) filtered = data.filter_by_date("20241231") assert len(filtered) == 0 assert isinstance(filtered, FactorData) def test_get_cross_section(self, sample_df, cs_context): """测试 get_cross_section 返回 current_date 当天的数据""" data = FactorData(sample_df, cs_context) cs = data.get_cross_section() assert len(cs) == 2 assert cs["ts_code"].to_list() == ["000001.SZ", "000002.SZ"] def test_get_cross_section_no_date_raises(self, sample_df, ts_context): """测试 get_cross_section current_date 为 None 时抛出 ValueError""" data = FactorData(sample_df, ts_context) with pytest.raises(ValueError, match="current_date is not set"): data.get_cross_section() def test_to_polars(self, sample_df, cs_context): """测试 to_polars 返回原始 DataFrame""" data = FactorData(sample_df, cs_context) df = data.to_polars() assert isinstance(df, pl.DataFrame) assert df.shape == sample_df.shape assert df.columns == sample_df.columns def test_context_property(self, sample_df, cs_context): """测试 context 属性返回正确的上下文""" data = FactorData(sample_df, cs_context) assert data.context == cs_context assert data.context.current_date == "20240101" def test_len(self, sample_df, cs_context): """测试 __len__ 返回正确的行数""" data = FactorData(sample_df, cs_context) assert len(data) == 4 def test_repr(self, sample_df, cs_context): """测试 __repr__ 返回可读字符串""" data = FactorData(sample_df, cs_context) repr_str = repr(data) assert "FactorData" in repr_str assert "rows=4" in repr_str assert "date=20240101" in repr_str class TestHDF5DataAccess: """测试 HDF5 数据读取功能""" def test_daily_h5_file_exists(self): """测试 daily.h5 文件存在""" data_path = Path("data/daily.h5") assert data_path.exists(), f"daily.h5 文件不存在: {data_path.absolute()}" def test_daily_h5_can_read_with_pandas(self): """测试能用 pandas 读取 daily.h5""" data_path = Path("data/daily.h5") df = pd.read_hdf(data_path, key="/daily") assert df is not None assert len(df) > 0 assert "ts_code" in df.columns assert "trade_date" in df.columns assert "close" in df.columns def test_daily_h5_columns(self): """测试 daily.h5 包含预期的列""" data_path = Path("data/daily.h5") df = pd.read_hdf(data_path, key="/daily") expected_columns = [ "trade_date", "ts_code", "open", "high", "low", "close", "pre_close", "change", "pct_chg", "vol", "amount", "turnover_rate", "volume_ratio", ] for col in expected_columns: assert col in df.columns, f"列 {col} 不存在于 daily.h5" def test_daily_h5_date_format(self): """测试 daily.h5 日期格式正确""" data_path = Path("data/daily.h5") df = pd.read_hdf(data_path, key="/daily") # 检查日期格式是 YYYYMMDD 字符串 sample_date = df["trade_date"].iloc[0] assert isinstance(sample_date, str), "日期应该是字符串格式" assert len(sample_date) == 8, "日期应该是 8 位字符串 (YYYYMMDD)" assert sample_date.isdigit(), "日期应该只包含数字" def test_daily_h5_stock_format(self): """测试 daily.h5 股票代码格式正确""" data_path = Path("data/daily.h5") df = pd.read_hdf(data_path, key="/daily") # 检查股票代码格式如 "000001.SZ" sample_code = df["ts_code"].iloc[0] assert isinstance(sample_code, str), "股票代码应该是字符串" assert "." in sample_code, "股票代码应该包含交易所后缀" assert sample_code.endswith((".SZ", ".SH", ".BJ")), ( "股票代码应该以交易所后缀结尾" ) def test_daily_h5_to_polars(self): """测试将 daily.h5 数据转换为 Polars""" data_path = Path("data/daily.h5") pdf = pd.read_hdf(data_path, key="/daily") # 转换为 Polars df = pl.from_pandas(pdf) assert isinstance(df, pl.DataFrame) assert len(df) > 0 assert "ts_code" in df.columns assert "trade_date" in df.columns def test_daily_h5_sample_data_with_factors(self): """测试用 daily.h5 真实数据创建 FactorData""" data_path = Path("data/daily.h5") pdf = pd.read_hdf(data_path, key="/daily") # 取前 100 行作为示例 sample_pdf = pdf.head(100) df = pl.from_pandas(sample_pdf) # 创建 FactorData ctx = FactorContext(current_date=df["trade_date"][0]) data = FactorData(df, ctx) # 验证基本操作 assert len(data) == 100 assert "close" in data.to_polars().columns # 测试 get_column close_prices = data.get_column("close") assert len(close_prices) == 100 # 测试 filter_by_date first_date = df["trade_date"][0] filtered = data.filter_by_date(first_date) assert len(filtered) > 0 # 导入 FrozenInstanceError try: from dataclasses import FrozenInstanceError except ImportError: # Python < 3.10 compatibility FrozenInstanceError = AttributeError