From c143815443f3e775a1a3f939e7b76913ff977761 Mon Sep 17 00:00:00 2001 From: liaozhaorun <1300336796@qq.com> Date: Wed, 1 Apr 2026 00:20:05 +0800 Subject: [PATCH] =?UTF-8?q?feat(training):=20TabM=E6=A8=A1=E5=9E=8B?= =?UTF-8?q?=E9=87=8F=E5=8C=96=E4=BA=A4=E6=98=93=E4=BC=98=E5=8C=96=20-=20?= =?UTF-8?q?=E6=96=B0=E5=A2=9E=20CrossSectionSampler=20=E6=94=AF=E6=8C=81?= =?UTF-8?q?=E6=88=AA=E9=9D=A2=E6=95=B0=E6=8D=AE=E9=87=87=E6=A0=B7=EF=BC=88?= =?UTF-8?q?=E6=8C=89=E4=BA=A4=E6=98=93=E6=97=A5=E6=89=B9=E5=A4=84=E7=90=86?= =?UTF-8?q?=EF=BC=89=20-=20=E6=96=B0=E5=A2=9E=20EnsembleQuantLoss=20(Huber?= =?UTF-8?q?=20+=20IC)=20=E6=9B=BF=E4=BB=A3=20MSE=20=E4=BD=9C=E4=B8=BA?= =?UTF-8?q?=E6=8D=9F=E5=A4=B1=E5=87=BD=E6=95=B0=20-=20=E9=87=8D=E6=9E=84?= =?UTF-8?q?=20TabMModel=20=E6=94=AF=E6=8C=81=E9=87=8F=E5=8C=96=E5=9C=BA?= =?UTF-8?q?=E6=99=AF=EF=BC=9ARank=20IC=20=E4=BD=9C=E4=B8=BA=E9=AA=8C?= =?UTF-8?q?=E8=AF=81=E6=8C=87=E6=A0=87=E3=80=81CosineAnnealingLR=E5=AD=A6?= =?UTF-8?q?=E4=B9=A0=E7=8E=87=E8=B0=83=E5=BA=A6=E3=80=81=E6=A2=AF=E5=BA=A6?= =?UTF-8?q?=E8=A3=81=E5=89=AA=20-=20=E6=94=AF=E6=8C=81=20date=5Fcol=20?= =?UTF-8?q?=E5=8F=82=E6=95=B0=E5=92=8C=E7=89=B9=E5=BE=81=E5=AF=B9=E9=BD=90?= =?UTF-8?q?=20-=20=E6=9B=B4=E6=96=B0=E5=AE=9E=E9=AA=8C=E9=85=8D=E7=BD=AE?= =?UTF-8?q?=20batch=5Fsize=202048=20=E5=92=8C=20weight=5Fdecay=20=E7=AD=89?= =?UTF-8?q?=E8=B6=85=E5=8F=82=E6=95=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/experiment/tabm_regression.py | 2 +- src/training/components/models/__init__.py | 11 +- .../models/cross_section_sampler.py | 59 +++++ .../components/models/ensemble_quant_loss.py | 86 +++++++ src/training/components/models/tabm_model.py | 210 +++++++++++++----- src/training/tasks/tabm_regression_task.py | 5 +- tests/test_tabm_integration.py | 2 +- tests/training/test_cross_section_sampler.py | 79 +++++++ tests/training/test_ensemble_quant_loss.py | 98 ++++++++ 9 files changed, 492 insertions(+), 60 deletions(-) create mode 100644 src/training/components/models/cross_section_sampler.py create mode 100644 src/training/components/models/ensemble_quant_loss.py create mode 100644 tests/training/test_cross_section_sampler.py create mode 100644 tests/training/test_ensemble_quant_loss.py diff --git a/src/experiment/tabm_regression.py b/src/experiment/tabm_regression.py index 16f04a6..a998771 100644 --- a/src/experiment/tabm_regression.py +++ b/src/experiment/tabm_regression.py @@ -260,7 +260,7 @@ MODEL_PARAMS = { # ==================== 集成机制 ==================== "ensemble_size": 32, # 内置集成大小(模拟 32 个模型集成) # ==================== 训练参数 ==================== - "batch_size": 1024, # 批次大小 + "batch_size": 2048, # 批次大小 "learning_rate": 1e-3, # 学习率 "weight_decay": 1e-5, # 权重衰减 "epochs": 100, # 训练轮数 diff --git a/src/training/components/models/__init__.py b/src/training/components/models/__init__.py index 648f4ec..907e040 100644 --- a/src/training/components/models/__init__.py +++ b/src/training/components/models/__init__.py @@ -7,5 +7,14 @@ from src.training.components.models.lightgbm import LightGBMModel from src.training.components.models.lightgbm_lambdarank import LightGBMLambdaRankModel from src.training.components.models.tabpfn_model import TabPFNModel from src.training.components.models.tabm_model import TabMModel +from src.training.components.models.cross_section_sampler import CrossSectionSampler +from src.training.components.models.ensemble_quant_loss import EnsembleQuantLoss -__all__ = ["LightGBMModel", "LightGBMLambdaRankModel", "TabPFNModel", "TabMModel"] +__all__ = [ + "LightGBMModel", + "LightGBMLambdaRankModel", + "TabPFNModel", + "TabMModel", + "CrossSectionSampler", + "EnsembleQuantLoss", +] diff --git a/src/training/components/models/cross_section_sampler.py b/src/training/components/models/cross_section_sampler.py new file mode 100644 index 0000000..3d9e28a --- /dev/null +++ b/src/training/components/models/cross_section_sampler.py @@ -0,0 +1,59 @@ +"""截面数据采样器 + +用于量化交易场景,确保每个批次包含同一天的所有股票数据。 +实现横截面批处理,避免全局随机打乱。 +""" + +import numpy as np +from torch.utils.data import Sampler + + +class CrossSectionSampler(Sampler): + """截面数据采样器 + + 保证每次产出的 indices 属于同一个交易日。 + 适用于量化选股场景:模型每轮前向传播面对的都是当天的全市场股票。 + + Attributes: + date_to_indices: 日期到索引列表的映射 + unique_dates: 唯一日期列表 + shuffle_days: 是否打乱日期顺序 + """ + + def __init__(self, dates: np.ndarray, shuffle_days: bool = True): + """初始化采样器 + + Args: + dates: 日期数组,每个元素对应一行数据 + shuffle_days: 是否打乱日期的训练顺序,但同一天的数据始终在一起 + """ + # 记录每个日期对应的所有行索引 + self.date_to_indices = {} + for idx, date in enumerate(dates): + date_str = str(date) + if date_str not in self.date_to_indices: + self.date_to_indices[date_str] = [] + self.date_to_indices[date_str].append(idx) + + self.unique_dates = list(self.date_to_indices.keys()) + self.shuffle_days = shuffle_days + + def __iter__(self): + """迭代生成批次索引 + + Yields: + list: 同一日期的所有样本索引 + """ + dates = self.unique_dates.copy() + if self.shuffle_days: + np.random.shuffle(dates) # 打乱日期的训练顺序 + + for date in dates: + indices = self.date_to_indices[date].copy() + # 在截面内打乱股票顺序,防止顺序带来的隐性 bias + np.random.shuffle(indices) + yield indices + + def __len__(self): + """返回批次数量(等于日期数量)""" + return len(self.unique_dates) diff --git a/src/training/components/models/ensemble_quant_loss.py b/src/training/components/models/ensemble_quant_loss.py new file mode 100644 index 0000000..a7fe564 --- /dev/null +++ b/src/training/components/models/ensemble_quant_loss.py @@ -0,0 +1,86 @@ +"""量化专用组合损失函数 + +结合 Huber Loss(处理极值)和 IC Loss(优化排序), +针对 TabM 的集成输出设计,为每个集成成员独立计算损失以保持多样性。 +""" + +import torch +import torch.nn as nn + + +class EnsembleQuantLoss(nn.Module): + """Ensemble 量化组合损失函数 + + 组合 Huber Loss 和 IC Loss: + - Huber Loss: 对收益率极值鲁棒,避免梯度爆炸 + - IC Loss: 直接优化预测与目标的秩相关性 + + 针对 TabM 的集成输出设计,为每个集成成员独立计算损失, + 保持集成多样性。 + + 注意:std() 使用 unbiased=False 以保持数学一致性 + (IC 计算使用除以 N,而非 N-1) + + Attributes: + alpha: Huber Loss 权重,(1-alpha) 为 IC Loss 权重 + huber: HuberLoss 实例 + ensemble_size: 集成成员数量 + """ + + def __init__(self, alpha: float = 0.5, ensemble_size: int = 32): + """初始化损失函数 + + Args: + alpha: Huber Loss 权重,范围 [0, 1] + ensemble_size: TabM 集成大小 + """ + super().__init__() + self.alpha = alpha + self.huber = nn.HuberLoss(reduction="mean") + self.ensemble_size = ensemble_size + + def forward(self, preds: torch.Tensor, target: torch.Tensor) -> torch.Tensor: + """计算组合损失 + + Args: + preds: 预测值,形状 [Batch(单日股票数), Ensemble_size] + target: 目标值,形状 [Batch(单日股票数)] + + Returns: + 标量损失值 + """ + total_loss = 0.0 + + # 为了稳定,如果单日截面股票数太少,直接回退到 Huber + if preds.shape[0] < 10: + for i in range(self.ensemble_size): + total_loss += self.huber(preds[:, i], target) + return total_loss / self.ensemble_size + + # 预计算目标的标准化值(在所有集成成员间复用) + target_mean = target.mean() + # 【关键】使用 unbiased=False 保持数学一致性 + target_std = target.std(unbiased=False) + 1e-8 + target_norm = (target - target_mean) / target_std + + # 对每个 ensemble 成员独立计算组合 Loss + for i in range(self.ensemble_size): + pred_i = preds[:, i] + + # 1. Huber Loss (处理极值比 MSE 更好) + h_loss = self.huber(pred_i, target) + + # 2. IC Loss (皮尔逊相关系数) + pred_mean = pred_i.mean() + # 【关键】使用 unbiased=False 保持数学一致性 + pred_std = pred_i.std(unbiased=False) + 1e-8 + pred_norm = (pred_i - pred_mean) / pred_std + + # 相关系数 + ic = (pred_norm * target_norm).mean() + # 损失函数化:希望 IC 越大越好,所以用 1 - IC + ic_loss = 1.0 - ic + + total_loss += self.alpha * h_loss + (1.0 - self.alpha) * ic_loss + + return total_loss / self.ensemble_size diff --git a/src/training/components/models/tabm_model.py b/src/training/components/models/tabm_model.py index d0ee247..f8063a5 100644 --- a/src/training/components/models/tabm_model.py +++ b/src/training/components/models/tabm_model.py @@ -10,6 +10,7 @@ import pickle import numpy as np import polars as pl +import scipy.stats as stats import torch import torch.nn as nn import torch.optim as optim @@ -17,6 +18,8 @@ from torch.utils.data import DataLoader, TensorDataset from tabm import TabM from src.training.components.base import BaseModel +from src.training.components.models.cross_section_sampler import CrossSectionSampler +from src.training.components.models.ensemble_quant_loss import EnsembleQuantLoss from src.training.registry import register_model @@ -69,14 +72,23 @@ class TabMModel(BaseModel): self.criterion = nn.MSELoss() def _make_loader( - self, X: np.ndarray, y: Optional[np.ndarray] = None, shuffle: bool = False + self, + X: np.ndarray, + y: Optional[np.ndarray] = None, + dates: Optional[np.ndarray] = None, + shuffle_days: bool = False, ) -> DataLoader: """创建DataLoader + 支持两种模式: + 1. 截面模式 (dates provided): 每天的数据作为一个独立批次 + 2. 普通模式 (dates is None): 使用固定batch_size + Args: X: 特征数组 [N, n_features] y: 标签数组 [N] 或 None - shuffle: 是否打乱数据 + dates: 日期数组 [N],用于截面采样 + shuffle_days: 是否打乱日期顺序 Returns: DataLoader实例 @@ -86,106 +98,159 @@ class TabMModel(BaseModel): else: dataset = TensorDataset(torch.from_numpy(X)) - batch_size = self.params.get("batch_size", 1024) - return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle) + if dates is not None: + # 使用截面 Sampler + sampler = CrossSectionSampler(dates, shuffle_days=shuffle_days) + # 注意:使用 batch_sampler 时,不需要指定 batch_size 和 shuffle + return DataLoader(dataset, batch_sampler=sampler) + else: + # 预测时如果没有 date,退化为普通 Loader + batch_size = self.params.get("batch_size", 1024) + return DataLoader(dataset, batch_size=batch_size, shuffle=False) def _validate(self, val_loader: DataLoader) -> float: - """验证模型 + """验证模型(使用 Rank IC 作为指标) Args: val_loader: 验证数据加载器 Returns: - 平均验证损失 + 平均 Rank IC(斯皮尔曼秩相关系数) """ self.model.eval() - total_loss = 0.0 - n_batches = 0 + ic_list = [] with torch.no_grad(): for batch in val_loader: - if len(batch) == 2: - bx, by = batch - bx, by = bx.to(self.device), by.to(self.device) - else: - bx = batch[0].to(self.device) - by = None + if len(batch) != 2: + continue + + bx, by = batch + bx = bx.to(self.device) + by = by.cpu().numpy() + + # 如果截面股票数太少,跳过 + if len(by) < 10: + continue - # 预测时取集成成员均值 outputs = self.model(bx) # [B, E, 1] - preds = outputs.mean(dim=1).squeeze(-1) # [B] + # 预测时取均值 + preds = outputs.mean(dim=1).squeeze(-1).cpu().numpy() - if by is not None: - loss = self.criterion(preds, by).item() - total_loss += loss - n_batches += 1 + # 计算斯皮尔曼秩相关系数 (Rank IC) + rank_ic, _ = stats.spearmanr(preds, by) - return total_loss / max(n_batches, 1) + # 如果算不出 rank_ic,记为 0 + if np.isnan(rank_ic): + rank_ic = 0.0 + + ic_list.append(rank_ic) + + # 返回验证集的平均 Rank IC + return np.mean(ic_list) if len(ic_list) > 0 else 0.0 def fit( - self, X: pl.DataFrame, y: pl.Series, eval_set: Optional[tuple] = None + self, + X: pl.DataFrame, + y: pl.Series, + eval_set: Optional[tuple] = None, + date_col: Optional[str] = None, ) -> "TabMModel": - """训练TabM模型 + """训练TabM模型(量化优化版本) 训练策略: - 1. 对所有集成成员独立计算Loss,保持多样性 - 2. 验证和预测时取ensemble成员均值 + 1. 支持截面数据采样(按交易日批处理) + 2. 使用 EnsembleQuantLoss (Huber + IC) 替代 MSE + 3. CosineAnnealingLR 学习率调度 + 4. 梯度裁剪防止异常数据影响 + 5. 以 Rank IC 作为验证指标和早停依据 Args: - X: 训练特征DataFrame + X: 训练特征DataFrame,可包含日期列 y: 训练标签Series eval_set: 验证集元组 (X_val, y_val),可选 + date_col: 日期列名称,用于截面采样。如果为None,使用普通批次 Returns: self """ - # 保存特征名称 - self.feature_names_ = X.columns + # 【修复】初始化 val_ic 训练历史 + if "val_ic" not in self.training_history_: + self.training_history_["val_ic"] = [] + + # 保存特征名称并处理日期列 + self.feature_names_ = list(X.columns) + dates = None + if date_col and date_col in X.columns: + dates = X[date_col].to_numpy() + X = X.drop(date_col) + self.feature_names_ = [c for c in self.feature_names_ if c != date_col] + self.params["date_col"] = date_col # 保存供 predict 使用 # 【关键】数据类型强制转换为float32 - # PyTorch对float64支持较差,避免使用Polars/Numpy默认类型 X_np = X.to_numpy().astype(np.float32) y_np = y.to_numpy().astype(np.float32) - # 创建DataLoader - train_loader = self._make_loader(X_np, y_np, shuffle=True) + # 创建DataLoader(支持截面采样) + train_loader = self._make_loader(X_np, y_np, dates=dates, shuffle_days=True) + val_loader = None + val_dates = None if eval_set is not None: X_val, y_val = eval_set + # 处理验证集的日期列 + if date_col and date_col in X_val.columns: + val_dates = X_val[date_col].to_numpy() + X_val = X_val.drop(date_col) X_val_np = X_val.to_numpy().astype(np.float32) y_val_np = y_val.to_numpy().astype(np.float32) - val_loader = self._make_loader(X_val_np, y_val_np, shuffle=False) + val_loader = self._make_loader( + X_val_np, y_val_np, dates=val_dates, shuffle_days=False + ) n_features = X_np.shape[1] ensemble_size = self.params.get("ensemble_size", 32) - # 初始化TabM模型,使用TabM.make()自动填充默认参数 + # 初始化TabM模型 self.model = TabM.make( n_num_features=n_features, cat_cardinalities=[], - d_out=1, # 回归任务输出维度为1 + d_out=1, n_blocks=self.params.get("n_blocks", 3), d_block=self.params.get("d_block", 256), dropout=self.params.get("dropout", 0.1), - k=ensemble_size, # 集成大小 + k=ensemble_size, ).to(self.device) - # 优化器 + # 优化器(增加 weight_decay 用于正则化) optimizer = optim.AdamW( self.model.parameters(), lr=self.params.get("learning_rate", 1e-3), - weight_decay=self.params.get("weight_decay", 1e-5), + weight_decay=self.params.get("weight_decay", 1e-4), + ) + + # 学习率调度器:CosineAnnealingLR 平滑衰减 + epochs = self.params.get("epochs", 50) + scheduler = optim.lr_scheduler.CosineAnnealingLR( + optimizer, T_max=epochs, eta_min=1e-6 + ) + + # 使用 EnsembleQuantLoss 替代 MSE + self.criterion = EnsembleQuantLoss( + alpha=self.params.get("loss_alpha", 0.5), ensemble_size=ensemble_size ) # 训练参数 - epochs = self.params.get("epochs", 50) early_stopping_patience = self.params.get("early_stopping_patience", 10) # 训练循环 - best_val_loss = float("inf") + best_val_ic = -float("inf") # Rank IC 越大越好 patience_counter = 0 + best_model_state = None print(f"[TabM] 开始训练... 设备: {self.device}, 集成大小: {ensemble_size}") + if date_col: + print(f"[TabM] 使用截面采样,日期列: {date_col}") for epoch in range(epochs): # 训练阶段 @@ -193,22 +258,30 @@ class TabMModel(BaseModel): train_loss = 0.0 n_train_batches = 0 - for bx, by in train_loader: - bx, by = bx.to(self.device), by.to(self.device) + for batch in train_loader: + if len(batch) == 2: + bx, by = batch + bx, by = bx.to(self.device), by.to(self.device) + else: + continue optimizer.zero_grad() # 前向传播 - # outputs形状: [Batch, Ensemble, 1] - outputs = self.model(bx) + outputs = self.model(bx) # [B, E, 1] outputs_squeezed = outputs.squeeze(-1) # [B, E] - # 【关键】针对所有集成成员计算Loss - # 不先取均值,让每个集成成员独立收敛,保持集成多样性 - by_expanded = by.unsqueeze(-1).expand(-1, ensemble_size) # [B, E] - loss = self.criterion(outputs_squeezed, by_expanded) + # 使用 EnsembleQuantLoss + loss = self.criterion(outputs_squeezed, by) loss.backward() + + # 【关键】梯度裁剪,防止某天的异常收益率拉崩模型 + torch.nn.utils.clip_grad_norm_( + self.model.parameters(), + max_norm=self.params.get("max_grad_norm", 1.0), + ) + optimizer.step() train_loss += loss.item() @@ -217,15 +290,19 @@ class TabMModel(BaseModel): avg_train_loss = train_loss / max(n_train_batches, 1) self.training_history_["train_loss"].append(avg_train_loss) - # 验证阶段 + # 验证阶段(使用 Rank IC 作为指标) if val_loader is not None: - val_loss = self._validate(val_loader) - self.training_history_["val_loss"].append(val_loss) + val_ic = self._validate(val_loader) + self.training_history_["val_ic"].append(val_ic) - # 早停逻辑 - if val_loss < best_val_loss: - best_val_loss = val_loss + # 早停逻辑:Rank IC 越大越好 + if val_ic > best_val_ic: + best_val_ic = val_ic patience_counter = 0 + # 保存最佳模型权重 + best_model_state = { + k: v.cpu().clone() for k, v in self.model.state_dict().items() + } else: patience_counter += 1 @@ -233,7 +310,7 @@ class TabMModel(BaseModel): print( f"[TabM] Epoch {epoch + 1}/{epochs} | " f"Train Loss: {avg_train_loss:.6f} | " - f"Val Loss: {val_loss:.6f}" + f"Val Rank IC: {val_ic:.4f} (Best: {best_val_ic:.4f})" ) if patience_counter >= early_stopping_patience: @@ -246,11 +323,19 @@ class TabMModel(BaseModel): f"Train Loss: {avg_train_loss:.6f}" ) + # 更新学习率 + scheduler.step() + + # 恢复最佳模型 + if best_model_state is not None: + self.model.load_state_dict(best_model_state) + print(f"[TabM] 恢复最佳模型 (Val Rank IC: {best_val_ic:.4f})") + print(f"[TabM] 训练完成") return self def predict(self, X: pl.DataFrame) -> np.ndarray: - """生成预测 + """生成预测(自动处理 date_col 和特征对齐) 预测时对ensemble_size个成员取均值,获得稳定结果。 @@ -263,9 +348,22 @@ class TabMModel(BaseModel): if self.model is None: raise RuntimeError("模型未训练,请先调用fit()") + # 自动移除 date_col + date_col = self.params.get("date_col") + if date_col and date_col in X.columns: + X = X.drop(date_col) + + # 【关键修复】特征对齐:检查缺失列并显式抛出异常 + if self.feature_names_: + missing_cols = [c for c in self.feature_names_ if c not in X.columns] + if missing_cols: + raise ValueError(f"预测数据缺失训练特征: {missing_cols}") + X = X.select(self.feature_names_) + # 数据类型转换 X_np = X.to_numpy().astype(np.float32) - loader = self._make_loader(X_np, shuffle=False) + # 【修复】使用正确的参数名 shuffle_days(而非 shuffle) + loader = self._make_loader(X_np, shuffle_days=False) self.model.eval() all_preds = [] diff --git a/src/training/tasks/tabm_regression_task.py b/src/training/tasks/tabm_regression_task.py index 63f42dc..8280d19 100644 --- a/src/training/tasks/tabm_regression_task.py +++ b/src/training/tasks/tabm_regression_task.py @@ -96,8 +96,11 @@ class TabMRegressionTask(BaseTask): X_val = val_data["X"] y_val = val_data["y"] + # 支持 date_col 参数 + date_col = self.model_params.get("date_col", None) + # 训练模型 - self.model.fit(X=X_train, y=y_train, eval_set=(X_val, y_val)) + self.model.fit(X=X_train, y=y_train, eval_set=(X_val, y_val), date_col=date_col) print("[TabMRegressionTask] 训练完成") diff --git a/tests/test_tabm_integration.py b/tests/test_tabm_integration.py index 294904a..0566680 100644 --- a/tests/test_tabm_integration.py +++ b/tests/test_tabm_integration.py @@ -272,7 +272,7 @@ class TestTabMIntegration: # 4. 验证训练历史 model = task.get_model() assert len(model.training_history_["train_loss"]) > 0 - assert len(model.training_history_["val_loss"]) > 0 + assert len(model.training_history_["val_ic"]) > 0 # 5. 预测 predictions = task.predict({"X": X_test}) diff --git a/tests/training/test_cross_section_sampler.py b/tests/training/test_cross_section_sampler.py new file mode 100644 index 0000000..a22b7c8 --- /dev/null +++ b/tests/training/test_cross_section_sampler.py @@ -0,0 +1,79 @@ +"""截面数据采样器单元测试""" + +import numpy as np +import pytest +import torch +from torch.utils.data import TensorDataset + +from src.training.components.models.cross_section_sampler import CrossSectionSampler + + +class TestCrossSectionSampler: + """截面采样器单元测试""" + + def test_basic_functionality(self): + """测试基本功能:按日期分组""" + dates = np.array(["20240101", "20240101", "20240102", "20240102", "20240103"]) + sampler = CrossSectionSampler(dates, shuffle_days=False) + + # 应该有3个日期批次 + assert len(sampler) == 3 + + # 获取所有批次 + batches = list(sampler) + + # 验证每个批次包含同一天的数据 + for batch in batches: + batch_dates = [dates[i] for i in batch] + assert len(set(batch_dates)) == 1, "批次内日期不一致" + + def test_shuffle_days(self): + """测试日期打乱功能""" + np.random.seed(42) + dates = np.array(["20240101"] * 5 + ["20240102"] * 5 + ["20240103"] * 5) + + # 多次采样,验证日期顺序会变化 + orders = [] + for _ in range(10): + batches = list(CrossSectionSampler(dates, shuffle_days=True)) + date_order = [dates[batch[0]] for batch in batches] + orders.append(tuple(date_order)) + + # 应该有不同的顺序出现 + assert len(set(orders)) > 1, "日期顺序未被打乱" + + def test_internal_shuffle(self): + """测试截面内股票顺序打乱""" + np.random.seed(42) + dates = np.array(["20240101"] * 10) + + # 多次获取同一批次 + indices_list = [] + for _ in range(5): + sampler = CrossSectionSampler(dates, shuffle_days=False) + batch = next(iter(sampler)) + indices_list.append(list(batch)) + + # 应该有不同顺序 + assert len(set(tuple(x) for x in indices_list)) > 1, "截面内顺序未被打乱" + + def test_with_dataloader(self): + """测试与 DataLoader 集成""" + dates = np.array(["20240101", "20240101", "20240102", "20240102"]) + X = torch.randn(4, 5) + y = torch.randn(4) + + dataset = TensorDataset(X, y) + sampler = CrossSectionSampler(dates, shuffle_days=False) + loader = torch.utils.data.DataLoader(dataset, batch_sampler=sampler) + + batches = list(loader) + assert len(batches) == 2 # 2个日期 + + for bx, by in batches: + assert bx.shape[0] == 2 # 每个日期2个样本 + assert by.shape[0] == 2 + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/tests/training/test_ensemble_quant_loss.py b/tests/training/test_ensemble_quant_loss.py new file mode 100644 index 0000000..a511e30 --- /dev/null +++ b/tests/training/test_ensemble_quant_loss.py @@ -0,0 +1,98 @@ +"""EnsembleQuantLoss 单元测试""" + +import numpy as np +import pytest +import torch +import torch.nn as nn + +from src.training.components.models.ensemble_quant_loss import EnsembleQuantLoss + + +class TestEnsembleQuantLoss: + """EnsembleQuantLoss 单元测试""" + + def test_initialization(self): + """测试初始化""" + loss_fn = EnsembleQuantLoss(alpha=0.7, ensemble_size=16) + + assert loss_fn.alpha == 0.7 + assert loss_fn.ensemble_size == 16 + assert isinstance(loss_fn.huber, nn.HuberLoss) + + def test_output_shape(self): + """测试输出形状和类型""" + loss_fn = EnsembleQuantLoss(alpha=0.5, ensemble_size=4) + + # 创建模拟数据: 20只股票, 4个集成成员 + preds = torch.randn(20, 4) + target = torch.randn(20) + + loss = loss_fn(preds, target) + + # 验证输出是标量 + assert loss.shape == torch.Size([]) + assert isinstance(loss.item(), float) + + def test_small_batch_fallback(self): + """测试小批次回退到 Huber""" + loss_fn = EnsembleQuantLoss(alpha=0.5, ensemble_size=4) + + # 少于10只股票的批次 + preds = torch.randn(5, 4) + target = torch.randn(5) + + loss = loss_fn(preds, target) + + # 应该正常返回loss + assert not torch.isnan(loss) + assert loss.item() > 0 + + def test_huber_component(self): + """测试 Huber 损失组件""" + loss_fn = EnsembleQuantLoss(alpha=1.0, ensemble_size=4) # 纯 Huber + + preds = torch.randn(50, 4) + target = torch.randn(50) + + loss = loss_fn(preds, target) + + # 手动计算期望的 Huber 损失 + huber = nn.HuberLoss(reduction="mean") + expected_loss = 0 + for i in range(4): + expected_loss += huber(preds[:, i], target) + expected_loss /= 4 + + assert torch.allclose(loss, expected_loss, rtol=1e-5) + + def test_ic_component(self): + """测试 IC 损失组件""" + loss_fn = EnsembleQuantLoss(alpha=0.0, ensemble_size=1) # 纯 IC + + # 创建完全相关的预测和目标 + target = torch.randn(50) + preds = target.unsqueeze(1) # 完美相关 + + loss = loss_fn(preds, target) + + # 完美相关时 IC=1,所以 IC loss = 0 + # 但由于 std 计算和数值精度,可能不完全为0 + assert loss.item() < 0.1 + + def test_gradient_flow(self): + """测试梯度可以正常回传""" + loss_fn = EnsembleQuantLoss(alpha=0.5, ensemble_size=4) + + preds = torch.randn(50, 4, requires_grad=True) + target = torch.randn(50) + + loss = loss_fn(preds, target) + loss.backward() + + # 验证梯度存在且非零 + assert preds.grad is not None + assert not torch.all(preds.grad == 0) + + +if __name__ == "__main__": + pytest.main([__file__, "-v"])