From 1fa4ff954494d1168a34cd9b7fe4c8aec2decfe8 Mon Sep 17 00:00:00 2001 From: liaozhaorun <1300336796@qq.com> Date: Sun, 5 Apr 2026 19:01:08 +0800 Subject: [PATCH] =?UTF-8?q?feat(training):=20TabM=20=E6=8E=92=E5=BA=8F?= =?UTF-8?q?=E6=A8=A1=E5=9E=8B=E6=9E=B6=E6=9E=84=E4=BC=98=E5=8C=96=E4=B8=8E?= =?UTF-8?q?=20Rank-Gauss=20=E6=A0=87=E7=AD=BE=E5=B7=A5=E7=A8=8B=20-=20TabM?= =?UTF-8?q?SetRank:=20=E5=B0=86=20TabM=20=E8=BE=93=E5=87=BA=E6=94=B9?= =?UTF-8?q?=E4=B8=BA=E9=9A=90=E8=97=8F=E5=B1=82=E7=89=B9=E5=BE=81=EF=BC=8C?= =?UTF-8?q?=E7=BB=8F=20SetRankHead=20=E4=BA=A4=E4=BA=92=E5=90=8E=E9=80=9A?= =?UTF-8?q?=E8=BF=87=20final=5Fmlp=20=E8=BE=93=E5=87=BA=20Ensemble=20?= =?UTF-8?q?=E6=8E=92=E5=BA=8F=E5=88=86=20-=20SetRankHead=20=E5=BC=95?= =?UTF-8?q?=E5=85=A5=E5=8F=AF=E5=AD=A6=E4=B9=A0=E6=AE=8B=E5=B7=AE=E7=BC=A9?= =?UTF-8?q?=E6=94=BE=E5=9B=A0=E5=AD=90=EF=BC=88Zero-init=EF=BC=89=E4=B8=8E?= =?UTF-8?q?=20Pre-Norm=20=E7=BB=93=E6=9E=84=EF=BC=8C=E6=8F=90=E5=8D=87?= =?UTF-8?q?=E8=AE=AD=E7=BB=83=E7=A8=B3=E5=AE=9A=E6=80=A7=20-=20TabMRankTas?= =?UTF-8?q?k=20=E6=96=B0=E5=A2=9E=20Rank-Gauss=20=E8=BF=9E=E7=BB=AD?= =?UTF-8?q?=E6=A0=87=E7=AD=BE=E5=8F=98=E6=8D=A2=EF=BC=8C=E6=94=AF=E6=8C=81?= =?UTF-8?q?=E6=A0=87=E5=87=86=E5=88=86=E4=BD=8D=E6=95=B0/=E6=8C=87?= =?UTF-8?q?=E6=95=B0=E5=A2=9E=E7=9B=8A/Rank-Gauss=20=E4=B8=89=E7=A7=8D?= =?UTF-8?q?=E6=A0=87=E7=AD=BE=E6=A8=A1=E5=BC=8F=20-=20=E4=BF=AE=E5=A4=8D?= =?UTF-8?q?=20NDCG=20=E8=AF=84=E4=BC=B0=E5=9C=A8=E8=B4=9F=E5=80=BC?= =?UTF-8?q?=E6=A0=87=E7=AD=BE=E4=B8=8B=E7=9A=84=E8=AE=A1=E7=AE=97=E9=97=AE?= =?UTF-8?q?=E9=A2=98=20-=20=E8=B0=83=E6=95=B4=E5=AE=9E=E9=AA=8C=E8=84=9A?= =?UTF-8?q?=E6=9C=AC=E8=B6=85=E5=8F=82=E6=95=B0=EF=BC=88dropout=E3=80=81hi?= =?UTF-8?q?dden=20dim=E3=80=81weight=20decay=EF=BC=89=E5=8F=8A=E6=8E=92?= =?UTF-8?q?=E9=99=A4=E5=9B=A0=E5=AD=90=E5=88=97=E8=A1=A8=20-=20=E8=BF=81?= =?UTF-8?q?=E7=A7=BB=E5=BA=9F=E5=BC=83=E7=9A=84=20torch.cuda.amp=20?= =?UTF-8?q?=E5=88=B0=20torch.amp=EF=BC=8C=E5=B9=B6=E5=B0=86=E6=95=B0?= =?UTF-8?q?=E6=8D=AE=E9=A2=84=E5=8A=A0=E8=BD=BD=E8=87=B3=20GPU=20=E5=87=8F?= =?UTF-8?q?=E5=B0=91=E5=BE=AA=E7=8E=AF=E6=8B=B7=E8=B4=9D?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/experiment/learn_to_rank.py | 19 ++- src/experiment/regression.py | 46 +++---- src/experiment/tabm_rank_train.py | 12 +- src/experiment/tabm_setrank_train.py | 8 +- .../components/models/tabm_rank_model.py | 7 +- .../components/models/tabm_setrank_model.py | 103 +++++++++++----- src/training/tasks/tabm_rank_task.py | 115 +++++++++++++----- 7 files changed, 205 insertions(+), 105 deletions(-) diff --git a/src/experiment/learn_to_rank.py b/src/experiment/learn_to_rank.py index 98107eb..cb4cb65 100644 --- a/src/experiment/learn_to_rank.py +++ b/src/experiment/learn_to_rank.py @@ -54,9 +54,22 @@ N_QUANTILES = 20 # 排除的因子列表 EXCLUDED_FACTORS = [ - # 'debt_to_equity', - # 'GTJA_alpha016', - # 'GTJA_alpha141', + "amivest_liq_20", + "atr_price_impact", + "hui_heubel_ratio", + "corwin_schultz_spread_20", + "roll_spread_20", + "gibbs_effective_spread", + "overnight_illiq_20", + "illiq_volatility_20", + "amount_cv_20", + "amount_skewness_20", + "low_vol_days_20", + "liquidity_shock_momentum", + "downside_illiq_20", + "upside_illiq_20", + "illiq_asymmetry_20", + "pastor_stambaugh_proxy" ] diff --git a/src/experiment/regression.py b/src/experiment/regression.py index 4bc56e5..c661780 100644 --- a/src/experiment/regression.py +++ b/src/experiment/regression.py @@ -52,36 +52,22 @@ TRAINING_TYPE = "regression" # 排除的因子列表 EXCLUDED_FACTORS = [ - # 'GTJA_alpha016', - # 'volatility_20', - # 'current_ratio', - # 'GTJA_alpha001', - # 'GTJA_alpha141', - # 'GTJA_alpha129', - # 'GTJA_alpha164', - # 'amivest_liq_20', - # 'GTJA_alpha012', - # 'debt_to_equity', - # 'turnover_deviation', - # 'GTJA_alpha073', - # 'GTJA_alpha043', - # 'GTJA_alpha032', - # 'GTJA_alpha028', - # 'GTJA_alpha090', - # 'GTJA_alpha108', - # 'GTJA_alpha105', - # 'GTJA_alpha091', - # 'GTJA_alpha119', - # 'GTJA_alpha104', - # 'GTJA_alpha163', - # 'GTJA_alpha157', - # 'cost_skewness', - # 'GTJA_alpha176', - # 'chip_transition', - # 'amount_skewness_20', - # 'GTJA_alpha148', - # 'mean_median_dev', - # 'downside_illiq_20', + # "amivest_liq_20", + # "atr_price_impact", + # "hui_heubel_ratio", + # "corwin_schultz_spread_20", + # "roll_spread_20", + # "gibbs_effective_spread", + # "overnight_illiq_20", + # "illiq_volatility_20", + # "amount_cv_20", + # "amount_skewness_20", + # "low_vol_days_20", + # "liquidity_shock_momentum", + # "downside_illiq_20", + # "upside_illiq_20", + # "illiq_asymmetry_20", + # "pastor_stambaugh_proxy" ] # 模型参数配置 diff --git a/src/experiment/tabm_rank_train.py b/src/experiment/tabm_rank_train.py index b360006..b43fc81 100644 --- a/src/experiment/tabm_rank_train.py +++ b/src/experiment/tabm_rank_train.py @@ -46,12 +46,16 @@ TRAINING_TYPE = "tabm_rank" # Label 配置(从 common.py 统一导入) # LABEL_NAME 和 LABEL_FACTOR 已在 common.py 中绑定,只需从 common 导入 -# 分位数配置(提高分辨率以更好地区分头部) +# 分位数配置(分桶模式下使用;Rank-Gauss 模式下不使用,但保留兼容性) N_QUANTILES = 50 -# 【Top-K 优化】标签工程配置 - 默认启用平方增益 -LABEL_TRANSFORM = "exponential" # 启用平方增益标签 (rank^2) -LABEL_SCALE = 20.0 # 保留参数(当前未使用,平方变换不需要缩放) +# 标签工程配置 +# 可选值: +# - "rank_gauss": Rank-Gauss 连续化标签(推荐,神经网络更友好) +# - "exponential": 指数化增益标签 (rank^2) +# - None: 标准分位数标签 (0, 1, ..., n_quantiles-1) +LABEL_TRANSFORM = "rank_gauss" +LABEL_SCALE = 20.0 # 保留参数(rank_gauss / exponential 下均未使用) # 排除的因子列表 EXCLUDED_FACTORS = ["GTJA_alpha041", "GTJA_alpha127"] diff --git a/src/experiment/tabm_setrank_train.py b/src/experiment/tabm_setrank_train.py index 37e1295..18fb243 100644 --- a/src/experiment/tabm_setrank_train.py +++ b/src/experiment/tabm_setrank_train.py @@ -61,7 +61,7 @@ MODEL_PARAMS = { # ==================== MLP 结构 ==================== "n_blocks": 3, "d_block": 256, - "dropout": 0.5, + "dropout": 0.3, # ==================== 集成机制 ==================== "ensemble_size": 32, @@ -71,9 +71,9 @@ MODEL_PARAMS = { "setrank_heads": 4, # 【优化1】将隐藏维度从 128 降到 64。 # 截面特征对比不需要那么宽的维度,太宽会导致模型记忆当天特有的无效噪音。 - "setrank_hidden": 128, + "setrank_hidden": 256, # 【优化2】增大 SetRank 层的 Dropout - "setrank_dropout": 0.5, + "setrank_dropout": 0.3, # ==================== AMP 与显存优化 ==================== "use_amp": True, @@ -85,7 +85,7 @@ MODEL_PARAMS = { "learning_rate": 5e-4, # 【优化4】核心操作!将 L2 惩罚(权重衰减)放大 10 倍甚至 100 倍! # 带有 Attention 的网络极容易对某些特定股票产生依赖,强烈的 Weight Decay 能逼迫模型关注全局特征。 - "weight_decay": 1e-5, # 原为 1e-5,现改为 1e-3 + "weight_decay": 1e-4, # 原为 1e-5,现改为 1e-3 "epochs": 150, # 不需要 500 次,从图中看 150 绝对够了 diff --git a/src/training/components/models/tabm_rank_model.py b/src/training/components/models/tabm_rank_model.py index 741b583..da7cb04 100644 --- a/src/training/components/models/tabm_rank_model.py +++ b/src/training/components/models/tabm_rank_model.py @@ -739,7 +739,12 @@ class TabMRankModel(BaseModel): for y_true, y_score in zip(y_true_list, y_score_list): if len(y_true) > 1: try: - score = ndcg_score([y_true], [y_score], k=k) + # 若标签存在负值(如 Rank-Gauss),平移到正值区间再计算 NDCG + if y_true.min() < 0: + y_true_shifted = y_true - y_true.min() + 1e-5 + score = ndcg_score([y_true_shifted], [y_score], k=k) + else: + score = ndcg_score([y_true], [y_score], k=k) ndcg_scores.append(score) except ValueError: pass diff --git a/src/training/components/models/tabm_setrank_model.py b/src/training/components/models/tabm_setrank_model.py index 933eaeb..b6fc964 100644 --- a/src/training/components/models/tabm_setrank_model.py +++ b/src/training/components/models/tabm_setrank_model.py @@ -16,7 +16,11 @@ import torch.nn as nn import torch.nn.functional as F import torch.optim as optim from torch.utils.data import DataLoader, TensorDataset, Sampler -from torch.cuda.amp import GradScaler + +# 注:torch.cuda.amp.autocast 已被废弃,请统一使用 torch.amp.autocast('cuda', ...) +# 同时从子模块显式导入以兼容各类 IDE 的类型提示 +from torch.amp.grad_scaler import GradScaler +from torch.amp.autocast_mode import autocast from tabm import TabM from src.training.components.base import BaseModel @@ -65,15 +69,23 @@ class SetRankHead(nn.Module): # 4. 恢复集成维度: [N, d_ff] -> [N, E] self.output_proj = nn.Linear(d_ff, d_in) + # 5. 可学习的残差缩放因子,初始化为 0,训练初期退化为纯 TabM + self.res_scale = nn.Parameter(torch.zeros(1)) + self._init_weights() def _init_weights(self): for m in self.modules(): if isinstance(m, nn.Linear): nn.init.xavier_uniform_(m.weight) + if m.bias is not None: + nn.init.zeros_(m.bias) elif isinstance(m, nn.LayerNorm): nn.init.ones_(m.weight) nn.init.zeros_(m.bias) + # Zero-init 输出层,确保训练启动时为恒等映射 + nn.init.zeros_(self.output_proj.weight) + nn.init.zeros_(self.output_proj.bias) def forward(self, x: torch.Tensor) -> torch.Tensor: # x 初始维度: [N, E] @@ -85,15 +97,13 @@ class SetRankHead(nn.Module): # B=1 (代表一天的截面), L=N (股票数量), D=d_ff h = h.unsqueeze(0) - # Attention 交互 - attn_out, _ = self.attn(h, h, h) + # Pre-Norm Attention 交互 + attn_out, _ = self.attn(self.norm1(h), self.norm1(h), self.norm1(h)) + h = h + attn_out - # 残差块 1 - h = self.norm1(h + attn_out) - - # 残差块 2 (FFN) - ffn_out = self.ffn(h) - h = self.norm2(h + ffn_out) + # Pre-Norm FFN + ffn_out = self.ffn(self.norm2(h)) + h = h + ffn_out # 降维并去除 Batch 维度: [1, N, d_ff] -> [N, d_ff] h = h.squeeze(0) @@ -101,8 +111,8 @@ class SetRankHead(nn.Module): # 输出回 Ensemble 维度: [N, E] out_logits = self.output_proj(h) - # 引入残差连接,保留原始 TabM 的预测分,仅用 SetRank 提供修正项 (极大地增加训练稳定性) - return x + out_logits + # 带有可学习缩放因子的残差连接,初始阶段退化为纯 TabM + return x + self.res_scale * out_logits class TabMSetRankNet(nn.Module): @@ -116,32 +126,50 @@ class TabMSetRankNet(nn.Module): setrank_heads: int, setrank_hidden: int, setrank_dropout: float, + mlp_hidden: Optional[int] = None, ): super().__init__() - # TabM 特征提取器 + mlp_hidden = mlp_hidden or d_block + + # TabM 特征提取器 (作为隐藏层,输出维度为 d_block) self.tabm = TabM.make( n_num_features=n_features, cat_cardinalities=[], - d_out=1, + d_out=d_block, n_blocks=n_blocks, d_block=d_block, dropout=dropout, k=ensemble_size, ) - # 截面注意力修正器 + # 截面注意力修正器 (处理 TabM 隐藏特征) self.setrank_head = SetRankHead( - d_in=ensemble_size, + d_in=d_block, n_heads=setrank_heads, d_ff=setrank_hidden, dropout=setrank_dropout, ) - def forward(self, x: torch.Tensor) -> torch.Tensor: - # 1. 独立特征提取 -> [N, E] - tabm_out = self.tabm(x).squeeze(-1) + # 最终两次 MLP 输出 Ensemble 排序分 + self.final_mlp = nn.Sequential( + nn.Linear(d_block, mlp_hidden), + nn.GELU(), + nn.Dropout(dropout), + nn.Linear(mlp_hidden, ensemble_size), + ) - # 2. 截面上下文交互修正 -> [N, E] - scores = self.setrank_head(tabm_out) + def forward(self, x: torch.Tensor) -> torch.Tensor: + # 1. TabM 独立特征提取 -> [N, E, d_block] + tabm_out = self.tabm(x) + + # 2. 对 ensemble 维度取平均,得到原始隐藏特征 -> [N, d_block] + tabm_hidden = tabm_out.mean(dim=1) + + # 3. 截面上下文交互修正 -> [N, d_block] + # SetRankHead 内部实现残差: x + res_scale * transform(x) + setrank_hidden = self.setrank_head(tabm_hidden) + + # 4. 两次 MLP 输出最终 Ensemble 排序分 -> [N, E] + scores = self.final_mlp(setrank_hidden) return scores @@ -193,6 +221,7 @@ class TabMSetRankModel(BaseModel): setrank_heads=self.params.get("setrank_heads", 4), setrank_hidden=self.params.get("setrank_hidden", 128), setrank_dropout=self.params.get("setrank_dropout", 0.1), + mlp_hidden=self.params.get("mlp_hidden", None), ).to(self.device) def _make_loader( @@ -202,10 +231,13 @@ class TabMSetRankModel(BaseModel): group: Optional[np.ndarray] = None, shuffle_groups: bool = False, ) -> DataLoader: - """创建 DataLoader (支持 Query/Group 截面打包)""" - X_tensor = torch.from_numpy(X) + """创建 DataLoader (支持 Query/Group 截面打包) + + 数据在一开始就直接放到 GPU,避免训练/预测循环中反复拷贝。 + """ + X_tensor = torch.from_numpy(X).to(self.device) if y is not None: - y_tensor = torch.from_numpy(y) + y_tensor = torch.from_numpy(y).to(self.device) dataset = TensorDataset(X_tensor, y_tensor) else: dataset = TensorDataset(X_tensor) @@ -241,12 +273,11 @@ class TabMSetRankModel(BaseModel): if len(batch) != 2: continue bx, by = batch - bx = bx.to(self.device) by = by.cpu().numpy() if len(by) <= 1: continue - with torch.amp.autocast("cuda", enabled=self.use_amp): + with autocast("cuda", enabled=self.use_amp): scores = self.model(bx) # [N, E] # 对 Ensemble 维度求平均,得到最终排序分 [N] @@ -331,10 +362,10 @@ class TabMSetRankModel(BaseModel): for batch in train_loader: if len(batch) != 2: continue - bx, by = batch[0].to(self.device), batch[1].to(self.device) + bx, by = batch optimizer.zero_grad() - with torch.amp.autocast("cuda", enabled=self.use_amp): + with autocast("cuda", enabled=self.use_amp): scores = self.model(bx) # [N, E] loss = self.criterion(scores, by) @@ -390,9 +421,19 @@ class TabMSetRankModel(BaseModel): def predict( self, X: pl.DataFrame, group: Optional[np.ndarray] = None ) -> np.ndarray: - """预测排序分数""" + """预测排序分数 + + 对于 TabM + SetRank 模型,group 参数是必需的。 + SetRankHead 中的截面注意力机制假设每个 batch 就是一天的所有股票, + 若缺失 group 退化为普通 batch_size=2048 预测,会导致跨天样本错误交互。 + """ if self.model is None: raise RuntimeError("模型未训练,请先调用fit()") + if group is None: + raise ValueError( + "TabMSetRankModel.predict() 需要提供 group 数组," + "确保每个 batch 为完整的一天截面。请通过 test_data['groups'] 传入。" + ) if self.feature_names_: missing = [c for c in self.feature_names_ if c not in X.columns] if missing: @@ -406,8 +447,8 @@ class TabMSetRankModel(BaseModel): with torch.no_grad(): for batch in loader: - bx = batch[0].to(self.device) - with torch.amp.autocast("cuda", enabled=self.use_amp): + bx = batch[0] + with autocast("cuda", enabled=self.use_amp): scores = self.model(bx) # [N, E] # 对 Ensemble 维度求平均,得到最终排序分 [N] @@ -468,7 +509,7 @@ class TabMSetRankModel(BaseModel): """评估 NDCG 指标""" from sklearn.metrics import ndcg_score - y_pred = self.predict(X) + y_pred = self.predict(X, group=group) y_true_array = y.to_numpy() if isinstance(y, pl.Series) else y y_true_list = [] y_score_list = [] diff --git a/src/training/tasks/tabm_rank_task.py b/src/training/tasks/tabm_rank_task.py index 0c68e93..cd86187 100644 --- a/src/training/tasks/tabm_rank_task.py +++ b/src/training/tasks/tabm_rank_task.py @@ -51,11 +51,44 @@ class TabMRankTask(BaseTask): self.label_scale = label_scale self.model_class = model_class or TabMRankModel - def prepare_labels(self, data: Dict[str, Dict]) -> Dict[str, Dict]: - """准备标签(转换为分位数标签,可选指数化增益变换) + @staticmethod + def _rank_gauss(arr: np.ndarray) -> np.ndarray: + """对一维数组做 Rank-Gauss 变换 - 将连续收益率转换为分位数标签,并生成 group 数组。 - 支持指数化增益变换以增强头部样本的区分度。 + 将数组值转换为基于排名的标准正态分布值,结果近似服从 N(0, 1)。 + + Args: + arr: 输入一维数组 + + Returns: + Rank-Gauss 变换后的数组 + """ + from scipy.special import erfinv + from scipy.stats import rankdata + + n = len(arr) + if n <= 1: + return np.zeros(n, dtype=np.float64) + + # 1. 获取平均排名 (1-based) + ranks = rankdata(arr, method="average") + # 2. 映射到 (0, 1),使用 (ranks - 0.5) / n 避免边界 + p = (ranks - 0.5) / n + # 3. 映射到 (-1, 1) + p = 2.0 * p - 1.0 + # 4. 裁剪边界避免 erfinv 输入越界 + eps = 1e-6 + p = np.clip(p, -1.0 + eps, 1.0 - eps) + # 5. Rank-Gauss: erfinv(p) * sqrt(2) + return erfinv(p) * np.sqrt(2.0) + + def prepare_labels(self, data: Dict[str, Dict]) -> Dict[str, Dict]: + """准备标签(转换为分位数标签或 Rank-Gauss 连续标签) + + 支持三种模式: + - 标准分位数标签 (0, 1, ..., n_quantiles-1) + - 指数化增益变换 (rank^2) + - Rank-Gauss 连续标签(截面正态化,不依赖分桶数) Args: data: 数据字典 @@ -68,40 +101,58 @@ class TabMRankTask(BaseTask): continue df = data[split]["raw_data"] - - # 分位数转换 rank_col = f"{self.label_name}_rank" - # 1. 基础分位数标签 (0 到 n_quantiles-1) - df_ranked = df.with_columns( - pl.col(self.label_name) - .rank(method="min") - .over("trade_date") - .alias("_rank") - ).with_columns( - ((pl.col("_rank") - 1) / pl.len().over("trade_date") * self.n_quantiles) - .floor() - .cast(pl.Int64) - .clip(0, self.n_quantiles - 1) - .alias("_base_rank") - ) - - # 2. 【Top-K 优化】可选指数化增益变换 - if self.label_transform == "exponential": - # 平方变换: rank^2 - # 例如 rank=0 -> 0, rank=10 -> 100, rank=19 -> 361 - # 效果:高分样本与低分样本的差距被平方级拉大 - df_ranked = df_ranked.with_columns( - (pl.col("_base_rank").cast(pl.Float64) ** 2).alias(rank_col) + if self.label_transform == "rank_gauss": + # Rank-Gauss 连续化标签:先截面高斯化,再缩放到 (0, 1) 全局正区间 + df_ranked = df.with_columns( + pl.col(self.label_name) + .map_batches( + lambda s: pl.Series(self._rank_gauss(s.to_numpy())), + is_elementwise=False, + ) + .map_batches( + lambda s: pl.Series((s.to_numpy() + 4.0) / 8.0).clip(0.0, 1.0), + is_elementwise=False, + ) + .over("trade_date") + .alias(rank_col) ) else: - # 标准分位数标签 - df_ranked = df_ranked.with_columns( - pl.col("_base_rank").cast(pl.Float64).alias(rank_col) + # 1. 基础分位数标签 (0 到 n_quantiles-1) + df_ranked = df.with_columns( + pl.col(self.label_name) + .rank(method="min") + .over("trade_date") + .alias("_rank") + ).with_columns( + ( + (pl.col("_rank") - 1) + / pl.len().over("trade_date") + * self.n_quantiles + ) + .floor() + .cast(pl.Int64) + .clip(0, self.n_quantiles - 1) + .alias("_base_rank") ) - # 清理临时列 - df_ranked = df_ranked.drop(["_rank", "_base_rank"]) + # 2. 【Top-K 优化】可选指数化增益变换 + if self.label_transform == "exponential": + # 平方变换: rank^2 + # 例如 rank=0 -> 0, rank=10 -> 100, rank=19 -> 361 + # 效果:高分样本与低分样本的差距被平方级拉大 + df_ranked = df_ranked.with_columns( + (pl.col("_base_rank").cast(pl.Float64) ** 2).alias(rank_col) + ) + else: + # 标准分位数标签 + df_ranked = df_ranked.with_columns( + pl.col("_base_rank").cast(pl.Float64).alias(rank_col) + ) + + # 清理临时列 + df_ranked = df_ranked.drop(["_rank", "_base_rank"]) # 更新数据 data[split]["raw_data"] = df_ranked