refactor(training): 重构股票池管理 API 并更新训练流程
- 移除 StockFilterConfig/MarketCapSelectorConfig,改用 StockPoolManager + filter_func - Trainer 支持 train/val/test 三分法划分 - 更新 regression.ipynb 适配新 API - 删除已弃用的 test_selectors.py,后续补充 StockPoolManager 测试
This commit is contained in:
@@ -9,11 +9,8 @@ from src.training.components.base import BaseModel, BaseProcessor
|
||||
# 数据划分器
|
||||
from src.training.components.splitters import DateSplitter
|
||||
|
||||
# 股票池选择器配置
|
||||
from src.training.components.selectors import (
|
||||
MarketCapSelectorConfig,
|
||||
StockFilterConfig,
|
||||
)
|
||||
# 股票池选择器配置(已迁移到 StockPoolManager)
|
||||
# from src.training.components.selectors import ... # 已删除
|
||||
|
||||
# 数据处理器
|
||||
from src.training.components.processors import (
|
||||
@@ -29,8 +26,8 @@ __all__ = [
|
||||
"BaseModel",
|
||||
"BaseProcessor",
|
||||
"DateSplitter",
|
||||
"StockFilterConfig",
|
||||
"MarketCapSelectorConfig",
|
||||
# "StockFilterConfig", # 已删除
|
||||
# "MarketCapSelectorConfig", # 已删除
|
||||
"StandardScaler",
|
||||
"CrossSectionalStandardScaler",
|
||||
"Winsorizer",
|
||||
|
||||
@@ -1,81 +1,20 @@
|
||||
"""股票池选择器配置
|
||||
|
||||
提供股票过滤和市值选择的配置类。
|
||||
此模块目前为空,股票池筛选功能已迁移到 StockPoolManager。
|
||||
所有筛选逻辑通过传入自定义函数实现。
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional
|
||||
|
||||
|
||||
@dataclass
|
||||
class StockFilterConfig:
|
||||
"""股票过滤器配置
|
||||
|
||||
用于过滤掉不需要的股票(如创业板、科创板等)。
|
||||
基于股票代码进行过滤,不依赖外部数据。
|
||||
|
||||
Attributes:
|
||||
exclude_cyb: 是否排除创业板(300xxx, 301xxx)
|
||||
exclude_kcb: 是否排除科创板(688xxx)
|
||||
exclude_bj: 是否排除北交所(.BJ 后缀)
|
||||
exclude_st: 是否排除ST股票(需要外部数据支持)
|
||||
"""
|
||||
|
||||
exclude_cyb: bool = True
|
||||
exclude_kcb: bool = True
|
||||
exclude_bj: bool = True
|
||||
exclude_st: bool = True
|
||||
|
||||
def filter_codes(self, codes: List[str]) -> List[str]:
|
||||
"""应用过滤条件,返回过滤后的股票代码列表
|
||||
|
||||
Args:
|
||||
codes: 原始股票代码列表
|
||||
|
||||
Returns:
|
||||
过滤后的股票代码列表
|
||||
|
||||
Note:
|
||||
ST 股票过滤需要额外数据,在 StockPoolManager 中处理。
|
||||
此方法仅基于代码前缀进行过滤。
|
||||
"""
|
||||
result = []
|
||||
for code in codes:
|
||||
# 排除创业板(300xxx, 301xxx)
|
||||
if self.exclude_cyb and code.startswith(("300", "301")):
|
||||
continue
|
||||
# 排除科创板(688xxx)
|
||||
if self.exclude_kcb and code.startswith("688"):
|
||||
continue
|
||||
# 排除北交所(.BJ 后缀)
|
||||
if self.exclude_bj and code.endswith(".BJ"):
|
||||
continue
|
||||
result.append(code)
|
||||
return result
|
||||
|
||||
|
||||
@dataclass
|
||||
class MarketCapSelectorConfig:
|
||||
"""市值选择器配置
|
||||
|
||||
每日独立选择市值最大或最小的 n 只股票。
|
||||
市值数据从 daily_basic 表独立获取,仅用于筛选。
|
||||
|
||||
Attributes:
|
||||
enabled: 是否启用选择
|
||||
n: 选择前 n 只
|
||||
ascending: False=最大市值, True=最小市值
|
||||
market_cap_col: 市值列名(来自 daily_basic)
|
||||
"""
|
||||
|
||||
enabled: bool = True
|
||||
n: int = 100
|
||||
ascending: bool = False
|
||||
market_cap_col: str = "total_mv"
|
||||
|
||||
def __post_init__(self):
|
||||
"""验证配置参数"""
|
||||
if self.n <= 0:
|
||||
raise ValueError(f"n 必须是正整数,得到: {self.n}")
|
||||
if not self.market_cap_col:
|
||||
raise ValueError("market_cap_col 不能为空")
|
||||
# 旧配置类已删除:
|
||||
# - StockFilterConfig (使用 filter_func 替代)
|
||||
# - MarketCapSelectorConfig (使用 filter_func + required_factors 替代)
|
||||
#
|
||||
# 新的使用方式:
|
||||
# from src.training import StockPoolManager
|
||||
#
|
||||
# def my_filter(df: pl.DataFrame) -> pl.Series:
|
||||
# return df["total_mv"] > 1e9
|
||||
#
|
||||
# pool_manager = StockPoolManager(
|
||||
# filter_func=my_filter,
|
||||
# required_columns=["total_mv"],
|
||||
# )
|
||||
|
||||
Reference in New Issue
Block a user