feat(training): 实现 Trainer 模块化重构 (Trainer V2)

- 新增 FactorManager 组件:统一管理多种来源因子
- 新增 DataPipeline 组件:完整数据处理流程(注册、过滤、划分、预处理)
- 新增 Task 策略组件:BaseTask 抽象基类、RegressionTask、RankTask
- 新增 ResultAnalyzer 组件:特征重要性分析和结果组装
- 新增 TrainerV2:作为纯调度引擎协调各组件
- 支持回归和排序学习两种训练模式
- 采用组合模式解耦训练流程,消除代码重复
This commit is contained in:
2026-03-24 23:35:31 +08:00
parent bace4cc5f4
commit e41a128ca3
13 changed files with 4045 additions and 1509 deletions

View File

@@ -43,6 +43,12 @@ from src.training.utils import check_data_quality
# 配置
from src.training.config import TrainingConfig
# 新增:模块化 Trainer 组件
from src.training.factor_manager import FactorManager
from src.training.pipeline import DataPipeline
from src.training.result_analyzer import ResultAnalyzer
from src.training.tasks import BaseTask, RegressionTask, RankTask
__all__ = [
# 基础抽象类
"BaseModel",
@@ -74,4 +80,11 @@ __all__ = [
"check_data_quality",
# 配置
"TrainingConfig",
# 新增:模块化 Trainer 组件
"FactorManager",
"DataPipeline",
"ResultAnalyzer",
"BaseTask",
"RegressionTask",
"RankTask",
]

View File

@@ -185,131 +185,6 @@ class LightGBMLambdaRankModel(BaseModel):
return None
return self.model.best_score
def plot_metric(
self,
metric: Optional[str] = None,
figsize: tuple = (10, 6),
title: Optional[str] = None,
ax=None,
):
"""绘制训练指标曲线
Args:
metric: 要绘制的指标名称,如 'ndcg@5'
figsize: 图大小,默认 (10, 6)
title: 图表标题
ax: matplotlib Axes 对象
Returns:
matplotlib Axes 对象
"""
if self.model is None:
raise RuntimeError("模型尚未训练,请先调用 fit()")
if self.evals_result_ is None or not self.evals_result_:
raise RuntimeError("没有可用的评估结果")
import lightgbm as lgb
import matplotlib.pyplot as plt
if metric is None:
available_metrics = list(self.evals_result_.get("train", {}).keys())
ndcg_metrics = [m for m in available_metrics if "ndcg" in m.lower()]
if ndcg_metrics:
metric = ndcg_metrics[0]
elif available_metrics:
metric = available_metrics[0]
else:
raise ValueError("没有可用的评估指标")
if metric not in self.evals_result_.get("train", {}):
available = list(self.evals_result_.get("train", {}).keys())
raise ValueError(f"指标 '{metric}' 不存在。可用的指标: {available}")
if ax is None:
_, ax = plt.subplots(figsize=figsize)
lgb.plot_metric(self.evals_result_, metric=metric, ax=ax)
if title is None:
assert metric is not None
title = f"Training Metric ({metric.upper()}) over Iterations"
ax.set_title(title, fontsize=12, fontweight="bold")
return ax
def plot_all_metrics(
self,
metrics: Optional[list] = None,
figsize: tuple = (14, 10),
max_cols: int = 2,
):
"""绘制所有训练指标曲线
Args:
metrics: 要绘制的指标列表
figsize: 图大小,默认 (14, 10)
max_cols: 每行最多显示的子图数,默认 2
Returns:
matplotlib Figure 对象
"""
if self.model is None:
raise RuntimeError("模型尚未训练,请先调用 fit()")
if self.evals_result_ is None or not self.evals_result_:
raise RuntimeError("没有可用的评估结果")
import lightgbm as lgb
import matplotlib.pyplot as plt
available_metrics = list(self.evals_result_.get("train", {}).keys())
if metrics is None:
ndcg_metrics = [m for m in available_metrics if "ndcg" in m.lower()]
metrics = ndcg_metrics[:4] if ndcg_metrics else available_metrics[:4]
if not metrics:
raise ValueError("没有可用的评估指标")
n_metrics = len(metrics)
n_cols = min(max_cols, n_metrics)
n_rows = (n_metrics + n_cols - 1) // n_cols
fig, axes = plt.subplots(n_rows, n_cols, figsize=figsize)
if n_metrics == 1:
axes = [axes]
else:
axes = (
axes.flatten()
if n_rows > 1
else [axes]
if n_cols == 1
else axes.flatten()
)
for idx, metric in enumerate(metrics):
if idx < len(axes):
ax = axes[idx]
if metric in available_metrics:
self.plot_metric(metric=metric, ax=ax)
ax.set_title(f"{metric.upper()}", fontsize=11, fontweight="bold")
else:
ax.text(
0.5,
0.5,
f"Metric '{metric}' not found",
ha="center",
va="center",
transform=ax.transAxes,
)
for idx in range(n_metrics, len(axes)):
axes[idx].axis("off")
plt.tight_layout()
return fig
def feature_importance(self) -> Optional[pd.Series]:
"""返回特征重要性

View File

@@ -0,0 +1,163 @@
"""因子管理器
管理多种来源的因子:
- metadata 中注册的因子
- DSL 表达式定义的因子
- Label 因子
- 排除的因子列表
"""
from typing import Dict, List, Optional, Any
import polars as pl
from src.factors import FactorEngine
class FactorManager:
"""因子管理器
统一管理多种来源的因子注册和准备:
1. metadata 中已注册的因子(通过名称引用)
2. DSL 表达式定义的因子(动态注册)
3. Label 因子(通过表达式定义)
4. 排除的因子列表(从最终列表中移除)
Attributes:
selected_factors: 从 metadata 中选择的因子名称列表
factor_definitions: DSL 表达式定义的因子字典 {name: dsl_expression}
label_factor: Label 因子定义 {name: dsl_expression}
excluded_factors: 需要排除的因子名称列表
registered_factors: 已注册到 FactorEngine 的因子列表
"""
def __init__(
self,
selected_factors: List[str],
factor_definitions: Dict[str, str],
label_factor: Dict[str, str],
excluded_factors: Optional[List[str]] = None,
):
"""初始化因子管理器
Args:
selected_factors: 从 metadata 中选择的因子名称列表
factor_definitions: DSL 表达式定义的因子字典
label_factor: Label 因子定义字典
excluded_factors: 需要排除的因子名称列表
"""
self.selected_factors = selected_factors or []
self.factor_definitions = factor_definitions or {}
self.label_factor = label_factor or {}
self.excluded_factors = excluded_factors or []
self.registered_factors: List[str] = []
def register_to_engine(
self,
engine: FactorEngine,
verbose: bool = True,
) -> List[str]:
"""注册所有因子到 FactorEngine
按以下顺序注册:
1. metadata 中的因子(通过名称从 metadata 加载)
2. DSL 表达式定义的因子(使用 add_factor 注册)
3. Label 因子(使用 add_factor 注册)
4. 排除指定的因子
Args:
engine: FactorEngine 实例
verbose: 是否打印注册信息
Returns:
最终的特征列名列表(已排除指定因子)
"""
if verbose:
print("\n" + "=" * 80)
print("因子注册")
print("=" * 80)
# Step 1: 从 metadata 注册选中的因子
if verbose:
print(f"\n[1/4] 从 metadata 注册 {len(self.selected_factors)} 个因子...")
feature_cols = []
for factor_name in self.selected_factors:
try:
engine.add_factor(factor_name)
feature_cols.append(factor_name)
if verbose:
print(f" [OK] {factor_name}")
except Exception as e:
if verbose:
print(f" [FAIL] {factor_name}: {e}")
# Step 2: 注册 DSL 定义的因子
if self.factor_definitions:
if verbose:
print(f"\n[2/4] 注册 {len(self.factor_definitions)} 个 DSL 定义因子...")
for factor_name, dsl_expr in self.factor_definitions.items():
if factor_name not in self.excluded_factors:
try:
engine.add_factor(factor_name, dsl_expr)
feature_cols.append(factor_name)
if verbose:
print(f"{factor_name}: {dsl_expr[:50]}...")
except Exception as e:
if verbose:
print(f"{factor_name}: {e}")
# Step 3: 注册 Label 因子
if self.label_factor:
if verbose:
print(f"\n[3/4] 注册 Label 因子...")
for factor_name, dsl_expr in self.label_factor.items():
try:
engine.add_factor(factor_name, dsl_expr)
if verbose:
print(f" ✓ Label: {factor_name}")
except Exception as e:
if verbose:
print(f" ✗ Label {factor_name}: {e}")
# Step 4: 排除指定因子
if self.excluded_factors:
if verbose:
print(f"\n[4/4] 排除 {len(self.excluded_factors)} 个因子...")
original_count = len(feature_cols)
feature_cols = [f for f in feature_cols if f not in self.excluded_factors]
excluded_count = original_count - len(feature_cols)
if verbose:
print(f" 排除 {excluded_count} 个因子")
for f in self.excluded_factors:
if f in self.selected_factors or f in self.factor_definitions:
print(f" - {f}")
self.registered_factors = feature_cols
if verbose:
print(f"\n[结果] 最终特征数: {len(feature_cols)}")
print("=" * 80)
return feature_cols
def get_feature_cols(self) -> List[str]:
"""获取已注册的特征列名列表
Returns:
特征列名列表
"""
return self.registered_factors
def get_label_col(self) -> Optional[str]:
"""获取 Label 列名
Returns:
Label 列名,如果没有则返回 None
"""
if self.label_factor:
return list(self.label_factor.keys())[0]
return None

309
src/training/pipeline.py Normal file
View File

@@ -0,0 +1,309 @@
"""数据流水线
完整的数据处理流程:
1. 因子注册和数据准备
2. 应用过滤器STFilter 等)
3. 股票池筛选(自定义函数)
4. 数据质量检查
5. 数据划分train/val/test
6. 数据预处理fit_transform/transform
"""
from typing import Any, Callable, Dict, List, Optional, Tuple, Type
import polars as pl
import numpy as np
from src.factors import FactorEngine
from src.training.factor_manager import FactorManager
from src.training.components.base import BaseProcessor
from src.training.core.stock_pool_manager import StockPoolManager
class DataPipeline:
"""数据流水线
执行完整的数据处理流程,返回标准化的数据字典。
Attributes:
factor_manager: 因子管理器
filters: 类形式的过滤器列表(如 STFilter
stock_pool_filter_func: 函数形式的股票池筛选器
processor_configs: 数据处理器配置列表(类+参数)
stock_pool_required_columns: 股票池筛选所需的额外列
fitted_processors: 已拟合的处理器列表(训练后填充)
"""
def __init__(
self,
factor_manager: FactorManager,
processor_configs: List[Tuple[Type[BaseProcessor], Dict[str, Any]]],
filters: Optional[List[Any]] = None,
stock_pool_filter_func: Optional[Callable] = None,
stock_pool_required_columns: Optional[List[str]] = None,
):
"""初始化数据流水线
Args:
factor_manager: 因子管理器实例
processor_configs: 数据处理器配置列表,每个元素为 (ProcessorClass, kwargs)
例如:[(NullFiller, {"strategy": "mean"}), (Winsorizer, {"lower": 0.01, "upper": 0.99})]
filters: 类形式的过滤器列表(如 [STFilter]
stock_pool_filter_func: 函数形式的股票池筛选器
stock_pool_required_columns: 股票池筛选所需的额外列
"""
self.factor_manager = factor_manager
self.processor_configs = processor_configs or []
self.filters = filters or []
self.stock_pool_filter_func = stock_pool_filter_func
self.stock_pool_required_columns = stock_pool_required_columns or []
self.fitted_processors: List[BaseProcessor] = []
def prepare_data(
self,
engine: FactorEngine,
date_range: Dict[str, Tuple[str, str]],
label_name: str,
verbose: bool = True,
) -> Dict[str, Dict[str, Any]]:
"""执行完整数据流程
流程:
1. 注册因子并准备数据
2. 应用类过滤器STFilter
3. 应用股票池筛选(函数形式)
4. 数据质量检查
5. 数据划分
6. 数据预处理
Args:
engine: FactorEngine 实例
date_range: 日期范围字典 {"train": (start, end), "val": ..., "test": ...}
label_name: Label 列名
verbose: 是否打印处理信息
Returns:
标准化的数据字典
"""
if verbose:
print("\n" + "=" * 80)
print("数据流水线")
print("=" * 80)
# Step 1: 注册因子并准备数据
if verbose:
print("\n[1/6] 注册因子并准备数据...")
feature_cols = self.factor_manager.register_to_engine(engine, verbose=verbose)
# 计算完整日期范围
all_start = min(
date_range["train"][0], date_range["val"][0], date_range["test"][0]
)
all_end = max(
date_range["train"][1], date_range["val"][1], date_range["test"][1]
)
# 准备数据
data = engine.compute(
factor_names=feature_cols + [label_name],
start_date=all_start,
end_date=all_end,
)
if verbose:
print(f" 原始数据规模: {data.shape}")
print(f" 特征数: {len(feature_cols)}")
# Step 2: 应用类过滤器STFilter
if self.filters:
if verbose:
print(f"\n[2/6] 应用过滤器({len(self.filters)}个)...")
for filter_obj in self.filters:
data_before = len(data)
data = filter_obj.filter(data)
data_after = len(data)
if verbose:
print(f" {filter_obj.__class__.__name__}:")
print(f" 过滤前: {data_before}, 过滤后: {data_after}")
print(f" 删除: {data_before - data_after}")
# Step 3: 应用股票池筛选(函数形式)
if self.stock_pool_filter_func:
if verbose:
print(f"\n[3/6] 股票池筛选...")
data_before = len(data)
# 创建 StockPoolManager
pool_manager = StockPoolManager(
filter_func=self.stock_pool_filter_func,
required_columns=self.stock_pool_required_columns,
data_router=engine.router,
)
data = pool_manager.filter_and_select_daily(data)
data_after = len(data)
if verbose:
print(f" 筛选前: {data_before}, 筛选后: {data_after}")
print(f" 删除: {data_before - data_after}")
# Step 4: 数据质量检查
if verbose:
print(f"\n[4/6] 数据质量检查...")
self._check_data_quality(data, feature_cols, verbose=verbose)
# Step 5: 数据划分
if verbose:
print(f"\n[5/6] 数据划分...")
split_data = self._split_data(
data, date_range, feature_cols, label_name, verbose=verbose
)
# Step 6: 数据预处理
if verbose:
print(f"\n[6/6] 数据预处理...")
split_data = self._preprocess(split_data, feature_cols, verbose=verbose)
if verbose:
print("\n" + "=" * 80)
print("数据流水线完成")
print("=" * 80)
return split_data
def _check_data_quality(
self,
data: pl.DataFrame,
feature_cols: List[str],
verbose: bool = True,
) -> None:
"""检查数据质量
Args:
data: 数据框
feature_cols: 特征列名列表
verbose: 是否打印信息
"""
# 检查缺失值
null_counts = {}
for col in feature_cols[:10]: # 只检查前10个特征
null_count = data[col].null_count()
if null_count > 0:
null_counts[col] = null_count
if null_counts and verbose:
print(f" [警告] 发现缺失值仅显示前10个特征:")
for col, count in list(null_counts.items())[:5]:
pct = count / len(data) * 100
print(f" {col}: {count} ({pct:.2f}%)")
def _split_data(
self,
data: pl.DataFrame,
date_range: Dict[str, Tuple[str, str]],
feature_cols: List[str],
label_name: str,
verbose: bool = True,
) -> Dict[str, Dict[str, Any]]:
"""划分数据集
Args:
data: 完整数据
date_range: 日期范围字典
feature_cols: 特征列名
label_name: Label 列名
verbose: 是否打印信息
Returns:
划分后的数据字典
"""
result = {}
for split_name, (start, end) in date_range.items():
mask = (data["trade_date"] >= start) & (data["trade_date"] <= end)
split_df = data.filter(mask)
result[split_name] = {
"X": split_df.select(feature_cols),
"y": split_df[label_name],
"raw_data": split_df,
"feature_cols": feature_cols,
}
if verbose:
print(f" {split_name}: {len(split_df)} 条记录")
return result
def _preprocess(
self,
split_data: Dict[str, Dict[str, Any]],
feature_cols: List[str],
verbose: bool = True,
) -> Dict[str, Dict[str, Any]]:
"""预处理数据
训练集使用 fit_transform验证集和测试集使用 transform
Args:
split_data: 划分后的数据字典
feature_cols: 特征列名列表
verbose: 是否打印信息
Returns:
预处理后的数据字典
"""
if not self.processor_configs:
return split_data
self.fitted_processors = []
# 实例化 processors传入 feature_cols
processors = []
for proc_class, proc_kwargs in self.processor_configs:
proc_kwargs_with_cols = {**proc_kwargs, "feature_cols": feature_cols}
processors.append(proc_class(**proc_kwargs_with_cols))
# 训练集fit_transform
if verbose:
print(f" 训练集预处理fit_transform...")
train_data = split_data["train"]["raw_data"]
for processor in processors:
train_data = processor.fit_transform(train_data)
self.fitted_processors.append(processor)
# 更新训练集
split_data["train"]["raw_data"] = train_data
split_data["train"]["X"] = train_data.select(feature_cols)
split_data["train"]["y"] = train_data[split_data["train"]["y"].name]
# 验证集和测试集transform
for split_name in ["val", "test"]:
if split_name in split_data:
if verbose:
print(f" {split_name}集预处理transform...")
split_df = split_data[split_name]["raw_data"]
for processor in self.fitted_processors:
split_df = processor.transform(split_df)
split_data[split_name]["raw_data"] = split_df
split_data[split_name]["X"] = split_df.select(feature_cols)
split_data[split_name]["y"] = split_df[split_data[split_name]["y"].name]
return split_data
def get_fitted_processors(self) -> List[BaseProcessor]:
"""获取已拟合的处理器列表
Returns:
已拟合的处理器列表(用于模型保存)
"""
return self.fitted_processors

View File

@@ -0,0 +1,191 @@
"""结果分析器
训练后的分析和结果处理:
1. 特征重要性分析Top N、零贡献特征
2. 结果组装(生成每日 Top N
3. 结果保存
"""
from typing import Any, Dict, List, Optional
import os
import polars as pl
import pandas as pd
import numpy as np
class ResultAnalyzer:
"""结果分析器
分析训练结果,生成报告并保存。
"""
def analyze_feature_importance(
self,
model,
feature_cols: List[str],
top_n: int = 20,
verbose: bool = True,
) -> Dict[str, Any]:
"""分析特征重要性
Args:
model: 训练好的模型
feature_cols: 特征列名列表
top_n: 显示 Top N 特征
verbose: 是否打印信息
Returns:
分析结果字典
"""
importance = model.feature_importance()
if importance is None:
if verbose:
print("[警告] 无法获取特征重要性")
return {}
# 按重要性排序
importance_sorted = importance.sort_values(ascending=False)
# 计算百分比
total_importance = importance_sorted.sum()
importance_pct = (importance_sorted / total_importance * 100).round(2)
# 识别零贡献特征
zero_importance_features = importance_sorted[
importance_sorted == 0
].index.tolist()
if verbose:
print("\n" + "=" * 80)
print("特征重要性分析")
print("=" * 80)
# 打印 Top N
print(f"\nTop {top_n} 特征:")
print("-" * 80)
print(f"{'排名':<6}{'特征名':<35}{'重要性':<15}{'占比':<10}")
print("-" * 80)
for i, (feature, score) in enumerate(
importance_sorted.head(top_n).items(), 1
):
pct = importance_pct[feature]
if pct >= 10:
marker = " [高贡献]"
elif pct >= 1:
marker = " [中贡献]"
else:
marker = " [低贡献]"
print(f"{i:<6}{feature:<35}{score:<15.2f}{pct:<8.2f}%{marker}")
# 打印零贡献特征
if zero_importance_features:
print("\n" + "-" * 80)
print(f"[警告] 贡献为0的特征{len(zero_importance_features)} 个):")
for i, feature in enumerate(zero_importance_features, 1):
print(f" {i}. {feature}")
# 统计摘要
print("\n" + "=" * 80)
print("统计摘要:")
print("-" * 80)
print(f" 特征总数: {len(importance_sorted)}")
print(
f" 有贡献特征数: {len(importance_sorted) - len(zero_importance_features)}"
)
print(f" 零贡献特征数: {len(zero_importance_features)}")
if len(importance_sorted) > 0:
print(
f" 零贡献占比: {len(zero_importance_features) / len(importance_sorted) * 100:.1f}%"
)
print(f" Top {top_n} 累计占比: {importance_pct.head(top_n).sum():.1f}%")
print("=" * 80)
return {
"importance": importance_sorted,
"importance_pct": importance_pct,
"zero_importance_features": zero_importance_features,
"top_n": importance_sorted.head(top_n),
}
def assemble_results(
self,
test_data: Dict[str, Any],
predictions: np.ndarray,
top_n: int = 50,
verbose: bool = True,
) -> pl.DataFrame:
"""组装结果
生成每日 Top N 股票推荐列表。
Args:
test_data: 测试数据字典
predictions: 预测结果数组
top_n: 每日选择的股票数
verbose: 是否打印信息
Returns:
结果数据框
"""
# 添加预测列
raw_data = test_data["raw_data"]
results = raw_data.with_columns([pl.Series("prediction", predictions)])
# 按日期分组取 Top N
unique_dates = results["trade_date"].unique().sort()
topn_by_date = []
for date in unique_dates:
day_data = results.filter(results["trade_date"] == date)
topn = day_data.sort("prediction", descending=True).head(top_n)
topn_by_date.append(topn)
# 合并所有日期的 Top N
topn_results = pl.concat(topn_by_date)
if verbose:
print(f"\n生成每日 Top {top_n} 股票列表:")
print(f" 交易日数: {len(unique_dates)}")
print(f" 总推荐数: {len(topn_results)}")
return topn_results
def save_results(
self,
results: pl.DataFrame,
output_path: str,
verbose: bool = True,
) -> None:
"""保存结果
Args:
results: 结果数据框
output_path: 输出路径
verbose: 是否打印信息
"""
# 格式化日期并调整列顺序
formatted = results.select(
[
(
pl.col("trade_date").str.slice(0, 4)
+ "-"
+ pl.col("trade_date").str.slice(4, 2)
+ "-"
+ pl.col("trade_date").str.slice(6, 2)
).alias("date"),
pl.col("prediction").alias("score"),
pl.col("ts_code"),
]
)
# 确保目录存在
os.makedirs(os.path.dirname(output_path), exist_ok=True)
# 保存 CSV
formatted.write_csv(output_path, include_header=True)
if verbose:
print(f" 保存路径: {output_path}")
print(f" 保存行数: {len(formatted)}")

View File

@@ -0,0 +1,14 @@
"""Tasks 模块
提供各种训练任务的实现。
"""
from src.training.tasks.base import BaseTask
from src.training.tasks.regression_task import RegressionTask
from src.training.tasks.rank_task import RankTask
__all__ = [
"BaseTask",
"RegressionTask",
"RankTask",
]

View File

@@ -0,0 +1,79 @@
"""任务抽象基类
定义 Task 接口,所有具体任务必须实现此接口。
"""
from abc import ABC, abstractmethod
from typing import Any, Dict, Optional
import numpy as np
class BaseTask(ABC):
"""任务抽象基类
所有训练任务(回归、排序学习、分类等)必须继承此类。
提供统一的接口Label处理、模型训练、预测、评估。
Attributes:
label_name: Label 列名
model_params: 模型参数字典
"""
def __init__(self, model_params: Dict[str, Any], label_name: str):
"""初始化任务
Args:
model_params: 模型参数字典
label_name: Label 列名
"""
self.model_params = model_params
self.label_name = label_name
self.model = None
@abstractmethod
def prepare_labels(self, data: Dict[str, Dict]) -> Dict[str, Dict]:
"""准备标签
子类可实现特定的 Label 转换逻辑(如排序学习的分位数转换)。
Args:
data: 数据字典
Returns:
处理后的数据字典
"""
raise NotImplementedError
@abstractmethod
def fit(self, train_data: Dict, val_data: Dict) -> None:
"""训练模型
Args:
train_data: 训练数据字典 {"X": DataFrame, "y": Series, ...}
val_data: 验证数据字典
"""
raise NotImplementedError
@abstractmethod
def predict(self, test_data: Dict) -> np.ndarray:
"""生成预测
Args:
test_data: 测试数据字典
Returns:
预测结果数组
"""
raise NotImplementedError
def get_model(self):
"""获取底层模型
Returns:
训练后的模型实例
"""
return self.model
def plot_training_metrics(self) -> None:
"""绘制训练指标曲线(可选)"""
pass

View File

@@ -0,0 +1,198 @@
"""排序学习任务实现
实现排序学习任务的训练流程:
- Label 转换为分位数标签
- 生成 group 数组
- 使用 LightGBM LambdaRank
- 支持 NDCG@k 评估
"""
from typing import Any, Dict, List, Optional
import numpy as np
import polars as pl
from src.training.tasks.base import BaseTask
from src.training.components.models.lightgbm_lambdarank import LightGBMLambdaRankModel
class RankTask(BaseTask):
"""排序学习任务
使用 LightGBM LambdaRank 进行排序学习训练。
将连续收益率转换为分位数标签进行训练。
"""
def __init__(
self,
model_params: Dict[str, Any],
label_name: str = "future_return_5",
n_quantiles: int = 20,
):
"""初始化排序学习任务
Args:
model_params: LightGBM 参数字典
label_name: Label 列名
n_quantiles: 分位数数量
"""
super().__init__(model_params, label_name)
self.n_quantiles = n_quantiles
def prepare_labels(self, data: Dict[str, Dict]) -> Dict[str, Dict]:
"""准备标签(转换为分位数标签)
将连续收益率转换为分位数标签,并生成 group 数组。
Args:
data: 数据字典
Returns:
处理后的数据字典(添加了 y_rank 和 groups
"""
for split in ["train", "val", "test"]:
if split not in data:
continue
df = data[split]["raw_data"]
# 分位数转换
rank_col = f"{self.label_name}_rank"
df_ranked = (
df.with_columns(
pl.col(self.label_name)
.rank(method="min")
.over("trade_date")
.alias("_rank")
)
.with_columns(
(
(pl.col("_rank") - 1)
/ pl.len().over("trade_date")
* self.n_quantiles
)
.floor()
.cast(pl.Int64)
.clip(0, self.n_quantiles - 1)
.alias(rank_col)
)
.drop("_rank")
)
# 更新数据
data[split]["raw_data"] = df_ranked
data[split]["y"] = df_ranked[rank_col]
data[split]["y_raw"] = df_ranked[self.label_name] # 保留原始值
# 生成 group 数组
data[split]["groups"] = self._compute_group_array(df_ranked, "trade_date")
return data
def _compute_group_array(
self,
df: pl.DataFrame,
date_col: str = "trade_date",
) -> np.ndarray:
"""计算 group 数组
Args:
df: 数据框
date_col: 日期列名
Returns:
group 数组(每个日期的样本数)
"""
group_counts = df.group_by(date_col, maintain_order=True).agg(
pl.count().alias("count")
)
return group_counts["count"].to_numpy()
def fit(self, train_data: Dict, val_data: Dict) -> None:
"""训练排序模型
Args:
train_data: 训练数据
val_data: 验证数据
"""
self.model = LightGBMLambdaRankModel(params=self.model_params)
self.model.fit(
train_data["X"],
train_data["y"],
group=train_data["groups"],
eval_set=(val_data["X"], val_data["y"], val_data["groups"])
if val_data
else None,
)
def predict(self, test_data: Dict) -> np.ndarray:
"""生成预测
Args:
test_data: 测试数据
Returns:
预测结果
"""
return self.model.predict(test_data["X"])
def evaluate_ndcg(
self,
test_data: Dict,
k_list: List[int] = None,
) -> Dict[str, float]:
"""评估 NDCG@k
Args:
test_data: 测试数据
k_list: k 值列表,默认 [1, 5, 10, 20]
Returns:
NDCG 分数字典 {"ndcg@1": score, ...}
"""
if k_list is None:
k_list = [1, 5, 10, 20]
y_true = test_data["y_raw"]
y_pred = self.predict(test_data)
groups = test_data["groups"]
from sklearn.metrics import ndcg_score
results = {}
# 按 group 拆分
start_idx = 0
y_true_groups = []
y_pred_groups = []
for group_size in groups:
end_idx = start_idx + group_size
y_true_groups.append(y_true.to_numpy()[start_idx:end_idx])
y_pred_groups.append(y_pred[start_idx:end_idx])
start_idx = end_idx
# 计算每个 k 的 NDCG
for k in k_list:
ndcg_scores = []
for yt, yp in zip(y_true_groups, y_pred_groups):
if len(yt) > 1:
try:
score = ndcg_score([yt], [yp], k=k)
ndcg_scores.append(score)
except ValueError:
pass
results[f"ndcg@{k}"] = float(np.mean(ndcg_scores)) if ndcg_scores else 0.0
return results
def plot_training_metrics(self) -> None:
"""绘制训练指标曲线NDCG"""
if self.model and hasattr(self.model, "model") and self.model.model:
try:
import lightgbm as lgb
lgb.plot_metric(self.model.model)
except Exception as e:
print(f"[警告] 无法绘制训练曲线: {e}")

View File

@@ -0,0 +1,86 @@
"""回归任务实现
实现回归任务的训练流程:
- Label 无需转换(保持连续值)
- 使用 LightGBM 回归模型
- 支持 MAE/RMSE 评估
"""
from typing import Any, Dict, Optional
import numpy as np
import polars as pl
from src.training.tasks.base import BaseTask
from src.training.components.models.lightgbm import LightGBMModel
class RegressionTask(BaseTask):
"""回归任务
使用 LightGBM 进行回归训练,支持早停和训练曲线。
"""
def __init__(
self,
model_params: Dict[str, Any],
label_name: str = "future_return_5",
):
"""初始化回归任务
Args:
model_params: LightGBM 参数字典
label_name: Label 列名
"""
super().__init__(model_params, label_name)
self.evals_result: Optional[Dict] = None
def prepare_labels(self, data: Dict[str, Dict]) -> Dict[str, Dict]:
"""准备标签(回归任务无需转换)
Args:
data: 数据字典
Returns:
原样返回数据字典
"""
# 回归任务不需要转换 Label
return data
def fit(self, train_data: Dict, val_data: Dict) -> None:
"""训练回归模型
Args:
train_data: 训练数据 {"X": DataFrame, "y": Series}
val_data: 验证数据
"""
self.model = LightGBMModel(params=self.model_params)
X_train = train_data["X"]
y_train = train_data["y"]
X_val = val_data["X"]
y_val = val_data["y"]
self.model.fit(
X_train, y_train, eval_set=(X_val, y_val) if X_val is not None else None
)
def predict(self, test_data: Dict) -> np.ndarray:
"""生成预测
Args:
test_data: 测试数据
Returns:
预测结果
"""
return self.model.predict(test_data["X"])
def plot_training_metrics(self) -> None:
"""绘制训练指标曲线"""
if self.model and hasattr(self.model, "model") and self.model.model:
try:
import lightgbm as lgb
lgb.plot_metric(self.model.model)
except Exception as e:
print(f"[警告] 无法绘制训练曲线: {e}")

211
src/training/trainer_v2.py Normal file
View File

@@ -0,0 +1,211 @@
"""训练调度引擎
协调 FactorManager、DataPipeline、Task 和 ResultAnalyzer 完成训练流程。
"""
from typing import Any, Callable, Dict, List, Optional, Tuple
import os
from datetime import datetime
import polars as pl
from src.factors import FactorEngine
from src.training.pipeline import DataPipeline
from src.training.tasks.base import BaseTask
from src.training.result_analyzer import ResultAnalyzer
class Trainer:
"""训练调度引擎
协调各个组件执行完整训练流程:
1. 准备数据DataPipeline
2. 处理标签Task
3. 训练模型Task
4. 绘制指标Task
5. 生成预测Task
6. 分析结果ResultAnalyzer
7. 保存结果
Attributes:
data_pipeline: 数据流水线
task: 任务实例RegressionTask/RankTask
analyzer: 结果分析器
output_config: 输出配置
verbose: 是否打印详细信息
results: 训练结果
"""
def __init__(
self,
data_pipeline: DataPipeline,
task: BaseTask,
analyzer: Optional[ResultAnalyzer] = None,
output_config: Optional[Dict[str, Any]] = None,
verbose: bool = True,
):
"""初始化训练器
Args:
data_pipeline: 数据流水线实例
task: 任务实例RegressionTask 或 RankTask
analyzer: 结果分析器(可选,默认创建新实例)
output_config: 输出配置字典
verbose: 是否打印详细信息
"""
self.data_pipeline = data_pipeline
self.task = task
self.analyzer = analyzer or ResultAnalyzer()
self.output_config = output_config or {}
self.verbose = verbose
self.results: Optional[pl.DataFrame] = None
def run(
self,
engine: FactorEngine,
date_range: Dict[str, Tuple[str, str]],
) -> pl.DataFrame:
"""执行完整训练流程
Args:
engine: FactorEngine 实例
date_range: 日期范围字典
{
"train": (start_date, end_date),
"val": (start_date, end_date),
"test": (start_date, end_date),
}
Returns:
训练结果数据框
"""
if self.verbose:
print("\n" + "=" * 80)
print(f"开始训练: {self.task.__class__.__name__}")
print("=" * 80)
# Step 1: 准备数据
if self.verbose:
print("\n[Step 1/7] 准备数据...")
data = self.data_pipeline.prepare_data(
engine=engine,
date_range=date_range,
label_name=self.task.label_name,
verbose=self.verbose,
)
# Step 2: 处理标签
if self.verbose:
print("\n[Step 2/7] 处理标签...")
data = self.task.prepare_labels(data)
# Step 3: 训练模型
if self.verbose:
print("\n[Step 3/7] 训练模型...")
self.task.fit(data["train"], data["val"])
# Step 4: 绘制训练指标
if self.verbose:
print("\n[Step 4/7] 绘制训练指标...")
self.task.plot_training_metrics()
# Step 5: 生成预测
if self.verbose:
print("\n[Step 5/7] 生成预测...")
predictions = self.task.predict(data["test"])
# Step 6: 分析结果
if self.verbose:
print("\n[Step 6/7] 分析结果...")
# 特征重要性
self.analyzer.analyze_feature_importance(
model=self.task.get_model(),
feature_cols=data["test"]["feature_cols"],
top_n=20,
verbose=self.verbose,
)
# NDCG 评估(排序任务特有)
if hasattr(self.task, "evaluate_ndcg"):
ndcg_scores = self.task.evaluate_ndcg(data["test"])
if self.verbose:
print("\nNDCG 评估结果:")
for metric, score in ndcg_scores.items():
print(f" {metric}: {score:.4f}")
# 组装结果
self.results = self.analyzer.assemble_results(
test_data=data["test"],
predictions=predictions,
top_n=self.output_config.get("top_n", 50),
verbose=self.verbose,
)
# Step 7: 保存结果
if self.verbose:
print("\n[Step 7/7] 保存结果...")
if self.output_config.get("save_predictions", True):
self._save_predictions()
if self.output_config.get("save_model", False):
self._save_model()
if self.verbose:
print("\n" + "=" * 80)
print("训练完成!")
print("=" * 80)
return self.results
def _save_predictions(self) -> None:
"""保存预测结果"""
output_dir = self.output_config.get("output_dir", "experiment/output")
output_filename = self.output_config.get("output_filename", "output.csv")
output_path = os.path.join(output_dir, output_filename)
self.analyzer.save_results(
results=self.results,
output_path=output_path,
verbose=self.verbose,
)
def _save_model(self) -> None:
"""保存模型"""
model_save_path = self.output_config.get("model_save_path")
if not model_save_path:
return
# 确保目录存在
os.makedirs(os.path.dirname(model_save_path), exist_ok=True)
# 获取模型和相关信息
model = self.task.get_model()
# 保存模型
model.save(model_save_path)
if self.verbose:
print(f" 模型保存路径: {model_save_path}")
def get_results(self) -> Optional[pl.DataFrame]:
"""获取训练结果
Returns:
训练结果数据框,如果尚未训练则返回 None
"""
return self.results
def get_task(self) -> BaseTask:
"""获取任务实例
Returns:
任务实例
"""
return self.task