2026-02-22 14:41:32 +08:00
|
|
|
|
"""测试数据加载器 - DataLoader
|
|
|
|
|
|
|
|
|
|
|
|
测试需求(来自 factor_implementation_plan.md):
|
2026-02-23 00:07:21 +08:00
|
|
|
|
- 测试从 DuckDB 加载数据
|
|
|
|
|
|
- 测试从多个查询加载并合并
|
2026-02-22 14:41:32 +08:00
|
|
|
|
- 测试列选择(只加载需要的列)
|
|
|
|
|
|
- 测试缓存机制(第二次加载更快)
|
|
|
|
|
|
- 测试 clear_cache() 清空缓存
|
|
|
|
|
|
- 测试按 date_range 过滤
|
2026-02-23 00:07:21 +08:00
|
|
|
|
- 测试表不存在时的处理
|
2026-02-22 14:41:32 +08:00
|
|
|
|
- 测试列不存在时抛出 KeyError
|
2026-02-23 00:07:21 +08:00
|
|
|
|
|
|
|
|
|
|
使用 3 个月的真实数据进行测试 (2024年1月-3月)
|
2026-02-22 14:41:32 +08:00
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
import pytest
|
|
|
|
|
|
import polars as pl
|
|
|
|
|
|
import pandas as pd
|
|
|
|
|
|
from pathlib import Path
|
|
|
|
|
|
|
|
|
|
|
|
from src.factors import DataSpec, DataLoader
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TestDataLoaderBasic:
|
|
|
|
|
|
"""测试 DataLoader 基本功能"""
|
|
|
|
|
|
|
2026-02-23 00:07:21 +08:00
|
|
|
|
# 测试数据时间范围:3个月
|
|
|
|
|
|
TEST_START_DATE = "20240101"
|
|
|
|
|
|
TEST_END_DATE = "20240331"
|
|
|
|
|
|
|
2026-02-22 14:41:32 +08:00
|
|
|
|
@pytest.fixture
|
|
|
|
|
|
def loader(self):
|
|
|
|
|
|
"""创建 DataLoader 实例"""
|
|
|
|
|
|
return DataLoader(data_dir="data")
|
|
|
|
|
|
|
|
|
|
|
|
def test_init(self):
|
|
|
|
|
|
"""测试初始化"""
|
|
|
|
|
|
loader = DataLoader(data_dir="data")
|
|
|
|
|
|
assert loader.data_dir == Path("data")
|
|
|
|
|
|
assert loader._cache == {}
|
|
|
|
|
|
|
|
|
|
|
|
def test_load_single_source(self, loader):
|
2026-02-23 00:07:21 +08:00
|
|
|
|
"""测试从 DuckDB 加载数据"""
|
2026-02-22 14:41:32 +08:00
|
|
|
|
specs = [
|
|
|
|
|
|
DataSpec(
|
|
|
|
|
|
source="daily",
|
|
|
|
|
|
columns=["ts_code", "trade_date", "close"],
|
|
|
|
|
|
lookback_days=1,
|
|
|
|
|
|
)
|
|
|
|
|
|
]
|
|
|
|
|
|
|
2026-02-23 00:07:21 +08:00
|
|
|
|
# 使用 3 个月日期范围限制数据量
|
|
|
|
|
|
df = loader.load(specs, date_range=(self.TEST_START_DATE, self.TEST_END_DATE))
|
2026-02-22 14:41:32 +08:00
|
|
|
|
|
|
|
|
|
|
assert isinstance(df, pl.DataFrame)
|
|
|
|
|
|
assert len(df) > 0
|
|
|
|
|
|
assert "ts_code" in df.columns
|
|
|
|
|
|
assert "trade_date" in df.columns
|
|
|
|
|
|
assert "close" in df.columns
|
|
|
|
|
|
|
2026-02-23 00:07:21 +08:00
|
|
|
|
def test_load_with_date_range(self, loader):
|
|
|
|
|
|
"""测试加载特定日期范围(3个月)"""
|
|
|
|
|
|
specs = [
|
|
|
|
|
|
DataSpec(
|
|
|
|
|
|
source="daily",
|
|
|
|
|
|
columns=["ts_code", "trade_date", "close", "open", "high", "low"],
|
|
|
|
|
|
lookback_days=1,
|
|
|
|
|
|
)
|
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
|
|
df = loader.load(specs, date_range=(self.TEST_START_DATE, self.TEST_END_DATE))
|
|
|
|
|
|
|
|
|
|
|
|
assert isinstance(df, pl.DataFrame)
|
|
|
|
|
|
assert len(df) > 0
|
|
|
|
|
|
|
|
|
|
|
|
# 验证日期范围
|
|
|
|
|
|
if len(df) > 0:
|
|
|
|
|
|
dates = df["trade_date"].to_list()
|
|
|
|
|
|
assert all(self.TEST_START_DATE <= d <= self.TEST_END_DATE for d in dates)
|
|
|
|
|
|
print(f"[TEST] Loaded {len(df)} rows from {min(dates)} to {max(dates)}")
|
|
|
|
|
|
|
|
|
|
|
|
def test_load_multiple_specs(self, loader):
|
|
|
|
|
|
"""测试从多个 DataSpec 加载并合并"""
|
2026-02-22 14:41:32 +08:00
|
|
|
|
specs = [
|
|
|
|
|
|
DataSpec(
|
|
|
|
|
|
source="daily",
|
|
|
|
|
|
columns=["ts_code", "trade_date", "close"],
|
|
|
|
|
|
lookback_days=1,
|
|
|
|
|
|
),
|
|
|
|
|
|
DataSpec(
|
|
|
|
|
|
source="daily",
|
|
|
|
|
|
columns=["ts_code", "trade_date", "open", "high", "low"],
|
|
|
|
|
|
lookback_days=1,
|
|
|
|
|
|
),
|
|
|
|
|
|
]
|
|
|
|
|
|
|
2026-02-23 00:07:21 +08:00
|
|
|
|
df = loader.load(specs, date_range=(self.TEST_START_DATE, self.TEST_END_DATE))
|
2026-02-22 14:41:32 +08:00
|
|
|
|
|
|
|
|
|
|
assert isinstance(df, pl.DataFrame)
|
|
|
|
|
|
assert len(df) > 0
|
|
|
|
|
|
# 应该包含所有列
|
|
|
|
|
|
assert set(df.columns) >= {
|
|
|
|
|
|
"ts_code",
|
|
|
|
|
|
"trade_date",
|
|
|
|
|
|
"close",
|
|
|
|
|
|
"open",
|
|
|
|
|
|
"high",
|
|
|
|
|
|
"low",
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
def test_column_selection(self, loader):
|
|
|
|
|
|
"""测试列选择(只加载需要的列)"""
|
|
|
|
|
|
specs = [
|
|
|
|
|
|
DataSpec(
|
|
|
|
|
|
source="daily",
|
|
|
|
|
|
columns=["ts_code", "trade_date", "close"],
|
|
|
|
|
|
lookback_days=1,
|
|
|
|
|
|
)
|
|
|
|
|
|
]
|
|
|
|
|
|
|
2026-02-23 00:07:21 +08:00
|
|
|
|
df = loader.load(specs, date_range=(self.TEST_START_DATE, self.TEST_END_DATE))
|
2026-02-22 14:41:32 +08:00
|
|
|
|
|
|
|
|
|
|
# 只应该有 3 列
|
|
|
|
|
|
assert set(df.columns) == {"ts_code", "trade_date", "close"}
|
|
|
|
|
|
|
|
|
|
|
|
def test_date_range_filter(self, loader):
|
2026-02-23 00:07:21 +08:00
|
|
|
|
"""测试按 date_range 过滤 - 使用3个月数据的不同子集"""
|
2026-02-22 14:41:32 +08:00
|
|
|
|
specs = [
|
|
|
|
|
|
DataSpec(
|
|
|
|
|
|
source="daily",
|
|
|
|
|
|
columns=["ts_code", "trade_date", "close"],
|
|
|
|
|
|
lookback_days=1,
|
|
|
|
|
|
)
|
|
|
|
|
|
]
|
|
|
|
|
|
|
2026-02-23 00:07:21 +08:00
|
|
|
|
# 加载完整的3个月数据
|
|
|
|
|
|
df_all = loader.load(
|
|
|
|
|
|
specs, date_range=(self.TEST_START_DATE, self.TEST_END_DATE)
|
|
|
|
|
|
)
|
2026-02-22 14:41:32 +08:00
|
|
|
|
total_rows = len(df_all)
|
|
|
|
|
|
|
2026-02-23 00:07:21 +08:00
|
|
|
|
# 清空缓存,重新加载1个月数据
|
2026-02-22 14:41:32 +08:00
|
|
|
|
loader.clear_cache()
|
|
|
|
|
|
df_filtered = loader.load(specs, date_range=("20240101", "20240131"))
|
|
|
|
|
|
|
|
|
|
|
|
# 过滤后的数据应该更少或相等
|
|
|
|
|
|
assert len(df_filtered) <= total_rows
|
|
|
|
|
|
|
|
|
|
|
|
# 所有日期都应该在范围内
|
|
|
|
|
|
if len(df_filtered) > 0:
|
|
|
|
|
|
dates = df_filtered["trade_date"].to_list()
|
|
|
|
|
|
assert all("20240101" <= d <= "20240131" for d in dates)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TestDataLoaderCache:
|
|
|
|
|
|
"""测试 DataLoader 缓存机制"""
|
|
|
|
|
|
|
2026-02-23 00:07:21 +08:00
|
|
|
|
TEST_START_DATE = "20240101"
|
|
|
|
|
|
TEST_END_DATE = "20240331"
|
|
|
|
|
|
|
2026-02-22 14:41:32 +08:00
|
|
|
|
@pytest.fixture
|
|
|
|
|
|
def loader(self):
|
|
|
|
|
|
"""创建 DataLoader 实例"""
|
|
|
|
|
|
return DataLoader(data_dir="data")
|
|
|
|
|
|
|
|
|
|
|
|
def test_cache_populated(self, loader):
|
|
|
|
|
|
"""测试加载后缓存被填充"""
|
|
|
|
|
|
specs = [
|
|
|
|
|
|
DataSpec(
|
|
|
|
|
|
source="daily",
|
|
|
|
|
|
columns=["ts_code", "trade_date", "close"],
|
|
|
|
|
|
lookback_days=1,
|
|
|
|
|
|
)
|
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
|
|
# 第一次加载
|
2026-02-23 00:07:21 +08:00
|
|
|
|
loader.load(specs, date_range=(self.TEST_START_DATE, self.TEST_END_DATE))
|
2026-02-22 14:41:32 +08:00
|
|
|
|
|
|
|
|
|
|
# 检查缓存
|
|
|
|
|
|
assert len(loader._cache) > 0
|
|
|
|
|
|
|
|
|
|
|
|
def test_cache_used(self, loader):
|
|
|
|
|
|
"""测试第二次加载使用缓存(更快)"""
|
|
|
|
|
|
import time
|
|
|
|
|
|
|
|
|
|
|
|
specs = [
|
|
|
|
|
|
DataSpec(
|
|
|
|
|
|
source="daily",
|
|
|
|
|
|
columns=["ts_code", "trade_date", "close"],
|
|
|
|
|
|
lookback_days=1,
|
|
|
|
|
|
)
|
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
|
|
# 第一次加载
|
|
|
|
|
|
start = time.time()
|
2026-02-23 00:07:21 +08:00
|
|
|
|
df1 = loader.load(specs, date_range=(self.TEST_START_DATE, self.TEST_END_DATE))
|
2026-02-22 14:41:32 +08:00
|
|
|
|
time1 = time.time() - start
|
|
|
|
|
|
|
|
|
|
|
|
# 第二次加载(应该使用缓存)
|
|
|
|
|
|
start = time.time()
|
2026-02-23 00:07:21 +08:00
|
|
|
|
df2 = loader.load(specs, date_range=(self.TEST_START_DATE, self.TEST_END_DATE))
|
2026-02-22 14:41:32 +08:00
|
|
|
|
time2 = time.time() - start
|
|
|
|
|
|
|
|
|
|
|
|
# 数据应该相同
|
|
|
|
|
|
assert df1.shape == df2.shape
|
|
|
|
|
|
|
2026-02-23 00:07:21 +08:00
|
|
|
|
# 第二次应该更快
|
|
|
|
|
|
print(f"[TEST] First load: {time1:.3f}s, cached load: {time2:.3f}s")
|
|
|
|
|
|
assert time2 < time1, "Cached load should be faster"
|
2026-02-22 14:41:32 +08:00
|
|
|
|
|
|
|
|
|
|
def test_clear_cache(self, loader):
|
|
|
|
|
|
"""测试 clear_cache() 清空缓存"""
|
|
|
|
|
|
specs = [
|
|
|
|
|
|
DataSpec(
|
|
|
|
|
|
source="daily",
|
|
|
|
|
|
columns=["ts_code", "trade_date", "close"],
|
|
|
|
|
|
lookback_days=1,
|
|
|
|
|
|
)
|
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
|
|
# 加载数据
|
2026-02-23 00:07:21 +08:00
|
|
|
|
loader.load(specs, date_range=(self.TEST_START_DATE, self.TEST_END_DATE))
|
2026-02-22 14:41:32 +08:00
|
|
|
|
assert len(loader._cache) > 0
|
|
|
|
|
|
|
|
|
|
|
|
# 清空缓存
|
|
|
|
|
|
loader.clear_cache()
|
|
|
|
|
|
assert len(loader._cache) == 0
|
|
|
|
|
|
|
|
|
|
|
|
def test_cache_info(self, loader):
|
|
|
|
|
|
"""测试 get_cache_info()"""
|
|
|
|
|
|
specs = [
|
|
|
|
|
|
DataSpec(
|
|
|
|
|
|
source="daily",
|
|
|
|
|
|
columns=["ts_code", "trade_date", "close"],
|
|
|
|
|
|
lookback_days=1,
|
|
|
|
|
|
)
|
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
|
|
# 加载前
|
|
|
|
|
|
info_before = loader.get_cache_info()
|
|
|
|
|
|
assert info_before["entries"] == 0
|
|
|
|
|
|
|
|
|
|
|
|
# 加载后
|
2026-02-23 00:07:21 +08:00
|
|
|
|
loader.load(specs, date_range=(self.TEST_START_DATE, self.TEST_END_DATE))
|
2026-02-22 14:41:32 +08:00
|
|
|
|
info_after = loader.get_cache_info()
|
|
|
|
|
|
assert info_after["entries"] > 0
|
|
|
|
|
|
assert info_after["total_rows"] > 0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TestDataLoaderErrors:
|
|
|
|
|
|
"""测试 DataLoader 错误处理"""
|
|
|
|
|
|
|
2026-02-23 00:07:21 +08:00
|
|
|
|
def test_table_not_exists(self):
|
|
|
|
|
|
"""测试表不存在时的处理"""
|
|
|
|
|
|
loader = DataLoader(data_dir="data")
|
2026-02-22 14:41:32 +08:00
|
|
|
|
specs = [
|
|
|
|
|
|
DataSpec(
|
2026-02-23 00:07:21 +08:00
|
|
|
|
source="nonexistent_table",
|
2026-02-22 14:41:32 +08:00
|
|
|
|
columns=["ts_code", "trade_date", "close"],
|
|
|
|
|
|
lookback_days=1,
|
|
|
|
|
|
)
|
|
|
|
|
|
]
|
|
|
|
|
|
|
2026-02-23 00:07:21 +08:00
|
|
|
|
# 应该返回空 DataFrame 或抛出异常
|
|
|
|
|
|
with pytest.raises(Exception):
|
2026-02-22 14:41:32 +08:00
|
|
|
|
loader.load(specs)
|
|
|
|
|
|
|
|
|
|
|
|
def test_column_not_found(self):
|
|
|
|
|
|
"""测试列不存在时抛出 KeyError"""
|
|
|
|
|
|
loader = DataLoader(data_dir="data")
|
|
|
|
|
|
specs = [
|
|
|
|
|
|
DataSpec(
|
|
|
|
|
|
source="daily",
|
|
|
|
|
|
columns=["ts_code", "trade_date", "nonexistent_column"],
|
|
|
|
|
|
lookback_days=1,
|
|
|
|
|
|
)
|
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
|
|
with pytest.raises(KeyError, match="nonexistent_column"):
|
|
|
|
|
|
loader.load(specs)
|
2026-02-23 00:07:21 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
|
|
pytest.main([__file__, "-v", "-s"])
|