feat(training): 新增 TabM SetRank 模型并支持任务注入
- 添加 TabMSetRankModel 实现集合排序训练 - TabMRankTask 支持通过 model_class 注入兼容模型 - 启用 common.py 中的流动性因子
This commit is contained in:
@@ -12,6 +12,7 @@ from src.training.components.models.tabm_rank_model import (
|
||||
EnsembleListNetLoss,
|
||||
EnsembleLambdaLoss,
|
||||
)
|
||||
from src.training.components.models.tabm_setrank_model import TabMSetRankModel
|
||||
from src.training.components.models.cross_section_sampler import CrossSectionSampler
|
||||
from src.training.components.models.ensemble_quant_loss import EnsembleQuantLoss
|
||||
|
||||
@@ -21,6 +22,7 @@ __all__ = [
|
||||
"TabPFNModel",
|
||||
"TabMModel",
|
||||
"TabMRankModel",
|
||||
"TabMSetRankModel",
|
||||
"EnsembleListNetLoss",
|
||||
"EnsembleLambdaLoss",
|
||||
"CrossSectionSampler",
|
||||
|
||||
489
src/training/components/models/tabm_setrank_model.py
Normal file
489
src/training/components/models/tabm_setrank_model.py
Normal file
@@ -0,0 +1,489 @@
|
||||
"""TabM + SetRank 排序模型实现 (TabM_SetRank)
|
||||
|
||||
基于 TabM 特征提取 + SetRank 组内注意力建模。
|
||||
引入 ListNet/LambdaLoss 列表级排序损失,支持 AMP 混合精度与显存优化。
|
||||
适用于股票未来收益率的截面排序预测。
|
||||
"""
|
||||
|
||||
from typing import Dict, Any, List, Optional, Tuple
|
||||
from pathlib import Path
|
||||
import pickle
|
||||
|
||||
import numpy as np
|
||||
import polars as pl
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch.optim as optim
|
||||
from torch.utils.data import DataLoader, TensorDataset, Sampler
|
||||
from torch.cuda.amp import GradScaler
|
||||
from tabm import TabM
|
||||
|
||||
from src.training.components.base import BaseModel
|
||||
from src.training.registry import register_model
|
||||
from src.training.components.models.tabm_rank_model import (
|
||||
EnsembleListNetLoss,
|
||||
EnsembleLambdaLoss,
|
||||
GroupSampler,
|
||||
)
|
||||
|
||||
|
||||
class SetRankHead(nn.Module):
|
||||
"""标准的 SetRank 组内上下文注意力头
|
||||
|
||||
输入: [N, E] (TabM集成预测)
|
||||
输出: [N, E] (经过截面交互后的集成预测)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
d_in: int, # TabM 的 Ensemble 数量 (例如 32)
|
||||
n_heads: int = 4,
|
||||
d_ff: int = 128, # SetRank 的隐藏特征维度
|
||||
dropout: float = 0.1,
|
||||
):
|
||||
super().__init__()
|
||||
# 1. 特征升维: 将 Logits 映射为高维语义特征
|
||||
self.feature_proj = nn.Linear(d_in, d_ff)
|
||||
|
||||
# 2. 多头注意力层 (SetRank 核心)
|
||||
# 注意:不要加入位置编码 (Positional Encoding),因为股票截面是一个无序集合(Set),必须保持排列等变性
|
||||
self.attn = nn.MultiheadAttention(
|
||||
embed_dim=d_ff, num_heads=n_heads, dropout=dropout, batch_first=True
|
||||
)
|
||||
self.norm1 = nn.LayerNorm(d_ff)
|
||||
|
||||
# 3. 前馈神经网络 (标准 Transformer FFN 结构)
|
||||
self.ffn = nn.Sequential(
|
||||
nn.Linear(d_ff, d_ff * 2),
|
||||
nn.GELU(),
|
||||
nn.Dropout(dropout),
|
||||
nn.Linear(d_ff * 2, d_ff),
|
||||
)
|
||||
self.norm2 = nn.LayerNorm(d_ff)
|
||||
|
||||
# 4. 恢复集成维度: [N, d_ff] -> [N, E]
|
||||
self.output_proj = nn.Linear(d_ff, d_in)
|
||||
|
||||
self._init_weights()
|
||||
|
||||
def _init_weights(self):
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Linear):
|
||||
nn.init.xavier_uniform_(m.weight)
|
||||
elif isinstance(m, nn.LayerNorm):
|
||||
nn.init.ones_(m.weight)
|
||||
nn.init.zeros_(m.bias)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
# x 初始维度: [N, E]
|
||||
|
||||
# 升维: [N, d_ff]
|
||||
h = self.feature_proj(x)
|
||||
|
||||
# 增加 Batch 维度以适配 MultiheadAttention: [1, N, d_ff]
|
||||
# B=1 (代表一天的截面), L=N (股票数量), D=d_ff
|
||||
h = h.unsqueeze(0)
|
||||
|
||||
# Attention 交互
|
||||
attn_out, _ = self.attn(h, h, h)
|
||||
|
||||
# 残差块 1
|
||||
h = self.norm1(h + attn_out)
|
||||
|
||||
# 残差块 2 (FFN)
|
||||
ffn_out = self.ffn(h)
|
||||
h = self.norm2(h + ffn_out)
|
||||
|
||||
# 降维并去除 Batch 维度: [1, N, d_ff] -> [N, d_ff]
|
||||
h = h.squeeze(0)
|
||||
|
||||
# 输出回 Ensemble 维度: [N, E]
|
||||
out_logits = self.output_proj(h)
|
||||
|
||||
# 引入残差连接,保留原始 TabM 的预测分,仅用 SetRank 提供修正项 (极大地增加训练稳定性)
|
||||
return x + out_logits
|
||||
|
||||
|
||||
class TabMSetRankNet(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
n_features: int,
|
||||
ensemble_size: int,
|
||||
n_blocks: int,
|
||||
d_block: int,
|
||||
dropout: float,
|
||||
setrank_heads: int,
|
||||
setrank_hidden: int,
|
||||
setrank_dropout: float,
|
||||
):
|
||||
super().__init__()
|
||||
# TabM 特征提取器
|
||||
self.tabm = TabM.make(
|
||||
n_num_features=n_features,
|
||||
cat_cardinalities=[],
|
||||
d_out=1,
|
||||
n_blocks=n_blocks,
|
||||
d_block=d_block,
|
||||
dropout=dropout,
|
||||
k=ensemble_size,
|
||||
)
|
||||
# 截面注意力修正器
|
||||
self.setrank_head = SetRankHead(
|
||||
d_in=ensemble_size,
|
||||
n_heads=setrank_heads,
|
||||
d_ff=setrank_hidden,
|
||||
dropout=setrank_dropout,
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
# 1. 独立特征提取 -> [N, E]
|
||||
tabm_out = self.tabm(x).squeeze(-1)
|
||||
|
||||
# 2. 截面上下文交互修正 -> [N, E]
|
||||
scores = self.setrank_head(tabm_out)
|
||||
return scores
|
||||
|
||||
|
||||
@register_model("tabm_setrank_rank")
|
||||
class TabMSetRankModel(BaseModel):
|
||||
"""TabM + SetRank 学习排序模型
|
||||
|
||||
使用统一的 nn.Module (TabMSetRankNet) 封装 TabM 与 SetRankHead。
|
||||
支持 AMP 混合精度与显存优化。
|
||||
"""
|
||||
|
||||
name = "tabm_setrank_rank"
|
||||
|
||||
def __init__(self, params: Optional[Dict[str, Any]] = None):
|
||||
self.params = params or {}
|
||||
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
self.training_history_: Dict[str, List[float]] = {
|
||||
"train_loss": [],
|
||||
"val_ndcg": [],
|
||||
}
|
||||
self.feature_names_: Optional[List[str]] = None
|
||||
|
||||
self.use_amp = self.params.get("use_amp", True)
|
||||
self.scaler = GradScaler(enabled=self.use_amp)
|
||||
|
||||
loss_type = self.params.get("loss_type", "listnet")
|
||||
if loss_type == "lambda":
|
||||
self.criterion = EnsembleLambdaLoss(
|
||||
sigma=self.params.get("lambda_sigma", 1.0),
|
||||
ndcg_weight_power=self.params.get("ndcg_weight_power", 1.0),
|
||||
)
|
||||
elif loss_type == "weighted_listnet":
|
||||
self.criterion = EnsembleListNetLoss(
|
||||
topk_weight=self.params.get("topk_weight", 5.0)
|
||||
)
|
||||
else:
|
||||
self.criterion = EnsembleListNetLoss(topk_weight=1.0)
|
||||
|
||||
n_features = self.params.get("n_features")
|
||||
if n_features is None:
|
||||
raise ValueError("TabMSetRankModel 初始化需要在 params 中传入 'n_features'")
|
||||
ensemble_size = self.params.get("ensemble_size", 32)
|
||||
self.model = TabMSetRankNet(
|
||||
n_features=n_features,
|
||||
ensemble_size=ensemble_size,
|
||||
n_blocks=self.params.get("n_blocks", 3),
|
||||
d_block=self.params.get("d_block", 256),
|
||||
dropout=self.params.get("dropout", 0.1),
|
||||
setrank_heads=self.params.get("setrank_heads", 4),
|
||||
setrank_hidden=self.params.get("setrank_hidden", 128),
|
||||
setrank_dropout=self.params.get("setrank_dropout", 0.1),
|
||||
).to(self.device)
|
||||
|
||||
def _make_loader(
|
||||
self,
|
||||
X: np.ndarray,
|
||||
y: Optional[np.ndarray] = None,
|
||||
group: Optional[np.ndarray] = None,
|
||||
shuffle_groups: bool = False,
|
||||
) -> DataLoader:
|
||||
"""创建 DataLoader (支持 Query/Group 截面打包)"""
|
||||
X_tensor = torch.from_numpy(X)
|
||||
if y is not None:
|
||||
y_tensor = torch.from_numpy(y)
|
||||
dataset = TensorDataset(X_tensor, y_tensor)
|
||||
else:
|
||||
dataset = TensorDataset(X_tensor)
|
||||
|
||||
if group is not None:
|
||||
sampler = GroupSampler(group, shuffle_groups=shuffle_groups)
|
||||
return DataLoader(
|
||||
dataset,
|
||||
batch_sampler=sampler,
|
||||
num_workers=self.params.get("num_workers", 0),
|
||||
pin_memory=self.params.get("pin_memory", False),
|
||||
)
|
||||
else:
|
||||
batch_size = self.params.get("batch_size", 2048)
|
||||
return DataLoader(
|
||||
dataset,
|
||||
batch_size=batch_size,
|
||||
shuffle=False,
|
||||
num_workers=self.params.get("num_workers", 0),
|
||||
pin_memory=self.params.get("pin_memory", False),
|
||||
)
|
||||
|
||||
def _validate_ndcg(self, val_loader: DataLoader, k: Optional[int] = None) -> float:
|
||||
"""验证模型 (使用 NDCG 排序指标)"""
|
||||
from sklearn.metrics import ndcg_score
|
||||
|
||||
assert self.model is not None, "模型未训练,无法验证"
|
||||
self.model.eval()
|
||||
ndcg_list = []
|
||||
|
||||
with torch.no_grad():
|
||||
for batch in val_loader:
|
||||
if len(batch) != 2:
|
||||
continue
|
||||
bx, by = batch
|
||||
bx = bx.to(self.device)
|
||||
by = by.cpu().numpy()
|
||||
if len(by) <= 1:
|
||||
continue
|
||||
|
||||
with torch.amp.autocast("cuda", enabled=self.use_amp):
|
||||
scores = self.model(bx) # [N, E]
|
||||
|
||||
# 对 Ensemble 维度求平均,得到最终排序分 [N]
|
||||
preds = scores.mean(dim=1).cpu().numpy()
|
||||
try:
|
||||
ndcg_list.append(ndcg_score([by], [preds], k=k))
|
||||
except ValueError:
|
||||
pass
|
||||
return float(np.mean(ndcg_list)) if ndcg_list else 0.0
|
||||
|
||||
def fit(
|
||||
self,
|
||||
X: pl.DataFrame,
|
||||
y: pl.Series,
|
||||
group: Optional[np.ndarray] = None,
|
||||
eval_set: Optional[Tuple] = None,
|
||||
) -> "TabMSetRankModel":
|
||||
"""训练排序模型"""
|
||||
self.feature_names_ = list(X.columns)
|
||||
X_np = X.to_numpy().astype(np.float32)
|
||||
y_np = y.to_numpy().astype(np.float32)
|
||||
|
||||
if group is None:
|
||||
group = np.array([len(y_np)])
|
||||
if group.sum() != len(y_np):
|
||||
raise ValueError(
|
||||
f"group 数组的和 ({group.sum()}) 必须等于样本数 ({len(y_np)})"
|
||||
)
|
||||
|
||||
train_loader = self._make_loader(X_np, y_np, group=group, shuffle_groups=True)
|
||||
val_loader = None
|
||||
if eval_set is not None:
|
||||
X_val, y_val, group_val = eval_set
|
||||
X_val_np = (
|
||||
X_val.to_numpy().astype(np.float32)
|
||||
if isinstance(X_val, pl.DataFrame)
|
||||
else X_val
|
||||
)
|
||||
y_val_np = (
|
||||
y_val.to_numpy().astype(np.float32)
|
||||
if isinstance(y_val, pl.Series)
|
||||
else y_val
|
||||
)
|
||||
if group_val is None:
|
||||
group_val = np.array([len(y_val_np)])
|
||||
val_loader = self._make_loader(
|
||||
X_val_np, y_val_np, group=group_val, shuffle_groups=False
|
||||
)
|
||||
|
||||
n_features = X_np.shape[1]
|
||||
expected_n_features = self.params.get("n_features")
|
||||
if expected_n_features is not None and n_features != expected_n_features:
|
||||
raise ValueError(
|
||||
f"输入特征维度 ({n_features}) 与初始化时指定的 n_features ({expected_n_features}) 不一致"
|
||||
)
|
||||
|
||||
optimizer = optim.AdamW(
|
||||
self.model.parameters(),
|
||||
lr=self.params.get("learning_rate", 1e-3),
|
||||
weight_decay=self.params.get("weight_decay", 1e-5),
|
||||
)
|
||||
epochs = self.params.get("epochs", 50)
|
||||
scheduler = optim.lr_scheduler.CosineAnnealingLR(
|
||||
optimizer, T_max=epochs, eta_min=1e-6
|
||||
)
|
||||
early_stopping_patience = self.params.get(
|
||||
"early_stopping_patience",
|
||||
self.params.get("early_stopping_round", 10),
|
||||
)
|
||||
best_val_ndcg = -float("inf")
|
||||
patience_counter = 0
|
||||
best_model_state = None
|
||||
ndcg_k = self.params.get("ndcg_k", None)
|
||||
|
||||
print(f"[TabM+SetRank] 开始训练... 设备: {self.device}, AMP: {self.use_amp}")
|
||||
|
||||
for epoch in range(epochs):
|
||||
self.model.train()
|
||||
train_loss = 0.0
|
||||
n_train_batches = 0
|
||||
|
||||
for batch in train_loader:
|
||||
if len(batch) != 2:
|
||||
continue
|
||||
bx, by = batch[0].to(self.device), batch[1].to(self.device)
|
||||
|
||||
optimizer.zero_grad()
|
||||
with torch.amp.autocast("cuda", enabled=self.use_amp):
|
||||
scores = self.model(bx) # [N, E]
|
||||
loss = self.criterion(scores, by)
|
||||
|
||||
self.scaler.scale(loss).backward()
|
||||
self.scaler.unscale_(optimizer)
|
||||
torch.nn.utils.clip_grad_norm_(
|
||||
self.model.parameters(),
|
||||
max_norm=self.params.get("max_grad_norm", 1.0),
|
||||
)
|
||||
self.scaler.step(optimizer)
|
||||
self.scaler.update()
|
||||
|
||||
train_loss += loss.item()
|
||||
n_train_batches += 1
|
||||
|
||||
avg_train_loss = train_loss / max(n_train_batches, 1)
|
||||
self.training_history_["train_loss"].append(avg_train_loss)
|
||||
|
||||
if val_loader is not None:
|
||||
val_ndcg = self._validate_ndcg(val_loader, k=ndcg_k)
|
||||
self.training_history_["val_ndcg"].append(val_ndcg)
|
||||
if val_ndcg > best_val_ndcg:
|
||||
best_val_ndcg = val_ndcg
|
||||
patience_counter = 0
|
||||
best_model_state = {
|
||||
k: v.cpu().clone() for k, v in self.model.state_dict().items()
|
||||
}
|
||||
else:
|
||||
patience_counter += 1
|
||||
|
||||
if (epoch + 1) % 5 == 0 or epoch == 0:
|
||||
print(
|
||||
f"[TabM+SetRank] Epoch {epoch + 1}/{epochs} | "
|
||||
f"Loss: {avg_train_loss:.4f} | "
|
||||
f"Val NDCG: {val_ndcg:.4f} (Best: {best_val_ndcg:.4f})"
|
||||
)
|
||||
if patience_counter >= early_stopping_patience:
|
||||
print(f"[TabM+SetRank] 早停于 epoch {epoch + 1}")
|
||||
break
|
||||
else:
|
||||
if (epoch + 1) % 5 == 0 or epoch == 0:
|
||||
print(
|
||||
f"[TabM+SetRank] Epoch {epoch + 1}/{epochs} | "
|
||||
f"Loss: {avg_train_loss:.4f}"
|
||||
)
|
||||
scheduler.step()
|
||||
|
||||
if best_model_state:
|
||||
self.model.load_state_dict(best_model_state)
|
||||
print(f"[TabM+SetRank] 恢复最佳权重 (Val NDCG: {best_val_ndcg:.4f})")
|
||||
return self
|
||||
|
||||
def predict(
|
||||
self, X: pl.DataFrame, group: Optional[np.ndarray] = None
|
||||
) -> np.ndarray:
|
||||
"""预测排序分数"""
|
||||
if self.model is None:
|
||||
raise RuntimeError("模型未训练,请先调用fit()")
|
||||
if self.feature_names_:
|
||||
missing = [c for c in self.feature_names_ if c not in X.columns]
|
||||
if missing:
|
||||
raise ValueError(f"缺失特征: {missing}")
|
||||
X = X.select(self.feature_names_)
|
||||
|
||||
X_np = X.to_numpy().astype(np.float32)
|
||||
loader = self._make_loader(X_np, group=group, shuffle_groups=False)
|
||||
self.model.eval()
|
||||
all_preds = []
|
||||
|
||||
with torch.no_grad():
|
||||
for batch in loader:
|
||||
bx = batch[0].to(self.device)
|
||||
with torch.amp.autocast("cuda", enabled=self.use_amp):
|
||||
scores = self.model(bx) # [N, E]
|
||||
|
||||
# 对 Ensemble 维度求平均,得到最终排序分 [N]
|
||||
preds = scores.mean(dim=1)
|
||||
all_preds.append(preds.cpu().numpy())
|
||||
return np.concatenate(all_preds)
|
||||
|
||||
def get_evals_result(self) -> Optional[Dict[str, List[float]]]:
|
||||
"""获取训练评估结果"""
|
||||
return self.training_history_
|
||||
|
||||
def feature_importance(self) -> None:
|
||||
"""TabM 没有内置特征重要性计算,返回 None。"""
|
||||
return None
|
||||
|
||||
def save(self, path: str | Path) -> None:
|
||||
"""保存模型"""
|
||||
if self.model is None:
|
||||
raise RuntimeError("模型未训练,无法保存")
|
||||
path = Path(path)
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
torch.save(self.model.state_dict(), path.with_suffix(".pt"))
|
||||
with open(path.with_suffix(".meta"), "wb") as f:
|
||||
pickle.dump(
|
||||
{
|
||||
"params": self.params,
|
||||
"feature_names": self.feature_names_,
|
||||
"training_history": self.training_history_,
|
||||
"device": str(self.device),
|
||||
},
|
||||
f,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def load(cls, path: str | Path) -> "TabMSetRankModel":
|
||||
"""加载模型"""
|
||||
path = Path(path)
|
||||
with open(path.with_suffix(".meta"), "rb") as f:
|
||||
meta = pickle.load(f)
|
||||
meta_params = meta["params"]
|
||||
feature_names = meta["feature_names"]
|
||||
assert feature_names is not None, "加载的模型缺少 feature_names,无法重建网络"
|
||||
meta_params["n_features"] = len(feature_names)
|
||||
instance = cls(meta_params)
|
||||
instance.feature_names_ = feature_names
|
||||
instance.training_history_ = meta["training_history"]
|
||||
ckpt = torch.load(path.with_suffix(".pt"), map_location=instance.device)
|
||||
instance.model.load_state_dict(ckpt)
|
||||
return instance
|
||||
|
||||
def evaluate_ndcg(
|
||||
self,
|
||||
X: pl.DataFrame,
|
||||
y: pl.Series,
|
||||
group: np.ndarray,
|
||||
k: Optional[int] = None,
|
||||
) -> float:
|
||||
"""评估 NDCG 指标"""
|
||||
from sklearn.metrics import ndcg_score
|
||||
|
||||
y_pred = self.predict(X)
|
||||
y_true_array = y.to_numpy() if isinstance(y, pl.Series) else y
|
||||
y_true_list = []
|
||||
y_score_list = []
|
||||
start_idx = 0
|
||||
for group_size in group:
|
||||
end_idx = start_idx + group_size
|
||||
y_true_list.append(y_true_array[start_idx:end_idx])
|
||||
y_score_list.append(y_pred[start_idx:end_idx])
|
||||
start_idx = end_idx
|
||||
|
||||
ndcg_scores = []
|
||||
for yt, yp in zip(y_true_list, y_score_list):
|
||||
if len(yt) > 1:
|
||||
try:
|
||||
ndcg_scores.append(ndcg_score([yt], [yp], k=k))
|
||||
except ValueError:
|
||||
pass
|
||||
return float(np.mean(ndcg_scores)) if ndcg_scores else 0.0
|
||||
@@ -7,7 +7,7 @@
|
||||
- 支持 NDCG@k 评估
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, Dict, List, Optional, Type
|
||||
import numpy as np
|
||||
import polars as pl
|
||||
|
||||
@@ -18,7 +18,7 @@ from src.training.components.models.tabm_rank_model import TabMRankModel
|
||||
class TabMRankTask(BaseTask):
|
||||
"""TabM 排序学习任务
|
||||
|
||||
使用 TabMRankModel 进行排序学习训练。
|
||||
使用 TabMRankModel(或兼容模型)进行排序学习训练。
|
||||
将连续收益率转换为分位数标签进行训练。
|
||||
支持指数化增益标签以增强 Top-K 关注。
|
||||
"""
|
||||
@@ -30,6 +30,7 @@ class TabMRankTask(BaseTask):
|
||||
n_quantiles: int = 20,
|
||||
label_transform: Optional[str] = None,
|
||||
label_scale: float = 20.0,
|
||||
model_class: Optional[Type] = None,
|
||||
):
|
||||
"""初始化排序学习任务
|
||||
|
||||
@@ -41,11 +42,14 @@ class TabMRankTask(BaseTask):
|
||||
- None: 标准分位数标签 (0, 1, ..., n_quantiles-1)
|
||||
- "exponential": 指数化增益: 2^(rank/scale) - 1
|
||||
label_scale: 指数变换的缩放因子,用于控制增益幅度
|
||||
model_class: 模型类,默认 TabMRankModel。可传入兼容接口的其他模型,
|
||||
如 TabMSetRankModel。
|
||||
"""
|
||||
super().__init__(model_params, label_name)
|
||||
self.n_quantiles = n_quantiles
|
||||
self.label_transform = label_transform
|
||||
self.label_scale = label_scale
|
||||
self.model_class = model_class or TabMRankModel
|
||||
|
||||
def prepare_labels(self, data: Dict[str, Dict]) -> Dict[str, Dict]:
|
||||
"""准备标签(转换为分位数标签,可选指数化增益变换)
|
||||
@@ -135,7 +139,8 @@ class TabMRankTask(BaseTask):
|
||||
train_data: 训练数据
|
||||
val_data: 验证数据
|
||||
"""
|
||||
self.model = TabMRankModel(params=self.model_params)
|
||||
self.model_params["n_features"] = train_data["X"].shape[1]
|
||||
self.model = self.model_class(params=self.model_params)
|
||||
|
||||
self.model.fit(
|
||||
train_data["X"],
|
||||
|
||||
Reference in New Issue
Block a user