feat(training): 实现股票池选择器配置
- 新增 StockFilterConfig:支持按代码前缀过滤创业板/科创板/北交所 - 新增 MarketCapSelectorConfig:配置市值选择参数(数量、排序、列名) - 添加参数验证(n>0, 列名非空) - 在 components 模块导出配置类 - 添加 15 个单元测试覆盖各种场景
This commit is contained in:
@@ -9,8 +9,16 @@ from src.training.components.base import BaseModel, BaseProcessor
|
||||
# 数据划分器
|
||||
from src.training.components.splitters import DateSplitter
|
||||
|
||||
# 股票池选择器配置
|
||||
from src.training.components.selectors import (
|
||||
MarketCapSelectorConfig,
|
||||
StockFilterConfig,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"BaseModel",
|
||||
"BaseProcessor",
|
||||
"DateSplitter",
|
||||
"StockFilterConfig",
|
||||
"MarketCapSelectorConfig",
|
||||
]
|
||||
|
||||
81
src/training/components/selectors.py
Normal file
81
src/training/components/selectors.py
Normal file
@@ -0,0 +1,81 @@
|
||||
"""股票池选择器配置
|
||||
|
||||
提供股票过滤和市值选择的配置类。
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional
|
||||
|
||||
|
||||
@dataclass
|
||||
class StockFilterConfig:
|
||||
"""股票过滤器配置
|
||||
|
||||
用于过滤掉不需要的股票(如创业板、科创板等)。
|
||||
基于股票代码进行过滤,不依赖外部数据。
|
||||
|
||||
Attributes:
|
||||
exclude_cyb: 是否排除创业板(300xxx)
|
||||
exclude_kcb: 是否排除科创板(688xxx)
|
||||
exclude_bj: 是否排除北交所(8xxxxxx, 4xxxxxx)
|
||||
exclude_st: 是否排除ST股票(需要外部数据支持)
|
||||
"""
|
||||
|
||||
exclude_cyb: bool = True
|
||||
exclude_kcb: bool = True
|
||||
exclude_bj: bool = True
|
||||
exclude_st: bool = True
|
||||
|
||||
def filter_codes(self, codes: List[str]) -> List[str]:
|
||||
"""应用过滤条件,返回过滤后的股票代码列表
|
||||
|
||||
Args:
|
||||
codes: 原始股票代码列表
|
||||
|
||||
Returns:
|
||||
过滤后的股票代码列表
|
||||
|
||||
Note:
|
||||
ST 股票过滤需要额外数据,在 StockPoolManager 中处理。
|
||||
此方法仅基于代码前缀进行过滤。
|
||||
"""
|
||||
result = []
|
||||
for code in codes:
|
||||
# 排除创业板(300xxx)
|
||||
if self.exclude_cyb and code.startswith("300"):
|
||||
continue
|
||||
# 排除科创板(688xxx)
|
||||
if self.exclude_kcb and code.startswith("688"):
|
||||
continue
|
||||
# 排除北交所(8xxxxxx 或 4xxxxxx)
|
||||
if self.exclude_bj and (code.startswith("8") or code.startswith("4")):
|
||||
continue
|
||||
result.append(code)
|
||||
return result
|
||||
|
||||
|
||||
@dataclass
|
||||
class MarketCapSelectorConfig:
|
||||
"""市值选择器配置
|
||||
|
||||
每日独立选择市值最大或最小的 n 只股票。
|
||||
市值数据从 daily_basic 表独立获取,仅用于筛选。
|
||||
|
||||
Attributes:
|
||||
enabled: 是否启用选择
|
||||
n: 选择前 n 只
|
||||
ascending: False=最大市值, True=最小市值
|
||||
market_cap_col: 市值列名(来自 daily_basic)
|
||||
"""
|
||||
|
||||
enabled: bool = True
|
||||
n: int = 100
|
||||
ascending: bool = False
|
||||
market_cap_col: str = "total_mv"
|
||||
|
||||
def __post_init__(self):
|
||||
"""验证配置参数"""
|
||||
if self.n <= 0:
|
||||
raise ValueError(f"n 必须是正整数,得到: {self.n}")
|
||||
if not self.market_cap_col:
|
||||
raise ValueError("market_cap_col 不能为空")
|
||||
Reference in New Issue
Block a user