Files
ProStock/tests/training/test_selectors.py
liaozhaorun 6b63c428d9 feat(training): 实现股票池选择器配置
- 新增 StockFilterConfig:支持按代码前缀过滤创业板/科创板/北交所
- 新增 MarketCapSelectorConfig:配置市值选择参数(数量、排序、列名)
- 添加参数验证(n>0, 列名非空)
- 在 components 模块导出配置类
- 添加 15 个单元测试覆盖各种场景
2026-03-03 22:10:36 +08:00

184 lines
5.3 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""测试股票池选择器配置
验证 StockFilterConfig 和 MarketCapSelectorConfig 功能。
"""
import pytest
from src.training.components.selectors import (
MarketCapSelectorConfig,
StockFilterConfig,
)
class TestStockFilterConfig:
"""StockFilterConfig 测试类"""
def test_default_values(self):
"""测试默认值"""
config = StockFilterConfig()
assert config.exclude_cyb is True
assert config.exclude_kcb is True
assert config.exclude_bj is True
assert config.exclude_st is True
def test_custom_values(self):
"""测试自定义值"""
config = StockFilterConfig(
exclude_cyb=False,
exclude_kcb=False,
exclude_bj=False,
exclude_st=False,
)
assert config.exclude_cyb is False
assert config.exclude_kcb is False
assert config.exclude_bj is False
assert config.exclude_st is False
def test_filter_codes_exclude_all(self):
"""测试排除所有类型"""
config = StockFilterConfig(
exclude_cyb=True,
exclude_kcb=True,
exclude_bj=True,
exclude_st=True,
)
codes = [
"000001.SZ", # 主板 - 保留
"300001.SZ", # 创业板 - 排除
"688001.SH", # 科创板 - 排除
"830001.BJ", # 北交所8开头- 排除
"430001.BJ", # 北交所4开头- 排除
]
result = config.filter_codes(codes)
assert result == ["000001.SZ"]
def test_filter_codes_allow_cyb(self):
"""测试允许创业板"""
config = StockFilterConfig(
exclude_cyb=False,
exclude_kcb=True,
exclude_bj=True,
exclude_st=True,
)
codes = [
"000001.SZ",
"300001.SZ",
"688001.SH",
]
result = config.filter_codes(codes)
assert result == ["000001.SZ", "300001.SZ"]
def test_filter_codes_allow_kcb(self):
"""测试允许科创板"""
config = StockFilterConfig(
exclude_cyb=True,
exclude_kcb=False,
exclude_bj=True,
exclude_st=True,
)
codes = [
"000001.SZ",
"300001.SZ",
"688001.SH",
]
result = config.filter_codes(codes)
assert result == ["000001.SZ", "688001.SH"]
def test_filter_codes_allow_bj(self):
"""测试允许北交所"""
config = StockFilterConfig(
exclude_cyb=True,
exclude_kcb=True,
exclude_bj=False,
exclude_st=True,
)
codes = [
"000001.SZ",
"300001.SZ",
"830001.BJ",
"430001.BJ",
]
result = config.filter_codes(codes)
assert result == ["000001.SZ", "830001.BJ", "430001.BJ"]
def test_filter_codes_allow_all(self):
"""测试允许所有类型"""
config = StockFilterConfig(
exclude_cyb=False,
exclude_kcb=False,
exclude_bj=False,
exclude_st=False,
)
codes = [
"000001.SZ",
"300001.SZ",
"688001.SH",
"830001.BJ",
"430001.BJ",
]
result = config.filter_codes(codes)
assert result == codes
def test_filter_codes_empty_list(self):
"""测试空列表"""
config = StockFilterConfig()
result = config.filter_codes([])
assert result == []
def test_filter_codes_no_matching(self):
"""测试全部排除"""
config = StockFilterConfig()
codes = ["300001.SZ", "688001.SH", "830001.BJ"]
result = config.filter_codes(codes)
assert result == []
class TestMarketCapSelectorConfig:
"""MarketCapSelectorConfig 测试类"""
def test_default_values(self):
"""测试默认值"""
config = MarketCapSelectorConfig()
assert config.enabled is True
assert config.n == 100
assert config.ascending is False
assert config.market_cap_col == "total_mv"
def test_custom_values(self):
"""测试自定义值"""
config = MarketCapSelectorConfig(
enabled=False,
n=50,
ascending=True,
market_cap_col="circ_mv",
)
assert config.enabled is False
assert config.n == 50
assert config.ascending is True
assert config.market_cap_col == "circ_mv"
def test_invalid_n_zero(self):
"""测试无效的 n=0"""
with pytest.raises(ValueError, match="n 必须是正整数"):
MarketCapSelectorConfig(n=0)
def test_invalid_n_negative(self):
"""测试无效的负数 n"""
with pytest.raises(ValueError, match="n 必须是正整数"):
MarketCapSelectorConfig(n=-1)
def test_invalid_empty_market_cap_col(self):
"""测试空的市值列名"""
with pytest.raises(ValueError, match="market_cap_col 不能为空"):
MarketCapSelectorConfig(market_cap_col="")
def test_large_n(self):
"""测试大的 n 值"""
config = MarketCapSelectorConfig(n=5000)
assert config.n == 5000
if __name__ == "__main__":
pytest.main([__file__, "-v"])