Files
ProStock/src/experiment/probe_selection/probe_trainer.py

254 lines
7.6 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""探针训练器
执行多任务训练(回归 + 分类),支持验证集早停。
"""
from typing import Optional, Tuple
import numpy as np
import polars as pl
from src.experiment.probe_selection.lightgbm_classifier import LightGBMClassifier
from src.training.components.models.lightgbm import LightGBMModel
def split_validation_by_date(
df: pl.DataFrame,
date_col: str = "trade_date",
val_ratio: float = 0.15,
) -> Tuple[pl.DataFrame, pl.DataFrame]:
"""按时间切分训练集和验证集(最近日期作为验证集)
Args:
df: 输入数据
date_col: 日期列名
val_ratio: 验证集比例
Returns:
(train_df, val_df) 元组
"""
dates = df[date_col].unique().sort()
n_dates = len(dates)
n_val_dates = max(1, int(n_dates * val_ratio))
val_dates = dates[-n_val_dates:]
train_df = df.filter(~pl.col(date_col).is_in(val_dates))
val_df = df.filter(pl.col(date_col).is_in(val_dates))
return train_df, val_df
def create_classification_target(
df: pl.DataFrame,
return_col: str,
date_col: str = "trade_date",
new_col_name: str = "target_class",
) -> pl.DataFrame:
"""将收益率转换为截面中位数分类标签
优势:预测跑赢当天市场平均水平的股票,真正有 Alpha 的因子
避免:牛熊市不平衡导致的分类失效
Args:
df: 输入数据
return_col: 收益率列名
date_col: 日期列名
new_col_name: 新列名
Returns:
添加了分类标签的 DataFrame
"""
return df.with_columns(
(pl.col(return_col) > pl.col(return_col).median().over(date_col))
.cast(pl.Int8)
.alias(new_col_name)
)
class ProbeTrainer:
"""探针训练器
执行多任务训练(回归 + 分类),基于验证集早停。
"""
def __init__(
self,
regression_params: Optional[dict] = None,
classification_params: Optional[dict] = None,
validation_ratio: float = 0.15,
random_state: int = 42,
):
"""初始化探针训练器
Args:
regression_params: 回归模型参数
classification_params: 分类模型参数
validation_ratio: 验证集比例
random_state: 随机种子
"""
self.regression_params = regression_params or {
"objective": "regression",
"metric": "mae",
"n_estimators": 500,
"learning_rate": 0.05,
"early_stopping_round": 50,
"verbose": -1,
}
self.classification_params = classification_params or {
"objective": "binary",
"metric": "auc",
"n_estimators": 500,
"learning_rate": 0.05,
"early_stopping_round": 50,
"verbose": -1,
}
self.validation_ratio = validation_ratio
self.random_state = random_state
self.regression_model: Optional[LightGBMModel] = None
self.classification_model: Optional[LightGBMClassifier] = None
self.training_info: dict = {}
def fit(
self,
df: pl.DataFrame,
feature_cols: list[str],
target_col_regression: str,
target_col_classification: Optional[str] = None,
date_col: str = "trade_date",
) -> "ProbeTrainer":
"""训练回归和分类模型
Args:
df: 训练数据(包含噪音特征)
feature_cols: 特征列名列表(包含噪音)
target_col_regression: 回归目标列名
target_col_classification: 分类目标列名(如不传则自动生成)
date_col: 日期列名
Returns:
self
"""
# 切分训练集和验证集(按时间)
train_df, val_df = split_validation_by_date(df, date_col, self.validation_ratio)
self.training_info = {
"train_size": len(train_df),
"val_size": len(val_df),
"n_features": len(feature_cols),
}
# 训练回归模型
self._fit_regression(train_df, val_df, feature_cols, target_col_regression)
# 准备分类目标
if target_col_classification is None:
# 自动生成截面中位数分类目标
train_df = create_classification_target(
train_df, target_col_regression, date_col
)
val_df = create_classification_target(
val_df, target_col_regression, date_col
)
target_col_classification = "target_class"
# 训练分类模型
self._fit_classification(
train_df, val_df, feature_cols, target_col_classification
)
return self
def _fit_regression(
self,
train_df: pl.DataFrame,
val_df: pl.DataFrame,
feature_cols: list[str],
target_col: str,
):
"""训练回归模型"""
X_train = train_df.select(feature_cols)
y_train = train_df.select(target_col).to_series()
X_val = val_df.select(feature_cols)
y_val = val_df.select(target_col).to_series()
self.regression_model = LightGBMModel(params=self.regression_params)
self.regression_model.fit(
X_train,
y_train,
eval_set=(X_val, y_val),
)
# 获取早停信息
if hasattr(self.regression_model.model, "best_iteration"):
self.training_info["regression_best_iter"] = (
self.regression_model.model.best_iteration
)
def _fit_classification(
self,
train_df: pl.DataFrame,
val_df: pl.DataFrame,
feature_cols: list[str],
target_col: str,
):
"""训练分类模型"""
X_train = train_df.select(feature_cols)
y_train = train_df.select(target_col).to_series()
X_val = val_df.select(feature_cols)
y_val = val_df.select(target_col).to_series()
self.classification_model = LightGBMClassifier(
params=self.classification_params
)
self.classification_model.fit(
X_train,
y_train,
eval_set=(X_val, y_val),
)
# 获取早停信息
if hasattr(self.classification_model.model, "best_iteration"):
self.training_info["classification_best_iter"] = (
self.classification_model.model.best_iteration
)
def get_feature_importance(
self, importance_type: str = "gain"
) -> Tuple[Optional[dict], Optional[dict]]:
"""获取两个模型的特征重要性
Args:
importance_type: 重要性类型,必须传入 "gain"
Returns:
(regression_importance, classification_importance) 元组
每个重要性为 {feature_name: importance_value} 字典
"""
assert importance_type == "gain", (
"必须使用 importance_type='gain'split 会被噪音欺骗"
)
reg_importance = None
cls_importance = None
if self.regression_model is not None:
imp = self.regression_model.feature_importance()
if imp is not None:
reg_importance = imp.to_dict()
if self.classification_model is not None:
imp = self.classification_model.feature_importance(importance_type)
if imp is not None:
cls_importance = imp.to_dict()
return reg_importance, cls_importance
def get_training_info(self) -> dict:
"""获取训练信息
Returns:
训练信息字典
"""
return self.training_info