"""测试 StockPoolManager 验证新的自定义函数和因子筛选功能。 重点测试:临时因子隔离(只删除新生成的因子,保留原本存在的)。 """ from unittest.mock import Mock, patch import polars as pl import pytest from src.training.core.stock_pool_manager import StockPoolManager class TestStockPoolManagerBasic: """StockPoolManager 基础测试类""" def test_basic_filter_with_columns(self): """测试使用基础列进行筛选""" def filter_func(df: pl.DataFrame) -> pl.Series: return df["total_mv"] > 50 # 创建模拟 data_router mock_router = Mock() manager = StockPoolManager( filter_func=filter_func, required_columns=["total_mv"], data_router=mock_router, ) # 创建测试数据 data = pl.DataFrame( { "ts_code": ["000001.SZ", "000002.SZ", "000003.SZ", "000004.SZ"], "trade_date": ["20240101"] * 4, "close": [10.0, 20.0, 30.0, 40.0], "total_mv": [100.0, 30.0, 80.0, 20.0], } ) # 执行筛选(无需 mock,因为 total_mv 已在数据中) result = manager.filter_and_select_daily(data) # 验证返回数据列与输入一致 assert result.columns == data.columns # 验证筛选生效(保留市值 > 50 的股票) assert len(result) == 2 assert "000001.SZ" in result["ts_code"].to_list() assert "000003.SZ" in result["ts_code"].to_list() def test_filter_without_required_columns(self): """测试不使用额外列,仅使用输入数据中已有的列""" def filter_func(df: pl.DataFrame) -> pl.Series: return df["close"] > 25 manager = StockPoolManager(filter_func=filter_func) data = pl.DataFrame( { "ts_code": ["000001.SZ", "000002.SZ", "000003.SZ"], "trade_date": ["20240101"] * 3, "close": [10.0, 30.0, 20.0], } ) result = manager.filter_and_select_daily(data) # 验证只保留 close > 25 的股票 assert len(result) == 1 assert result["ts_code"][0] == "000002.SZ" assert result.columns == data.columns def test_empty_result(self): """测试筛选结果为空的情况""" def filter_func(df: pl.DataFrame) -> pl.Series: return df["close"] > 9999 # 不可能满足的条件 manager = StockPoolManager(filter_func=filter_func) data = pl.DataFrame( { "ts_code": ["000001.SZ", "000002.SZ"], "trade_date": ["20240101"] * 2, "close": [10.0, 20.0], } ) result = manager.filter_and_select_daily(data) assert len(result) == 0 assert result.columns == data.columns # 即使为空,列结构保持一致 class TestStockPoolManagerDailyIndependence: """每日独立筛选测试类""" def test_daily_independence(self): """测试每日独立进行筛选""" def filter_func(df: pl.DataFrame) -> pl.Series: # 每日选收盘价前 50% median = df["close"].median() return df["close"] >= median manager = StockPoolManager(filter_func=filter_func) # 创建多日期数据 data = pl.DataFrame( { "ts_code": [ "000001.SZ", "000002.SZ", "000003.SZ", "000004.SZ", # 日期 2 "000001.SZ", "000002.SZ", "000003.SZ", "000004.SZ", ], "trade_date": [ "20240101", "20240101", "20240101", "20240101", "20240102", "20240102", "20240102", "20240102", ], "close": [ 10.0, 20.0, 30.0, 40.0, # 日期1:选 30, 40 5.0, 15.0, 25.0, 35.0, # 日期2:选 25, 35 ], } ) result = manager.filter_and_select_daily(data) # 验证每个日期独立筛选 day1 = result.filter(pl.col("trade_date") == "20240101") day2 = result.filter(pl.col("trade_date") == "20240102") # 日期1:收盘价 >= 25(中位数)- 30 和 40 assert len(day1) == 2 assert set(day1["ts_code"].to_list()) == {"000003.SZ", "000004.SZ"} # 日期2:收盘价 >= 20(中位数)- 25 和 35 assert len(day2) == 2 assert set(day2["ts_code"].to_list()) == {"000003.SZ", "000004.SZ"} def test_uneven_daily_distribution(self): """测试每日股票数量不均的情况""" def filter_func(df: pl.DataFrame) -> pl.Series: return df["close"] > 15 manager = StockPoolManager(filter_func=filter_func) data = pl.DataFrame( { "ts_code": [ "000001.SZ", "000002.SZ", # 日期1:2只股票 "000001.SZ", "000002.SZ", "000003.SZ", "000004.SZ", # 日期2:4只股票 ], "trade_date": [ "20240101", "20240101", "20240102", "20240102", "20240102", "20240102", ], "close": [10.0, 20.0, 5.0, 15.0, 25.0, 35.0], } ) result = manager.filter_and_select_daily(data) # 日期1:只有 000002.SZ (20 > 15) day1 = result.filter(pl.col("trade_date") == "20240101") assert len(day1) == 1 # 日期2:000003.SZ (25) 和 000004.SZ (35) day2 = result.filter(pl.col("trade_date") == "20240102") assert len(day2) == 2 class TestStockPoolManagerFactorIsolation: """因子隔离测试类 - 核心测试""" @patch.object(StockPoolManager, "_compute_factors") def test_filter_with_factors(self, mock_compute): """测试使用因子表达式进行筛选,验证临时因子被删除""" # 设置 mock 返回值(包含计算后的因子) mock_compute.return_value = pl.DataFrame( { "ts_code": ["000001.SZ", "000002.SZ", "000003.SZ"], "trade_date": ["20240101"] * 3, "close": [11.0, 9.5, 10.8], "momentum_20": [0.1, -0.05, 0.08], # 只有第一个 > 0.05 } ) # 创建 Manager def filter_func(df: pl.DataFrame) -> pl.Series: return df["momentum_20"] > 0.05 manager = StockPoolManager( filter_func=filter_func, required_factors={"momentum_20": "(close / ts_delay(close, 20)) - 1"}, ) # 输入数据不含 momentum_20 data = pl.DataFrame( { "ts_code": ["000001.SZ", "000002.SZ", "000003.SZ"], "trade_date": ["20240101"] * 3, "close": [11.0, 9.5, 10.8], } ) result = manager.filter_and_select_daily(data) # 验证返回数据列与输入一致(momentum_20 被删除) assert result.columns == data.columns assert "momentum_20" not in result.columns # 验证筛选生效 # momentum_20 > 0.05: 000001.SZ (0.1), 000003.SZ (0.08) assert len(result) == 2 assert "000001.SZ" in result["ts_code"].to_list() assert "000003.SZ" in result["ts_code"].to_list() # 验证 _compute_factors 被调用 mock_compute.assert_called_once() @patch.object(StockPoolManager, "_compute_factors") def test_preserve_existing_factors(self, mock_compute): """测试输入中已存在的因子不会被删除(核心测试)""" # 设置 mock 返回值(包含 roe,但 momentum_20 已在输入中) mock_compute.return_value = pl.DataFrame( { "ts_code": ["000001.SZ", "000002.SZ", "000003.SZ"], "trade_date": ["20240101"] * 3, "close": [11.0, 9.5, 10.8], "momentum_20": [0.1, -0.05, 0.08], # 原本就存在 "roe": [0.12, 0.08, 0.15], # 本次生成,第二个 < 0.1 } ) def filter_func(df: pl.DataFrame) -> pl.Series: return (df["momentum_20"] > 0.05) & (df["roe"] > 0.1) manager = StockPoolManager( filter_func=filter_func, # 只声明 roe 为本次生成,momentum_20 已在输入中 required_factors={"roe": "n_income / equity"}, ) # 输入数据已包含 momentum_20 data = pl.DataFrame( { "ts_code": ["000001.SZ", "000002.SZ", "000003.SZ"], "trade_date": ["20240101"] * 3, "close": [11.0, 9.5, 10.8], "momentum_20": [0.1, -0.05, 0.08], # 原本就存在的因子 } ) result = manager.filter_and_select_daily(data) # 关键断言: # 1. momentum_20 保留(原本存在) assert "momentum_20" in result.columns # 2. roe 删除(本次生成) assert "roe" not in result.columns # 3. 列与输入完全一致 assert result.columns == data.columns # 验证筛选正确执行 # momentum_20 > 0.05: 000001.SZ (0.1), 000003.SZ (0.08) # roe > 0.1: 000001.SZ (0.12), 000003.SZ (0.15) # 交集:000001.SZ, 000003.SZ assert len(result) == 2 # 验证 _compute_factors 被调用(因为 roe 不存在) mock_compute.assert_called_once() def test_no_factor_computation_when_all_exist(self): """测试所有因子都已存在时,不调用 FactorEngine""" def filter_func(df: pl.DataFrame) -> pl.Series: return (df["factor_a"] > 0.5) & (df["factor_b"] < 0.3) manager = StockPoolManager( filter_func=filter_func, required_factors={ "factor_a": "some_expr_a", "factor_b": "some_expr_b", }, ) # 输入数据已包含所有因子 data = pl.DataFrame( { "ts_code": ["000001.SZ", "000002.SZ"], "trade_date": ["20240101"] * 2, "close": [10.0, 20.0], "factor_a": [0.6, 0.4], # 原本存在 "factor_b": [0.2, 0.5], # 原本存在 } ) with patch("src.factors.engine.factor_engine.FactorEngine") as mock_engine: result = manager.filter_and_select_daily(data) # FactorEngine 不应被调用(所有因子都已存在) mock_engine.assert_not_called() # 验证结果正确 assert len(result) == 1 assert result["ts_code"][0] == "000001.SZ" assert result.columns == data.columns class TestStockPoolManagerEdgeCases: """边界情况测试类""" def test_single_date(self): """测试单日数据""" def filter_func(df: pl.DataFrame) -> pl.Series: return df["close"] > 15 manager = StockPoolManager(filter_func=filter_func) data = pl.DataFrame( { "ts_code": ["000001.SZ", "000002.SZ", "000003.SZ"], "trade_date": ["20240101"] * 3, "close": [10.0, 20.0, 30.0], } ) result = manager.filter_and_select_daily(data) assert len(result) == 2 assert result.columns == data.columns def test_single_stock_per_day(self): """测试每天只有一只股票""" def filter_func(df: pl.DataFrame) -> pl.Series: return df["close"] > 0 # 都保留 manager = StockPoolManager(filter_func=filter_func) data = pl.DataFrame( { "ts_code": ["000001.SZ", "000002.SZ", "000003.SZ"], "trade_date": ["20240101", "20240102", "20240103"], "close": [10.0, 20.0, 30.0], } ) result = manager.filter_and_select_daily(data) assert len(result) == 3 assert result.columns == data.columns def test_filter_all_out_one_day(self): """测试某天全部过滤掉""" def filter_func(df: pl.DataFrame) -> pl.Series: return df["close"] > 100 # 很高的阈值 manager = StockPoolManager(filter_func=filter_func) data = pl.DataFrame( { "ts_code": [ "000001.SZ", "000002.SZ", "000001.SZ", "000002.SZ", ], "trade_date": [ "20240101", "20240101", "20240102", "20240102", ], "close": [ 10.0, 20.0, # 日期1:都过滤掉 150.0, 200.0, # 日期2:都保留 ], } ) result = manager.filter_and_select_daily(data) # 只有日期2的数据 assert len(result) == 2 assert all(result["trade_date"] == "20240102") def test_column_order_preserved(self): """测试列顺序保持不变""" def filter_func(df: pl.DataFrame) -> pl.Series: return df["close"] > 15 manager = StockPoolManager(filter_func=filter_func) data = pl.DataFrame( { "ts_code": ["000001.SZ", "000002.SZ", "000003.SZ"], "trade_date": ["20240101"] * 3, "open": [9.0, 19.0, 29.0], "high": [11.0, 21.0, 31.0], "low": [8.0, 18.0, 28.0], "close": [10.0, 20.0, 30.0], "volume": [1000, 2000, 3000], } ) result = manager.filter_and_select_daily(data) # 验证列顺序完全一致 assert list(result.columns) == list(data.columns) if __name__ == "__main__": pytest.main([__file__, "-v"])