Files
ProStock/tests/training/test_selectors.py

184 lines
5.3 KiB
Python
Raw Normal View History

"""测试股票池选择器配置
验证 StockFilterConfig MarketCapSelectorConfig 功能
"""
import pytest
from src.training.components.selectors import (
MarketCapSelectorConfig,
StockFilterConfig,
)
class TestStockFilterConfig:
"""StockFilterConfig 测试类"""
def test_default_values(self):
"""测试默认值"""
config = StockFilterConfig()
assert config.exclude_cyb is True
assert config.exclude_kcb is True
assert config.exclude_bj is True
assert config.exclude_st is True
def test_custom_values(self):
"""测试自定义值"""
config = StockFilterConfig(
exclude_cyb=False,
exclude_kcb=False,
exclude_bj=False,
exclude_st=False,
)
assert config.exclude_cyb is False
assert config.exclude_kcb is False
assert config.exclude_bj is False
assert config.exclude_st is False
def test_filter_codes_exclude_all(self):
"""测试排除所有类型"""
config = StockFilterConfig(
exclude_cyb=True,
exclude_kcb=True,
exclude_bj=True,
exclude_st=True,
)
codes = [
"000001.SZ", # 主板 - 保留
"300001.SZ", # 创业板 - 排除
"688001.SH", # 科创板 - 排除
"830001.BJ", # 北交所8开头- 排除
"430001.BJ", # 北交所4开头- 排除
]
result = config.filter_codes(codes)
assert result == ["000001.SZ"]
def test_filter_codes_allow_cyb(self):
"""测试允许创业板"""
config = StockFilterConfig(
exclude_cyb=False,
exclude_kcb=True,
exclude_bj=True,
exclude_st=True,
)
codes = [
"000001.SZ",
"300001.SZ",
"688001.SH",
]
result = config.filter_codes(codes)
assert result == ["000001.SZ", "300001.SZ"]
def test_filter_codes_allow_kcb(self):
"""测试允许科创板"""
config = StockFilterConfig(
exclude_cyb=True,
exclude_kcb=False,
exclude_bj=True,
exclude_st=True,
)
codes = [
"000001.SZ",
"300001.SZ",
"688001.SH",
]
result = config.filter_codes(codes)
assert result == ["000001.SZ", "688001.SH"]
def test_filter_codes_allow_bj(self):
"""测试允许北交所"""
config = StockFilterConfig(
exclude_cyb=True,
exclude_kcb=True,
exclude_bj=False,
exclude_st=True,
)
codes = [
"000001.SZ",
"300001.SZ",
"830001.BJ",
"430001.BJ",
]
result = config.filter_codes(codes)
assert result == ["000001.SZ", "830001.BJ", "430001.BJ"]
def test_filter_codes_allow_all(self):
"""测试允许所有类型"""
config = StockFilterConfig(
exclude_cyb=False,
exclude_kcb=False,
exclude_bj=False,
exclude_st=False,
)
codes = [
"000001.SZ",
"300001.SZ",
"688001.SH",
"830001.BJ",
"430001.BJ",
]
result = config.filter_codes(codes)
assert result == codes
def test_filter_codes_empty_list(self):
"""测试空列表"""
config = StockFilterConfig()
result = config.filter_codes([])
assert result == []
def test_filter_codes_no_matching(self):
"""测试全部排除"""
config = StockFilterConfig()
codes = ["300001.SZ", "688001.SH", "830001.BJ"]
result = config.filter_codes(codes)
assert result == []
class TestMarketCapSelectorConfig:
"""MarketCapSelectorConfig 测试类"""
def test_default_values(self):
"""测试默认值"""
config = MarketCapSelectorConfig()
assert config.enabled is True
assert config.n == 100
assert config.ascending is False
assert config.market_cap_col == "total_mv"
def test_custom_values(self):
"""测试自定义值"""
config = MarketCapSelectorConfig(
enabled=False,
n=50,
ascending=True,
market_cap_col="circ_mv",
)
assert config.enabled is False
assert config.n == 50
assert config.ascending is True
assert config.market_cap_col == "circ_mv"
def test_invalid_n_zero(self):
"""测试无效的 n=0"""
with pytest.raises(ValueError, match="n 必须是正整数"):
MarketCapSelectorConfig(n=0)
def test_invalid_n_negative(self):
"""测试无效的负数 n"""
with pytest.raises(ValueError, match="n 必须是正整数"):
MarketCapSelectorConfig(n=-1)
def test_invalid_empty_market_cap_col(self):
"""测试空的市值列名"""
with pytest.raises(ValueError, match="market_cap_col 不能为空"):
MarketCapSelectorConfig(market_cap_col="")
def test_large_n(self):
"""测试大的 n 值"""
config = MarketCapSelectorConfig(n=5000)
assert config.n == 5000
if __name__ == "__main__":
pytest.main([__file__, "-v"])