feat(training): 添加 LightGBM LambdaRank 排序学习功能
新增基于 LambdaRank 的排序学习模型,用于股票排序预测任务: - 实现 LightGBMLambdaRankModel 模型类,支持分位数标签转换 - 提供完整的训练流程和 NDCG 评估指标 - 添加实验 Notebook 演示排序学习全流程
This commit is contained in:
@@ -15,6 +15,8 @@ A股量化投资框架 - Python 项目,用于量化股票投资分析。
|
|||||||
|
|
||||||
**⚠️ 重要:本项目强制使用 uv 作为 Python 包管理器和运行工具。禁止直接使用 `python` 或 `pip` 命令。**
|
**⚠️ 重要:本项目强制使用 uv 作为 Python 包管理器和运行工具。禁止直接使用 `python` 或 `pip` 命令。**
|
||||||
|
|
||||||
|
**测试规则:** 当修改或查看 `tests/` 目录下的代码时,必须使用 pytest 命令进行测试验证。
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
# 安装依赖(必须使用 uv)
|
# 安装依赖(必须使用 uv)
|
||||||
uv pip install -e .
|
uv pip install -e .
|
||||||
|
|||||||
1607
src/experiment/learn_to_rank.ipynb
Normal file
1607
src/experiment/learn_to_rank.ipynb
Normal file
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
@@ -4,5 +4,6 @@
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
from src.training.components.models.lightgbm import LightGBMModel
|
from src.training.components.models.lightgbm import LightGBMModel
|
||||||
|
from src.training.components.models.lightgbm_lambdarank import LightGBMLambdaRankModel
|
||||||
|
|
||||||
__all__ = ["LightGBMModel"]
|
__all__ = ["LightGBMModel", "LightGBMLambdaRankModel"]
|
||||||
|
|||||||
406
src/training/components/models/lightgbm_lambdarank.py
Normal file
406
src/training/components/models/lightgbm_lambdarank.py
Normal file
@@ -0,0 +1,406 @@
|
|||||||
|
"""LightGBM LambdaRank 排序学习模型
|
||||||
|
|
||||||
|
提供 LightGBM LambdaRank 模型的实现,支持学习排序任务。
|
||||||
|
适用于将股票按未来收益率进行排序的场景。
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Any, Optional, List
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pandas as pd
|
||||||
|
import polars as pl
|
||||||
|
|
||||||
|
from src.training.components.base import BaseModel
|
||||||
|
from src.training.registry import register_model
|
||||||
|
|
||||||
|
|
||||||
|
@register_model("lightgbm_lambdarank")
|
||||||
|
class LightGBMLambdaRankModel(BaseModel):
|
||||||
|
"""LightGBM LambdaRank 排序学习模型
|
||||||
|
|
||||||
|
使用 LightGBM 的 LambdaRank 实现学习排序。
|
||||||
|
适用于股票排序任务,将未来收益率转换为分位数标签进行训练。
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
name: 模型名称 "lightgbm_lambdarank"
|
||||||
|
params: LightGBM 参数字典
|
||||||
|
model: 训练后的 LightGBM Booster 对象
|
||||||
|
feature_names_: 特征名称列表
|
||||||
|
"""
|
||||||
|
|
||||||
|
name = "lightgbm_lambdarank"
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
params: Optional[dict] = None,
|
||||||
|
learning_rate: float = 0.05,
|
||||||
|
num_leaves: int = 31,
|
||||||
|
n_estimators: int = 100,
|
||||||
|
min_data_in_leaf: int = 20,
|
||||||
|
ndcg_at: Optional[List[int]] = None,
|
||||||
|
early_stopping_rounds: int = 50,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
"""初始化 LambdaRank 模型
|
||||||
|
|
||||||
|
支持两种方式传入参数:
|
||||||
|
1. 通过 params 字典传入所有参数(推荐方式)
|
||||||
|
2. 通过独立参数传入(向后兼容)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
params: LightGBM 参数字典,如果提供则直接使用此字典
|
||||||
|
learning_rate: 学习率,默认 0.05
|
||||||
|
num_leaves: 叶子节点数,默认 31
|
||||||
|
n_estimators: 迭代次数,默认 100
|
||||||
|
min_data_in_leaf: 叶子最小样本数,默认 20
|
||||||
|
ndcg_at: NDCG 评估的 k 值列表,默认 [1, 5, 10, 20]
|
||||||
|
early_stopping_rounds: 早停轮数,默认 50
|
||||||
|
**kwargs: 其他 LightGBM 参数
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> # 方式1:通过 params 字典(推荐)
|
||||||
|
>>> model = LightGBMLambdaRankModel(params={
|
||||||
|
... "objective": "lambdarank",
|
||||||
|
... "metric": "ndcg",
|
||||||
|
... "ndcg_at": [1, 5, 10, 20],
|
||||||
|
... "num_leaves": 31,
|
||||||
|
... "learning_rate": 0.05,
|
||||||
|
... "n_estimators": 1000,
|
||||||
|
... })
|
||||||
|
"""
|
||||||
|
if ndcg_at is None:
|
||||||
|
ndcg_at = [1, 5, 10, 20]
|
||||||
|
|
||||||
|
if params is not None:
|
||||||
|
# 方式1:直接使用 params 字典
|
||||||
|
self.params = dict(params) # 复制一份,避免修改原始字典
|
||||||
|
self.params.setdefault("objective", "lambdarank")
|
||||||
|
self.params.setdefault("metric", "ndcg")
|
||||||
|
self.params.setdefault("verbose", -1)
|
||||||
|
self.n_estimators = self.params.pop("n_estimators", n_estimators)
|
||||||
|
self.early_stopping_rounds = self.params.pop(
|
||||||
|
"early_stopping_rounds", early_stopping_rounds
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# 方式2:通过独立参数构建 params
|
||||||
|
self.params = {
|
||||||
|
"objective": "lambdarank",
|
||||||
|
"metric": "ndcg",
|
||||||
|
"ndcg_at": ndcg_at,
|
||||||
|
"num_leaves": num_leaves,
|
||||||
|
"learning_rate": learning_rate,
|
||||||
|
"min_data_in_leaf": min_data_in_leaf,
|
||||||
|
"verbose": -1,
|
||||||
|
**kwargs,
|
||||||
|
}
|
||||||
|
self.n_estimators = n_estimators
|
||||||
|
self.early_stopping_rounds = early_stopping_rounds
|
||||||
|
|
||||||
|
self.model = None
|
||||||
|
self.feature_names_: Optional[list] = None
|
||||||
|
|
||||||
|
def fit(
|
||||||
|
self,
|
||||||
|
X: pl.DataFrame,
|
||||||
|
y: pl.Series,
|
||||||
|
group: Optional[np.ndarray] = None,
|
||||||
|
eval_set: Optional[tuple] = None,
|
||||||
|
) -> "LightGBMLambdaRankModel":
|
||||||
|
"""训练排序模型
|
||||||
|
|
||||||
|
Args:
|
||||||
|
X: 特征矩阵 (Polars DataFrame)
|
||||||
|
y: 目标变量 (Polars Series),应为分位数标签 (0, 1, 2, ...)
|
||||||
|
group: 分组数组,表示每个 query 的样本数。
|
||||||
|
例如 [10, 15, 20] 表示第一个 query 有 10 个样本,
|
||||||
|
第二个 query 有 15 个样本,第三个 query 有 20 个样本。
|
||||||
|
如果为 None,则假设所有样本属于同一个 query。
|
||||||
|
eval_set: 验证集元组 (X_val, y_val, group_val),用于早停
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
self (支持链式调用)
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ImportError: 未安装 lightgbm
|
||||||
|
RuntimeError: 训练失败
|
||||||
|
ValueError: group 参数无效
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
import lightgbm as lgb
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError(
|
||||||
|
"使用 LightGBMLambdaRankModel 需要安装 lightgbm: pip install lightgbm"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 保存特征名称
|
||||||
|
self.feature_names_ = X.columns
|
||||||
|
|
||||||
|
# 转换为 numpy
|
||||||
|
X_np = X.to_numpy()
|
||||||
|
y_np = y.to_numpy()
|
||||||
|
|
||||||
|
# 处理 group 参数
|
||||||
|
if group is None:
|
||||||
|
# 如果未提供 group,假设所有样本属于同一个 query
|
||||||
|
group = np.array([len(y_np)])
|
||||||
|
|
||||||
|
# 验证 group 参数
|
||||||
|
if not isinstance(group, np.ndarray):
|
||||||
|
group = np.array(group)
|
||||||
|
if group.sum() != len(y_np):
|
||||||
|
raise ValueError(
|
||||||
|
f"group 数组的和 ({group.sum()}) 必须等于样本数 ({len(y_np)})"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 创建训练数据集
|
||||||
|
train_data = lgb.Dataset(X_np, label=y_np, group=group)
|
||||||
|
|
||||||
|
# 准备验证集
|
||||||
|
valid_sets = [train_data]
|
||||||
|
if eval_set is not None:
|
||||||
|
X_val, y_val, group_val = eval_set
|
||||||
|
X_val_np = X_val.to_numpy() if isinstance(X_val, pl.DataFrame) else X_val
|
||||||
|
y_val_np = y_val.to_numpy() if isinstance(y_val, pl.Series) else y_val
|
||||||
|
|
||||||
|
if group_val is None:
|
||||||
|
group_val = np.array([len(y_val_np)])
|
||||||
|
elif not isinstance(group_val, np.ndarray):
|
||||||
|
group_val = np.array(group_val)
|
||||||
|
|
||||||
|
val_data = lgb.Dataset(X_val_np, label=y_val_np, group=group_val)
|
||||||
|
valid_sets.append(val_data)
|
||||||
|
|
||||||
|
# 训练
|
||||||
|
callbacks = [lgb.early_stopping(stopping_rounds=self.early_stopping_rounds)]
|
||||||
|
|
||||||
|
self.model = lgb.train(
|
||||||
|
self.params,
|
||||||
|
train_data,
|
||||||
|
num_boost_round=self.n_estimators,
|
||||||
|
valid_sets=valid_sets,
|
||||||
|
callbacks=callbacks,
|
||||||
|
)
|
||||||
|
|
||||||
|
return self
|
||||||
|
|
||||||
|
def predict(self, X: pl.DataFrame) -> np.ndarray:
|
||||||
|
"""预测排序分数
|
||||||
|
|
||||||
|
Args:
|
||||||
|
X: 特征矩阵 (Polars DataFrame)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
预测分数 (numpy ndarray),分数越高表示排序越靠前
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
RuntimeError: 模型未训练时调用
|
||||||
|
"""
|
||||||
|
if self.model is None:
|
||||||
|
raise RuntimeError("模型尚未训练,请先调用 fit()")
|
||||||
|
|
||||||
|
X_np = X.to_numpy()
|
||||||
|
return self.model.predict(X_np)
|
||||||
|
|
||||||
|
def feature_importance(self) -> Optional[pd.Series]:
|
||||||
|
"""返回特征重要性
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
特征重要性序列,如果模型未训练则返回 None
|
||||||
|
"""
|
||||||
|
if self.model is None or self.feature_names_ is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
importance = self.model.feature_importance(importance_type="gain")
|
||||||
|
return pd.Series(importance, index=self.feature_names_)
|
||||||
|
|
||||||
|
def save(self, path: str) -> None:
|
||||||
|
"""保存模型(使用 LightGBM 原生格式)
|
||||||
|
|
||||||
|
使用 LightGBM 的原生格式保存,不依赖 pickle,
|
||||||
|
可以在不同环境中加载。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
path: 保存路径
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
RuntimeError: 模型未训练时调用
|
||||||
|
"""
|
||||||
|
if self.model is None:
|
||||||
|
raise RuntimeError("模型尚未训练,无法保存")
|
||||||
|
|
||||||
|
self.model.save_model(path)
|
||||||
|
|
||||||
|
# 同时保存特征名称和其他元数据
|
||||||
|
import json
|
||||||
|
|
||||||
|
meta_path = path + ".meta.json"
|
||||||
|
with open(meta_path, "w") as f:
|
||||||
|
json.dump(
|
||||||
|
{
|
||||||
|
"feature_names": self.feature_names_,
|
||||||
|
"params": self.params,
|
||||||
|
"n_estimators": self.n_estimators,
|
||||||
|
"early_stopping_rounds": self.early_stopping_rounds,
|
||||||
|
},
|
||||||
|
f,
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def load(cls, path: str) -> "LightGBMLambdaRankModel":
|
||||||
|
"""加载模型
|
||||||
|
|
||||||
|
从 LightGBM 原生格式加载模型。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
path: 模型文件路径
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
加载的 LightGBMLambdaRankModel 实例
|
||||||
|
"""
|
||||||
|
import lightgbm as lgb
|
||||||
|
import json
|
||||||
|
|
||||||
|
instance = cls()
|
||||||
|
instance.model = lgb.Booster(model_file=path)
|
||||||
|
|
||||||
|
# 加载元数据
|
||||||
|
meta_path = path + ".meta.json"
|
||||||
|
try:
|
||||||
|
with open(meta_path, "r") as f:
|
||||||
|
meta = json.load(f)
|
||||||
|
instance.feature_names_ = meta.get("feature_names")
|
||||||
|
instance.params = meta.get("params", instance.params)
|
||||||
|
instance.n_estimators = meta.get("n_estimators", instance.n_estimators)
|
||||||
|
instance.early_stopping_rounds = meta.get(
|
||||||
|
"early_stopping_rounds", instance.early_stopping_rounds
|
||||||
|
)
|
||||||
|
except FileNotFoundError:
|
||||||
|
# 如果没有元数据文件,继续运行(feature_names_ 为 None)
|
||||||
|
pass
|
||||||
|
|
||||||
|
return instance
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def prepare_group_from_dates(
|
||||||
|
df: pl.DataFrame,
|
||||||
|
date_col: str = "trade_date",
|
||||||
|
) -> np.ndarray:
|
||||||
|
"""从日期列生成 group 数组
|
||||||
|
|
||||||
|
将数据按日期分组,每个日期作为一个 query。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
df: 包含日期列的 DataFrame
|
||||||
|
date_col: 日期列名,默认 "trade_date"
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
group 数组,表示每个日期的样本数
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> df = pl.DataFrame({
|
||||||
|
... "trade_date": ["20240101", "20240101", "20240102", "20240102", "20240102"],
|
||||||
|
... "feature": [1, 2, 3, 4, 5]
|
||||||
|
... })
|
||||||
|
>>> group = LightGBMLambdaRankModel.prepare_group_from_dates(df)
|
||||||
|
>>> print(group) # array([2, 3])
|
||||||
|
"""
|
||||||
|
# 按日期统计样本数
|
||||||
|
group_counts = df.group_by(date_col, maintain_order=True).agg(
|
||||||
|
pl.count().alias("count")
|
||||||
|
)
|
||||||
|
return group_counts["count"].to_numpy()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def convert_to_quantile_labels(
|
||||||
|
df: pl.DataFrame,
|
||||||
|
label_col: str,
|
||||||
|
date_col: str = "trade_date",
|
||||||
|
n_quantiles: int = 20,
|
||||||
|
new_col_name: Optional[str] = None,
|
||||||
|
) -> pl.DataFrame:
|
||||||
|
"""将连续标签转换为分位数标签
|
||||||
|
|
||||||
|
对每个日期的数据分别进行分位数划分,生成 0, 1, 2, ..., n_quantiles-1 的标签。
|
||||||
|
值越大表示原始值越大(排序越靠前)。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
df: 输入 DataFrame
|
||||||
|
label_col: 原始标签列名(如 "future_return_5")
|
||||||
|
date_col: 日期列名,默认 "trade_date"
|
||||||
|
n_quantiles: 分位数数量,默认 20
|
||||||
|
new_col_name: 新列名,默认在原始列名后加 "_rank"
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
添加了分位数标签列的 DataFrame
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> df = pl.DataFrame({
|
||||||
|
... "trade_date": ["20240101"] * 5 + ["20240102"] * 5,
|
||||||
|
... "future_return_5": [0.01, 0.02, 0.03, 0.04, 0.05,
|
||||||
|
... 0.02, 0.03, 0.04, 0.05, 0.06]
|
||||||
|
... })
|
||||||
|
>>> df = LightGBMLambdaRankModel.convert_to_quantile_labels(
|
||||||
|
... df, "future_return_5", n_quantiles=5
|
||||||
|
... )
|
||||||
|
>>> print(df["future_return_5_rank"]) # [0, 1, 2, 3, 4, 0, 1, 2, 3, 4]
|
||||||
|
"""
|
||||||
|
if new_col_name is None:
|
||||||
|
new_col_name = f"{label_col}_rank"
|
||||||
|
|
||||||
|
# 使用 qcut 按日期分组进行分位数划分
|
||||||
|
# qcut 返回的是 Categorical,使用 to_physical() 转换为整数(0, 1, 2, ...)
|
||||||
|
return df.with_columns(
|
||||||
|
pl.col(label_col)
|
||||||
|
.qcut(n_quantiles)
|
||||||
|
.over(date_col)
|
||||||
|
.to_physical()
|
||||||
|
.cast(pl.Int64)
|
||||||
|
.alias(new_col_name)
|
||||||
|
)
|
||||||
|
|
||||||
|
def evaluate_ndcg(
|
||||||
|
self,
|
||||||
|
X: pl.DataFrame,
|
||||||
|
y: pl.Series,
|
||||||
|
group: np.ndarray,
|
||||||
|
k: Optional[int] = None,
|
||||||
|
) -> float:
|
||||||
|
"""评估 NDCG 指标
|
||||||
|
|
||||||
|
Args:
|
||||||
|
X: 特征矩阵
|
||||||
|
y: 真实标签
|
||||||
|
group: 分组数组
|
||||||
|
k: NDCG@k 的 k 值,None 表示使用所有位置
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
NDCG 分数
|
||||||
|
"""
|
||||||
|
from sklearn.metrics import ndcg_score
|
||||||
|
|
||||||
|
# 获取预测分数
|
||||||
|
y_pred = self.predict(X)
|
||||||
|
|
||||||
|
# 将数据按 group 拆分
|
||||||
|
y_true_list = []
|
||||||
|
y_score_list = []
|
||||||
|
|
||||||
|
start_idx = 0
|
||||||
|
for group_size in group:
|
||||||
|
end_idx = start_idx + group_size
|
||||||
|
y_true_list.append(y.to_numpy()[start_idx:end_idx])
|
||||||
|
y_score_list.append(y_pred[start_idx:end_idx])
|
||||||
|
start_idx = end_idx
|
||||||
|
|
||||||
|
# 计算平均 NDCG
|
||||||
|
ndcg_scores = []
|
||||||
|
for y_true, y_score in zip(y_true_list, y_score_list):
|
||||||
|
if len(y_true) > 1: # 至少要有 2 个样本才能计算 NDCG
|
||||||
|
try:
|
||||||
|
score = ndcg_score([y_true], [y_score], k=k)
|
||||||
|
ndcg_scores.append(score)
|
||||||
|
except ValueError:
|
||||||
|
# 如果标签都相同,跳过
|
||||||
|
pass
|
||||||
|
|
||||||
|
return np.mean(ndcg_scores) if ndcg_scores else 0.0
|
||||||
@@ -29,7 +29,7 @@ class NullFiller(BaseProcessor):
|
|||||||
fill_value: 当 strategy="value" 时使用的填充值
|
fill_value: 当 strategy="value" 时使用的填充值
|
||||||
by_date: 是否按日期独立计算统计量(仅对 mean/median 有效)
|
by_date: 是否按日期独立计算统计量(仅对 mean/median 有效)
|
||||||
date_col: 日期列名
|
date_col: 日期列名
|
||||||
exclude_cols: 不参与填充的列名列表
|
feature_cols: 参与填充的特征列名列表
|
||||||
stats_: 存储学习到的统计量(全局模式)
|
stats_: 存储学习到的统计量(全局模式)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@@ -37,24 +37,24 @@ class NullFiller(BaseProcessor):
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
feature_cols: List[str],
|
||||||
strategy: Literal["zero", "mean", "median", "value"] = "zero",
|
strategy: Literal["zero", "mean", "median", "value"] = "zero",
|
||||||
fill_value: Optional[float] = None,
|
fill_value: Optional[float] = None,
|
||||||
by_date: bool = True,
|
by_date: bool = True,
|
||||||
date_col: str = "trade_date",
|
date_col: str = "trade_date",
|
||||||
exclude_cols: Optional[List[str]] = None,
|
|
||||||
):
|
):
|
||||||
"""初始化缺失值填充处理器
|
"""初始化缺失值填充处理器
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
feature_cols: 参与填充的特征列名列表
|
||||||
strategy: 填充策略,默认 "zero"
|
strategy: 填充策略,默认 "zero"
|
||||||
- "zero": 填充0
|
- "zero": 填充0
|
||||||
- "mean": 填充均值
|
- "mean": 填充均值
|
||||||
- "median": 填充中值
|
- "median": 填充中值
|
||||||
- "value": 填充指定数值(需配合 fill_value)
|
- "value": 填充指定数值(需配合 fill_value)
|
||||||
fill_value: 当 strategy="value" 时的填充值,默认为 None
|
fill_value: 当 strategy="value" 时的填充值,默认为 None
|
||||||
by_date: 是否每天独立计算统计量,默认 False(全局统计量)
|
by_date: 是否每天独立计算统计量,默认 True(截面统计量)
|
||||||
date_col: 日期列名,默认 "trade_date"
|
date_col: 日期列名,默认 "trade_date"
|
||||||
exclude_cols: 不参与填充的列名列表,默认为 ["ts_code", "trade_date"]
|
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: 策略无效或 fill_value 未提供时
|
ValueError: 策略无效或 fill_value 未提供时
|
||||||
@@ -67,12 +67,12 @@ class NullFiller(BaseProcessor):
|
|||||||
if strategy == "value" and fill_value is None:
|
if strategy == "value" and fill_value is None:
|
||||||
raise ValueError("当 strategy='value' 时,必须提供 fill_value")
|
raise ValueError("当 strategy='value' 时,必须提供 fill_value")
|
||||||
|
|
||||||
|
self.feature_cols = feature_cols
|
||||||
self.strategy = strategy
|
self.strategy = strategy
|
||||||
self.fill_value = fill_value
|
self.fill_value = fill_value
|
||||||
self.by_date = by_date
|
self.by_date = by_date
|
||||||
self.date_col = date_col
|
self.date_col = date_col
|
||||||
self.exclude_cols = exclude_cols or ["ts_code", "trade_date"]
|
self.stats_ = {}
|
||||||
self.stats_: dict = {}
|
|
||||||
|
|
||||||
def fit(self, X: pl.DataFrame) -> "NullFiller":
|
def fit(self, X: pl.DataFrame) -> "NullFiller":
|
||||||
"""学习统计量(仅在全局模式下)
|
"""学习统计量(仅在全局模式下)
|
||||||
@@ -87,17 +87,12 @@ class NullFiller(BaseProcessor):
|
|||||||
self
|
self
|
||||||
"""
|
"""
|
||||||
if not self.by_date and self.strategy in ("mean", "median"):
|
if not self.by_date and self.strategy in ("mean", "median"):
|
||||||
numeric_cols = [
|
for col in self.feature_cols:
|
||||||
c
|
if col in X.columns and X[col].dtype.is_numeric():
|
||||||
for c in X.columns
|
|
||||||
if c not in self.exclude_cols and X[c].dtype.is_numeric()
|
|
||||||
]
|
|
||||||
|
|
||||||
for col in numeric_cols:
|
|
||||||
if self.strategy == "mean":
|
if self.strategy == "mean":
|
||||||
self.stats_[col] = X[col].mean()
|
self.stats_[col] = X[col].mean() or 0.0
|
||||||
else: # median
|
else: # median
|
||||||
self.stats_[col] = X[col].median()
|
self.stats_[col] = X[col].median() or 0.0
|
||||||
|
|
||||||
return self
|
return self
|
||||||
|
|
||||||
@@ -125,15 +120,9 @@ class NullFiller(BaseProcessor):
|
|||||||
|
|
||||||
def _fill_with_zero(self, X: pl.DataFrame) -> pl.DataFrame:
|
def _fill_with_zero(self, X: pl.DataFrame) -> pl.DataFrame:
|
||||||
"""使用0填充缺失值"""
|
"""使用0填充缺失值"""
|
||||||
numeric_cols = [
|
|
||||||
c
|
|
||||||
for c in X.columns
|
|
||||||
if c not in self.exclude_cols and X[c].dtype.is_numeric()
|
|
||||||
]
|
|
||||||
|
|
||||||
expressions = []
|
expressions = []
|
||||||
for col in X.columns:
|
for col in X.columns:
|
||||||
if col in numeric_cols:
|
if col in self.feature_cols and X[col].dtype.is_numeric():
|
||||||
expr = pl.col(col).fill_null(0).alias(col)
|
expr = pl.col(col).fill_null(0).alias(col)
|
||||||
expressions.append(expr)
|
expressions.append(expr)
|
||||||
else:
|
else:
|
||||||
@@ -143,15 +132,9 @@ class NullFiller(BaseProcessor):
|
|||||||
|
|
||||||
def _fill_with_value(self, X: pl.DataFrame) -> pl.DataFrame:
|
def _fill_with_value(self, X: pl.DataFrame) -> pl.DataFrame:
|
||||||
"""使用指定值填充缺失值"""
|
"""使用指定值填充缺失值"""
|
||||||
numeric_cols = [
|
|
||||||
c
|
|
||||||
for c in X.columns
|
|
||||||
if c not in self.exclude_cols and X[c].dtype.is_numeric()
|
|
||||||
]
|
|
||||||
|
|
||||||
expressions = []
|
expressions = []
|
||||||
for col in X.columns:
|
for col in X.columns:
|
||||||
if col in numeric_cols:
|
if col in self.feature_cols and X[col].dtype.is_numeric():
|
||||||
expr = pl.col(col).fill_null(self.fill_value).alias(col)
|
expr = pl.col(col).fill_null(self.fill_value).alias(col)
|
||||||
expressions.append(expr)
|
expressions.append(expr)
|
||||||
else:
|
else:
|
||||||
@@ -174,15 +157,16 @@ class NullFiller(BaseProcessor):
|
|||||||
|
|
||||||
def _fill_by_date(self, X: pl.DataFrame) -> pl.DataFrame:
|
def _fill_by_date(self, X: pl.DataFrame) -> pl.DataFrame:
|
||||||
"""使用每天截面统计量填充"""
|
"""使用每天截面统计量填充"""
|
||||||
numeric_cols = [
|
# 确定需要处理的数值列
|
||||||
c
|
target_cols = [
|
||||||
for c in X.columns
|
col
|
||||||
if c not in self.exclude_cols and X[c].dtype.is_numeric()
|
for col in self.feature_cols
|
||||||
|
if col in X.columns and X[col].dtype.is_numeric()
|
||||||
]
|
]
|
||||||
|
|
||||||
# 计算每天的统计量
|
# 计算每天的统计量
|
||||||
stat_exprs = []
|
stat_exprs = []
|
||||||
for col in numeric_cols:
|
for col in target_cols:
|
||||||
if self.strategy == "mean":
|
if self.strategy == "mean":
|
||||||
stat_exprs.append(
|
stat_exprs.append(
|
||||||
pl.col(col).mean().over(self.date_col).alias(f"{col}_stat")
|
pl.col(col).mean().over(self.date_col).alias(f"{col}_stat")
|
||||||
@@ -198,7 +182,7 @@ class NullFiller(BaseProcessor):
|
|||||||
# 使用统计量填充缺失值
|
# 使用统计量填充缺失值
|
||||||
fill_exprs = []
|
fill_exprs = []
|
||||||
for col in X.columns:
|
for col in X.columns:
|
||||||
if col in numeric_cols:
|
if col in target_cols:
|
||||||
expr = pl.col(col).fill_null(pl.col(f"{col}_stat")).alias(col)
|
expr = pl.col(col).fill_null(pl.col(f"{col}_stat")).alias(col)
|
||||||
fill_exprs.append(expr)
|
fill_exprs.append(expr)
|
||||||
else:
|
else:
|
||||||
@@ -219,22 +203,22 @@ class StandardScaler(BaseProcessor):
|
|||||||
适用于需要全局统计量的场景。
|
适用于需要全局统计量的场景。
|
||||||
|
|
||||||
Attributes:
|
Attributes:
|
||||||
exclude_cols: 不参与标准化的列名列表
|
feature_cols: 参与标准化的特征列名列表
|
||||||
mean_: 学习到的均值字典 {列名: 均值}
|
mean_: 学习到的均值字典 {列名: 均值}
|
||||||
std_: 学习到的标准差字典 {列名: 标准差}
|
std_: 学习到的标准差字典 {列名: 标准差}
|
||||||
"""
|
"""
|
||||||
|
|
||||||
name = "standard_scaler"
|
name = "standard_scaler"
|
||||||
|
|
||||||
def __init__(self, exclude_cols: Optional[List[str]] = None):
|
def __init__(self, feature_cols: List[str]):
|
||||||
"""初始化标准化处理器
|
"""初始化标准化处理器
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
exclude_cols: 不参与标准化的列名列表,默认为 ["ts_code", "trade_date"]
|
feature_cols: 参与标准化的特征列名列表
|
||||||
"""
|
"""
|
||||||
self.exclude_cols = exclude_cols or ["ts_code", "trade_date"]
|
self.feature_cols = feature_cols
|
||||||
self.mean_: dict = {}
|
self.mean_ = {}
|
||||||
self.std_: dict = {}
|
self.std_ = {}
|
||||||
|
|
||||||
def fit(self, X: pl.DataFrame) -> "StandardScaler":
|
def fit(self, X: pl.DataFrame) -> "StandardScaler":
|
||||||
"""计算均值和标准差(仅在训练集上)
|
"""计算均值和标准差(仅在训练集上)
|
||||||
@@ -245,15 +229,13 @@ class StandardScaler(BaseProcessor):
|
|||||||
Returns:
|
Returns:
|
||||||
self
|
self
|
||||||
"""
|
"""
|
||||||
numeric_cols = [
|
for col in self.feature_cols:
|
||||||
c
|
if col in X.columns and X[col].dtype.is_numeric():
|
||||||
for c in X.columns
|
col_mean = X[col].mean()
|
||||||
if c not in self.exclude_cols and X[c].dtype.is_numeric()
|
col_std = X[col].std()
|
||||||
]
|
if col_mean is not None and col_std is not None:
|
||||||
|
self.mean_[col] = col_mean
|
||||||
for col in numeric_cols:
|
self.std_[col] = col_std
|
||||||
self.mean_[col] = X[col].mean()
|
|
||||||
self.std_[col] = X[col].std()
|
|
||||||
|
|
||||||
return self
|
return self
|
||||||
|
|
||||||
@@ -291,7 +273,7 @@ class CrossSectionalStandardScaler(BaseProcessor):
|
|||||||
- 公式:z = (x - mean_today) / std_today
|
- 公式:z = (x - mean_today) / std_today
|
||||||
|
|
||||||
Attributes:
|
Attributes:
|
||||||
exclude_cols: 不参与标准化的列名列表
|
feature_cols: 参与标准化的特征列名列表
|
||||||
date_col: 日期列名
|
date_col: 日期列名
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@@ -299,16 +281,16 @@ class CrossSectionalStandardScaler(BaseProcessor):
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
exclude_cols: Optional[List[str]] = None,
|
feature_cols: List[str],
|
||||||
date_col: str = "trade_date",
|
date_col: str = "trade_date",
|
||||||
):
|
):
|
||||||
"""初始化截面标准化处理器
|
"""初始化截面标准化处理器
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
exclude_cols: 不参与标准化的列名列表,默认为 ["ts_code", "trade_date"]
|
feature_cols: 参与标准化的特征列名列表
|
||||||
date_col: 日期列名
|
date_col: 日期列名
|
||||||
"""
|
"""
|
||||||
self.exclude_cols = exclude_cols or ["ts_code", "trade_date"]
|
self.feature_cols = feature_cols
|
||||||
self.date_col = date_col
|
self.date_col = date_col
|
||||||
|
|
||||||
def transform(self, X: pl.DataFrame) -> pl.DataFrame:
|
def transform(self, X: pl.DataFrame) -> pl.DataFrame:
|
||||||
@@ -323,16 +305,10 @@ class CrossSectionalStandardScaler(BaseProcessor):
|
|||||||
Returns:
|
Returns:
|
||||||
标准化后的数据
|
标准化后的数据
|
||||||
"""
|
"""
|
||||||
numeric_cols = [
|
|
||||||
c
|
|
||||||
for c in X.columns
|
|
||||||
if c not in self.exclude_cols and X[c].dtype.is_numeric()
|
|
||||||
]
|
|
||||||
|
|
||||||
# 构建表达式列表
|
# 构建表达式列表
|
||||||
expressions = []
|
expressions = []
|
||||||
for col in X.columns:
|
for col in X.columns:
|
||||||
if col in numeric_cols:
|
if col in self.feature_cols and X[col].dtype.is_numeric():
|
||||||
# 截面标准化:每天独立计算均值和标准差
|
# 截面标准化:每天独立计算均值和标准差
|
||||||
# 避免除以0,当std为0时设为1
|
# 避免除以0,当std为0时设为1
|
||||||
expr = (
|
expr = (
|
||||||
@@ -355,6 +331,7 @@ class Winsorizer(BaseProcessor):
|
|||||||
也可以截面截断(每天独立处理)。
|
也可以截面截断(每天独立处理)。
|
||||||
|
|
||||||
Attributes:
|
Attributes:
|
||||||
|
feature_cols: 参与缩尾的特征列名列表
|
||||||
lower: 下分位数(如0.01表示1%分位数)
|
lower: 下分位数(如0.01表示1%分位数)
|
||||||
upper: 上分位数(如0.99表示99%分位数)
|
upper: 上分位数(如0.99表示99%分位数)
|
||||||
by_date: True=每天独立缩尾, False=全局缩尾
|
by_date: True=每天独立缩尾, False=全局缩尾
|
||||||
@@ -366,6 +343,7 @@ class Winsorizer(BaseProcessor):
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
feature_cols: List[str],
|
||||||
lower: float = 0.01,
|
lower: float = 0.01,
|
||||||
upper: float = 0.99,
|
upper: float = 0.99,
|
||||||
by_date: bool = False,
|
by_date: bool = False,
|
||||||
@@ -374,6 +352,7 @@ class Winsorizer(BaseProcessor):
|
|||||||
"""初始化缩尾处理器
|
"""初始化缩尾处理器
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
feature_cols: 参与缩尾的特征列名列表
|
||||||
lower: 下分位数,默认0.01
|
lower: 下分位数,默认0.01
|
||||||
upper: 上分位数,默认0.99
|
upper: 上分位数,默认0.99
|
||||||
by_date: 每天独立缩尾,默认False(全局缩尾)
|
by_date: 每天独立缩尾,默认False(全局缩尾)
|
||||||
@@ -387,11 +366,12 @@ class Winsorizer(BaseProcessor):
|
|||||||
f"lower ({lower}) 必须小于 upper ({upper}),且都在 [0, 1] 范围内"
|
f"lower ({lower}) 必须小于 upper ({upper}),且都在 [0, 1] 范围内"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.feature_cols = feature_cols
|
||||||
self.lower = lower
|
self.lower = lower
|
||||||
self.upper = upper
|
self.upper = upper
|
||||||
self.by_date = by_date
|
self.by_date = by_date
|
||||||
self.date_col = date_col
|
self.date_col = date_col
|
||||||
self.bounds_: dict = {}
|
self.bounds_ = {}
|
||||||
|
|
||||||
def fit(self, X: pl.DataFrame) -> "Winsorizer":
|
def fit(self, X: pl.DataFrame) -> "Winsorizer":
|
||||||
"""学习分位数边界(仅在全局模式下)
|
"""学习分位数边界(仅在全局模式下)
|
||||||
@@ -403,8 +383,8 @@ class Winsorizer(BaseProcessor):
|
|||||||
self
|
self
|
||||||
"""
|
"""
|
||||||
if not self.by_date:
|
if not self.by_date:
|
||||||
numeric_cols = [c for c in X.columns if X[c].dtype.is_numeric()]
|
for col in self.feature_cols:
|
||||||
for col in numeric_cols:
|
if col in X.columns and X[col].dtype.is_numeric():
|
||||||
self.bounds_[col] = {
|
self.bounds_[col] = {
|
||||||
"lower": X[col].quantile(self.lower),
|
"lower": X[col].quantile(self.lower),
|
||||||
"upper": X[col].quantile(self.upper),
|
"upper": X[col].quantile(self.upper),
|
||||||
@@ -440,16 +420,21 @@ class Winsorizer(BaseProcessor):
|
|||||||
|
|
||||||
def _transform_by_date(self, X: pl.DataFrame) -> pl.DataFrame:
|
def _transform_by_date(self, X: pl.DataFrame) -> pl.DataFrame:
|
||||||
"""每日独立缩尾"""
|
"""每日独立缩尾"""
|
||||||
numeric_cols = [c for c in X.columns if X[c].dtype.is_numeric()]
|
# 确定需要处理的数值列
|
||||||
|
target_cols = [
|
||||||
|
col
|
||||||
|
for col in self.feature_cols
|
||||||
|
if col in X.columns and X[col].dtype.is_numeric()
|
||||||
|
]
|
||||||
|
|
||||||
# 先计算每天的分位数
|
# 先计算每天的分位数
|
||||||
lower_exprs = [
|
lower_exprs = [
|
||||||
pl.col(col).quantile(self.lower).over(self.date_col).alias(f"{col}_lower")
|
pl.col(col).quantile(self.lower).over(self.date_col).alias(f"{col}_lower")
|
||||||
for col in numeric_cols
|
for col in target_cols
|
||||||
]
|
]
|
||||||
upper_exprs = [
|
upper_exprs = [
|
||||||
pl.col(col).quantile(self.upper).over(self.date_col).alias(f"{col}_upper")
|
pl.col(col).quantile(self.upper).over(self.date_col).alias(f"{col}_upper")
|
||||||
for col in numeric_cols
|
for col in target_cols
|
||||||
]
|
]
|
||||||
|
|
||||||
# 添加分位数列
|
# 添加分位数列
|
||||||
@@ -458,7 +443,7 @@ class Winsorizer(BaseProcessor):
|
|||||||
# 执行缩尾
|
# 执行缩尾
|
||||||
clip_exprs = []
|
clip_exprs = []
|
||||||
for col in X.columns:
|
for col in X.columns:
|
||||||
if col in numeric_cols:
|
if col in target_cols:
|
||||||
clipped = (
|
clipped = (
|
||||||
pl.col(col)
|
pl.col(col)
|
||||||
.clip(pl.col(f"{col}_lower"), pl.col(f"{col}_upper"))
|
.clip(pl.col(f"{col}_lower"), pl.col(f"{col}_upper"))
|
||||||
|
|||||||
Reference in New Issue
Block a user