184 lines
5.3 KiB
Python
184 lines
5.3 KiB
Python
|
|
"""测试股票池选择器配置
|
|||
|
|
|
|||
|
|
验证 StockFilterConfig 和 MarketCapSelectorConfig 功能。
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
import pytest
|
|||
|
|
|
|||
|
|
from src.training.components.selectors import (
|
|||
|
|
MarketCapSelectorConfig,
|
|||
|
|
StockFilterConfig,
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
|
|||
|
|
class TestStockFilterConfig:
|
|||
|
|
"""StockFilterConfig 测试类"""
|
|||
|
|
|
|||
|
|
def test_default_values(self):
|
|||
|
|
"""测试默认值"""
|
|||
|
|
config = StockFilterConfig()
|
|||
|
|
assert config.exclude_cyb is True
|
|||
|
|
assert config.exclude_kcb is True
|
|||
|
|
assert config.exclude_bj is True
|
|||
|
|
assert config.exclude_st is True
|
|||
|
|
|
|||
|
|
def test_custom_values(self):
|
|||
|
|
"""测试自定义值"""
|
|||
|
|
config = StockFilterConfig(
|
|||
|
|
exclude_cyb=False,
|
|||
|
|
exclude_kcb=False,
|
|||
|
|
exclude_bj=False,
|
|||
|
|
exclude_st=False,
|
|||
|
|
)
|
|||
|
|
assert config.exclude_cyb is False
|
|||
|
|
assert config.exclude_kcb is False
|
|||
|
|
assert config.exclude_bj is False
|
|||
|
|
assert config.exclude_st is False
|
|||
|
|
|
|||
|
|
def test_filter_codes_exclude_all(self):
|
|||
|
|
"""测试排除所有类型"""
|
|||
|
|
config = StockFilterConfig(
|
|||
|
|
exclude_cyb=True,
|
|||
|
|
exclude_kcb=True,
|
|||
|
|
exclude_bj=True,
|
|||
|
|
exclude_st=True,
|
|||
|
|
)
|
|||
|
|
codes = [
|
|||
|
|
"000001.SZ", # 主板 - 保留
|
|||
|
|
"300001.SZ", # 创业板 - 排除
|
|||
|
|
"688001.SH", # 科创板 - 排除
|
|||
|
|
"830001.BJ", # 北交所(8开头)- 排除
|
|||
|
|
"430001.BJ", # 北交所(4开头)- 排除
|
|||
|
|
]
|
|||
|
|
result = config.filter_codes(codes)
|
|||
|
|
assert result == ["000001.SZ"]
|
|||
|
|
|
|||
|
|
def test_filter_codes_allow_cyb(self):
|
|||
|
|
"""测试允许创业板"""
|
|||
|
|
config = StockFilterConfig(
|
|||
|
|
exclude_cyb=False,
|
|||
|
|
exclude_kcb=True,
|
|||
|
|
exclude_bj=True,
|
|||
|
|
exclude_st=True,
|
|||
|
|
)
|
|||
|
|
codes = [
|
|||
|
|
"000001.SZ",
|
|||
|
|
"300001.SZ",
|
|||
|
|
"688001.SH",
|
|||
|
|
]
|
|||
|
|
result = config.filter_codes(codes)
|
|||
|
|
assert result == ["000001.SZ", "300001.SZ"]
|
|||
|
|
|
|||
|
|
def test_filter_codes_allow_kcb(self):
|
|||
|
|
"""测试允许科创板"""
|
|||
|
|
config = StockFilterConfig(
|
|||
|
|
exclude_cyb=True,
|
|||
|
|
exclude_kcb=False,
|
|||
|
|
exclude_bj=True,
|
|||
|
|
exclude_st=True,
|
|||
|
|
)
|
|||
|
|
codes = [
|
|||
|
|
"000001.SZ",
|
|||
|
|
"300001.SZ",
|
|||
|
|
"688001.SH",
|
|||
|
|
]
|
|||
|
|
result = config.filter_codes(codes)
|
|||
|
|
assert result == ["000001.SZ", "688001.SH"]
|
|||
|
|
|
|||
|
|
def test_filter_codes_allow_bj(self):
|
|||
|
|
"""测试允许北交所"""
|
|||
|
|
config = StockFilterConfig(
|
|||
|
|
exclude_cyb=True,
|
|||
|
|
exclude_kcb=True,
|
|||
|
|
exclude_bj=False,
|
|||
|
|
exclude_st=True,
|
|||
|
|
)
|
|||
|
|
codes = [
|
|||
|
|
"000001.SZ",
|
|||
|
|
"300001.SZ",
|
|||
|
|
"830001.BJ",
|
|||
|
|
"430001.BJ",
|
|||
|
|
]
|
|||
|
|
result = config.filter_codes(codes)
|
|||
|
|
assert result == ["000001.SZ", "830001.BJ", "430001.BJ"]
|
|||
|
|
|
|||
|
|
def test_filter_codes_allow_all(self):
|
|||
|
|
"""测试允许所有类型"""
|
|||
|
|
config = StockFilterConfig(
|
|||
|
|
exclude_cyb=False,
|
|||
|
|
exclude_kcb=False,
|
|||
|
|
exclude_bj=False,
|
|||
|
|
exclude_st=False,
|
|||
|
|
)
|
|||
|
|
codes = [
|
|||
|
|
"000001.SZ",
|
|||
|
|
"300001.SZ",
|
|||
|
|
"688001.SH",
|
|||
|
|
"830001.BJ",
|
|||
|
|
"430001.BJ",
|
|||
|
|
]
|
|||
|
|
result = config.filter_codes(codes)
|
|||
|
|
assert result == codes
|
|||
|
|
|
|||
|
|
def test_filter_codes_empty_list(self):
|
|||
|
|
"""测试空列表"""
|
|||
|
|
config = StockFilterConfig()
|
|||
|
|
result = config.filter_codes([])
|
|||
|
|
assert result == []
|
|||
|
|
|
|||
|
|
def test_filter_codes_no_matching(self):
|
|||
|
|
"""测试全部排除"""
|
|||
|
|
config = StockFilterConfig()
|
|||
|
|
codes = ["300001.SZ", "688001.SH", "830001.BJ"]
|
|||
|
|
result = config.filter_codes(codes)
|
|||
|
|
assert result == []
|
|||
|
|
|
|||
|
|
|
|||
|
|
class TestMarketCapSelectorConfig:
|
|||
|
|
"""MarketCapSelectorConfig 测试类"""
|
|||
|
|
|
|||
|
|
def test_default_values(self):
|
|||
|
|
"""测试默认值"""
|
|||
|
|
config = MarketCapSelectorConfig()
|
|||
|
|
assert config.enabled is True
|
|||
|
|
assert config.n == 100
|
|||
|
|
assert config.ascending is False
|
|||
|
|
assert config.market_cap_col == "total_mv"
|
|||
|
|
|
|||
|
|
def test_custom_values(self):
|
|||
|
|
"""测试自定义值"""
|
|||
|
|
config = MarketCapSelectorConfig(
|
|||
|
|
enabled=False,
|
|||
|
|
n=50,
|
|||
|
|
ascending=True,
|
|||
|
|
market_cap_col="circ_mv",
|
|||
|
|
)
|
|||
|
|
assert config.enabled is False
|
|||
|
|
assert config.n == 50
|
|||
|
|
assert config.ascending is True
|
|||
|
|
assert config.market_cap_col == "circ_mv"
|
|||
|
|
|
|||
|
|
def test_invalid_n_zero(self):
|
|||
|
|
"""测试无效的 n=0"""
|
|||
|
|
with pytest.raises(ValueError, match="n 必须是正整数"):
|
|||
|
|
MarketCapSelectorConfig(n=0)
|
|||
|
|
|
|||
|
|
def test_invalid_n_negative(self):
|
|||
|
|
"""测试无效的负数 n"""
|
|||
|
|
with pytest.raises(ValueError, match="n 必须是正整数"):
|
|||
|
|
MarketCapSelectorConfig(n=-1)
|
|||
|
|
|
|||
|
|
def test_invalid_empty_market_cap_col(self):
|
|||
|
|
"""测试空的市值列名"""
|
|||
|
|
with pytest.raises(ValueError, match="market_cap_col 不能为空"):
|
|||
|
|
MarketCapSelectorConfig(market_cap_col="")
|
|||
|
|
|
|||
|
|
def test_large_n(self):
|
|||
|
|
"""测试大的 n 值"""
|
|||
|
|
config = MarketCapSelectorConfig(n=5000)
|
|||
|
|
assert config.n == 5000
|
|||
|
|
|
|||
|
|
|
|||
|
|
if __name__ == "__main__":
|
|||
|
|
pytest.main([__file__, "-v"])
|