feat(training): 实现股票池选择器配置
- 新增 StockFilterConfig:支持按代码前缀过滤创业板/科创板/北交所 - 新增 MarketCapSelectorConfig:配置市值选择参数(数量、排序、列名) - 添加参数验证(n>0, 列名非空) - 在 components 模块导出配置类 - 添加 15 个单元测试覆盖各种场景
This commit is contained in:
@@ -9,8 +9,16 @@ from src.training.components.base import BaseModel, BaseProcessor
|
|||||||
# 数据划分器
|
# 数据划分器
|
||||||
from src.training.components.splitters import DateSplitter
|
from src.training.components.splitters import DateSplitter
|
||||||
|
|
||||||
|
# 股票池选择器配置
|
||||||
|
from src.training.components.selectors import (
|
||||||
|
MarketCapSelectorConfig,
|
||||||
|
StockFilterConfig,
|
||||||
|
)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"BaseModel",
|
"BaseModel",
|
||||||
"BaseProcessor",
|
"BaseProcessor",
|
||||||
"DateSplitter",
|
"DateSplitter",
|
||||||
|
"StockFilterConfig",
|
||||||
|
"MarketCapSelectorConfig",
|
||||||
]
|
]
|
||||||
|
|||||||
81
src/training/components/selectors.py
Normal file
81
src/training/components/selectors.py
Normal file
@@ -0,0 +1,81 @@
|
|||||||
|
"""股票池选择器配置
|
||||||
|
|
||||||
|
提供股票过滤和市值选择的配置类。
|
||||||
|
"""
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class StockFilterConfig:
|
||||||
|
"""股票过滤器配置
|
||||||
|
|
||||||
|
用于过滤掉不需要的股票(如创业板、科创板等)。
|
||||||
|
基于股票代码进行过滤,不依赖外部数据。
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
exclude_cyb: 是否排除创业板(300xxx)
|
||||||
|
exclude_kcb: 是否排除科创板(688xxx)
|
||||||
|
exclude_bj: 是否排除北交所(8xxxxxx, 4xxxxxx)
|
||||||
|
exclude_st: 是否排除ST股票(需要外部数据支持)
|
||||||
|
"""
|
||||||
|
|
||||||
|
exclude_cyb: bool = True
|
||||||
|
exclude_kcb: bool = True
|
||||||
|
exclude_bj: bool = True
|
||||||
|
exclude_st: bool = True
|
||||||
|
|
||||||
|
def filter_codes(self, codes: List[str]) -> List[str]:
|
||||||
|
"""应用过滤条件,返回过滤后的股票代码列表
|
||||||
|
|
||||||
|
Args:
|
||||||
|
codes: 原始股票代码列表
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
过滤后的股票代码列表
|
||||||
|
|
||||||
|
Note:
|
||||||
|
ST 股票过滤需要额外数据,在 StockPoolManager 中处理。
|
||||||
|
此方法仅基于代码前缀进行过滤。
|
||||||
|
"""
|
||||||
|
result = []
|
||||||
|
for code in codes:
|
||||||
|
# 排除创业板(300xxx)
|
||||||
|
if self.exclude_cyb and code.startswith("300"):
|
||||||
|
continue
|
||||||
|
# 排除科创板(688xxx)
|
||||||
|
if self.exclude_kcb and code.startswith("688"):
|
||||||
|
continue
|
||||||
|
# 排除北交所(8xxxxxx 或 4xxxxxx)
|
||||||
|
if self.exclude_bj and (code.startswith("8") or code.startswith("4")):
|
||||||
|
continue
|
||||||
|
result.append(code)
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class MarketCapSelectorConfig:
|
||||||
|
"""市值选择器配置
|
||||||
|
|
||||||
|
每日独立选择市值最大或最小的 n 只股票。
|
||||||
|
市值数据从 daily_basic 表独立获取,仅用于筛选。
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
enabled: 是否启用选择
|
||||||
|
n: 选择前 n 只
|
||||||
|
ascending: False=最大市值, True=最小市值
|
||||||
|
market_cap_col: 市值列名(来自 daily_basic)
|
||||||
|
"""
|
||||||
|
|
||||||
|
enabled: bool = True
|
||||||
|
n: int = 100
|
||||||
|
ascending: bool = False
|
||||||
|
market_cap_col: str = "total_mv"
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
"""验证配置参数"""
|
||||||
|
if self.n <= 0:
|
||||||
|
raise ValueError(f"n 必须是正整数,得到: {self.n}")
|
||||||
|
if not self.market_cap_col:
|
||||||
|
raise ValueError("market_cap_col 不能为空")
|
||||||
183
tests/training/test_selectors.py
Normal file
183
tests/training/test_selectors.py
Normal file
@@ -0,0 +1,183 @@
|
|||||||
|
"""测试股票池选择器配置
|
||||||
|
|
||||||
|
验证 StockFilterConfig 和 MarketCapSelectorConfig 功能。
|
||||||
|
"""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from src.training.components.selectors import (
|
||||||
|
MarketCapSelectorConfig,
|
||||||
|
StockFilterConfig,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestStockFilterConfig:
|
||||||
|
"""StockFilterConfig 测试类"""
|
||||||
|
|
||||||
|
def test_default_values(self):
|
||||||
|
"""测试默认值"""
|
||||||
|
config = StockFilterConfig()
|
||||||
|
assert config.exclude_cyb is True
|
||||||
|
assert config.exclude_kcb is True
|
||||||
|
assert config.exclude_bj is True
|
||||||
|
assert config.exclude_st is True
|
||||||
|
|
||||||
|
def test_custom_values(self):
|
||||||
|
"""测试自定义值"""
|
||||||
|
config = StockFilterConfig(
|
||||||
|
exclude_cyb=False,
|
||||||
|
exclude_kcb=False,
|
||||||
|
exclude_bj=False,
|
||||||
|
exclude_st=False,
|
||||||
|
)
|
||||||
|
assert config.exclude_cyb is False
|
||||||
|
assert config.exclude_kcb is False
|
||||||
|
assert config.exclude_bj is False
|
||||||
|
assert config.exclude_st is False
|
||||||
|
|
||||||
|
def test_filter_codes_exclude_all(self):
|
||||||
|
"""测试排除所有类型"""
|
||||||
|
config = StockFilterConfig(
|
||||||
|
exclude_cyb=True,
|
||||||
|
exclude_kcb=True,
|
||||||
|
exclude_bj=True,
|
||||||
|
exclude_st=True,
|
||||||
|
)
|
||||||
|
codes = [
|
||||||
|
"000001.SZ", # 主板 - 保留
|
||||||
|
"300001.SZ", # 创业板 - 排除
|
||||||
|
"688001.SH", # 科创板 - 排除
|
||||||
|
"830001.BJ", # 北交所(8开头)- 排除
|
||||||
|
"430001.BJ", # 北交所(4开头)- 排除
|
||||||
|
]
|
||||||
|
result = config.filter_codes(codes)
|
||||||
|
assert result == ["000001.SZ"]
|
||||||
|
|
||||||
|
def test_filter_codes_allow_cyb(self):
|
||||||
|
"""测试允许创业板"""
|
||||||
|
config = StockFilterConfig(
|
||||||
|
exclude_cyb=False,
|
||||||
|
exclude_kcb=True,
|
||||||
|
exclude_bj=True,
|
||||||
|
exclude_st=True,
|
||||||
|
)
|
||||||
|
codes = [
|
||||||
|
"000001.SZ",
|
||||||
|
"300001.SZ",
|
||||||
|
"688001.SH",
|
||||||
|
]
|
||||||
|
result = config.filter_codes(codes)
|
||||||
|
assert result == ["000001.SZ", "300001.SZ"]
|
||||||
|
|
||||||
|
def test_filter_codes_allow_kcb(self):
|
||||||
|
"""测试允许科创板"""
|
||||||
|
config = StockFilterConfig(
|
||||||
|
exclude_cyb=True,
|
||||||
|
exclude_kcb=False,
|
||||||
|
exclude_bj=True,
|
||||||
|
exclude_st=True,
|
||||||
|
)
|
||||||
|
codes = [
|
||||||
|
"000001.SZ",
|
||||||
|
"300001.SZ",
|
||||||
|
"688001.SH",
|
||||||
|
]
|
||||||
|
result = config.filter_codes(codes)
|
||||||
|
assert result == ["000001.SZ", "688001.SH"]
|
||||||
|
|
||||||
|
def test_filter_codes_allow_bj(self):
|
||||||
|
"""测试允许北交所"""
|
||||||
|
config = StockFilterConfig(
|
||||||
|
exclude_cyb=True,
|
||||||
|
exclude_kcb=True,
|
||||||
|
exclude_bj=False,
|
||||||
|
exclude_st=True,
|
||||||
|
)
|
||||||
|
codes = [
|
||||||
|
"000001.SZ",
|
||||||
|
"300001.SZ",
|
||||||
|
"830001.BJ",
|
||||||
|
"430001.BJ",
|
||||||
|
]
|
||||||
|
result = config.filter_codes(codes)
|
||||||
|
assert result == ["000001.SZ", "830001.BJ", "430001.BJ"]
|
||||||
|
|
||||||
|
def test_filter_codes_allow_all(self):
|
||||||
|
"""测试允许所有类型"""
|
||||||
|
config = StockFilterConfig(
|
||||||
|
exclude_cyb=False,
|
||||||
|
exclude_kcb=False,
|
||||||
|
exclude_bj=False,
|
||||||
|
exclude_st=False,
|
||||||
|
)
|
||||||
|
codes = [
|
||||||
|
"000001.SZ",
|
||||||
|
"300001.SZ",
|
||||||
|
"688001.SH",
|
||||||
|
"830001.BJ",
|
||||||
|
"430001.BJ",
|
||||||
|
]
|
||||||
|
result = config.filter_codes(codes)
|
||||||
|
assert result == codes
|
||||||
|
|
||||||
|
def test_filter_codes_empty_list(self):
|
||||||
|
"""测试空列表"""
|
||||||
|
config = StockFilterConfig()
|
||||||
|
result = config.filter_codes([])
|
||||||
|
assert result == []
|
||||||
|
|
||||||
|
def test_filter_codes_no_matching(self):
|
||||||
|
"""测试全部排除"""
|
||||||
|
config = StockFilterConfig()
|
||||||
|
codes = ["300001.SZ", "688001.SH", "830001.BJ"]
|
||||||
|
result = config.filter_codes(codes)
|
||||||
|
assert result == []
|
||||||
|
|
||||||
|
|
||||||
|
class TestMarketCapSelectorConfig:
|
||||||
|
"""MarketCapSelectorConfig 测试类"""
|
||||||
|
|
||||||
|
def test_default_values(self):
|
||||||
|
"""测试默认值"""
|
||||||
|
config = MarketCapSelectorConfig()
|
||||||
|
assert config.enabled is True
|
||||||
|
assert config.n == 100
|
||||||
|
assert config.ascending is False
|
||||||
|
assert config.market_cap_col == "total_mv"
|
||||||
|
|
||||||
|
def test_custom_values(self):
|
||||||
|
"""测试自定义值"""
|
||||||
|
config = MarketCapSelectorConfig(
|
||||||
|
enabled=False,
|
||||||
|
n=50,
|
||||||
|
ascending=True,
|
||||||
|
market_cap_col="circ_mv",
|
||||||
|
)
|
||||||
|
assert config.enabled is False
|
||||||
|
assert config.n == 50
|
||||||
|
assert config.ascending is True
|
||||||
|
assert config.market_cap_col == "circ_mv"
|
||||||
|
|
||||||
|
def test_invalid_n_zero(self):
|
||||||
|
"""测试无效的 n=0"""
|
||||||
|
with pytest.raises(ValueError, match="n 必须是正整数"):
|
||||||
|
MarketCapSelectorConfig(n=0)
|
||||||
|
|
||||||
|
def test_invalid_n_negative(self):
|
||||||
|
"""测试无效的负数 n"""
|
||||||
|
with pytest.raises(ValueError, match="n 必须是正整数"):
|
||||||
|
MarketCapSelectorConfig(n=-1)
|
||||||
|
|
||||||
|
def test_invalid_empty_market_cap_col(self):
|
||||||
|
"""测试空的市值列名"""
|
||||||
|
with pytest.raises(ValueError, match="market_cap_col 不能为空"):
|
||||||
|
MarketCapSelectorConfig(market_cap_col="")
|
||||||
|
|
||||||
|
def test_large_n(self):
|
||||||
|
"""测试大的 n 值"""
|
||||||
|
config = MarketCapSelectorConfig(n=5000)
|
||||||
|
assert config.n == 5000
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
pytest.main([__file__, "-v"])
|
||||||
Reference in New Issue
Block a user