Files
ProStock/tests/factors/test_data_loader.py

249 lines
6.8 KiB
Python
Raw Normal View History

"""测试数据加载器 - DataLoader
测试需求来自 factor_implementation_plan.md
- 测试从单个 H5 文件加载数据
- 测试从多个 H5 文件加载并合并
- 测试列选择只加载需要的列
- 测试缓存机制第二次加载更快
- 测试 clear_cache() 清空缓存
- 测试按 date_range 过滤
- 测试文件不存在时抛出 FileNotFoundError
- 测试列不存在时抛出 KeyError
"""
import pytest
import polars as pl
import pandas as pd
from pathlib import Path
from src.factors import DataSpec, DataLoader
class TestDataLoaderBasic:
"""测试 DataLoader 基本功能"""
@pytest.fixture
def loader(self):
"""创建 DataLoader 实例"""
return DataLoader(data_dir="data")
def test_init(self):
"""测试初始化"""
loader = DataLoader(data_dir="data")
assert loader.data_dir == Path("data")
assert loader._cache == {}
def test_load_single_source(self, loader):
"""测试从单个 H5 文件加载数据"""
specs = [
DataSpec(
source="daily",
columns=["ts_code", "trade_date", "close"],
lookback_days=1,
)
]
df = loader.load(specs)
assert isinstance(df, pl.DataFrame)
assert len(df) > 0
assert "ts_code" in df.columns
assert "trade_date" in df.columns
assert "close" in df.columns
def test_load_multiple_sources(self, loader):
"""测试从多个 H5 文件加载并合并"""
# 注意:这里假设只有一个 daily.h5 文件
# 如果有多个文件,可以测试合并逻辑
specs = [
DataSpec(
source="daily",
columns=["ts_code", "trade_date", "close"],
lookback_days=1,
),
DataSpec(
source="daily",
columns=["ts_code", "trade_date", "open", "high", "low"],
lookback_days=1,
),
]
df = loader.load(specs)
assert isinstance(df, pl.DataFrame)
assert len(df) > 0
# 应该包含所有列
assert set(df.columns) >= {
"ts_code",
"trade_date",
"close",
"open",
"high",
"low",
}
def test_column_selection(self, loader):
"""测试列选择(只加载需要的列)"""
specs = [
DataSpec(
source="daily",
columns=["ts_code", "trade_date", "close"],
lookback_days=1,
)
]
df = loader.load(specs)
# 只应该有 3 列
assert set(df.columns) == {"ts_code", "trade_date", "close"}
def test_date_range_filter(self, loader):
"""测试按 date_range 过滤"""
specs = [
DataSpec(
source="daily",
columns=["ts_code", "trade_date", "close"],
lookback_days=1,
)
]
# 先加载所有数据
df_all = loader.load(specs)
total_rows = len(df_all)
# 清空缓存,重新加载特定日期范围
loader.clear_cache()
df_filtered = loader.load(specs, date_range=("20240101", "20240131"))
# 过滤后的数据应该更少或相等
assert len(df_filtered) <= total_rows
# 所有日期都应该在范围内
if len(df_filtered) > 0:
dates = df_filtered["trade_date"].to_list()
assert all("20240101" <= d <= "20240131" for d in dates)
class TestDataLoaderCache:
"""测试 DataLoader 缓存机制"""
@pytest.fixture
def loader(self):
"""创建 DataLoader 实例"""
return DataLoader(data_dir="data")
def test_cache_populated(self, loader):
"""测试加载后缓存被填充"""
specs = [
DataSpec(
source="daily",
columns=["ts_code", "trade_date", "close"],
lookback_days=1,
)
]
# 第一次加载
loader.load(specs)
# 检查缓存
assert len(loader._cache) > 0
def test_cache_used(self, loader):
"""测试第二次加载使用缓存(更快)"""
import time
specs = [
DataSpec(
source="daily",
columns=["ts_code", "trade_date", "close"],
lookback_days=1,
)
]
# 第一次加载
start = time.time()
df1 = loader.load(specs)
time1 = time.time() - start
# 第二次加载(应该使用缓存)
start = time.time()
df2 = loader.load(specs)
time2 = time.time() - start
# 数据应该相同
assert df1.shape == df2.shape
# 第二次应该更快(至少快 50%
# 注意:如果数据量很小,这个测试可能不稳定
# assert time2 < time1 * 0.5
def test_clear_cache(self, loader):
"""测试 clear_cache() 清空缓存"""
specs = [
DataSpec(
source="daily",
columns=["ts_code", "trade_date", "close"],
lookback_days=1,
)
]
# 加载数据
loader.load(specs)
assert len(loader._cache) > 0
# 清空缓存
loader.clear_cache()
assert len(loader._cache) == 0
def test_cache_info(self, loader):
"""测试 get_cache_info()"""
specs = [
DataSpec(
source="daily",
columns=["ts_code", "trade_date", "close"],
lookback_days=1,
)
]
# 加载前
info_before = loader.get_cache_info()
assert info_before["entries"] == 0
# 加载后
loader.load(specs)
info_after = loader.get_cache_info()
assert info_after["entries"] > 0
assert info_after["total_rows"] > 0
class TestDataLoaderErrors:
"""测试 DataLoader 错误处理"""
def test_file_not_found(self):
"""测试文件不存在时抛出 FileNotFoundError"""
loader = DataLoader(data_dir="nonexistent_dir")
specs = [
DataSpec(
source="daily",
columns=["ts_code", "trade_date", "close"],
lookback_days=1,
)
]
with pytest.raises(FileNotFoundError):
loader.load(specs)
def test_column_not_found(self):
"""测试列不存在时抛出 KeyError"""
loader = DataLoader(data_dir="data")
specs = [
DataSpec(
source="daily",
columns=["ts_code", "trade_date", "nonexistent_column"],
lookback_days=1,
)
]
with pytest.raises(KeyError, match="nonexistent_column"):
loader.load(specs)