"""测试数据加载器 - DataLoader 测试需求(来自 factor_implementation_plan.md): - 测试从 DuckDB 加载数据 - 测试从多个查询加载并合并 - 测试列选择(只加载需要的列) - 测试缓存机制(第二次加载更快) - 测试 clear_cache() 清空缓存 - 测试按 date_range 过滤 - 测试表不存在时的处理 - 测试列不存在时抛出 KeyError 使用 3 个月的真实数据进行测试 (2024年1月-3月) """ import pytest import polars as pl import pandas as pd from pathlib import Path from src.factors import DataSpec, DataLoader class TestDataLoaderBasic: """测试 DataLoader 基本功能""" # 测试数据时间范围:3个月 TEST_START_DATE = "20240101" TEST_END_DATE = "20240331" @pytest.fixture def loader(self): """创建 DataLoader 实例""" return DataLoader(data_dir="data") def test_init(self): """测试初始化""" loader = DataLoader(data_dir="data") assert loader.data_dir == Path("data") assert loader._cache == {} def test_load_single_source(self, loader): """测试从 DuckDB 加载数据""" specs = [ DataSpec( source="daily", columns=["ts_code", "trade_date", "close"], lookback_days=1, ) ] # 使用 3 个月日期范围限制数据量 df = loader.load(specs, date_range=(self.TEST_START_DATE, self.TEST_END_DATE)) assert isinstance(df, pl.DataFrame) assert len(df) > 0 assert "ts_code" in df.columns assert "trade_date" in df.columns assert "close" in df.columns def test_load_with_date_range(self, loader): """测试加载特定日期范围(3个月)""" specs = [ DataSpec( source="daily", columns=["ts_code", "trade_date", "close", "open", "high", "low"], lookback_days=1, ) ] df = loader.load(specs, date_range=(self.TEST_START_DATE, self.TEST_END_DATE)) assert isinstance(df, pl.DataFrame) assert len(df) > 0 # 验证日期范围 if len(df) > 0: dates = df["trade_date"].to_list() assert all(self.TEST_START_DATE <= d <= self.TEST_END_DATE for d in dates) print(f"[TEST] Loaded {len(df)} rows from {min(dates)} to {max(dates)}") def test_load_multiple_specs(self, loader): """测试从多个 DataSpec 加载并合并""" specs = [ DataSpec( source="daily", columns=["ts_code", "trade_date", "close"], lookback_days=1, ), DataSpec( source="daily", columns=["ts_code", "trade_date", "open", "high", "low"], lookback_days=1, ), ] df = loader.load(specs, date_range=(self.TEST_START_DATE, self.TEST_END_DATE)) assert isinstance(df, pl.DataFrame) assert len(df) > 0 # 应该包含所有列 assert set(df.columns) >= { "ts_code", "trade_date", "close", "open", "high", "low", } def test_column_selection(self, loader): """测试列选择(只加载需要的列)""" specs = [ DataSpec( source="daily", columns=["ts_code", "trade_date", "close"], lookback_days=1, ) ] df = loader.load(specs, date_range=(self.TEST_START_DATE, self.TEST_END_DATE)) # 只应该有 3 列 assert set(df.columns) == {"ts_code", "trade_date", "close"} def test_date_range_filter(self, loader): """测试按 date_range 过滤 - 使用3个月数据的不同子集""" specs = [ DataSpec( source="daily", columns=["ts_code", "trade_date", "close"], lookback_days=1, ) ] # 加载完整的3个月数据 df_all = loader.load( specs, date_range=(self.TEST_START_DATE, self.TEST_END_DATE) ) total_rows = len(df_all) # 清空缓存,重新加载1个月数据 loader.clear_cache() df_filtered = loader.load(specs, date_range=("20240101", "20240131")) # 过滤后的数据应该更少或相等 assert len(df_filtered) <= total_rows # 所有日期都应该在范围内 if len(df_filtered) > 0: dates = df_filtered["trade_date"].to_list() assert all("20240101" <= d <= "20240131" for d in dates) class TestDataLoaderCache: """测试 DataLoader 缓存机制""" TEST_START_DATE = "20240101" TEST_END_DATE = "20240331" @pytest.fixture def loader(self): """创建 DataLoader 实例""" return DataLoader(data_dir="data") def test_cache_populated(self, loader): """测试加载后缓存被填充""" specs = [ DataSpec( source="daily", columns=["ts_code", "trade_date", "close"], lookback_days=1, ) ] # 第一次加载 loader.load(specs, date_range=(self.TEST_START_DATE, self.TEST_END_DATE)) # 检查缓存 assert len(loader._cache) > 0 def test_cache_used(self, loader): """测试第二次加载使用缓存(更快)""" import time specs = [ DataSpec( source="daily", columns=["ts_code", "trade_date", "close"], lookback_days=1, ) ] # 第一次加载 start = time.time() df1 = loader.load(specs, date_range=(self.TEST_START_DATE, self.TEST_END_DATE)) time1 = time.time() - start # 第二次加载(应该使用缓存) start = time.time() df2 = loader.load(specs, date_range=(self.TEST_START_DATE, self.TEST_END_DATE)) time2 = time.time() - start # 数据应该相同 assert df1.shape == df2.shape # 第二次应该更快 print(f"[TEST] First load: {time1:.3f}s, cached load: {time2:.3f}s") assert time2 < time1, "Cached load should be faster" def test_clear_cache(self, loader): """测试 clear_cache() 清空缓存""" specs = [ DataSpec( source="daily", columns=["ts_code", "trade_date", "close"], lookback_days=1, ) ] # 加载数据 loader.load(specs, date_range=(self.TEST_START_DATE, self.TEST_END_DATE)) assert len(loader._cache) > 0 # 清空缓存 loader.clear_cache() assert len(loader._cache) == 0 def test_cache_info(self, loader): """测试 get_cache_info()""" specs = [ DataSpec( source="daily", columns=["ts_code", "trade_date", "close"], lookback_days=1, ) ] # 加载前 info_before = loader.get_cache_info() assert info_before["entries"] == 0 # 加载后 loader.load(specs, date_range=(self.TEST_START_DATE, self.TEST_END_DATE)) info_after = loader.get_cache_info() assert info_after["entries"] > 0 assert info_after["total_rows"] > 0 class TestDataLoaderErrors: """测试 DataLoader 错误处理""" def test_table_not_exists(self): """测试表不存在时的处理""" loader = DataLoader(data_dir="data") specs = [ DataSpec( source="nonexistent_table", columns=["ts_code", "trade_date", "close"], lookback_days=1, ) ] # 应该返回空 DataFrame 或抛出异常 with pytest.raises(Exception): loader.load(specs) def test_column_not_found(self): """测试列不存在时抛出 KeyError""" loader = DataLoader(data_dir="data") specs = [ DataSpec( source="daily", columns=["ts_code", "trade_date", "nonexistent_column"], lookback_days=1, ) ] with pytest.raises(KeyError, match="nonexistent_column"): loader.load(specs) if __name__ == "__main__": pytest.main([__file__, "-v", "-s"])