feat(probe-selection): 添加探针法因子筛选模块
This commit is contained in:
253
src/experiment/probe_selection/probe_trainer.py
Normal file
253
src/experiment/probe_selection/probe_trainer.py
Normal file
@@ -0,0 +1,253 @@
|
||||
"""探针训练器
|
||||
|
||||
执行多任务训练(回归 + 分类),支持验证集早停。
|
||||
"""
|
||||
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import polars as pl
|
||||
|
||||
from src.experiment.probe_selection.lightgbm_classifier import LightGBMClassifier
|
||||
from src.training.components.models.lightgbm import LightGBMModel
|
||||
|
||||
|
||||
def split_validation_by_date(
|
||||
df: pl.DataFrame,
|
||||
date_col: str = "trade_date",
|
||||
val_ratio: float = 0.15,
|
||||
) -> Tuple[pl.DataFrame, pl.DataFrame]:
|
||||
"""按时间切分训练集和验证集(最近日期作为验证集)
|
||||
|
||||
Args:
|
||||
df: 输入数据
|
||||
date_col: 日期列名
|
||||
val_ratio: 验证集比例
|
||||
|
||||
Returns:
|
||||
(train_df, val_df) 元组
|
||||
"""
|
||||
dates = df[date_col].unique().sort()
|
||||
n_dates = len(dates)
|
||||
n_val_dates = max(1, int(n_dates * val_ratio))
|
||||
|
||||
val_dates = dates[-n_val_dates:]
|
||||
train_df = df.filter(~pl.col(date_col).is_in(val_dates))
|
||||
val_df = df.filter(pl.col(date_col).is_in(val_dates))
|
||||
|
||||
return train_df, val_df
|
||||
|
||||
|
||||
def create_classification_target(
|
||||
df: pl.DataFrame,
|
||||
return_col: str,
|
||||
date_col: str = "trade_date",
|
||||
new_col_name: str = "target_class",
|
||||
) -> pl.DataFrame:
|
||||
"""将收益率转换为截面中位数分类标签
|
||||
|
||||
优势:预测跑赢当天市场平均水平的股票,真正有 Alpha 的因子
|
||||
避免:牛熊市不平衡导致的分类失效
|
||||
|
||||
Args:
|
||||
df: 输入数据
|
||||
return_col: 收益率列名
|
||||
date_col: 日期列名
|
||||
new_col_name: 新列名
|
||||
|
||||
Returns:
|
||||
添加了分类标签的 DataFrame
|
||||
"""
|
||||
return df.with_columns(
|
||||
(pl.col(return_col) > pl.col(return_col).median().over(date_col))
|
||||
.cast(pl.Int8)
|
||||
.alias(new_col_name)
|
||||
)
|
||||
|
||||
|
||||
class ProbeTrainer:
|
||||
"""探针训练器
|
||||
|
||||
执行多任务训练(回归 + 分类),基于验证集早停。
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
regression_params: Optional[dict] = None,
|
||||
classification_params: Optional[dict] = None,
|
||||
validation_ratio: float = 0.15,
|
||||
random_state: int = 42,
|
||||
):
|
||||
"""初始化探针训练器
|
||||
|
||||
Args:
|
||||
regression_params: 回归模型参数
|
||||
classification_params: 分类模型参数
|
||||
validation_ratio: 验证集比例
|
||||
random_state: 随机种子
|
||||
"""
|
||||
self.regression_params = regression_params or {
|
||||
"objective": "regression",
|
||||
"metric": "mae",
|
||||
"n_estimators": 500,
|
||||
"learning_rate": 0.05,
|
||||
"early_stopping_round": 50,
|
||||
"verbose": -1,
|
||||
}
|
||||
self.classification_params = classification_params or {
|
||||
"objective": "binary",
|
||||
"metric": "auc",
|
||||
"n_estimators": 500,
|
||||
"learning_rate": 0.05,
|
||||
"early_stopping_round": 50,
|
||||
"verbose": -1,
|
||||
}
|
||||
self.validation_ratio = validation_ratio
|
||||
self.random_state = random_state
|
||||
|
||||
self.regression_model: Optional[LightGBMModel] = None
|
||||
self.classification_model: Optional[LightGBMClassifier] = None
|
||||
self.training_info: dict = {}
|
||||
|
||||
def fit(
|
||||
self,
|
||||
df: pl.DataFrame,
|
||||
feature_cols: list[str],
|
||||
target_col_regression: str,
|
||||
target_col_classification: Optional[str] = None,
|
||||
date_col: str = "trade_date",
|
||||
) -> "ProbeTrainer":
|
||||
"""训练回归和分类模型
|
||||
|
||||
Args:
|
||||
df: 训练数据(包含噪音特征)
|
||||
feature_cols: 特征列名列表(包含噪音)
|
||||
target_col_regression: 回归目标列名
|
||||
target_col_classification: 分类目标列名(如不传则自动生成)
|
||||
date_col: 日期列名
|
||||
|
||||
Returns:
|
||||
self
|
||||
"""
|
||||
# 切分训练集和验证集(按时间)
|
||||
train_df, val_df = split_validation_by_date(df, date_col, self.validation_ratio)
|
||||
|
||||
self.training_info = {
|
||||
"train_size": len(train_df),
|
||||
"val_size": len(val_df),
|
||||
"n_features": len(feature_cols),
|
||||
}
|
||||
|
||||
# 训练回归模型
|
||||
self._fit_regression(train_df, val_df, feature_cols, target_col_regression)
|
||||
|
||||
# 准备分类目标
|
||||
if target_col_classification is None:
|
||||
# 自动生成截面中位数分类目标
|
||||
train_df = create_classification_target(
|
||||
train_df, target_col_regression, date_col
|
||||
)
|
||||
val_df = create_classification_target(
|
||||
val_df, target_col_regression, date_col
|
||||
)
|
||||
target_col_classification = "target_class"
|
||||
|
||||
# 训练分类模型
|
||||
self._fit_classification(
|
||||
train_df, val_df, feature_cols, target_col_classification
|
||||
)
|
||||
|
||||
return self
|
||||
|
||||
def _fit_regression(
|
||||
self,
|
||||
train_df: pl.DataFrame,
|
||||
val_df: pl.DataFrame,
|
||||
feature_cols: list[str],
|
||||
target_col: str,
|
||||
):
|
||||
"""训练回归模型"""
|
||||
X_train = train_df.select(feature_cols)
|
||||
y_train = train_df.select(target_col).to_series()
|
||||
X_val = val_df.select(feature_cols)
|
||||
y_val = val_df.select(target_col).to_series()
|
||||
|
||||
self.regression_model = LightGBMModel(params=self.regression_params)
|
||||
self.regression_model.fit(
|
||||
X_train,
|
||||
y_train,
|
||||
eval_set=(X_val, y_val),
|
||||
)
|
||||
|
||||
# 获取早停信息
|
||||
if hasattr(self.regression_model.model, "best_iteration"):
|
||||
self.training_info["regression_best_iter"] = (
|
||||
self.regression_model.model.best_iteration
|
||||
)
|
||||
|
||||
def _fit_classification(
|
||||
self,
|
||||
train_df: pl.DataFrame,
|
||||
val_df: pl.DataFrame,
|
||||
feature_cols: list[str],
|
||||
target_col: str,
|
||||
):
|
||||
"""训练分类模型"""
|
||||
X_train = train_df.select(feature_cols)
|
||||
y_train = train_df.select(target_col).to_series()
|
||||
X_val = val_df.select(feature_cols)
|
||||
y_val = val_df.select(target_col).to_series()
|
||||
|
||||
self.classification_model = LightGBMClassifier(
|
||||
params=self.classification_params
|
||||
)
|
||||
self.classification_model.fit(
|
||||
X_train,
|
||||
y_train,
|
||||
eval_set=(X_val, y_val),
|
||||
)
|
||||
|
||||
# 获取早停信息
|
||||
if hasattr(self.classification_model.model, "best_iteration"):
|
||||
self.training_info["classification_best_iter"] = (
|
||||
self.classification_model.model.best_iteration
|
||||
)
|
||||
|
||||
def get_feature_importance(
|
||||
self, importance_type: str = "gain"
|
||||
) -> Tuple[Optional[dict], Optional[dict]]:
|
||||
"""获取两个模型的特征重要性
|
||||
|
||||
Args:
|
||||
importance_type: 重要性类型,必须传入 "gain"
|
||||
|
||||
Returns:
|
||||
(regression_importance, classification_importance) 元组
|
||||
每个重要性为 {feature_name: importance_value} 字典
|
||||
"""
|
||||
assert importance_type == "gain", (
|
||||
"必须使用 importance_type='gain',split 会被噪音欺骗"
|
||||
)
|
||||
|
||||
reg_importance = None
|
||||
cls_importance = None
|
||||
|
||||
if self.regression_model is not None:
|
||||
imp = self.regression_model.feature_importance()
|
||||
if imp is not None:
|
||||
reg_importance = imp.to_dict()
|
||||
|
||||
if self.classification_model is not None:
|
||||
imp = self.classification_model.feature_importance(importance_type)
|
||||
if imp is not None:
|
||||
cls_importance = imp.to_dict()
|
||||
|
||||
return reg_importance, cls_importance
|
||||
|
||||
def get_training_info(self) -> dict:
|
||||
"""获取训练信息
|
||||
|
||||
Returns:
|
||||
训练信息字典
|
||||
"""
|
||||
return self.training_info
|
||||
Reference in New Issue
Block a user