diff --git a/AGENTS.md b/AGENTS.md index 22db931..4f65dc2 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -864,12 +864,13 @@ class CustomProcessor(BaseProcessor): 如果 LSP 检测报错,必须按照以下流程处理: 1. **问题定位** - - 报错必定是由基础格式错误引起:缩进错误、引号括号不匹配、代码格式错误等 + - 报错必定是由基础格式错误引起:缩进错误、引号括号不匹配、代码格式错误、不兼容等 - 必须读取对应的代码行,精确定位错误 2. **修复方式** - ✅ **必须**:读取报错文件,检查具体代码行 - ✅ **必须**:修复格式错误(缩进、括号匹配、引号闭合等) + - ✅ **必须**:判断是否影响代码运行,如果不影响运行,需要询问后可以忽略错误 - ❌ **禁止**:删除文件重新修改 - ❌ **禁止**:自行 rollback 文件 - ❌ **禁止**:新建文件重新修改 diff --git a/src/experiment/common.py b/src/experiment/common.py index 791ec0d..e68e4a9 100644 --- a/src/experiment/common.py +++ b/src/experiment/common.py @@ -271,22 +271,22 @@ SELECTED_FACTORS = [ "pivot_reversion", "chip_transition", - # "amivest_liq_20", - # "atr_price_impact", - # "hui_heubel_ratio", - # "corwin_schultz_spread_20", - # "roll_spread_20", - # "gibbs_effective_spread", - # "overnight_illiq_20", - # "illiq_volatility_20", - # "amount_cv_20", - # "amount_skewness_20", - # "low_vol_days_20", - # "liquidity_shock_momentum", - # "downside_illiq_20", - # "upside_illiq_20", - # "illiq_asymmetry_20", - # "pastor_stambaugh_proxy" + "amivest_liq_20", + "atr_price_impact", + "hui_heubel_ratio", + "corwin_schultz_spread_20", + "roll_spread_20", + "gibbs_effective_spread", + "overnight_illiq_20", + "illiq_volatility_20", + "amount_cv_20", + "amount_skewness_20", + "low_vol_days_20", + "liquidity_shock_momentum", + "downside_illiq_20", + "upside_illiq_20", + "illiq_asymmetry_20", + "pastor_stambaugh_proxy" ] # 因子定义字典(完整因子库,用于存放尚未注册到metadata的因子) diff --git a/src/experiment/tabm_setrank_train.py b/src/experiment/tabm_setrank_train.py new file mode 100644 index 0000000..37e1295 --- /dev/null +++ b/src/experiment/tabm_setrank_train.py @@ -0,0 +1,201 @@ +"""TabM + SetRank 排序学习训练流程 + +使用模块化 Trainer 架构,基于 TabMSetRankModel 实现排序学习。 +引入 SetRank 组内注意力头,其余配置与 tabm_rank_train.py 对齐。 +""" + +import os + +from src.factors import FactorEngine +from src.training import ( + FactorManager, + DataPipeline, + NullFiller, + Winsorizer, + CrossSectionalStandardScaler, +) +from src.training.tasks.tabm_rank_task import TabMRankTask +from src.training.core.trainer_v2 import Trainer +from src.training.components.filters import STFilter +from src.training.components.models import TabMSetRankModel +from src.experiment.common import ( + SELECTED_FACTORS, + FACTOR_DEFINITIONS, + LABEL_NAME, + LABEL_FACTOR, + TRAIN_START, + TRAIN_END, + VAL_START, + VAL_END, + TEST_START, + TEST_END, + stock_pool_filter, + STOCK_FILTER_REQUIRED_COLUMNS, + OUTPUT_DIR, + SAVE_PREDICTIONS, + SAVE_MODEL, + get_model_save_path, + save_model_with_factors, + TOP_N, + TRAIN_SKIP_DAYS, +) + +# 训练类型标识 +TRAINING_TYPE = "tabm_setrank_rank" + +# %% +# Label 配置(从 common.py 统一导入) + +# 分位数配置(提高分辨率以更好地区分头部) +N_QUANTILES = 50 + +# 【Top-K 优化】标签工程配置 - 默认启用平方增益 +LABEL_TRANSFORM = "exponential" # 启用平方增益标签 (rank^2) +LABEL_SCALE = 20.0 # 保留参数(当前未使用,平方变换不需要缩放) + +# 排除的因子列表 +EXCLUDED_FACTORS = ["GTJA_alpha041", "GTJA_alpha127"] + +# TabM + SetRank 模型参数配置 +MODEL_PARAMS = { + # ==================== MLP 结构 ==================== + "n_blocks": 3, + "d_block": 256, + "dropout": 0.5, + + # ==================== 集成机制 ==================== + "ensemble_size": 32, + + # ==================== SetRank 头 (降维防过拟合) ==================== + "use_setrank": True, + "setrank_heads": 4, + # 【优化1】将隐藏维度从 128 降到 64。 + # 截面特征对比不需要那么宽的维度,太宽会导致模型记忆当天特有的无效噪音。 + "setrank_hidden": 128, + # 【优化2】增大 SetRank 层的 Dropout + "setrank_dropout": 0.5, + + # ==================== AMP 与显存优化 ==================== + "use_amp": True, + "num_workers": 0, + "pin_memory": False, + + # ==================== 训练参数 (强正则化) ==================== + # 【优化3】稍微调低学习率,让模型在接近最优点时不要走得太快(防震荡) + "learning_rate": 5e-4, + # 【优化4】核心操作!将 L2 惩罚(权重衰减)放大 10 倍甚至 100 倍! + # 带有 Attention 的网络极容易对某些特定股票产生依赖,强烈的 Weight Decay 能逼迫模型关注全局特征。 + "weight_decay": 1e-5, # 原为 1e-5,现改为 1e-3 + + "epochs": 150, # 不需要 500 次,从图中看 150 绝对够了 + + # ==================== 早停 ==================== + "early_stopping_round": 30, # 耐心值 30 足矣 + + # ==================== NDCG 评估 ==================== + "ndcg_k": 20, + + # ==================== 损失函数配置 ==================== + "loss_type": "lambda", + "lambda_sigma": 1.0, + # 【优化5】稍微放大 DeltaNDCG 的权重幂次,让模型在排错 Top 5 股票时受到更严厉的惩罚 + "ndcg_weight_power": 1.0, +} + +# 日期范围配置 +date_range = { + "train": (TRAIN_START, TRAIN_END), + "val": (VAL_START, VAL_END), + "test": (TEST_START, TEST_END), +} + +# 输出配置 +output_config = { + "output_dir": OUTPUT_DIR, + "output_filename": "tabm_setrank_rank_output.csv", + "save_predictions": SAVE_PREDICTIONS, + "save_model": SAVE_MODEL, + "model_save_path": get_model_save_path(TRAINING_TYPE), + "top_n": TOP_N, +} + + +def main(): + """主函数""" + print("\n" + "=" * 80) + print("TabM + SetRank 排序学习训练") + print("=" * 80) + + # 1. 创建 FactorEngine + print("\n[1] 创建 FactorEngine") + engine = FactorEngine() + + # 2. 创建 FactorManager + print("\n[2] 创建 FactorManager") + factor_manager = FactorManager( + selected_factors=SELECTED_FACTORS, + factor_definitions=FACTOR_DEFINITIONS, + label_factor=LABEL_FACTOR, + excluded_factors=EXCLUDED_FACTORS, + ) + + # 3. 创建 DataPipeline + print("\n[3] 创建 DataPipeline") + pipeline = DataPipeline( + factor_manager=factor_manager, + processor_configs=[ + (Winsorizer, {"lower": 0.01, "upper": 0.99}), + (NullFiller, {"strategy": "mean"}), + (CrossSectionalStandardScaler, {}), + ], + filters=[STFilter(data_router=engine.router)], + stock_pool_filter_func=stock_pool_filter, + stock_pool_required_columns=STOCK_FILTER_REQUIRED_COLUMNS, + train_skip_days=TRAIN_SKIP_DAYS, + ) + + # 4. 创建 TabMRankTask(注入 TabMSetRankModel) + print("\n[4] 创建 TabMRankTask(TabMSetRankModel)") + task = TabMRankTask( + model_class=TabMSetRankModel, + model_params=MODEL_PARAMS, + label_name=LABEL_NAME, + n_quantiles=N_QUANTILES, + label_transform=LABEL_TRANSFORM, + label_scale=LABEL_SCALE, + ) + + # 5. 创建 Trainer + print("\n[5] 创建 Trainer") + trainer = Trainer( + data_pipeline=pipeline, + task=task, + output_config=output_config, + verbose=True, + ) + + # 6. 执行训练 + print("\n[6] 执行训练") + results = trainer.run(engine=engine, date_range=date_range) + + # 7. 保存模型和因子信息(如果启用) + if SAVE_MODEL: + print("\n[7] 保存模型和因子信息") + save_model_with_factors( + model=task.get_model(), + model_path=output_config["model_save_path"], + selected_factors=SELECTED_FACTORS, + factor_definitions=FACTOR_DEFINITIONS, + fitted_processors=pipeline.get_fitted_processors(), + ) + + print("\n" + "=" * 80) + print("训练流程完成!") + print(f"结果保存路径: {os.path.join(OUTPUT_DIR, 'tabm_setrank_rank_output.csv')}") + print("=" * 80) + + return results + + +if __name__ == "__main__": + main() diff --git a/src/training/components/models/__init__.py b/src/training/components/models/__init__.py index e757708..2131517 100644 --- a/src/training/components/models/__init__.py +++ b/src/training/components/models/__init__.py @@ -12,6 +12,7 @@ from src.training.components.models.tabm_rank_model import ( EnsembleListNetLoss, EnsembleLambdaLoss, ) +from src.training.components.models.tabm_setrank_model import TabMSetRankModel from src.training.components.models.cross_section_sampler import CrossSectionSampler from src.training.components.models.ensemble_quant_loss import EnsembleQuantLoss @@ -21,6 +22,7 @@ __all__ = [ "TabPFNModel", "TabMModel", "TabMRankModel", + "TabMSetRankModel", "EnsembleListNetLoss", "EnsembleLambdaLoss", "CrossSectionSampler", diff --git a/src/training/components/models/tabm_setrank_model.py b/src/training/components/models/tabm_setrank_model.py new file mode 100644 index 0000000..933eaeb --- /dev/null +++ b/src/training/components/models/tabm_setrank_model.py @@ -0,0 +1,489 @@ +"""TabM + SetRank 排序模型实现 (TabM_SetRank) + +基于 TabM 特征提取 + SetRank 组内注意力建模。 +引入 ListNet/LambdaLoss 列表级排序损失,支持 AMP 混合精度与显存优化。 +适用于股票未来收益率的截面排序预测。 +""" + +from typing import Dict, Any, List, Optional, Tuple +from pathlib import Path +import pickle + +import numpy as np +import polars as pl +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +from torch.utils.data import DataLoader, TensorDataset, Sampler +from torch.cuda.amp import GradScaler +from tabm import TabM + +from src.training.components.base import BaseModel +from src.training.registry import register_model +from src.training.components.models.tabm_rank_model import ( + EnsembleListNetLoss, + EnsembleLambdaLoss, + GroupSampler, +) + + +class SetRankHead(nn.Module): + """标准的 SetRank 组内上下文注意力头 + + 输入: [N, E] (TabM集成预测) + 输出: [N, E] (经过截面交互后的集成预测) + """ + + def __init__( + self, + d_in: int, # TabM 的 Ensemble 数量 (例如 32) + n_heads: int = 4, + d_ff: int = 128, # SetRank 的隐藏特征维度 + dropout: float = 0.1, + ): + super().__init__() + # 1. 特征升维: 将 Logits 映射为高维语义特征 + self.feature_proj = nn.Linear(d_in, d_ff) + + # 2. 多头注意力层 (SetRank 核心) + # 注意:不要加入位置编码 (Positional Encoding),因为股票截面是一个无序集合(Set),必须保持排列等变性 + self.attn = nn.MultiheadAttention( + embed_dim=d_ff, num_heads=n_heads, dropout=dropout, batch_first=True + ) + self.norm1 = nn.LayerNorm(d_ff) + + # 3. 前馈神经网络 (标准 Transformer FFN 结构) + self.ffn = nn.Sequential( + nn.Linear(d_ff, d_ff * 2), + nn.GELU(), + nn.Dropout(dropout), + nn.Linear(d_ff * 2, d_ff), + ) + self.norm2 = nn.LayerNorm(d_ff) + + # 4. 恢复集成维度: [N, d_ff] -> [N, E] + self.output_proj = nn.Linear(d_ff, d_in) + + self._init_weights() + + def _init_weights(self): + for m in self.modules(): + if isinstance(m, nn.Linear): + nn.init.xavier_uniform_(m.weight) + elif isinstance(m, nn.LayerNorm): + nn.init.ones_(m.weight) + nn.init.zeros_(m.bias) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # x 初始维度: [N, E] + + # 升维: [N, d_ff] + h = self.feature_proj(x) + + # 增加 Batch 维度以适配 MultiheadAttention: [1, N, d_ff] + # B=1 (代表一天的截面), L=N (股票数量), D=d_ff + h = h.unsqueeze(0) + + # Attention 交互 + attn_out, _ = self.attn(h, h, h) + + # 残差块 1 + h = self.norm1(h + attn_out) + + # 残差块 2 (FFN) + ffn_out = self.ffn(h) + h = self.norm2(h + ffn_out) + + # 降维并去除 Batch 维度: [1, N, d_ff] -> [N, d_ff] + h = h.squeeze(0) + + # 输出回 Ensemble 维度: [N, E] + out_logits = self.output_proj(h) + + # 引入残差连接,保留原始 TabM 的预测分,仅用 SetRank 提供修正项 (极大地增加训练稳定性) + return x + out_logits + + +class TabMSetRankNet(nn.Module): + def __init__( + self, + n_features: int, + ensemble_size: int, + n_blocks: int, + d_block: int, + dropout: float, + setrank_heads: int, + setrank_hidden: int, + setrank_dropout: float, + ): + super().__init__() + # TabM 特征提取器 + self.tabm = TabM.make( + n_num_features=n_features, + cat_cardinalities=[], + d_out=1, + n_blocks=n_blocks, + d_block=d_block, + dropout=dropout, + k=ensemble_size, + ) + # 截面注意力修正器 + self.setrank_head = SetRankHead( + d_in=ensemble_size, + n_heads=setrank_heads, + d_ff=setrank_hidden, + dropout=setrank_dropout, + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # 1. 独立特征提取 -> [N, E] + tabm_out = self.tabm(x).squeeze(-1) + + # 2. 截面上下文交互修正 -> [N, E] + scores = self.setrank_head(tabm_out) + return scores + + +@register_model("tabm_setrank_rank") +class TabMSetRankModel(BaseModel): + """TabM + SetRank 学习排序模型 + + 使用统一的 nn.Module (TabMSetRankNet) 封装 TabM 与 SetRankHead。 + 支持 AMP 混合精度与显存优化。 + """ + + name = "tabm_setrank_rank" + + def __init__(self, params: Optional[Dict[str, Any]] = None): + self.params = params or {} + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + self.training_history_: Dict[str, List[float]] = { + "train_loss": [], + "val_ndcg": [], + } + self.feature_names_: Optional[List[str]] = None + + self.use_amp = self.params.get("use_amp", True) + self.scaler = GradScaler(enabled=self.use_amp) + + loss_type = self.params.get("loss_type", "listnet") + if loss_type == "lambda": + self.criterion = EnsembleLambdaLoss( + sigma=self.params.get("lambda_sigma", 1.0), + ndcg_weight_power=self.params.get("ndcg_weight_power", 1.0), + ) + elif loss_type == "weighted_listnet": + self.criterion = EnsembleListNetLoss( + topk_weight=self.params.get("topk_weight", 5.0) + ) + else: + self.criterion = EnsembleListNetLoss(topk_weight=1.0) + + n_features = self.params.get("n_features") + if n_features is None: + raise ValueError("TabMSetRankModel 初始化需要在 params 中传入 'n_features'") + ensemble_size = self.params.get("ensemble_size", 32) + self.model = TabMSetRankNet( + n_features=n_features, + ensemble_size=ensemble_size, + n_blocks=self.params.get("n_blocks", 3), + d_block=self.params.get("d_block", 256), + dropout=self.params.get("dropout", 0.1), + setrank_heads=self.params.get("setrank_heads", 4), + setrank_hidden=self.params.get("setrank_hidden", 128), + setrank_dropout=self.params.get("setrank_dropout", 0.1), + ).to(self.device) + + def _make_loader( + self, + X: np.ndarray, + y: Optional[np.ndarray] = None, + group: Optional[np.ndarray] = None, + shuffle_groups: bool = False, + ) -> DataLoader: + """创建 DataLoader (支持 Query/Group 截面打包)""" + X_tensor = torch.from_numpy(X) + if y is not None: + y_tensor = torch.from_numpy(y) + dataset = TensorDataset(X_tensor, y_tensor) + else: + dataset = TensorDataset(X_tensor) + + if group is not None: + sampler = GroupSampler(group, shuffle_groups=shuffle_groups) + return DataLoader( + dataset, + batch_sampler=sampler, + num_workers=self.params.get("num_workers", 0), + pin_memory=self.params.get("pin_memory", False), + ) + else: + batch_size = self.params.get("batch_size", 2048) + return DataLoader( + dataset, + batch_size=batch_size, + shuffle=False, + num_workers=self.params.get("num_workers", 0), + pin_memory=self.params.get("pin_memory", False), + ) + + def _validate_ndcg(self, val_loader: DataLoader, k: Optional[int] = None) -> float: + """验证模型 (使用 NDCG 排序指标)""" + from sklearn.metrics import ndcg_score + + assert self.model is not None, "模型未训练,无法验证" + self.model.eval() + ndcg_list = [] + + with torch.no_grad(): + for batch in val_loader: + if len(batch) != 2: + continue + bx, by = batch + bx = bx.to(self.device) + by = by.cpu().numpy() + if len(by) <= 1: + continue + + with torch.amp.autocast("cuda", enabled=self.use_amp): + scores = self.model(bx) # [N, E] + + # 对 Ensemble 维度求平均,得到最终排序分 [N] + preds = scores.mean(dim=1).cpu().numpy() + try: + ndcg_list.append(ndcg_score([by], [preds], k=k)) + except ValueError: + pass + return float(np.mean(ndcg_list)) if ndcg_list else 0.0 + + def fit( + self, + X: pl.DataFrame, + y: pl.Series, + group: Optional[np.ndarray] = None, + eval_set: Optional[Tuple] = None, + ) -> "TabMSetRankModel": + """训练排序模型""" + self.feature_names_ = list(X.columns) + X_np = X.to_numpy().astype(np.float32) + y_np = y.to_numpy().astype(np.float32) + + if group is None: + group = np.array([len(y_np)]) + if group.sum() != len(y_np): + raise ValueError( + f"group 数组的和 ({group.sum()}) 必须等于样本数 ({len(y_np)})" + ) + + train_loader = self._make_loader(X_np, y_np, group=group, shuffle_groups=True) + val_loader = None + if eval_set is not None: + X_val, y_val, group_val = eval_set + X_val_np = ( + X_val.to_numpy().astype(np.float32) + if isinstance(X_val, pl.DataFrame) + else X_val + ) + y_val_np = ( + y_val.to_numpy().astype(np.float32) + if isinstance(y_val, pl.Series) + else y_val + ) + if group_val is None: + group_val = np.array([len(y_val_np)]) + val_loader = self._make_loader( + X_val_np, y_val_np, group=group_val, shuffle_groups=False + ) + + n_features = X_np.shape[1] + expected_n_features = self.params.get("n_features") + if expected_n_features is not None and n_features != expected_n_features: + raise ValueError( + f"输入特征维度 ({n_features}) 与初始化时指定的 n_features ({expected_n_features}) 不一致" + ) + + optimizer = optim.AdamW( + self.model.parameters(), + lr=self.params.get("learning_rate", 1e-3), + weight_decay=self.params.get("weight_decay", 1e-5), + ) + epochs = self.params.get("epochs", 50) + scheduler = optim.lr_scheduler.CosineAnnealingLR( + optimizer, T_max=epochs, eta_min=1e-6 + ) + early_stopping_patience = self.params.get( + "early_stopping_patience", + self.params.get("early_stopping_round", 10), + ) + best_val_ndcg = -float("inf") + patience_counter = 0 + best_model_state = None + ndcg_k = self.params.get("ndcg_k", None) + + print(f"[TabM+SetRank] 开始训练... 设备: {self.device}, AMP: {self.use_amp}") + + for epoch in range(epochs): + self.model.train() + train_loss = 0.0 + n_train_batches = 0 + + for batch in train_loader: + if len(batch) != 2: + continue + bx, by = batch[0].to(self.device), batch[1].to(self.device) + + optimizer.zero_grad() + with torch.amp.autocast("cuda", enabled=self.use_amp): + scores = self.model(bx) # [N, E] + loss = self.criterion(scores, by) + + self.scaler.scale(loss).backward() + self.scaler.unscale_(optimizer) + torch.nn.utils.clip_grad_norm_( + self.model.parameters(), + max_norm=self.params.get("max_grad_norm", 1.0), + ) + self.scaler.step(optimizer) + self.scaler.update() + + train_loss += loss.item() + n_train_batches += 1 + + avg_train_loss = train_loss / max(n_train_batches, 1) + self.training_history_["train_loss"].append(avg_train_loss) + + if val_loader is not None: + val_ndcg = self._validate_ndcg(val_loader, k=ndcg_k) + self.training_history_["val_ndcg"].append(val_ndcg) + if val_ndcg > best_val_ndcg: + best_val_ndcg = val_ndcg + patience_counter = 0 + best_model_state = { + k: v.cpu().clone() for k, v in self.model.state_dict().items() + } + else: + patience_counter += 1 + + if (epoch + 1) % 5 == 0 or epoch == 0: + print( + f"[TabM+SetRank] Epoch {epoch + 1}/{epochs} | " + f"Loss: {avg_train_loss:.4f} | " + f"Val NDCG: {val_ndcg:.4f} (Best: {best_val_ndcg:.4f})" + ) + if patience_counter >= early_stopping_patience: + print(f"[TabM+SetRank] 早停于 epoch {epoch + 1}") + break + else: + if (epoch + 1) % 5 == 0 or epoch == 0: + print( + f"[TabM+SetRank] Epoch {epoch + 1}/{epochs} | " + f"Loss: {avg_train_loss:.4f}" + ) + scheduler.step() + + if best_model_state: + self.model.load_state_dict(best_model_state) + print(f"[TabM+SetRank] 恢复最佳权重 (Val NDCG: {best_val_ndcg:.4f})") + return self + + def predict( + self, X: pl.DataFrame, group: Optional[np.ndarray] = None + ) -> np.ndarray: + """预测排序分数""" + if self.model is None: + raise RuntimeError("模型未训练,请先调用fit()") + if self.feature_names_: + missing = [c for c in self.feature_names_ if c not in X.columns] + if missing: + raise ValueError(f"缺失特征: {missing}") + X = X.select(self.feature_names_) + + X_np = X.to_numpy().astype(np.float32) + loader = self._make_loader(X_np, group=group, shuffle_groups=False) + self.model.eval() + all_preds = [] + + with torch.no_grad(): + for batch in loader: + bx = batch[0].to(self.device) + with torch.amp.autocast("cuda", enabled=self.use_amp): + scores = self.model(bx) # [N, E] + + # 对 Ensemble 维度求平均,得到最终排序分 [N] + preds = scores.mean(dim=1) + all_preds.append(preds.cpu().numpy()) + return np.concatenate(all_preds) + + def get_evals_result(self) -> Optional[Dict[str, List[float]]]: + """获取训练评估结果""" + return self.training_history_ + + def feature_importance(self) -> None: + """TabM 没有内置特征重要性计算,返回 None。""" + return None + + def save(self, path: str | Path) -> None: + """保存模型""" + if self.model is None: + raise RuntimeError("模型未训练,无法保存") + path = Path(path) + path.parent.mkdir(parents=True, exist_ok=True) + torch.save(self.model.state_dict(), path.with_suffix(".pt")) + with open(path.with_suffix(".meta"), "wb") as f: + pickle.dump( + { + "params": self.params, + "feature_names": self.feature_names_, + "training_history": self.training_history_, + "device": str(self.device), + }, + f, + ) + + @classmethod + def load(cls, path: str | Path) -> "TabMSetRankModel": + """加载模型""" + path = Path(path) + with open(path.with_suffix(".meta"), "rb") as f: + meta = pickle.load(f) + meta_params = meta["params"] + feature_names = meta["feature_names"] + assert feature_names is not None, "加载的模型缺少 feature_names,无法重建网络" + meta_params["n_features"] = len(feature_names) + instance = cls(meta_params) + instance.feature_names_ = feature_names + instance.training_history_ = meta["training_history"] + ckpt = torch.load(path.with_suffix(".pt"), map_location=instance.device) + instance.model.load_state_dict(ckpt) + return instance + + def evaluate_ndcg( + self, + X: pl.DataFrame, + y: pl.Series, + group: np.ndarray, + k: Optional[int] = None, + ) -> float: + """评估 NDCG 指标""" + from sklearn.metrics import ndcg_score + + y_pred = self.predict(X) + y_true_array = y.to_numpy() if isinstance(y, pl.Series) else y + y_true_list = [] + y_score_list = [] + start_idx = 0 + for group_size in group: + end_idx = start_idx + group_size + y_true_list.append(y_true_array[start_idx:end_idx]) + y_score_list.append(y_pred[start_idx:end_idx]) + start_idx = end_idx + + ndcg_scores = [] + for yt, yp in zip(y_true_list, y_score_list): + if len(yt) > 1: + try: + ndcg_scores.append(ndcg_score([yt], [yp], k=k)) + except ValueError: + pass + return float(np.mean(ndcg_scores)) if ndcg_scores else 0.0 diff --git a/src/training/tasks/tabm_rank_task.py b/src/training/tasks/tabm_rank_task.py index 2006b51..0c68e93 100644 --- a/src/training/tasks/tabm_rank_task.py +++ b/src/training/tasks/tabm_rank_task.py @@ -7,7 +7,7 @@ - 支持 NDCG@k 评估 """ -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Type import numpy as np import polars as pl @@ -18,7 +18,7 @@ from src.training.components.models.tabm_rank_model import TabMRankModel class TabMRankTask(BaseTask): """TabM 排序学习任务 - 使用 TabMRankModel 进行排序学习训练。 + 使用 TabMRankModel(或兼容模型)进行排序学习训练。 将连续收益率转换为分位数标签进行训练。 支持指数化增益标签以增强 Top-K 关注。 """ @@ -30,6 +30,7 @@ class TabMRankTask(BaseTask): n_quantiles: int = 20, label_transform: Optional[str] = None, label_scale: float = 20.0, + model_class: Optional[Type] = None, ): """初始化排序学习任务 @@ -41,11 +42,14 @@ class TabMRankTask(BaseTask): - None: 标准分位数标签 (0, 1, ..., n_quantiles-1) - "exponential": 指数化增益: 2^(rank/scale) - 1 label_scale: 指数变换的缩放因子,用于控制增益幅度 + model_class: 模型类,默认 TabMRankModel。可传入兼容接口的其他模型, + 如 TabMSetRankModel。 """ super().__init__(model_params, label_name) self.n_quantiles = n_quantiles self.label_transform = label_transform self.label_scale = label_scale + self.model_class = model_class or TabMRankModel def prepare_labels(self, data: Dict[str, Dict]) -> Dict[str, Dict]: """准备标签(转换为分位数标签,可选指数化增益变换) @@ -135,7 +139,8 @@ class TabMRankTask(BaseTask): train_data: 训练数据 val_data: 验证数据 """ - self.model = TabMRankModel(params=self.model_params) + self.model_params["n_features"] = train_data["X"].shape[1] + self.model = self.model_class(params=self.model_params) self.model.fit( train_data["X"],