254 lines
7.6 KiB
Python
254 lines
7.6 KiB
Python
|
|
"""探针训练器
|
|||
|
|
|
|||
|
|
执行多任务训练(回归 + 分类),支持验证集早停。
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
from typing import Optional, Tuple
|
|||
|
|
|
|||
|
|
import numpy as np
|
|||
|
|
import polars as pl
|
|||
|
|
|
|||
|
|
from src.experiment.probe_selection.lightgbm_classifier import LightGBMClassifier
|
|||
|
|
from src.training.components.models.lightgbm import LightGBMModel
|
|||
|
|
|
|||
|
|
|
|||
|
|
def split_validation_by_date(
|
|||
|
|
df: pl.DataFrame,
|
|||
|
|
date_col: str = "trade_date",
|
|||
|
|
val_ratio: float = 0.15,
|
|||
|
|
) -> Tuple[pl.DataFrame, pl.DataFrame]:
|
|||
|
|
"""按时间切分训练集和验证集(最近日期作为验证集)
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
df: 输入数据
|
|||
|
|
date_col: 日期列名
|
|||
|
|
val_ratio: 验证集比例
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
(train_df, val_df) 元组
|
|||
|
|
"""
|
|||
|
|
dates = df[date_col].unique().sort()
|
|||
|
|
n_dates = len(dates)
|
|||
|
|
n_val_dates = max(1, int(n_dates * val_ratio))
|
|||
|
|
|
|||
|
|
val_dates = dates[-n_val_dates:]
|
|||
|
|
train_df = df.filter(~pl.col(date_col).is_in(val_dates))
|
|||
|
|
val_df = df.filter(pl.col(date_col).is_in(val_dates))
|
|||
|
|
|
|||
|
|
return train_df, val_df
|
|||
|
|
|
|||
|
|
|
|||
|
|
def create_classification_target(
|
|||
|
|
df: pl.DataFrame,
|
|||
|
|
return_col: str,
|
|||
|
|
date_col: str = "trade_date",
|
|||
|
|
new_col_name: str = "target_class",
|
|||
|
|
) -> pl.DataFrame:
|
|||
|
|
"""将收益率转换为截面中位数分类标签
|
|||
|
|
|
|||
|
|
优势:预测跑赢当天市场平均水平的股票,真正有 Alpha 的因子
|
|||
|
|
避免:牛熊市不平衡导致的分类失效
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
df: 输入数据
|
|||
|
|
return_col: 收益率列名
|
|||
|
|
date_col: 日期列名
|
|||
|
|
new_col_name: 新列名
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
添加了分类标签的 DataFrame
|
|||
|
|
"""
|
|||
|
|
return df.with_columns(
|
|||
|
|
(pl.col(return_col) > pl.col(return_col).median().over(date_col))
|
|||
|
|
.cast(pl.Int8)
|
|||
|
|
.alias(new_col_name)
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
|
|||
|
|
class ProbeTrainer:
|
|||
|
|
"""探针训练器
|
|||
|
|
|
|||
|
|
执行多任务训练(回归 + 分类),基于验证集早停。
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
def __init__(
|
|||
|
|
self,
|
|||
|
|
regression_params: Optional[dict] = None,
|
|||
|
|
classification_params: Optional[dict] = None,
|
|||
|
|
validation_ratio: float = 0.15,
|
|||
|
|
random_state: int = 42,
|
|||
|
|
):
|
|||
|
|
"""初始化探针训练器
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
regression_params: 回归模型参数
|
|||
|
|
classification_params: 分类模型参数
|
|||
|
|
validation_ratio: 验证集比例
|
|||
|
|
random_state: 随机种子
|
|||
|
|
"""
|
|||
|
|
self.regression_params = regression_params or {
|
|||
|
|
"objective": "regression",
|
|||
|
|
"metric": "mae",
|
|||
|
|
"n_estimators": 500,
|
|||
|
|
"learning_rate": 0.05,
|
|||
|
|
"early_stopping_round": 50,
|
|||
|
|
"verbose": -1,
|
|||
|
|
}
|
|||
|
|
self.classification_params = classification_params or {
|
|||
|
|
"objective": "binary",
|
|||
|
|
"metric": "auc",
|
|||
|
|
"n_estimators": 500,
|
|||
|
|
"learning_rate": 0.05,
|
|||
|
|
"early_stopping_round": 50,
|
|||
|
|
"verbose": -1,
|
|||
|
|
}
|
|||
|
|
self.validation_ratio = validation_ratio
|
|||
|
|
self.random_state = random_state
|
|||
|
|
|
|||
|
|
self.regression_model: Optional[LightGBMModel] = None
|
|||
|
|
self.classification_model: Optional[LightGBMClassifier] = None
|
|||
|
|
self.training_info: dict = {}
|
|||
|
|
|
|||
|
|
def fit(
|
|||
|
|
self,
|
|||
|
|
df: pl.DataFrame,
|
|||
|
|
feature_cols: list[str],
|
|||
|
|
target_col_regression: str,
|
|||
|
|
target_col_classification: Optional[str] = None,
|
|||
|
|
date_col: str = "trade_date",
|
|||
|
|
) -> "ProbeTrainer":
|
|||
|
|
"""训练回归和分类模型
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
df: 训练数据(包含噪音特征)
|
|||
|
|
feature_cols: 特征列名列表(包含噪音)
|
|||
|
|
target_col_regression: 回归目标列名
|
|||
|
|
target_col_classification: 分类目标列名(如不传则自动生成)
|
|||
|
|
date_col: 日期列名
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
self
|
|||
|
|
"""
|
|||
|
|
# 切分训练集和验证集(按时间)
|
|||
|
|
train_df, val_df = split_validation_by_date(df, date_col, self.validation_ratio)
|
|||
|
|
|
|||
|
|
self.training_info = {
|
|||
|
|
"train_size": len(train_df),
|
|||
|
|
"val_size": len(val_df),
|
|||
|
|
"n_features": len(feature_cols),
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
# 训练回归模型
|
|||
|
|
self._fit_regression(train_df, val_df, feature_cols, target_col_regression)
|
|||
|
|
|
|||
|
|
# 准备分类目标
|
|||
|
|
if target_col_classification is None:
|
|||
|
|
# 自动生成截面中位数分类目标
|
|||
|
|
train_df = create_classification_target(
|
|||
|
|
train_df, target_col_regression, date_col
|
|||
|
|
)
|
|||
|
|
val_df = create_classification_target(
|
|||
|
|
val_df, target_col_regression, date_col
|
|||
|
|
)
|
|||
|
|
target_col_classification = "target_class"
|
|||
|
|
|
|||
|
|
# 训练分类模型
|
|||
|
|
self._fit_classification(
|
|||
|
|
train_df, val_df, feature_cols, target_col_classification
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
return self
|
|||
|
|
|
|||
|
|
def _fit_regression(
|
|||
|
|
self,
|
|||
|
|
train_df: pl.DataFrame,
|
|||
|
|
val_df: pl.DataFrame,
|
|||
|
|
feature_cols: list[str],
|
|||
|
|
target_col: str,
|
|||
|
|
):
|
|||
|
|
"""训练回归模型"""
|
|||
|
|
X_train = train_df.select(feature_cols)
|
|||
|
|
y_train = train_df.select(target_col).to_series()
|
|||
|
|
X_val = val_df.select(feature_cols)
|
|||
|
|
y_val = val_df.select(target_col).to_series()
|
|||
|
|
|
|||
|
|
self.regression_model = LightGBMModel(params=self.regression_params)
|
|||
|
|
self.regression_model.fit(
|
|||
|
|
X_train,
|
|||
|
|
y_train,
|
|||
|
|
eval_set=(X_val, y_val),
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# 获取早停信息
|
|||
|
|
if hasattr(self.regression_model.model, "best_iteration"):
|
|||
|
|
self.training_info["regression_best_iter"] = (
|
|||
|
|
self.regression_model.model.best_iteration
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
def _fit_classification(
|
|||
|
|
self,
|
|||
|
|
train_df: pl.DataFrame,
|
|||
|
|
val_df: pl.DataFrame,
|
|||
|
|
feature_cols: list[str],
|
|||
|
|
target_col: str,
|
|||
|
|
):
|
|||
|
|
"""训练分类模型"""
|
|||
|
|
X_train = train_df.select(feature_cols)
|
|||
|
|
y_train = train_df.select(target_col).to_series()
|
|||
|
|
X_val = val_df.select(feature_cols)
|
|||
|
|
y_val = val_df.select(target_col).to_series()
|
|||
|
|
|
|||
|
|
self.classification_model = LightGBMClassifier(
|
|||
|
|
params=self.classification_params
|
|||
|
|
)
|
|||
|
|
self.classification_model.fit(
|
|||
|
|
X_train,
|
|||
|
|
y_train,
|
|||
|
|
eval_set=(X_val, y_val),
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# 获取早停信息
|
|||
|
|
if hasattr(self.classification_model.model, "best_iteration"):
|
|||
|
|
self.training_info["classification_best_iter"] = (
|
|||
|
|
self.classification_model.model.best_iteration
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
def get_feature_importance(
|
|||
|
|
self, importance_type: str = "gain"
|
|||
|
|
) -> Tuple[Optional[dict], Optional[dict]]:
|
|||
|
|
"""获取两个模型的特征重要性
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
importance_type: 重要性类型,必须传入 "gain"
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
(regression_importance, classification_importance) 元组
|
|||
|
|
每个重要性为 {feature_name: importance_value} 字典
|
|||
|
|
"""
|
|||
|
|
assert importance_type == "gain", (
|
|||
|
|
"必须使用 importance_type='gain',split 会被噪音欺骗"
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
reg_importance = None
|
|||
|
|
cls_importance = None
|
|||
|
|
|
|||
|
|
if self.regression_model is not None:
|
|||
|
|
imp = self.regression_model.feature_importance()
|
|||
|
|
if imp is not None:
|
|||
|
|
reg_importance = imp.to_dict()
|
|||
|
|
|
|||
|
|
if self.classification_model is not None:
|
|||
|
|
imp = self.classification_model.feature_importance(importance_type)
|
|||
|
|
if imp is not None:
|
|||
|
|
cls_importance = imp.to_dict()
|
|||
|
|
|
|||
|
|
return reg_importance, cls_importance
|
|||
|
|
|
|||
|
|
def get_training_info(self) -> dict:
|
|||
|
|
"""获取训练信息
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
训练信息字典
|
|||
|
|
"""
|
|||
|
|
return self.training_info
|