feat(training): 添加 LightGBM LambdaRank 排序学习功能

新增基于 LambdaRank 的排序学习模型,用于股票排序预测任务:
- 实现 LightGBMLambdaRankModel 模型类,支持分位数标签转换
- 提供完整的训练流程和 NDCG 评估指标
- 添加实验 Notebook 演示排序学习全流程
This commit is contained in:
2026-03-10 22:23:44 +08:00
parent f1811815e7
commit e6c3a918c7
6 changed files with 2366 additions and 366 deletions

View File

@@ -15,6 +15,8 @@ A股量化投资框架 - Python 项目,用于量化股票投资分析。
**⚠️ 重要:本项目强制使用 uv 作为 Python 包管理器和运行工具。禁止直接使用 `python``pip` 命令。**
**测试规则:** 当修改或查看 `tests/` 目录下的代码时,必须使用 pytest 命令进行测试验证。
```bash
# 安装依赖(必须使用 uv
uv pip install -e .

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@@ -4,5 +4,6 @@
"""
from src.training.components.models.lightgbm import LightGBMModel
from src.training.components.models.lightgbm_lambdarank import LightGBMLambdaRankModel
__all__ = ["LightGBMModel"]
__all__ = ["LightGBMModel", "LightGBMLambdaRankModel"]

View File

@@ -0,0 +1,406 @@
"""LightGBM LambdaRank 排序学习模型
提供 LightGBM LambdaRank 模型的实现,支持学习排序任务。
适用于将股票按未来收益率进行排序的场景。
"""
from typing import Any, Optional, List
import numpy as np
import pandas as pd
import polars as pl
from src.training.components.base import BaseModel
from src.training.registry import register_model
@register_model("lightgbm_lambdarank")
class LightGBMLambdaRankModel(BaseModel):
"""LightGBM LambdaRank 排序学习模型
使用 LightGBM 的 LambdaRank 实现学习排序。
适用于股票排序任务,将未来收益率转换为分位数标签进行训练。
Attributes:
name: 模型名称 "lightgbm_lambdarank"
params: LightGBM 参数字典
model: 训练后的 LightGBM Booster 对象
feature_names_: 特征名称列表
"""
name = "lightgbm_lambdarank"
def __init__(
self,
params: Optional[dict] = None,
learning_rate: float = 0.05,
num_leaves: int = 31,
n_estimators: int = 100,
min_data_in_leaf: int = 20,
ndcg_at: Optional[List[int]] = None,
early_stopping_rounds: int = 50,
**kwargs,
):
"""初始化 LambdaRank 模型
支持两种方式传入参数:
1. 通过 params 字典传入所有参数(推荐方式)
2. 通过独立参数传入(向后兼容)
Args:
params: LightGBM 参数字典,如果提供则直接使用此字典
learning_rate: 学习率,默认 0.05
num_leaves: 叶子节点数,默认 31
n_estimators: 迭代次数,默认 100
min_data_in_leaf: 叶子最小样本数,默认 20
ndcg_at: NDCG 评估的 k 值列表,默认 [1, 5, 10, 20]
early_stopping_rounds: 早停轮数,默认 50
**kwargs: 其他 LightGBM 参数
Examples:
>>> # 方式1通过 params 字典(推荐)
>>> model = LightGBMLambdaRankModel(params={
... "objective": "lambdarank",
... "metric": "ndcg",
... "ndcg_at": [1, 5, 10, 20],
... "num_leaves": 31,
... "learning_rate": 0.05,
... "n_estimators": 1000,
... })
"""
if ndcg_at is None:
ndcg_at = [1, 5, 10, 20]
if params is not None:
# 方式1直接使用 params 字典
self.params = dict(params) # 复制一份,避免修改原始字典
self.params.setdefault("objective", "lambdarank")
self.params.setdefault("metric", "ndcg")
self.params.setdefault("verbose", -1)
self.n_estimators = self.params.pop("n_estimators", n_estimators)
self.early_stopping_rounds = self.params.pop(
"early_stopping_rounds", early_stopping_rounds
)
else:
# 方式2通过独立参数构建 params
self.params = {
"objective": "lambdarank",
"metric": "ndcg",
"ndcg_at": ndcg_at,
"num_leaves": num_leaves,
"learning_rate": learning_rate,
"min_data_in_leaf": min_data_in_leaf,
"verbose": -1,
**kwargs,
}
self.n_estimators = n_estimators
self.early_stopping_rounds = early_stopping_rounds
self.model = None
self.feature_names_: Optional[list] = None
def fit(
self,
X: pl.DataFrame,
y: pl.Series,
group: Optional[np.ndarray] = None,
eval_set: Optional[tuple] = None,
) -> "LightGBMLambdaRankModel":
"""训练排序模型
Args:
X: 特征矩阵 (Polars DataFrame)
y: 目标变量 (Polars Series),应为分位数标签 (0, 1, 2, ...)
group: 分组数组,表示每个 query 的样本数。
例如 [10, 15, 20] 表示第一个 query 有 10 个样本,
第二个 query 有 15 个样本,第三个 query 有 20 个样本。
如果为 None则假设所有样本属于同一个 query。
eval_set: 验证集元组 (X_val, y_val, group_val),用于早停
Returns:
self (支持链式调用)
Raises:
ImportError: 未安装 lightgbm
RuntimeError: 训练失败
ValueError: group 参数无效
"""
try:
import lightgbm as lgb
except ImportError:
raise ImportError(
"使用 LightGBMLambdaRankModel 需要安装 lightgbm: pip install lightgbm"
)
# 保存特征名称
self.feature_names_ = X.columns
# 转换为 numpy
X_np = X.to_numpy()
y_np = y.to_numpy()
# 处理 group 参数
if group is None:
# 如果未提供 group假设所有样本属于同一个 query
group = np.array([len(y_np)])
# 验证 group 参数
if not isinstance(group, np.ndarray):
group = np.array(group)
if group.sum() != len(y_np):
raise ValueError(
f"group 数组的和 ({group.sum()}) 必须等于样本数 ({len(y_np)})"
)
# 创建训练数据集
train_data = lgb.Dataset(X_np, label=y_np, group=group)
# 准备验证集
valid_sets = [train_data]
if eval_set is not None:
X_val, y_val, group_val = eval_set
X_val_np = X_val.to_numpy() if isinstance(X_val, pl.DataFrame) else X_val
y_val_np = y_val.to_numpy() if isinstance(y_val, pl.Series) else y_val
if group_val is None:
group_val = np.array([len(y_val_np)])
elif not isinstance(group_val, np.ndarray):
group_val = np.array(group_val)
val_data = lgb.Dataset(X_val_np, label=y_val_np, group=group_val)
valid_sets.append(val_data)
# 训练
callbacks = [lgb.early_stopping(stopping_rounds=self.early_stopping_rounds)]
self.model = lgb.train(
self.params,
train_data,
num_boost_round=self.n_estimators,
valid_sets=valid_sets,
callbacks=callbacks,
)
return self
def predict(self, X: pl.DataFrame) -> np.ndarray:
"""预测排序分数
Args:
X: 特征矩阵 (Polars DataFrame)
Returns:
预测分数 (numpy ndarray),分数越高表示排序越靠前
Raises:
RuntimeError: 模型未训练时调用
"""
if self.model is None:
raise RuntimeError("模型尚未训练,请先调用 fit()")
X_np = X.to_numpy()
return self.model.predict(X_np)
def feature_importance(self) -> Optional[pd.Series]:
"""返回特征重要性
Returns:
特征重要性序列,如果模型未训练则返回 None
"""
if self.model is None or self.feature_names_ is None:
return None
importance = self.model.feature_importance(importance_type="gain")
return pd.Series(importance, index=self.feature_names_)
def save(self, path: str) -> None:
"""保存模型(使用 LightGBM 原生格式)
使用 LightGBM 的原生格式保存,不依赖 pickle
可以在不同环境中加载。
Args:
path: 保存路径
Raises:
RuntimeError: 模型未训练时调用
"""
if self.model is None:
raise RuntimeError("模型尚未训练,无法保存")
self.model.save_model(path)
# 同时保存特征名称和其他元数据
import json
meta_path = path + ".meta.json"
with open(meta_path, "w") as f:
json.dump(
{
"feature_names": self.feature_names_,
"params": self.params,
"n_estimators": self.n_estimators,
"early_stopping_rounds": self.early_stopping_rounds,
},
f,
)
@classmethod
def load(cls, path: str) -> "LightGBMLambdaRankModel":
"""加载模型
从 LightGBM 原生格式加载模型。
Args:
path: 模型文件路径
Returns:
加载的 LightGBMLambdaRankModel 实例
"""
import lightgbm as lgb
import json
instance = cls()
instance.model = lgb.Booster(model_file=path)
# 加载元数据
meta_path = path + ".meta.json"
try:
with open(meta_path, "r") as f:
meta = json.load(f)
instance.feature_names_ = meta.get("feature_names")
instance.params = meta.get("params", instance.params)
instance.n_estimators = meta.get("n_estimators", instance.n_estimators)
instance.early_stopping_rounds = meta.get(
"early_stopping_rounds", instance.early_stopping_rounds
)
except FileNotFoundError:
# 如果没有元数据文件继续运行feature_names_ 为 None
pass
return instance
@staticmethod
def prepare_group_from_dates(
df: pl.DataFrame,
date_col: str = "trade_date",
) -> np.ndarray:
"""从日期列生成 group 数组
将数据按日期分组,每个日期作为一个 query。
Args:
df: 包含日期列的 DataFrame
date_col: 日期列名,默认 "trade_date"
Returns:
group 数组,表示每个日期的样本数
Example:
>>> df = pl.DataFrame({
... "trade_date": ["20240101", "20240101", "20240102", "20240102", "20240102"],
... "feature": [1, 2, 3, 4, 5]
... })
>>> group = LightGBMLambdaRankModel.prepare_group_from_dates(df)
>>> print(group) # array([2, 3])
"""
# 按日期统计样本数
group_counts = df.group_by(date_col, maintain_order=True).agg(
pl.count().alias("count")
)
return group_counts["count"].to_numpy()
@staticmethod
def convert_to_quantile_labels(
df: pl.DataFrame,
label_col: str,
date_col: str = "trade_date",
n_quantiles: int = 20,
new_col_name: Optional[str] = None,
) -> pl.DataFrame:
"""将连续标签转换为分位数标签
对每个日期的数据分别进行分位数划分,生成 0, 1, 2, ..., n_quantiles-1 的标签。
值越大表示原始值越大(排序越靠前)。
Args:
df: 输入 DataFrame
label_col: 原始标签列名(如 "future_return_5"
date_col: 日期列名,默认 "trade_date"
n_quantiles: 分位数数量,默认 20
new_col_name: 新列名,默认在原始列名后加 "_rank"
Returns:
添加了分位数标签列的 DataFrame
Example:
>>> df = pl.DataFrame({
... "trade_date": ["20240101"] * 5 + ["20240102"] * 5,
... "future_return_5": [0.01, 0.02, 0.03, 0.04, 0.05,
... 0.02, 0.03, 0.04, 0.05, 0.06]
... })
>>> df = LightGBMLambdaRankModel.convert_to_quantile_labels(
... df, "future_return_5", n_quantiles=5
... )
>>> print(df["future_return_5_rank"]) # [0, 1, 2, 3, 4, 0, 1, 2, 3, 4]
"""
if new_col_name is None:
new_col_name = f"{label_col}_rank"
# 使用 qcut 按日期分组进行分位数划分
# qcut 返回的是 Categorical使用 to_physical() 转换为整数0, 1, 2, ...
return df.with_columns(
pl.col(label_col)
.qcut(n_quantiles)
.over(date_col)
.to_physical()
.cast(pl.Int64)
.alias(new_col_name)
)
def evaluate_ndcg(
self,
X: pl.DataFrame,
y: pl.Series,
group: np.ndarray,
k: Optional[int] = None,
) -> float:
"""评估 NDCG 指标
Args:
X: 特征矩阵
y: 真实标签
group: 分组数组
k: NDCG@k 的 k 值None 表示使用所有位置
Returns:
NDCG 分数
"""
from sklearn.metrics import ndcg_score
# 获取预测分数
y_pred = self.predict(X)
# 将数据按 group 拆分
y_true_list = []
y_score_list = []
start_idx = 0
for group_size in group:
end_idx = start_idx + group_size
y_true_list.append(y.to_numpy()[start_idx:end_idx])
y_score_list.append(y_pred[start_idx:end_idx])
start_idx = end_idx
# 计算平均 NDCG
ndcg_scores = []
for y_true, y_score in zip(y_true_list, y_score_list):
if len(y_true) > 1: # 至少要有 2 个样本才能计算 NDCG
try:
score = ndcg_score([y_true], [y_score], k=k)
ndcg_scores.append(score)
except ValueError:
# 如果标签都相同,跳过
pass
return np.mean(ndcg_scores) if ndcg_scores else 0.0

View File

@@ -29,7 +29,7 @@ class NullFiller(BaseProcessor):
fill_value: 当 strategy="value" 时使用的填充值
by_date: 是否按日期独立计算统计量(仅对 mean/median 有效)
date_col: 日期列名
exclude_cols: 参与填充的列名列表
feature_cols: 参与填充的特征列名列表
stats_: 存储学习到的统计量(全局模式)
"""
@@ -37,24 +37,24 @@ class NullFiller(BaseProcessor):
def __init__(
self,
feature_cols: List[str],
strategy: Literal["zero", "mean", "median", "value"] = "zero",
fill_value: Optional[float] = None,
by_date: bool = True,
date_col: str = "trade_date",
exclude_cols: Optional[List[str]] = None,
):
"""初始化缺失值填充处理器
Args:
feature_cols: 参与填充的特征列名列表
strategy: 填充策略,默认 "zero"
- "zero": 填充0
- "mean": 填充均值
- "median": 填充中值
- "value": 填充指定数值(需配合 fill_value
fill_value: 当 strategy="value" 时的填充值,默认为 None
by_date: 是否每天独立计算统计量,默认 False全局统计量)
by_date: 是否每天独立计算统计量,默认 True截面统计量)
date_col: 日期列名,默认 "trade_date"
exclude_cols: 不参与填充的列名列表,默认为 ["ts_code", "trade_date"]
Raises:
ValueError: 策略无效或 fill_value 未提供时
@@ -67,12 +67,12 @@ class NullFiller(BaseProcessor):
if strategy == "value" and fill_value is None:
raise ValueError("当 strategy='value' 时,必须提供 fill_value")
self.feature_cols = feature_cols
self.strategy = strategy
self.fill_value = fill_value
self.by_date = by_date
self.date_col = date_col
self.exclude_cols = exclude_cols or ["ts_code", "trade_date"]
self.stats_: dict = {}
self.stats_ = {}
def fit(self, X: pl.DataFrame) -> "NullFiller":
"""学习统计量(仅在全局模式下)
@@ -87,17 +87,12 @@ class NullFiller(BaseProcessor):
self
"""
if not self.by_date and self.strategy in ("mean", "median"):
numeric_cols = [
c
for c in X.columns
if c not in self.exclude_cols and X[c].dtype.is_numeric()
]
for col in numeric_cols:
if self.strategy == "mean":
self.stats_[col] = X[col].mean()
else: # median
self.stats_[col] = X[col].median()
for col in self.feature_cols:
if col in X.columns and X[col].dtype.is_numeric():
if self.strategy == "mean":
self.stats_[col] = X[col].mean() or 0.0
else: # median
self.stats_[col] = X[col].median() or 0.0
return self
@@ -125,15 +120,9 @@ class NullFiller(BaseProcessor):
def _fill_with_zero(self, X: pl.DataFrame) -> pl.DataFrame:
"""使用0填充缺失值"""
numeric_cols = [
c
for c in X.columns
if c not in self.exclude_cols and X[c].dtype.is_numeric()
]
expressions = []
for col in X.columns:
if col in numeric_cols:
if col in self.feature_cols and X[col].dtype.is_numeric():
expr = pl.col(col).fill_null(0).alias(col)
expressions.append(expr)
else:
@@ -143,15 +132,9 @@ class NullFiller(BaseProcessor):
def _fill_with_value(self, X: pl.DataFrame) -> pl.DataFrame:
"""使用指定值填充缺失值"""
numeric_cols = [
c
for c in X.columns
if c not in self.exclude_cols and X[c].dtype.is_numeric()
]
expressions = []
for col in X.columns:
if col in numeric_cols:
if col in self.feature_cols and X[col].dtype.is_numeric():
expr = pl.col(col).fill_null(self.fill_value).alias(col)
expressions.append(expr)
else:
@@ -174,15 +157,16 @@ class NullFiller(BaseProcessor):
def _fill_by_date(self, X: pl.DataFrame) -> pl.DataFrame:
"""使用每天截面统计量填充"""
numeric_cols = [
c
for c in X.columns
if c not in self.exclude_cols and X[c].dtype.is_numeric()
# 确定需要处理的数值列
target_cols = [
col
for col in self.feature_cols
if col in X.columns and X[col].dtype.is_numeric()
]
# 计算每天的统计量
stat_exprs = []
for col in numeric_cols:
for col in target_cols:
if self.strategy == "mean":
stat_exprs.append(
pl.col(col).mean().over(self.date_col).alias(f"{col}_stat")
@@ -198,7 +182,7 @@ class NullFiller(BaseProcessor):
# 使用统计量填充缺失值
fill_exprs = []
for col in X.columns:
if col in numeric_cols:
if col in target_cols:
expr = pl.col(col).fill_null(pl.col(f"{col}_stat")).alias(col)
fill_exprs.append(expr)
else:
@@ -219,22 +203,22 @@ class StandardScaler(BaseProcessor):
适用于需要全局统计量的场景。
Attributes:
exclude_cols: 参与标准化的列名列表
feature_cols: 参与标准化的特征列名列表
mean_: 学习到的均值字典 {列名: 均值}
std_: 学习到的标准差字典 {列名: 标准差}
"""
name = "standard_scaler"
def __init__(self, exclude_cols: Optional[List[str]] = None):
def __init__(self, feature_cols: List[str]):
"""初始化标准化处理器
Args:
exclude_cols: 参与标准化的列名列表,默认为 ["ts_code", "trade_date"]
feature_cols: 参与标准化的特征列名列表
"""
self.exclude_cols = exclude_cols or ["ts_code", "trade_date"]
self.mean_: dict = {}
self.std_: dict = {}
self.feature_cols = feature_cols
self.mean_ = {}
self.std_ = {}
def fit(self, X: pl.DataFrame) -> "StandardScaler":
"""计算均值和标准差(仅在训练集上)
@@ -245,15 +229,13 @@ class StandardScaler(BaseProcessor):
Returns:
self
"""
numeric_cols = [
c
for c in X.columns
if c not in self.exclude_cols and X[c].dtype.is_numeric()
]
for col in numeric_cols:
self.mean_[col] = X[col].mean()
self.std_[col] = X[col].std()
for col in self.feature_cols:
if col in X.columns and X[col].dtype.is_numeric():
col_mean = X[col].mean()
col_std = X[col].std()
if col_mean is not None and col_std is not None:
self.mean_[col] = col_mean
self.std_[col] = col_std
return self
@@ -291,7 +273,7 @@ class CrossSectionalStandardScaler(BaseProcessor):
- 公式z = (x - mean_today) / std_today
Attributes:
exclude_cols: 参与标准化的列名列表
feature_cols: 参与标准化的特征列名列表
date_col: 日期列名
"""
@@ -299,16 +281,16 @@ class CrossSectionalStandardScaler(BaseProcessor):
def __init__(
self,
exclude_cols: Optional[List[str]] = None,
feature_cols: List[str],
date_col: str = "trade_date",
):
"""初始化截面标准化处理器
Args:
exclude_cols: 参与标准化的列名列表,默认为 ["ts_code", "trade_date"]
feature_cols: 参与标准化的特征列名列表
date_col: 日期列名
"""
self.exclude_cols = exclude_cols or ["ts_code", "trade_date"]
self.feature_cols = feature_cols
self.date_col = date_col
def transform(self, X: pl.DataFrame) -> pl.DataFrame:
@@ -323,16 +305,10 @@ class CrossSectionalStandardScaler(BaseProcessor):
Returns:
标准化后的数据
"""
numeric_cols = [
c
for c in X.columns
if c not in self.exclude_cols and X[c].dtype.is_numeric()
]
# 构建表达式列表
expressions = []
for col in X.columns:
if col in numeric_cols:
if col in self.feature_cols and X[col].dtype.is_numeric():
# 截面标准化:每天独立计算均值和标准差
# 避免除以0当std为0时设为1
expr = (
@@ -355,6 +331,7 @@ class Winsorizer(BaseProcessor):
也可以截面截断(每天独立处理)。
Attributes:
feature_cols: 参与缩尾的特征列名列表
lower: 下分位数如0.01表示1%分位数)
upper: 上分位数如0.99表示99%分位数)
by_date: True=每天独立缩尾, False=全局缩尾
@@ -366,6 +343,7 @@ class Winsorizer(BaseProcessor):
def __init__(
self,
feature_cols: List[str],
lower: float = 0.01,
upper: float = 0.99,
by_date: bool = False,
@@ -374,6 +352,7 @@ class Winsorizer(BaseProcessor):
"""初始化缩尾处理器
Args:
feature_cols: 参与缩尾的特征列名列表
lower: 下分位数默认0.01
upper: 上分位数默认0.99
by_date: 每天独立缩尾默认False全局缩尾
@@ -387,11 +366,12 @@ class Winsorizer(BaseProcessor):
f"lower ({lower}) 必须小于 upper ({upper}),且都在 [0, 1] 范围内"
)
self.feature_cols = feature_cols
self.lower = lower
self.upper = upper
self.by_date = by_date
self.date_col = date_col
self.bounds_: dict = {}
self.bounds_ = {}
def fit(self, X: pl.DataFrame) -> "Winsorizer":
"""学习分位数边界(仅在全局模式下)
@@ -403,12 +383,12 @@ class Winsorizer(BaseProcessor):
self
"""
if not self.by_date:
numeric_cols = [c for c in X.columns if X[c].dtype.is_numeric()]
for col in numeric_cols:
self.bounds_[col] = {
"lower": X[col].quantile(self.lower),
"upper": X[col].quantile(self.upper),
}
for col in self.feature_cols:
if col in X.columns and X[col].dtype.is_numeric():
self.bounds_[col] = {
"lower": X[col].quantile(self.lower),
"upper": X[col].quantile(self.upper),
}
return self
def transform(self, X: pl.DataFrame) -> pl.DataFrame:
@@ -440,16 +420,21 @@ class Winsorizer(BaseProcessor):
def _transform_by_date(self, X: pl.DataFrame) -> pl.DataFrame:
"""每日独立缩尾"""
numeric_cols = [c for c in X.columns if X[c].dtype.is_numeric()]
# 确定需要处理的数值列
target_cols = [
col
for col in self.feature_cols
if col in X.columns and X[col].dtype.is_numeric()
]
# 先计算每天的分位数
lower_exprs = [
pl.col(col).quantile(self.lower).over(self.date_col).alias(f"{col}_lower")
for col in numeric_cols
for col in target_cols
]
upper_exprs = [
pl.col(col).quantile(self.upper).over(self.date_col).alias(f"{col}_upper")
for col in numeric_cols
for col in target_cols
]
# 添加分位数列
@@ -458,7 +443,7 @@ class Winsorizer(BaseProcessor):
# 执行缩尾
clip_exprs = []
for col in X.columns:
if col in numeric_cols:
if col in target_cols:
clipped = (
pl.col(col)
.clip(pl.col(f"{col}_lower"), pl.col(f"{col}_upper"))