"""截面数据采样器 用于量化交易场景,确保每个批次包含同一天的所有股票数据。 实现横截面批处理,避免全局随机打乱。 """ import numpy as np from torch.utils.data import Sampler class CrossSectionSampler(Sampler): """截面数据采样器 保证每次产出的 indices 属于同一个交易日。 适用于量化选股场景:模型每轮前向传播面对的都是当天的全市场股票。 Attributes: date_to_indices: 日期到索引列表的映射 unique_dates: 唯一日期列表 shuffle_days: 是否打乱日期顺序 """ def __init__(self, dates: np.ndarray, shuffle_days: bool = True): """初始化采样器 Args: dates: 日期数组,每个元素对应一行数据 shuffle_days: 是否打乱日期的训练顺序,但同一天的数据始终在一起 """ # 记录每个日期对应的所有行索引 self.date_to_indices = {} for idx, date in enumerate(dates): date_str = str(date) if date_str not in self.date_to_indices: self.date_to_indices[date_str] = [] self.date_to_indices[date_str].append(idx) self.unique_dates = list(self.date_to_indices.keys()) self.shuffle_days = shuffle_days def __iter__(self): """迭代生成批次索引 Yields: list: 同一日期的所有样本索引 """ dates = self.unique_dates.copy() if self.shuffle_days: np.random.shuffle(dates) # 打乱日期的训练顺序 for date in dates: indices = self.date_to_indices[date].copy() # 在截面内打乱股票顺序,防止顺序带来的隐性 bias np.random.shuffle(indices) yield indices def __len__(self): """返回批次数量(等于日期数量)""" return len(self.unique_dates)