Files
ProStock/docs/plan/training_module_plan.md
liaozhaorun 192718095f feat(training): 实现训练模块核心组件(commits 6-9)
- StockPoolManager:每日独立筛选股票池,支持代码过滤和市值选择
- Trainer:整合训练完整流程,支持 processor 分阶段行为和模型持久化
- TrainingConfig:pydantic 配置管理,含必填字段和日期验证
- experiment 模块:预留结构
- 从计划中移除 metrics 组件
- 调整 commit 序号(7-10 → 6-9)
- 更新 training/__init__.py 导出所有公开 API
2026-03-03 22:57:01 +08:00

32 KiB
Raw Blame History

训练模块实现计划

1. 概述

本计划描述 ProStock 训练模块的完整实现方案,支持标准回归模型训练并输出预测结果。

1.1 目标

  • 提供简洁的训练流程
  • 支持标准回归模型LightGBM、CatBoost
  • 输出可解释的预测结果
  • 与现有因子引擎无缝集成

1.2 设计原则

  • 职责分离:训练流程与建模组件分离
  • 注册机制:使用装饰器实现即插即用
  • 配置驱动:所有参数通过配置类管理
  • 阶段感知Processor 在训练和测试阶段行为不同
  • 每日筛选:股票池每日独立筛选(市值动态变化)
  • 模型持久化:支持保存和加载训练好的模型

2. 模块结构

src/
├── training/                    # 训练模块(流程 + 组件)
│   ├── __init__.py              # 导出核心类
│   ├── core/                    # 训练流程核心
│   │   ├── __init__.py
│   │   ├── trainer.py          # Trainer 主类
│   │   └── stock_pool_manager.py  # 股票池管理器(每日独立筛选)
│   ├── components/             # 建模组件
│   │   ├── __init__.py
│   │   ├── base.py             # BaseModel, BaseProcessor 抽象基类
│   │   ├── splitters.py        # 时间序列划分策略(一次性划分)
│   │   ├── selectors.py        # 股票池选择器配置
│   │   ├── models/             # 模型实现
│   │   │   ├── __init__.py
│   │   │   ├── lightgbm.py     # LightGBM 回归模型
│   │   │   └── catboost.py     # CatBoost 回归模型
│   │   ├── processors/         # 数据处理器
│   │   │   ├── __init__.py
│   │   │   └── transforms.py   # 标准化(截面/时序)、缩尾
│   ├── config/                 # 配置管理
│   │   ├── __init__.py
│   │   └── config.py           # TrainingConfig (pydantic)
│   └── registry.py             # 组件注册中心
│
└── experiment/                 # 实验管理(预留结构,暂不实现)
    └── __init__.py

3. 核心组件设计

3.1 抽象基类 (components/base.py)

BaseModel

class BaseModel(ABC):
    """模型基类"""
    
    name: str = ""  # 模型名称
    
    def fit(self, X: pl.DataFrame, y: pl.Series) -> "BaseModel":
        """训练模型"""
        raise NotImplementedError
    
    def predict(self, X: pl.DataFrame) -> np.ndarray:
        """预测"""
        raise NotImplementedError
    
    def feature_importance(self) -> Optional[pd.Series]:
        """特征重要性"""
        return None
    
    def save(self, path: str) -> None:
        """保存模型到文件
        
        默认实现使用 pickle子类可覆盖
        """
        import pickle
        with open(path, 'wb') as f:
            pickle.dump(self, f)
    
    @classmethod
    def load(cls, path: str) -> "BaseModel":
        """从文件加载模型"""
        import pickle
        with open(path, 'rb') as f:
            return pickle.load(f)

BaseProcessor

class BaseProcessor(ABC):
    """数据处理器基类
    
    重要Processor 在不同阶段行为不同:
    - 训练阶段fit_transform学习参数并应用
    - 验证/测试阶段transform使用训练阶段学到的参数
    
    这意味着 Processor 实例会在训练后被保存,
    用于后续的验证和测试数据转换。
    """
    
    name: str = ""
    
    def fit(self, X: pl.DataFrame) -> "BaseProcessor":
        """学习参数(仅在训练阶段调用)"""
        return self
    
    def transform(self, X: pl.DataFrame) -> pl.DataFrame:
        """转换数据"""
        raise NotImplementedError
    
    def fit_transform(self, X: pl.DataFrame) -> pl.DataFrame:
        """拟合并转换(训练阶段使用)"""
        return self.fit(X).transform(X)

3.2 时间序列划分 (components/splitters.py)

设计说明:暂不实现滚动训练,采用一次性训练/测试划分。

DateSplitter

class DateSplitter:
    """基于日期范围的一次性划分
    
    将数据按日期划分为训练集和测试集,不滚动。
    
    示例:
        train_start: "20200101", train_end: "20221231"  (训练集3年)
        test_start: "20230101", test_end: "20231231"    (测试集1年)
        
    特点:
        - 一次性划分,不滚动
        - 训练集和测试集互不重叠
        - 基于实际日期范围,而非行数
    """
    
    def __init__(
        self,
        train_start: str,      # 训练期开始日期 "YYYYMMDD"
        train_end: str,        # 训练期结束日期 "YYYYMMDD"
        test_start: str,       # 测试期开始日期 "YYYYMMDD"
        test_end: str,         # 测试期结束日期 "YYYYMMDD"
    ):
        self.train_start = train_start
        self.train_end = train_end
        self.test_start = test_start
        self.test_end = test_end
    
    def split(self, data: pl.DataFrame) -> Tuple[pl.DataFrame, pl.DataFrame]:
        """划分数据为训练集和测试集"""
        train_data = data.filter(
            (pl.col("trade_date") >= self.train_start) & 
            (pl.col("trade_date") <= self.train_end)
        )
        test_data = data.filter(
            (pl.col("trade_date") >= self.test_start) & 
            (pl.col("trade_date") <= self.test_end)
        )
        return train_data, test_data

3.3 股票池选择器配置 (components/selectors.py)

设计说明:股票池每日独立筛选,市值选择需要配合 StockPoolManager 使用。

StockFilterConfig

@dataclass
class StockFilterConfig:
    """股票过滤器配置
    
    用于过滤掉不需要的股票(如创业板、科创板等)。
    基于股票代码进行过滤,不依赖外部数据。
    """
    
    exclude_cyb: bool = True      # 是否排除创业板300xxx
    exclude_kcb: bool = True      # 是否排除科创板688xxx
    exclude_bj: bool = True       # 是否排除北交所8xxxxxx, 4xxxxxx
    exclude_st: bool = True       # 是否排除ST股票
    
    def filter_codes(self, codes: List[str]) -> List[str]:
        """应用过滤条件,返回过滤后的股票代码列表"""
        result = []
        for code in codes:
            if self.exclude_cyb and code.startswith("300"):
                continue
            if self.exclude_kcb and code.startswith("688"):
                continue
            if self.exclude_bj and (code.startswith("8") or code.startswith("4")):
                continue
            # ST 股票过滤需要额外数据,在 StockPoolManager 中处理
            result.append(code)
        return result

MarketCapSelectorConfig

@dataclass
class MarketCapSelectorConfig:
    """市值选择器配置
    
    每日独立选择市值最大或最小的 n 只股票。
    市值数据从 daily_basic 表独立获取,仅用于筛选。
    """
    
    enabled: bool = True          # 是否启用选择
    n: int = 100                  # 选择前 n 只
    ascending: bool = False       # False=最大市值, True=最小市值
    market_cap_col: str = "total_mv"  # 市值列名(来自 daily_basic

3.4 股票池管理器 (core/stock_pool_manager.py)

设计说明:每日独立筛选股票池,市值数据从 daily_basic 表独立获取。

class StockPoolManager:
    """股票池管理器 - 每日独立筛选
    
    重要约束:
    1. 市值数据仅从 daily_basic 表获取,仅用于筛选
    2. 市值数据绝不混入特征矩阵
    3. 每日独立筛选(市值是动态变化的)
    
    处理流程(每日):
        当日所有股票
        代码过滤创业板、ST等
        查询 daily_basic 获取当日市值
        市值选择前N只
        返回当日选中股票列表
    """
    
    def __init__(
        self,
        filter_config: StockFilterConfig,
        selector_config: Optional[MarketCapSelectorConfig],
        data_router: DataRouter,  # 用于获取 daily_basic 数据
        code_col: str = "ts_code",
        date_col: str = "trade_date",
    ):
        self.filter_config = filter_config
        self.selector_config = selector_config
        self.data_router = data_router
        self.code_col = code_col
        self.date_col = date_col
    
    def filter_and_select_daily(self, data: pl.DataFrame) -> pl.DataFrame:
        """每日独立筛选股票池
        
        Args:
            data: 因子计算后的全市场数据,必须包含 trade_date 和 ts_code 列
            
        Returns:
            筛选后的数据,仅包含每日选中的股票
            
        注意:
            - 按日期分组处理
            - 市值数据从 daily_basic 独立获取
            - 保持市值数据与特征数据隔离
        """
        dates = data.select(self.date_col).unique().sort(self.date_col)
        
        result_frames = []
        for date in dates.to_series():
            # 获取当日数据
            daily_data = data.filter(pl.col(self.date_col) == date)
            daily_codes = daily_data.select(self.code_col).to_series().to_list()
            
            # 1. 代码过滤
            filtered_codes = self.filter_config.filter_codes(daily_codes)
            
            # 2. 市值选择(如果启用)
            if self.selector_config and self.selector_config.enabled:
                # 从 daily_basic 获取当日市值
                market_caps = self._get_market_caps_for_date(filtered_codes, date)
                selected_codes = self._select_by_market_cap(filtered_codes, market_caps)
            else:
                selected_codes = filtered_codes
            
            # 3. 保留当日选中的股票数据
            daily_selected = daily_data.filter(
                pl.col(self.code_col).is_in(selected_codes)
            )
            result_frames.append(daily_selected)
        
        return pl.concat(result_frames)
    
    def _get_market_caps_for_date(
        self, 
        codes: List[str], 
        date: str
    ) -> Dict[str, float]:
        """从 daily_basic 表获取指定日期的市值数据
        
        Args:
            codes: 股票代码列表
            date: 日期 "YYYYMMDD"
            
        Returns:
            {股票代码: 市值} 的字典
        """
        # 通过 data_router 查询 daily_basic 表
        pass
    
    def _select_by_market_cap(
        self, 
        codes: List[str], 
        market_caps: Dict[str, float]
    ) -> List[str]:
        """根据市值选择股票"""
        if not market_caps:
            return codes
        
        # 按市值排序并选择前N只
        sorted_codes = sorted(
            codes, 
            key=lambda c: market_caps.get(c, 0),
            reverse=not self.selector_config.ascending
        )
        return sorted_codes[:self.selector_config.n]

3.5 模型实现 (components/models/)

LightGBMModel

@register_model("lightgbm")
class LightGBMModel(BaseModel):
    """LightGBM 回归模型"""
    
    name = "lightgbm"
    
    def __init__(self, 
                 objective: str = "regression",
                 metric: str = "rmse",
                 num_leaves: int = 31,
                 learning_rate: float = 0.05,
                 n_estimators: int = 100,
                 **kwargs):
        self.params = {
            "objective": objective,
            "metric": metric,
            "num_leaves": num_leaves,
            "learning_rate": learning_rate,
            "n_estimators": n_estimators,
            **kwargs
        }
        self.model = None
    
    def fit(self, X: pl.DataFrame, y: pl.Series) -> "LightGBMModel":
        """训练模型"""
        import lightgbm as lgb
        
        # 转换为 numpy
        X_np = X.to_numpy()
        y_np = y.to_numpy()
        
        # 创建数据集
        train_data = lgb.Dataset(X_np, label=y_np)
        
        # 训练
        self.model = lgb.train(
            self.params,
            train_data,
            num_boost_round=self.params.get("n_estimators", 100)
        )
        return self
    
    def predict(self, X: pl.DataFrame) -> np.ndarray:
        """预测"""
        if self.model is None:
            raise RuntimeError("Model not fitted yet")
        X_np = X.to_numpy()
        return self.model.predict(X_np)
    
    def feature_importance(self) -> pd.Series:
        """返回特征重要性"""
        if self.model is None:
            return None
        importance = self.model.feature_importance(importance_type="gain")
        return pd.Series(importance, index=self.feature_names_)
    
    def save(self, path: str) -> None:
        """保存模型(使用 LightGBM 原生格式)"""
        if self.model is None:
            raise RuntimeError("Model not fitted yet")
        self.model.save_model(path)
    
    @classmethod
    def load(cls, path: str) -> "LightGBMModel":
        """加载模型"""
        import lightgbm as lgb
        instance = cls()
        instance.model = lgb.Booster(model_file=path)
        return instance

3.6 数据处理器 (components/processors/)

StandardScaler

@register_processor("standard_scaler")
class StandardScaler(BaseProcessor):
    """标准化处理器(时序标准化)
    
    在整个训练集上学习均值和标准差,
    然后应用到训练集和测试集。
    """
    
    name = "standard_scaler"
    
    def __init__(self, exclude_cols: List[str] = None):
        self.exclude_cols = exclude_cols or ["ts_code", "trade_date"]
        self.mean_ = {}
        self.std_ = {}
    
    def fit(self, X: pl.DataFrame) -> "StandardScaler":
        """计算均值和标准差(仅在训练集上)"""
        numeric_cols = [
            c for c in X.columns 
            if c not in self.exclude_cols and X[c].dtype.is_numeric()
        ]
        
        for col in numeric_cols:
            self.mean_[col] = X[col].mean()
            self.std_[col] = X[col].std()
        
        return self
    
    def transform(self, X: pl.DataFrame) -> pl.DataFrame:
        """标准化(使用训练集学到的参数)"""
        expressions = []
        for col in X.columns:
            if col in self.mean_:
                expr = ((pl.col(col) - self.mean_[col]) / self.std_[col]).alias(col)
                expressions.append(expr)
            else:
                expressions.append(pl.col(col))
        
        return X.select(expressions)

CrossSectionalStandardScaler

@register_processor("cs_standard_scaler")
class CrossSectionalStandardScaler(BaseProcessor):
    """截面标准化处理器
    
    每天独立进行标准化:对当天所有股票的某一因子进行标准化。
    
    特点:
        - 不需要 fit每天独立计算当天的均值和标准差
        - 适用于截面因子,消除市值等行业差异
        - 公式z = (x - mean_today) / std_today
    """
    
    name = "cs_standard_scaler"
    
    def __init__(self, exclude_cols: List[str] = None, date_col: str = "trade_date"):
        self.exclude_cols = exclude_cols or ["ts_code", "trade_date"]
        self.date_col = date_col
    
    def transform(self, X: pl.DataFrame) -> pl.DataFrame:
        """截面标准化
        
        按日期分组,每天独立计算均值和标准差并标准化。
        不需要 fit因为每天使用当天的统计量。
        """
        numeric_cols = [
            c for c in X.columns 
            if c not in self.exclude_cols and X[c].dtype.is_numeric()
        ]
        
        # 按日期分组标准化
        result = X.with_columns([
            pl.col(col).mean().over(self.date_col).alias(f"{col}_mean")
            for col in numeric_cols
        ] + [
            pl.col(col).std().over(self.date_col).alias(f"{col}_std")
            for col in numeric_cols
        ])
        
        # 计算标准化值
        for col in numeric_cols:
            result = result.with_columns([
                ((pl.col(col) - pl.col(f"{col}_mean")) / pl.col(f"{col}_std")).alias(col)
            ])
            # 删除中间列
            result = result.drop([f"{col}_mean", f"{col}_std"])
        
        return result

Winsorizer

@register_processor("winsorizer")
class Winsorizer(BaseProcessor):
    """缩尾处理器
    
    对每一列的极端值进行截断处理。
    可以全局截断(在整个训练集上学习分位数),
    也可以截面截断(每天独立处理)。
    """
    
    name = "winsorizer"
    
    def __init__(
        self, 
        lower: float = 0.01, 
        upper: float = 0.99,
        by_date: bool = False,  # True=每天独立缩尾, False=全局缩尾
        date_col: str = "trade_date"
    ):
        self.lower = lower
        self.upper = upper
        self.by_date = by_date
        self.date_col = date_col
        self.bounds_ = {}  # 存储分位数边界(全局模式)
    
    def fit(self, X: pl.DataFrame) -> "Winsorizer":
        """学习分位数边界(仅在全局模式下)"""
        if not self.by_date:
            numeric_cols = [c for c in X.columns if X[c].dtype.is_numeric()]
            for col in numeric_cols:
                self.bounds_[col] = {
                    "lower": X[col].quantile(self.lower),
                    "upper": X[col].quantile(self.upper)
                }
        return self
    
    def transform(self, X: pl.DataFrame) -> pl.DataFrame:
        """缩尾处理"""
        if self.by_date:
            # 每天独立缩尾
            return self._transform_by_date(X)
        else:
            # 全局缩尾
            return self._transform_global(X)
    
    def _transform_global(self, X: pl.DataFrame) -> pl.DataFrame:
        """全局缩尾(使用训练集学到的边界)"""
        expressions = []
        for col in X.columns:
            if col in self.bounds_:
                lower = self.bounds_[col]["lower"]
                upper = self.bounds_[col]["upper"]
                expr = pl.col(col).clip(lower, upper).alias(col)
                expressions.append(expr)
            else:
                expressions.append(pl.col(col))
        return X.select(expressions)
    
    def _transform_by_date(self, X: pl.DataFrame) -> pl.DataFrame:
        """每日独立缩尾"""
        # 按日期分组计算分位数并截断
        # Polars 实现...
        pass

4. 训练流程设计

4.1 Trainer 主类 (core/trainer.py)

class Trainer:
    """训练器主类
    
    整合数据处理、模型训练、评估的完整流程。
    
    关键设计:
    1. 因子先计算(全市场),再筛选股票池(每日独立)
    2. Processor 分阶段行为:训练集 fit_transform测试集 transform
    3. 一次性训练,不滚动
    4. 支持模型持久化
    """
    
    def __init__(
        self,
        model: BaseModel,
        pool_manager: Optional[StockPoolManager] = None,
        processors: List[BaseProcessor] = None,
        splitter: DateSplitter = None,
        target_col: str = "target",
        feature_cols: List[str] = None,
        persist_model: bool = False,
        model_save_path: Optional[str] = None,
    ):
        self.model = model
        self.pool_manager = pool_manager
        self.processors = processors or []
        self.splitter = splitter
        self.target_col = target_col
        self.feature_cols = feature_cols
        self.persist_model = persist_model
        self.model_save_path = model_save_path
        
        # 存储训练后的处理器
        self.fitted_processors: List[BaseProcessor] = []
        self.results: pl.DataFrame = None
    
    def train(self, data: pl.DataFrame) -> "Trainer":
        """执行训练流程
        
        流程:
        1. 股票池每日筛选(如果配置了 pool_manager
        2. 按日期划分训练集/测试集
        3. 训练集processors fit_transform
        4. 训练模型
        5. 测试集processors transform使用训练集学到的参数
        6. 预测并评估
        7. 持久化模型(如果启用)
        
        Args:
            data: 因子计算后的全市场数据
                  必须包含 ts_code 和 trade_date 列
        
        Returns:
            self (支持链式调用)
        """
        # 1. 股票池筛选(每日独立)
        if self.pool_manager:
            print("[筛选] 每日独立筛选股票池...")
            data = self.pool_manager.filter_and_select_daily(data)
        
        # 2. 划分训练/测试集
        print("[划分] 划分训练集和测试集...")
        train_data, test_data = self.splitter.split(data)
        
        # 3. 训练集processors fit_transform
        print("[处理] 处理训练集...")
        for processor in self.processors:
            train_data = processor.fit_transform(train_data)
            self.fitted_processors.append(processor)
        
        # 4. 训练模型
        print("[训练] 训练模型...")
        X_train = train_data.select(self.feature_cols)
        y_train = train_data.select(self.target_col).to_series()
        self.model.fit(X_train, y_train)
        
        # 5. 测试集processors transform
        print("[处理] 处理测试集...")
        for processor in self.fitted_processors:
            test_data = processor.transform(test_data)
        
        # 6. 预测
        print("[预测] 生成预测...")
        X_test = test_data.select(self.feature_cols)
        predictions = self.model.predict(X_test)
        
        # 7. 保存结果
        self.results = test_data.with_columns([
            pl.Series("prediction", predictions)
        ])
        
        # 8. 持久化模型
        if self.persist_model and self.model_save_path:
            print(f"[保存] 保存模型到 {self.model_save_path}...")
            self.save_model(self.model_save_path)
        
        return self
    
    def predict(self, data: pl.DataFrame) -> pl.DataFrame:
        """对新数据进行预测
        
        注意:新数据需要先经过股票池筛选,
        然后使用训练好的 processors 和 model 进行预测。
        """
        # 应用 processors
        for processor in self.fitted_processors:
            data = processor.transform(data)
        
        # 预测
        X = data.select(self.feature_cols)
        predictions = self.model.predict(X)
        
        return data.with_columns([pl.Series("prediction", predictions)])
    
    def get_results(self) -> pl.DataFrame:
        """获取所有预测结果"""
        return self.results
    
    def save_results(self, path: str):
        """保存预测结果到文件"""
        if self.results is not None:
            self.results.write_csv(path)
    
    def save_model(self, path: str):
        """保存模型"""
        self.model.save(path)

4.2 配置类 (config/config.py)

from pydantic import BaseSettings, Field
from typing import List, Dict, Optional
from dataclasses import dataclass

@dataclass
class StockFilterConfig:
    """股票过滤器配置"""
    exclude_cyb: bool = True      # 排除创业板
    exclude_kcb: bool = True      # 排除科创板
    exclude_bj: bool = True       # 排除北交所
    exclude_st: bool = True       # 排除ST股票

@dataclass
class MarketCapSelectorConfig:
    """市值选择器配置"""
    enabled: bool = True          # 是否启用
    n: int = 100                  # 选择前 n 只
    ascending: bool = False       # False=最大市值, True=最小市值
    market_cap_col: str = "total_mv"  # 市值列名(来自 daily_basic

@dataclass
class ProcessorConfig:
    """处理器配置"""
    name: str
    params: Dict = Field(default_factory=dict)

class TrainingConfig(BaseSettings):
    """训练配置类"""
    
    # === 数据配置(必填)===
    feature_cols: List[str] = Field(..., min_items=1)  # 特征列名,至少一个
    target_col: str = "target"      # 目标变量列名
    date_col: str = "trade_date"    # 日期列名
    code_col: str = "ts_code"       # 股票代码列名
    
    # === 日期划分(必填)===
    train_start: str = Field(..., description="训练期开始 YYYYMMDD")
    train_end: str = Field(..., description="训练期结束 YYYYMMDD")
    test_start: str = Field(..., description="测试期开始 YYYYMMDD")
    test_end: str = Field(..., description="测试期结束 YYYYMMDD")
    
    # === 股票池配置 ===
    stock_filter: StockFilterConfig = Field(
        default_factory=lambda: StockFilterConfig(
            exclude_cyb=True,
            exclude_kcb=True,
            exclude_bj=True,
            exclude_st=True,
        )
    )
    stock_selector: Optional[MarketCapSelectorConfig] = Field(
        default_factory=lambda: MarketCapSelectorConfig(
            enabled=True,
            n=100,
            ascending=False,
            market_cap_col="total_mv",
        )
    )
    # 注意:如果 stock_selector = None则跳过市值选择
    
    # === 模型配置 ===
    model_type: str = "lightgbm"
    model_params: Dict = Field(default_factory=dict)
    
    # === 处理器配置 ===
    processors: List[ProcessorConfig] = Field(
        default_factory=lambda: [
            ProcessorConfig(name="winsorizer", params={"lower": 0.01, "upper": 0.99}),
            ProcessorConfig(name="cs_standard_scaler", params={}),
        ]
    )
    
    # === 持久化配置 ===
    persist_model: bool = False           # 默认不持久化
    model_save_path: Optional[str] = None  # 持久化路径
    
    # === 输出配置 ===
    output_dir: str = "output"
    save_predictions: bool = True

5. 使用示例

5.1 基本用法

from src.training import Trainer, LightGBMModel, DateSplitter
from src.training.config import StockFilterConfig, MarketCapSelectorConfig
from src.training.core.stock_pool_manager import StockPoolManager
from src.training.components.processors import Winsorizer, CrossSectionalStandardScaler
from src.factors import FactorEngine
import polars as pl

# 1. 因子计算(全市场数据)
engine = FactorEngine()
all_data = engine.compute(
    factor_names=["factor1", "factor2", "factor3", "future_return_5d"],
    start_date="20200101",
    end_date="20231231",
    # 不指定 stock_codes获取全市场数据
)

# 2. 创建股票池管理器(每日独立筛选)
pool_manager = StockPoolManager(
    filter_config=StockFilterConfig(
        exclude_cyb=True,
        exclude_kcb=True,
        exclude_bj=True,
        exclude_st=True,
    ),
    selector_config=MarketCapSelectorConfig(
        enabled=True,
        n=100,
        ascending=False,  # 最大市值
    ),
    data_router=data_router,  # 用于获取 daily_basic 数据
)

# 3. 创建模型
model = LightGBMModel(
    objective="regression",
    n_estimators=100,
    learning_rate=0.05
)

# 4. 创建处理器
processors = [
    Winsorizer(lower=0.01, upper=0.99, by_date=False),  # 全局缩尾
    CrossSectionalStandardScaler(),  # 截面标准化(每天独立)
]

# 5. 创建划分器
splitter = DateSplitter(
    train_start="20200101",
    train_end="20221231",
    test_start="20230101",
    test_end="20231231"
)

# 6. 创建训练器
trainer = Trainer(
    model=model,
    pool_manager=pool_manager,  # 每日筛选股票池
    processors=processors,
    splitter=splitter,
    target_col="future_return_5d",
    feature_cols=["factor1", "factor2", "factor3"],
    persist_model=True,              # 启用持久化
    model_save_path="output/model.pkl",
)

# 7. 执行训练(传入全市场数据)
trainer.train(all_data)

# 8. 获取结果
results = trainer.get_results()  # 包含预测值

# 9. 保存结果
trainer.save_results("output/predictions.csv")

# 10. 加载模型并预测新数据
loaded_model = LightGBMModel.load("output/model.pkl")
new_predictions = loaded_model.predict(new_data)

5.2 使用配置驱动

from src.training.config import TrainingConfig, StockFilterConfig
from src.training import Trainer
from src.training.registry import ModelRegistry, ProcessorRegistry
from src.training.core.stock_pool_manager import StockPoolManager
from src.factors import FactorEngine

# 1. 配置(必填字段校验)
config = TrainingConfig(
    feature_cols=["factor1", "factor2", "factor3"],
    target_col="future_return_5d",
    train_start="20200101",
    train_end="20221231",
    test_start="20230101",
    test_end="20231231",
    stock_filter=StockFilterConfig(
        exclude_cyb=True,
        exclude_kcb=True,
    ),
    stock_selector=None,  # 跳过市值选择
    persist_model=False,  # 不持久化
)

# 2. 因子计算
engine = FactorEngine()
all_data = engine.compute(
    factor_names=config.feature_cols + [config.target_col],
    start_date=config.train_start,
    end_date=config.test_end,
)

# 3. 从配置创建组件
model = ModelRegistry.get_model(config.model_type)(**config.model_params)
processors = [
    ProcessorRegistry.get_processor(p.name)(**p.params)
    for p in config.processors
]

# 4. 创建股票池管理器
pool_manager = StockPoolManager(
    filter_config=config.stock_filter,
    selector_config=config.stock_selector,
    data_router=data_router,
)

# 5. 创建并运行训练器
trainer = Trainer(
    model=model,
    pool_manager=pool_manager,
    processors=processors,
    splitter=DateSplitter(
        train_start=config.train_start,
        train_end=config.train_end,
        test_start=config.test_start,
        test_end=config.test_end,
    ),
    target_col=config.target_col,
    feature_cols=config.feature_cols,
)

trainer.train(all_data)
results = trainer.get_results()

6. 实现顺序

按以下顺序实现和提交:

Commit 1: 基础架构

  • training/__init__.py
  • training/components/__init__.py
  • training/components/base.pyBaseModel, BaseProcessor含 save/load
  • training/registry.py(组件注册中心)

Commit 2: 数据划分

  • training/components/splitters.pyDateSplitter仅一次性划分

Commit 3: 股票池选择器配置

  • training/components/selectors.pyStockFilterConfig, MarketCapSelectorConfig

Commit 4: 数据处理器

  • training/components/processors/__init__.py
  • training/components/processors/transforms.py
    • Winsorizer
    • StandardScaler
    • CrossSectionalStandardScaler

Commit 5: LightGBM 模型

  • training/components/models/__init__.py
  • training/components/models/lightgbm.py(含 save_model/load_model

Commit 6: 股票池管理器

  • training/core/__init__.py
  • training/core/stock_pool_manager.py(每日独立筛选)

Commit 7: Trainer 训练器

  • training/core/trainer.py

Commit 8: 配置管理

  • training/config/__init__.py
  • training/config/config.pyTrainingConfig含必填校验

Commit 9: 预留实验模块

  • experiment/__init__.py

7. 注意事项

7.1 股票池处理顺序(每日)

当日所有股票数据
    ↓
代码过滤创业板、ST等
    ↓
查询 daily_basic 获取当日市值
    ↓
市值选择前N只
    ↓
返回当日选中股票
  • 每日独立筛选,市值动态变化
  • 市值数据仅从 daily_basic 获取
  • 市值数据绝不混入特征矩阵

7.2 Processor 阶段行为

Processor 训练集 测试集
StandardScaler fit_transform transform使用训练集参数
CrossSectionalStandardScaler transform transform每天独立
Winsorizer (global) fit_transform transform使用训练集参数
Winsorizer (by_date) transform transform每天独立

7.3 依赖关系

  • 使用 Polars 进行数据处理
  • LightGBM 用于模型训练
  • Pydantic 用于配置管理
  • 市值数据来自 daily_basic 表(独立数据源)

7.4 删除的功能

以下原计划在本次实现中删除:

  1. 特征选择processors/selectors.py
  2. 滚动训练WalkForward, ExpandingWindow
  3. 结果分析工具(复杂分析功能)
  4. validator.py, evaluator.py(已删除,不实现 metrics

7.5 新增功能

  1. StockPoolManager(每日独立筛选)
  2. 模型持久化save/load默认关闭
  3. 配置必填校验feature_cols, 日期范围)