Files
ProStock/src/training/components/models/cross_section_sampler.py

60 lines
1.9 KiB
Python
Raw Normal View History

"""截面数据采样器
用于量化交易场景确保每个批次包含同一天的所有股票数据
实现横截面批处理避免全局随机打乱
"""
import numpy as np
from torch.utils.data import Sampler
class CrossSectionSampler(Sampler):
"""截面数据采样器
保证每次产出的 indices 属于同一个交易日
适用于量化选股场景模型每轮前向传播面对的都是当天的全市场股票
Attributes:
date_to_indices: 日期到索引列表的映射
unique_dates: 唯一日期列表
shuffle_days: 是否打乱日期顺序
"""
def __init__(self, dates: np.ndarray, shuffle_days: bool = True):
"""初始化采样器
Args:
dates: 日期数组每个元素对应一行数据
shuffle_days: 是否打乱日期的训练顺序但同一天的数据始终在一起
"""
# 记录每个日期对应的所有行索引
self.date_to_indices = {}
for idx, date in enumerate(dates):
date_str = str(date)
if date_str not in self.date_to_indices:
self.date_to_indices[date_str] = []
self.date_to_indices[date_str].append(idx)
self.unique_dates = list(self.date_to_indices.keys())
self.shuffle_days = shuffle_days
def __iter__(self):
"""迭代生成批次索引
Yields:
list: 同一日期的所有样本索引
"""
dates = self.unique_dates.copy()
if self.shuffle_days:
np.random.shuffle(dates) # 打乱日期的训练顺序
for date in dates:
indices = self.date_to_indices[date].copy()
# 在截面内打乱股票顺序,防止顺序带来的隐性 bias
np.random.shuffle(indices)
yield indices
def __len__(self):
"""返回批次数量(等于日期数量)"""
return len(self.unique_dates)