refactor(training): 重构股票池管理 API 并更新训练流程

- 移除 StockFilterConfig/MarketCapSelectorConfig,改用 StockPoolManager + filter_func
- Trainer 支持 train/val/test 三分法划分
- 更新 regression.ipynb 适配新 API
- 删除已弃用的 test_selectors.py,后续补充 StockPoolManager 测试
This commit is contained in:
2026-03-09 22:33:41 +08:00
parent a464ef70c0
commit 88fa848b96
10 changed files with 1110 additions and 804 deletions

View File

@@ -9,11 +9,8 @@ from src.training.components.base import BaseModel, BaseProcessor
# 数据划分器
from src.training.components.splitters import DateSplitter
# 股票池选择器配置
from src.training.components.selectors import (
MarketCapSelectorConfig,
StockFilterConfig,
)
# 股票池选择器配置(已迁移到 StockPoolManager
# from src.training.components.selectors import ... # 已删除
# 数据处理器
from src.training.components.processors import (
@@ -29,8 +26,8 @@ __all__ = [
"BaseModel",
"BaseProcessor",
"DateSplitter",
"StockFilterConfig",
"MarketCapSelectorConfig",
# "StockFilterConfig", # 已删除
# "MarketCapSelectorConfig", # 已删除
"StandardScaler",
"CrossSectionalStandardScaler",
"Winsorizer",

View File

@@ -1,81 +1,20 @@
"""股票池选择器配置
提供股票过滤和市值选择的配置类
此模块目前为空,股票池筛选功能已迁移到 StockPoolManager
所有筛选逻辑通过传入自定义函数实现。
"""
from dataclasses import dataclass
from typing import List, Optional
@dataclass
class StockFilterConfig:
"""股票过滤器配置
用于过滤掉不需要的股票(如创业板、科创板等)。
基于股票代码进行过滤,不依赖外部数据。
Attributes:
exclude_cyb: 是否排除创业板300xxx, 301xxx
exclude_kcb: 是否排除科创板688xxx
exclude_bj: 是否排除北交所(.BJ 后缀)
exclude_st: 是否排除ST股票需要外部数据支持
"""
exclude_cyb: bool = True
exclude_kcb: bool = True
exclude_bj: bool = True
exclude_st: bool = True
def filter_codes(self, codes: List[str]) -> List[str]:
"""应用过滤条件,返回过滤后的股票代码列表
Args:
codes: 原始股票代码列表
Returns:
过滤后的股票代码列表
Note:
ST 股票过滤需要额外数据,在 StockPoolManager 中处理。
此方法仅基于代码前缀进行过滤。
"""
result = []
for code in codes:
# 排除创业板300xxx, 301xxx
if self.exclude_cyb and code.startswith(("300", "301")):
continue
# 排除科创板688xxx
if self.exclude_kcb and code.startswith("688"):
continue
# 排除北交所(.BJ 后缀)
if self.exclude_bj and code.endswith(".BJ"):
continue
result.append(code)
return result
@dataclass
class MarketCapSelectorConfig:
"""市值选择器配置
每日独立选择市值最大或最小的 n 只股票。
市值数据从 daily_basic 表独立获取,仅用于筛选。
Attributes:
enabled: 是否启用选择
n: 选择前 n 只
ascending: False=最大市值, True=最小市值
market_cap_col: 市值列名(来自 daily_basic
"""
enabled: bool = True
n: int = 100
ascending: bool = False
market_cap_col: str = "total_mv"
def __post_init__(self):
"""验证配置参数"""
if self.n <= 0:
raise ValueError(f"n 必须是正整数,得到: {self.n}")
if not self.market_cap_col:
raise ValueError("market_cap_col 不能为空")
# 旧配置类已删除:
# - StockFilterConfig (使用 filter_func 替代)
# - MarketCapSelectorConfig (使用 filter_func + required_factors 替代)
#
# 新的使用方式:
# from src.training import StockPoolManager
#
# def my_filter(df: pl.DataFrame) -> pl.Series:
# return df["total_mv"] > 1e9
#
# pool_manager = StockPoolManager(
# filter_func=my_filter,
# required_columns=["total_mv"],
# )