Files
ProStock/tests/factors/test_data_loader.py
liaozhaorun e58b39970c feat: HDF5迁移至DuckDB存储
- 新增DuckDB Storage与ThreadSafeStorage实现
- 新增db_manager模块支持增量同步策略
- DataLoader与Sync模块适配DuckDB
- 补充迁移相关文档与测试
- 修复README文档链接
2026-02-23 00:07:21 +08:00

285 lines
8.3 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""测试数据加载器 - DataLoader
测试需求(来自 factor_implementation_plan.md
- 测试从 DuckDB 加载数据
- 测试从多个查询加载并合并
- 测试列选择(只加载需要的列)
- 测试缓存机制(第二次加载更快)
- 测试 clear_cache() 清空缓存
- 测试按 date_range 过滤
- 测试表不存在时的处理
- 测试列不存在时抛出 KeyError
使用 3 个月的真实数据进行测试 (2024年1月-3月)
"""
import pytest
import polars as pl
import pandas as pd
from pathlib import Path
from src.factors import DataSpec, DataLoader
class TestDataLoaderBasic:
"""测试 DataLoader 基本功能"""
# 测试数据时间范围3个月
TEST_START_DATE = "20240101"
TEST_END_DATE = "20240331"
@pytest.fixture
def loader(self):
"""创建 DataLoader 实例"""
return DataLoader(data_dir="data")
def test_init(self):
"""测试初始化"""
loader = DataLoader(data_dir="data")
assert loader.data_dir == Path("data")
assert loader._cache == {}
def test_load_single_source(self, loader):
"""测试从 DuckDB 加载数据"""
specs = [
DataSpec(
source="daily",
columns=["ts_code", "trade_date", "close"],
lookback_days=1,
)
]
# 使用 3 个月日期范围限制数据量
df = loader.load(specs, date_range=(self.TEST_START_DATE, self.TEST_END_DATE))
assert isinstance(df, pl.DataFrame)
assert len(df) > 0
assert "ts_code" in df.columns
assert "trade_date" in df.columns
assert "close" in df.columns
def test_load_with_date_range(self, loader):
"""测试加载特定日期范围3个月"""
specs = [
DataSpec(
source="daily",
columns=["ts_code", "trade_date", "close", "open", "high", "low"],
lookback_days=1,
)
]
df = loader.load(specs, date_range=(self.TEST_START_DATE, self.TEST_END_DATE))
assert isinstance(df, pl.DataFrame)
assert len(df) > 0
# 验证日期范围
if len(df) > 0:
dates = df["trade_date"].to_list()
assert all(self.TEST_START_DATE <= d <= self.TEST_END_DATE for d in dates)
print(f"[TEST] Loaded {len(df)} rows from {min(dates)} to {max(dates)}")
def test_load_multiple_specs(self, loader):
"""测试从多个 DataSpec 加载并合并"""
specs = [
DataSpec(
source="daily",
columns=["ts_code", "trade_date", "close"],
lookback_days=1,
),
DataSpec(
source="daily",
columns=["ts_code", "trade_date", "open", "high", "low"],
lookback_days=1,
),
]
df = loader.load(specs, date_range=(self.TEST_START_DATE, self.TEST_END_DATE))
assert isinstance(df, pl.DataFrame)
assert len(df) > 0
# 应该包含所有列
assert set(df.columns) >= {
"ts_code",
"trade_date",
"close",
"open",
"high",
"low",
}
def test_column_selection(self, loader):
"""测试列选择(只加载需要的列)"""
specs = [
DataSpec(
source="daily",
columns=["ts_code", "trade_date", "close"],
lookback_days=1,
)
]
df = loader.load(specs, date_range=(self.TEST_START_DATE, self.TEST_END_DATE))
# 只应该有 3 列
assert set(df.columns) == {"ts_code", "trade_date", "close"}
def test_date_range_filter(self, loader):
"""测试按 date_range 过滤 - 使用3个月数据的不同子集"""
specs = [
DataSpec(
source="daily",
columns=["ts_code", "trade_date", "close"],
lookback_days=1,
)
]
# 加载完整的3个月数据
df_all = loader.load(
specs, date_range=(self.TEST_START_DATE, self.TEST_END_DATE)
)
total_rows = len(df_all)
# 清空缓存重新加载1个月数据
loader.clear_cache()
df_filtered = loader.load(specs, date_range=("20240101", "20240131"))
# 过滤后的数据应该更少或相等
assert len(df_filtered) <= total_rows
# 所有日期都应该在范围内
if len(df_filtered) > 0:
dates = df_filtered["trade_date"].to_list()
assert all("20240101" <= d <= "20240131" for d in dates)
class TestDataLoaderCache:
"""测试 DataLoader 缓存机制"""
TEST_START_DATE = "20240101"
TEST_END_DATE = "20240331"
@pytest.fixture
def loader(self):
"""创建 DataLoader 实例"""
return DataLoader(data_dir="data")
def test_cache_populated(self, loader):
"""测试加载后缓存被填充"""
specs = [
DataSpec(
source="daily",
columns=["ts_code", "trade_date", "close"],
lookback_days=1,
)
]
# 第一次加载
loader.load(specs, date_range=(self.TEST_START_DATE, self.TEST_END_DATE))
# 检查缓存
assert len(loader._cache) > 0
def test_cache_used(self, loader):
"""测试第二次加载使用缓存(更快)"""
import time
specs = [
DataSpec(
source="daily",
columns=["ts_code", "trade_date", "close"],
lookback_days=1,
)
]
# 第一次加载
start = time.time()
df1 = loader.load(specs, date_range=(self.TEST_START_DATE, self.TEST_END_DATE))
time1 = time.time() - start
# 第二次加载(应该使用缓存)
start = time.time()
df2 = loader.load(specs, date_range=(self.TEST_START_DATE, self.TEST_END_DATE))
time2 = time.time() - start
# 数据应该相同
assert df1.shape == df2.shape
# 第二次应该更快
print(f"[TEST] First load: {time1:.3f}s, cached load: {time2:.3f}s")
assert time2 < time1, "Cached load should be faster"
def test_clear_cache(self, loader):
"""测试 clear_cache() 清空缓存"""
specs = [
DataSpec(
source="daily",
columns=["ts_code", "trade_date", "close"],
lookback_days=1,
)
]
# 加载数据
loader.load(specs, date_range=(self.TEST_START_DATE, self.TEST_END_DATE))
assert len(loader._cache) > 0
# 清空缓存
loader.clear_cache()
assert len(loader._cache) == 0
def test_cache_info(self, loader):
"""测试 get_cache_info()"""
specs = [
DataSpec(
source="daily",
columns=["ts_code", "trade_date", "close"],
lookback_days=1,
)
]
# 加载前
info_before = loader.get_cache_info()
assert info_before["entries"] == 0
# 加载后
loader.load(specs, date_range=(self.TEST_START_DATE, self.TEST_END_DATE))
info_after = loader.get_cache_info()
assert info_after["entries"] > 0
assert info_after["total_rows"] > 0
class TestDataLoaderErrors:
"""测试 DataLoader 错误处理"""
def test_table_not_exists(self):
"""测试表不存在时的处理"""
loader = DataLoader(data_dir="data")
specs = [
DataSpec(
source="nonexistent_table",
columns=["ts_code", "trade_date", "close"],
lookback_days=1,
)
]
# 应该返回空 DataFrame 或抛出异常
with pytest.raises(Exception):
loader.load(specs)
def test_column_not_found(self):
"""测试列不存在时抛出 KeyError"""
loader = DataLoader(data_dir="data")
specs = [
DataSpec(
source="daily",
columns=["ts_code", "trade_date", "nonexistent_column"],
lookback_days=1,
)
]
with pytest.raises(KeyError, match="nonexistent_column"):
loader.load(specs)
if __name__ == "__main__":
pytest.main([__file__, "-v", "-s"])