feat(training): TabM模型量化交易优化

- 新增 CrossSectionSampler 支持截面数据采样(按交易日批处理)
- 新增 EnsembleQuantLoss (Huber + IC) 替代 MSE 作为损失函数
- 重构 TabMModel 支持量化场景:Rank IC 作为验证指标、CosineAnnealingLR学习率调度、梯度裁剪
- 支持 date_col 参数和特征对齐
- 更新实验配置 batch_size 2048 和 weight_decay 等超参数
This commit is contained in:
2026-04-01 00:20:05 +08:00
parent 36a3ccbcc8
commit c143815443
9 changed files with 492 additions and 60 deletions

View File

@@ -260,7 +260,7 @@ MODEL_PARAMS = {
# ==================== 集成机制 ====================
"ensemble_size": 32, # 内置集成大小(模拟 32 个模型集成)
# ==================== 训练参数 ====================
"batch_size": 1024, # 批次大小
"batch_size": 2048, # 批次大小
"learning_rate": 1e-3, # 学习率
"weight_decay": 1e-5, # 权重衰减
"epochs": 100, # 训练轮数

View File

@@ -7,5 +7,14 @@ from src.training.components.models.lightgbm import LightGBMModel
from src.training.components.models.lightgbm_lambdarank import LightGBMLambdaRankModel
from src.training.components.models.tabpfn_model import TabPFNModel
from src.training.components.models.tabm_model import TabMModel
from src.training.components.models.cross_section_sampler import CrossSectionSampler
from src.training.components.models.ensemble_quant_loss import EnsembleQuantLoss
__all__ = ["LightGBMModel", "LightGBMLambdaRankModel", "TabPFNModel", "TabMModel"]
__all__ = [
"LightGBMModel",
"LightGBMLambdaRankModel",
"TabPFNModel",
"TabMModel",
"CrossSectionSampler",
"EnsembleQuantLoss",
]

View File

@@ -0,0 +1,59 @@
"""截面数据采样器
用于量化交易场景,确保每个批次包含同一天的所有股票数据。
实现横截面批处理,避免全局随机打乱。
"""
import numpy as np
from torch.utils.data import Sampler
class CrossSectionSampler(Sampler):
"""截面数据采样器
保证每次产出的 indices 属于同一个交易日。
适用于量化选股场景:模型每轮前向传播面对的都是当天的全市场股票。
Attributes:
date_to_indices: 日期到索引列表的映射
unique_dates: 唯一日期列表
shuffle_days: 是否打乱日期顺序
"""
def __init__(self, dates: np.ndarray, shuffle_days: bool = True):
"""初始化采样器
Args:
dates: 日期数组,每个元素对应一行数据
shuffle_days: 是否打乱日期的训练顺序,但同一天的数据始终在一起
"""
# 记录每个日期对应的所有行索引
self.date_to_indices = {}
for idx, date in enumerate(dates):
date_str = str(date)
if date_str not in self.date_to_indices:
self.date_to_indices[date_str] = []
self.date_to_indices[date_str].append(idx)
self.unique_dates = list(self.date_to_indices.keys())
self.shuffle_days = shuffle_days
def __iter__(self):
"""迭代生成批次索引
Yields:
list: 同一日期的所有样本索引
"""
dates = self.unique_dates.copy()
if self.shuffle_days:
np.random.shuffle(dates) # 打乱日期的训练顺序
for date in dates:
indices = self.date_to_indices[date].copy()
# 在截面内打乱股票顺序,防止顺序带来的隐性 bias
np.random.shuffle(indices)
yield indices
def __len__(self):
"""返回批次数量(等于日期数量)"""
return len(self.unique_dates)

View File

@@ -0,0 +1,86 @@
"""量化专用组合损失函数
结合 Huber Loss处理极值和 IC Loss优化排序
针对 TabM 的集成输出设计,为每个集成成员独立计算损失以保持多样性。
"""
import torch
import torch.nn as nn
class EnsembleQuantLoss(nn.Module):
"""Ensemble 量化组合损失函数
组合 Huber Loss 和 IC Loss:
- Huber Loss: 对收益率极值鲁棒,避免梯度爆炸
- IC Loss: 直接优化预测与目标的秩相关性
针对 TabM 的集成输出设计,为每个集成成员独立计算损失,
保持集成多样性。
注意std() 使用 unbiased=False 以保持数学一致性
(IC 计算使用除以 N而非 N-1)
Attributes:
alpha: Huber Loss 权重,(1-alpha) 为 IC Loss 权重
huber: HuberLoss 实例
ensemble_size: 集成成员数量
"""
def __init__(self, alpha: float = 0.5, ensemble_size: int = 32):
"""初始化损失函数
Args:
alpha: Huber Loss 权重,范围 [0, 1]
ensemble_size: TabM 集成大小
"""
super().__init__()
self.alpha = alpha
self.huber = nn.HuberLoss(reduction="mean")
self.ensemble_size = ensemble_size
def forward(self, preds: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
"""计算组合损失
Args:
preds: 预测值,形状 [Batch(单日股票数), Ensemble_size]
target: 目标值,形状 [Batch(单日股票数)]
Returns:
标量损失值
"""
total_loss = 0.0
# 为了稳定,如果单日截面股票数太少,直接回退到 Huber
if preds.shape[0] < 10:
for i in range(self.ensemble_size):
total_loss += self.huber(preds[:, i], target)
return total_loss / self.ensemble_size
# 预计算目标的标准化值(在所有集成成员间复用)
target_mean = target.mean()
# 【关键】使用 unbiased=False 保持数学一致性
target_std = target.std(unbiased=False) + 1e-8
target_norm = (target - target_mean) / target_std
# 对每个 ensemble 成员独立计算组合 Loss
for i in range(self.ensemble_size):
pred_i = preds[:, i]
# 1. Huber Loss (处理极值比 MSE 更好)
h_loss = self.huber(pred_i, target)
# 2. IC Loss (皮尔逊相关系数)
pred_mean = pred_i.mean()
# 【关键】使用 unbiased=False 保持数学一致性
pred_std = pred_i.std(unbiased=False) + 1e-8
pred_norm = (pred_i - pred_mean) / pred_std
# 相关系数
ic = (pred_norm * target_norm).mean()
# 损失函数化:希望 IC 越大越好,所以用 1 - IC
ic_loss = 1.0 - ic
total_loss += self.alpha * h_loss + (1.0 - self.alpha) * ic_loss
return total_loss / self.ensemble_size

View File

@@ -10,6 +10,7 @@ import pickle
import numpy as np
import polars as pl
import scipy.stats as stats
import torch
import torch.nn as nn
import torch.optim as optim
@@ -17,6 +18,8 @@ from torch.utils.data import DataLoader, TensorDataset
from tabm import TabM
from src.training.components.base import BaseModel
from src.training.components.models.cross_section_sampler import CrossSectionSampler
from src.training.components.models.ensemble_quant_loss import EnsembleQuantLoss
from src.training.registry import register_model
@@ -69,14 +72,23 @@ class TabMModel(BaseModel):
self.criterion = nn.MSELoss()
def _make_loader(
self, X: np.ndarray, y: Optional[np.ndarray] = None, shuffle: bool = False
self,
X: np.ndarray,
y: Optional[np.ndarray] = None,
dates: Optional[np.ndarray] = None,
shuffle_days: bool = False,
) -> DataLoader:
"""创建DataLoader
支持两种模式:
1. 截面模式 (dates provided): 每天的数据作为一个独立批次
2. 普通模式 (dates is None): 使用固定batch_size
Args:
X: 特征数组 [N, n_features]
y: 标签数组 [N] 或 None
shuffle: 是否打乱数据
dates: 日期数组 [N],用于截面采样
shuffle_days: 是否打乱日期顺序
Returns:
DataLoader实例
@@ -86,106 +98,159 @@ class TabMModel(BaseModel):
else:
dataset = TensorDataset(torch.from_numpy(X))
if dates is not None:
# 使用截面 Sampler
sampler = CrossSectionSampler(dates, shuffle_days=shuffle_days)
# 注意:使用 batch_sampler 时,不需要指定 batch_size 和 shuffle
return DataLoader(dataset, batch_sampler=sampler)
else:
# 预测时如果没有 date退化为普通 Loader
batch_size = self.params.get("batch_size", 1024)
return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle)
return DataLoader(dataset, batch_size=batch_size, shuffle=False)
def _validate(self, val_loader: DataLoader) -> float:
"""验证模型
"""验证模型(使用 Rank IC 作为指标)
Args:
val_loader: 验证数据加载器
Returns:
平均验证损失
平均 Rank IC斯皮尔曼秩相关系数
"""
self.model.eval()
total_loss = 0.0
n_batches = 0
ic_list = []
with torch.no_grad():
for batch in val_loader:
if len(batch) == 2:
if len(batch) != 2:
continue
bx, by = batch
bx, by = bx.to(self.device), by.to(self.device)
else:
bx = batch[0].to(self.device)
by = None
bx = bx.to(self.device)
by = by.cpu().numpy()
# 如果截面股票数太少,跳过
if len(by) < 10:
continue
# 预测时取集成成员均值
outputs = self.model(bx) # [B, E, 1]
preds = outputs.mean(dim=1).squeeze(-1) # [B]
# 预测时取均值
preds = outputs.mean(dim=1).squeeze(-1).cpu().numpy()
if by is not None:
loss = self.criterion(preds, by).item()
total_loss += loss
n_batches += 1
# 计算斯皮尔曼秩相关系数 (Rank IC)
rank_ic, _ = stats.spearmanr(preds, by)
return total_loss / max(n_batches, 1)
# 如果算不出 rank_ic记为 0
if np.isnan(rank_ic):
rank_ic = 0.0
ic_list.append(rank_ic)
# 返回验证集的平均 Rank IC
return np.mean(ic_list) if len(ic_list) > 0 else 0.0
def fit(
self, X: pl.DataFrame, y: pl.Series, eval_set: Optional[tuple] = None
self,
X: pl.DataFrame,
y: pl.Series,
eval_set: Optional[tuple] = None,
date_col: Optional[str] = None,
) -> "TabMModel":
"""训练TabM模型
"""训练TabM模型(量化优化版本)
训练策略:
1. 对所有集成成员独立计算Loss保持多样性
2. 验证和预测时取ensemble成员均值
1. 支持截面数据采样(按交易日批处理)
2. 使用 EnsembleQuantLoss (Huber + IC) 替代 MSE
3. CosineAnnealingLR 学习率调度
4. 梯度裁剪防止异常数据影响
5. 以 Rank IC 作为验证指标和早停依据
Args:
X: 训练特征DataFrame
X: 训练特征DataFrame,可包含日期列
y: 训练标签Series
eval_set: 验证集元组 (X_val, y_val),可选
date_col: 日期列名称用于截面采样。如果为None使用普通批次
Returns:
self
"""
# 保存特征名称
self.feature_names_ = X.columns
# 【修复】初始化 val_ic 训练历史
if "val_ic" not in self.training_history_:
self.training_history_["val_ic"] = []
# 保存特征名称并处理日期列
self.feature_names_ = list(X.columns)
dates = None
if date_col and date_col in X.columns:
dates = X[date_col].to_numpy()
X = X.drop(date_col)
self.feature_names_ = [c for c in self.feature_names_ if c != date_col]
self.params["date_col"] = date_col # 保存供 predict 使用
# 【关键】数据类型强制转换为float32
# PyTorch对float64支持较差避免使用Polars/Numpy默认类型
X_np = X.to_numpy().astype(np.float32)
y_np = y.to_numpy().astype(np.float32)
# 创建DataLoader
train_loader = self._make_loader(X_np, y_np, shuffle=True)
# 创建DataLoader(支持截面采样)
train_loader = self._make_loader(X_np, y_np, dates=dates, shuffle_days=True)
val_loader = None
val_dates = None
if eval_set is not None:
X_val, y_val = eval_set
# 处理验证集的日期列
if date_col and date_col in X_val.columns:
val_dates = X_val[date_col].to_numpy()
X_val = X_val.drop(date_col)
X_val_np = X_val.to_numpy().astype(np.float32)
y_val_np = y_val.to_numpy().astype(np.float32)
val_loader = self._make_loader(X_val_np, y_val_np, shuffle=False)
val_loader = self._make_loader(
X_val_np, y_val_np, dates=val_dates, shuffle_days=False
)
n_features = X_np.shape[1]
ensemble_size = self.params.get("ensemble_size", 32)
# 初始化TabM模型使用TabM.make()自动填充默认参数
# 初始化TabM模型
self.model = TabM.make(
n_num_features=n_features,
cat_cardinalities=[],
d_out=1, # 回归任务输出维度为1
d_out=1,
n_blocks=self.params.get("n_blocks", 3),
d_block=self.params.get("d_block", 256),
dropout=self.params.get("dropout", 0.1),
k=ensemble_size, # 集成大小
k=ensemble_size,
).to(self.device)
# 优化器
# 优化器(增加 weight_decay 用于正则化)
optimizer = optim.AdamW(
self.model.parameters(),
lr=self.params.get("learning_rate", 1e-3),
weight_decay=self.params.get("weight_decay", 1e-5),
weight_decay=self.params.get("weight_decay", 1e-4),
)
# 学习率调度器CosineAnnealingLR 平滑衰减
epochs = self.params.get("epochs", 50)
scheduler = optim.lr_scheduler.CosineAnnealingLR(
optimizer, T_max=epochs, eta_min=1e-6
)
# 使用 EnsembleQuantLoss 替代 MSE
self.criterion = EnsembleQuantLoss(
alpha=self.params.get("loss_alpha", 0.5), ensemble_size=ensemble_size
)
# 训练参数
epochs = self.params.get("epochs", 50)
early_stopping_patience = self.params.get("early_stopping_patience", 10)
# 训练循环
best_val_loss = float("inf")
best_val_ic = -float("inf") # Rank IC 越大越好
patience_counter = 0
best_model_state = None
print(f"[TabM] 开始训练... 设备: {self.device}, 集成大小: {ensemble_size}")
if date_col:
print(f"[TabM] 使用截面采样,日期列: {date_col}")
for epoch in range(epochs):
# 训练阶段
@@ -193,22 +258,30 @@ class TabMModel(BaseModel):
train_loss = 0.0
n_train_batches = 0
for bx, by in train_loader:
for batch in train_loader:
if len(batch) == 2:
bx, by = batch
bx, by = bx.to(self.device), by.to(self.device)
else:
continue
optimizer.zero_grad()
# 前向传播
# outputs形状: [Batch, Ensemble, 1]
outputs = self.model(bx)
outputs = self.model(bx) # [B, E, 1]
outputs_squeezed = outputs.squeeze(-1) # [B, E]
# 【关键】针对所有集成成员计算Loss
# 不先取均值,让每个集成成员独立收敛,保持集成多样性
by_expanded = by.unsqueeze(-1).expand(-1, ensemble_size) # [B, E]
loss = self.criterion(outputs_squeezed, by_expanded)
# 使用 EnsembleQuantLoss
loss = self.criterion(outputs_squeezed, by)
loss.backward()
# 【关键】梯度裁剪,防止某天的异常收益率拉崩模型
torch.nn.utils.clip_grad_norm_(
self.model.parameters(),
max_norm=self.params.get("max_grad_norm", 1.0),
)
optimizer.step()
train_loss += loss.item()
@@ -217,15 +290,19 @@ class TabMModel(BaseModel):
avg_train_loss = train_loss / max(n_train_batches, 1)
self.training_history_["train_loss"].append(avg_train_loss)
# 验证阶段
# 验证阶段(使用 Rank IC 作为指标)
if val_loader is not None:
val_loss = self._validate(val_loader)
self.training_history_["val_loss"].append(val_loss)
val_ic = self._validate(val_loader)
self.training_history_["val_ic"].append(val_ic)
# 早停逻辑
if val_loss < best_val_loss:
best_val_loss = val_loss
# 早停逻辑Rank IC 越大越好
if val_ic > best_val_ic:
best_val_ic = val_ic
patience_counter = 0
# 保存最佳模型权重
best_model_state = {
k: v.cpu().clone() for k, v in self.model.state_dict().items()
}
else:
patience_counter += 1
@@ -233,7 +310,7 @@ class TabMModel(BaseModel):
print(
f"[TabM] Epoch {epoch + 1}/{epochs} | "
f"Train Loss: {avg_train_loss:.6f} | "
f"Val Loss: {val_loss:.6f}"
f"Val Rank IC: {val_ic:.4f} (Best: {best_val_ic:.4f})"
)
if patience_counter >= early_stopping_patience:
@@ -246,11 +323,19 @@ class TabMModel(BaseModel):
f"Train Loss: {avg_train_loss:.6f}"
)
# 更新学习率
scheduler.step()
# 恢复最佳模型
if best_model_state is not None:
self.model.load_state_dict(best_model_state)
print(f"[TabM] 恢复最佳模型 (Val Rank IC: {best_val_ic:.4f})")
print(f"[TabM] 训练完成")
return self
def predict(self, X: pl.DataFrame) -> np.ndarray:
"""生成预测
"""生成预测(自动处理 date_col 和特征对齐)
预测时对ensemble_size个成员取均值获得稳定结果。
@@ -263,9 +348,22 @@ class TabMModel(BaseModel):
if self.model is None:
raise RuntimeError("模型未训练请先调用fit()")
# 自动移除 date_col
date_col = self.params.get("date_col")
if date_col and date_col in X.columns:
X = X.drop(date_col)
# 【关键修复】特征对齐:检查缺失列并显式抛出异常
if self.feature_names_:
missing_cols = [c for c in self.feature_names_ if c not in X.columns]
if missing_cols:
raise ValueError(f"预测数据缺失训练特征: {missing_cols}")
X = X.select(self.feature_names_)
# 数据类型转换
X_np = X.to_numpy().astype(np.float32)
loader = self._make_loader(X_np, shuffle=False)
# 【修复】使用正确的参数名 shuffle_days而非 shuffle
loader = self._make_loader(X_np, shuffle_days=False)
self.model.eval()
all_preds = []

View File

@@ -96,8 +96,11 @@ class TabMRegressionTask(BaseTask):
X_val = val_data["X"]
y_val = val_data["y"]
# 支持 date_col 参数
date_col = self.model_params.get("date_col", None)
# 训练模型
self.model.fit(X=X_train, y=y_train, eval_set=(X_val, y_val))
self.model.fit(X=X_train, y=y_train, eval_set=(X_val, y_val), date_col=date_col)
print("[TabMRegressionTask] 训练完成")

View File

@@ -272,7 +272,7 @@ class TestTabMIntegration:
# 4. 验证训练历史
model = task.get_model()
assert len(model.training_history_["train_loss"]) > 0
assert len(model.training_history_["val_loss"]) > 0
assert len(model.training_history_["val_ic"]) > 0
# 5. 预测
predictions = task.predict({"X": X_test})

View File

@@ -0,0 +1,79 @@
"""截面数据采样器单元测试"""
import numpy as np
import pytest
import torch
from torch.utils.data import TensorDataset
from src.training.components.models.cross_section_sampler import CrossSectionSampler
class TestCrossSectionSampler:
"""截面采样器单元测试"""
def test_basic_functionality(self):
"""测试基本功能:按日期分组"""
dates = np.array(["20240101", "20240101", "20240102", "20240102", "20240103"])
sampler = CrossSectionSampler(dates, shuffle_days=False)
# 应该有3个日期批次
assert len(sampler) == 3
# 获取所有批次
batches = list(sampler)
# 验证每个批次包含同一天的数据
for batch in batches:
batch_dates = [dates[i] for i in batch]
assert len(set(batch_dates)) == 1, "批次内日期不一致"
def test_shuffle_days(self):
"""测试日期打乱功能"""
np.random.seed(42)
dates = np.array(["20240101"] * 5 + ["20240102"] * 5 + ["20240103"] * 5)
# 多次采样,验证日期顺序会变化
orders = []
for _ in range(10):
batches = list(CrossSectionSampler(dates, shuffle_days=True))
date_order = [dates[batch[0]] for batch in batches]
orders.append(tuple(date_order))
# 应该有不同的顺序出现
assert len(set(orders)) > 1, "日期顺序未被打乱"
def test_internal_shuffle(self):
"""测试截面内股票顺序打乱"""
np.random.seed(42)
dates = np.array(["20240101"] * 10)
# 多次获取同一批次
indices_list = []
for _ in range(5):
sampler = CrossSectionSampler(dates, shuffle_days=False)
batch = next(iter(sampler))
indices_list.append(list(batch))
# 应该有不同顺序
assert len(set(tuple(x) for x in indices_list)) > 1, "截面内顺序未被打乱"
def test_with_dataloader(self):
"""测试与 DataLoader 集成"""
dates = np.array(["20240101", "20240101", "20240102", "20240102"])
X = torch.randn(4, 5)
y = torch.randn(4)
dataset = TensorDataset(X, y)
sampler = CrossSectionSampler(dates, shuffle_days=False)
loader = torch.utils.data.DataLoader(dataset, batch_sampler=sampler)
batches = list(loader)
assert len(batches) == 2 # 2个日期
for bx, by in batches:
assert bx.shape[0] == 2 # 每个日期2个样本
assert by.shape[0] == 2
if __name__ == "__main__":
pytest.main([__file__, "-v"])

View File

@@ -0,0 +1,98 @@
"""EnsembleQuantLoss 单元测试"""
import numpy as np
import pytest
import torch
import torch.nn as nn
from src.training.components.models.ensemble_quant_loss import EnsembleQuantLoss
class TestEnsembleQuantLoss:
"""EnsembleQuantLoss 单元测试"""
def test_initialization(self):
"""测试初始化"""
loss_fn = EnsembleQuantLoss(alpha=0.7, ensemble_size=16)
assert loss_fn.alpha == 0.7
assert loss_fn.ensemble_size == 16
assert isinstance(loss_fn.huber, nn.HuberLoss)
def test_output_shape(self):
"""测试输出形状和类型"""
loss_fn = EnsembleQuantLoss(alpha=0.5, ensemble_size=4)
# 创建模拟数据: 20只股票, 4个集成成员
preds = torch.randn(20, 4)
target = torch.randn(20)
loss = loss_fn(preds, target)
# 验证输出是标量
assert loss.shape == torch.Size([])
assert isinstance(loss.item(), float)
def test_small_batch_fallback(self):
"""测试小批次回退到 Huber"""
loss_fn = EnsembleQuantLoss(alpha=0.5, ensemble_size=4)
# 少于10只股票的批次
preds = torch.randn(5, 4)
target = torch.randn(5)
loss = loss_fn(preds, target)
# 应该正常返回loss
assert not torch.isnan(loss)
assert loss.item() > 0
def test_huber_component(self):
"""测试 Huber 损失组件"""
loss_fn = EnsembleQuantLoss(alpha=1.0, ensemble_size=4) # 纯 Huber
preds = torch.randn(50, 4)
target = torch.randn(50)
loss = loss_fn(preds, target)
# 手动计算期望的 Huber 损失
huber = nn.HuberLoss(reduction="mean")
expected_loss = 0
for i in range(4):
expected_loss += huber(preds[:, i], target)
expected_loss /= 4
assert torch.allclose(loss, expected_loss, rtol=1e-5)
def test_ic_component(self):
"""测试 IC 损失组件"""
loss_fn = EnsembleQuantLoss(alpha=0.0, ensemble_size=1) # 纯 IC
# 创建完全相关的预测和目标
target = torch.randn(50)
preds = target.unsqueeze(1) # 完美相关
loss = loss_fn(preds, target)
# 完美相关时 IC=1所以 IC loss = 0
# 但由于 std 计算和数值精度可能不完全为0
assert loss.item() < 0.1
def test_gradient_flow(self):
"""测试梯度可以正常回传"""
loss_fn = EnsembleQuantLoss(alpha=0.5, ensemble_size=4)
preds = torch.randn(50, 4, requires_grad=True)
target = torch.randn(50)
loss = loss_fn(preds, target)
loss.backward()
# 验证梯度存在且非零
assert preds.grad is not None
assert not torch.all(preds.grad == 0)
if __name__ == "__main__":
pytest.main([__file__, "-v"])