feat(training): 新增 TabM 排序学习模型支持并优化训练流程
- 新增 TabMRankModel、TabMRankTask 及配套损失函数与配置 - 将 DataQualityAnalyzer 从 experiment 迁移至 training 模块 - 调整数据处理器移除过度的 NaN/null 硬填充逻辑 - 优化 RankTask 评估指标使用分位数标签替代原始收益率 - 更新实验脚本处理器顺序与模型超参数配置
This commit is contained in:
126
src/experiment/README_TABM_TOPK.md
Normal file
126
src/experiment/README_TABM_TOPK.md
Normal file
@@ -0,0 +1,126 @@
|
||||
# TabM Top-K 优化指南
|
||||
|
||||
针对每日只买入预测分最高的 5-10 只股票的 Top-K 选股场景,对 TabM 排序学习进行了三大方向优化。
|
||||
|
||||
## 优化概述
|
||||
|
||||
### 1. 损失函数优化(核心)
|
||||
|
||||
**加权 ListNet (推荐首次尝试)**
|
||||
- 原理:给予高分标签样本更高的损失权重
|
||||
- 参数:`loss_type="weighted_listnet"`, `topk_weight=5.0`
|
||||
- 效果:头部样本权重是尾部的 5 倍,强迫模型关注头部排序
|
||||
|
||||
**LambdaLoss (精细 Top-K 优化)**
|
||||
- 原理:基于 DeltaNDCG 计算每对样本交换位置后的损失
|
||||
- 参数:`loss_type="lambda"`, `lambda_sigma=1.0`, `ndcg_weight_power=1.5`
|
||||
- 效果:精准优化 NDCG@K 指标,适合追求极致 Top-K 性能
|
||||
|
||||
### 2. 标签工程优化(增强)
|
||||
|
||||
**指数化增益变换**
|
||||
- 公式:`Gain = 2^(rank/scale) - 1`
|
||||
- 参数:`label_transform="exponential"`, `label_scale=20.0`
|
||||
- 效果:rank=0 → 0, rank=19 → ~0.93, rank=99 → ~30.5
|
||||
- 用途:拉大高分样本与低分样本的差距,强化头部区分度
|
||||
|
||||
### 3. 推荐配置组合
|
||||
|
||||
```python
|
||||
# 配置 A: 温和优化(推荐首次尝试)
|
||||
MODEL_PARAMS = {
|
||||
"loss_type": "weighted_listnet",
|
||||
"topk_weight": 3.0,
|
||||
# ... 其他参数
|
||||
}
|
||||
LABEL_TRANSFORM = None # 保持标准分位数
|
||||
|
||||
# 配置 B: 平衡优化(兼顾效果和稳定性)
|
||||
MODEL_PARAMS = {
|
||||
"loss_type": "weighted_listnet",
|
||||
"topk_weight": 5.0,
|
||||
"ndcg_k": 20, # 验证时关注 NDCG@20
|
||||
}
|
||||
LABEL_TRANSFORM = "exponential"
|
||||
LABEL_SCALE = 20.0
|
||||
|
||||
# 配置 C: 激进优化(专注 Top-10)
|
||||
MODEL_PARAMS = {
|
||||
"loss_type": "lambda",
|
||||
"lambda_sigma": 1.0,
|
||||
"ndcg_weight_power": 1.5,
|
||||
"ndcg_k": 10,
|
||||
}
|
||||
N_QUANTILES = 50 # 提高分位数分辨率
|
||||
LABEL_TRANSFORM = "exponential"
|
||||
LABEL_SCALE = 25.0
|
||||
```
|
||||
|
||||
## 使用示例
|
||||
|
||||
在 `tabm_rank_train.py` 中修改配置:
|
||||
|
||||
```python
|
||||
# 分位数配置
|
||||
N_QUANTILES = 30
|
||||
|
||||
# 标签工程配置
|
||||
LABEL_TRANSFORM = "exponential" # 启用指数化增益
|
||||
LABEL_SCALE = 20.0
|
||||
|
||||
# 模型参数配置
|
||||
MODEL_PARAMS = {
|
||||
# ... 基础参数 ...
|
||||
|
||||
# Top-K 优化参数
|
||||
"loss_type": "weighted_listnet", # 或 "lambda"
|
||||
"topk_weight": 5.0, # 仅 weighted_listnet 有效
|
||||
"lambda_sigma": 1.0, # 仅 lambda 有效
|
||||
"ndcg_weight_power": 1.0, # 仅 lambda 有效
|
||||
"ndcg_k": 20, # 验证指标
|
||||
}
|
||||
```
|
||||
|
||||
## 参数说明
|
||||
|
||||
### 损失函数参数
|
||||
|
||||
| 参数 | 类型 | 默认值 | 说明 |
|
||||
|------|------|--------|------|
|
||||
| `loss_type` | str | "listnet" | 损失类型: "listnet"/"weighted_listnet"/"lambda" |
|
||||
| `topk_weight` | float | 5.0 | 头部权重系数 (weighted_listnet),越大越关注头部 |
|
||||
| `lambda_sigma` | float | 1.0 | Sigmoid 陡峭程度 (lambda) |
|
||||
| `ndcg_weight_power` | float | 1.0 | DeltaNDCG 权重幂次 (lambda),>1 进一步放大头部 |
|
||||
| `ndcg_k` | int/None | None | 验证时计算的 NDCG@k,None 表示全局 |
|
||||
|
||||
### 标签工程参数
|
||||
|
||||
| 参数 | 类型 | 默认值 | 说明 |
|
||||
|------|------|--------|------|
|
||||
| `n_quantiles` | int | 20 | 分位数数量,越大分辨率越高 |
|
||||
| `label_transform` | str/None | None | 变换类型: None/"exponential" |
|
||||
| `label_scale` | float | 20.0 | 指数变换缩放因子,控制增益幅度 |
|
||||
|
||||
## 实施建议
|
||||
|
||||
1. **渐进式优化**:
|
||||
- 第1轮:仅启用 `loss_type="weighted_listnet"`, `topk_weight=3.0`
|
||||
- 第2轮:增加标签工程 `label_transform="exponential"`
|
||||
- 第3轮:尝试 `loss_type="lambda"` 精细优化
|
||||
|
||||
2. **监控指标**:
|
||||
- 关注 NDCG@K(K 设为实际 Top-K 大小)
|
||||
- 对比不同配置的回测收益率
|
||||
- 观察训练损失是否稳定下降
|
||||
|
||||
3. **注意事项**:
|
||||
- LambdaLoss 训练更慢,每 epoch 需更多时间
|
||||
- 指数化增益会改变标签分布,可能需要调整学习率
|
||||
- 过高的 topk_weight 可能导致过拟合头部样本
|
||||
|
||||
## 参考论文
|
||||
|
||||
1. **ListNet**: "Learning to Rank: From Pairwise Approach to Listwise Approach" (Cao et al., 2007)
|
||||
2. **LambdaRank**: "From RankNet to LambdaRank to LambdaMART: An Overview" (Burges, 2010)
|
||||
3. **LambdaLoss**: "The LambdaLoss Framework for Ranking Metric Optimization" (Wang et al., 2018)
|
||||
4. **深度学习选股**: "Deep Learning for Stock Selection" (Gu, Kelly, & Xiu, 2020)
|
||||
@@ -270,6 +270,7 @@ SELECTED_FACTORS = [
|
||||
"bottom_cost_stability",
|
||||
"pivot_reversion",
|
||||
"chip_transition",
|
||||
|
||||
# "amivest_liq_20",
|
||||
# "atr_price_impact",
|
||||
# "hui_heubel_ratio",
|
||||
@@ -450,20 +451,41 @@ def prepare_data(
|
||||
return data
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# 输出配置
|
||||
# =============================================================================
|
||||
OUTPUT_DIR = "output"
|
||||
SAVE_PREDICTIONS = True
|
||||
|
||||
# 模型保存配置
|
||||
SAVE_MODEL = False # 是否保存模型
|
||||
MODEL_SAVE_DIR = "models" # 模型保存目录
|
||||
|
||||
# Top N 配置:每日推荐股票数量
|
||||
TOP_N = 5 # 可调整为 10, 20 等
|
||||
|
||||
# 股票池大小配置
|
||||
STOCK_POOL_SIZE = 1000 # 股票池选择市值最小的股票数量
|
||||
|
||||
# 训练数据跳过天数配置
|
||||
TRAIN_SKIP_DAYS = 300 # 跳过训练数据前252天的数据,避免训练初期数据不足的问题
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# 股票池筛选配置
|
||||
# =============================================================================
|
||||
def stock_pool_filter(df: pl.DataFrame) -> pl.Series:
|
||||
def stock_pool_filter(df: pl.DataFrame, n_stocks: int = STOCK_POOL_SIZE) -> pl.Series:
|
||||
"""股票池筛选函数(单日数据)。
|
||||
|
||||
筛选条件:
|
||||
1. 排除创业板(代码以 300 开头)
|
||||
2. 排除科创板(代码以 688 开头)
|
||||
3. 排除北交所(代码以 8、9 或 4 开头)
|
||||
4. 选取当日市值最小的500只股票
|
||||
4. 选取当日市值最小的n_stocks只股票
|
||||
|
||||
Args:
|
||||
df: 单日数据框
|
||||
n_stocks: 选取的股票数量,默认为 STOCK_POOL_SIZE
|
||||
|
||||
Returns:
|
||||
布尔Series,表示哪些股票被选中
|
||||
@@ -477,9 +499,9 @@ def stock_pool_filter(df: pl.DataFrame) -> pl.Series:
|
||||
& ~df["ts_code"].str.starts_with("4") # 排除北交所
|
||||
)
|
||||
|
||||
# 在已筛选的股票中,选取流通市值最小的500只
|
||||
# 在已筛选的股票中,选取流通市值最小的n_stocks只
|
||||
valid_df = df.filter(code_filter)
|
||||
n = min(1000, len(valid_df))
|
||||
n = min(n_stocks, len(valid_df))
|
||||
small_cap_codes = valid_df.sort("circ_mv").head(n)["ts_code"]
|
||||
|
||||
# 返回布尔 Series:是否在被选中的股票中
|
||||
@@ -489,22 +511,6 @@ def stock_pool_filter(df: pl.DataFrame) -> pl.Series:
|
||||
# 定义筛选所需的基础列
|
||||
STOCK_FILTER_REQUIRED_COLUMNS = ["circ_mv"]
|
||||
|
||||
# =============================================================================
|
||||
# 输出配置
|
||||
# =============================================================================
|
||||
OUTPUT_DIR = "output"
|
||||
SAVE_PREDICTIONS = True
|
||||
|
||||
# 模型保存配置
|
||||
SAVE_MODEL = False # 是否保存模型
|
||||
MODEL_SAVE_DIR = "models" # 模型保存目录
|
||||
|
||||
# Top N 配置:每日推荐股票数量
|
||||
TOP_N = 5 # 可调整为 10, 20 等
|
||||
|
||||
# 训练数据跳过天数配置
|
||||
TRAIN_SKIP_DAYS = 300 # 跳过训练数据前252天的数据,避免训练初期数据不足的问题
|
||||
|
||||
|
||||
def get_output_path(model_type: str, test_start: str, test_end: str) -> str:
|
||||
"""生成输出文件路径。
|
||||
|
||||
@@ -54,27 +54,10 @@ N_QUANTILES = 20
|
||||
|
||||
# 排除的因子列表
|
||||
EXCLUDED_FACTORS = [
|
||||
'active_market_cap',
|
||||
'close_vwap_deviation',
|
||||
'sharpe_ratio_20',
|
||||
'upper_shadow_ratio',
|
||||
'volume_ratio_5_20',
|
||||
'GTJA_alpha090',
|
||||
'GTJA_alpha084',
|
||||
'GTJA_alpha066',
|
||||
'GTJA_alpha150',
|
||||
'GTJA_alpha148',
|
||||
'GTJA_alpha106',
|
||||
'GTJA_alpha109',
|
||||
'GTJA_alpha108',
|
||||
'GTJA_alpha176',
|
||||
'GTJA_alpha169',
|
||||
'GTJA_alpha156',
|
||||
'chip_dispersion_70',
|
||||
'winner_rate_cs_rank',
|
||||
'atr_price_impact',
|
||||
'low_vol_days_20',
|
||||
'liquidity_shock_momentum',
|
||||
# 'debt_to_equity',
|
||||
# 'GTJA_alpha016',
|
||||
# 'GTJA_alpha141',
|
||||
|
||||
]
|
||||
|
||||
# LambdaRank 模型参数配置
|
||||
@@ -145,8 +128,8 @@ def main():
|
||||
pipeline = DataPipeline(
|
||||
factor_manager=factor_manager,
|
||||
processor_configs=[
|
||||
(NullFiller, {"strategy": "mean"}),
|
||||
(Winsorizer, {"lower": 0.01, "upper": 0.99}),
|
||||
(NullFiller, {"strategy": "mean"}),
|
||||
(CrossSectionalStandardScaler, {}),
|
||||
],
|
||||
filters=[STFilter(data_router=engine.router)],
|
||||
|
||||
@@ -52,36 +52,36 @@ TRAINING_TYPE = "regression"
|
||||
|
||||
# 排除的因子列表
|
||||
EXCLUDED_FACTORS = [
|
||||
'GTJA_alpha016',
|
||||
'volatility_20',
|
||||
'current_ratio',
|
||||
'GTJA_alpha001',
|
||||
'GTJA_alpha141',
|
||||
'GTJA_alpha129',
|
||||
'GTJA_alpha164',
|
||||
'amivest_liq_20',
|
||||
'GTJA_alpha012',
|
||||
'debt_to_equity',
|
||||
'turnover_deviation',
|
||||
'GTJA_alpha073',
|
||||
'GTJA_alpha043',
|
||||
'GTJA_alpha032',
|
||||
'GTJA_alpha028',
|
||||
'GTJA_alpha090',
|
||||
'GTJA_alpha108',
|
||||
'GTJA_alpha105',
|
||||
'GTJA_alpha091',
|
||||
'GTJA_alpha119',
|
||||
'GTJA_alpha104',
|
||||
'GTJA_alpha163',
|
||||
'GTJA_alpha157',
|
||||
'cost_skewness',
|
||||
'GTJA_alpha176',
|
||||
'chip_transition',
|
||||
'amount_skewness_20',
|
||||
'GTJA_alpha148',
|
||||
'mean_median_dev',
|
||||
'downside_illiq_20',
|
||||
# 'GTJA_alpha016',
|
||||
# 'volatility_20',
|
||||
# 'current_ratio',
|
||||
# 'GTJA_alpha001',
|
||||
# 'GTJA_alpha141',
|
||||
# 'GTJA_alpha129',
|
||||
# 'GTJA_alpha164',
|
||||
# 'amivest_liq_20',
|
||||
# 'GTJA_alpha012',
|
||||
# 'debt_to_equity',
|
||||
# 'turnover_deviation',
|
||||
# 'GTJA_alpha073',
|
||||
# 'GTJA_alpha043',
|
||||
# 'GTJA_alpha032',
|
||||
# 'GTJA_alpha028',
|
||||
# 'GTJA_alpha090',
|
||||
# 'GTJA_alpha108',
|
||||
# 'GTJA_alpha105',
|
||||
# 'GTJA_alpha091',
|
||||
# 'GTJA_alpha119',
|
||||
# 'GTJA_alpha104',
|
||||
# 'GTJA_alpha163',
|
||||
# 'GTJA_alpha157',
|
||||
# 'cost_skewness',
|
||||
# 'GTJA_alpha176',
|
||||
# 'chip_transition',
|
||||
# 'amount_skewness_20',
|
||||
# 'GTJA_alpha148',
|
||||
# 'mean_median_dev',
|
||||
# 'downside_illiq_20',
|
||||
]
|
||||
|
||||
# 模型参数配置
|
||||
@@ -153,15 +153,15 @@ def main():
|
||||
pipeline = DataPipeline(
|
||||
factor_manager=factor_manager,
|
||||
processor_configs=[
|
||||
(NullFiller, {"strategy": "mean"}),
|
||||
(Winsorizer, {"lower": 0.01, "upper": 0.99}),
|
||||
(NullFiller, {"strategy": "mean"}),
|
||||
(StandardScaler, {}),
|
||||
# (CrossSectionalStandardScaler, {}),
|
||||
],
|
||||
label_processor_configs=[
|
||||
# 对 label 进行缩尾处理(去除极端收益率)
|
||||
(Winsorizer, {"lower": 0.05, "upper": 0.95}),
|
||||
# (StandardScaler, {}),
|
||||
(StandardScaler, {}),
|
||||
],
|
||||
filters=[STFilter(data_router=engine.router)],
|
||||
stock_pool_filter_func=stock_pool_filter,
|
||||
|
||||
176
src/experiment/tabm_rank_train.py
Normal file
176
src/experiment/tabm_rank_train.py
Normal file
@@ -0,0 +1,176 @@
|
||||
"""TabM 排序学习训练流程(模块化版本)
|
||||
|
||||
使用新的模块化 Trainer 架构,基于 TabMRankModel 实现排序学习。
|
||||
TabM 使用 ListNet 损失函数,支持集成学习。
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
from src.factors import FactorEngine
|
||||
from src.training import (
|
||||
FactorManager,
|
||||
DataPipeline,
|
||||
NullFiller,
|
||||
Winsorizer,
|
||||
CrossSectionalStandardScaler,
|
||||
)
|
||||
from src.training.tasks.tabm_rank_task import TabMRankTask
|
||||
from src.training.core.trainer_v2 import Trainer
|
||||
from src.training.components.filters import STFilter
|
||||
from src.experiment.common import (
|
||||
SELECTED_FACTORS,
|
||||
FACTOR_DEFINITIONS,
|
||||
LABEL_NAME,
|
||||
LABEL_FACTOR,
|
||||
TRAIN_START,
|
||||
TRAIN_END,
|
||||
VAL_START,
|
||||
VAL_END,
|
||||
TEST_START,
|
||||
TEST_END,
|
||||
stock_pool_filter,
|
||||
STOCK_FILTER_REQUIRED_COLUMNS,
|
||||
OUTPUT_DIR,
|
||||
SAVE_PREDICTIONS,
|
||||
SAVE_MODEL,
|
||||
get_model_save_path,
|
||||
save_model_with_factors,
|
||||
TOP_N,
|
||||
TRAIN_SKIP_DAYS,
|
||||
)
|
||||
|
||||
# 训练类型标识
|
||||
TRAINING_TYPE = "tabm_rank"
|
||||
|
||||
# %%
|
||||
# Label 配置(从 common.py 统一导入)
|
||||
# LABEL_NAME 和 LABEL_FACTOR 已在 common.py 中绑定,只需从 common 导入
|
||||
|
||||
# 分位数配置(提高分辨率以更好地区分头部)
|
||||
N_QUANTILES = 50
|
||||
|
||||
# 【Top-K 优化】标签工程配置 - 默认启用平方增益
|
||||
LABEL_TRANSFORM = "exponential" # 启用平方增益标签 (rank^2)
|
||||
LABEL_SCALE = 20.0 # 保留参数(当前未使用,平方变换不需要缩放)
|
||||
|
||||
# 排除的因子列表
|
||||
EXCLUDED_FACTORS = ["GTJA_alpha041", "GTJA_alpha127"]
|
||||
|
||||
# TabM Rank 模型参数配置(Top-K 优化全部开启,使用 LambdaLoss)
|
||||
MODEL_PARAMS = {
|
||||
# ==================== MLP 结构 ====================
|
||||
"n_blocks": 4, # MLP 层数
|
||||
"d_block": 256, # 每层神经元数
|
||||
"dropout": 0.5, # Dropout 率
|
||||
# ==================== 集成机制 ====================
|
||||
"ensemble_size": 32, # 内置集成大小(模拟 32 个模型集成)
|
||||
# ==================== 训练参数 ====================
|
||||
"learning_rate": 1e-4, # 学习率
|
||||
"weight_decay": 1e-5, # 权重衰减
|
||||
"epochs": 500, # 训练轮数
|
||||
# ==================== 早停 ====================
|
||||
"early_stopping_round": 50, # 早停耐心值
|
||||
# NDCG 评估 - 关注 Top-20
|
||||
"ndcg_k": 20, # 验证时计算 NDCG@20
|
||||
# 【Top-K 优化】损失函数配置 - 使用 LambdaLoss
|
||||
"loss_type": "lambda", # 使用 LambdaLoss 精准优化 Top-K
|
||||
"lambda_sigma": 1.0, # Sigmoid 陡峭程度
|
||||
"ndcg_weight_power": 1.0, # DeltaNDCG 权重幂次,>1 进一步放大头部效应
|
||||
}
|
||||
|
||||
# 日期范围配置
|
||||
date_range = {
|
||||
"train": (TRAIN_START, TRAIN_END),
|
||||
"val": (VAL_START, VAL_END),
|
||||
"test": (TEST_START, TEST_END),
|
||||
}
|
||||
|
||||
# 输出配置
|
||||
output_config = {
|
||||
"output_dir": OUTPUT_DIR,
|
||||
"output_filename": "tabm_rank_output.csv",
|
||||
"save_predictions": SAVE_PREDICTIONS,
|
||||
"save_model": SAVE_MODEL,
|
||||
"model_save_path": get_model_save_path(TRAINING_TYPE),
|
||||
"top_n": TOP_N,
|
||||
}
|
||||
|
||||
|
||||
def main():
|
||||
"""主函数"""
|
||||
print("\n" + "=" * 80)
|
||||
print("TabM 排序学习训练(模块化版本)")
|
||||
print("=" * 80)
|
||||
|
||||
# 1. 创建 FactorEngine
|
||||
print("\n[1] 创建 FactorEngine")
|
||||
engine = FactorEngine()
|
||||
|
||||
# 2. 创建 FactorManager
|
||||
print("\n[2] 创建 FactorManager")
|
||||
factor_manager = FactorManager(
|
||||
selected_factors=SELECTED_FACTORS,
|
||||
factor_definitions=FACTOR_DEFINITIONS,
|
||||
label_factor=LABEL_FACTOR,
|
||||
excluded_factors=EXCLUDED_FACTORS,
|
||||
)
|
||||
|
||||
# 3. 创建 DataPipeline
|
||||
print("\n[3] 创建 DataPipeline")
|
||||
pipeline = DataPipeline(
|
||||
factor_manager=factor_manager,
|
||||
processor_configs=[
|
||||
(Winsorizer, {"lower": 0.01, "upper": 0.99}),
|
||||
(NullFiller, {"strategy": "mean"}),
|
||||
(CrossSectionalStandardScaler, {}),
|
||||
],
|
||||
filters=[STFilter(data_router=engine.router)],
|
||||
stock_pool_filter_func=stock_pool_filter,
|
||||
stock_pool_required_columns=STOCK_FILTER_REQUIRED_COLUMNS,
|
||||
train_skip_days=TRAIN_SKIP_DAYS,
|
||||
)
|
||||
|
||||
# 4. 创建 TabMRankTask
|
||||
print("\n[4] 创建 TabMRankTask")
|
||||
task = TabMRankTask(
|
||||
model_params=MODEL_PARAMS,
|
||||
label_name=LABEL_NAME,
|
||||
n_quantiles=N_QUANTILES,
|
||||
label_transform=LABEL_TRANSFORM,
|
||||
label_scale=LABEL_SCALE,
|
||||
)
|
||||
|
||||
# 5. 创建 Trainer
|
||||
print("\n[5] 创建 Trainer")
|
||||
trainer = Trainer(
|
||||
data_pipeline=pipeline,
|
||||
task=task,
|
||||
output_config=output_config,
|
||||
verbose=True,
|
||||
)
|
||||
|
||||
# 6. 执行训练
|
||||
print("\n[6] 执行训练")
|
||||
results = trainer.run(engine=engine, date_range=date_range)
|
||||
|
||||
# 7. 保存模型和因子信息(如果启用)
|
||||
if SAVE_MODEL:
|
||||
print("\n[7] 保存模型和因子信息")
|
||||
save_model_with_factors(
|
||||
model=task.get_model(),
|
||||
model_path=output_config["model_save_path"],
|
||||
selected_factors=SELECTED_FACTORS,
|
||||
factor_definitions=FACTOR_DEFINITIONS,
|
||||
fitted_processors=pipeline.get_fitted_processors(),
|
||||
)
|
||||
|
||||
print("\n" + "=" * 80)
|
||||
print("训练流程完成!")
|
||||
print(f"结果保存路径: {os.path.join(OUTPUT_DIR, 'tabm_rank_output.csv')}")
|
||||
print("=" * 80)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -39,9 +39,8 @@ from src.experiment.common import (
|
||||
get_model_save_path,
|
||||
save_model_with_factors,
|
||||
TOP_N,
|
||||
TRAIN_SKIP_DAYS,
|
||||
TRAIN_SKIP_DAYS, STOCK_POOL_SIZE,
|
||||
)
|
||||
from src.experiment.data_quality_analyzer import DataQualityAnalyzer
|
||||
|
||||
# 训练类型标识
|
||||
TRAINING_TYPE = "tabm_regression"
|
||||
@@ -54,201 +53,7 @@ TRAINING_TYPE = "tabm_regression"
|
||||
|
||||
# 排除的因子列表(与 LightGBM 回归保持一致)
|
||||
EXCLUDED_FACTORS = [
|
||||
# "GTJA_alpha001",
|
||||
# "GTJA_alpha002",
|
||||
# "GTJA_alpha003",
|
||||
# "GTJA_alpha004",
|
||||
# "GTJA_alpha005",
|
||||
# "GTJA_alpha006",
|
||||
# "GTJA_alpha007",
|
||||
# "GTJA_alpha008",
|
||||
# "GTJA_alpha009",
|
||||
# "GTJA_alpha010",
|
||||
# "GTJA_alpha011",
|
||||
# "GTJA_alpha012",
|
||||
# "GTJA_alpha013",
|
||||
# "GTJA_alpha014",
|
||||
# "GTJA_alpha015",
|
||||
# "GTJA_alpha016",
|
||||
# "GTJA_alpha017",
|
||||
# "GTJA_alpha018",
|
||||
# "GTJA_alpha019",
|
||||
# "GTJA_alpha020",
|
||||
# "GTJA_alpha022",
|
||||
# "GTJA_alpha023",
|
||||
# "GTJA_alpha024",
|
||||
# "GTJA_alpha025",
|
||||
# "GTJA_alpha026",
|
||||
# "GTJA_alpha027",
|
||||
# "GTJA_alpha028",
|
||||
# "GTJA_alpha029",
|
||||
# "GTJA_alpha031",
|
||||
# "GTJA_alpha032",
|
||||
# "GTJA_alpha033",
|
||||
# "GTJA_alpha034",
|
||||
# "GTJA_alpha035",
|
||||
# "GTJA_alpha036",
|
||||
# "GTJA_alpha037",
|
||||
# # "GTJA_alpha038",
|
||||
# "GTJA_alpha039",
|
||||
# "GTJA_alpha040",
|
||||
# "GTJA_alpha041",
|
||||
# "GTJA_alpha042",
|
||||
# "GTJA_alpha043",
|
||||
# "GTJA_alpha044",
|
||||
# "GTJA_alpha045",
|
||||
# "GTJA_alpha046",
|
||||
# "GTJA_alpha047",
|
||||
# "GTJA_alpha048",
|
||||
# "GTJA_alpha049",
|
||||
# "GTJA_alpha050",
|
||||
# "GTJA_alpha051",
|
||||
# "GTJA_alpha052",
|
||||
# "GTJA_alpha053",
|
||||
# "GTJA_alpha054",
|
||||
# "GTJA_alpha056",
|
||||
# "GTJA_alpha057",
|
||||
# "GTJA_alpha058",
|
||||
# "GTJA_alpha059",
|
||||
# "GTJA_alpha060",
|
||||
# "GTJA_alpha061",
|
||||
# "GTJA_alpha062",
|
||||
# "GTJA_alpha063",
|
||||
# "GTJA_alpha064",
|
||||
# "GTJA_alpha065",
|
||||
# "GTJA_alpha066",
|
||||
# "GTJA_alpha067",
|
||||
# "GTJA_alpha068",
|
||||
# "GTJA_alpha070",
|
||||
# "GTJA_alpha071",
|
||||
# "GTJA_alpha072",
|
||||
# "GTJA_alpha073",
|
||||
# "GTJA_alpha074",
|
||||
# "GTJA_alpha076",
|
||||
# "GTJA_alpha077",
|
||||
# "GTJA_alpha078",
|
||||
# "GTJA_alpha079",
|
||||
# "GTJA_alpha080",
|
||||
# "GTJA_alpha081",
|
||||
# "GTJA_alpha082",
|
||||
# "GTJA_alpha083",
|
||||
# "GTJA_alpha084",
|
||||
# "GTJA_alpha085",
|
||||
# "GTJA_alpha086",
|
||||
# "GTJA_alpha087",
|
||||
# "GTJA_alpha088",
|
||||
# "GTJA_alpha089",
|
||||
# "GTJA_alpha090",
|
||||
# "GTJA_alpha091",
|
||||
# "GTJA_alpha092",
|
||||
# "GTJA_alpha093",
|
||||
# "GTJA_alpha094",
|
||||
# "GTJA_alpha095",
|
||||
# "GTJA_alpha096",
|
||||
# "GTJA_alpha097",
|
||||
# "GTJA_alpha098",
|
||||
# "GTJA_alpha099",
|
||||
# "GTJA_alpha100",
|
||||
# "GTJA_alpha101",
|
||||
# "GTJA_alpha102",
|
||||
# "GTJA_alpha103",
|
||||
# "GTJA_alpha104",
|
||||
# "GTJA_alpha105",
|
||||
# "GTJA_alpha106",
|
||||
# "GTJA_alpha107",
|
||||
# "GTJA_alpha108",
|
||||
# "GTJA_alpha109",
|
||||
# "GTJA_alpha110",
|
||||
# "GTJA_alpha111",
|
||||
# "GTJA_alpha112",
|
||||
# # "GTJA_alpha113",
|
||||
# "GTJA_alpha114",
|
||||
# "GTJA_alpha115",
|
||||
# "GTJA_alpha117",
|
||||
# "GTJA_alpha118",
|
||||
# "GTJA_alpha119",
|
||||
# "GTJA_alpha120",
|
||||
# # "GTJA_alpha121",
|
||||
# "GTJA_alpha122",
|
||||
# "GTJA_alpha123",
|
||||
# "GTJA_alpha124",
|
||||
# "GTJA_alpha125",
|
||||
# "GTJA_alpha126",
|
||||
# "GTJA_alpha127",
|
||||
# "GTJA_alpha128",
|
||||
# "GTJA_alpha129",
|
||||
# "GTJA_alpha130",
|
||||
# "GTJA_alpha131",
|
||||
# "GTJA_alpha132",
|
||||
# "GTJA_alpha133",
|
||||
# "GTJA_alpha134",
|
||||
# "GTJA_alpha135",
|
||||
# "GTJA_alpha136",
|
||||
# # "GTJA_alpha138",
|
||||
# "GTJA_alpha139",
|
||||
# # "GTJA_alpha140",
|
||||
# "GTJA_alpha141",
|
||||
# "GTJA_alpha142",
|
||||
# "GTJA_alpha145",
|
||||
# # "GTJA_alpha146",
|
||||
# "GTJA_alpha148",
|
||||
# "GTJA_alpha150",
|
||||
# "GTJA_alpha151",
|
||||
# "GTJA_alpha152",
|
||||
# "GTJA_alpha153",
|
||||
# "GTJA_alpha154",
|
||||
# "GTJA_alpha155",
|
||||
# "GTJA_alpha156",
|
||||
# "GTJA_alpha157",
|
||||
# "GTJA_alpha158",
|
||||
# "GTJA_alpha159",
|
||||
# "GTJA_alpha160",
|
||||
# "GTJA_alpha161",
|
||||
# "GTJA_alpha162",
|
||||
# "GTJA_alpha163",
|
||||
# "GTJA_alpha164",
|
||||
# # "GTJA_alpha165",
|
||||
# "GTJA_alpha166",
|
||||
# "GTJA_alpha167",
|
||||
# "GTJA_alpha168",
|
||||
# "GTJA_alpha169",
|
||||
# "GTJA_alpha170",
|
||||
# "GTJA_alpha171",
|
||||
# "GTJA_alpha173",
|
||||
# "GTJA_alpha174",
|
||||
# "GTJA_alpha175",
|
||||
# "GTJA_alpha176",
|
||||
# "GTJA_alpha177",
|
||||
# "GTJA_alpha178",
|
||||
# "GTJA_alpha179",
|
||||
# "GTJA_alpha180",
|
||||
# # "GTJA_alpha183",
|
||||
# "GTJA_alpha184",
|
||||
# "GTJA_alpha185",
|
||||
# "GTJA_alpha187",
|
||||
# "GTJA_alpha188",
|
||||
# "GTJA_alpha189",
|
||||
# "GTJA_alpha191",
|
||||
# "chip_dispersion_90",
|
||||
# "chip_dispersion_70",
|
||||
# "cost_skewness",
|
||||
# "dispersion_change_20",
|
||||
# "price_to_avg_cost",
|
||||
# "price_to_median_cost",
|
||||
# "mean_median_dev",
|
||||
# "trap_pressure",
|
||||
# "bottom_profit",
|
||||
# "history_position",
|
||||
# "winner_rate_surge_5",
|
||||
# "winner_rate_cs_rank",
|
||||
# "winner_rate_dev_20",
|
||||
# "winner_rate_volatility",
|
||||
# "smart_money_accumulation",
|
||||
# "winner_vol_corr_20",
|
||||
# "cost_base_momentum",
|
||||
# "bottom_cost_stability",
|
||||
# "pivot_reversion",
|
||||
# "chip_transition",
|
||||
|
||||
]
|
||||
|
||||
# TabM 模型参数配置(来自用户提供的示例代码)
|
||||
@@ -256,11 +61,11 @@ MODEL_PARAMS = {
|
||||
# ==================== MLP 结构 ====================
|
||||
"n_blocks": 3, # MLP 层数
|
||||
"d_block": 256, # 每层神经元数
|
||||
"dropout": 0.3, # Dropout 率
|
||||
"dropout": 0.5, # Dropout 率
|
||||
# ==================== 集成机制 ====================
|
||||
"ensemble_size": 32, # 内置集成大小(模拟 32 个模型集成)
|
||||
# ==================== 训练参数 ====================
|
||||
"batch_size": 2048, # 批次大小
|
||||
"batch_size": STOCK_POOL_SIZE * 5, # 批次大小
|
||||
"learning_rate": 1e-3, # 学习率
|
||||
"weight_decay": 1e-5, # 权重衰减
|
||||
"epochs": 100, # 训练轮数
|
||||
@@ -312,13 +117,14 @@ def main():
|
||||
pipeline = DataPipeline(
|
||||
factor_manager=factor_manager,
|
||||
processor_configs=[
|
||||
(NullFiller, {"strategy": "mean"}),
|
||||
(Winsorizer, {"lower": 0.01, "upper": 0.99}), # 先缩尾处理厚尾分布
|
||||
(NullFiller, {"strategy": "mean"}),
|
||||
(StandardScaler, {}), # TabM 需要标准化输入
|
||||
],
|
||||
label_processor_configs=[
|
||||
# 对 label 进行缩尾处理(去除极端收益率)
|
||||
(Winsorizer, {"lower": 0.05, "upper": 0.95}),
|
||||
(Winsorizer, {"lower": 0.01, "upper": 0.99}),
|
||||
(StandardScaler, {}),
|
||||
],
|
||||
filters=[STFilter(data_router=engine.router)],
|
||||
stock_pool_filter_func=stock_pool_filter,
|
||||
|
||||
@@ -43,6 +43,12 @@ from src.training.core import StockPoolManager, Trainer
|
||||
# 工具函数
|
||||
from src.training.utils import check_data_quality
|
||||
|
||||
# 数据质量分析器
|
||||
from src.training.data_quality_analyzer import (
|
||||
DataQualityAnalyzer,
|
||||
analyze_data_quality,
|
||||
)
|
||||
|
||||
# 配置
|
||||
from src.training.config import TrainingConfig
|
||||
|
||||
@@ -85,6 +91,9 @@ __all__ = [
|
||||
"Trainer",
|
||||
# 工具函数
|
||||
"check_data_quality",
|
||||
# 数据质量分析器
|
||||
"DataQualityAnalyzer",
|
||||
"analyze_data_quality",
|
||||
# 配置
|
||||
"TrainingConfig",
|
||||
# 新增:模块化 Trainer 组件(推荐使用)
|
||||
|
||||
@@ -7,6 +7,11 @@ from src.training.components.models.lightgbm import LightGBMModel
|
||||
from src.training.components.models.lightgbm_lambdarank import LightGBMLambdaRankModel
|
||||
from src.training.components.models.tabpfn_model import TabPFNModel
|
||||
from src.training.components.models.tabm_model import TabMModel
|
||||
from src.training.components.models.tabm_rank_model import (
|
||||
TabMRankModel,
|
||||
EnsembleListNetLoss,
|
||||
EnsembleLambdaLoss,
|
||||
)
|
||||
from src.training.components.models.cross_section_sampler import CrossSectionSampler
|
||||
from src.training.components.models.ensemble_quant_loss import EnsembleQuantLoss
|
||||
|
||||
@@ -15,6 +20,9 @@ __all__ = [
|
||||
"LightGBMLambdaRankModel",
|
||||
"TabPFNModel",
|
||||
"TabMModel",
|
||||
"TabMRankModel",
|
||||
"EnsembleListNetLoss",
|
||||
"EnsembleLambdaLoss",
|
||||
"CrossSectionSampler",
|
||||
"EnsembleQuantLoss",
|
||||
]
|
||||
|
||||
747
src/training/components/models/tabm_rank_model.py
Normal file
747
src/training/components/models/tabm_rank_model.py
Normal file
@@ -0,0 +1,747 @@
|
||||
"""TabM 排序模型实现 (TabM Rank)
|
||||
|
||||
基于 TabM (Tabular Multilayer Perceptron with Ensembles) 架构
|
||||
引入 ListNet 列表级排序损失,实现类似 LambdaRank 的截面排序学习。
|
||||
适用于股票未来收益率的截面排序预测。
|
||||
"""
|
||||
|
||||
from typing import Dict, Any, List, Optional, Tuple
|
||||
from pathlib import Path
|
||||
import pickle
|
||||
|
||||
import numpy as np
|
||||
import polars as pl
|
||||
import scipy.stats as stats
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch.optim as optim
|
||||
from torch.utils.data import DataLoader, TensorDataset, Sampler
|
||||
from tabm import TabM
|
||||
|
||||
from src.training.components.base import BaseModel
|
||||
from src.training.registry import register_model
|
||||
|
||||
|
||||
class GroupSampler(Sampler):
|
||||
"""排序学习专用的分组采样器
|
||||
|
||||
确保每个 Batch 包含的是同一个 Query (如同一天) 的所有样本。
|
||||
这与 LightGBM 的 `group` 参数逻辑完全一致。
|
||||
"""
|
||||
|
||||
def __init__(self, group_counts: np.ndarray, shuffle_groups: bool = True):
|
||||
"""初始化分组采样器
|
||||
|
||||
Args:
|
||||
group_counts: 各组的样本数量数组,如 [100, 150, 120]
|
||||
shuffle_groups: 是否打乱组的训练顺序
|
||||
"""
|
||||
self.group_counts = group_counts
|
||||
self.shuffle_groups = shuffle_groups
|
||||
# 计算每组在原数组中的起始和结束边界
|
||||
self.boundaries = np.insert(np.cumsum(group_counts), 0, 0)
|
||||
self.num_groups = len(group_counts)
|
||||
|
||||
def __iter__(self):
|
||||
"""迭代生成批次索引
|
||||
|
||||
Yields:
|
||||
list: 同一组的所有样本索引
|
||||
"""
|
||||
group_indices = list(range(self.num_groups))
|
||||
if self.shuffle_groups:
|
||||
np.random.shuffle(group_indices)
|
||||
|
||||
for g_idx in group_indices:
|
||||
start = self.boundaries[g_idx]
|
||||
end = self.boundaries[g_idx + 1]
|
||||
# 返回该组(天)内所有样本的索引,作为一个完整的 batch
|
||||
yield list(range(start, end))
|
||||
|
||||
def __len__(self):
|
||||
"""返回组数量"""
|
||||
return self.num_groups
|
||||
|
||||
|
||||
class EnsembleListNetLoss(nn.Module):
|
||||
"""集成 ListNet 排序损失 (Listwise Ranking Loss)
|
||||
|
||||
基于交叉熵将截面上的相关度转化为概率分布。
|
||||
支持 TabM 的 Ensemble 维度。
|
||||
"""
|
||||
|
||||
def __init__(self, topk_weight: float = 1.0):
|
||||
"""初始化 ListNet 损失
|
||||
|
||||
Args:
|
||||
topk_weight: 头部样本权重系数,>1.0 时强化对高分标签的关注
|
||||
"""
|
||||
super().__init__()
|
||||
self.topk_weight = topk_weight
|
||||
|
||||
def forward(self, preds: torch.Tensor, targets: torch.Tensor):
|
||||
"""计算 ListNet 排序损失
|
||||
|
||||
Args:
|
||||
preds: [BatchSize, EnsembleSize] - 当前组所有样本的预测 logits
|
||||
targets: [BatchSize] - 当前组所有样本的真实排序标签 (0, 1, 2...)
|
||||
|
||||
Returns:
|
||||
标量损失值
|
||||
"""
|
||||
# 如果该组内样本太少,无法排序,直接返回 0
|
||||
if preds.size(0) <= 1:
|
||||
return torch.tensor(0.0, requires_grad=True, device=preds.device)
|
||||
|
||||
# [BatchSize] -> [BatchSize, EnsembleSize] 广播以对齐每个集成成员
|
||||
targets_expanded = targets.unsqueeze(1).expand_as(preds)
|
||||
|
||||
# 1. 计算真实标签的分数概率分布 (Softmax 使得高分权重成倍提升)
|
||||
targets_prob = F.softmax(targets_expanded, dim=0)
|
||||
|
||||
# 2. 【Top-K 优化】如果启用加权,给予高分标签更高权重
|
||||
if self.topk_weight > 1.0:
|
||||
# 基于标签值计算权重(标签越高,权重越大)
|
||||
# 归一化到 [1, topk_weight] 范围
|
||||
max_target = targets.max()
|
||||
min_target = targets.min()
|
||||
if max_target > min_target:
|
||||
# 线性插值:权重 = 1 + (target - min) / (max - min) * (topk_weight - 1)
|
||||
sample_weights = 1.0 + (targets - min_target) / (
|
||||
max_target - min_target
|
||||
) * (self.topk_weight - 1.0)
|
||||
sample_weights = sample_weights.unsqueeze(1).expand_as(preds)
|
||||
# 归一化权重使其和为样本数
|
||||
sample_weights = sample_weights * len(targets) / sample_weights.sum()
|
||||
# 应用权重到目标概率
|
||||
targets_prob = targets_prob * sample_weights
|
||||
|
||||
# 3. 计算预测值的对数概率分布 (log_softmax 数值上比 log(softmax) 更稳定)
|
||||
preds_log_prob = F.log_softmax(preds, dim=0)
|
||||
|
||||
# 4. 计算交叉熵损失 (在 Batch/组 维度求和)
|
||||
loss = -torch.sum(targets_prob * preds_log_prob, dim=0) # [EnsembleSize]
|
||||
|
||||
# 5. 对所有集成成员的 Loss 取平均
|
||||
return loss.mean()
|
||||
|
||||
|
||||
class EnsembleLambdaLoss(nn.Module):
|
||||
"""集成 LambdaLoss (支持 TabM 集成维度)
|
||||
|
||||
基于 Pairwise 排序损失,引入 DeltaNDCG 权重。
|
||||
参考: "The LambdaLoss Framework for Ranking Metric Optimization" (Google Research)
|
||||
|
||||
特点:
|
||||
- 计算每对样本交换位置后对 NDCG 的影响
|
||||
- 对头部样本(高 Gain 且排名靠前)给予更高权重
|
||||
- 更适合 Top-K 选股场景
|
||||
"""
|
||||
|
||||
def __init__(self, sigma: float = 1.0, ndcg_weight_power: float = 1.0):
|
||||
"""初始化 LambdaLoss
|
||||
|
||||
Args:
|
||||
sigma: Sigmoid 函数的陡峭程度,控制梯度大小
|
||||
ndcg_weight_power: DeltaNDCG 权重幂次,>1 时进一步放大头部效应
|
||||
"""
|
||||
super().__init__()
|
||||
self.sigma = sigma
|
||||
self.ndcg_weight_power = ndcg_weight_power
|
||||
|
||||
def forward(self, preds: torch.Tensor, targets: torch.Tensor):
|
||||
"""计算 LambdaLoss
|
||||
|
||||
Args:
|
||||
preds: [BatchSize, EnsembleSize] - 预测分
|
||||
targets: [BatchSize] - 相关性标签 (Gain)
|
||||
|
||||
Returns:
|
||||
标量损失值
|
||||
"""
|
||||
if preds.size(0) <= 1:
|
||||
return torch.tensor(0.0, requires_grad=True, device=preds.device)
|
||||
|
||||
# 1. 计算两两对之间的差值
|
||||
preds_diff = preds.unsqueeze(1) - preds.unsqueeze(0) # [B, B, E]
|
||||
|
||||
# 【性能优化】: target_diff 不需要 E 维度!直接保持 [B, B]
|
||||
target_diff = targets.unsqueeze(1) - targets.unsqueeze(0) # [B, B]
|
||||
|
||||
# 2. 掩码矩阵: [B, B, 1] 方便后续广播
|
||||
mask = (target_diff > 0).float().unsqueeze(2) # [B, B, 1]
|
||||
|
||||
# 3. 计算 Delta NDCG
|
||||
# DeltaNDCG = |Gain_i - Gain_j| * |1/log(rank_i+1) - 1/log(rank_j+1)|
|
||||
with torch.no_grad():
|
||||
# 【性能核爆优化】: 使用两次 argsort 完全消灭 for 循环
|
||||
# 第1次 argsort: 获得从大到小的索引
|
||||
# 第2次 argsort: 直接将索引反转为排名 (加上 1 就是真实名次)
|
||||
ranks = preds.argsort(dim=0, descending=True).argsort(dim=0) + 1
|
||||
ranks = ranks.float() # [B, E]
|
||||
|
||||
# 计算位置惩罚项 (log2 排名)
|
||||
log_rank = torch.log2(ranks + 1)
|
||||
inv_log_rank_diff = torch.abs(
|
||||
1.0 / log_rank.unsqueeze(1) - 1.0 / log_rank.unsqueeze(0)
|
||||
) # [B, B, E]
|
||||
|
||||
# DeltaNDCG = |Gain 差| * |位置惩罚差|
|
||||
# target_diff 是 [B, B],通过 unsqueeze 广播到 [B, B, 1]
|
||||
delta_ndcg = torch.abs(target_diff).unsqueeze(2) * inv_log_rank_diff
|
||||
|
||||
# 应用幂次调整权重分布
|
||||
if self.ndcg_weight_power != 1.0:
|
||||
delta_ndcg = torch.pow(delta_ndcg, self.ndcg_weight_power)
|
||||
|
||||
# 4. Pairwise Logistic Loss 并加权 Delta NDCG
|
||||
# loss = delta_ndcg * log(1 + exp(-sigma * (preds_i - preds_j)))
|
||||
pairwise_loss = F.binary_cross_entropy_with_logits(
|
||||
self.sigma * preds_diff,
|
||||
torch.ones_like(preds_diff),
|
||||
reduction="none",
|
||||
)
|
||||
|
||||
# 5. 应用掩码和权重
|
||||
weighted_loss = pairwise_loss * mask * delta_ndcg
|
||||
|
||||
# 避免全 0 导致的除零错误
|
||||
valid_pairs_count = mask.sum().clamp(min=1.0)
|
||||
mean_loss = weighted_loss.sum() / valid_pairs_count
|
||||
|
||||
return mean_loss
|
||||
|
||||
|
||||
@register_model("tabm_rank")
|
||||
class TabMRankModel(BaseModel):
|
||||
"""TabM 学习排序模型
|
||||
|
||||
基于 TabM 架构的排序学习模型,支持 ListNet 损失。
|
||||
适用于股票截面排序任务,将未来收益率转换为分位数标签进行训练。
|
||||
|
||||
特点:
|
||||
- 使用 ListNet 列表级排序损失
|
||||
- 支持 group 参数进行分组训练
|
||||
- 以 NDCG 作为验证指标
|
||||
- 与 LightGBMLambdaRank 接口兼容
|
||||
"""
|
||||
|
||||
name = "tabm_rank"
|
||||
|
||||
def __init__(self, params: Optional[Dict[str, Any]] = None):
|
||||
"""初始化 TabM Rank 模型
|
||||
|
||||
Args:
|
||||
params: 模型参数字典,包含:
|
||||
- ensemble_size: 集成大小 (默认: 32)
|
||||
- n_blocks: MLP层数 (默认: 3)
|
||||
- d_block: 每层神经元数 (默认: 256)
|
||||
- dropout: Dropout率 (默认: 0.1)
|
||||
- batch_size: 批次大小 (默认: 2048,仅预测时使用)
|
||||
- learning_rate: 学习率 (默认: 1e-3)
|
||||
- weight_decay: 权重衰减 (默认: 1e-5)
|
||||
- epochs: 训练轮数 (默认: 50)
|
||||
- early_stopping_round: 早停轮数 (默认: 10)
|
||||
- max_grad_norm: 梯度裁剪阈值 (默认: 1.0)
|
||||
- ndcg_k: NDCG@k 的 k 值,None 表示全局 (默认: None)
|
||||
- loss_type: 损失函数类型 (默认: "listnet")
|
||||
- "listnet": 标准 ListNet 损失
|
||||
- "weighted_listnet": 加权 ListNet,通过 topk_weight 强化头部
|
||||
- "lambda": LambdaLoss,基于 DeltaNDCG 加权
|
||||
- topk_weight: 头部样本权重系数,用于 weighted_listnet (默认: 5.0)
|
||||
- lambda_sigma: LambdaLoss 的 sigma 参数 (默认: 1.0)
|
||||
- ndcg_weight_power: DeltaNDCG 权重幂次 (默认: 1.0)
|
||||
"""
|
||||
self.params = params or {}
|
||||
self.model = None
|
||||
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
self.training_history_: Dict[str, List[float]] = {
|
||||
"train_loss": [],
|
||||
"val_ndcg": [],
|
||||
}
|
||||
self.feature_names_: Optional[List[str]] = None
|
||||
|
||||
# 根据配置选择损失函数
|
||||
loss_type = self.params.get("loss_type", "listnet")
|
||||
if loss_type == "lambda":
|
||||
self.criterion = EnsembleLambdaLoss(
|
||||
sigma=self.params.get("lambda_sigma", 1.0),
|
||||
ndcg_weight_power=self.params.get("ndcg_weight_power", 1.0),
|
||||
)
|
||||
elif loss_type == "weighted_listnet":
|
||||
self.criterion = EnsembleListNetLoss(
|
||||
topk_weight=self.params.get("topk_weight", 5.0)
|
||||
)
|
||||
else: # "listnet"
|
||||
self.criterion = EnsembleListNetLoss(topk_weight=1.0)
|
||||
|
||||
def _make_loader(
|
||||
self,
|
||||
X: np.ndarray,
|
||||
y: Optional[np.ndarray] = None,
|
||||
group: Optional[np.ndarray] = None,
|
||||
shuffle_groups: bool = False,
|
||||
) -> DataLoader:
|
||||
"""创建 DataLoader (支持 Query/Group 截面打包)
|
||||
|
||||
Args:
|
||||
X: 特征数组 [N, n_features]
|
||||
y: 标签数组 [N] 或 None
|
||||
group: 分组数组,表示每组样本数
|
||||
shuffle_groups: 是否打乱组的顺序
|
||||
|
||||
Returns:
|
||||
DataLoader 实例
|
||||
"""
|
||||
# 【性能核爆优化】: 显存管够时,直接把全量数据一把推进 GPU,避免训练时 PCIe 搬运
|
||||
X_tensor = torch.from_numpy(X).to(self.device)
|
||||
|
||||
if y is not None:
|
||||
y_tensor = torch.from_numpy(y).to(self.device)
|
||||
dataset = TensorDataset(X_tensor, y_tensor)
|
||||
else:
|
||||
dataset = TensorDataset(X_tensor)
|
||||
|
||||
if group is not None:
|
||||
# 训练和验证时使用 GroupSampler,每个 batch 就是一个 Query
|
||||
sampler = GroupSampler(group, shuffle_groups=shuffle_groups)
|
||||
return DataLoader(dataset, batch_sampler=sampler)
|
||||
else:
|
||||
# 预测时如果没有 group,则退化为普通批次预测
|
||||
batch_size = self.params.get("batch_size", 2048)
|
||||
return DataLoader(dataset, batch_size=batch_size, shuffle=False)
|
||||
|
||||
def _validate_ndcg(self, val_loader: DataLoader, k: Optional[int] = None) -> float:
|
||||
"""验证模型 (使用 NDCG 排序指标)
|
||||
|
||||
Args:
|
||||
val_loader: 验证数据加载器
|
||||
k: NDCG@k 的 k 值,None 表示计算全局 NDCG
|
||||
|
||||
Returns:
|
||||
平均 NDCG 分数
|
||||
"""
|
||||
from sklearn.metrics import ndcg_score
|
||||
|
||||
assert self.model is not None, "模型未训练,无法验证"
|
||||
|
||||
self.model.eval()
|
||||
ndcg_list = []
|
||||
|
||||
with torch.no_grad():
|
||||
for batch in val_loader:
|
||||
if len(batch) != 2:
|
||||
continue
|
||||
|
||||
bx, by = batch
|
||||
bx = bx.to(self.device)
|
||||
by = by.cpu().numpy()
|
||||
|
||||
if len(by) <= 1:
|
||||
continue
|
||||
|
||||
outputs = self.model(bx) # [B, E, 1]
|
||||
preds = outputs.mean(dim=1).squeeze(-1).cpu().numpy() # [B]
|
||||
|
||||
try:
|
||||
# ndcg_score 需要形状为 (1, n_samples) 的二维数组
|
||||
score = ndcg_score([by], [preds], k=k)
|
||||
ndcg_list.append(score)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
return float(np.mean(ndcg_list)) if len(ndcg_list) > 0 else 0.0
|
||||
|
||||
def fit(
|
||||
self,
|
||||
X: pl.DataFrame,
|
||||
y: pl.Series,
|
||||
group: Optional[np.ndarray] = None,
|
||||
eval_set: Optional[Tuple] = None,
|
||||
) -> "TabMRankModel":
|
||||
"""训练排序模型
|
||||
|
||||
Args:
|
||||
X: 训练特征DataFrame
|
||||
y: 训练标签 (Polars Series),应为分位数标签 (0, 1, 2, ...)
|
||||
group: 分组数组,表示每个 query 的样本数
|
||||
eval_set: 验证集元组 (X_val, y_val, group_val),用于早停
|
||||
|
||||
Returns:
|
||||
self (支持链式调用)
|
||||
|
||||
Raises:
|
||||
ValueError: group 参数无效
|
||||
"""
|
||||
self.feature_names_ = list(X.columns)
|
||||
|
||||
X_np = X.to_numpy().astype(np.float32)
|
||||
y_np = y.to_numpy().astype(np.float32)
|
||||
|
||||
# 检查和处理 group 参数
|
||||
if group is None:
|
||||
group = np.array([len(y_np)])
|
||||
if group.sum() != len(y_np):
|
||||
raise ValueError(
|
||||
f"group 数组的和 ({group.sum()}) 必须等于样本数 ({len(y_np)})"
|
||||
)
|
||||
|
||||
train_loader = self._make_loader(X_np, y_np, group=group, shuffle_groups=True)
|
||||
|
||||
val_loader = None
|
||||
if eval_set is not None:
|
||||
X_val, y_val, group_val = eval_set
|
||||
X_val_np = (
|
||||
X_val.to_numpy().astype(np.float32)
|
||||
if isinstance(X_val, pl.DataFrame)
|
||||
else X_val
|
||||
)
|
||||
y_val_np = (
|
||||
y_val.to_numpy().astype(np.float32)
|
||||
if isinstance(y_val, pl.Series)
|
||||
else y_val
|
||||
)
|
||||
|
||||
if group_val is None:
|
||||
group_val = np.array([len(y_val_np)])
|
||||
val_loader = self._make_loader(
|
||||
X_val_np, y_val_np, group=group_val, shuffle_groups=False
|
||||
)
|
||||
|
||||
ensemble_size = self.params.get("ensemble_size", 32)
|
||||
n_features = X_np.shape[1]
|
||||
|
||||
# 初始化 TabM 模型
|
||||
self.model = TabM.make(
|
||||
n_num_features=n_features,
|
||||
cat_cardinalities=[],
|
||||
d_out=1,
|
||||
n_blocks=self.params.get("n_blocks", 3),
|
||||
d_block=self.params.get("d_block", 256),
|
||||
dropout=self.params.get("dropout", 0.1),
|
||||
k=ensemble_size,
|
||||
).to(self.device)
|
||||
|
||||
optimizer = optim.AdamW(
|
||||
self.model.parameters(),
|
||||
lr=self.params.get("learning_rate", 1e-3),
|
||||
weight_decay=self.params.get("weight_decay", 1e-5),
|
||||
)
|
||||
|
||||
epochs = self.params.get("epochs", 50)
|
||||
scheduler = optim.lr_scheduler.CosineAnnealingLR(
|
||||
optimizer, T_max=epochs, eta_min=1e-6
|
||||
)
|
||||
|
||||
early_stopping_patience = self.params.get(
|
||||
"early_stopping_patience",
|
||||
self.params.get("early_stopping_round", 10),
|
||||
)
|
||||
best_val_ndcg = -float("inf")
|
||||
patience_counter = 0
|
||||
best_model_state = None
|
||||
ndcg_k = self.params.get("ndcg_k", None) # None 表示计算全局 NDCG
|
||||
|
||||
print(f"[TabMRank] 开始训练... 设备: {self.device}, 集成大小: {ensemble_size}")
|
||||
|
||||
for epoch in range(epochs):
|
||||
# 训练阶段
|
||||
self.model.train()
|
||||
train_loss = 0.0
|
||||
n_train_batches = 0
|
||||
|
||||
for batch in train_loader:
|
||||
if len(batch) != 2:
|
||||
continue
|
||||
bx, by = batch[0], batch[1]
|
||||
|
||||
optimizer.zero_grad()
|
||||
outputs = self.model(bx) # [B, E, 1]
|
||||
outputs_squeezed = outputs.squeeze(-1) # [B, E]
|
||||
|
||||
# 计算 ListNet 排序损失
|
||||
loss = self.criterion(outputs_squeezed, by)
|
||||
loss.backward()
|
||||
|
||||
torch.nn.utils.clip_grad_norm_(
|
||||
self.model.parameters(),
|
||||
max_norm=self.params.get("max_grad_norm", 1.0),
|
||||
)
|
||||
optimizer.step()
|
||||
|
||||
train_loss += loss.item()
|
||||
n_train_batches += 1
|
||||
|
||||
avg_train_loss = train_loss / max(n_train_batches, 1)
|
||||
self.training_history_["train_loss"].append(avg_train_loss)
|
||||
|
||||
# 验证阶段 (基于 NDCG)
|
||||
if val_loader is not None:
|
||||
val_ndcg = self._validate_ndcg(val_loader, k=ndcg_k)
|
||||
self.training_history_["val_ndcg"].append(val_ndcg)
|
||||
|
||||
if val_ndcg > best_val_ndcg:
|
||||
best_val_ndcg = val_ndcg
|
||||
patience_counter = 0
|
||||
best_model_state = {
|
||||
k: v.cpu().clone() for k, v in self.model.state_dict().items()
|
||||
}
|
||||
else:
|
||||
patience_counter += 1
|
||||
|
||||
if (epoch + 1) % 5 == 0 or epoch == 0:
|
||||
print(
|
||||
f"[TabMRank] Epoch {epoch + 1}/{epochs} | "
|
||||
f"Train Loss (ListNet): {avg_train_loss:.4f} | "
|
||||
f"Val NDCG: {val_ndcg:.4f} (Best: {best_val_ndcg:.4f})"
|
||||
)
|
||||
|
||||
if patience_counter >= early_stopping_patience:
|
||||
print(f"[TabMRank] 触发早停,停止于 epoch {epoch + 1}")
|
||||
break
|
||||
else:
|
||||
if (epoch + 1) % 5 == 0 or epoch == 0:
|
||||
print(
|
||||
f"[TabMRank] Epoch {epoch + 1}/{epochs} | "
|
||||
f"Train Loss: {avg_train_loss:.4f}"
|
||||
)
|
||||
|
||||
scheduler.step()
|
||||
|
||||
# 恢复最佳权重
|
||||
if best_model_state is not None:
|
||||
self.model.load_state_dict(best_model_state)
|
||||
print(f"[TabMRank] 已恢复最佳模型权重 (Val NDCG: {best_val_ndcg:.4f})")
|
||||
|
||||
return self
|
||||
|
||||
def predict(
|
||||
self, X: pl.DataFrame, group: Optional[np.ndarray] = None
|
||||
) -> np.ndarray:
|
||||
"""预测排序分数
|
||||
|
||||
Args:
|
||||
X: 特征矩阵 (Polars DataFrame)
|
||||
group: 分组数组,表示每个 query 的样本数。
|
||||
如果提供,将使用 GroupSampler 确保预测顺序与分组一致。
|
||||
|
||||
Returns:
|
||||
预测分数 (numpy ndarray)
|
||||
|
||||
Raises:
|
||||
RuntimeError: 模型未训练时调用
|
||||
ValueError: 预测数据缺失特征
|
||||
"""
|
||||
if self.model is None:
|
||||
raise RuntimeError("模型未训练,请先调用fit()")
|
||||
|
||||
# 特征对齐检查
|
||||
if self.feature_names_:
|
||||
missing_cols = [c for c in self.feature_names_ if c not in X.columns]
|
||||
if missing_cols:
|
||||
raise ValueError(f"预测数据缺失特征: {missing_cols}")
|
||||
X = X.select(self.feature_names_)
|
||||
|
||||
X_np = X.to_numpy().astype(np.float32)
|
||||
loader = self._make_loader(X_np, group=group, shuffle_groups=False)
|
||||
|
||||
self.model.eval()
|
||||
all_preds = []
|
||||
|
||||
with torch.no_grad():
|
||||
for batch in loader:
|
||||
bx = batch[0].to(self.device)
|
||||
outputs = self.model(bx) # [B, E, 1]
|
||||
# 排序模型预测时直接输出集成成员的均值作为最终分数
|
||||
preds = outputs.mean(dim=1).squeeze(-1) # [B]
|
||||
all_preds.append(preds.cpu().numpy())
|
||||
|
||||
return np.concatenate(all_preds)
|
||||
|
||||
def get_evals_result(self) -> Optional[Dict[str, List[float]]]:
|
||||
"""获取训练评估结果
|
||||
|
||||
Returns:
|
||||
评估结果字典,包含 train_loss 和 val_ndcg
|
||||
"""
|
||||
return self.training_history_
|
||||
|
||||
def feature_importance(self) -> None:
|
||||
"""获取特征重要性
|
||||
|
||||
TabM没有内置特征重要性计算,返回None。
|
||||
"""
|
||||
return None
|
||||
|
||||
def save(self, path: str | Path) -> None:
|
||||
"""保存模型
|
||||
|
||||
Args:
|
||||
path: 保存路径
|
||||
|
||||
Raises:
|
||||
RuntimeError: 模型未训练时调用
|
||||
"""
|
||||
if self.model is None:
|
||||
raise RuntimeError("模型未训练,无法保存")
|
||||
|
||||
path = Path(path)
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 保存模型权重
|
||||
model_path = path.with_suffix(".pt")
|
||||
torch.save(self.model.state_dict(), model_path)
|
||||
|
||||
# 保存元数据
|
||||
meta_path = path.with_suffix(".meta")
|
||||
meta = {
|
||||
"params": self.params,
|
||||
"feature_names": self.feature_names_,
|
||||
"training_history": self.training_history_,
|
||||
"device": str(self.device),
|
||||
}
|
||||
with open(meta_path, "wb") as f:
|
||||
pickle.dump(meta, f)
|
||||
|
||||
print(f"[TabMRank] 模型保存到: {path}")
|
||||
|
||||
@classmethod
|
||||
def load(cls, path: str | Path) -> "TabMRankModel":
|
||||
"""加载模型
|
||||
|
||||
Args:
|
||||
path: 模型路径(不含扩展名)
|
||||
|
||||
Returns:
|
||||
加载的 TabMRankModel 实例
|
||||
"""
|
||||
path = Path(path)
|
||||
|
||||
# 加载元数据
|
||||
meta_path = path.with_suffix(".meta")
|
||||
with open(meta_path, "rb") as f:
|
||||
meta = pickle.load(f)
|
||||
|
||||
# 创建实例
|
||||
instance = cls(meta["params"])
|
||||
instance.feature_names_ = meta["feature_names"]
|
||||
instance.training_history_ = meta["training_history"]
|
||||
|
||||
# 重建模型结构
|
||||
if instance.feature_names_ is not None:
|
||||
n_features = len(instance.feature_names_)
|
||||
ensemble_size = instance.params.get("ensemble_size", 32)
|
||||
|
||||
instance.model = TabM.make(
|
||||
n_num_features=n_features,
|
||||
cat_cardinalities=[],
|
||||
d_out=1,
|
||||
n_blocks=instance.params.get("n_blocks", 3),
|
||||
d_block=instance.params.get("d_block", 256),
|
||||
dropout=instance.params.get("dropout", 0.1),
|
||||
k=ensemble_size,
|
||||
).to(instance.device)
|
||||
|
||||
# 加载权重
|
||||
model_path = path.with_suffix(".pt")
|
||||
instance.model.load_state_dict(
|
||||
torch.load(model_path, map_location=instance.device)
|
||||
)
|
||||
|
||||
print(f"[TabMRank] 模型从 {path} 加载完成")
|
||||
return instance
|
||||
|
||||
@staticmethod
|
||||
def prepare_group_from_dates(
|
||||
df: pl.DataFrame,
|
||||
date_col: str = "trade_date",
|
||||
) -> np.ndarray:
|
||||
"""从日期列生成 group 数组
|
||||
|
||||
Args:
|
||||
df: 包含日期列的 DataFrame
|
||||
date_col: 日期列名,默认 "trade_date"
|
||||
|
||||
Returns:
|
||||
group 数组
|
||||
"""
|
||||
group_counts = df.group_by(date_col, maintain_order=True).agg(
|
||||
pl.count().alias("count")
|
||||
)
|
||||
return group_counts["count"].to_numpy()
|
||||
|
||||
@staticmethod
|
||||
def convert_to_quantile_labels(
|
||||
df: pl.DataFrame,
|
||||
label_col: str,
|
||||
date_col: str = "trade_date",
|
||||
n_quantiles: int = 20,
|
||||
new_col_name: Optional[str] = None,
|
||||
) -> pl.DataFrame:
|
||||
"""将连续标签转换为分位数标签
|
||||
|
||||
Args:
|
||||
df: 输入 DataFrame
|
||||
label_col: 原始标签列名
|
||||
date_col: 日期列名,默认 "trade_date"
|
||||
n_quantiles: 分位数数量,默认 20
|
||||
new_col_name: 新列名,默认为 {label_col}_rank
|
||||
|
||||
Returns:
|
||||
添加了分位数标签列的 DataFrame
|
||||
"""
|
||||
if new_col_name is None:
|
||||
new_col_name = f"{label_col}_rank"
|
||||
|
||||
return df.with_columns(
|
||||
pl.col(label_col)
|
||||
.qcut(n_quantiles)
|
||||
.over(date_col)
|
||||
.to_physical()
|
||||
.cast(pl.Int64)
|
||||
.alias(new_col_name)
|
||||
)
|
||||
|
||||
def evaluate_ndcg(
|
||||
self,
|
||||
X: pl.DataFrame,
|
||||
y: pl.Series,
|
||||
group: np.ndarray,
|
||||
k: Optional[int] = None,
|
||||
) -> float:
|
||||
"""评估 NDCG 指标
|
||||
|
||||
Args:
|
||||
X: 特征矩阵
|
||||
y: 真实标签
|
||||
group: 分组数组
|
||||
k: NDCG@k 的 k 值
|
||||
|
||||
Returns:
|
||||
NDCG 分数
|
||||
"""
|
||||
from sklearn.metrics import ndcg_score
|
||||
|
||||
y_pred = self.predict(X)
|
||||
|
||||
y_true_list = []
|
||||
y_score_list = []
|
||||
|
||||
start_idx = 0
|
||||
for group_size in group:
|
||||
end_idx = start_idx + group_size
|
||||
y_true_list.append(y.to_numpy()[start_idx:end_idx])
|
||||
y_score_list.append(y_pred[start_idx:end_idx])
|
||||
start_idx = end_idx
|
||||
|
||||
ndcg_scores = []
|
||||
for y_true, y_score in zip(y_true_list, y_score_list):
|
||||
if len(y_true) > 1:
|
||||
try:
|
||||
score = ndcg_score([y_true], [y_score], k=k)
|
||||
ndcg_scores.append(score)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
return float(np.mean(ndcg_scores)) if ndcg_scores else 0.0
|
||||
@@ -289,7 +289,7 @@ class StandardScaler(BaseProcessor):
|
||||
return self
|
||||
|
||||
def transform(self, X: pl.DataFrame) -> pl.DataFrame:
|
||||
"""标准化(使用训练集学到的参数,增加 NaN 保护)
|
||||
"""标准化(使用训练集学到的参数)
|
||||
|
||||
Args:
|
||||
X: 待转换数据
|
||||
@@ -302,18 +302,11 @@ class StandardScaler(BaseProcessor):
|
||||
if col in self.mean_ and col in self.std_:
|
||||
# 避免除以0
|
||||
std_val = self.std_[col] if self.std_[col] != 0 else 1.0
|
||||
# 关键修复:添加 fill_nan(0) 保险,防止计算产生 NaN
|
||||
expr = (
|
||||
((pl.col(col) - self.mean_[col]) / std_val)
|
||||
.fill_nan(0)
|
||||
.fill_null(0)
|
||||
.alias(col)
|
||||
)
|
||||
expr = ((pl.col(col) - self.mean_[col]) / std_val).alias(col)
|
||||
expressions.append(expr)
|
||||
elif col in self.feature_cols:
|
||||
# 对于应该被处理但未学习到统计量的列
|
||||
# 统一转换为float并同时处理 NaN 和 null
|
||||
expr = pl.col(col).cast(pl.Float64).fill_nan(0).fill_null(0).alias(col)
|
||||
# 对于应该被处理但未学习到统计量的列,统一转换为float
|
||||
expr = pl.col(col).cast(pl.Float64).alias(col)
|
||||
expressions.append(expr)
|
||||
else:
|
||||
expressions.append(pl.col(col))
|
||||
@@ -372,20 +365,14 @@ class CrossSectionalStandardScaler(BaseProcessor):
|
||||
if col in self.feature_cols and X[col].dtype.is_numeric():
|
||||
# 截面标准化:每天独立计算均值和标准差
|
||||
# 避免除以0,当std为0时设为1
|
||||
# 关键修复:先 fill_nan 再 fill_null,防止计算产生的 NaN
|
||||
expr = (
|
||||
(
|
||||
(pl.col(col) - pl.col(col).mean().over(self.date_col))
|
||||
/ (pl.col(col).std().over(self.date_col) + 1e-10)
|
||||
)
|
||||
.fill_nan(0)
|
||||
.fill_null(0)
|
||||
.alias(col)
|
||||
)
|
||||
).alias(col)
|
||||
expressions.append(expr)
|
||||
elif col in self.feature_cols:
|
||||
# 对于应该被处理但类型不匹配的列,转换为float并同时处理 NaN 和 null
|
||||
expr = pl.col(col).cast(pl.Float64).fill_nan(0).fill_null(0).alias(col)
|
||||
# 对于应该被处理但类型不匹配的列,转换为float
|
||||
expr = pl.col(col).cast(pl.Float64).alias(col)
|
||||
expressions.append(expr)
|
||||
else:
|
||||
expressions.append(pl.col(col))
|
||||
@@ -488,8 +475,8 @@ class Winsorizer(BaseProcessor):
|
||||
expressions.append(expr)
|
||||
elif col in self.feature_cols:
|
||||
# 对于应该被处理但未学习到边界的列(如全为NaN、布尔列等)
|
||||
# 统一转换为float并填充0
|
||||
expr = pl.col(col).cast(pl.Float64).fill_null(0).alias(col)
|
||||
# 统一转换为float
|
||||
expr = pl.col(col).cast(pl.Float64).alias(col)
|
||||
expressions.append(expr)
|
||||
else:
|
||||
expressions.append(pl.col(col))
|
||||
@@ -522,11 +509,9 @@ class Winsorizer(BaseProcessor):
|
||||
clip_exprs = []
|
||||
for col in X.columns:
|
||||
if col in target_cols:
|
||||
# 先用当天分位数缩尾,如果分位数是null(该日全为NaN)则填充0
|
||||
clipped = (
|
||||
pl.col(col)
|
||||
.clip(pl.col(f"{col}_lower"), pl.col(f"{col}_upper"))
|
||||
.fill_null(0)
|
||||
.alias(col)
|
||||
)
|
||||
clip_exprs.append(clipped)
|
||||
|
||||
@@ -13,6 +13,7 @@ from src.factors import FactorEngine
|
||||
from src.training.pipeline import DataPipeline
|
||||
from src.training.tasks.base import BaseTask
|
||||
from src.training.result_analyzer import ResultAnalyzer
|
||||
from src.training.data_quality_analyzer import DataQualityAnalyzer
|
||||
|
||||
|
||||
class Trainer:
|
||||
@@ -100,8 +101,6 @@ class Trainer:
|
||||
print("\n[Step 1.5/7] 数据质量分析...")
|
||||
|
||||
try:
|
||||
from src.experiment.data_quality_analyzer import DataQualityAnalyzer
|
||||
|
||||
# 获取特征列名(从训练集)
|
||||
feature_cols = data["train"].get("feature_cols", [])
|
||||
label_name = self.task.label_name
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
"""数据质量分析模块
|
||||
|
||||
提供数据质量检查功能,包括:
|
||||
- 数据集日期范围信息
|
||||
- 缺失值统计
|
||||
- 零值统计
|
||||
- 按日期检查全空列
|
||||
@@ -19,6 +20,7 @@ class DataQualityAnalyzer:
|
||||
Attributes:
|
||||
feature_cols: 特征列名列表
|
||||
label_col: 标签列名
|
||||
date_col: 日期列名
|
||||
verbose: 是否打印详细信息
|
||||
"""
|
||||
|
||||
@@ -26,6 +28,7 @@ class DataQualityAnalyzer:
|
||||
self,
|
||||
feature_cols: Optional[List[str]] = None,
|
||||
label_col: Optional[str] = None,
|
||||
date_col: str = "trade_date",
|
||||
verbose: bool = True,
|
||||
):
|
||||
"""初始化数据质量分析器
|
||||
@@ -33,10 +36,12 @@ class DataQualityAnalyzer:
|
||||
Args:
|
||||
feature_cols: 特征列名列表
|
||||
label_col: 标签列名
|
||||
date_col: 日期列名,默认为 "trade_date"
|
||||
verbose: 是否打印详细信息
|
||||
"""
|
||||
self.feature_cols = feature_cols or []
|
||||
self.label_col = label_col
|
||||
self.date_col = date_col
|
||||
self.verbose = verbose
|
||||
self.analysis_results: Dict[str, Any] = {}
|
||||
|
||||
@@ -74,6 +79,10 @@ class DataQualityAnalyzer:
|
||||
|
||||
self.analysis_results = {}
|
||||
|
||||
# 首先打印数据集概览(日期范围等基本信息)
|
||||
if self.verbose:
|
||||
self._print_dataset_overview(data, split_names)
|
||||
|
||||
for split_name in split_names:
|
||||
if split_name not in data:
|
||||
continue
|
||||
@@ -96,6 +105,104 @@ class DataQualityAnalyzer:
|
||||
|
||||
return self.analysis_results
|
||||
|
||||
def _print_dataset_overview(
|
||||
self,
|
||||
data: Dict[str, Dict[str, Any]],
|
||||
split_names: List[str],
|
||||
) -> None:
|
||||
"""打印数据集概览信息
|
||||
|
||||
包括每个数据集的起始日期、终止日期、样本数量等基本信息。
|
||||
|
||||
Args:
|
||||
data: 数据字典
|
||||
split_names: 数据划分名称列表
|
||||
"""
|
||||
print("\n[数据集概览]")
|
||||
print("-" * 40)
|
||||
|
||||
overview_data = []
|
||||
|
||||
for split_name in split_names:
|
||||
if split_name not in data:
|
||||
continue
|
||||
|
||||
split_data = data[split_name]
|
||||
raw_df = split_data.get("raw_data")
|
||||
|
||||
if raw_df is None or len(raw_df) == 0:
|
||||
overview_data.append(
|
||||
{
|
||||
"划分": split_name.upper(),
|
||||
"起始日期": "-",
|
||||
"终止日期": "-",
|
||||
"样本数": 0,
|
||||
"股票数": 0,
|
||||
}
|
||||
)
|
||||
continue
|
||||
|
||||
# 获取日期范围
|
||||
if self.date_col in raw_df.columns:
|
||||
dates = raw_df[self.date_col]
|
||||
start_date = dates.min()
|
||||
end_date = dates.max()
|
||||
unique_dates = dates.n_unique()
|
||||
else:
|
||||
start_date = "-"
|
||||
end_date = "-"
|
||||
unique_dates = 0
|
||||
|
||||
# 获取股票数量
|
||||
if "ts_code" in raw_df.columns:
|
||||
unique_stocks = raw_df["ts_code"].n_unique()
|
||||
else:
|
||||
unique_stocks = 0
|
||||
|
||||
overview_data.append(
|
||||
{
|
||||
"划分": split_name.upper(),
|
||||
"起始日期": str(start_date),
|
||||
"终止日期": str(end_date),
|
||||
"交易日数": unique_dates,
|
||||
"样本数": len(raw_df),
|
||||
"股票数": unique_stocks,
|
||||
}
|
||||
)
|
||||
|
||||
# 打印表格
|
||||
if overview_data:
|
||||
# 计算列宽
|
||||
headers = ["划分", "起始日期", "终止日期", "交易日数", "样本数", "股票数"]
|
||||
col_widths = {}
|
||||
for header in headers:
|
||||
max_data_len = max(
|
||||
len(
|
||||
str(
|
||||
row.get(
|
||||
header.lower().replace("数", ""), row.get(header, "")
|
||||
)
|
||||
)
|
||||
)
|
||||
for row in overview_data
|
||||
)
|
||||
col_widths[header] = max(len(header), max_data_len) + 2
|
||||
|
||||
# 打印表头
|
||||
header_line = " ".join(h.ljust(col_widths[h]) for h in headers)
|
||||
print(f" {header_line}")
|
||||
print(f" {'-' * (sum(col_widths.values()) + 2 * (len(headers) - 1))}")
|
||||
|
||||
# 打印数据行
|
||||
for row in overview_data:
|
||||
line = " ".join(
|
||||
str(row.get(h, row.get(h.lower().replace("数", ""), ""))).ljust(
|
||||
col_widths[h]
|
||||
)
|
||||
for h in headers
|
||||
)
|
||||
print(f" {line}")
|
||||
|
||||
def _analyze_split(
|
||||
self,
|
||||
df: pl.DataFrame,
|
||||
@@ -120,6 +227,16 @@ class DataQualityAnalyzer:
|
||||
"all_null_by_date": {},
|
||||
}
|
||||
|
||||
# 获取日期范围
|
||||
if self.date_col in df.columns:
|
||||
results["start_date"] = str(df[self.date_col].min())
|
||||
results["end_date"] = str(df[self.date_col].max())
|
||||
results["unique_dates"] = df[self.date_col].n_unique()
|
||||
|
||||
# 获取股票数量
|
||||
if "ts_code" in df.columns:
|
||||
results["unique_stocks"] = df["ts_code"].n_unique()
|
||||
|
||||
# 1. 分析特征列的缺失值
|
||||
null_stats = self._analyze_null_values(df, self.feature_cols)
|
||||
results["null_analysis"] = null_stats
|
||||
@@ -249,7 +366,7 @@ class DataQualityAnalyzer:
|
||||
"issues": [],
|
||||
}
|
||||
|
||||
if "trade_date" not in df.columns:
|
||||
if self.date_col not in df.columns:
|
||||
return results
|
||||
|
||||
# 过滤掉不在表中的列
|
||||
@@ -267,7 +384,7 @@ class DataQualityAnalyzer:
|
||||
]
|
||||
agg_exprs.append(pl.len().alias("total_rows"))
|
||||
|
||||
agg_lf = lf.group_by("trade_date").agg(agg_exprs)
|
||||
agg_lf = lf.group_by(self.date_col).agg(agg_exprs)
|
||||
|
||||
# 收集结果 (此时 agg_df 行数通常只有几百到几千行)
|
||||
agg_df = agg_lf.collect()
|
||||
@@ -279,13 +396,13 @@ class DataQualityAnalyzer:
|
||||
# 找出 null 数量等于总行数的日期
|
||||
bad_dates = agg_df.filter(
|
||||
(pl.col(null_col) == pl.col("total_rows")) & (pl.col("total_rows") > 0)
|
||||
).select(["trade_date", "total_rows"])
|
||||
).select([self.date_col, "total_rows"])
|
||||
|
||||
if not bad_dates.is_empty():
|
||||
for row in bad_dates.to_dicts():
|
||||
issues.append(
|
||||
{
|
||||
"date": row["trade_date"],
|
||||
"date": row[self.date_col],
|
||||
"column": col,
|
||||
"total_rows": row["total_rows"],
|
||||
}
|
||||
@@ -464,8 +581,20 @@ class DataQualityAnalyzer:
|
||||
|
||||
for split_name, results in self.analysis_results.items():
|
||||
lines.append(f"\n[{split_name.upper()}]")
|
||||
|
||||
# 添加日期范围信息
|
||||
if "start_date" in results and "end_date" in results:
|
||||
lines.append(
|
||||
f" 日期范围: {results['start_date']} ~ {results['end_date']}"
|
||||
)
|
||||
if "unique_dates" in results:
|
||||
lines.append(f" 交易日数: {results['unique_dates']}")
|
||||
|
||||
lines.append(f" 总行数: {results['total_rows']:,}")
|
||||
|
||||
if "unique_stocks" in results:
|
||||
lines.append(f" 股票数: {results['unique_stocks']}")
|
||||
|
||||
null_stats = results.get("null_analysis", {})
|
||||
if null_stats.get("columns_with_null"):
|
||||
lines.append(
|
||||
@@ -493,6 +622,7 @@ def analyze_data_quality(
|
||||
data: Dict[str, Dict[str, Any]],
|
||||
feature_cols: Optional[List[str]] = None,
|
||||
label_col: Optional[str] = None,
|
||||
date_col: str = "trade_date",
|
||||
verbose: bool = True,
|
||||
) -> Dict[str, Any]:
|
||||
"""便捷函数:执行数据质量分析
|
||||
@@ -501,6 +631,7 @@ def analyze_data_quality(
|
||||
data: 数据字典
|
||||
feature_cols: 特征列名列表
|
||||
label_col: 标签列名
|
||||
date_col: 日期列名,默认为 "trade_date"
|
||||
verbose: 是否打印详细信息
|
||||
|
||||
Returns:
|
||||
@@ -509,6 +640,7 @@ def analyze_data_quality(
|
||||
analyzer = DataQualityAnalyzer(
|
||||
feature_cols=feature_cols,
|
||||
label_col=label_col,
|
||||
date_col=date_col,
|
||||
verbose=verbose,
|
||||
)
|
||||
return analyzer.analyze(data)
|
||||
@@ -7,10 +7,12 @@ from src.training.tasks.base import BaseTask
|
||||
from src.training.tasks.regression_task import RegressionTask
|
||||
from src.training.tasks.rank_task import RankTask
|
||||
from src.training.tasks.tabm_regression_task import TabMRegressionTask
|
||||
from src.training.tasks.tabm_rank_task import TabMRankTask
|
||||
|
||||
__all__ = [
|
||||
"BaseTask",
|
||||
"RegressionTask",
|
||||
"RankTask",
|
||||
"TabMRegressionTask",
|
||||
"TabMRankTask",
|
||||
]
|
||||
|
||||
@@ -153,7 +153,6 @@ class RankTask(BaseTask):
|
||||
if k_list is None:
|
||||
k_list = [1, 5, 10, 20]
|
||||
|
||||
y_true = test_data["y_raw"]
|
||||
y_pred = self.predict(test_data)
|
||||
groups = test_data["groups"]
|
||||
|
||||
@@ -166,9 +165,13 @@ class RankTask(BaseTask):
|
||||
y_true_groups = []
|
||||
y_pred_groups = []
|
||||
|
||||
# 使用分位数标签 y (0-19) 作为真实相关性分数,而非原始收益率 y_raw
|
||||
# 这样与模型学习目标一致,避免原始收益率中负值的影响
|
||||
y_true_array = test_data["y"].to_numpy()
|
||||
|
||||
for group_size in groups:
|
||||
end_idx = start_idx + group_size
|
||||
y_true_groups.append(y_true.to_numpy()[start_idx:end_idx])
|
||||
y_true_groups.append(y_true_array[start_idx:end_idx])
|
||||
y_pred_groups.append(y_pred[start_idx:end_idx])
|
||||
start_idx = end_idx
|
||||
|
||||
|
||||
249
src/training/tasks/tabm_rank_task.py
Normal file
249
src/training/tasks/tabm_rank_task.py
Normal file
@@ -0,0 +1,249 @@
|
||||
"""TabM 排序学习任务实现
|
||||
|
||||
实现基于 TabM 的排序学习训练流程:
|
||||
- Label 转换为分位数标签
|
||||
- 生成 group 数组
|
||||
- 使用 TabMRankModel(基于 ListNet Loss)
|
||||
- 支持 NDCG@k 评估
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, List, Optional
|
||||
import numpy as np
|
||||
import polars as pl
|
||||
|
||||
from src.training.tasks.base import BaseTask
|
||||
from src.training.components.models.tabm_rank_model import TabMRankModel
|
||||
|
||||
|
||||
class TabMRankTask(BaseTask):
|
||||
"""TabM 排序学习任务
|
||||
|
||||
使用 TabMRankModel 进行排序学习训练。
|
||||
将连续收益率转换为分位数标签进行训练。
|
||||
支持指数化增益标签以增强 Top-K 关注。
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_params: Dict[str, Any],
|
||||
label_name: str = "future_return_5",
|
||||
n_quantiles: int = 20,
|
||||
label_transform: Optional[str] = None,
|
||||
label_scale: float = 20.0,
|
||||
):
|
||||
"""初始化排序学习任务
|
||||
|
||||
Args:
|
||||
model_params: TabM 参数字典
|
||||
label_name: Label 列名
|
||||
n_quantiles: 分位数数量
|
||||
label_transform: 标签变换类型,可选:
|
||||
- None: 标准分位数标签 (0, 1, ..., n_quantiles-1)
|
||||
- "exponential": 指数化增益: 2^(rank/scale) - 1
|
||||
label_scale: 指数变换的缩放因子,用于控制增益幅度
|
||||
"""
|
||||
super().__init__(model_params, label_name)
|
||||
self.n_quantiles = n_quantiles
|
||||
self.label_transform = label_transform
|
||||
self.label_scale = label_scale
|
||||
|
||||
def prepare_labels(self, data: Dict[str, Dict]) -> Dict[str, Dict]:
|
||||
"""准备标签(转换为分位数标签,可选指数化增益变换)
|
||||
|
||||
将连续收益率转换为分位数标签,并生成 group 数组。
|
||||
支持指数化增益变换以增强头部样本的区分度。
|
||||
|
||||
Args:
|
||||
data: 数据字典
|
||||
|
||||
Returns:
|
||||
处理后的数据字典(添加了 y_rank 和 groups)
|
||||
"""
|
||||
for split in ["train", "val", "test"]:
|
||||
if split not in data:
|
||||
continue
|
||||
|
||||
df = data[split]["raw_data"]
|
||||
|
||||
# 分位数转换
|
||||
rank_col = f"{self.label_name}_rank"
|
||||
|
||||
# 1. 基础分位数标签 (0 到 n_quantiles-1)
|
||||
df_ranked = df.with_columns(
|
||||
pl.col(self.label_name)
|
||||
.rank(method="min")
|
||||
.over("trade_date")
|
||||
.alias("_rank")
|
||||
).with_columns(
|
||||
((pl.col("_rank") - 1) / pl.len().over("trade_date") * self.n_quantiles)
|
||||
.floor()
|
||||
.cast(pl.Int64)
|
||||
.clip(0, self.n_quantiles - 1)
|
||||
.alias("_base_rank")
|
||||
)
|
||||
|
||||
# 2. 【Top-K 优化】可选指数化增益变换
|
||||
if self.label_transform == "exponential":
|
||||
# 平方变换: rank^2
|
||||
# 例如 rank=0 -> 0, rank=10 -> 100, rank=19 -> 361
|
||||
# 效果:高分样本与低分样本的差距被平方级拉大
|
||||
df_ranked = df_ranked.with_columns(
|
||||
(pl.col("_base_rank").cast(pl.Float64) ** 2).alias(rank_col)
|
||||
)
|
||||
else:
|
||||
# 标准分位数标签
|
||||
df_ranked = df_ranked.with_columns(
|
||||
pl.col("_base_rank").cast(pl.Float64).alias(rank_col)
|
||||
)
|
||||
|
||||
# 清理临时列
|
||||
df_ranked = df_ranked.drop(["_rank", "_base_rank"])
|
||||
|
||||
# 更新数据
|
||||
data[split]["raw_data"] = df_ranked
|
||||
data[split]["y"] = df_ranked[rank_col]
|
||||
data[split]["y_raw"] = df_ranked[self.label_name] # 保留原始值
|
||||
|
||||
# 生成 group 数组
|
||||
data[split]["groups"] = self._compute_group_array(df_ranked, "trade_date")
|
||||
|
||||
return data
|
||||
|
||||
def _compute_group_array(
|
||||
self,
|
||||
df: pl.DataFrame,
|
||||
date_col: str = "trade_date",
|
||||
) -> np.ndarray:
|
||||
"""计算 group 数组
|
||||
|
||||
Args:
|
||||
df: 数据框
|
||||
date_col: 日期列名
|
||||
|
||||
Returns:
|
||||
group 数组(每个日期的样本数)
|
||||
"""
|
||||
group_counts = df.group_by(date_col, maintain_order=True).agg(
|
||||
pl.count().alias("count")
|
||||
)
|
||||
return group_counts["count"].to_numpy()
|
||||
|
||||
def fit(self, train_data: Dict, val_data: Dict) -> None:
|
||||
"""训练排序模型
|
||||
|
||||
Args:
|
||||
train_data: 训练数据
|
||||
val_data: 验证数据
|
||||
"""
|
||||
self.model = TabMRankModel(params=self.model_params)
|
||||
|
||||
self.model.fit(
|
||||
train_data["X"],
|
||||
train_data["y"],
|
||||
group=train_data["groups"],
|
||||
eval_set=(val_data["X"], val_data["y"], val_data["groups"])
|
||||
if val_data
|
||||
else None,
|
||||
)
|
||||
|
||||
def predict(self, test_data: Dict) -> np.ndarray:
|
||||
"""生成预测
|
||||
|
||||
Args:
|
||||
test_data: 测试数据
|
||||
|
||||
Returns:
|
||||
预测结果
|
||||
"""
|
||||
# 传入 groups 参数,确保预测顺序与分组一致,与验证逻辑保持一致
|
||||
return self.model.predict(test_data["X"], group=test_data.get("groups"))
|
||||
|
||||
def evaluate_ndcg(
|
||||
self,
|
||||
test_data: Dict,
|
||||
k_list: List[int] = None,
|
||||
) -> Dict[str, float]:
|
||||
"""评估 NDCG@k
|
||||
|
||||
Args:
|
||||
test_data: 测试数据
|
||||
k_list: k 值列表,默认 [1, 5, 10, 20]
|
||||
|
||||
Returns:
|
||||
NDCG 分数字典 {"ndcg@1": score, ...}
|
||||
"""
|
||||
if k_list is None:
|
||||
k_list = [1, 5, 10, 20]
|
||||
|
||||
y_pred = self.predict(test_data)
|
||||
groups = test_data["groups"]
|
||||
|
||||
from sklearn.metrics import ndcg_score
|
||||
|
||||
results = {}
|
||||
|
||||
# 按 group 拆分
|
||||
start_idx = 0
|
||||
y_true_groups = []
|
||||
y_pred_groups = []
|
||||
|
||||
# 使用分位数标签 y (0-19) 作为真实相关性分数,而非原始收益率 y_raw
|
||||
# 这样与模型学习目标一致,避免原始收益率中负值的影响
|
||||
y_true_array = test_data["y"].to_numpy()
|
||||
|
||||
for group_size in groups:
|
||||
end_idx = start_idx + group_size
|
||||
y_true_groups.append(y_true_array[start_idx:end_idx])
|
||||
y_pred_groups.append(y_pred[start_idx:end_idx])
|
||||
start_idx = end_idx
|
||||
|
||||
# 计算每个 k 的 NDCG
|
||||
for k in k_list:
|
||||
ndcg_scores = []
|
||||
for yt, yp in zip(y_true_groups, y_pred_groups):
|
||||
if len(yt) > 1:
|
||||
try:
|
||||
score = ndcg_score([yt], [yp], k=k)
|
||||
ndcg_scores.append(score)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
results[f"ndcg@{k}"] = float(np.mean(ndcg_scores)) if ndcg_scores else 0.0
|
||||
|
||||
return results
|
||||
|
||||
def plot_training_metrics(self) -> None:
|
||||
"""绘制训练指标曲线(NDCG)"""
|
||||
if self.model and hasattr(self.model, "get_evals_result"):
|
||||
try:
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
evals_result = self.model.get_evals_result()
|
||||
if not evals_result:
|
||||
print("[警告] 没有训练指标数据可供绘制")
|
||||
return
|
||||
|
||||
fig, ax = plt.subplots(1, 2, figsize=(12, 4))
|
||||
|
||||
# 绘制训练损失
|
||||
if "train_loss" in evals_result:
|
||||
ax[0].plot(evals_result["train_loss"], label="Train Loss")
|
||||
ax[0].set_xlabel("Epoch")
|
||||
ax[0].set_ylabel("ListNet Loss")
|
||||
ax[0].set_title("Training Loss")
|
||||
ax[0].legend()
|
||||
ax[0].grid(True)
|
||||
|
||||
# 绘制验证 NDCG
|
||||
if "val_ndcg" in evals_result:
|
||||
ax[1].plot(evals_result["val_ndcg"], label="Val NDCG")
|
||||
ax[1].set_xlabel("Epoch")
|
||||
ax[1].set_ylabel("NDCG")
|
||||
ax[1].set_title("Validation NDCG")
|
||||
ax[1].legend()
|
||||
ax[1].grid(True)
|
||||
|
||||
plt.tight_layout()
|
||||
plt.show()
|
||||
except Exception as e:
|
||||
print(f"[警告] 无法绘制训练曲线: {e}")
|
||||
88
tests/test_tabm_rank_model.py
Normal file
88
tests/test_tabm_rank_model.py
Normal file
@@ -0,0 +1,88 @@
|
||||
"""测试 TabMRankModel 基础功能"""
|
||||
|
||||
import pytest
|
||||
import numpy as np
|
||||
import polars as pl
|
||||
import torch
|
||||
|
||||
from src.training.components.models import TabMRankModel
|
||||
|
||||
|
||||
class TestTabMRankModel:
|
||||
"""TabMRankModel 测试类"""
|
||||
|
||||
def test_model_initialization(self):
|
||||
"""测试模型初始化"""
|
||||
model = TabMRankModel(params={"ensemble_size": 16})
|
||||
assert model.name == "tabm_rank"
|
||||
assert model.params["ensemble_size"] == 16
|
||||
assert model.device in [torch.device("cpu"), torch.device("cuda")]
|
||||
|
||||
def test_prepare_group_from_dates(self):
|
||||
"""测试从日期生成 group 数组"""
|
||||
df = pl.DataFrame(
|
||||
{
|
||||
"trade_date": [
|
||||
"20240101",
|
||||
"20240101",
|
||||
"20240102",
|
||||
"20240102",
|
||||
"20240102",
|
||||
],
|
||||
"value": [1, 2, 3, 4, 5],
|
||||
}
|
||||
)
|
||||
group = TabMRankModel.prepare_group_from_dates(df)
|
||||
assert np.array_equal(group, np.array([2, 3]))
|
||||
|
||||
def test_convert_to_quantile_labels(self):
|
||||
"""测试转换分位数标签"""
|
||||
df = pl.DataFrame(
|
||||
{
|
||||
"trade_date": ["20240101"] * 5 + ["20240102"] * 5,
|
||||
"return": [0.1, 0.05, 0.0, -0.05, -0.1] + [0.2, 0.1, 0.0, -0.1, -0.2],
|
||||
}
|
||||
)
|
||||
result = TabMRankModel.convert_to_quantile_labels(df, "return", n_quantiles=5)
|
||||
assert "return_rank" in result.columns
|
||||
assert result["return_rank"].dtype == pl.Int64
|
||||
|
||||
def test_save_load(self, tmp_path):
|
||||
"""测试模型保存和加载(跳过实际训练)"""
|
||||
params = {
|
||||
"ensemble_size": 8,
|
||||
"n_blocks": 2,
|
||||
"d_block": 64,
|
||||
"epochs": 1,
|
||||
}
|
||||
model = TabMRankModel(params=params)
|
||||
model.feature_names_ = ["feat1", "feat2"]
|
||||
|
||||
# 模拟训练历史
|
||||
model.training_history_["train_loss"] = [0.5, 0.4]
|
||||
model.training_history_["val_ndcg"] = [0.3, 0.35]
|
||||
|
||||
# 测试保存前需要初始化模型
|
||||
# 这里只测试元数据保存
|
||||
save_path = tmp_path / "test_model"
|
||||
import pickle
|
||||
|
||||
with open(save_path.with_suffix(".meta"), "wb") as f:
|
||||
pickle.dump(
|
||||
{
|
||||
"params": model.params,
|
||||
"feature_names": model.feature_names_,
|
||||
"training_history": model.training_history_,
|
||||
"device": str(model.device),
|
||||
},
|
||||
f,
|
||||
)
|
||||
|
||||
# 测试加载元数据
|
||||
loaded_model = TabMRankModel.load(save_path)
|
||||
assert loaded_model.params == params
|
||||
assert loaded_model.feature_names_ == ["feat1", "feat2"]
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
Reference in New Issue
Block a user