feat(training): 添加 LightGBM LambdaRank 排序学习功能
新增基于 LambdaRank 的排序学习模型,用于股票排序预测任务: - 实现 LightGBMLambdaRankModel 模型类,支持分位数标签转换 - 提供完整的训练流程和 NDCG 评估指标 - 添加实验 Notebook 演示排序学习全流程
This commit is contained in:
@@ -15,6 +15,8 @@ A股量化投资框架 - Python 项目,用于量化股票投资分析。
|
||||
|
||||
**⚠️ 重要:本项目强制使用 uv 作为 Python 包管理器和运行工具。禁止直接使用 `python` 或 `pip` 命令。**
|
||||
|
||||
**测试规则:** 当修改或查看 `tests/` 目录下的代码时,必须使用 pytest 命令进行测试验证。
|
||||
|
||||
```bash
|
||||
# 安装依赖(必须使用 uv)
|
||||
uv pip install -e .
|
||||
|
||||
1607
src/experiment/learn_to_rank.ipynb
Normal file
1607
src/experiment/learn_to_rank.ipynb
Normal file
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
@@ -4,5 +4,6 @@
|
||||
"""
|
||||
|
||||
from src.training.components.models.lightgbm import LightGBMModel
|
||||
from src.training.components.models.lightgbm_lambdarank import LightGBMLambdaRankModel
|
||||
|
||||
__all__ = ["LightGBMModel"]
|
||||
__all__ = ["LightGBMModel", "LightGBMLambdaRankModel"]
|
||||
|
||||
406
src/training/components/models/lightgbm_lambdarank.py
Normal file
406
src/training/components/models/lightgbm_lambdarank.py
Normal file
@@ -0,0 +1,406 @@
|
||||
"""LightGBM LambdaRank 排序学习模型
|
||||
|
||||
提供 LightGBM LambdaRank 模型的实现,支持学习排序任务。
|
||||
适用于将股票按未来收益率进行排序的场景。
|
||||
"""
|
||||
|
||||
from typing import Any, Optional, List
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import polars as pl
|
||||
|
||||
from src.training.components.base import BaseModel
|
||||
from src.training.registry import register_model
|
||||
|
||||
|
||||
@register_model("lightgbm_lambdarank")
|
||||
class LightGBMLambdaRankModel(BaseModel):
|
||||
"""LightGBM LambdaRank 排序学习模型
|
||||
|
||||
使用 LightGBM 的 LambdaRank 实现学习排序。
|
||||
适用于股票排序任务,将未来收益率转换为分位数标签进行训练。
|
||||
|
||||
Attributes:
|
||||
name: 模型名称 "lightgbm_lambdarank"
|
||||
params: LightGBM 参数字典
|
||||
model: 训练后的 LightGBM Booster 对象
|
||||
feature_names_: 特征名称列表
|
||||
"""
|
||||
|
||||
name = "lightgbm_lambdarank"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
params: Optional[dict] = None,
|
||||
learning_rate: float = 0.05,
|
||||
num_leaves: int = 31,
|
||||
n_estimators: int = 100,
|
||||
min_data_in_leaf: int = 20,
|
||||
ndcg_at: Optional[List[int]] = None,
|
||||
early_stopping_rounds: int = 50,
|
||||
**kwargs,
|
||||
):
|
||||
"""初始化 LambdaRank 模型
|
||||
|
||||
支持两种方式传入参数:
|
||||
1. 通过 params 字典传入所有参数(推荐方式)
|
||||
2. 通过独立参数传入(向后兼容)
|
||||
|
||||
Args:
|
||||
params: LightGBM 参数字典,如果提供则直接使用此字典
|
||||
learning_rate: 学习率,默认 0.05
|
||||
num_leaves: 叶子节点数,默认 31
|
||||
n_estimators: 迭代次数,默认 100
|
||||
min_data_in_leaf: 叶子最小样本数,默认 20
|
||||
ndcg_at: NDCG 评估的 k 值列表,默认 [1, 5, 10, 20]
|
||||
early_stopping_rounds: 早停轮数,默认 50
|
||||
**kwargs: 其他 LightGBM 参数
|
||||
|
||||
Examples:
|
||||
>>> # 方式1:通过 params 字典(推荐)
|
||||
>>> model = LightGBMLambdaRankModel(params={
|
||||
... "objective": "lambdarank",
|
||||
... "metric": "ndcg",
|
||||
... "ndcg_at": [1, 5, 10, 20],
|
||||
... "num_leaves": 31,
|
||||
... "learning_rate": 0.05,
|
||||
... "n_estimators": 1000,
|
||||
... })
|
||||
"""
|
||||
if ndcg_at is None:
|
||||
ndcg_at = [1, 5, 10, 20]
|
||||
|
||||
if params is not None:
|
||||
# 方式1:直接使用 params 字典
|
||||
self.params = dict(params) # 复制一份,避免修改原始字典
|
||||
self.params.setdefault("objective", "lambdarank")
|
||||
self.params.setdefault("metric", "ndcg")
|
||||
self.params.setdefault("verbose", -1)
|
||||
self.n_estimators = self.params.pop("n_estimators", n_estimators)
|
||||
self.early_stopping_rounds = self.params.pop(
|
||||
"early_stopping_rounds", early_stopping_rounds
|
||||
)
|
||||
else:
|
||||
# 方式2:通过独立参数构建 params
|
||||
self.params = {
|
||||
"objective": "lambdarank",
|
||||
"metric": "ndcg",
|
||||
"ndcg_at": ndcg_at,
|
||||
"num_leaves": num_leaves,
|
||||
"learning_rate": learning_rate,
|
||||
"min_data_in_leaf": min_data_in_leaf,
|
||||
"verbose": -1,
|
||||
**kwargs,
|
||||
}
|
||||
self.n_estimators = n_estimators
|
||||
self.early_stopping_rounds = early_stopping_rounds
|
||||
|
||||
self.model = None
|
||||
self.feature_names_: Optional[list] = None
|
||||
|
||||
def fit(
|
||||
self,
|
||||
X: pl.DataFrame,
|
||||
y: pl.Series,
|
||||
group: Optional[np.ndarray] = None,
|
||||
eval_set: Optional[tuple] = None,
|
||||
) -> "LightGBMLambdaRankModel":
|
||||
"""训练排序模型
|
||||
|
||||
Args:
|
||||
X: 特征矩阵 (Polars DataFrame)
|
||||
y: 目标变量 (Polars Series),应为分位数标签 (0, 1, 2, ...)
|
||||
group: 分组数组,表示每个 query 的样本数。
|
||||
例如 [10, 15, 20] 表示第一个 query 有 10 个样本,
|
||||
第二个 query 有 15 个样本,第三个 query 有 20 个样本。
|
||||
如果为 None,则假设所有样本属于同一个 query。
|
||||
eval_set: 验证集元组 (X_val, y_val, group_val),用于早停
|
||||
|
||||
Returns:
|
||||
self (支持链式调用)
|
||||
|
||||
Raises:
|
||||
ImportError: 未安装 lightgbm
|
||||
RuntimeError: 训练失败
|
||||
ValueError: group 参数无效
|
||||
"""
|
||||
try:
|
||||
import lightgbm as lgb
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"使用 LightGBMLambdaRankModel 需要安装 lightgbm: pip install lightgbm"
|
||||
)
|
||||
|
||||
# 保存特征名称
|
||||
self.feature_names_ = X.columns
|
||||
|
||||
# 转换为 numpy
|
||||
X_np = X.to_numpy()
|
||||
y_np = y.to_numpy()
|
||||
|
||||
# 处理 group 参数
|
||||
if group is None:
|
||||
# 如果未提供 group,假设所有样本属于同一个 query
|
||||
group = np.array([len(y_np)])
|
||||
|
||||
# 验证 group 参数
|
||||
if not isinstance(group, np.ndarray):
|
||||
group = np.array(group)
|
||||
if group.sum() != len(y_np):
|
||||
raise ValueError(
|
||||
f"group 数组的和 ({group.sum()}) 必须等于样本数 ({len(y_np)})"
|
||||
)
|
||||
|
||||
# 创建训练数据集
|
||||
train_data = lgb.Dataset(X_np, label=y_np, group=group)
|
||||
|
||||
# 准备验证集
|
||||
valid_sets = [train_data]
|
||||
if eval_set is not None:
|
||||
X_val, y_val, group_val = eval_set
|
||||
X_val_np = X_val.to_numpy() if isinstance(X_val, pl.DataFrame) else X_val
|
||||
y_val_np = y_val.to_numpy() if isinstance(y_val, pl.Series) else y_val
|
||||
|
||||
if group_val is None:
|
||||
group_val = np.array([len(y_val_np)])
|
||||
elif not isinstance(group_val, np.ndarray):
|
||||
group_val = np.array(group_val)
|
||||
|
||||
val_data = lgb.Dataset(X_val_np, label=y_val_np, group=group_val)
|
||||
valid_sets.append(val_data)
|
||||
|
||||
# 训练
|
||||
callbacks = [lgb.early_stopping(stopping_rounds=self.early_stopping_rounds)]
|
||||
|
||||
self.model = lgb.train(
|
||||
self.params,
|
||||
train_data,
|
||||
num_boost_round=self.n_estimators,
|
||||
valid_sets=valid_sets,
|
||||
callbacks=callbacks,
|
||||
)
|
||||
|
||||
return self
|
||||
|
||||
def predict(self, X: pl.DataFrame) -> np.ndarray:
|
||||
"""预测排序分数
|
||||
|
||||
Args:
|
||||
X: 特征矩阵 (Polars DataFrame)
|
||||
|
||||
Returns:
|
||||
预测分数 (numpy ndarray),分数越高表示排序越靠前
|
||||
|
||||
Raises:
|
||||
RuntimeError: 模型未训练时调用
|
||||
"""
|
||||
if self.model is None:
|
||||
raise RuntimeError("模型尚未训练,请先调用 fit()")
|
||||
|
||||
X_np = X.to_numpy()
|
||||
return self.model.predict(X_np)
|
||||
|
||||
def feature_importance(self) -> Optional[pd.Series]:
|
||||
"""返回特征重要性
|
||||
|
||||
Returns:
|
||||
特征重要性序列,如果模型未训练则返回 None
|
||||
"""
|
||||
if self.model is None or self.feature_names_ is None:
|
||||
return None
|
||||
|
||||
importance = self.model.feature_importance(importance_type="gain")
|
||||
return pd.Series(importance, index=self.feature_names_)
|
||||
|
||||
def save(self, path: str) -> None:
|
||||
"""保存模型(使用 LightGBM 原生格式)
|
||||
|
||||
使用 LightGBM 的原生格式保存,不依赖 pickle,
|
||||
可以在不同环境中加载。
|
||||
|
||||
Args:
|
||||
path: 保存路径
|
||||
|
||||
Raises:
|
||||
RuntimeError: 模型未训练时调用
|
||||
"""
|
||||
if self.model is None:
|
||||
raise RuntimeError("模型尚未训练,无法保存")
|
||||
|
||||
self.model.save_model(path)
|
||||
|
||||
# 同时保存特征名称和其他元数据
|
||||
import json
|
||||
|
||||
meta_path = path + ".meta.json"
|
||||
with open(meta_path, "w") as f:
|
||||
json.dump(
|
||||
{
|
||||
"feature_names": self.feature_names_,
|
||||
"params": self.params,
|
||||
"n_estimators": self.n_estimators,
|
||||
"early_stopping_rounds": self.early_stopping_rounds,
|
||||
},
|
||||
f,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def load(cls, path: str) -> "LightGBMLambdaRankModel":
|
||||
"""加载模型
|
||||
|
||||
从 LightGBM 原生格式加载模型。
|
||||
|
||||
Args:
|
||||
path: 模型文件路径
|
||||
|
||||
Returns:
|
||||
加载的 LightGBMLambdaRankModel 实例
|
||||
"""
|
||||
import lightgbm as lgb
|
||||
import json
|
||||
|
||||
instance = cls()
|
||||
instance.model = lgb.Booster(model_file=path)
|
||||
|
||||
# 加载元数据
|
||||
meta_path = path + ".meta.json"
|
||||
try:
|
||||
with open(meta_path, "r") as f:
|
||||
meta = json.load(f)
|
||||
instance.feature_names_ = meta.get("feature_names")
|
||||
instance.params = meta.get("params", instance.params)
|
||||
instance.n_estimators = meta.get("n_estimators", instance.n_estimators)
|
||||
instance.early_stopping_rounds = meta.get(
|
||||
"early_stopping_rounds", instance.early_stopping_rounds
|
||||
)
|
||||
except FileNotFoundError:
|
||||
# 如果没有元数据文件,继续运行(feature_names_ 为 None)
|
||||
pass
|
||||
|
||||
return instance
|
||||
|
||||
@staticmethod
|
||||
def prepare_group_from_dates(
|
||||
df: pl.DataFrame,
|
||||
date_col: str = "trade_date",
|
||||
) -> np.ndarray:
|
||||
"""从日期列生成 group 数组
|
||||
|
||||
将数据按日期分组,每个日期作为一个 query。
|
||||
|
||||
Args:
|
||||
df: 包含日期列的 DataFrame
|
||||
date_col: 日期列名,默认 "trade_date"
|
||||
|
||||
Returns:
|
||||
group 数组,表示每个日期的样本数
|
||||
|
||||
Example:
|
||||
>>> df = pl.DataFrame({
|
||||
... "trade_date": ["20240101", "20240101", "20240102", "20240102", "20240102"],
|
||||
... "feature": [1, 2, 3, 4, 5]
|
||||
... })
|
||||
>>> group = LightGBMLambdaRankModel.prepare_group_from_dates(df)
|
||||
>>> print(group) # array([2, 3])
|
||||
"""
|
||||
# 按日期统计样本数
|
||||
group_counts = df.group_by(date_col, maintain_order=True).agg(
|
||||
pl.count().alias("count")
|
||||
)
|
||||
return group_counts["count"].to_numpy()
|
||||
|
||||
@staticmethod
|
||||
def convert_to_quantile_labels(
|
||||
df: pl.DataFrame,
|
||||
label_col: str,
|
||||
date_col: str = "trade_date",
|
||||
n_quantiles: int = 20,
|
||||
new_col_name: Optional[str] = None,
|
||||
) -> pl.DataFrame:
|
||||
"""将连续标签转换为分位数标签
|
||||
|
||||
对每个日期的数据分别进行分位数划分,生成 0, 1, 2, ..., n_quantiles-1 的标签。
|
||||
值越大表示原始值越大(排序越靠前)。
|
||||
|
||||
Args:
|
||||
df: 输入 DataFrame
|
||||
label_col: 原始标签列名(如 "future_return_5")
|
||||
date_col: 日期列名,默认 "trade_date"
|
||||
n_quantiles: 分位数数量,默认 20
|
||||
new_col_name: 新列名,默认在原始列名后加 "_rank"
|
||||
|
||||
Returns:
|
||||
添加了分位数标签列的 DataFrame
|
||||
|
||||
Example:
|
||||
>>> df = pl.DataFrame({
|
||||
... "trade_date": ["20240101"] * 5 + ["20240102"] * 5,
|
||||
... "future_return_5": [0.01, 0.02, 0.03, 0.04, 0.05,
|
||||
... 0.02, 0.03, 0.04, 0.05, 0.06]
|
||||
... })
|
||||
>>> df = LightGBMLambdaRankModel.convert_to_quantile_labels(
|
||||
... df, "future_return_5", n_quantiles=5
|
||||
... )
|
||||
>>> print(df["future_return_5_rank"]) # [0, 1, 2, 3, 4, 0, 1, 2, 3, 4]
|
||||
"""
|
||||
if new_col_name is None:
|
||||
new_col_name = f"{label_col}_rank"
|
||||
|
||||
# 使用 qcut 按日期分组进行分位数划分
|
||||
# qcut 返回的是 Categorical,使用 to_physical() 转换为整数(0, 1, 2, ...)
|
||||
return df.with_columns(
|
||||
pl.col(label_col)
|
||||
.qcut(n_quantiles)
|
||||
.over(date_col)
|
||||
.to_physical()
|
||||
.cast(pl.Int64)
|
||||
.alias(new_col_name)
|
||||
)
|
||||
|
||||
def evaluate_ndcg(
|
||||
self,
|
||||
X: pl.DataFrame,
|
||||
y: pl.Series,
|
||||
group: np.ndarray,
|
||||
k: Optional[int] = None,
|
||||
) -> float:
|
||||
"""评估 NDCG 指标
|
||||
|
||||
Args:
|
||||
X: 特征矩阵
|
||||
y: 真实标签
|
||||
group: 分组数组
|
||||
k: NDCG@k 的 k 值,None 表示使用所有位置
|
||||
|
||||
Returns:
|
||||
NDCG 分数
|
||||
"""
|
||||
from sklearn.metrics import ndcg_score
|
||||
|
||||
# 获取预测分数
|
||||
y_pred = self.predict(X)
|
||||
|
||||
# 将数据按 group 拆分
|
||||
y_true_list = []
|
||||
y_score_list = []
|
||||
|
||||
start_idx = 0
|
||||
for group_size in group:
|
||||
end_idx = start_idx + group_size
|
||||
y_true_list.append(y.to_numpy()[start_idx:end_idx])
|
||||
y_score_list.append(y_pred[start_idx:end_idx])
|
||||
start_idx = end_idx
|
||||
|
||||
# 计算平均 NDCG
|
||||
ndcg_scores = []
|
||||
for y_true, y_score in zip(y_true_list, y_score_list):
|
||||
if len(y_true) > 1: # 至少要有 2 个样本才能计算 NDCG
|
||||
try:
|
||||
score = ndcg_score([y_true], [y_score], k=k)
|
||||
ndcg_scores.append(score)
|
||||
except ValueError:
|
||||
# 如果标签都相同,跳过
|
||||
pass
|
||||
|
||||
return np.mean(ndcg_scores) if ndcg_scores else 0.0
|
||||
@@ -29,7 +29,7 @@ class NullFiller(BaseProcessor):
|
||||
fill_value: 当 strategy="value" 时使用的填充值
|
||||
by_date: 是否按日期独立计算统计量(仅对 mean/median 有效)
|
||||
date_col: 日期列名
|
||||
exclude_cols: 不参与填充的列名列表
|
||||
feature_cols: 参与填充的特征列名列表
|
||||
stats_: 存储学习到的统计量(全局模式)
|
||||
"""
|
||||
|
||||
@@ -37,24 +37,24 @@ class NullFiller(BaseProcessor):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
feature_cols: List[str],
|
||||
strategy: Literal["zero", "mean", "median", "value"] = "zero",
|
||||
fill_value: Optional[float] = None,
|
||||
by_date: bool = True,
|
||||
date_col: str = "trade_date",
|
||||
exclude_cols: Optional[List[str]] = None,
|
||||
):
|
||||
"""初始化缺失值填充处理器
|
||||
|
||||
Args:
|
||||
feature_cols: 参与填充的特征列名列表
|
||||
strategy: 填充策略,默认 "zero"
|
||||
- "zero": 填充0
|
||||
- "mean": 填充均值
|
||||
- "median": 填充中值
|
||||
- "value": 填充指定数值(需配合 fill_value)
|
||||
fill_value: 当 strategy="value" 时的填充值,默认为 None
|
||||
by_date: 是否每天独立计算统计量,默认 False(全局统计量)
|
||||
by_date: 是否每天独立计算统计量,默认 True(截面统计量)
|
||||
date_col: 日期列名,默认 "trade_date"
|
||||
exclude_cols: 不参与填充的列名列表,默认为 ["ts_code", "trade_date"]
|
||||
|
||||
Raises:
|
||||
ValueError: 策略无效或 fill_value 未提供时
|
||||
@@ -67,12 +67,12 @@ class NullFiller(BaseProcessor):
|
||||
if strategy == "value" and fill_value is None:
|
||||
raise ValueError("当 strategy='value' 时,必须提供 fill_value")
|
||||
|
||||
self.feature_cols = feature_cols
|
||||
self.strategy = strategy
|
||||
self.fill_value = fill_value
|
||||
self.by_date = by_date
|
||||
self.date_col = date_col
|
||||
self.exclude_cols = exclude_cols or ["ts_code", "trade_date"]
|
||||
self.stats_: dict = {}
|
||||
self.stats_ = {}
|
||||
|
||||
def fit(self, X: pl.DataFrame) -> "NullFiller":
|
||||
"""学习统计量(仅在全局模式下)
|
||||
@@ -87,17 +87,12 @@ class NullFiller(BaseProcessor):
|
||||
self
|
||||
"""
|
||||
if not self.by_date and self.strategy in ("mean", "median"):
|
||||
numeric_cols = [
|
||||
c
|
||||
for c in X.columns
|
||||
if c not in self.exclude_cols and X[c].dtype.is_numeric()
|
||||
]
|
||||
|
||||
for col in numeric_cols:
|
||||
if self.strategy == "mean":
|
||||
self.stats_[col] = X[col].mean()
|
||||
else: # median
|
||||
self.stats_[col] = X[col].median()
|
||||
for col in self.feature_cols:
|
||||
if col in X.columns and X[col].dtype.is_numeric():
|
||||
if self.strategy == "mean":
|
||||
self.stats_[col] = X[col].mean() or 0.0
|
||||
else: # median
|
||||
self.stats_[col] = X[col].median() or 0.0
|
||||
|
||||
return self
|
||||
|
||||
@@ -125,15 +120,9 @@ class NullFiller(BaseProcessor):
|
||||
|
||||
def _fill_with_zero(self, X: pl.DataFrame) -> pl.DataFrame:
|
||||
"""使用0填充缺失值"""
|
||||
numeric_cols = [
|
||||
c
|
||||
for c in X.columns
|
||||
if c not in self.exclude_cols and X[c].dtype.is_numeric()
|
||||
]
|
||||
|
||||
expressions = []
|
||||
for col in X.columns:
|
||||
if col in numeric_cols:
|
||||
if col in self.feature_cols and X[col].dtype.is_numeric():
|
||||
expr = pl.col(col).fill_null(0).alias(col)
|
||||
expressions.append(expr)
|
||||
else:
|
||||
@@ -143,15 +132,9 @@ class NullFiller(BaseProcessor):
|
||||
|
||||
def _fill_with_value(self, X: pl.DataFrame) -> pl.DataFrame:
|
||||
"""使用指定值填充缺失值"""
|
||||
numeric_cols = [
|
||||
c
|
||||
for c in X.columns
|
||||
if c not in self.exclude_cols and X[c].dtype.is_numeric()
|
||||
]
|
||||
|
||||
expressions = []
|
||||
for col in X.columns:
|
||||
if col in numeric_cols:
|
||||
if col in self.feature_cols and X[col].dtype.is_numeric():
|
||||
expr = pl.col(col).fill_null(self.fill_value).alias(col)
|
||||
expressions.append(expr)
|
||||
else:
|
||||
@@ -174,15 +157,16 @@ class NullFiller(BaseProcessor):
|
||||
|
||||
def _fill_by_date(self, X: pl.DataFrame) -> pl.DataFrame:
|
||||
"""使用每天截面统计量填充"""
|
||||
numeric_cols = [
|
||||
c
|
||||
for c in X.columns
|
||||
if c not in self.exclude_cols and X[c].dtype.is_numeric()
|
||||
# 确定需要处理的数值列
|
||||
target_cols = [
|
||||
col
|
||||
for col in self.feature_cols
|
||||
if col in X.columns and X[col].dtype.is_numeric()
|
||||
]
|
||||
|
||||
# 计算每天的统计量
|
||||
stat_exprs = []
|
||||
for col in numeric_cols:
|
||||
for col in target_cols:
|
||||
if self.strategy == "mean":
|
||||
stat_exprs.append(
|
||||
pl.col(col).mean().over(self.date_col).alias(f"{col}_stat")
|
||||
@@ -198,7 +182,7 @@ class NullFiller(BaseProcessor):
|
||||
# 使用统计量填充缺失值
|
||||
fill_exprs = []
|
||||
for col in X.columns:
|
||||
if col in numeric_cols:
|
||||
if col in target_cols:
|
||||
expr = pl.col(col).fill_null(pl.col(f"{col}_stat")).alias(col)
|
||||
fill_exprs.append(expr)
|
||||
else:
|
||||
@@ -219,22 +203,22 @@ class StandardScaler(BaseProcessor):
|
||||
适用于需要全局统计量的场景。
|
||||
|
||||
Attributes:
|
||||
exclude_cols: 不参与标准化的列名列表
|
||||
feature_cols: 参与标准化的特征列名列表
|
||||
mean_: 学习到的均值字典 {列名: 均值}
|
||||
std_: 学习到的标准差字典 {列名: 标准差}
|
||||
"""
|
||||
|
||||
name = "standard_scaler"
|
||||
|
||||
def __init__(self, exclude_cols: Optional[List[str]] = None):
|
||||
def __init__(self, feature_cols: List[str]):
|
||||
"""初始化标准化处理器
|
||||
|
||||
Args:
|
||||
exclude_cols: 不参与标准化的列名列表,默认为 ["ts_code", "trade_date"]
|
||||
feature_cols: 参与标准化的特征列名列表
|
||||
"""
|
||||
self.exclude_cols = exclude_cols or ["ts_code", "trade_date"]
|
||||
self.mean_: dict = {}
|
||||
self.std_: dict = {}
|
||||
self.feature_cols = feature_cols
|
||||
self.mean_ = {}
|
||||
self.std_ = {}
|
||||
|
||||
def fit(self, X: pl.DataFrame) -> "StandardScaler":
|
||||
"""计算均值和标准差(仅在训练集上)
|
||||
@@ -245,15 +229,13 @@ class StandardScaler(BaseProcessor):
|
||||
Returns:
|
||||
self
|
||||
"""
|
||||
numeric_cols = [
|
||||
c
|
||||
for c in X.columns
|
||||
if c not in self.exclude_cols and X[c].dtype.is_numeric()
|
||||
]
|
||||
|
||||
for col in numeric_cols:
|
||||
self.mean_[col] = X[col].mean()
|
||||
self.std_[col] = X[col].std()
|
||||
for col in self.feature_cols:
|
||||
if col in X.columns and X[col].dtype.is_numeric():
|
||||
col_mean = X[col].mean()
|
||||
col_std = X[col].std()
|
||||
if col_mean is not None and col_std is not None:
|
||||
self.mean_[col] = col_mean
|
||||
self.std_[col] = col_std
|
||||
|
||||
return self
|
||||
|
||||
@@ -291,7 +273,7 @@ class CrossSectionalStandardScaler(BaseProcessor):
|
||||
- 公式:z = (x - mean_today) / std_today
|
||||
|
||||
Attributes:
|
||||
exclude_cols: 不参与标准化的列名列表
|
||||
feature_cols: 参与标准化的特征列名列表
|
||||
date_col: 日期列名
|
||||
"""
|
||||
|
||||
@@ -299,16 +281,16 @@ class CrossSectionalStandardScaler(BaseProcessor):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
exclude_cols: Optional[List[str]] = None,
|
||||
feature_cols: List[str],
|
||||
date_col: str = "trade_date",
|
||||
):
|
||||
"""初始化截面标准化处理器
|
||||
|
||||
Args:
|
||||
exclude_cols: 不参与标准化的列名列表,默认为 ["ts_code", "trade_date"]
|
||||
feature_cols: 参与标准化的特征列名列表
|
||||
date_col: 日期列名
|
||||
"""
|
||||
self.exclude_cols = exclude_cols or ["ts_code", "trade_date"]
|
||||
self.feature_cols = feature_cols
|
||||
self.date_col = date_col
|
||||
|
||||
def transform(self, X: pl.DataFrame) -> pl.DataFrame:
|
||||
@@ -323,16 +305,10 @@ class CrossSectionalStandardScaler(BaseProcessor):
|
||||
Returns:
|
||||
标准化后的数据
|
||||
"""
|
||||
numeric_cols = [
|
||||
c
|
||||
for c in X.columns
|
||||
if c not in self.exclude_cols and X[c].dtype.is_numeric()
|
||||
]
|
||||
|
||||
# 构建表达式列表
|
||||
expressions = []
|
||||
for col in X.columns:
|
||||
if col in numeric_cols:
|
||||
if col in self.feature_cols and X[col].dtype.is_numeric():
|
||||
# 截面标准化:每天独立计算均值和标准差
|
||||
# 避免除以0,当std为0时设为1
|
||||
expr = (
|
||||
@@ -355,6 +331,7 @@ class Winsorizer(BaseProcessor):
|
||||
也可以截面截断(每天独立处理)。
|
||||
|
||||
Attributes:
|
||||
feature_cols: 参与缩尾的特征列名列表
|
||||
lower: 下分位数(如0.01表示1%分位数)
|
||||
upper: 上分位数(如0.99表示99%分位数)
|
||||
by_date: True=每天独立缩尾, False=全局缩尾
|
||||
@@ -366,6 +343,7 @@ class Winsorizer(BaseProcessor):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
feature_cols: List[str],
|
||||
lower: float = 0.01,
|
||||
upper: float = 0.99,
|
||||
by_date: bool = False,
|
||||
@@ -374,6 +352,7 @@ class Winsorizer(BaseProcessor):
|
||||
"""初始化缩尾处理器
|
||||
|
||||
Args:
|
||||
feature_cols: 参与缩尾的特征列名列表
|
||||
lower: 下分位数,默认0.01
|
||||
upper: 上分位数,默认0.99
|
||||
by_date: 每天独立缩尾,默认False(全局缩尾)
|
||||
@@ -387,11 +366,12 @@ class Winsorizer(BaseProcessor):
|
||||
f"lower ({lower}) 必须小于 upper ({upper}),且都在 [0, 1] 范围内"
|
||||
)
|
||||
|
||||
self.feature_cols = feature_cols
|
||||
self.lower = lower
|
||||
self.upper = upper
|
||||
self.by_date = by_date
|
||||
self.date_col = date_col
|
||||
self.bounds_: dict = {}
|
||||
self.bounds_ = {}
|
||||
|
||||
def fit(self, X: pl.DataFrame) -> "Winsorizer":
|
||||
"""学习分位数边界(仅在全局模式下)
|
||||
@@ -403,12 +383,12 @@ class Winsorizer(BaseProcessor):
|
||||
self
|
||||
"""
|
||||
if not self.by_date:
|
||||
numeric_cols = [c for c in X.columns if X[c].dtype.is_numeric()]
|
||||
for col in numeric_cols:
|
||||
self.bounds_[col] = {
|
||||
"lower": X[col].quantile(self.lower),
|
||||
"upper": X[col].quantile(self.upper),
|
||||
}
|
||||
for col in self.feature_cols:
|
||||
if col in X.columns and X[col].dtype.is_numeric():
|
||||
self.bounds_[col] = {
|
||||
"lower": X[col].quantile(self.lower),
|
||||
"upper": X[col].quantile(self.upper),
|
||||
}
|
||||
return self
|
||||
|
||||
def transform(self, X: pl.DataFrame) -> pl.DataFrame:
|
||||
@@ -440,16 +420,21 @@ class Winsorizer(BaseProcessor):
|
||||
|
||||
def _transform_by_date(self, X: pl.DataFrame) -> pl.DataFrame:
|
||||
"""每日独立缩尾"""
|
||||
numeric_cols = [c for c in X.columns if X[c].dtype.is_numeric()]
|
||||
# 确定需要处理的数值列
|
||||
target_cols = [
|
||||
col
|
||||
for col in self.feature_cols
|
||||
if col in X.columns and X[col].dtype.is_numeric()
|
||||
]
|
||||
|
||||
# 先计算每天的分位数
|
||||
lower_exprs = [
|
||||
pl.col(col).quantile(self.lower).over(self.date_col).alias(f"{col}_lower")
|
||||
for col in numeric_cols
|
||||
for col in target_cols
|
||||
]
|
||||
upper_exprs = [
|
||||
pl.col(col).quantile(self.upper).over(self.date_col).alias(f"{col}_upper")
|
||||
for col in numeric_cols
|
||||
for col in target_cols
|
||||
]
|
||||
|
||||
# 添加分位数列
|
||||
@@ -458,7 +443,7 @@ class Winsorizer(BaseProcessor):
|
||||
# 执行缩尾
|
||||
clip_exprs = []
|
||||
for col in X.columns:
|
||||
if col in numeric_cols:
|
||||
if col in target_cols:
|
||||
clipped = (
|
||||
pl.col(col)
|
||||
.clip(pl.col(f"{col}_lower"), pl.col(f"{col}_upper"))
|
||||
|
||||
Reference in New Issue
Block a user