Files
ProStock/tests/training/test_stock_pool_manager.py
liaozhaorun 88fa848b96 refactor(training): 重构股票池管理 API 并更新训练流程
- 移除 StockFilterConfig/MarketCapSelectorConfig,改用 StockPoolManager + filter_func
- Trainer 支持 train/val/test 三分法划分
- 更新 regression.ipynb 适配新 API
- 删除已弃用的 test_selectors.py,后续补充 StockPoolManager 测试
2026-03-09 22:33:41 +08:00

453 lines
14 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""测试 StockPoolManager
验证新的自定义函数和因子筛选功能。
重点测试:临时因子隔离(只删除新生成的因子,保留原本存在的)。
"""
from unittest.mock import Mock, patch
import polars as pl
import pytest
from src.training.core.stock_pool_manager import StockPoolManager
class TestStockPoolManagerBasic:
"""StockPoolManager 基础测试类"""
def test_basic_filter_with_columns(self):
"""测试使用基础列进行筛选"""
def filter_func(df: pl.DataFrame) -> pl.Series:
return df["total_mv"] > 50
# 创建模拟 data_router
mock_router = Mock()
manager = StockPoolManager(
filter_func=filter_func,
required_columns=["total_mv"],
data_router=mock_router,
)
# 创建测试数据
data = pl.DataFrame(
{
"ts_code": ["000001.SZ", "000002.SZ", "000003.SZ", "000004.SZ"],
"trade_date": ["20240101"] * 4,
"close": [10.0, 20.0, 30.0, 40.0],
"total_mv": [100.0, 30.0, 80.0, 20.0],
}
)
# 执行筛选(无需 mock因为 total_mv 已在数据中)
result = manager.filter_and_select_daily(data)
# 验证返回数据列与输入一致
assert result.columns == data.columns
# 验证筛选生效(保留市值 > 50 的股票)
assert len(result) == 2
assert "000001.SZ" in result["ts_code"].to_list()
assert "000003.SZ" in result["ts_code"].to_list()
def test_filter_without_required_columns(self):
"""测试不使用额外列,仅使用输入数据中已有的列"""
def filter_func(df: pl.DataFrame) -> pl.Series:
return df["close"] > 25
manager = StockPoolManager(filter_func=filter_func)
data = pl.DataFrame(
{
"ts_code": ["000001.SZ", "000002.SZ", "000003.SZ"],
"trade_date": ["20240101"] * 3,
"close": [10.0, 30.0, 20.0],
}
)
result = manager.filter_and_select_daily(data)
# 验证只保留 close > 25 的股票
assert len(result) == 1
assert result["ts_code"][0] == "000002.SZ"
assert result.columns == data.columns
def test_empty_result(self):
"""测试筛选结果为空的情况"""
def filter_func(df: pl.DataFrame) -> pl.Series:
return df["close"] > 9999 # 不可能满足的条件
manager = StockPoolManager(filter_func=filter_func)
data = pl.DataFrame(
{
"ts_code": ["000001.SZ", "000002.SZ"],
"trade_date": ["20240101"] * 2,
"close": [10.0, 20.0],
}
)
result = manager.filter_and_select_daily(data)
assert len(result) == 0
assert result.columns == data.columns # 即使为空,列结构保持一致
class TestStockPoolManagerDailyIndependence:
"""每日独立筛选测试类"""
def test_daily_independence(self):
"""测试每日独立进行筛选"""
def filter_func(df: pl.DataFrame) -> pl.Series:
# 每日选收盘价前 50%
median = df["close"].median()
return df["close"] >= median
manager = StockPoolManager(filter_func=filter_func)
# 创建多日期数据
data = pl.DataFrame(
{
"ts_code": [
"000001.SZ",
"000002.SZ",
"000003.SZ",
"000004.SZ",
# 日期 2
"000001.SZ",
"000002.SZ",
"000003.SZ",
"000004.SZ",
],
"trade_date": [
"20240101",
"20240101",
"20240101",
"20240101",
"20240102",
"20240102",
"20240102",
"20240102",
],
"close": [
10.0,
20.0,
30.0,
40.0, # 日期1选 30, 40
5.0,
15.0,
25.0,
35.0, # 日期2选 25, 35
],
}
)
result = manager.filter_and_select_daily(data)
# 验证每个日期独立筛选
day1 = result.filter(pl.col("trade_date") == "20240101")
day2 = result.filter(pl.col("trade_date") == "20240102")
# 日期1收盘价 >= 25中位数- 30 和 40
assert len(day1) == 2
assert set(day1["ts_code"].to_list()) == {"000003.SZ", "000004.SZ"}
# 日期2收盘价 >= 20中位数- 25 和 35
assert len(day2) == 2
assert set(day2["ts_code"].to_list()) == {"000003.SZ", "000004.SZ"}
def test_uneven_daily_distribution(self):
"""测试每日股票数量不均的情况"""
def filter_func(df: pl.DataFrame) -> pl.Series:
return df["close"] > 15
manager = StockPoolManager(filter_func=filter_func)
data = pl.DataFrame(
{
"ts_code": [
"000001.SZ",
"000002.SZ", # 日期12只股票
"000001.SZ",
"000002.SZ",
"000003.SZ",
"000004.SZ", # 日期24只股票
],
"trade_date": [
"20240101",
"20240101",
"20240102",
"20240102",
"20240102",
"20240102",
],
"close": [10.0, 20.0, 5.0, 15.0, 25.0, 35.0],
}
)
result = manager.filter_and_select_daily(data)
# 日期1只有 000002.SZ (20 > 15)
day1 = result.filter(pl.col("trade_date") == "20240101")
assert len(day1) == 1
# 日期2000003.SZ (25) 和 000004.SZ (35)
day2 = result.filter(pl.col("trade_date") == "20240102")
assert len(day2) == 2
class TestStockPoolManagerFactorIsolation:
"""因子隔离测试类 - 核心测试"""
@patch.object(StockPoolManager, "_compute_factors")
def test_filter_with_factors(self, mock_compute):
"""测试使用因子表达式进行筛选,验证临时因子被删除"""
# 设置 mock 返回值(包含计算后的因子)
mock_compute.return_value = pl.DataFrame(
{
"ts_code": ["000001.SZ", "000002.SZ", "000003.SZ"],
"trade_date": ["20240101"] * 3,
"close": [11.0, 9.5, 10.8],
"momentum_20": [0.1, -0.05, 0.08], # 只有第一个 > 0.05
}
)
# 创建 Manager
def filter_func(df: pl.DataFrame) -> pl.Series:
return df["momentum_20"] > 0.05
manager = StockPoolManager(
filter_func=filter_func,
required_factors={"momentum_20": "(close / ts_delay(close, 20)) - 1"},
)
# 输入数据不含 momentum_20
data = pl.DataFrame(
{
"ts_code": ["000001.SZ", "000002.SZ", "000003.SZ"],
"trade_date": ["20240101"] * 3,
"close": [11.0, 9.5, 10.8],
}
)
result = manager.filter_and_select_daily(data)
# 验证返回数据列与输入一致momentum_20 被删除)
assert result.columns == data.columns
assert "momentum_20" not in result.columns
# 验证筛选生效
# momentum_20 > 0.05: 000001.SZ (0.1), 000003.SZ (0.08)
assert len(result) == 2
assert "000001.SZ" in result["ts_code"].to_list()
assert "000003.SZ" in result["ts_code"].to_list()
# 验证 _compute_factors 被调用
mock_compute.assert_called_once()
@patch.object(StockPoolManager, "_compute_factors")
def test_preserve_existing_factors(self, mock_compute):
"""测试输入中已存在的因子不会被删除(核心测试)"""
# 设置 mock 返回值(包含 roe但 momentum_20 已在输入中)
mock_compute.return_value = pl.DataFrame(
{
"ts_code": ["000001.SZ", "000002.SZ", "000003.SZ"],
"trade_date": ["20240101"] * 3,
"close": [11.0, 9.5, 10.8],
"momentum_20": [0.1, -0.05, 0.08], # 原本就存在
"roe": [0.12, 0.08, 0.15], # 本次生成,第二个 < 0.1
}
)
def filter_func(df: pl.DataFrame) -> pl.Series:
return (df["momentum_20"] > 0.05) & (df["roe"] > 0.1)
manager = StockPoolManager(
filter_func=filter_func,
# 只声明 roe 为本次生成momentum_20 已在输入中
required_factors={"roe": "n_income / equity"},
)
# 输入数据已包含 momentum_20
data = pl.DataFrame(
{
"ts_code": ["000001.SZ", "000002.SZ", "000003.SZ"],
"trade_date": ["20240101"] * 3,
"close": [11.0, 9.5, 10.8],
"momentum_20": [0.1, -0.05, 0.08], # 原本就存在的因子
}
)
result = manager.filter_and_select_daily(data)
# 关键断言:
# 1. momentum_20 保留(原本存在)
assert "momentum_20" in result.columns
# 2. roe 删除(本次生成)
assert "roe" not in result.columns
# 3. 列与输入完全一致
assert result.columns == data.columns
# 验证筛选正确执行
# momentum_20 > 0.05: 000001.SZ (0.1), 000003.SZ (0.08)
# roe > 0.1: 000001.SZ (0.12), 000003.SZ (0.15)
# 交集000001.SZ, 000003.SZ
assert len(result) == 2
# 验证 _compute_factors 被调用(因为 roe 不存在)
mock_compute.assert_called_once()
def test_no_factor_computation_when_all_exist(self):
"""测试所有因子都已存在时,不调用 FactorEngine"""
def filter_func(df: pl.DataFrame) -> pl.Series:
return (df["factor_a"] > 0.5) & (df["factor_b"] < 0.3)
manager = StockPoolManager(
filter_func=filter_func,
required_factors={
"factor_a": "some_expr_a",
"factor_b": "some_expr_b",
},
)
# 输入数据已包含所有因子
data = pl.DataFrame(
{
"ts_code": ["000001.SZ", "000002.SZ"],
"trade_date": ["20240101"] * 2,
"close": [10.0, 20.0],
"factor_a": [0.6, 0.4], # 原本存在
"factor_b": [0.2, 0.5], # 原本存在
}
)
with patch("src.factors.engine.factor_engine.FactorEngine") as mock_engine:
result = manager.filter_and_select_daily(data)
# FactorEngine 不应被调用(所有因子都已存在)
mock_engine.assert_not_called()
# 验证结果正确
assert len(result) == 1
assert result["ts_code"][0] == "000001.SZ"
assert result.columns == data.columns
class TestStockPoolManagerEdgeCases:
"""边界情况测试类"""
def test_single_date(self):
"""测试单日数据"""
def filter_func(df: pl.DataFrame) -> pl.Series:
return df["close"] > 15
manager = StockPoolManager(filter_func=filter_func)
data = pl.DataFrame(
{
"ts_code": ["000001.SZ", "000002.SZ", "000003.SZ"],
"trade_date": ["20240101"] * 3,
"close": [10.0, 20.0, 30.0],
}
)
result = manager.filter_and_select_daily(data)
assert len(result) == 2
assert result.columns == data.columns
def test_single_stock_per_day(self):
"""测试每天只有一只股票"""
def filter_func(df: pl.DataFrame) -> pl.Series:
return df["close"] > 0 # 都保留
manager = StockPoolManager(filter_func=filter_func)
data = pl.DataFrame(
{
"ts_code": ["000001.SZ", "000002.SZ", "000003.SZ"],
"trade_date": ["20240101", "20240102", "20240103"],
"close": [10.0, 20.0, 30.0],
}
)
result = manager.filter_and_select_daily(data)
assert len(result) == 3
assert result.columns == data.columns
def test_filter_all_out_one_day(self):
"""测试某天全部过滤掉"""
def filter_func(df: pl.DataFrame) -> pl.Series:
return df["close"] > 100 # 很高的阈值
manager = StockPoolManager(filter_func=filter_func)
data = pl.DataFrame(
{
"ts_code": [
"000001.SZ",
"000002.SZ",
"000001.SZ",
"000002.SZ",
],
"trade_date": [
"20240101",
"20240101",
"20240102",
"20240102",
],
"close": [
10.0,
20.0, # 日期1都过滤掉
150.0,
200.0, # 日期2都保留
],
}
)
result = manager.filter_and_select_daily(data)
# 只有日期2的数据
assert len(result) == 2
assert all(result["trade_date"] == "20240102")
def test_column_order_preserved(self):
"""测试列顺序保持不变"""
def filter_func(df: pl.DataFrame) -> pl.Series:
return df["close"] > 15
manager = StockPoolManager(filter_func=filter_func)
data = pl.DataFrame(
{
"ts_code": ["000001.SZ", "000002.SZ", "000003.SZ"],
"trade_date": ["20240101"] * 3,
"open": [9.0, 19.0, 29.0],
"high": [11.0, 21.0, 31.0],
"low": [8.0, 18.0, 28.0],
"close": [10.0, 20.0, 30.0],
"volume": [1000, 2000, 3000],
}
)
result = manager.filter_and_select_daily(data)
# 验证列顺序完全一致
assert list(result.columns) == list(data.columns)
if __name__ == "__main__":
pytest.main([__file__, "-v"])