feat(training): 添加数据过滤器支持及 ST 股票过滤
- 新增 filters.py 模块,实现 BaseFilter 抽象类和 STFilter 过滤器 - 在 Trainer 中支持 filters 参数,可在股票池筛选之前执行数据过滤 - 更新 training/__init__.py 导出 BaseFilter 和 STFilter - 在 regression.py 中集成 STFilter,用于过滤 ST 股票
This commit is contained in:
@@ -13,6 +13,7 @@ from src.factors import FactorEngine
|
|||||||
from src.training import (
|
from src.training import (
|
||||||
DateSplitter,
|
DateSplitter,
|
||||||
LightGBMModel,
|
LightGBMModel,
|
||||||
|
STFilter,
|
||||||
StandardScaler,
|
StandardScaler,
|
||||||
StockFilterConfig,
|
StockFilterConfig,
|
||||||
StockPoolManager,
|
StockPoolManager,
|
||||||
@@ -223,11 +224,17 @@ def train_regression_model():
|
|||||||
data_router=engine.router, # 从 FactorEngine 获取数据路由器
|
data_router=engine.router, # 从 FactorEngine 获取数据路由器
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# 8.5 创建 ST 股票过滤器(在股票池筛选之前执行)
|
||||||
|
st_filter = STFilter(
|
||||||
|
data_router=engine.router,
|
||||||
|
)
|
||||||
|
|
||||||
# 9. 创建训练器
|
# 9. 创建训练器
|
||||||
trainer = Trainer(
|
trainer = Trainer(
|
||||||
model=model,
|
model=model,
|
||||||
pool_manager=pool_manager,
|
pool_manager=pool_manager,
|
||||||
processors=processors,
|
processors=processors,
|
||||||
|
filters=[st_filter], # 在股票池筛选之前过滤 ST 股票
|
||||||
splitter=splitter,
|
splitter=splitter,
|
||||||
target_col=target_col,
|
target_col=target_col,
|
||||||
feature_cols=feature_cols,
|
feature_cols=feature_cols,
|
||||||
|
|||||||
@@ -33,6 +33,9 @@ from src.training.components.processors import (
|
|||||||
# 模型
|
# 模型
|
||||||
from src.training.components.models import LightGBMModel
|
from src.training.components.models import LightGBMModel
|
||||||
|
|
||||||
|
# 数据过滤器
|
||||||
|
from src.training.components.filters import BaseFilter, STFilter
|
||||||
|
|
||||||
# 训练核心
|
# 训练核心
|
||||||
from src.training.core import StockPoolManager, Trainer
|
from src.training.core import StockPoolManager, Trainer
|
||||||
|
|
||||||
@@ -57,6 +60,9 @@ __all__ = [
|
|||||||
"StandardScaler",
|
"StandardScaler",
|
||||||
"CrossSectionalStandardScaler",
|
"CrossSectionalStandardScaler",
|
||||||
"Winsorizer",
|
"Winsorizer",
|
||||||
|
# 数据过滤器
|
||||||
|
"BaseFilter",
|
||||||
|
"STFilter",
|
||||||
# 模型
|
# 模型
|
||||||
"LightGBMModel",
|
"LightGBMModel",
|
||||||
# 训练核心
|
# 训练核心
|
||||||
|
|||||||
142
src/training/components/filters.py
Normal file
142
src/training/components/filters.py
Normal file
@@ -0,0 +1,142 @@
|
|||||||
|
"""数据过滤器组件
|
||||||
|
|
||||||
|
提供股票数据过滤功能,在因子计算后、市值筛选前执行。
|
||||||
|
与 Processor 不同,Filter 是无状态的筛选操作。
|
||||||
|
"""
|
||||||
|
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import TYPE_CHECKING, Set
|
||||||
|
|
||||||
|
import polars as pl
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from src.factors.engine.data_router import DataRouter
|
||||||
|
|
||||||
|
|
||||||
|
class BaseFilter(ABC):
|
||||||
|
"""数据过滤器基类
|
||||||
|
|
||||||
|
Filter 用于从数据中移除不符合条件的行(股票)。
|
||||||
|
与 Processor 不同:
|
||||||
|
- Filter 是无状态的,不需要 fit
|
||||||
|
- Filter 删除整行数据,而不是变换列值
|
||||||
|
- Filter 每日独立执行
|
||||||
|
"""
|
||||||
|
|
||||||
|
name: str = ""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def filter(self, data: pl.DataFrame) -> pl.DataFrame:
|
||||||
|
"""执行过滤
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data: 输入数据,必须包含 ts_code 和 trade_date 列
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
过滤后的数据
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
class STFilter(BaseFilter):
|
||||||
|
"""ST 股票过滤器
|
||||||
|
|
||||||
|
过滤掉每日的 ST 股票(包括 ST、*ST、S*ST、SST 等)。
|
||||||
|
从 stock_st 表获取每日 ST 股票列表进行过滤。
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
data_router: 数据路由器,用于获取 stock_st 表数据
|
||||||
|
code_col: 股票代码列名
|
||||||
|
date_col: 日期列名
|
||||||
|
"""
|
||||||
|
|
||||||
|
name = "st_filter"
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
data_router: "DataRouter",
|
||||||
|
code_col: str = "ts_code",
|
||||||
|
date_col: str = "trade_date",
|
||||||
|
):
|
||||||
|
"""初始化 ST 过滤器
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data_router: 数据路由器,用于查询 stock_st 表
|
||||||
|
code_col: 股票代码列名
|
||||||
|
date_col: 日期列名
|
||||||
|
"""
|
||||||
|
self.data_router = data_router
|
||||||
|
self.code_col = code_col
|
||||||
|
self.date_col = date_col
|
||||||
|
# 缓存:{date: set(stock_codes)}
|
||||||
|
self._st_cache: dict = {}
|
||||||
|
|
||||||
|
def filter(self, data: pl.DataFrame) -> pl.DataFrame:
|
||||||
|
"""过滤 ST 股票
|
||||||
|
|
||||||
|
按日期分组,每日独立从 stock_st 表获取 ST 列表并过滤。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data: 因子计算后的数据,包含 ts_code 和 trade_date
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
过滤后的数据(不含 ST 股票)
|
||||||
|
"""
|
||||||
|
dates = data.select(self.date_col).unique().sort(self.date_col)
|
||||||
|
|
||||||
|
result_frames = []
|
||||||
|
for date in dates.to_series():
|
||||||
|
# 获取当日数据
|
||||||
|
daily_data = data.filter(pl.col(self.date_col) == date)
|
||||||
|
daily_codes = daily_data.select(self.code_col).to_series().to_list()
|
||||||
|
|
||||||
|
# 获取当日 ST 股票列表
|
||||||
|
st_codes = self._get_st_codes_for_date(date)
|
||||||
|
|
||||||
|
# 过滤掉 ST 股票
|
||||||
|
daily_filtered = daily_data.filter(~pl.col(self.code_col).is_in(st_codes))
|
||||||
|
|
||||||
|
result_frames.append(daily_filtered)
|
||||||
|
|
||||||
|
# 打印过滤信息
|
||||||
|
n_removed = len(daily_codes) - len(daily_filtered)
|
||||||
|
if n_removed > 0:
|
||||||
|
print(f" [{date}] 过滤 {n_removed} 只 ST 股票")
|
||||||
|
|
||||||
|
return pl.concat(result_frames)
|
||||||
|
|
||||||
|
def _get_st_codes_for_date(self, date: str) -> Set[str]:
|
||||||
|
"""从 stock_st 表获取指定日期的 ST 股票代码
|
||||||
|
|
||||||
|
Args:
|
||||||
|
date: 日期 "YYYYMMDD"
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ST 股票代码集合
|
||||||
|
"""
|
||||||
|
# 检查缓存
|
||||||
|
if date in self._st_cache:
|
||||||
|
return self._st_cache[date]
|
||||||
|
|
||||||
|
try:
|
||||||
|
from src.factors.engine.data_spec import DataSpec
|
||||||
|
|
||||||
|
# 查询 stock_st 表获取当日所有 ST 股票
|
||||||
|
data_specs = [DataSpec("stock_st", [self.code_col])]
|
||||||
|
df = self.data_router.fetch_data(
|
||||||
|
data_specs=data_specs,
|
||||||
|
start_date=date,
|
||||||
|
end_date=date,
|
||||||
|
stock_codes=None, # 获取当日全部 ST 股票
|
||||||
|
)
|
||||||
|
|
||||||
|
# 提取 ST 股票代码
|
||||||
|
st_codes = set(df[self.code_col].to_list()) if len(df) > 0 else set()
|
||||||
|
|
||||||
|
# 缓存结果
|
||||||
|
self._st_cache[date] = st_codes
|
||||||
|
return st_codes
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"[警告] 获取 {date} ST 股票列表失败: {e}")
|
||||||
|
return set()
|
||||||
@@ -3,7 +3,7 @@
|
|||||||
整合数据处理、模型训练、预测的完整流程。
|
整合数据处理、模型训练、预测的完整流程。
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import List, Optional
|
from typing import TYPE_CHECKING, List, Optional
|
||||||
|
|
||||||
import polars as pl
|
import polars as pl
|
||||||
|
|
||||||
@@ -11,6 +11,9 @@ from src.training.components.base import BaseModel, BaseProcessor
|
|||||||
from src.training.components.splitters import DateSplitter
|
from src.training.components.splitters import DateSplitter
|
||||||
from src.training.core.stock_pool_manager import StockPoolManager
|
from src.training.core.stock_pool_manager import StockPoolManager
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from src.training.components.filters import BaseFilter
|
||||||
|
|
||||||
|
|
||||||
class Trainer:
|
class Trainer:
|
||||||
"""训练器主类
|
"""训练器主类
|
||||||
@@ -29,6 +32,7 @@ class Trainer:
|
|||||||
model: BaseModel,
|
model: BaseModel,
|
||||||
pool_manager: Optional[StockPoolManager] = None,
|
pool_manager: Optional[StockPoolManager] = None,
|
||||||
processors: Optional[List[BaseProcessor]] = None,
|
processors: Optional[List[BaseProcessor]] = None,
|
||||||
|
filters: Optional[List["BaseFilter"]] = None,
|
||||||
splitter: Optional[DateSplitter] = None,
|
splitter: Optional[DateSplitter] = None,
|
||||||
target_col: str = "target",
|
target_col: str = "target",
|
||||||
feature_cols: Optional[List[str]] = None,
|
feature_cols: Optional[List[str]] = None,
|
||||||
@@ -41,6 +45,7 @@ class Trainer:
|
|||||||
model: 模型实例
|
model: 模型实例
|
||||||
pool_manager: 股票池管理器,None 表示不筛选
|
pool_manager: 股票池管理器,None 表示不筛选
|
||||||
processors: 数据处理器列表
|
processors: 数据处理器列表
|
||||||
|
filters: 数据过滤器列表(在股票池筛选之前执行)
|
||||||
splitter: 数据划分器
|
splitter: 数据划分器
|
||||||
target_col: 目标变量列名
|
target_col: 目标变量列名
|
||||||
feature_cols: 特征列名列表
|
feature_cols: 特征列名列表
|
||||||
@@ -50,6 +55,7 @@ class Trainer:
|
|||||||
self.model = model
|
self.model = model
|
||||||
self.pool_manager = pool_manager
|
self.pool_manager = pool_manager
|
||||||
self.processors = processors or []
|
self.processors = processors or []
|
||||||
|
self.filters = filters or []
|
||||||
self.splitter = splitter
|
self.splitter = splitter
|
||||||
self.target_col = target_col
|
self.target_col = target_col
|
||||||
self.feature_cols = feature_cols or []
|
self.feature_cols = feature_cols or []
|
||||||
@@ -80,6 +86,12 @@ class Trainer:
|
|||||||
Returns:
|
Returns:
|
||||||
self (支持链式调用)
|
self (支持链式调用)
|
||||||
"""
|
"""
|
||||||
|
# 0. 数据过滤(在股票池筛选之前)
|
||||||
|
if self.filters:
|
||||||
|
print("[过滤] 应用数据过滤器...")
|
||||||
|
for filter_ in self.filters:
|
||||||
|
data = filter_.filter(data)
|
||||||
|
|
||||||
# 1. 股票池筛选(每日独立)
|
# 1. 股票池筛选(每日独立)
|
||||||
if self.pool_manager:
|
if self.pool_manager:
|
||||||
print("[筛选] 每日独立筛选股票池...")
|
print("[筛选] 每日独立筛选股票池...")
|
||||||
|
|||||||
Reference in New Issue
Block a user