feat(training): 添加数据过滤器支持及 ST 股票过滤
- 新增 filters.py 模块,实现 BaseFilter 抽象类和 STFilter 过滤器 - 在 Trainer 中支持 filters 参数,可在股票池筛选之前执行数据过滤 - 更新 training/__init__.py 导出 BaseFilter 和 STFilter - 在 regression.py 中集成 STFilter,用于过滤 ST 股票
This commit is contained in:
@@ -13,6 +13,7 @@ from src.factors import FactorEngine
|
||||
from src.training import (
|
||||
DateSplitter,
|
||||
LightGBMModel,
|
||||
STFilter,
|
||||
StandardScaler,
|
||||
StockFilterConfig,
|
||||
StockPoolManager,
|
||||
@@ -223,11 +224,17 @@ def train_regression_model():
|
||||
data_router=engine.router, # 从 FactorEngine 获取数据路由器
|
||||
)
|
||||
|
||||
# 8.5 创建 ST 股票过滤器(在股票池筛选之前执行)
|
||||
st_filter = STFilter(
|
||||
data_router=engine.router,
|
||||
)
|
||||
|
||||
# 9. 创建训练器
|
||||
trainer = Trainer(
|
||||
model=model,
|
||||
pool_manager=pool_manager,
|
||||
processors=processors,
|
||||
filters=[st_filter], # 在股票池筛选之前过滤 ST 股票
|
||||
splitter=splitter,
|
||||
target_col=target_col,
|
||||
feature_cols=feature_cols,
|
||||
|
||||
@@ -33,6 +33,9 @@ from src.training.components.processors import (
|
||||
# 模型
|
||||
from src.training.components.models import LightGBMModel
|
||||
|
||||
# 数据过滤器
|
||||
from src.training.components.filters import BaseFilter, STFilter
|
||||
|
||||
# 训练核心
|
||||
from src.training.core import StockPoolManager, Trainer
|
||||
|
||||
@@ -57,6 +60,9 @@ __all__ = [
|
||||
"StandardScaler",
|
||||
"CrossSectionalStandardScaler",
|
||||
"Winsorizer",
|
||||
# 数据过滤器
|
||||
"BaseFilter",
|
||||
"STFilter",
|
||||
# 模型
|
||||
"LightGBMModel",
|
||||
# 训练核心
|
||||
|
||||
142
src/training/components/filters.py
Normal file
142
src/training/components/filters.py
Normal file
@@ -0,0 +1,142 @@
|
||||
"""数据过滤器组件
|
||||
|
||||
提供股票数据过滤功能,在因子计算后、市值筛选前执行。
|
||||
与 Processor 不同,Filter 是无状态的筛选操作。
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import TYPE_CHECKING, Set
|
||||
|
||||
import polars as pl
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from src.factors.engine.data_router import DataRouter
|
||||
|
||||
|
||||
class BaseFilter(ABC):
|
||||
"""数据过滤器基类
|
||||
|
||||
Filter 用于从数据中移除不符合条件的行(股票)。
|
||||
与 Processor 不同:
|
||||
- Filter 是无状态的,不需要 fit
|
||||
- Filter 删除整行数据,而不是变换列值
|
||||
- Filter 每日独立执行
|
||||
"""
|
||||
|
||||
name: str = ""
|
||||
|
||||
@abstractmethod
|
||||
def filter(self, data: pl.DataFrame) -> pl.DataFrame:
|
||||
"""执行过滤
|
||||
|
||||
Args:
|
||||
data: 输入数据,必须包含 ts_code 和 trade_date 列
|
||||
|
||||
Returns:
|
||||
过滤后的数据
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class STFilter(BaseFilter):
|
||||
"""ST 股票过滤器
|
||||
|
||||
过滤掉每日的 ST 股票(包括 ST、*ST、S*ST、SST 等)。
|
||||
从 stock_st 表获取每日 ST 股票列表进行过滤。
|
||||
|
||||
Attributes:
|
||||
data_router: 数据路由器,用于获取 stock_st 表数据
|
||||
code_col: 股票代码列名
|
||||
date_col: 日期列名
|
||||
"""
|
||||
|
||||
name = "st_filter"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
data_router: "DataRouter",
|
||||
code_col: str = "ts_code",
|
||||
date_col: str = "trade_date",
|
||||
):
|
||||
"""初始化 ST 过滤器
|
||||
|
||||
Args:
|
||||
data_router: 数据路由器,用于查询 stock_st 表
|
||||
code_col: 股票代码列名
|
||||
date_col: 日期列名
|
||||
"""
|
||||
self.data_router = data_router
|
||||
self.code_col = code_col
|
||||
self.date_col = date_col
|
||||
# 缓存:{date: set(stock_codes)}
|
||||
self._st_cache: dict = {}
|
||||
|
||||
def filter(self, data: pl.DataFrame) -> pl.DataFrame:
|
||||
"""过滤 ST 股票
|
||||
|
||||
按日期分组,每日独立从 stock_st 表获取 ST 列表并过滤。
|
||||
|
||||
Args:
|
||||
data: 因子计算后的数据,包含 ts_code 和 trade_date
|
||||
|
||||
Returns:
|
||||
过滤后的数据(不含 ST 股票)
|
||||
"""
|
||||
dates = data.select(self.date_col).unique().sort(self.date_col)
|
||||
|
||||
result_frames = []
|
||||
for date in dates.to_series():
|
||||
# 获取当日数据
|
||||
daily_data = data.filter(pl.col(self.date_col) == date)
|
||||
daily_codes = daily_data.select(self.code_col).to_series().to_list()
|
||||
|
||||
# 获取当日 ST 股票列表
|
||||
st_codes = self._get_st_codes_for_date(date)
|
||||
|
||||
# 过滤掉 ST 股票
|
||||
daily_filtered = daily_data.filter(~pl.col(self.code_col).is_in(st_codes))
|
||||
|
||||
result_frames.append(daily_filtered)
|
||||
|
||||
# 打印过滤信息
|
||||
n_removed = len(daily_codes) - len(daily_filtered)
|
||||
if n_removed > 0:
|
||||
print(f" [{date}] 过滤 {n_removed} 只 ST 股票")
|
||||
|
||||
return pl.concat(result_frames)
|
||||
|
||||
def _get_st_codes_for_date(self, date: str) -> Set[str]:
|
||||
"""从 stock_st 表获取指定日期的 ST 股票代码
|
||||
|
||||
Args:
|
||||
date: 日期 "YYYYMMDD"
|
||||
|
||||
Returns:
|
||||
ST 股票代码集合
|
||||
"""
|
||||
# 检查缓存
|
||||
if date in self._st_cache:
|
||||
return self._st_cache[date]
|
||||
|
||||
try:
|
||||
from src.factors.engine.data_spec import DataSpec
|
||||
|
||||
# 查询 stock_st 表获取当日所有 ST 股票
|
||||
data_specs = [DataSpec("stock_st", [self.code_col])]
|
||||
df = self.data_router.fetch_data(
|
||||
data_specs=data_specs,
|
||||
start_date=date,
|
||||
end_date=date,
|
||||
stock_codes=None, # 获取当日全部 ST 股票
|
||||
)
|
||||
|
||||
# 提取 ST 股票代码
|
||||
st_codes = set(df[self.code_col].to_list()) if len(df) > 0 else set()
|
||||
|
||||
# 缓存结果
|
||||
self._st_cache[date] = st_codes
|
||||
return st_codes
|
||||
|
||||
except Exception as e:
|
||||
print(f"[警告] 获取 {date} ST 股票列表失败: {e}")
|
||||
return set()
|
||||
@@ -3,7 +3,7 @@
|
||||
整合数据处理、模型训练、预测的完整流程。
|
||||
"""
|
||||
|
||||
from typing import List, Optional
|
||||
from typing import TYPE_CHECKING, List, Optional
|
||||
|
||||
import polars as pl
|
||||
|
||||
@@ -11,6 +11,9 @@ from src.training.components.base import BaseModel, BaseProcessor
|
||||
from src.training.components.splitters import DateSplitter
|
||||
from src.training.core.stock_pool_manager import StockPoolManager
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from src.training.components.filters import BaseFilter
|
||||
|
||||
|
||||
class Trainer:
|
||||
"""训练器主类
|
||||
@@ -29,6 +32,7 @@ class Trainer:
|
||||
model: BaseModel,
|
||||
pool_manager: Optional[StockPoolManager] = None,
|
||||
processors: Optional[List[BaseProcessor]] = None,
|
||||
filters: Optional[List["BaseFilter"]] = None,
|
||||
splitter: Optional[DateSplitter] = None,
|
||||
target_col: str = "target",
|
||||
feature_cols: Optional[List[str]] = None,
|
||||
@@ -41,6 +45,7 @@ class Trainer:
|
||||
model: 模型实例
|
||||
pool_manager: 股票池管理器,None 表示不筛选
|
||||
processors: 数据处理器列表
|
||||
filters: 数据过滤器列表(在股票池筛选之前执行)
|
||||
splitter: 数据划分器
|
||||
target_col: 目标变量列名
|
||||
feature_cols: 特征列名列表
|
||||
@@ -50,6 +55,7 @@ class Trainer:
|
||||
self.model = model
|
||||
self.pool_manager = pool_manager
|
||||
self.processors = processors or []
|
||||
self.filters = filters or []
|
||||
self.splitter = splitter
|
||||
self.target_col = target_col
|
||||
self.feature_cols = feature_cols or []
|
||||
@@ -80,6 +86,12 @@ class Trainer:
|
||||
Returns:
|
||||
self (支持链式调用)
|
||||
"""
|
||||
# 0. 数据过滤(在股票池筛选之前)
|
||||
if self.filters:
|
||||
print("[过滤] 应用数据过滤器...")
|
||||
for filter_ in self.filters:
|
||||
data = filter_.filter(data)
|
||||
|
||||
# 1. 股票池筛选(每日独立)
|
||||
if self.pool_manager:
|
||||
print("[筛选] 每日独立筛选股票池...")
|
||||
|
||||
Reference in New Issue
Block a user