feat(probe-selection): 添加探针法因子筛选模块

This commit is contained in:
2026-03-14 22:50:32 +08:00
parent bdf937086f
commit 5541373ded
9 changed files with 1491 additions and 0 deletions

View File

@@ -0,0 +1,253 @@
"""探针训练器
执行多任务训练(回归 + 分类),支持验证集早停。
"""
from typing import Optional, Tuple
import numpy as np
import polars as pl
from src.experiment.probe_selection.lightgbm_classifier import LightGBMClassifier
from src.training.components.models.lightgbm import LightGBMModel
def split_validation_by_date(
df: pl.DataFrame,
date_col: str = "trade_date",
val_ratio: float = 0.15,
) -> Tuple[pl.DataFrame, pl.DataFrame]:
"""按时间切分训练集和验证集(最近日期作为验证集)
Args:
df: 输入数据
date_col: 日期列名
val_ratio: 验证集比例
Returns:
(train_df, val_df) 元组
"""
dates = df[date_col].unique().sort()
n_dates = len(dates)
n_val_dates = max(1, int(n_dates * val_ratio))
val_dates = dates[-n_val_dates:]
train_df = df.filter(~pl.col(date_col).is_in(val_dates))
val_df = df.filter(pl.col(date_col).is_in(val_dates))
return train_df, val_df
def create_classification_target(
df: pl.DataFrame,
return_col: str,
date_col: str = "trade_date",
new_col_name: str = "target_class",
) -> pl.DataFrame:
"""将收益率转换为截面中位数分类标签
优势:预测跑赢当天市场平均水平的股票,真正有 Alpha 的因子
避免:牛熊市不平衡导致的分类失效
Args:
df: 输入数据
return_col: 收益率列名
date_col: 日期列名
new_col_name: 新列名
Returns:
添加了分类标签的 DataFrame
"""
return df.with_columns(
(pl.col(return_col) > pl.col(return_col).median().over(date_col))
.cast(pl.Int8)
.alias(new_col_name)
)
class ProbeTrainer:
"""探针训练器
执行多任务训练(回归 + 分类),基于验证集早停。
"""
def __init__(
self,
regression_params: Optional[dict] = None,
classification_params: Optional[dict] = None,
validation_ratio: float = 0.15,
random_state: int = 42,
):
"""初始化探针训练器
Args:
regression_params: 回归模型参数
classification_params: 分类模型参数
validation_ratio: 验证集比例
random_state: 随机种子
"""
self.regression_params = regression_params or {
"objective": "regression",
"metric": "mae",
"n_estimators": 500,
"learning_rate": 0.05,
"early_stopping_round": 50,
"verbose": -1,
}
self.classification_params = classification_params or {
"objective": "binary",
"metric": "auc",
"n_estimators": 500,
"learning_rate": 0.05,
"early_stopping_round": 50,
"verbose": -1,
}
self.validation_ratio = validation_ratio
self.random_state = random_state
self.regression_model: Optional[LightGBMModel] = None
self.classification_model: Optional[LightGBMClassifier] = None
self.training_info: dict = {}
def fit(
self,
df: pl.DataFrame,
feature_cols: list[str],
target_col_regression: str,
target_col_classification: Optional[str] = None,
date_col: str = "trade_date",
) -> "ProbeTrainer":
"""训练回归和分类模型
Args:
df: 训练数据(包含噪音特征)
feature_cols: 特征列名列表(包含噪音)
target_col_regression: 回归目标列名
target_col_classification: 分类目标列名(如不传则自动生成)
date_col: 日期列名
Returns:
self
"""
# 切分训练集和验证集(按时间)
train_df, val_df = split_validation_by_date(df, date_col, self.validation_ratio)
self.training_info = {
"train_size": len(train_df),
"val_size": len(val_df),
"n_features": len(feature_cols),
}
# 训练回归模型
self._fit_regression(train_df, val_df, feature_cols, target_col_regression)
# 准备分类目标
if target_col_classification is None:
# 自动生成截面中位数分类目标
train_df = create_classification_target(
train_df, target_col_regression, date_col
)
val_df = create_classification_target(
val_df, target_col_regression, date_col
)
target_col_classification = "target_class"
# 训练分类模型
self._fit_classification(
train_df, val_df, feature_cols, target_col_classification
)
return self
def _fit_regression(
self,
train_df: pl.DataFrame,
val_df: pl.DataFrame,
feature_cols: list[str],
target_col: str,
):
"""训练回归模型"""
X_train = train_df.select(feature_cols)
y_train = train_df.select(target_col).to_series()
X_val = val_df.select(feature_cols)
y_val = val_df.select(target_col).to_series()
self.regression_model = LightGBMModel(params=self.regression_params)
self.regression_model.fit(
X_train,
y_train,
eval_set=(X_val, y_val),
)
# 获取早停信息
if hasattr(self.regression_model.model, "best_iteration"):
self.training_info["regression_best_iter"] = (
self.regression_model.model.best_iteration
)
def _fit_classification(
self,
train_df: pl.DataFrame,
val_df: pl.DataFrame,
feature_cols: list[str],
target_col: str,
):
"""训练分类模型"""
X_train = train_df.select(feature_cols)
y_train = train_df.select(target_col).to_series()
X_val = val_df.select(feature_cols)
y_val = val_df.select(target_col).to_series()
self.classification_model = LightGBMClassifier(
params=self.classification_params
)
self.classification_model.fit(
X_train,
y_train,
eval_set=(X_val, y_val),
)
# 获取早停信息
if hasattr(self.classification_model.model, "best_iteration"):
self.training_info["classification_best_iter"] = (
self.classification_model.model.best_iteration
)
def get_feature_importance(
self, importance_type: str = "gain"
) -> Tuple[Optional[dict], Optional[dict]]:
"""获取两个模型的特征重要性
Args:
importance_type: 重要性类型,必须传入 "gain"
Returns:
(regression_importance, classification_importance) 元组
每个重要性为 {feature_name: importance_value} 字典
"""
assert importance_type == "gain", (
"必须使用 importance_type='gain'split 会被噪音欺骗"
)
reg_importance = None
cls_importance = None
if self.regression_model is not None:
imp = self.regression_model.feature_importance()
if imp is not None:
reg_importance = imp.to_dict()
if self.classification_model is not None:
imp = self.classification_model.feature_importance(importance_type)
if imp is not None:
cls_importance = imp.to_dict()
return reg_importance, cls_importance
def get_training_info(self) -> dict:
"""获取训练信息
Returns:
训练信息字典
"""
return self.training_info