From 6b63c428d9ca3dcf0a64b3cb7e04f8fc61d08e55 Mon Sep 17 00:00:00 2001 From: liaozhaorun <1300336796@qq.com> Date: Tue, 3 Mar 2026 22:10:36 +0800 Subject: [PATCH] =?UTF-8?q?feat(training):=20=E5=AE=9E=E7=8E=B0=E8=82=A1?= =?UTF-8?q?=E7=A5=A8=E6=B1=A0=E9=80=89=E6=8B=A9=E5=99=A8=E9=85=8D=E7=BD=AE?= =?UTF-8?q?=20-=20=E6=96=B0=E5=A2=9E=20StockFilterConfig=EF=BC=9A=E6=94=AF?= =?UTF-8?q?=E6=8C=81=E6=8C=89=E4=BB=A3=E7=A0=81=E5=89=8D=E7=BC=80=E8=BF=87?= =?UTF-8?q?=E6=BB=A4=E5=88=9B=E4=B8=9A=E6=9D=BF/=E7=A7=91=E5=88=9B?= =?UTF-8?q?=E6=9D=BF/=E5=8C=97=E4=BA=A4=E6=89=80=20-=20=E6=96=B0=E5=A2=9E?= =?UTF-8?q?=20MarketCapSelectorConfig=EF=BC=9A=E9=85=8D=E7=BD=AE=E5=B8=82?= =?UTF-8?q?=E5=80=BC=E9=80=89=E6=8B=A9=E5=8F=82=E6=95=B0=EF=BC=88=E6=95=B0?= =?UTF-8?q?=E9=87=8F=E3=80=81=E6=8E=92=E5=BA=8F=E3=80=81=E5=88=97=E5=90=8D?= =?UTF-8?q?=EF=BC=89=20-=20=E6=B7=BB=E5=8A=A0=E5=8F=82=E6=95=B0=E9=AA=8C?= =?UTF-8?q?=E8=AF=81=EF=BC=88n>0,=20=E5=88=97=E5=90=8D=E9=9D=9E=E7=A9=BA?= =?UTF-8?q?=EF=BC=89=20-=20=E5=9C=A8=20components=20=E6=A8=A1=E5=9D=97?= =?UTF-8?q?=E5=AF=BC=E5=87=BA=E9=85=8D=E7=BD=AE=E7=B1=BB=20-=20=E6=B7=BB?= =?UTF-8?q?=E5=8A=A0=2015=20=E4=B8=AA=E5=8D=95=E5=85=83=E6=B5=8B=E8=AF=95?= =?UTF-8?q?=E8=A6=86=E7=9B=96=E5=90=84=E7=A7=8D=E5=9C=BA=E6=99=AF?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/training/components/__init__.py | 8 ++ src/training/components/selectors.py | 81 ++++++++++++ tests/training/test_selectors.py | 183 +++++++++++++++++++++++++++ 3 files changed, 272 insertions(+) create mode 100644 src/training/components/selectors.py create mode 100644 tests/training/test_selectors.py diff --git a/src/training/components/__init__.py b/src/training/components/__init__.py index 762d8d2..9b2ebe4 100644 --- a/src/training/components/__init__.py +++ b/src/training/components/__init__.py @@ -9,8 +9,16 @@ from src.training.components.base import BaseModel, BaseProcessor # 数据划分器 from src.training.components.splitters import DateSplitter +# 股票池选择器配置 +from src.training.components.selectors import ( + MarketCapSelectorConfig, + StockFilterConfig, +) + __all__ = [ "BaseModel", "BaseProcessor", "DateSplitter", + "StockFilterConfig", + "MarketCapSelectorConfig", ] diff --git a/src/training/components/selectors.py b/src/training/components/selectors.py new file mode 100644 index 0000000..b22daaa --- /dev/null +++ b/src/training/components/selectors.py @@ -0,0 +1,81 @@ +"""股票池选择器配置 + +提供股票过滤和市值选择的配置类。 +""" + +from dataclasses import dataclass +from typing import List, Optional + + +@dataclass +class StockFilterConfig: + """股票过滤器配置 + + 用于过滤掉不需要的股票(如创业板、科创板等)。 + 基于股票代码进行过滤,不依赖外部数据。 + + Attributes: + exclude_cyb: 是否排除创业板(300xxx) + exclude_kcb: 是否排除科创板(688xxx) + exclude_bj: 是否排除北交所(8xxxxxx, 4xxxxxx) + exclude_st: 是否排除ST股票(需要外部数据支持) + """ + + exclude_cyb: bool = True + exclude_kcb: bool = True + exclude_bj: bool = True + exclude_st: bool = True + + def filter_codes(self, codes: List[str]) -> List[str]: + """应用过滤条件,返回过滤后的股票代码列表 + + Args: + codes: 原始股票代码列表 + + Returns: + 过滤后的股票代码列表 + + Note: + ST 股票过滤需要额外数据,在 StockPoolManager 中处理。 + 此方法仅基于代码前缀进行过滤。 + """ + result = [] + for code in codes: + # 排除创业板(300xxx) + if self.exclude_cyb and code.startswith("300"): + continue + # 排除科创板(688xxx) + if self.exclude_kcb and code.startswith("688"): + continue + # 排除北交所(8xxxxxx 或 4xxxxxx) + if self.exclude_bj and (code.startswith("8") or code.startswith("4")): + continue + result.append(code) + return result + + +@dataclass +class MarketCapSelectorConfig: + """市值选择器配置 + + 每日独立选择市值最大或最小的 n 只股票。 + 市值数据从 daily_basic 表独立获取,仅用于筛选。 + + Attributes: + enabled: 是否启用选择 + n: 选择前 n 只 + ascending: False=最大市值, True=最小市值 + market_cap_col: 市值列名(来自 daily_basic) + """ + + enabled: bool = True + n: int = 100 + ascending: bool = False + market_cap_col: str = "total_mv" + + def __post_init__(self): + """验证配置参数""" + if self.n <= 0: + raise ValueError(f"n 必须是正整数,得到: {self.n}") + if not self.market_cap_col: + raise ValueError("market_cap_col 不能为空") diff --git a/tests/training/test_selectors.py b/tests/training/test_selectors.py new file mode 100644 index 0000000..faa9fb8 --- /dev/null +++ b/tests/training/test_selectors.py @@ -0,0 +1,183 @@ +"""测试股票池选择器配置 + +验证 StockFilterConfig 和 MarketCapSelectorConfig 功能。 +""" + +import pytest + +from src.training.components.selectors import ( + MarketCapSelectorConfig, + StockFilterConfig, +) + + +class TestStockFilterConfig: + """StockFilterConfig 测试类""" + + def test_default_values(self): + """测试默认值""" + config = StockFilterConfig() + assert config.exclude_cyb is True + assert config.exclude_kcb is True + assert config.exclude_bj is True + assert config.exclude_st is True + + def test_custom_values(self): + """测试自定义值""" + config = StockFilterConfig( + exclude_cyb=False, + exclude_kcb=False, + exclude_bj=False, + exclude_st=False, + ) + assert config.exclude_cyb is False + assert config.exclude_kcb is False + assert config.exclude_bj is False + assert config.exclude_st is False + + def test_filter_codes_exclude_all(self): + """测试排除所有类型""" + config = StockFilterConfig( + exclude_cyb=True, + exclude_kcb=True, + exclude_bj=True, + exclude_st=True, + ) + codes = [ + "000001.SZ", # 主板 - 保留 + "300001.SZ", # 创业板 - 排除 + "688001.SH", # 科创板 - 排除 + "830001.BJ", # 北交所(8开头)- 排除 + "430001.BJ", # 北交所(4开头)- 排除 + ] + result = config.filter_codes(codes) + assert result == ["000001.SZ"] + + def test_filter_codes_allow_cyb(self): + """测试允许创业板""" + config = StockFilterConfig( + exclude_cyb=False, + exclude_kcb=True, + exclude_bj=True, + exclude_st=True, + ) + codes = [ + "000001.SZ", + "300001.SZ", + "688001.SH", + ] + result = config.filter_codes(codes) + assert result == ["000001.SZ", "300001.SZ"] + + def test_filter_codes_allow_kcb(self): + """测试允许科创板""" + config = StockFilterConfig( + exclude_cyb=True, + exclude_kcb=False, + exclude_bj=True, + exclude_st=True, + ) + codes = [ + "000001.SZ", + "300001.SZ", + "688001.SH", + ] + result = config.filter_codes(codes) + assert result == ["000001.SZ", "688001.SH"] + + def test_filter_codes_allow_bj(self): + """测试允许北交所""" + config = StockFilterConfig( + exclude_cyb=True, + exclude_kcb=True, + exclude_bj=False, + exclude_st=True, + ) + codes = [ + "000001.SZ", + "300001.SZ", + "830001.BJ", + "430001.BJ", + ] + result = config.filter_codes(codes) + assert result == ["000001.SZ", "830001.BJ", "430001.BJ"] + + def test_filter_codes_allow_all(self): + """测试允许所有类型""" + config = StockFilterConfig( + exclude_cyb=False, + exclude_kcb=False, + exclude_bj=False, + exclude_st=False, + ) + codes = [ + "000001.SZ", + "300001.SZ", + "688001.SH", + "830001.BJ", + "430001.BJ", + ] + result = config.filter_codes(codes) + assert result == codes + + def test_filter_codes_empty_list(self): + """测试空列表""" + config = StockFilterConfig() + result = config.filter_codes([]) + assert result == [] + + def test_filter_codes_no_matching(self): + """测试全部排除""" + config = StockFilterConfig() + codes = ["300001.SZ", "688001.SH", "830001.BJ"] + result = config.filter_codes(codes) + assert result == [] + + +class TestMarketCapSelectorConfig: + """MarketCapSelectorConfig 测试类""" + + def test_default_values(self): + """测试默认值""" + config = MarketCapSelectorConfig() + assert config.enabled is True + assert config.n == 100 + assert config.ascending is False + assert config.market_cap_col == "total_mv" + + def test_custom_values(self): + """测试自定义值""" + config = MarketCapSelectorConfig( + enabled=False, + n=50, + ascending=True, + market_cap_col="circ_mv", + ) + assert config.enabled is False + assert config.n == 50 + assert config.ascending is True + assert config.market_cap_col == "circ_mv" + + def test_invalid_n_zero(self): + """测试无效的 n=0""" + with pytest.raises(ValueError, match="n 必须是正整数"): + MarketCapSelectorConfig(n=0) + + def test_invalid_n_negative(self): + """测试无效的负数 n""" + with pytest.raises(ValueError, match="n 必须是正整数"): + MarketCapSelectorConfig(n=-1) + + def test_invalid_empty_market_cap_col(self): + """测试空的市值列名""" + with pytest.raises(ValueError, match="market_cap_col 不能为空"): + MarketCapSelectorConfig(market_cap_col="") + + def test_large_n(self): + """测试大的 n 值""" + config = MarketCapSelectorConfig(n=5000) + assert config.n == 5000 + + +if __name__ == "__main__": + pytest.main([__file__, "-v"])