Files
ProStock/src/training/components/models/cross_section_sampler.py
liaozhaorun c143815443 feat(training): TabM模型量化交易优化
- 新增 CrossSectionSampler 支持截面数据采样(按交易日批处理)
- 新增 EnsembleQuantLoss (Huber + IC) 替代 MSE 作为损失函数
- 重构 TabMModel 支持量化场景:Rank IC 作为验证指标、CosineAnnealingLR学习率调度、梯度裁剪
- 支持 date_col 参数和特征对齐
- 更新实验配置 batch_size 2048 和 weight_decay 等超参数
2026-04-01 00:20:05 +08:00

60 lines
1.9 KiB
Python

"""截面数据采样器
用于量化交易场景,确保每个批次包含同一天的所有股票数据。
实现横截面批处理,避免全局随机打乱。
"""
import numpy as np
from torch.utils.data import Sampler
class CrossSectionSampler(Sampler):
"""截面数据采样器
保证每次产出的 indices 属于同一个交易日。
适用于量化选股场景:模型每轮前向传播面对的都是当天的全市场股票。
Attributes:
date_to_indices: 日期到索引列表的映射
unique_dates: 唯一日期列表
shuffle_days: 是否打乱日期顺序
"""
def __init__(self, dates: np.ndarray, shuffle_days: bool = True):
"""初始化采样器
Args:
dates: 日期数组,每个元素对应一行数据
shuffle_days: 是否打乱日期的训练顺序,但同一天的数据始终在一起
"""
# 记录每个日期对应的所有行索引
self.date_to_indices = {}
for idx, date in enumerate(dates):
date_str = str(date)
if date_str not in self.date_to_indices:
self.date_to_indices[date_str] = []
self.date_to_indices[date_str].append(idx)
self.unique_dates = list(self.date_to_indices.keys())
self.shuffle_days = shuffle_days
def __iter__(self):
"""迭代生成批次索引
Yields:
list: 同一日期的所有样本索引
"""
dates = self.unique_dates.copy()
if self.shuffle_days:
np.random.shuffle(dates) # 打乱日期的训练顺序
for date in dates:
indices = self.date_to_indices[date].copy()
# 在截面内打乱股票顺序,防止顺序带来的隐性 bias
np.random.shuffle(indices)
yield indices
def __len__(self):
"""返回批次数量(等于日期数量)"""
return len(self.unique_dates)