- 新增 CrossSectionSampler 支持截面数据采样(按交易日批处理) - 新增 EnsembleQuantLoss (Huber + IC) 替代 MSE 作为损失函数 - 重构 TabMModel 支持量化场景:Rank IC 作为验证指标、CosineAnnealingLR学习率调度、梯度裁剪 - 支持 date_col 参数和特征对齐 - 更新实验配置 batch_size 2048 和 weight_decay 等超参数
60 lines
1.9 KiB
Python
60 lines
1.9 KiB
Python
"""截面数据采样器
|
|
|
|
用于量化交易场景,确保每个批次包含同一天的所有股票数据。
|
|
实现横截面批处理,避免全局随机打乱。
|
|
"""
|
|
|
|
import numpy as np
|
|
from torch.utils.data import Sampler
|
|
|
|
|
|
class CrossSectionSampler(Sampler):
|
|
"""截面数据采样器
|
|
|
|
保证每次产出的 indices 属于同一个交易日。
|
|
适用于量化选股场景:模型每轮前向传播面对的都是当天的全市场股票。
|
|
|
|
Attributes:
|
|
date_to_indices: 日期到索引列表的映射
|
|
unique_dates: 唯一日期列表
|
|
shuffle_days: 是否打乱日期顺序
|
|
"""
|
|
|
|
def __init__(self, dates: np.ndarray, shuffle_days: bool = True):
|
|
"""初始化采样器
|
|
|
|
Args:
|
|
dates: 日期数组,每个元素对应一行数据
|
|
shuffle_days: 是否打乱日期的训练顺序,但同一天的数据始终在一起
|
|
"""
|
|
# 记录每个日期对应的所有行索引
|
|
self.date_to_indices = {}
|
|
for idx, date in enumerate(dates):
|
|
date_str = str(date)
|
|
if date_str not in self.date_to_indices:
|
|
self.date_to_indices[date_str] = []
|
|
self.date_to_indices[date_str].append(idx)
|
|
|
|
self.unique_dates = list(self.date_to_indices.keys())
|
|
self.shuffle_days = shuffle_days
|
|
|
|
def __iter__(self):
|
|
"""迭代生成批次索引
|
|
|
|
Yields:
|
|
list: 同一日期的所有样本索引
|
|
"""
|
|
dates = self.unique_dates.copy()
|
|
if self.shuffle_days:
|
|
np.random.shuffle(dates) # 打乱日期的训练顺序
|
|
|
|
for date in dates:
|
|
indices = self.date_to_indices[date].copy()
|
|
# 在截面内打乱股票顺序,防止顺序带来的隐性 bias
|
|
np.random.shuffle(indices)
|
|
yield indices
|
|
|
|
def __len__(self):
|
|
"""返回批次数量(等于日期数量)"""
|
|
return len(self.unique_dates)
|