feat(training): 添加数据过滤器支持及 ST 股票过滤

- 新增 filters.py 模块,实现 BaseFilter 抽象类和 STFilter 过滤器
- 在 Trainer 中支持 filters 参数,可在股票池筛选之前执行数据过滤
- 更新 training/__init__.py 导出 BaseFilter 和 STFilter
- 在 regression.py 中集成 STFilter,用于过滤 ST 股票
This commit is contained in:
2026-03-04 21:14:39 +08:00
parent f1687dadf3
commit af5c96cd53
4 changed files with 168 additions and 1 deletions

View File

@@ -0,0 +1,142 @@
"""数据过滤器组件
提供股票数据过滤功能,在因子计算后、市值筛选前执行。
与 Processor 不同Filter 是无状态的筛选操作。
"""
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Set
import polars as pl
if TYPE_CHECKING:
from src.factors.engine.data_router import DataRouter
class BaseFilter(ABC):
"""数据过滤器基类
Filter 用于从数据中移除不符合条件的行(股票)。
与 Processor 不同:
- Filter 是无状态的,不需要 fit
- Filter 删除整行数据,而不是变换列值
- Filter 每日独立执行
"""
name: str = ""
@abstractmethod
def filter(self, data: pl.DataFrame) -> pl.DataFrame:
"""执行过滤
Args:
data: 输入数据,必须包含 ts_code 和 trade_date 列
Returns:
过滤后的数据
"""
raise NotImplementedError
class STFilter(BaseFilter):
"""ST 股票过滤器
过滤掉每日的 ST 股票(包括 ST、*ST、S*ST、SST 等)。
从 stock_st 表获取每日 ST 股票列表进行过滤。
Attributes:
data_router: 数据路由器,用于获取 stock_st 表数据
code_col: 股票代码列名
date_col: 日期列名
"""
name = "st_filter"
def __init__(
self,
data_router: "DataRouter",
code_col: str = "ts_code",
date_col: str = "trade_date",
):
"""初始化 ST 过滤器
Args:
data_router: 数据路由器,用于查询 stock_st 表
code_col: 股票代码列名
date_col: 日期列名
"""
self.data_router = data_router
self.code_col = code_col
self.date_col = date_col
# 缓存:{date: set(stock_codes)}
self._st_cache: dict = {}
def filter(self, data: pl.DataFrame) -> pl.DataFrame:
"""过滤 ST 股票
按日期分组,每日独立从 stock_st 表获取 ST 列表并过滤。
Args:
data: 因子计算后的数据,包含 ts_code 和 trade_date
Returns:
过滤后的数据(不含 ST 股票)
"""
dates = data.select(self.date_col).unique().sort(self.date_col)
result_frames = []
for date in dates.to_series():
# 获取当日数据
daily_data = data.filter(pl.col(self.date_col) == date)
daily_codes = daily_data.select(self.code_col).to_series().to_list()
# 获取当日 ST 股票列表
st_codes = self._get_st_codes_for_date(date)
# 过滤掉 ST 股票
daily_filtered = daily_data.filter(~pl.col(self.code_col).is_in(st_codes))
result_frames.append(daily_filtered)
# 打印过滤信息
n_removed = len(daily_codes) - len(daily_filtered)
if n_removed > 0:
print(f" [{date}] 过滤 {n_removed} 只 ST 股票")
return pl.concat(result_frames)
def _get_st_codes_for_date(self, date: str) -> Set[str]:
"""从 stock_st 表获取指定日期的 ST 股票代码
Args:
date: 日期 "YYYYMMDD"
Returns:
ST 股票代码集合
"""
# 检查缓存
if date in self._st_cache:
return self._st_cache[date]
try:
from src.factors.engine.data_spec import DataSpec
# 查询 stock_st 表获取当日所有 ST 股票
data_specs = [DataSpec("stock_st", [self.code_col])]
df = self.data_router.fetch_data(
data_specs=data_specs,
start_date=date,
end_date=date,
stock_codes=None, # 获取当日全部 ST 股票
)
# 提取 ST 股票代码
st_codes = set(df[self.code_col].to_list()) if len(df) > 0 else set()
# 缓存结果
self._st_cache[date] = st_codes
return st_codes
except Exception as e:
print(f"[警告] 获取 {date} ST 股票列表失败: {e}")
return set()