feat(training): 添加数据过滤器支持及 ST 股票过滤

- 新增 filters.py 模块,实现 BaseFilter 抽象类和 STFilter 过滤器
- 在 Trainer 中支持 filters 参数,可在股票池筛选之前执行数据过滤
- 更新 training/__init__.py 导出 BaseFilter 和 STFilter
- 在 regression.py 中集成 STFilter,用于过滤 ST 股票
This commit is contained in:
2026-03-04 21:14:39 +08:00
parent f1687dadf3
commit af5c96cd53
4 changed files with 168 additions and 1 deletions

View File

@@ -13,6 +13,7 @@ from src.factors import FactorEngine
from src.training import ( from src.training import (
DateSplitter, DateSplitter,
LightGBMModel, LightGBMModel,
STFilter,
StandardScaler, StandardScaler,
StockFilterConfig, StockFilterConfig,
StockPoolManager, StockPoolManager,
@@ -223,11 +224,17 @@ def train_regression_model():
data_router=engine.router, # 从 FactorEngine 获取数据路由器 data_router=engine.router, # 从 FactorEngine 获取数据路由器
) )
# 8.5 创建 ST 股票过滤器(在股票池筛选之前执行)
st_filter = STFilter(
data_router=engine.router,
)
# 9. 创建训练器 # 9. 创建训练器
trainer = Trainer( trainer = Trainer(
model=model, model=model,
pool_manager=pool_manager, pool_manager=pool_manager,
processors=processors, processors=processors,
filters=[st_filter], # 在股票池筛选之前过滤 ST 股票
splitter=splitter, splitter=splitter,
target_col=target_col, target_col=target_col,
feature_cols=feature_cols, feature_cols=feature_cols,

View File

@@ -33,6 +33,9 @@ from src.training.components.processors import (
# 模型 # 模型
from src.training.components.models import LightGBMModel from src.training.components.models import LightGBMModel
# 数据过滤器
from src.training.components.filters import BaseFilter, STFilter
# 训练核心 # 训练核心
from src.training.core import StockPoolManager, Trainer from src.training.core import StockPoolManager, Trainer
@@ -57,6 +60,9 @@ __all__ = [
"StandardScaler", "StandardScaler",
"CrossSectionalStandardScaler", "CrossSectionalStandardScaler",
"Winsorizer", "Winsorizer",
# 数据过滤器
"BaseFilter",
"STFilter",
# 模型 # 模型
"LightGBMModel", "LightGBMModel",
# 训练核心 # 训练核心

View File

@@ -0,0 +1,142 @@
"""数据过滤器组件
提供股票数据过滤功能,在因子计算后、市值筛选前执行。
与 Processor 不同Filter 是无状态的筛选操作。
"""
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Set
import polars as pl
if TYPE_CHECKING:
from src.factors.engine.data_router import DataRouter
class BaseFilter(ABC):
"""数据过滤器基类
Filter 用于从数据中移除不符合条件的行(股票)。
与 Processor 不同:
- Filter 是无状态的,不需要 fit
- Filter 删除整行数据,而不是变换列值
- Filter 每日独立执行
"""
name: str = ""
@abstractmethod
def filter(self, data: pl.DataFrame) -> pl.DataFrame:
"""执行过滤
Args:
data: 输入数据,必须包含 ts_code 和 trade_date 列
Returns:
过滤后的数据
"""
raise NotImplementedError
class STFilter(BaseFilter):
"""ST 股票过滤器
过滤掉每日的 ST 股票(包括 ST、*ST、S*ST、SST 等)。
从 stock_st 表获取每日 ST 股票列表进行过滤。
Attributes:
data_router: 数据路由器,用于获取 stock_st 表数据
code_col: 股票代码列名
date_col: 日期列名
"""
name = "st_filter"
def __init__(
self,
data_router: "DataRouter",
code_col: str = "ts_code",
date_col: str = "trade_date",
):
"""初始化 ST 过滤器
Args:
data_router: 数据路由器,用于查询 stock_st 表
code_col: 股票代码列名
date_col: 日期列名
"""
self.data_router = data_router
self.code_col = code_col
self.date_col = date_col
# 缓存:{date: set(stock_codes)}
self._st_cache: dict = {}
def filter(self, data: pl.DataFrame) -> pl.DataFrame:
"""过滤 ST 股票
按日期分组,每日独立从 stock_st 表获取 ST 列表并过滤。
Args:
data: 因子计算后的数据,包含 ts_code 和 trade_date
Returns:
过滤后的数据(不含 ST 股票)
"""
dates = data.select(self.date_col).unique().sort(self.date_col)
result_frames = []
for date in dates.to_series():
# 获取当日数据
daily_data = data.filter(pl.col(self.date_col) == date)
daily_codes = daily_data.select(self.code_col).to_series().to_list()
# 获取当日 ST 股票列表
st_codes = self._get_st_codes_for_date(date)
# 过滤掉 ST 股票
daily_filtered = daily_data.filter(~pl.col(self.code_col).is_in(st_codes))
result_frames.append(daily_filtered)
# 打印过滤信息
n_removed = len(daily_codes) - len(daily_filtered)
if n_removed > 0:
print(f" [{date}] 过滤 {n_removed} 只 ST 股票")
return pl.concat(result_frames)
def _get_st_codes_for_date(self, date: str) -> Set[str]:
"""从 stock_st 表获取指定日期的 ST 股票代码
Args:
date: 日期 "YYYYMMDD"
Returns:
ST 股票代码集合
"""
# 检查缓存
if date in self._st_cache:
return self._st_cache[date]
try:
from src.factors.engine.data_spec import DataSpec
# 查询 stock_st 表获取当日所有 ST 股票
data_specs = [DataSpec("stock_st", [self.code_col])]
df = self.data_router.fetch_data(
data_specs=data_specs,
start_date=date,
end_date=date,
stock_codes=None, # 获取当日全部 ST 股票
)
# 提取 ST 股票代码
st_codes = set(df[self.code_col].to_list()) if len(df) > 0 else set()
# 缓存结果
self._st_cache[date] = st_codes
return st_codes
except Exception as e:
print(f"[警告] 获取 {date} ST 股票列表失败: {e}")
return set()

View File

@@ -3,7 +3,7 @@
整合数据处理、模型训练、预测的完整流程。 整合数据处理、模型训练、预测的完整流程。
""" """
from typing import List, Optional from typing import TYPE_CHECKING, List, Optional
import polars as pl import polars as pl
@@ -11,6 +11,9 @@ from src.training.components.base import BaseModel, BaseProcessor
from src.training.components.splitters import DateSplitter from src.training.components.splitters import DateSplitter
from src.training.core.stock_pool_manager import StockPoolManager from src.training.core.stock_pool_manager import StockPoolManager
if TYPE_CHECKING:
from src.training.components.filters import BaseFilter
class Trainer: class Trainer:
"""训练器主类 """训练器主类
@@ -29,6 +32,7 @@ class Trainer:
model: BaseModel, model: BaseModel,
pool_manager: Optional[StockPoolManager] = None, pool_manager: Optional[StockPoolManager] = None,
processors: Optional[List[BaseProcessor]] = None, processors: Optional[List[BaseProcessor]] = None,
filters: Optional[List["BaseFilter"]] = None,
splitter: Optional[DateSplitter] = None, splitter: Optional[DateSplitter] = None,
target_col: str = "target", target_col: str = "target",
feature_cols: Optional[List[str]] = None, feature_cols: Optional[List[str]] = None,
@@ -41,6 +45,7 @@ class Trainer:
model: 模型实例 model: 模型实例
pool_manager: 股票池管理器None 表示不筛选 pool_manager: 股票池管理器None 表示不筛选
processors: 数据处理器列表 processors: 数据处理器列表
filters: 数据过滤器列表(在股票池筛选之前执行)
splitter: 数据划分器 splitter: 数据划分器
target_col: 目标变量列名 target_col: 目标变量列名
feature_cols: 特征列名列表 feature_cols: 特征列名列表
@@ -50,6 +55,7 @@ class Trainer:
self.model = model self.model = model
self.pool_manager = pool_manager self.pool_manager = pool_manager
self.processors = processors or [] self.processors = processors or []
self.filters = filters or []
self.splitter = splitter self.splitter = splitter
self.target_col = target_col self.target_col = target_col
self.feature_cols = feature_cols or [] self.feature_cols = feature_cols or []
@@ -80,6 +86,12 @@ class Trainer:
Returns: Returns:
self (支持链式调用) self (支持链式调用)
""" """
# 0. 数据过滤(在股票池筛选之前)
if self.filters:
print("[过滤] 应用数据过滤器...")
for filter_ in self.filters:
data = filter_.filter(data)
# 1. 股票池筛选(每日独立) # 1. 股票池筛选(每日独立)
if self.pool_manager: if self.pool_manager:
print("[筛选] 每日独立筛选股票池...") print("[筛选] 每日独立筛选股票池...")