60 lines
1.9 KiB
Python
60 lines
1.9 KiB
Python
|
|
"""截面数据采样器
|
|||
|
|
|
|||
|
|
用于量化交易场景,确保每个批次包含同一天的所有股票数据。
|
|||
|
|
实现横截面批处理,避免全局随机打乱。
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
import numpy as np
|
|||
|
|
from torch.utils.data import Sampler
|
|||
|
|
|
|||
|
|
|
|||
|
|
class CrossSectionSampler(Sampler):
|
|||
|
|
"""截面数据采样器
|
|||
|
|
|
|||
|
|
保证每次产出的 indices 属于同一个交易日。
|
|||
|
|
适用于量化选股场景:模型每轮前向传播面对的都是当天的全市场股票。
|
|||
|
|
|
|||
|
|
Attributes:
|
|||
|
|
date_to_indices: 日期到索引列表的映射
|
|||
|
|
unique_dates: 唯一日期列表
|
|||
|
|
shuffle_days: 是否打乱日期顺序
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
def __init__(self, dates: np.ndarray, shuffle_days: bool = True):
|
|||
|
|
"""初始化采样器
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
dates: 日期数组,每个元素对应一行数据
|
|||
|
|
shuffle_days: 是否打乱日期的训练顺序,但同一天的数据始终在一起
|
|||
|
|
"""
|
|||
|
|
# 记录每个日期对应的所有行索引
|
|||
|
|
self.date_to_indices = {}
|
|||
|
|
for idx, date in enumerate(dates):
|
|||
|
|
date_str = str(date)
|
|||
|
|
if date_str not in self.date_to_indices:
|
|||
|
|
self.date_to_indices[date_str] = []
|
|||
|
|
self.date_to_indices[date_str].append(idx)
|
|||
|
|
|
|||
|
|
self.unique_dates = list(self.date_to_indices.keys())
|
|||
|
|
self.shuffle_days = shuffle_days
|
|||
|
|
|
|||
|
|
def __iter__(self):
|
|||
|
|
"""迭代生成批次索引
|
|||
|
|
|
|||
|
|
Yields:
|
|||
|
|
list: 同一日期的所有样本索引
|
|||
|
|
"""
|
|||
|
|
dates = self.unique_dates.copy()
|
|||
|
|
if self.shuffle_days:
|
|||
|
|
np.random.shuffle(dates) # 打乱日期的训练顺序
|
|||
|
|
|
|||
|
|
for date in dates:
|
|||
|
|
indices = self.date_to_indices[date].copy()
|
|||
|
|
# 在截面内打乱股票顺序,防止顺序带来的隐性 bias
|
|||
|
|
np.random.shuffle(indices)
|
|||
|
|
yield indices
|
|||
|
|
|
|||
|
|
def __len__(self):
|
|||
|
|
"""返回批次数量(等于日期数量)"""
|
|||
|
|
return len(self.unique_dates)
|