"""测试数据加载器 - DataLoader 测试需求(来自 factor_implementation_plan.md): - 测试从单个 H5 文件加载数据 - 测试从多个 H5 文件加载并合并 - 测试列选择(只加载需要的列) - 测试缓存机制(第二次加载更快) - 测试 clear_cache() 清空缓存 - 测试按 date_range 过滤 - 测试文件不存在时抛出 FileNotFoundError - 测试列不存在时抛出 KeyError """ import pytest import polars as pl import pandas as pd from pathlib import Path from src.factors import DataSpec, DataLoader class TestDataLoaderBasic: """测试 DataLoader 基本功能""" @pytest.fixture def loader(self): """创建 DataLoader 实例""" return DataLoader(data_dir="data") def test_init(self): """测试初始化""" loader = DataLoader(data_dir="data") assert loader.data_dir == Path("data") assert loader._cache == {} def test_load_single_source(self, loader): """测试从单个 H5 文件加载数据""" specs = [ DataSpec( source="daily", columns=["ts_code", "trade_date", "close"], lookback_days=1, ) ] df = loader.load(specs) assert isinstance(df, pl.DataFrame) assert len(df) > 0 assert "ts_code" in df.columns assert "trade_date" in df.columns assert "close" in df.columns def test_load_multiple_sources(self, loader): """测试从多个 H5 文件加载并合并""" # 注意:这里假设只有一个 daily.h5 文件 # 如果有多个文件,可以测试合并逻辑 specs = [ DataSpec( source="daily", columns=["ts_code", "trade_date", "close"], lookback_days=1, ), DataSpec( source="daily", columns=["ts_code", "trade_date", "open", "high", "low"], lookback_days=1, ), ] df = loader.load(specs) assert isinstance(df, pl.DataFrame) assert len(df) > 0 # 应该包含所有列 assert set(df.columns) >= { "ts_code", "trade_date", "close", "open", "high", "low", } def test_column_selection(self, loader): """测试列选择(只加载需要的列)""" specs = [ DataSpec( source="daily", columns=["ts_code", "trade_date", "close"], lookback_days=1, ) ] df = loader.load(specs) # 只应该有 3 列 assert set(df.columns) == {"ts_code", "trade_date", "close"} def test_date_range_filter(self, loader): """测试按 date_range 过滤""" specs = [ DataSpec( source="daily", columns=["ts_code", "trade_date", "close"], lookback_days=1, ) ] # 先加载所有数据 df_all = loader.load(specs) total_rows = len(df_all) # 清空缓存,重新加载特定日期范围 loader.clear_cache() df_filtered = loader.load(specs, date_range=("20240101", "20240131")) # 过滤后的数据应该更少或相等 assert len(df_filtered) <= total_rows # 所有日期都应该在范围内 if len(df_filtered) > 0: dates = df_filtered["trade_date"].to_list() assert all("20240101" <= d <= "20240131" for d in dates) class TestDataLoaderCache: """测试 DataLoader 缓存机制""" @pytest.fixture def loader(self): """创建 DataLoader 实例""" return DataLoader(data_dir="data") def test_cache_populated(self, loader): """测试加载后缓存被填充""" specs = [ DataSpec( source="daily", columns=["ts_code", "trade_date", "close"], lookback_days=1, ) ] # 第一次加载 loader.load(specs) # 检查缓存 assert len(loader._cache) > 0 def test_cache_used(self, loader): """测试第二次加载使用缓存(更快)""" import time specs = [ DataSpec( source="daily", columns=["ts_code", "trade_date", "close"], lookback_days=1, ) ] # 第一次加载 start = time.time() df1 = loader.load(specs) time1 = time.time() - start # 第二次加载(应该使用缓存) start = time.time() df2 = loader.load(specs) time2 = time.time() - start # 数据应该相同 assert df1.shape == df2.shape # 第二次应该更快(至少快 50%) # 注意:如果数据量很小,这个测试可能不稳定 # assert time2 < time1 * 0.5 def test_clear_cache(self, loader): """测试 clear_cache() 清空缓存""" specs = [ DataSpec( source="daily", columns=["ts_code", "trade_date", "close"], lookback_days=1, ) ] # 加载数据 loader.load(specs) assert len(loader._cache) > 0 # 清空缓存 loader.clear_cache() assert len(loader._cache) == 0 def test_cache_info(self, loader): """测试 get_cache_info()""" specs = [ DataSpec( source="daily", columns=["ts_code", "trade_date", "close"], lookback_days=1, ) ] # 加载前 info_before = loader.get_cache_info() assert info_before["entries"] == 0 # 加载后 loader.load(specs) info_after = loader.get_cache_info() assert info_after["entries"] > 0 assert info_after["total_rows"] > 0 class TestDataLoaderErrors: """测试 DataLoader 错误处理""" def test_file_not_found(self): """测试文件不存在时抛出 FileNotFoundError""" loader = DataLoader(data_dir="nonexistent_dir") specs = [ DataSpec( source="daily", columns=["ts_code", "trade_date", "close"], lookback_days=1, ) ] with pytest.raises(FileNotFoundError): loader.load(specs) def test_column_not_found(self): """测试列不存在时抛出 KeyError""" loader = DataLoader(data_dir="data") specs = [ DataSpec( source="daily", columns=["ts_code", "trade_date", "nonexistent_column"], lookback_days=1, ) ] with pytest.raises(KeyError, match="nonexistent_column"): loader.load(specs)