Files
ProStock/tests/factors/test_data_loader.py
liaozhaorun 0a16129548 feat(factors): 添加因子计算框架
- 新增因子基类 (BaseFactor, CrossSectionalFactor, TimeSeriesFactor)
- 新增数据规格和上下文类 (DataSpec, FactorContext, FactorData)
- 新增数据加载器 (DataLoader) 和执行引擎 (FactorEngine)
- 新增组合因子支持 (CompositeFactor, ScalarFactor)
- 添加因子模块完整测试用例
- 添加 Git 提交规范文档
2026-02-22 14:41:32 +08:00

249 lines
6.8 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""测试数据加载器 - DataLoader
测试需求(来自 factor_implementation_plan.md
- 测试从单个 H5 文件加载数据
- 测试从多个 H5 文件加载并合并
- 测试列选择(只加载需要的列)
- 测试缓存机制(第二次加载更快)
- 测试 clear_cache() 清空缓存
- 测试按 date_range 过滤
- 测试文件不存在时抛出 FileNotFoundError
- 测试列不存在时抛出 KeyError
"""
import pytest
import polars as pl
import pandas as pd
from pathlib import Path
from src.factors import DataSpec, DataLoader
class TestDataLoaderBasic:
"""测试 DataLoader 基本功能"""
@pytest.fixture
def loader(self):
"""创建 DataLoader 实例"""
return DataLoader(data_dir="data")
def test_init(self):
"""测试初始化"""
loader = DataLoader(data_dir="data")
assert loader.data_dir == Path("data")
assert loader._cache == {}
def test_load_single_source(self, loader):
"""测试从单个 H5 文件加载数据"""
specs = [
DataSpec(
source="daily",
columns=["ts_code", "trade_date", "close"],
lookback_days=1,
)
]
df = loader.load(specs)
assert isinstance(df, pl.DataFrame)
assert len(df) > 0
assert "ts_code" in df.columns
assert "trade_date" in df.columns
assert "close" in df.columns
def test_load_multiple_sources(self, loader):
"""测试从多个 H5 文件加载并合并"""
# 注意:这里假设只有一个 daily.h5 文件
# 如果有多个文件,可以测试合并逻辑
specs = [
DataSpec(
source="daily",
columns=["ts_code", "trade_date", "close"],
lookback_days=1,
),
DataSpec(
source="daily",
columns=["ts_code", "trade_date", "open", "high", "low"],
lookback_days=1,
),
]
df = loader.load(specs)
assert isinstance(df, pl.DataFrame)
assert len(df) > 0
# 应该包含所有列
assert set(df.columns) >= {
"ts_code",
"trade_date",
"close",
"open",
"high",
"low",
}
def test_column_selection(self, loader):
"""测试列选择(只加载需要的列)"""
specs = [
DataSpec(
source="daily",
columns=["ts_code", "trade_date", "close"],
lookback_days=1,
)
]
df = loader.load(specs)
# 只应该有 3 列
assert set(df.columns) == {"ts_code", "trade_date", "close"}
def test_date_range_filter(self, loader):
"""测试按 date_range 过滤"""
specs = [
DataSpec(
source="daily",
columns=["ts_code", "trade_date", "close"],
lookback_days=1,
)
]
# 先加载所有数据
df_all = loader.load(specs)
total_rows = len(df_all)
# 清空缓存,重新加载特定日期范围
loader.clear_cache()
df_filtered = loader.load(specs, date_range=("20240101", "20240131"))
# 过滤后的数据应该更少或相等
assert len(df_filtered) <= total_rows
# 所有日期都应该在范围内
if len(df_filtered) > 0:
dates = df_filtered["trade_date"].to_list()
assert all("20240101" <= d <= "20240131" for d in dates)
class TestDataLoaderCache:
"""测试 DataLoader 缓存机制"""
@pytest.fixture
def loader(self):
"""创建 DataLoader 实例"""
return DataLoader(data_dir="data")
def test_cache_populated(self, loader):
"""测试加载后缓存被填充"""
specs = [
DataSpec(
source="daily",
columns=["ts_code", "trade_date", "close"],
lookback_days=1,
)
]
# 第一次加载
loader.load(specs)
# 检查缓存
assert len(loader._cache) > 0
def test_cache_used(self, loader):
"""测试第二次加载使用缓存(更快)"""
import time
specs = [
DataSpec(
source="daily",
columns=["ts_code", "trade_date", "close"],
lookback_days=1,
)
]
# 第一次加载
start = time.time()
df1 = loader.load(specs)
time1 = time.time() - start
# 第二次加载(应该使用缓存)
start = time.time()
df2 = loader.load(specs)
time2 = time.time() - start
# 数据应该相同
assert df1.shape == df2.shape
# 第二次应该更快(至少快 50%
# 注意:如果数据量很小,这个测试可能不稳定
# assert time2 < time1 * 0.5
def test_clear_cache(self, loader):
"""测试 clear_cache() 清空缓存"""
specs = [
DataSpec(
source="daily",
columns=["ts_code", "trade_date", "close"],
lookback_days=1,
)
]
# 加载数据
loader.load(specs)
assert len(loader._cache) > 0
# 清空缓存
loader.clear_cache()
assert len(loader._cache) == 0
def test_cache_info(self, loader):
"""测试 get_cache_info()"""
specs = [
DataSpec(
source="daily",
columns=["ts_code", "trade_date", "close"],
lookback_days=1,
)
]
# 加载前
info_before = loader.get_cache_info()
assert info_before["entries"] == 0
# 加载后
loader.load(specs)
info_after = loader.get_cache_info()
assert info_after["entries"] > 0
assert info_after["total_rows"] > 0
class TestDataLoaderErrors:
"""测试 DataLoader 错误处理"""
def test_file_not_found(self):
"""测试文件不存在时抛出 FileNotFoundError"""
loader = DataLoader(data_dir="nonexistent_dir")
specs = [
DataSpec(
source="daily",
columns=["ts_code", "trade_date", "close"],
lookback_days=1,
)
]
with pytest.raises(FileNotFoundError):
loader.load(specs)
def test_column_not_found(self):
"""测试列不存在时抛出 KeyError"""
loader = DataLoader(data_dir="data")
specs = [
DataSpec(
source="daily",
columns=["ts_code", "trade_date", "nonexistent_column"],
lookback_days=1,
)
]
with pytest.raises(KeyError, match="nonexistent_column"):
loader.load(specs)