diff --git a/src/training/components/__init__.py b/src/training/components/__init__.py index 762d8d2..9b2ebe4 100644 --- a/src/training/components/__init__.py +++ b/src/training/components/__init__.py @@ -9,8 +9,16 @@ from src.training.components.base import BaseModel, BaseProcessor # 数据划分器 from src.training.components.splitters import DateSplitter +# 股票池选择器配置 +from src.training.components.selectors import ( + MarketCapSelectorConfig, + StockFilterConfig, +) + __all__ = [ "BaseModel", "BaseProcessor", "DateSplitter", + "StockFilterConfig", + "MarketCapSelectorConfig", ] diff --git a/src/training/components/selectors.py b/src/training/components/selectors.py new file mode 100644 index 0000000..b22daaa --- /dev/null +++ b/src/training/components/selectors.py @@ -0,0 +1,81 @@ +"""股票池选择器配置 + +提供股票过滤和市值选择的配置类。 +""" + +from dataclasses import dataclass +from typing import List, Optional + + +@dataclass +class StockFilterConfig: + """股票过滤器配置 + + 用于过滤掉不需要的股票(如创业板、科创板等)。 + 基于股票代码进行过滤,不依赖外部数据。 + + Attributes: + exclude_cyb: 是否排除创业板(300xxx) + exclude_kcb: 是否排除科创板(688xxx) + exclude_bj: 是否排除北交所(8xxxxxx, 4xxxxxx) + exclude_st: 是否排除ST股票(需要外部数据支持) + """ + + exclude_cyb: bool = True + exclude_kcb: bool = True + exclude_bj: bool = True + exclude_st: bool = True + + def filter_codes(self, codes: List[str]) -> List[str]: + """应用过滤条件,返回过滤后的股票代码列表 + + Args: + codes: 原始股票代码列表 + + Returns: + 过滤后的股票代码列表 + + Note: + ST 股票过滤需要额外数据,在 StockPoolManager 中处理。 + 此方法仅基于代码前缀进行过滤。 + """ + result = [] + for code in codes: + # 排除创业板(300xxx) + if self.exclude_cyb and code.startswith("300"): + continue + # 排除科创板(688xxx) + if self.exclude_kcb and code.startswith("688"): + continue + # 排除北交所(8xxxxxx 或 4xxxxxx) + if self.exclude_bj and (code.startswith("8") or code.startswith("4")): + continue + result.append(code) + return result + + +@dataclass +class MarketCapSelectorConfig: + """市值选择器配置 + + 每日独立选择市值最大或最小的 n 只股票。 + 市值数据从 daily_basic 表独立获取,仅用于筛选。 + + Attributes: + enabled: 是否启用选择 + n: 选择前 n 只 + ascending: False=最大市值, True=最小市值 + market_cap_col: 市值列名(来自 daily_basic) + """ + + enabled: bool = True + n: int = 100 + ascending: bool = False + market_cap_col: str = "total_mv" + + def __post_init__(self): + """验证配置参数""" + if self.n <= 0: + raise ValueError(f"n 必须是正整数,得到: {self.n}") + if not self.market_cap_col: + raise ValueError("market_cap_col 不能为空") diff --git a/tests/training/test_selectors.py b/tests/training/test_selectors.py new file mode 100644 index 0000000..faa9fb8 --- /dev/null +++ b/tests/training/test_selectors.py @@ -0,0 +1,183 @@ +"""测试股票池选择器配置 + +验证 StockFilterConfig 和 MarketCapSelectorConfig 功能。 +""" + +import pytest + +from src.training.components.selectors import ( + MarketCapSelectorConfig, + StockFilterConfig, +) + + +class TestStockFilterConfig: + """StockFilterConfig 测试类""" + + def test_default_values(self): + """测试默认值""" + config = StockFilterConfig() + assert config.exclude_cyb is True + assert config.exclude_kcb is True + assert config.exclude_bj is True + assert config.exclude_st is True + + def test_custom_values(self): + """测试自定义值""" + config = StockFilterConfig( + exclude_cyb=False, + exclude_kcb=False, + exclude_bj=False, + exclude_st=False, + ) + assert config.exclude_cyb is False + assert config.exclude_kcb is False + assert config.exclude_bj is False + assert config.exclude_st is False + + def test_filter_codes_exclude_all(self): + """测试排除所有类型""" + config = StockFilterConfig( + exclude_cyb=True, + exclude_kcb=True, + exclude_bj=True, + exclude_st=True, + ) + codes = [ + "000001.SZ", # 主板 - 保留 + "300001.SZ", # 创业板 - 排除 + "688001.SH", # 科创板 - 排除 + "830001.BJ", # 北交所(8开头)- 排除 + "430001.BJ", # 北交所(4开头)- 排除 + ] + result = config.filter_codes(codes) + assert result == ["000001.SZ"] + + def test_filter_codes_allow_cyb(self): + """测试允许创业板""" + config = StockFilterConfig( + exclude_cyb=False, + exclude_kcb=True, + exclude_bj=True, + exclude_st=True, + ) + codes = [ + "000001.SZ", + "300001.SZ", + "688001.SH", + ] + result = config.filter_codes(codes) + assert result == ["000001.SZ", "300001.SZ"] + + def test_filter_codes_allow_kcb(self): + """测试允许科创板""" + config = StockFilterConfig( + exclude_cyb=True, + exclude_kcb=False, + exclude_bj=True, + exclude_st=True, + ) + codes = [ + "000001.SZ", + "300001.SZ", + "688001.SH", + ] + result = config.filter_codes(codes) + assert result == ["000001.SZ", "688001.SH"] + + def test_filter_codes_allow_bj(self): + """测试允许北交所""" + config = StockFilterConfig( + exclude_cyb=True, + exclude_kcb=True, + exclude_bj=False, + exclude_st=True, + ) + codes = [ + "000001.SZ", + "300001.SZ", + "830001.BJ", + "430001.BJ", + ] + result = config.filter_codes(codes) + assert result == ["000001.SZ", "830001.BJ", "430001.BJ"] + + def test_filter_codes_allow_all(self): + """测试允许所有类型""" + config = StockFilterConfig( + exclude_cyb=False, + exclude_kcb=False, + exclude_bj=False, + exclude_st=False, + ) + codes = [ + "000001.SZ", + "300001.SZ", + "688001.SH", + "830001.BJ", + "430001.BJ", + ] + result = config.filter_codes(codes) + assert result == codes + + def test_filter_codes_empty_list(self): + """测试空列表""" + config = StockFilterConfig() + result = config.filter_codes([]) + assert result == [] + + def test_filter_codes_no_matching(self): + """测试全部排除""" + config = StockFilterConfig() + codes = ["300001.SZ", "688001.SH", "830001.BJ"] + result = config.filter_codes(codes) + assert result == [] + + +class TestMarketCapSelectorConfig: + """MarketCapSelectorConfig 测试类""" + + def test_default_values(self): + """测试默认值""" + config = MarketCapSelectorConfig() + assert config.enabled is True + assert config.n == 100 + assert config.ascending is False + assert config.market_cap_col == "total_mv" + + def test_custom_values(self): + """测试自定义值""" + config = MarketCapSelectorConfig( + enabled=False, + n=50, + ascending=True, + market_cap_col="circ_mv", + ) + assert config.enabled is False + assert config.n == 50 + assert config.ascending is True + assert config.market_cap_col == "circ_mv" + + def test_invalid_n_zero(self): + """测试无效的 n=0""" + with pytest.raises(ValueError, match="n 必须是正整数"): + MarketCapSelectorConfig(n=0) + + def test_invalid_n_negative(self): + """测试无效的负数 n""" + with pytest.raises(ValueError, match="n 必须是正整数"): + MarketCapSelectorConfig(n=-1) + + def test_invalid_empty_market_cap_col(self): + """测试空的市值列名""" + with pytest.raises(ValueError, match="market_cap_col 不能为空"): + MarketCapSelectorConfig(market_cap_col="") + + def test_large_n(self): + """测试大的 n 值""" + config = MarketCapSelectorConfig(n=5000) + assert config.n == 5000 + + +if __name__ == "__main__": + pytest.main([__file__, "-v"])