254 lines
7.6 KiB
Python
254 lines
7.6 KiB
Python
"""探针训练器
|
||
|
||
执行多任务训练(回归 + 分类),支持验证集早停。
|
||
"""
|
||
|
||
from typing import Optional, Tuple
|
||
|
||
import numpy as np
|
||
import polars as pl
|
||
|
||
from src.experiment.probe_selection.lightgbm_classifier import LightGBMClassifier
|
||
from src.training.components.models.lightgbm import LightGBMModel
|
||
|
||
|
||
def split_validation_by_date(
|
||
df: pl.DataFrame,
|
||
date_col: str = "trade_date",
|
||
val_ratio: float = 0.15,
|
||
) -> Tuple[pl.DataFrame, pl.DataFrame]:
|
||
"""按时间切分训练集和验证集(最近日期作为验证集)
|
||
|
||
Args:
|
||
df: 输入数据
|
||
date_col: 日期列名
|
||
val_ratio: 验证集比例
|
||
|
||
Returns:
|
||
(train_df, val_df) 元组
|
||
"""
|
||
dates = df[date_col].unique().sort()
|
||
n_dates = len(dates)
|
||
n_val_dates = max(1, int(n_dates * val_ratio))
|
||
|
||
val_dates = dates[-n_val_dates:]
|
||
train_df = df.filter(~pl.col(date_col).is_in(val_dates))
|
||
val_df = df.filter(pl.col(date_col).is_in(val_dates))
|
||
|
||
return train_df, val_df
|
||
|
||
|
||
def create_classification_target(
|
||
df: pl.DataFrame,
|
||
return_col: str,
|
||
date_col: str = "trade_date",
|
||
new_col_name: str = "target_class",
|
||
) -> pl.DataFrame:
|
||
"""将收益率转换为截面中位数分类标签
|
||
|
||
优势:预测跑赢当天市场平均水平的股票,真正有 Alpha 的因子
|
||
避免:牛熊市不平衡导致的分类失效
|
||
|
||
Args:
|
||
df: 输入数据
|
||
return_col: 收益率列名
|
||
date_col: 日期列名
|
||
new_col_name: 新列名
|
||
|
||
Returns:
|
||
添加了分类标签的 DataFrame
|
||
"""
|
||
return df.with_columns(
|
||
(pl.col(return_col) > pl.col(return_col).median().over(date_col))
|
||
.cast(pl.Int8)
|
||
.alias(new_col_name)
|
||
)
|
||
|
||
|
||
class ProbeTrainer:
|
||
"""探针训练器
|
||
|
||
执行多任务训练(回归 + 分类),基于验证集早停。
|
||
"""
|
||
|
||
def __init__(
|
||
self,
|
||
regression_params: Optional[dict] = None,
|
||
classification_params: Optional[dict] = None,
|
||
validation_ratio: float = 0.15,
|
||
random_state: int = 42,
|
||
):
|
||
"""初始化探针训练器
|
||
|
||
Args:
|
||
regression_params: 回归模型参数
|
||
classification_params: 分类模型参数
|
||
validation_ratio: 验证集比例
|
||
random_state: 随机种子
|
||
"""
|
||
self.regression_params = regression_params or {
|
||
"objective": "regression",
|
||
"metric": "mae",
|
||
"n_estimators": 500,
|
||
"learning_rate": 0.05,
|
||
"early_stopping_round": 50,
|
||
"verbose": -1,
|
||
}
|
||
self.classification_params = classification_params or {
|
||
"objective": "binary",
|
||
"metric": "auc",
|
||
"n_estimators": 500,
|
||
"learning_rate": 0.05,
|
||
"early_stopping_round": 50,
|
||
"verbose": -1,
|
||
}
|
||
self.validation_ratio = validation_ratio
|
||
self.random_state = random_state
|
||
|
||
self.regression_model: Optional[LightGBMModel] = None
|
||
self.classification_model: Optional[LightGBMClassifier] = None
|
||
self.training_info: dict = {}
|
||
|
||
def fit(
|
||
self,
|
||
df: pl.DataFrame,
|
||
feature_cols: list[str],
|
||
target_col_regression: str,
|
||
target_col_classification: Optional[str] = None,
|
||
date_col: str = "trade_date",
|
||
) -> "ProbeTrainer":
|
||
"""训练回归和分类模型
|
||
|
||
Args:
|
||
df: 训练数据(包含噪音特征)
|
||
feature_cols: 特征列名列表(包含噪音)
|
||
target_col_regression: 回归目标列名
|
||
target_col_classification: 分类目标列名(如不传则自动生成)
|
||
date_col: 日期列名
|
||
|
||
Returns:
|
||
self
|
||
"""
|
||
# 切分训练集和验证集(按时间)
|
||
train_df, val_df = split_validation_by_date(df, date_col, self.validation_ratio)
|
||
|
||
self.training_info = {
|
||
"train_size": len(train_df),
|
||
"val_size": len(val_df),
|
||
"n_features": len(feature_cols),
|
||
}
|
||
|
||
# 训练回归模型
|
||
self._fit_regression(train_df, val_df, feature_cols, target_col_regression)
|
||
|
||
# 准备分类目标
|
||
if target_col_classification is None:
|
||
# 自动生成截面中位数分类目标
|
||
train_df = create_classification_target(
|
||
train_df, target_col_regression, date_col
|
||
)
|
||
val_df = create_classification_target(
|
||
val_df, target_col_regression, date_col
|
||
)
|
||
target_col_classification = "target_class"
|
||
|
||
# 训练分类模型
|
||
self._fit_classification(
|
||
train_df, val_df, feature_cols, target_col_classification
|
||
)
|
||
|
||
return self
|
||
|
||
def _fit_regression(
|
||
self,
|
||
train_df: pl.DataFrame,
|
||
val_df: pl.DataFrame,
|
||
feature_cols: list[str],
|
||
target_col: str,
|
||
):
|
||
"""训练回归模型"""
|
||
X_train = train_df.select(feature_cols)
|
||
y_train = train_df.select(target_col).to_series()
|
||
X_val = val_df.select(feature_cols)
|
||
y_val = val_df.select(target_col).to_series()
|
||
|
||
self.regression_model = LightGBMModel(params=self.regression_params)
|
||
self.regression_model.fit(
|
||
X_train,
|
||
y_train,
|
||
eval_set=(X_val, y_val),
|
||
)
|
||
|
||
# 获取早停信息
|
||
if hasattr(self.regression_model.model, "best_iteration"):
|
||
self.training_info["regression_best_iter"] = (
|
||
self.regression_model.model.best_iteration
|
||
)
|
||
|
||
def _fit_classification(
|
||
self,
|
||
train_df: pl.DataFrame,
|
||
val_df: pl.DataFrame,
|
||
feature_cols: list[str],
|
||
target_col: str,
|
||
):
|
||
"""训练分类模型"""
|
||
X_train = train_df.select(feature_cols)
|
||
y_train = train_df.select(target_col).to_series()
|
||
X_val = val_df.select(feature_cols)
|
||
y_val = val_df.select(target_col).to_series()
|
||
|
||
self.classification_model = LightGBMClassifier(
|
||
params=self.classification_params
|
||
)
|
||
self.classification_model.fit(
|
||
X_train,
|
||
y_train,
|
||
eval_set=(X_val, y_val),
|
||
)
|
||
|
||
# 获取早停信息
|
||
if hasattr(self.classification_model.model, "best_iteration"):
|
||
self.training_info["classification_best_iter"] = (
|
||
self.classification_model.model.best_iteration
|
||
)
|
||
|
||
def get_feature_importance(
|
||
self, importance_type: str = "gain"
|
||
) -> Tuple[Optional[dict], Optional[dict]]:
|
||
"""获取两个模型的特征重要性
|
||
|
||
Args:
|
||
importance_type: 重要性类型,必须传入 "gain"
|
||
|
||
Returns:
|
||
(regression_importance, classification_importance) 元组
|
||
每个重要性为 {feature_name: importance_value} 字典
|
||
"""
|
||
assert importance_type == "gain", (
|
||
"必须使用 importance_type='gain',split 会被噪音欺骗"
|
||
)
|
||
|
||
reg_importance = None
|
||
cls_importance = None
|
||
|
||
if self.regression_model is not None:
|
||
imp = self.regression_model.feature_importance()
|
||
if imp is not None:
|
||
reg_importance = imp.to_dict()
|
||
|
||
if self.classification_model is not None:
|
||
imp = self.classification_model.feature_importance(importance_type)
|
||
if imp is not None:
|
||
cls_importance = imp.to_dict()
|
||
|
||
return reg_importance, cls_importance
|
||
|
||
def get_training_info(self) -> dict:
|
||
"""获取训练信息
|
||
|
||
Returns:
|
||
训练信息字典
|
||
"""
|
||
return self.training_info
|