feat(training): 实现股票池选择器配置
- 新增 StockFilterConfig:支持按代码前缀过滤创业板/科创板/北交所 - 新增 MarketCapSelectorConfig:配置市值选择参数(数量、排序、列名) - 添加参数验证(n>0, 列名非空) - 在 components 模块导出配置类 - 添加 15 个单元测试覆盖各种场景
This commit is contained in:
183
tests/training/test_selectors.py
Normal file
183
tests/training/test_selectors.py
Normal file
@@ -0,0 +1,183 @@
|
||||
"""测试股票池选择器配置
|
||||
|
||||
验证 StockFilterConfig 和 MarketCapSelectorConfig 功能。
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from src.training.components.selectors import (
|
||||
MarketCapSelectorConfig,
|
||||
StockFilterConfig,
|
||||
)
|
||||
|
||||
|
||||
class TestStockFilterConfig:
|
||||
"""StockFilterConfig 测试类"""
|
||||
|
||||
def test_default_values(self):
|
||||
"""测试默认值"""
|
||||
config = StockFilterConfig()
|
||||
assert config.exclude_cyb is True
|
||||
assert config.exclude_kcb is True
|
||||
assert config.exclude_bj is True
|
||||
assert config.exclude_st is True
|
||||
|
||||
def test_custom_values(self):
|
||||
"""测试自定义值"""
|
||||
config = StockFilterConfig(
|
||||
exclude_cyb=False,
|
||||
exclude_kcb=False,
|
||||
exclude_bj=False,
|
||||
exclude_st=False,
|
||||
)
|
||||
assert config.exclude_cyb is False
|
||||
assert config.exclude_kcb is False
|
||||
assert config.exclude_bj is False
|
||||
assert config.exclude_st is False
|
||||
|
||||
def test_filter_codes_exclude_all(self):
|
||||
"""测试排除所有类型"""
|
||||
config = StockFilterConfig(
|
||||
exclude_cyb=True,
|
||||
exclude_kcb=True,
|
||||
exclude_bj=True,
|
||||
exclude_st=True,
|
||||
)
|
||||
codes = [
|
||||
"000001.SZ", # 主板 - 保留
|
||||
"300001.SZ", # 创业板 - 排除
|
||||
"688001.SH", # 科创板 - 排除
|
||||
"830001.BJ", # 北交所(8开头)- 排除
|
||||
"430001.BJ", # 北交所(4开头)- 排除
|
||||
]
|
||||
result = config.filter_codes(codes)
|
||||
assert result == ["000001.SZ"]
|
||||
|
||||
def test_filter_codes_allow_cyb(self):
|
||||
"""测试允许创业板"""
|
||||
config = StockFilterConfig(
|
||||
exclude_cyb=False,
|
||||
exclude_kcb=True,
|
||||
exclude_bj=True,
|
||||
exclude_st=True,
|
||||
)
|
||||
codes = [
|
||||
"000001.SZ",
|
||||
"300001.SZ",
|
||||
"688001.SH",
|
||||
]
|
||||
result = config.filter_codes(codes)
|
||||
assert result == ["000001.SZ", "300001.SZ"]
|
||||
|
||||
def test_filter_codes_allow_kcb(self):
|
||||
"""测试允许科创板"""
|
||||
config = StockFilterConfig(
|
||||
exclude_cyb=True,
|
||||
exclude_kcb=False,
|
||||
exclude_bj=True,
|
||||
exclude_st=True,
|
||||
)
|
||||
codes = [
|
||||
"000001.SZ",
|
||||
"300001.SZ",
|
||||
"688001.SH",
|
||||
]
|
||||
result = config.filter_codes(codes)
|
||||
assert result == ["000001.SZ", "688001.SH"]
|
||||
|
||||
def test_filter_codes_allow_bj(self):
|
||||
"""测试允许北交所"""
|
||||
config = StockFilterConfig(
|
||||
exclude_cyb=True,
|
||||
exclude_kcb=True,
|
||||
exclude_bj=False,
|
||||
exclude_st=True,
|
||||
)
|
||||
codes = [
|
||||
"000001.SZ",
|
||||
"300001.SZ",
|
||||
"830001.BJ",
|
||||
"430001.BJ",
|
||||
]
|
||||
result = config.filter_codes(codes)
|
||||
assert result == ["000001.SZ", "830001.BJ", "430001.BJ"]
|
||||
|
||||
def test_filter_codes_allow_all(self):
|
||||
"""测试允许所有类型"""
|
||||
config = StockFilterConfig(
|
||||
exclude_cyb=False,
|
||||
exclude_kcb=False,
|
||||
exclude_bj=False,
|
||||
exclude_st=False,
|
||||
)
|
||||
codes = [
|
||||
"000001.SZ",
|
||||
"300001.SZ",
|
||||
"688001.SH",
|
||||
"830001.BJ",
|
||||
"430001.BJ",
|
||||
]
|
||||
result = config.filter_codes(codes)
|
||||
assert result == codes
|
||||
|
||||
def test_filter_codes_empty_list(self):
|
||||
"""测试空列表"""
|
||||
config = StockFilterConfig()
|
||||
result = config.filter_codes([])
|
||||
assert result == []
|
||||
|
||||
def test_filter_codes_no_matching(self):
|
||||
"""测试全部排除"""
|
||||
config = StockFilterConfig()
|
||||
codes = ["300001.SZ", "688001.SH", "830001.BJ"]
|
||||
result = config.filter_codes(codes)
|
||||
assert result == []
|
||||
|
||||
|
||||
class TestMarketCapSelectorConfig:
|
||||
"""MarketCapSelectorConfig 测试类"""
|
||||
|
||||
def test_default_values(self):
|
||||
"""测试默认值"""
|
||||
config = MarketCapSelectorConfig()
|
||||
assert config.enabled is True
|
||||
assert config.n == 100
|
||||
assert config.ascending is False
|
||||
assert config.market_cap_col == "total_mv"
|
||||
|
||||
def test_custom_values(self):
|
||||
"""测试自定义值"""
|
||||
config = MarketCapSelectorConfig(
|
||||
enabled=False,
|
||||
n=50,
|
||||
ascending=True,
|
||||
market_cap_col="circ_mv",
|
||||
)
|
||||
assert config.enabled is False
|
||||
assert config.n == 50
|
||||
assert config.ascending is True
|
||||
assert config.market_cap_col == "circ_mv"
|
||||
|
||||
def test_invalid_n_zero(self):
|
||||
"""测试无效的 n=0"""
|
||||
with pytest.raises(ValueError, match="n 必须是正整数"):
|
||||
MarketCapSelectorConfig(n=0)
|
||||
|
||||
def test_invalid_n_negative(self):
|
||||
"""测试无效的负数 n"""
|
||||
with pytest.raises(ValueError, match="n 必须是正整数"):
|
||||
MarketCapSelectorConfig(n=-1)
|
||||
|
||||
def test_invalid_empty_market_cap_col(self):
|
||||
"""测试空的市值列名"""
|
||||
with pytest.raises(ValueError, match="market_cap_col 不能为空"):
|
||||
MarketCapSelectorConfig(market_cap_col="")
|
||||
|
||||
def test_large_n(self):
|
||||
"""测试大的 n 值"""
|
||||
config = MarketCapSelectorConfig(n=5000)
|
||||
assert config.n == 5000
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
Reference in New Issue
Block a user