"""探针训练器 执行多任务训练(回归 + 分类),支持验证集早停。 """ from typing import Optional, Tuple import numpy as np import polars as pl from src.experiment.probe_selection.lightgbm_classifier import LightGBMClassifier from src.training.components.models.lightgbm import LightGBMModel def split_validation_by_date( df: pl.DataFrame, date_col: str = "trade_date", val_ratio: float = 0.15, ) -> Tuple[pl.DataFrame, pl.DataFrame]: """按时间切分训练集和验证集(最近日期作为验证集) Args: df: 输入数据 date_col: 日期列名 val_ratio: 验证集比例 Returns: (train_df, val_df) 元组 """ dates = df[date_col].unique().sort() n_dates = len(dates) n_val_dates = max(1, int(n_dates * val_ratio)) val_dates = dates[-n_val_dates:] train_df = df.filter(~pl.col(date_col).is_in(val_dates)) val_df = df.filter(pl.col(date_col).is_in(val_dates)) return train_df, val_df def create_classification_target( df: pl.DataFrame, return_col: str, date_col: str = "trade_date", new_col_name: str = "target_class", ) -> pl.DataFrame: """将收益率转换为截面中位数分类标签 优势:预测跑赢当天市场平均水平的股票,真正有 Alpha 的因子 避免:牛熊市不平衡导致的分类失效 Args: df: 输入数据 return_col: 收益率列名 date_col: 日期列名 new_col_name: 新列名 Returns: 添加了分类标签的 DataFrame """ return df.with_columns( (pl.col(return_col) > pl.col(return_col).median().over(date_col)) .cast(pl.Int8) .alias(new_col_name) ) class ProbeTrainer: """探针训练器 执行多任务训练(回归 + 分类),基于验证集早停。 """ def __init__( self, regression_params: Optional[dict] = None, classification_params: Optional[dict] = None, validation_ratio: float = 0.15, random_state: int = 42, ): """初始化探针训练器 Args: regression_params: 回归模型参数 classification_params: 分类模型参数 validation_ratio: 验证集比例 random_state: 随机种子 """ self.regression_params = regression_params or { "objective": "regression", "metric": "mae", "n_estimators": 500, "learning_rate": 0.05, "early_stopping_round": 50, "verbose": -1, } self.classification_params = classification_params or { "objective": "binary", "metric": "auc", "n_estimators": 500, "learning_rate": 0.05, "early_stopping_round": 50, "verbose": -1, } self.validation_ratio = validation_ratio self.random_state = random_state self.regression_model: Optional[LightGBMModel] = None self.classification_model: Optional[LightGBMClassifier] = None self.training_info: dict = {} def fit( self, df: pl.DataFrame, feature_cols: list[str], target_col_regression: str, target_col_classification: Optional[str] = None, date_col: str = "trade_date", ) -> "ProbeTrainer": """训练回归和分类模型 Args: df: 训练数据(包含噪音特征) feature_cols: 特征列名列表(包含噪音) target_col_regression: 回归目标列名 target_col_classification: 分类目标列名(如不传则自动生成) date_col: 日期列名 Returns: self """ # 切分训练集和验证集(按时间) train_df, val_df = split_validation_by_date(df, date_col, self.validation_ratio) self.training_info = { "train_size": len(train_df), "val_size": len(val_df), "n_features": len(feature_cols), } # 训练回归模型 self._fit_regression(train_df, val_df, feature_cols, target_col_regression) # 准备分类目标 if target_col_classification is None: # 自动生成截面中位数分类目标 train_df = create_classification_target( train_df, target_col_regression, date_col ) val_df = create_classification_target( val_df, target_col_regression, date_col ) target_col_classification = "target_class" # 训练分类模型 self._fit_classification( train_df, val_df, feature_cols, target_col_classification ) return self def _fit_regression( self, train_df: pl.DataFrame, val_df: pl.DataFrame, feature_cols: list[str], target_col: str, ): """训练回归模型""" X_train = train_df.select(feature_cols) y_train = train_df.select(target_col).to_series() X_val = val_df.select(feature_cols) y_val = val_df.select(target_col).to_series() self.regression_model = LightGBMModel(params=self.regression_params) self.regression_model.fit( X_train, y_train, eval_set=(X_val, y_val), ) # 获取早停信息 if hasattr(self.regression_model.model, "best_iteration"): self.training_info["regression_best_iter"] = ( self.regression_model.model.best_iteration ) def _fit_classification( self, train_df: pl.DataFrame, val_df: pl.DataFrame, feature_cols: list[str], target_col: str, ): """训练分类模型""" X_train = train_df.select(feature_cols) y_train = train_df.select(target_col).to_series() X_val = val_df.select(feature_cols) y_val = val_df.select(target_col).to_series() self.classification_model = LightGBMClassifier( params=self.classification_params ) self.classification_model.fit( X_train, y_train, eval_set=(X_val, y_val), ) # 获取早停信息 if hasattr(self.classification_model.model, "best_iteration"): self.training_info["classification_best_iter"] = ( self.classification_model.model.best_iteration ) def get_feature_importance( self, importance_type: str = "gain" ) -> Tuple[Optional[dict], Optional[dict]]: """获取两个模型的特征重要性 Args: importance_type: 重要性类型,必须传入 "gain" Returns: (regression_importance, classification_importance) 元组 每个重要性为 {feature_name: importance_value} 字典 """ assert importance_type == "gain", ( "必须使用 importance_type='gain',split 会被噪音欺骗" ) reg_importance = None cls_importance = None if self.regression_model is not None: imp = self.regression_model.feature_importance() if imp is not None: reg_importance = imp.to_dict() if self.classification_model is not None: imp = self.classification_model.feature_importance(importance_type) if imp is not None: cls_importance = imp.to_dict() return reg_importance, cls_importance def get_training_info(self) -> dict: """获取训练信息 Returns: 训练信息字典 """ return self.training_info