diff --git a/src/experiment/regression.py b/src/experiment/regression.py index 734552b..bea1e03 100644 --- a/src/experiment/regression.py +++ b/src/experiment/regression.py @@ -13,6 +13,7 @@ from src.factors import FactorEngine from src.training import ( DateSplitter, LightGBMModel, + STFilter, StandardScaler, StockFilterConfig, StockPoolManager, @@ -223,11 +224,17 @@ def train_regression_model(): data_router=engine.router, # 从 FactorEngine 获取数据路由器 ) + # 8.5 创建 ST 股票过滤器(在股票池筛选之前执行) + st_filter = STFilter( + data_router=engine.router, + ) + # 9. 创建训练器 trainer = Trainer( model=model, pool_manager=pool_manager, processors=processors, + filters=[st_filter], # 在股票池筛选之前过滤 ST 股票 splitter=splitter, target_col=target_col, feature_cols=feature_cols, diff --git a/src/training/__init__.py b/src/training/__init__.py index ac2fa66..e885e2c 100644 --- a/src/training/__init__.py +++ b/src/training/__init__.py @@ -33,6 +33,9 @@ from src.training.components.processors import ( # 模型 from src.training.components.models import LightGBMModel +# 数据过滤器 +from src.training.components.filters import BaseFilter, STFilter + # 训练核心 from src.training.core import StockPoolManager, Trainer @@ -57,6 +60,9 @@ __all__ = [ "StandardScaler", "CrossSectionalStandardScaler", "Winsorizer", + # 数据过滤器 + "BaseFilter", + "STFilter", # 模型 "LightGBMModel", # 训练核心 diff --git a/src/training/components/filters.py b/src/training/components/filters.py new file mode 100644 index 0000000..d4f165c --- /dev/null +++ b/src/training/components/filters.py @@ -0,0 +1,142 @@ +"""数据过滤器组件 + +提供股票数据过滤功能,在因子计算后、市值筛选前执行。 +与 Processor 不同,Filter 是无状态的筛选操作。 +""" + +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING, Set + +import polars as pl + +if TYPE_CHECKING: + from src.factors.engine.data_router import DataRouter + + +class BaseFilter(ABC): + """数据过滤器基类 + + Filter 用于从数据中移除不符合条件的行(股票)。 + 与 Processor 不同: + - Filter 是无状态的,不需要 fit + - Filter 删除整行数据,而不是变换列值 + - Filter 每日独立执行 + """ + + name: str = "" + + @abstractmethod + def filter(self, data: pl.DataFrame) -> pl.DataFrame: + """执行过滤 + + Args: + data: 输入数据,必须包含 ts_code 和 trade_date 列 + + Returns: + 过滤后的数据 + """ + raise NotImplementedError + + +class STFilter(BaseFilter): + """ST 股票过滤器 + + 过滤掉每日的 ST 股票(包括 ST、*ST、S*ST、SST 等)。 + 从 stock_st 表获取每日 ST 股票列表进行过滤。 + + Attributes: + data_router: 数据路由器,用于获取 stock_st 表数据 + code_col: 股票代码列名 + date_col: 日期列名 + """ + + name = "st_filter" + + def __init__( + self, + data_router: "DataRouter", + code_col: str = "ts_code", + date_col: str = "trade_date", + ): + """初始化 ST 过滤器 + + Args: + data_router: 数据路由器,用于查询 stock_st 表 + code_col: 股票代码列名 + date_col: 日期列名 + """ + self.data_router = data_router + self.code_col = code_col + self.date_col = date_col + # 缓存:{date: set(stock_codes)} + self._st_cache: dict = {} + + def filter(self, data: pl.DataFrame) -> pl.DataFrame: + """过滤 ST 股票 + + 按日期分组,每日独立从 stock_st 表获取 ST 列表并过滤。 + + Args: + data: 因子计算后的数据,包含 ts_code 和 trade_date + + Returns: + 过滤后的数据(不含 ST 股票) + """ + dates = data.select(self.date_col).unique().sort(self.date_col) + + result_frames = [] + for date in dates.to_series(): + # 获取当日数据 + daily_data = data.filter(pl.col(self.date_col) == date) + daily_codes = daily_data.select(self.code_col).to_series().to_list() + + # 获取当日 ST 股票列表 + st_codes = self._get_st_codes_for_date(date) + + # 过滤掉 ST 股票 + daily_filtered = daily_data.filter(~pl.col(self.code_col).is_in(st_codes)) + + result_frames.append(daily_filtered) + + # 打印过滤信息 + n_removed = len(daily_codes) - len(daily_filtered) + if n_removed > 0: + print(f" [{date}] 过滤 {n_removed} 只 ST 股票") + + return pl.concat(result_frames) + + def _get_st_codes_for_date(self, date: str) -> Set[str]: + """从 stock_st 表获取指定日期的 ST 股票代码 + + Args: + date: 日期 "YYYYMMDD" + + Returns: + ST 股票代码集合 + """ + # 检查缓存 + if date in self._st_cache: + return self._st_cache[date] + + try: + from src.factors.engine.data_spec import DataSpec + + # 查询 stock_st 表获取当日所有 ST 股票 + data_specs = [DataSpec("stock_st", [self.code_col])] + df = self.data_router.fetch_data( + data_specs=data_specs, + start_date=date, + end_date=date, + stock_codes=None, # 获取当日全部 ST 股票 + ) + + # 提取 ST 股票代码 + st_codes = set(df[self.code_col].to_list()) if len(df) > 0 else set() + + # 缓存结果 + self._st_cache[date] = st_codes + return st_codes + + except Exception as e: + print(f"[警告] 获取 {date} ST 股票列表失败: {e}") + return set() diff --git a/src/training/core/trainer.py b/src/training/core/trainer.py index c33c174..96e28f7 100644 --- a/src/training/core/trainer.py +++ b/src/training/core/trainer.py @@ -3,7 +3,7 @@ 整合数据处理、模型训练、预测的完整流程。 """ -from typing import List, Optional +from typing import TYPE_CHECKING, List, Optional import polars as pl @@ -11,6 +11,9 @@ from src.training.components.base import BaseModel, BaseProcessor from src.training.components.splitters import DateSplitter from src.training.core.stock_pool_manager import StockPoolManager +if TYPE_CHECKING: + from src.training.components.filters import BaseFilter + class Trainer: """训练器主类 @@ -29,6 +32,7 @@ class Trainer: model: BaseModel, pool_manager: Optional[StockPoolManager] = None, processors: Optional[List[BaseProcessor]] = None, + filters: Optional[List["BaseFilter"]] = None, splitter: Optional[DateSplitter] = None, target_col: str = "target", feature_cols: Optional[List[str]] = None, @@ -41,6 +45,7 @@ class Trainer: model: 模型实例 pool_manager: 股票池管理器,None 表示不筛选 processors: 数据处理器列表 + filters: 数据过滤器列表(在股票池筛选之前执行) splitter: 数据划分器 target_col: 目标变量列名 feature_cols: 特征列名列表 @@ -50,6 +55,7 @@ class Trainer: self.model = model self.pool_manager = pool_manager self.processors = processors or [] + self.filters = filters or [] self.splitter = splitter self.target_col = target_col self.feature_cols = feature_cols or [] @@ -80,6 +86,12 @@ class Trainer: Returns: self (支持链式调用) """ + # 0. 数据过滤(在股票池筛选之前) + if self.filters: + print("[过滤] 应用数据过滤器...") + for filter_ in self.filters: + data = filter_.filter(data) + # 1. 股票池筛选(每日独立) if self.pool_manager: print("[筛选] 每日独立筛选股票池...")