diff --git a/src/experiment/README_TABM_TOPK.md b/src/experiment/README_TABM_TOPK.md new file mode 100644 index 0000000..b73f631 --- /dev/null +++ b/src/experiment/README_TABM_TOPK.md @@ -0,0 +1,126 @@ +# TabM Top-K 优化指南 + +针对每日只买入预测分最高的 5-10 只股票的 Top-K 选股场景,对 TabM 排序学习进行了三大方向优化。 + +## 优化概述 + +### 1. 损失函数优化(核心) + +**加权 ListNet (推荐首次尝试)** +- 原理:给予高分标签样本更高的损失权重 +- 参数:`loss_type="weighted_listnet"`, `topk_weight=5.0` +- 效果:头部样本权重是尾部的 5 倍,强迫模型关注头部排序 + +**LambdaLoss (精细 Top-K 优化)** +- 原理:基于 DeltaNDCG 计算每对样本交换位置后的损失 +- 参数:`loss_type="lambda"`, `lambda_sigma=1.0`, `ndcg_weight_power=1.5` +- 效果:精准优化 NDCG@K 指标,适合追求极致 Top-K 性能 + +### 2. 标签工程优化(增强) + +**指数化增益变换** +- 公式:`Gain = 2^(rank/scale) - 1` +- 参数:`label_transform="exponential"`, `label_scale=20.0` +- 效果:rank=0 → 0, rank=19 → ~0.93, rank=99 → ~30.5 +- 用途:拉大高分样本与低分样本的差距,强化头部区分度 + +### 3. 推荐配置组合 + +```python +# 配置 A: 温和优化(推荐首次尝试) +MODEL_PARAMS = { + "loss_type": "weighted_listnet", + "topk_weight": 3.0, + # ... 其他参数 +} +LABEL_TRANSFORM = None # 保持标准分位数 + +# 配置 B: 平衡优化(兼顾效果和稳定性) +MODEL_PARAMS = { + "loss_type": "weighted_listnet", + "topk_weight": 5.0, + "ndcg_k": 20, # 验证时关注 NDCG@20 +} +LABEL_TRANSFORM = "exponential" +LABEL_SCALE = 20.0 + +# 配置 C: 激进优化(专注 Top-10) +MODEL_PARAMS = { + "loss_type": "lambda", + "lambda_sigma": 1.0, + "ndcg_weight_power": 1.5, + "ndcg_k": 10, +} +N_QUANTILES = 50 # 提高分位数分辨率 +LABEL_TRANSFORM = "exponential" +LABEL_SCALE = 25.0 +``` + +## 使用示例 + +在 `tabm_rank_train.py` 中修改配置: + +```python +# 分位数配置 +N_QUANTILES = 30 + +# 标签工程配置 +LABEL_TRANSFORM = "exponential" # 启用指数化增益 +LABEL_SCALE = 20.0 + +# 模型参数配置 +MODEL_PARAMS = { + # ... 基础参数 ... + + # Top-K 优化参数 + "loss_type": "weighted_listnet", # 或 "lambda" + "topk_weight": 5.0, # 仅 weighted_listnet 有效 + "lambda_sigma": 1.0, # 仅 lambda 有效 + "ndcg_weight_power": 1.0, # 仅 lambda 有效 + "ndcg_k": 20, # 验证指标 +} +``` + +## 参数说明 + +### 损失函数参数 + +| 参数 | 类型 | 默认值 | 说明 | +|------|------|--------|------| +| `loss_type` | str | "listnet" | 损失类型: "listnet"/"weighted_listnet"/"lambda" | +| `topk_weight` | float | 5.0 | 头部权重系数 (weighted_listnet),越大越关注头部 | +| `lambda_sigma` | float | 1.0 | Sigmoid 陡峭程度 (lambda) | +| `ndcg_weight_power` | float | 1.0 | DeltaNDCG 权重幂次 (lambda),>1 进一步放大头部 | +| `ndcg_k` | int/None | None | 验证时计算的 NDCG@k,None 表示全局 | + +### 标签工程参数 + +| 参数 | 类型 | 默认值 | 说明 | +|------|------|--------|------| +| `n_quantiles` | int | 20 | 分位数数量,越大分辨率越高 | +| `label_transform` | str/None | None | 变换类型: None/"exponential" | +| `label_scale` | float | 20.0 | 指数变换缩放因子,控制增益幅度 | + +## 实施建议 + +1. **渐进式优化**: + - 第1轮:仅启用 `loss_type="weighted_listnet"`, `topk_weight=3.0` + - 第2轮:增加标签工程 `label_transform="exponential"` + - 第3轮:尝试 `loss_type="lambda"` 精细优化 + +2. **监控指标**: + - 关注 NDCG@K(K 设为实际 Top-K 大小) + - 对比不同配置的回测收益率 + - 观察训练损失是否稳定下降 + +3. **注意事项**: + - LambdaLoss 训练更慢,每 epoch 需更多时间 + - 指数化增益会改变标签分布,可能需要调整学习率 + - 过高的 topk_weight 可能导致过拟合头部样本 + +## 参考论文 + +1. **ListNet**: "Learning to Rank: From Pairwise Approach to Listwise Approach" (Cao et al., 2007) +2. **LambdaRank**: "From RankNet to LambdaRank to LambdaMART: An Overview" (Burges, 2010) +3. **LambdaLoss**: "The LambdaLoss Framework for Ranking Metric Optimization" (Wang et al., 2018) +4. **深度学习选股**: "Deep Learning for Stock Selection" (Gu, Kelly, & Xiu, 2020) diff --git a/src/experiment/common.py b/src/experiment/common.py index 559061e..791ec0d 100644 --- a/src/experiment/common.py +++ b/src/experiment/common.py @@ -270,6 +270,7 @@ SELECTED_FACTORS = [ "bottom_cost_stability", "pivot_reversion", "chip_transition", + # "amivest_liq_20", # "atr_price_impact", # "hui_heubel_ratio", @@ -323,11 +324,11 @@ def get_label_factor(label_name: str) -> dict: # 辅助函数 # ============================================================================= def register_factors( - engine: FactorEngine, - selected_factors: List[str], - factor_definitions: dict, - label_factor: dict, - excluded_factors: Optional[List[str]] = None, + engine: FactorEngine, + selected_factors: List[str], + factor_definitions: dict, + label_factor: dict, + excluded_factors: Optional[List[str]] = None, ) -> List[str]: """注册因子。 @@ -408,11 +409,11 @@ def register_factors( def prepare_data( - engine: FactorEngine, - feature_cols: List[str], - start_date: str, - end_date: str, - label_name: str, + engine: FactorEngine, + feature_cols: List[str], + start_date: str, + end_date: str, + label_name: str, ) -> pl.DataFrame: """准备数据。 @@ -450,45 +451,6 @@ def prepare_data( return data -# ============================================================================= -# 股票池筛选配置 -# ============================================================================= -def stock_pool_filter(df: pl.DataFrame) -> pl.Series: - """股票池筛选函数(单日数据)。 - - 筛选条件: - 1. 排除创业板(代码以 300 开头) - 2. 排除科创板(代码以 688 开头) - 3. 排除北交所(代码以 8、9 或 4 开头) - 4. 选取当日市值最小的500只股票 - - Args: - df: 单日数据框 - - Returns: - 布尔Series,表示哪些股票被选中 - """ - # 代码筛选(排除创业板、科创板、北交所) - code_filter = ( - ~df["ts_code"].str.starts_with("30") # 排除创业板 - & ~df["ts_code"].str.starts_with("68") # 排除科创板 - & ~df["ts_code"].str.starts_with("8") # 排除北交所 - & ~df["ts_code"].str.starts_with("9") # 排除北交所 - & ~df["ts_code"].str.starts_with("4") # 排除北交所 - ) - - # 在已筛选的股票中,选取流通市值最小的500只 - valid_df = df.filter(code_filter) - n = min(1000, len(valid_df)) - small_cap_codes = valid_df.sort("circ_mv").head(n)["ts_code"] - - # 返回布尔 Series:是否在被选中的股票中 - return df["ts_code"].is_in(small_cap_codes) - - -# 定义筛选所需的基础列 -STOCK_FILTER_REQUIRED_COLUMNS = ["circ_mv"] - # ============================================================================= # 输出配置 # ============================================================================= @@ -502,10 +464,54 @@ MODEL_SAVE_DIR = "models" # 模型保存目录 # Top N 配置:每日推荐股票数量 TOP_N = 5 # 可调整为 10, 20 等 +# 股票池大小配置 +STOCK_POOL_SIZE = 1000 # 股票池选择市值最小的股票数量 + # 训练数据跳过天数配置 TRAIN_SKIP_DAYS = 300 # 跳过训练数据前252天的数据,避免训练初期数据不足的问题 +# ============================================================================= +# 股票池筛选配置 +# ============================================================================= +def stock_pool_filter(df: pl.DataFrame, n_stocks: int = STOCK_POOL_SIZE) -> pl.Series: + """股票池筛选函数(单日数据)。 + + 筛选条件: + 1. 排除创业板(代码以 300 开头) + 2. 排除科创板(代码以 688 开头) + 3. 排除北交所(代码以 8、9 或 4 开头) + 4. 选取当日市值最小的n_stocks只股票 + + Args: + df: 单日数据框 + n_stocks: 选取的股票数量,默认为 STOCK_POOL_SIZE + + Returns: + 布尔Series,表示哪些股票被选中 + """ + # 代码筛选(排除创业板、科创板、北交所) + code_filter = ( + ~df["ts_code"].str.starts_with("30") # 排除创业板 + & ~df["ts_code"].str.starts_with("68") # 排除科创板 + & ~df["ts_code"].str.starts_with("8") # 排除北交所 + & ~df["ts_code"].str.starts_with("9") # 排除北交所 + & ~df["ts_code"].str.starts_with("4") # 排除北交所 + ) + + # 在已筛选的股票中,选取流通市值最小的n_stocks只 + valid_df = df.filter(code_filter) + n = min(n_stocks, len(valid_df)) + small_cap_codes = valid_df.sort("circ_mv").head(n)["ts_code"] + + # 返回布尔 Series:是否在被选中的股票中 + return df["ts_code"].is_in(small_cap_codes) + + +# 定义筛选所需的基础列 +STOCK_FILTER_REQUIRED_COLUMNS = ["circ_mv"] + + def get_output_path(model_type: str, test_start: str, test_end: str) -> str: """生成输出文件路径。 @@ -532,7 +538,7 @@ def get_output_path(model_type: str, test_start: str, test_end: str) -> str: def get_model_save_path( - model_type: str, + model_type: str, ) -> Optional[str]: """生成模型保存路径。 @@ -558,11 +564,11 @@ def get_model_save_path( def save_model_with_factors( - model, - model_path: str, - selected_factors: list[str], - factor_definitions: dict, - fitted_processors: list | None = None, + model, + model_path: str, + selected_factors: list[str], + factor_definitions: dict, + fitted_processors: list | None = None, ) -> str: """保存模型及关联的因子信息和处理器。 diff --git a/src/experiment/learn_to_rank.py b/src/experiment/learn_to_rank.py index 7e34e2c..98107eb 100644 --- a/src/experiment/learn_to_rank.py +++ b/src/experiment/learn_to_rank.py @@ -54,27 +54,10 @@ N_QUANTILES = 20 # 排除的因子列表 EXCLUDED_FACTORS = [ - 'active_market_cap', - 'close_vwap_deviation', - 'sharpe_ratio_20', - 'upper_shadow_ratio', - 'volume_ratio_5_20', - 'GTJA_alpha090', - 'GTJA_alpha084', - 'GTJA_alpha066', - 'GTJA_alpha150', - 'GTJA_alpha148', - 'GTJA_alpha106', - 'GTJA_alpha109', - 'GTJA_alpha108', - 'GTJA_alpha176', - 'GTJA_alpha169', - 'GTJA_alpha156', - 'chip_dispersion_70', - 'winner_rate_cs_rank', - 'atr_price_impact', - 'low_vol_days_20', - 'liquidity_shock_momentum', + # 'debt_to_equity', + # 'GTJA_alpha016', + # 'GTJA_alpha141', + ] # LambdaRank 模型参数配置 @@ -145,8 +128,8 @@ def main(): pipeline = DataPipeline( factor_manager=factor_manager, processor_configs=[ - (NullFiller, {"strategy": "mean"}), (Winsorizer, {"lower": 0.01, "upper": 0.99}), + (NullFiller, {"strategy": "mean"}), (CrossSectionalStandardScaler, {}), ], filters=[STFilter(data_router=engine.router)], diff --git a/src/experiment/regression.py b/src/experiment/regression.py index be204ae..02534f6 100644 --- a/src/experiment/regression.py +++ b/src/experiment/regression.py @@ -52,36 +52,36 @@ TRAINING_TYPE = "regression" # 排除的因子列表 EXCLUDED_FACTORS = [ - 'GTJA_alpha016', - 'volatility_20', - 'current_ratio', - 'GTJA_alpha001', - 'GTJA_alpha141', - 'GTJA_alpha129', - 'GTJA_alpha164', - 'amivest_liq_20', - 'GTJA_alpha012', - 'debt_to_equity', - 'turnover_deviation', - 'GTJA_alpha073', - 'GTJA_alpha043', - 'GTJA_alpha032', - 'GTJA_alpha028', - 'GTJA_alpha090', - 'GTJA_alpha108', - 'GTJA_alpha105', - 'GTJA_alpha091', - 'GTJA_alpha119', - 'GTJA_alpha104', - 'GTJA_alpha163', - 'GTJA_alpha157', - 'cost_skewness', - 'GTJA_alpha176', - 'chip_transition', - 'amount_skewness_20', - 'GTJA_alpha148', - 'mean_median_dev', - 'downside_illiq_20', + # 'GTJA_alpha016', + # 'volatility_20', + # 'current_ratio', + # 'GTJA_alpha001', + # 'GTJA_alpha141', + # 'GTJA_alpha129', + # 'GTJA_alpha164', + # 'amivest_liq_20', + # 'GTJA_alpha012', + # 'debt_to_equity', + # 'turnover_deviation', + # 'GTJA_alpha073', + # 'GTJA_alpha043', + # 'GTJA_alpha032', + # 'GTJA_alpha028', + # 'GTJA_alpha090', + # 'GTJA_alpha108', + # 'GTJA_alpha105', + # 'GTJA_alpha091', + # 'GTJA_alpha119', + # 'GTJA_alpha104', + # 'GTJA_alpha163', + # 'GTJA_alpha157', + # 'cost_skewness', + # 'GTJA_alpha176', + # 'chip_transition', + # 'amount_skewness_20', + # 'GTJA_alpha148', + # 'mean_median_dev', + # 'downside_illiq_20', ] # 模型参数配置 @@ -153,15 +153,15 @@ def main(): pipeline = DataPipeline( factor_manager=factor_manager, processor_configs=[ - (NullFiller, {"strategy": "mean"}), (Winsorizer, {"lower": 0.01, "upper": 0.99}), + (NullFiller, {"strategy": "mean"}), (StandardScaler, {}), # (CrossSectionalStandardScaler, {}), ], label_processor_configs=[ # 对 label 进行缩尾处理(去除极端收益率) (Winsorizer, {"lower": 0.05, "upper": 0.95}), - # (StandardScaler, {}), + (StandardScaler, {}), ], filters=[STFilter(data_router=engine.router)], stock_pool_filter_func=stock_pool_filter, diff --git a/src/experiment/tabm_rank_train.py b/src/experiment/tabm_rank_train.py new file mode 100644 index 0000000..b360006 --- /dev/null +++ b/src/experiment/tabm_rank_train.py @@ -0,0 +1,176 @@ +"""TabM 排序学习训练流程(模块化版本) + +使用新的模块化 Trainer 架构,基于 TabMRankModel 实现排序学习。 +TabM 使用 ListNet 损失函数,支持集成学习。 +""" + +import os + +from src.factors import FactorEngine +from src.training import ( + FactorManager, + DataPipeline, + NullFiller, + Winsorizer, + CrossSectionalStandardScaler, +) +from src.training.tasks.tabm_rank_task import TabMRankTask +from src.training.core.trainer_v2 import Trainer +from src.training.components.filters import STFilter +from src.experiment.common import ( + SELECTED_FACTORS, + FACTOR_DEFINITIONS, + LABEL_NAME, + LABEL_FACTOR, + TRAIN_START, + TRAIN_END, + VAL_START, + VAL_END, + TEST_START, + TEST_END, + stock_pool_filter, + STOCK_FILTER_REQUIRED_COLUMNS, + OUTPUT_DIR, + SAVE_PREDICTIONS, + SAVE_MODEL, + get_model_save_path, + save_model_with_factors, + TOP_N, + TRAIN_SKIP_DAYS, +) + +# 训练类型标识 +TRAINING_TYPE = "tabm_rank" + +# %% +# Label 配置(从 common.py 统一导入) +# LABEL_NAME 和 LABEL_FACTOR 已在 common.py 中绑定,只需从 common 导入 + +# 分位数配置(提高分辨率以更好地区分头部) +N_QUANTILES = 50 + +# 【Top-K 优化】标签工程配置 - 默认启用平方增益 +LABEL_TRANSFORM = "exponential" # 启用平方增益标签 (rank^2) +LABEL_SCALE = 20.0 # 保留参数(当前未使用,平方变换不需要缩放) + +# 排除的因子列表 +EXCLUDED_FACTORS = ["GTJA_alpha041", "GTJA_alpha127"] + +# TabM Rank 模型参数配置(Top-K 优化全部开启,使用 LambdaLoss) +MODEL_PARAMS = { + # ==================== MLP 结构 ==================== + "n_blocks": 4, # MLP 层数 + "d_block": 256, # 每层神经元数 + "dropout": 0.5, # Dropout 率 + # ==================== 集成机制 ==================== + "ensemble_size": 32, # 内置集成大小(模拟 32 个模型集成) + # ==================== 训练参数 ==================== + "learning_rate": 1e-4, # 学习率 + "weight_decay": 1e-5, # 权重衰减 + "epochs": 500, # 训练轮数 + # ==================== 早停 ==================== + "early_stopping_round": 50, # 早停耐心值 + # NDCG 评估 - 关注 Top-20 + "ndcg_k": 20, # 验证时计算 NDCG@20 + # 【Top-K 优化】损失函数配置 - 使用 LambdaLoss + "loss_type": "lambda", # 使用 LambdaLoss 精准优化 Top-K + "lambda_sigma": 1.0, # Sigmoid 陡峭程度 + "ndcg_weight_power": 1.0, # DeltaNDCG 权重幂次,>1 进一步放大头部效应 +} + +# 日期范围配置 +date_range = { + "train": (TRAIN_START, TRAIN_END), + "val": (VAL_START, VAL_END), + "test": (TEST_START, TEST_END), +} + +# 输出配置 +output_config = { + "output_dir": OUTPUT_DIR, + "output_filename": "tabm_rank_output.csv", + "save_predictions": SAVE_PREDICTIONS, + "save_model": SAVE_MODEL, + "model_save_path": get_model_save_path(TRAINING_TYPE), + "top_n": TOP_N, +} + + +def main(): + """主函数""" + print("\n" + "=" * 80) + print("TabM 排序学习训练(模块化版本)") + print("=" * 80) + + # 1. 创建 FactorEngine + print("\n[1] 创建 FactorEngine") + engine = FactorEngine() + + # 2. 创建 FactorManager + print("\n[2] 创建 FactorManager") + factor_manager = FactorManager( + selected_factors=SELECTED_FACTORS, + factor_definitions=FACTOR_DEFINITIONS, + label_factor=LABEL_FACTOR, + excluded_factors=EXCLUDED_FACTORS, + ) + + # 3. 创建 DataPipeline + print("\n[3] 创建 DataPipeline") + pipeline = DataPipeline( + factor_manager=factor_manager, + processor_configs=[ + (Winsorizer, {"lower": 0.01, "upper": 0.99}), + (NullFiller, {"strategy": "mean"}), + (CrossSectionalStandardScaler, {}), + ], + filters=[STFilter(data_router=engine.router)], + stock_pool_filter_func=stock_pool_filter, + stock_pool_required_columns=STOCK_FILTER_REQUIRED_COLUMNS, + train_skip_days=TRAIN_SKIP_DAYS, + ) + + # 4. 创建 TabMRankTask + print("\n[4] 创建 TabMRankTask") + task = TabMRankTask( + model_params=MODEL_PARAMS, + label_name=LABEL_NAME, + n_quantiles=N_QUANTILES, + label_transform=LABEL_TRANSFORM, + label_scale=LABEL_SCALE, + ) + + # 5. 创建 Trainer + print("\n[5] 创建 Trainer") + trainer = Trainer( + data_pipeline=pipeline, + task=task, + output_config=output_config, + verbose=True, + ) + + # 6. 执行训练 + print("\n[6] 执行训练") + results = trainer.run(engine=engine, date_range=date_range) + + # 7. 保存模型和因子信息(如果启用) + if SAVE_MODEL: + print("\n[7] 保存模型和因子信息") + save_model_with_factors( + model=task.get_model(), + model_path=output_config["model_save_path"], + selected_factors=SELECTED_FACTORS, + factor_definitions=FACTOR_DEFINITIONS, + fitted_processors=pipeline.get_fitted_processors(), + ) + + print("\n" + "=" * 80) + print("训练流程完成!") + print(f"结果保存路径: {os.path.join(OUTPUT_DIR, 'tabm_rank_output.csv')}") + print("=" * 80) + + return results + + +if __name__ == "__main__": + main() diff --git a/src/experiment/tabm_regression.py b/src/experiment/tabm_regression.py index a998771..fd6acab 100644 --- a/src/experiment/tabm_regression.py +++ b/src/experiment/tabm_regression.py @@ -39,9 +39,8 @@ from src.experiment.common import ( get_model_save_path, save_model_with_factors, TOP_N, - TRAIN_SKIP_DAYS, + TRAIN_SKIP_DAYS, STOCK_POOL_SIZE, ) -from src.experiment.data_quality_analyzer import DataQualityAnalyzer # 训练类型标识 TRAINING_TYPE = "tabm_regression" @@ -54,201 +53,7 @@ TRAINING_TYPE = "tabm_regression" # 排除的因子列表(与 LightGBM 回归保持一致) EXCLUDED_FACTORS = [ - # "GTJA_alpha001", - # "GTJA_alpha002", - # "GTJA_alpha003", - # "GTJA_alpha004", - # "GTJA_alpha005", - # "GTJA_alpha006", - # "GTJA_alpha007", - # "GTJA_alpha008", - # "GTJA_alpha009", - # "GTJA_alpha010", - # "GTJA_alpha011", - # "GTJA_alpha012", - # "GTJA_alpha013", - # "GTJA_alpha014", - # "GTJA_alpha015", - # "GTJA_alpha016", - # "GTJA_alpha017", - # "GTJA_alpha018", - # "GTJA_alpha019", - # "GTJA_alpha020", - # "GTJA_alpha022", - # "GTJA_alpha023", - # "GTJA_alpha024", - # "GTJA_alpha025", - # "GTJA_alpha026", - # "GTJA_alpha027", - # "GTJA_alpha028", - # "GTJA_alpha029", - # "GTJA_alpha031", - # "GTJA_alpha032", - # "GTJA_alpha033", - # "GTJA_alpha034", - # "GTJA_alpha035", - # "GTJA_alpha036", - # "GTJA_alpha037", - # # "GTJA_alpha038", - # "GTJA_alpha039", - # "GTJA_alpha040", - # "GTJA_alpha041", - # "GTJA_alpha042", - # "GTJA_alpha043", - # "GTJA_alpha044", - # "GTJA_alpha045", - # "GTJA_alpha046", - # "GTJA_alpha047", - # "GTJA_alpha048", - # "GTJA_alpha049", - # "GTJA_alpha050", - # "GTJA_alpha051", - # "GTJA_alpha052", - # "GTJA_alpha053", - # "GTJA_alpha054", - # "GTJA_alpha056", - # "GTJA_alpha057", - # "GTJA_alpha058", - # "GTJA_alpha059", - # "GTJA_alpha060", - # "GTJA_alpha061", - # "GTJA_alpha062", - # "GTJA_alpha063", - # "GTJA_alpha064", - # "GTJA_alpha065", - # "GTJA_alpha066", - # "GTJA_alpha067", - # "GTJA_alpha068", - # "GTJA_alpha070", - # "GTJA_alpha071", - # "GTJA_alpha072", - # "GTJA_alpha073", - # "GTJA_alpha074", - # "GTJA_alpha076", - # "GTJA_alpha077", - # "GTJA_alpha078", - # "GTJA_alpha079", - # "GTJA_alpha080", - # "GTJA_alpha081", - # "GTJA_alpha082", - # "GTJA_alpha083", - # "GTJA_alpha084", - # "GTJA_alpha085", - # "GTJA_alpha086", - # "GTJA_alpha087", - # "GTJA_alpha088", - # "GTJA_alpha089", - # "GTJA_alpha090", - # "GTJA_alpha091", - # "GTJA_alpha092", - # "GTJA_alpha093", - # "GTJA_alpha094", - # "GTJA_alpha095", - # "GTJA_alpha096", - # "GTJA_alpha097", - # "GTJA_alpha098", - # "GTJA_alpha099", - # "GTJA_alpha100", - # "GTJA_alpha101", - # "GTJA_alpha102", - # "GTJA_alpha103", - # "GTJA_alpha104", - # "GTJA_alpha105", - # "GTJA_alpha106", - # "GTJA_alpha107", - # "GTJA_alpha108", - # "GTJA_alpha109", - # "GTJA_alpha110", - # "GTJA_alpha111", - # "GTJA_alpha112", - # # "GTJA_alpha113", - # "GTJA_alpha114", - # "GTJA_alpha115", - # "GTJA_alpha117", - # "GTJA_alpha118", - # "GTJA_alpha119", - # "GTJA_alpha120", - # # "GTJA_alpha121", - # "GTJA_alpha122", - # "GTJA_alpha123", - # "GTJA_alpha124", - # "GTJA_alpha125", - # "GTJA_alpha126", - # "GTJA_alpha127", - # "GTJA_alpha128", - # "GTJA_alpha129", - # "GTJA_alpha130", - # "GTJA_alpha131", - # "GTJA_alpha132", - # "GTJA_alpha133", - # "GTJA_alpha134", - # "GTJA_alpha135", - # "GTJA_alpha136", - # # "GTJA_alpha138", - # "GTJA_alpha139", - # # "GTJA_alpha140", - # "GTJA_alpha141", - # "GTJA_alpha142", - # "GTJA_alpha145", - # # "GTJA_alpha146", - # "GTJA_alpha148", - # "GTJA_alpha150", - # "GTJA_alpha151", - # "GTJA_alpha152", - # "GTJA_alpha153", - # "GTJA_alpha154", - # "GTJA_alpha155", - # "GTJA_alpha156", - # "GTJA_alpha157", - # "GTJA_alpha158", - # "GTJA_alpha159", - # "GTJA_alpha160", - # "GTJA_alpha161", - # "GTJA_alpha162", - # "GTJA_alpha163", - # "GTJA_alpha164", - # # "GTJA_alpha165", - # "GTJA_alpha166", - # "GTJA_alpha167", - # "GTJA_alpha168", - # "GTJA_alpha169", - # "GTJA_alpha170", - # "GTJA_alpha171", - # "GTJA_alpha173", - # "GTJA_alpha174", - # "GTJA_alpha175", - # "GTJA_alpha176", - # "GTJA_alpha177", - # "GTJA_alpha178", - # "GTJA_alpha179", - # "GTJA_alpha180", - # # "GTJA_alpha183", - # "GTJA_alpha184", - # "GTJA_alpha185", - # "GTJA_alpha187", - # "GTJA_alpha188", - # "GTJA_alpha189", - # "GTJA_alpha191", - # "chip_dispersion_90", - # "chip_dispersion_70", - # "cost_skewness", - # "dispersion_change_20", - # "price_to_avg_cost", - # "price_to_median_cost", - # "mean_median_dev", - # "trap_pressure", - # "bottom_profit", - # "history_position", - # "winner_rate_surge_5", - # "winner_rate_cs_rank", - # "winner_rate_dev_20", - # "winner_rate_volatility", - # "smart_money_accumulation", - # "winner_vol_corr_20", - # "cost_base_momentum", - # "bottom_cost_stability", - # "pivot_reversion", - # "chip_transition", + ] # TabM 模型参数配置(来自用户提供的示例代码) @@ -256,11 +61,11 @@ MODEL_PARAMS = { # ==================== MLP 结构 ==================== "n_blocks": 3, # MLP 层数 "d_block": 256, # 每层神经元数 - "dropout": 0.3, # Dropout 率 + "dropout": 0.5, # Dropout 率 # ==================== 集成机制 ==================== "ensemble_size": 32, # 内置集成大小(模拟 32 个模型集成) # ==================== 训练参数 ==================== - "batch_size": 2048, # 批次大小 + "batch_size": STOCK_POOL_SIZE * 5, # 批次大小 "learning_rate": 1e-3, # 学习率 "weight_decay": 1e-5, # 权重衰减 "epochs": 100, # 训练轮数 @@ -312,13 +117,14 @@ def main(): pipeline = DataPipeline( factor_manager=factor_manager, processor_configs=[ - (NullFiller, {"strategy": "mean"}), (Winsorizer, {"lower": 0.01, "upper": 0.99}), # 先缩尾处理厚尾分布 + (NullFiller, {"strategy": "mean"}), (StandardScaler, {}), # TabM 需要标准化输入 ], label_processor_configs=[ # 对 label 进行缩尾处理(去除极端收益率) - (Winsorizer, {"lower": 0.05, "upper": 0.95}), + (Winsorizer, {"lower": 0.01, "upper": 0.99}), + (StandardScaler, {}), ], filters=[STFilter(data_router=engine.router)], stock_pool_filter_func=stock_pool_filter, diff --git a/src/training/__init__.py b/src/training/__init__.py index 948ba0b..bea336a 100644 --- a/src/training/__init__.py +++ b/src/training/__init__.py @@ -43,6 +43,12 @@ from src.training.core import StockPoolManager, Trainer # 工具函数 from src.training.utils import check_data_quality +# 数据质量分析器 +from src.training.data_quality_analyzer import ( + DataQualityAnalyzer, + analyze_data_quality, +) + # 配置 from src.training.config import TrainingConfig @@ -85,6 +91,9 @@ __all__ = [ "Trainer", # 工具函数 "check_data_quality", + # 数据质量分析器 + "DataQualityAnalyzer", + "analyze_data_quality", # 配置 "TrainingConfig", # 新增:模块化 Trainer 组件(推荐使用) diff --git a/src/training/components/models/__init__.py b/src/training/components/models/__init__.py index 907e040..e757708 100644 --- a/src/training/components/models/__init__.py +++ b/src/training/components/models/__init__.py @@ -7,6 +7,11 @@ from src.training.components.models.lightgbm import LightGBMModel from src.training.components.models.lightgbm_lambdarank import LightGBMLambdaRankModel from src.training.components.models.tabpfn_model import TabPFNModel from src.training.components.models.tabm_model import TabMModel +from src.training.components.models.tabm_rank_model import ( + TabMRankModel, + EnsembleListNetLoss, + EnsembleLambdaLoss, +) from src.training.components.models.cross_section_sampler import CrossSectionSampler from src.training.components.models.ensemble_quant_loss import EnsembleQuantLoss @@ -15,6 +20,9 @@ __all__ = [ "LightGBMLambdaRankModel", "TabPFNModel", "TabMModel", + "TabMRankModel", + "EnsembleListNetLoss", + "EnsembleLambdaLoss", "CrossSectionSampler", "EnsembleQuantLoss", ] diff --git a/src/training/components/models/tabm_rank_model.py b/src/training/components/models/tabm_rank_model.py new file mode 100644 index 0000000..741b583 --- /dev/null +++ b/src/training/components/models/tabm_rank_model.py @@ -0,0 +1,747 @@ +"""TabM 排序模型实现 (TabM Rank) + +基于 TabM (Tabular Multilayer Perceptron with Ensembles) 架构 +引入 ListNet 列表级排序损失,实现类似 LambdaRank 的截面排序学习。 +适用于股票未来收益率的截面排序预测。 +""" + +from typing import Dict, Any, List, Optional, Tuple +from pathlib import Path +import pickle + +import numpy as np +import polars as pl +import scipy.stats as stats +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +from torch.utils.data import DataLoader, TensorDataset, Sampler +from tabm import TabM + +from src.training.components.base import BaseModel +from src.training.registry import register_model + + +class GroupSampler(Sampler): + """排序学习专用的分组采样器 + + 确保每个 Batch 包含的是同一个 Query (如同一天) 的所有样本。 + 这与 LightGBM 的 `group` 参数逻辑完全一致。 + """ + + def __init__(self, group_counts: np.ndarray, shuffle_groups: bool = True): + """初始化分组采样器 + + Args: + group_counts: 各组的样本数量数组,如 [100, 150, 120] + shuffle_groups: 是否打乱组的训练顺序 + """ + self.group_counts = group_counts + self.shuffle_groups = shuffle_groups + # 计算每组在原数组中的起始和结束边界 + self.boundaries = np.insert(np.cumsum(group_counts), 0, 0) + self.num_groups = len(group_counts) + + def __iter__(self): + """迭代生成批次索引 + + Yields: + list: 同一组的所有样本索引 + """ + group_indices = list(range(self.num_groups)) + if self.shuffle_groups: + np.random.shuffle(group_indices) + + for g_idx in group_indices: + start = self.boundaries[g_idx] + end = self.boundaries[g_idx + 1] + # 返回该组(天)内所有样本的索引,作为一个完整的 batch + yield list(range(start, end)) + + def __len__(self): + """返回组数量""" + return self.num_groups + + +class EnsembleListNetLoss(nn.Module): + """集成 ListNet 排序损失 (Listwise Ranking Loss) + + 基于交叉熵将截面上的相关度转化为概率分布。 + 支持 TabM 的 Ensemble 维度。 + """ + + def __init__(self, topk_weight: float = 1.0): + """初始化 ListNet 损失 + + Args: + topk_weight: 头部样本权重系数,>1.0 时强化对高分标签的关注 + """ + super().__init__() + self.topk_weight = topk_weight + + def forward(self, preds: torch.Tensor, targets: torch.Tensor): + """计算 ListNet 排序损失 + + Args: + preds: [BatchSize, EnsembleSize] - 当前组所有样本的预测 logits + targets: [BatchSize] - 当前组所有样本的真实排序标签 (0, 1, 2...) + + Returns: + 标量损失值 + """ + # 如果该组内样本太少,无法排序,直接返回 0 + if preds.size(0) <= 1: + return torch.tensor(0.0, requires_grad=True, device=preds.device) + + # [BatchSize] -> [BatchSize, EnsembleSize] 广播以对齐每个集成成员 + targets_expanded = targets.unsqueeze(1).expand_as(preds) + + # 1. 计算真实标签的分数概率分布 (Softmax 使得高分权重成倍提升) + targets_prob = F.softmax(targets_expanded, dim=0) + + # 2. 【Top-K 优化】如果启用加权,给予高分标签更高权重 + if self.topk_weight > 1.0: + # 基于标签值计算权重(标签越高,权重越大) + # 归一化到 [1, topk_weight] 范围 + max_target = targets.max() + min_target = targets.min() + if max_target > min_target: + # 线性插值:权重 = 1 + (target - min) / (max - min) * (topk_weight - 1) + sample_weights = 1.0 + (targets - min_target) / ( + max_target - min_target + ) * (self.topk_weight - 1.0) + sample_weights = sample_weights.unsqueeze(1).expand_as(preds) + # 归一化权重使其和为样本数 + sample_weights = sample_weights * len(targets) / sample_weights.sum() + # 应用权重到目标概率 + targets_prob = targets_prob * sample_weights + + # 3. 计算预测值的对数概率分布 (log_softmax 数值上比 log(softmax) 更稳定) + preds_log_prob = F.log_softmax(preds, dim=0) + + # 4. 计算交叉熵损失 (在 Batch/组 维度求和) + loss = -torch.sum(targets_prob * preds_log_prob, dim=0) # [EnsembleSize] + + # 5. 对所有集成成员的 Loss 取平均 + return loss.mean() + + +class EnsembleLambdaLoss(nn.Module): + """集成 LambdaLoss (支持 TabM 集成维度) + + 基于 Pairwise 排序损失,引入 DeltaNDCG 权重。 + 参考: "The LambdaLoss Framework for Ranking Metric Optimization" (Google Research) + + 特点: + - 计算每对样本交换位置后对 NDCG 的影响 + - 对头部样本(高 Gain 且排名靠前)给予更高权重 + - 更适合 Top-K 选股场景 + """ + + def __init__(self, sigma: float = 1.0, ndcg_weight_power: float = 1.0): + """初始化 LambdaLoss + + Args: + sigma: Sigmoid 函数的陡峭程度,控制梯度大小 + ndcg_weight_power: DeltaNDCG 权重幂次,>1 时进一步放大头部效应 + """ + super().__init__() + self.sigma = sigma + self.ndcg_weight_power = ndcg_weight_power + + def forward(self, preds: torch.Tensor, targets: torch.Tensor): + """计算 LambdaLoss + + Args: + preds: [BatchSize, EnsembleSize] - 预测分 + targets: [BatchSize] - 相关性标签 (Gain) + + Returns: + 标量损失值 + """ + if preds.size(0) <= 1: + return torch.tensor(0.0, requires_grad=True, device=preds.device) + + # 1. 计算两两对之间的差值 + preds_diff = preds.unsqueeze(1) - preds.unsqueeze(0) # [B, B, E] + + # 【性能优化】: target_diff 不需要 E 维度!直接保持 [B, B] + target_diff = targets.unsqueeze(1) - targets.unsqueeze(0) # [B, B] + + # 2. 掩码矩阵: [B, B, 1] 方便后续广播 + mask = (target_diff > 0).float().unsqueeze(2) # [B, B, 1] + + # 3. 计算 Delta NDCG + # DeltaNDCG = |Gain_i - Gain_j| * |1/log(rank_i+1) - 1/log(rank_j+1)| + with torch.no_grad(): + # 【性能核爆优化】: 使用两次 argsort 完全消灭 for 循环 + # 第1次 argsort: 获得从大到小的索引 + # 第2次 argsort: 直接将索引反转为排名 (加上 1 就是真实名次) + ranks = preds.argsort(dim=0, descending=True).argsort(dim=0) + 1 + ranks = ranks.float() # [B, E] + + # 计算位置惩罚项 (log2 排名) + log_rank = torch.log2(ranks + 1) + inv_log_rank_diff = torch.abs( + 1.0 / log_rank.unsqueeze(1) - 1.0 / log_rank.unsqueeze(0) + ) # [B, B, E] + + # DeltaNDCG = |Gain 差| * |位置惩罚差| + # target_diff 是 [B, B],通过 unsqueeze 广播到 [B, B, 1] + delta_ndcg = torch.abs(target_diff).unsqueeze(2) * inv_log_rank_diff + + # 应用幂次调整权重分布 + if self.ndcg_weight_power != 1.0: + delta_ndcg = torch.pow(delta_ndcg, self.ndcg_weight_power) + + # 4. Pairwise Logistic Loss 并加权 Delta NDCG + # loss = delta_ndcg * log(1 + exp(-sigma * (preds_i - preds_j))) + pairwise_loss = F.binary_cross_entropy_with_logits( + self.sigma * preds_diff, + torch.ones_like(preds_diff), + reduction="none", + ) + + # 5. 应用掩码和权重 + weighted_loss = pairwise_loss * mask * delta_ndcg + + # 避免全 0 导致的除零错误 + valid_pairs_count = mask.sum().clamp(min=1.0) + mean_loss = weighted_loss.sum() / valid_pairs_count + + return mean_loss + + +@register_model("tabm_rank") +class TabMRankModel(BaseModel): + """TabM 学习排序模型 + + 基于 TabM 架构的排序学习模型,支持 ListNet 损失。 + 适用于股票截面排序任务,将未来收益率转换为分位数标签进行训练。 + + 特点: + - 使用 ListNet 列表级排序损失 + - 支持 group 参数进行分组训练 + - 以 NDCG 作为验证指标 + - 与 LightGBMLambdaRank 接口兼容 + """ + + name = "tabm_rank" + + def __init__(self, params: Optional[Dict[str, Any]] = None): + """初始化 TabM Rank 模型 + + Args: + params: 模型参数字典,包含: + - ensemble_size: 集成大小 (默认: 32) + - n_blocks: MLP层数 (默认: 3) + - d_block: 每层神经元数 (默认: 256) + - dropout: Dropout率 (默认: 0.1) + - batch_size: 批次大小 (默认: 2048,仅预测时使用) + - learning_rate: 学习率 (默认: 1e-3) + - weight_decay: 权重衰减 (默认: 1e-5) + - epochs: 训练轮数 (默认: 50) + - early_stopping_round: 早停轮数 (默认: 10) + - max_grad_norm: 梯度裁剪阈值 (默认: 1.0) + - ndcg_k: NDCG@k 的 k 值,None 表示全局 (默认: None) + - loss_type: 损失函数类型 (默认: "listnet") + - "listnet": 标准 ListNet 损失 + - "weighted_listnet": 加权 ListNet,通过 topk_weight 强化头部 + - "lambda": LambdaLoss,基于 DeltaNDCG 加权 + - topk_weight: 头部样本权重系数,用于 weighted_listnet (默认: 5.0) + - lambda_sigma: LambdaLoss 的 sigma 参数 (默认: 1.0) + - ndcg_weight_power: DeltaNDCG 权重幂次 (默认: 1.0) + """ + self.params = params or {} + self.model = None + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + self.training_history_: Dict[str, List[float]] = { + "train_loss": [], + "val_ndcg": [], + } + self.feature_names_: Optional[List[str]] = None + + # 根据配置选择损失函数 + loss_type = self.params.get("loss_type", "listnet") + if loss_type == "lambda": + self.criterion = EnsembleLambdaLoss( + sigma=self.params.get("lambda_sigma", 1.0), + ndcg_weight_power=self.params.get("ndcg_weight_power", 1.0), + ) + elif loss_type == "weighted_listnet": + self.criterion = EnsembleListNetLoss( + topk_weight=self.params.get("topk_weight", 5.0) + ) + else: # "listnet" + self.criterion = EnsembleListNetLoss(topk_weight=1.0) + + def _make_loader( + self, + X: np.ndarray, + y: Optional[np.ndarray] = None, + group: Optional[np.ndarray] = None, + shuffle_groups: bool = False, + ) -> DataLoader: + """创建 DataLoader (支持 Query/Group 截面打包) + + Args: + X: 特征数组 [N, n_features] + y: 标签数组 [N] 或 None + group: 分组数组,表示每组样本数 + shuffle_groups: 是否打乱组的顺序 + + Returns: + DataLoader 实例 + """ + # 【性能核爆优化】: 显存管够时,直接把全量数据一把推进 GPU,避免训练时 PCIe 搬运 + X_tensor = torch.from_numpy(X).to(self.device) + + if y is not None: + y_tensor = torch.from_numpy(y).to(self.device) + dataset = TensorDataset(X_tensor, y_tensor) + else: + dataset = TensorDataset(X_tensor) + + if group is not None: + # 训练和验证时使用 GroupSampler,每个 batch 就是一个 Query + sampler = GroupSampler(group, shuffle_groups=shuffle_groups) + return DataLoader(dataset, batch_sampler=sampler) + else: + # 预测时如果没有 group,则退化为普通批次预测 + batch_size = self.params.get("batch_size", 2048) + return DataLoader(dataset, batch_size=batch_size, shuffle=False) + + def _validate_ndcg(self, val_loader: DataLoader, k: Optional[int] = None) -> float: + """验证模型 (使用 NDCG 排序指标) + + Args: + val_loader: 验证数据加载器 + k: NDCG@k 的 k 值,None 表示计算全局 NDCG + + Returns: + 平均 NDCG 分数 + """ + from sklearn.metrics import ndcg_score + + assert self.model is not None, "模型未训练,无法验证" + + self.model.eval() + ndcg_list = [] + + with torch.no_grad(): + for batch in val_loader: + if len(batch) != 2: + continue + + bx, by = batch + bx = bx.to(self.device) + by = by.cpu().numpy() + + if len(by) <= 1: + continue + + outputs = self.model(bx) # [B, E, 1] + preds = outputs.mean(dim=1).squeeze(-1).cpu().numpy() # [B] + + try: + # ndcg_score 需要形状为 (1, n_samples) 的二维数组 + score = ndcg_score([by], [preds], k=k) + ndcg_list.append(score) + except ValueError: + pass + + return float(np.mean(ndcg_list)) if len(ndcg_list) > 0 else 0.0 + + def fit( + self, + X: pl.DataFrame, + y: pl.Series, + group: Optional[np.ndarray] = None, + eval_set: Optional[Tuple] = None, + ) -> "TabMRankModel": + """训练排序模型 + + Args: + X: 训练特征DataFrame + y: 训练标签 (Polars Series),应为分位数标签 (0, 1, 2, ...) + group: 分组数组,表示每个 query 的样本数 + eval_set: 验证集元组 (X_val, y_val, group_val),用于早停 + + Returns: + self (支持链式调用) + + Raises: + ValueError: group 参数无效 + """ + self.feature_names_ = list(X.columns) + + X_np = X.to_numpy().astype(np.float32) + y_np = y.to_numpy().astype(np.float32) + + # 检查和处理 group 参数 + if group is None: + group = np.array([len(y_np)]) + if group.sum() != len(y_np): + raise ValueError( + f"group 数组的和 ({group.sum()}) 必须等于样本数 ({len(y_np)})" + ) + + train_loader = self._make_loader(X_np, y_np, group=group, shuffle_groups=True) + + val_loader = None + if eval_set is not None: + X_val, y_val, group_val = eval_set + X_val_np = ( + X_val.to_numpy().astype(np.float32) + if isinstance(X_val, pl.DataFrame) + else X_val + ) + y_val_np = ( + y_val.to_numpy().astype(np.float32) + if isinstance(y_val, pl.Series) + else y_val + ) + + if group_val is None: + group_val = np.array([len(y_val_np)]) + val_loader = self._make_loader( + X_val_np, y_val_np, group=group_val, shuffle_groups=False + ) + + ensemble_size = self.params.get("ensemble_size", 32) + n_features = X_np.shape[1] + + # 初始化 TabM 模型 + self.model = TabM.make( + n_num_features=n_features, + cat_cardinalities=[], + d_out=1, + n_blocks=self.params.get("n_blocks", 3), + d_block=self.params.get("d_block", 256), + dropout=self.params.get("dropout", 0.1), + k=ensemble_size, + ).to(self.device) + + optimizer = optim.AdamW( + self.model.parameters(), + lr=self.params.get("learning_rate", 1e-3), + weight_decay=self.params.get("weight_decay", 1e-5), + ) + + epochs = self.params.get("epochs", 50) + scheduler = optim.lr_scheduler.CosineAnnealingLR( + optimizer, T_max=epochs, eta_min=1e-6 + ) + + early_stopping_patience = self.params.get( + "early_stopping_patience", + self.params.get("early_stopping_round", 10), + ) + best_val_ndcg = -float("inf") + patience_counter = 0 + best_model_state = None + ndcg_k = self.params.get("ndcg_k", None) # None 表示计算全局 NDCG + + print(f"[TabMRank] 开始训练... 设备: {self.device}, 集成大小: {ensemble_size}") + + for epoch in range(epochs): + # 训练阶段 + self.model.train() + train_loss = 0.0 + n_train_batches = 0 + + for batch in train_loader: + if len(batch) != 2: + continue + bx, by = batch[0], batch[1] + + optimizer.zero_grad() + outputs = self.model(bx) # [B, E, 1] + outputs_squeezed = outputs.squeeze(-1) # [B, E] + + # 计算 ListNet 排序损失 + loss = self.criterion(outputs_squeezed, by) + loss.backward() + + torch.nn.utils.clip_grad_norm_( + self.model.parameters(), + max_norm=self.params.get("max_grad_norm", 1.0), + ) + optimizer.step() + + train_loss += loss.item() + n_train_batches += 1 + + avg_train_loss = train_loss / max(n_train_batches, 1) + self.training_history_["train_loss"].append(avg_train_loss) + + # 验证阶段 (基于 NDCG) + if val_loader is not None: + val_ndcg = self._validate_ndcg(val_loader, k=ndcg_k) + self.training_history_["val_ndcg"].append(val_ndcg) + + if val_ndcg > best_val_ndcg: + best_val_ndcg = val_ndcg + patience_counter = 0 + best_model_state = { + k: v.cpu().clone() for k, v in self.model.state_dict().items() + } + else: + patience_counter += 1 + + if (epoch + 1) % 5 == 0 or epoch == 0: + print( + f"[TabMRank] Epoch {epoch + 1}/{epochs} | " + f"Train Loss (ListNet): {avg_train_loss:.4f} | " + f"Val NDCG: {val_ndcg:.4f} (Best: {best_val_ndcg:.4f})" + ) + + if patience_counter >= early_stopping_patience: + print(f"[TabMRank] 触发早停,停止于 epoch {epoch + 1}") + break + else: + if (epoch + 1) % 5 == 0 or epoch == 0: + print( + f"[TabMRank] Epoch {epoch + 1}/{epochs} | " + f"Train Loss: {avg_train_loss:.4f}" + ) + + scheduler.step() + + # 恢复最佳权重 + if best_model_state is not None: + self.model.load_state_dict(best_model_state) + print(f"[TabMRank] 已恢复最佳模型权重 (Val NDCG: {best_val_ndcg:.4f})") + + return self + + def predict( + self, X: pl.DataFrame, group: Optional[np.ndarray] = None + ) -> np.ndarray: + """预测排序分数 + + Args: + X: 特征矩阵 (Polars DataFrame) + group: 分组数组,表示每个 query 的样本数。 + 如果提供,将使用 GroupSampler 确保预测顺序与分组一致。 + + Returns: + 预测分数 (numpy ndarray) + + Raises: + RuntimeError: 模型未训练时调用 + ValueError: 预测数据缺失特征 + """ + if self.model is None: + raise RuntimeError("模型未训练,请先调用fit()") + + # 特征对齐检查 + if self.feature_names_: + missing_cols = [c for c in self.feature_names_ if c not in X.columns] + if missing_cols: + raise ValueError(f"预测数据缺失特征: {missing_cols}") + X = X.select(self.feature_names_) + + X_np = X.to_numpy().astype(np.float32) + loader = self._make_loader(X_np, group=group, shuffle_groups=False) + + self.model.eval() + all_preds = [] + + with torch.no_grad(): + for batch in loader: + bx = batch[0].to(self.device) + outputs = self.model(bx) # [B, E, 1] + # 排序模型预测时直接输出集成成员的均值作为最终分数 + preds = outputs.mean(dim=1).squeeze(-1) # [B] + all_preds.append(preds.cpu().numpy()) + + return np.concatenate(all_preds) + + def get_evals_result(self) -> Optional[Dict[str, List[float]]]: + """获取训练评估结果 + + Returns: + 评估结果字典,包含 train_loss 和 val_ndcg + """ + return self.training_history_ + + def feature_importance(self) -> None: + """获取特征重要性 + + TabM没有内置特征重要性计算,返回None。 + """ + return None + + def save(self, path: str | Path) -> None: + """保存模型 + + Args: + path: 保存路径 + + Raises: + RuntimeError: 模型未训练时调用 + """ + if self.model is None: + raise RuntimeError("模型未训练,无法保存") + + path = Path(path) + path.parent.mkdir(parents=True, exist_ok=True) + + # 保存模型权重 + model_path = path.with_suffix(".pt") + torch.save(self.model.state_dict(), model_path) + + # 保存元数据 + meta_path = path.with_suffix(".meta") + meta = { + "params": self.params, + "feature_names": self.feature_names_, + "training_history": self.training_history_, + "device": str(self.device), + } + with open(meta_path, "wb") as f: + pickle.dump(meta, f) + + print(f"[TabMRank] 模型保存到: {path}") + + @classmethod + def load(cls, path: str | Path) -> "TabMRankModel": + """加载模型 + + Args: + path: 模型路径(不含扩展名) + + Returns: + 加载的 TabMRankModel 实例 + """ + path = Path(path) + + # 加载元数据 + meta_path = path.with_suffix(".meta") + with open(meta_path, "rb") as f: + meta = pickle.load(f) + + # 创建实例 + instance = cls(meta["params"]) + instance.feature_names_ = meta["feature_names"] + instance.training_history_ = meta["training_history"] + + # 重建模型结构 + if instance.feature_names_ is not None: + n_features = len(instance.feature_names_) + ensemble_size = instance.params.get("ensemble_size", 32) + + instance.model = TabM.make( + n_num_features=n_features, + cat_cardinalities=[], + d_out=1, + n_blocks=instance.params.get("n_blocks", 3), + d_block=instance.params.get("d_block", 256), + dropout=instance.params.get("dropout", 0.1), + k=ensemble_size, + ).to(instance.device) + + # 加载权重 + model_path = path.with_suffix(".pt") + instance.model.load_state_dict( + torch.load(model_path, map_location=instance.device) + ) + + print(f"[TabMRank] 模型从 {path} 加载完成") + return instance + + @staticmethod + def prepare_group_from_dates( + df: pl.DataFrame, + date_col: str = "trade_date", + ) -> np.ndarray: + """从日期列生成 group 数组 + + Args: + df: 包含日期列的 DataFrame + date_col: 日期列名,默认 "trade_date" + + Returns: + group 数组 + """ + group_counts = df.group_by(date_col, maintain_order=True).agg( + pl.count().alias("count") + ) + return group_counts["count"].to_numpy() + + @staticmethod + def convert_to_quantile_labels( + df: pl.DataFrame, + label_col: str, + date_col: str = "trade_date", + n_quantiles: int = 20, + new_col_name: Optional[str] = None, + ) -> pl.DataFrame: + """将连续标签转换为分位数标签 + + Args: + df: 输入 DataFrame + label_col: 原始标签列名 + date_col: 日期列名,默认 "trade_date" + n_quantiles: 分位数数量,默认 20 + new_col_name: 新列名,默认为 {label_col}_rank + + Returns: + 添加了分位数标签列的 DataFrame + """ + if new_col_name is None: + new_col_name = f"{label_col}_rank" + + return df.with_columns( + pl.col(label_col) + .qcut(n_quantiles) + .over(date_col) + .to_physical() + .cast(pl.Int64) + .alias(new_col_name) + ) + + def evaluate_ndcg( + self, + X: pl.DataFrame, + y: pl.Series, + group: np.ndarray, + k: Optional[int] = None, + ) -> float: + """评估 NDCG 指标 + + Args: + X: 特征矩阵 + y: 真实标签 + group: 分组数组 + k: NDCG@k 的 k 值 + + Returns: + NDCG 分数 + """ + from sklearn.metrics import ndcg_score + + y_pred = self.predict(X) + + y_true_list = [] + y_score_list = [] + + start_idx = 0 + for group_size in group: + end_idx = start_idx + group_size + y_true_list.append(y.to_numpy()[start_idx:end_idx]) + y_score_list.append(y_pred[start_idx:end_idx]) + start_idx = end_idx + + ndcg_scores = [] + for y_true, y_score in zip(y_true_list, y_score_list): + if len(y_true) > 1: + try: + score = ndcg_score([y_true], [y_score], k=k) + ndcg_scores.append(score) + except ValueError: + pass + + return float(np.mean(ndcg_scores)) if ndcg_scores else 0.0 diff --git a/src/training/components/processors/transforms.py b/src/training/components/processors/transforms.py index 5554c57..4549538 100644 --- a/src/training/components/processors/transforms.py +++ b/src/training/components/processors/transforms.py @@ -289,7 +289,7 @@ class StandardScaler(BaseProcessor): return self def transform(self, X: pl.DataFrame) -> pl.DataFrame: - """标准化(使用训练集学到的参数,增加 NaN 保护) + """标准化(使用训练集学到的参数) Args: X: 待转换数据 @@ -302,18 +302,11 @@ class StandardScaler(BaseProcessor): if col in self.mean_ and col in self.std_: # 避免除以0 std_val = self.std_[col] if self.std_[col] != 0 else 1.0 - # 关键修复:添加 fill_nan(0) 保险,防止计算产生 NaN - expr = ( - ((pl.col(col) - self.mean_[col]) / std_val) - .fill_nan(0) - .fill_null(0) - .alias(col) - ) + expr = ((pl.col(col) - self.mean_[col]) / std_val).alias(col) expressions.append(expr) elif col in self.feature_cols: - # 对于应该被处理但未学习到统计量的列 - # 统一转换为float并同时处理 NaN 和 null - expr = pl.col(col).cast(pl.Float64).fill_nan(0).fill_null(0).alias(col) + # 对于应该被处理但未学习到统计量的列,统一转换为float + expr = pl.col(col).cast(pl.Float64).alias(col) expressions.append(expr) else: expressions.append(pl.col(col)) @@ -372,20 +365,14 @@ class CrossSectionalStandardScaler(BaseProcessor): if col in self.feature_cols and X[col].dtype.is_numeric(): # 截面标准化:每天独立计算均值和标准差 # 避免除以0,当std为0时设为1 - # 关键修复:先 fill_nan 再 fill_null,防止计算产生的 NaN expr = ( - ( - (pl.col(col) - pl.col(col).mean().over(self.date_col)) - / (pl.col(col).std().over(self.date_col) + 1e-10) - ) - .fill_nan(0) - .fill_null(0) - .alias(col) - ) + (pl.col(col) - pl.col(col).mean().over(self.date_col)) + / (pl.col(col).std().over(self.date_col) + 1e-10) + ).alias(col) expressions.append(expr) elif col in self.feature_cols: - # 对于应该被处理但类型不匹配的列,转换为float并同时处理 NaN 和 null - expr = pl.col(col).cast(pl.Float64).fill_nan(0).fill_null(0).alias(col) + # 对于应该被处理但类型不匹配的列,转换为float + expr = pl.col(col).cast(pl.Float64).alias(col) expressions.append(expr) else: expressions.append(pl.col(col)) @@ -488,8 +475,8 @@ class Winsorizer(BaseProcessor): expressions.append(expr) elif col in self.feature_cols: # 对于应该被处理但未学习到边界的列(如全为NaN、布尔列等) - # 统一转换为float并填充0 - expr = pl.col(col).cast(pl.Float64).fill_null(0).alias(col) + # 统一转换为float + expr = pl.col(col).cast(pl.Float64).alias(col) expressions.append(expr) else: expressions.append(pl.col(col)) @@ -522,11 +509,9 @@ class Winsorizer(BaseProcessor): clip_exprs = [] for col in X.columns: if col in target_cols: - # 先用当天分位数缩尾,如果分位数是null(该日全为NaN)则填充0 clipped = ( pl.col(col) .clip(pl.col(f"{col}_lower"), pl.col(f"{col}_upper")) - .fill_null(0) .alias(col) ) clip_exprs.append(clipped) diff --git a/src/training/core/trainer_v2.py b/src/training/core/trainer_v2.py index 2aa5d2d..d7c0217 100644 --- a/src/training/core/trainer_v2.py +++ b/src/training/core/trainer_v2.py @@ -13,6 +13,7 @@ from src.factors import FactorEngine from src.training.pipeline import DataPipeline from src.training.tasks.base import BaseTask from src.training.result_analyzer import ResultAnalyzer +from src.training.data_quality_analyzer import DataQualityAnalyzer class Trainer: @@ -100,8 +101,6 @@ class Trainer: print("\n[Step 1.5/7] 数据质量分析...") try: - from src.experiment.data_quality_analyzer import DataQualityAnalyzer - # 获取特征列名(从训练集) feature_cols = data["train"].get("feature_cols", []) label_name = self.task.label_name diff --git a/src/experiment/data_quality_analyzer.py b/src/training/data_quality_analyzer.py similarity index 76% rename from src/experiment/data_quality_analyzer.py rename to src/training/data_quality_analyzer.py index 34842aa..bd54e69 100644 --- a/src/experiment/data_quality_analyzer.py +++ b/src/training/data_quality_analyzer.py @@ -1,6 +1,7 @@ """数据质量分析模块 提供数据质量检查功能,包括: +- 数据集日期范围信息 - 缺失值统计 - 零值统计 - 按日期检查全空列 @@ -19,6 +20,7 @@ class DataQualityAnalyzer: Attributes: feature_cols: 特征列名列表 label_col: 标签列名 + date_col: 日期列名 verbose: 是否打印详细信息 """ @@ -26,6 +28,7 @@ class DataQualityAnalyzer: self, feature_cols: Optional[List[str]] = None, label_col: Optional[str] = None, + date_col: str = "trade_date", verbose: bool = True, ): """初始化数据质量分析器 @@ -33,10 +36,12 @@ class DataQualityAnalyzer: Args: feature_cols: 特征列名列表 label_col: 标签列名 + date_col: 日期列名,默认为 "trade_date" verbose: 是否打印详细信息 """ self.feature_cols = feature_cols or [] self.label_col = label_col + self.date_col = date_col self.verbose = verbose self.analysis_results: Dict[str, Any] = {} @@ -74,6 +79,10 @@ class DataQualityAnalyzer: self.analysis_results = {} + # 首先打印数据集概览(日期范围等基本信息) + if self.verbose: + self._print_dataset_overview(data, split_names) + for split_name in split_names: if split_name not in data: continue @@ -96,6 +105,104 @@ class DataQualityAnalyzer: return self.analysis_results + def _print_dataset_overview( + self, + data: Dict[str, Dict[str, Any]], + split_names: List[str], + ) -> None: + """打印数据集概览信息 + + 包括每个数据集的起始日期、终止日期、样本数量等基本信息。 + + Args: + data: 数据字典 + split_names: 数据划分名称列表 + """ + print("\n[数据集概览]") + print("-" * 40) + + overview_data = [] + + for split_name in split_names: + if split_name not in data: + continue + + split_data = data[split_name] + raw_df = split_data.get("raw_data") + + if raw_df is None or len(raw_df) == 0: + overview_data.append( + { + "划分": split_name.upper(), + "起始日期": "-", + "终止日期": "-", + "样本数": 0, + "股票数": 0, + } + ) + continue + + # 获取日期范围 + if self.date_col in raw_df.columns: + dates = raw_df[self.date_col] + start_date = dates.min() + end_date = dates.max() + unique_dates = dates.n_unique() + else: + start_date = "-" + end_date = "-" + unique_dates = 0 + + # 获取股票数量 + if "ts_code" in raw_df.columns: + unique_stocks = raw_df["ts_code"].n_unique() + else: + unique_stocks = 0 + + overview_data.append( + { + "划分": split_name.upper(), + "起始日期": str(start_date), + "终止日期": str(end_date), + "交易日数": unique_dates, + "样本数": len(raw_df), + "股票数": unique_stocks, + } + ) + + # 打印表格 + if overview_data: + # 计算列宽 + headers = ["划分", "起始日期", "终止日期", "交易日数", "样本数", "股票数"] + col_widths = {} + for header in headers: + max_data_len = max( + len( + str( + row.get( + header.lower().replace("数", ""), row.get(header, "") + ) + ) + ) + for row in overview_data + ) + col_widths[header] = max(len(header), max_data_len) + 2 + + # 打印表头 + header_line = " ".join(h.ljust(col_widths[h]) for h in headers) + print(f" {header_line}") + print(f" {'-' * (sum(col_widths.values()) + 2 * (len(headers) - 1))}") + + # 打印数据行 + for row in overview_data: + line = " ".join( + str(row.get(h, row.get(h.lower().replace("数", ""), ""))).ljust( + col_widths[h] + ) + for h in headers + ) + print(f" {line}") + def _analyze_split( self, df: pl.DataFrame, @@ -120,6 +227,16 @@ class DataQualityAnalyzer: "all_null_by_date": {}, } + # 获取日期范围 + if self.date_col in df.columns: + results["start_date"] = str(df[self.date_col].min()) + results["end_date"] = str(df[self.date_col].max()) + results["unique_dates"] = df[self.date_col].n_unique() + + # 获取股票数量 + if "ts_code" in df.columns: + results["unique_stocks"] = df["ts_code"].n_unique() + # 1. 分析特征列的缺失值 null_stats = self._analyze_null_values(df, self.feature_cols) results["null_analysis"] = null_stats @@ -249,7 +366,7 @@ class DataQualityAnalyzer: "issues": [], } - if "trade_date" not in df.columns: + if self.date_col not in df.columns: return results # 过滤掉不在表中的列 @@ -267,7 +384,7 @@ class DataQualityAnalyzer: ] agg_exprs.append(pl.len().alias("total_rows")) - agg_lf = lf.group_by("trade_date").agg(agg_exprs) + agg_lf = lf.group_by(self.date_col).agg(agg_exprs) # 收集结果 (此时 agg_df 行数通常只有几百到几千行) agg_df = agg_lf.collect() @@ -279,13 +396,13 @@ class DataQualityAnalyzer: # 找出 null 数量等于总行数的日期 bad_dates = agg_df.filter( (pl.col(null_col) == pl.col("total_rows")) & (pl.col("total_rows") > 0) - ).select(["trade_date", "total_rows"]) + ).select([self.date_col, "total_rows"]) if not bad_dates.is_empty(): for row in bad_dates.to_dicts(): issues.append( { - "date": row["trade_date"], + "date": row[self.date_col], "column": col, "total_rows": row["total_rows"], } @@ -464,8 +581,20 @@ class DataQualityAnalyzer: for split_name, results in self.analysis_results.items(): lines.append(f"\n[{split_name.upper()}]") + + # 添加日期范围信息 + if "start_date" in results and "end_date" in results: + lines.append( + f" 日期范围: {results['start_date']} ~ {results['end_date']}" + ) + if "unique_dates" in results: + lines.append(f" 交易日数: {results['unique_dates']}") + lines.append(f" 总行数: {results['total_rows']:,}") + if "unique_stocks" in results: + lines.append(f" 股票数: {results['unique_stocks']}") + null_stats = results.get("null_analysis", {}) if null_stats.get("columns_with_null"): lines.append( @@ -493,6 +622,7 @@ def analyze_data_quality( data: Dict[str, Dict[str, Any]], feature_cols: Optional[List[str]] = None, label_col: Optional[str] = None, + date_col: str = "trade_date", verbose: bool = True, ) -> Dict[str, Any]: """便捷函数:执行数据质量分析 @@ -501,6 +631,7 @@ def analyze_data_quality( data: 数据字典 feature_cols: 特征列名列表 label_col: 标签列名 + date_col: 日期列名,默认为 "trade_date" verbose: 是否打印详细信息 Returns: @@ -509,6 +640,7 @@ def analyze_data_quality( analyzer = DataQualityAnalyzer( feature_cols=feature_cols, label_col=label_col, + date_col=date_col, verbose=verbose, ) return analyzer.analyze(data) diff --git a/src/training/tasks/__init__.py b/src/training/tasks/__init__.py index 16f3d5c..09298ef 100644 --- a/src/training/tasks/__init__.py +++ b/src/training/tasks/__init__.py @@ -7,10 +7,12 @@ from src.training.tasks.base import BaseTask from src.training.tasks.regression_task import RegressionTask from src.training.tasks.rank_task import RankTask from src.training.tasks.tabm_regression_task import TabMRegressionTask +from src.training.tasks.tabm_rank_task import TabMRankTask __all__ = [ "BaseTask", "RegressionTask", "RankTask", "TabMRegressionTask", + "TabMRankTask", ] diff --git a/src/training/tasks/rank_task.py b/src/training/tasks/rank_task.py index 0afdc26..bbcf9e1 100644 --- a/src/training/tasks/rank_task.py +++ b/src/training/tasks/rank_task.py @@ -153,7 +153,6 @@ class RankTask(BaseTask): if k_list is None: k_list = [1, 5, 10, 20] - y_true = test_data["y_raw"] y_pred = self.predict(test_data) groups = test_data["groups"] @@ -166,9 +165,13 @@ class RankTask(BaseTask): y_true_groups = [] y_pred_groups = [] + # 使用分位数标签 y (0-19) 作为真实相关性分数,而非原始收益率 y_raw + # 这样与模型学习目标一致,避免原始收益率中负值的影响 + y_true_array = test_data["y"].to_numpy() + for group_size in groups: end_idx = start_idx + group_size - y_true_groups.append(y_true.to_numpy()[start_idx:end_idx]) + y_true_groups.append(y_true_array[start_idx:end_idx]) y_pred_groups.append(y_pred[start_idx:end_idx]) start_idx = end_idx diff --git a/src/training/tasks/tabm_rank_task.py b/src/training/tasks/tabm_rank_task.py new file mode 100644 index 0000000..2006b51 --- /dev/null +++ b/src/training/tasks/tabm_rank_task.py @@ -0,0 +1,249 @@ +"""TabM 排序学习任务实现 + +实现基于 TabM 的排序学习训练流程: +- Label 转换为分位数标签 +- 生成 group 数组 +- 使用 TabMRankModel(基于 ListNet Loss) +- 支持 NDCG@k 评估 +""" + +from typing import Any, Dict, List, Optional +import numpy as np +import polars as pl + +from src.training.tasks.base import BaseTask +from src.training.components.models.tabm_rank_model import TabMRankModel + + +class TabMRankTask(BaseTask): + """TabM 排序学习任务 + + 使用 TabMRankModel 进行排序学习训练。 + 将连续收益率转换为分位数标签进行训练。 + 支持指数化增益标签以增强 Top-K 关注。 + """ + + def __init__( + self, + model_params: Dict[str, Any], + label_name: str = "future_return_5", + n_quantiles: int = 20, + label_transform: Optional[str] = None, + label_scale: float = 20.0, + ): + """初始化排序学习任务 + + Args: + model_params: TabM 参数字典 + label_name: Label 列名 + n_quantiles: 分位数数量 + label_transform: 标签变换类型,可选: + - None: 标准分位数标签 (0, 1, ..., n_quantiles-1) + - "exponential": 指数化增益: 2^(rank/scale) - 1 + label_scale: 指数变换的缩放因子,用于控制增益幅度 + """ + super().__init__(model_params, label_name) + self.n_quantiles = n_quantiles + self.label_transform = label_transform + self.label_scale = label_scale + + def prepare_labels(self, data: Dict[str, Dict]) -> Dict[str, Dict]: + """准备标签(转换为分位数标签,可选指数化增益变换) + + 将连续收益率转换为分位数标签,并生成 group 数组。 + 支持指数化增益变换以增强头部样本的区分度。 + + Args: + data: 数据字典 + + Returns: + 处理后的数据字典(添加了 y_rank 和 groups) + """ + for split in ["train", "val", "test"]: + if split not in data: + continue + + df = data[split]["raw_data"] + + # 分位数转换 + rank_col = f"{self.label_name}_rank" + + # 1. 基础分位数标签 (0 到 n_quantiles-1) + df_ranked = df.with_columns( + pl.col(self.label_name) + .rank(method="min") + .over("trade_date") + .alias("_rank") + ).with_columns( + ((pl.col("_rank") - 1) / pl.len().over("trade_date") * self.n_quantiles) + .floor() + .cast(pl.Int64) + .clip(0, self.n_quantiles - 1) + .alias("_base_rank") + ) + + # 2. 【Top-K 优化】可选指数化增益变换 + if self.label_transform == "exponential": + # 平方变换: rank^2 + # 例如 rank=0 -> 0, rank=10 -> 100, rank=19 -> 361 + # 效果:高分样本与低分样本的差距被平方级拉大 + df_ranked = df_ranked.with_columns( + (pl.col("_base_rank").cast(pl.Float64) ** 2).alias(rank_col) + ) + else: + # 标准分位数标签 + df_ranked = df_ranked.with_columns( + pl.col("_base_rank").cast(pl.Float64).alias(rank_col) + ) + + # 清理临时列 + df_ranked = df_ranked.drop(["_rank", "_base_rank"]) + + # 更新数据 + data[split]["raw_data"] = df_ranked + data[split]["y"] = df_ranked[rank_col] + data[split]["y_raw"] = df_ranked[self.label_name] # 保留原始值 + + # 生成 group 数组 + data[split]["groups"] = self._compute_group_array(df_ranked, "trade_date") + + return data + + def _compute_group_array( + self, + df: pl.DataFrame, + date_col: str = "trade_date", + ) -> np.ndarray: + """计算 group 数组 + + Args: + df: 数据框 + date_col: 日期列名 + + Returns: + group 数组(每个日期的样本数) + """ + group_counts = df.group_by(date_col, maintain_order=True).agg( + pl.count().alias("count") + ) + return group_counts["count"].to_numpy() + + def fit(self, train_data: Dict, val_data: Dict) -> None: + """训练排序模型 + + Args: + train_data: 训练数据 + val_data: 验证数据 + """ + self.model = TabMRankModel(params=self.model_params) + + self.model.fit( + train_data["X"], + train_data["y"], + group=train_data["groups"], + eval_set=(val_data["X"], val_data["y"], val_data["groups"]) + if val_data + else None, + ) + + def predict(self, test_data: Dict) -> np.ndarray: + """生成预测 + + Args: + test_data: 测试数据 + + Returns: + 预测结果 + """ + # 传入 groups 参数,确保预测顺序与分组一致,与验证逻辑保持一致 + return self.model.predict(test_data["X"], group=test_data.get("groups")) + + def evaluate_ndcg( + self, + test_data: Dict, + k_list: List[int] = None, + ) -> Dict[str, float]: + """评估 NDCG@k + + Args: + test_data: 测试数据 + k_list: k 值列表,默认 [1, 5, 10, 20] + + Returns: + NDCG 分数字典 {"ndcg@1": score, ...} + """ + if k_list is None: + k_list = [1, 5, 10, 20] + + y_pred = self.predict(test_data) + groups = test_data["groups"] + + from sklearn.metrics import ndcg_score + + results = {} + + # 按 group 拆分 + start_idx = 0 + y_true_groups = [] + y_pred_groups = [] + + # 使用分位数标签 y (0-19) 作为真实相关性分数,而非原始收益率 y_raw + # 这样与模型学习目标一致,避免原始收益率中负值的影响 + y_true_array = test_data["y"].to_numpy() + + for group_size in groups: + end_idx = start_idx + group_size + y_true_groups.append(y_true_array[start_idx:end_idx]) + y_pred_groups.append(y_pred[start_idx:end_idx]) + start_idx = end_idx + + # 计算每个 k 的 NDCG + for k in k_list: + ndcg_scores = [] + for yt, yp in zip(y_true_groups, y_pred_groups): + if len(yt) > 1: + try: + score = ndcg_score([yt], [yp], k=k) + ndcg_scores.append(score) + except ValueError: + pass + + results[f"ndcg@{k}"] = float(np.mean(ndcg_scores)) if ndcg_scores else 0.0 + + return results + + def plot_training_metrics(self) -> None: + """绘制训练指标曲线(NDCG)""" + if self.model and hasattr(self.model, "get_evals_result"): + try: + import matplotlib.pyplot as plt + + evals_result = self.model.get_evals_result() + if not evals_result: + print("[警告] 没有训练指标数据可供绘制") + return + + fig, ax = plt.subplots(1, 2, figsize=(12, 4)) + + # 绘制训练损失 + if "train_loss" in evals_result: + ax[0].plot(evals_result["train_loss"], label="Train Loss") + ax[0].set_xlabel("Epoch") + ax[0].set_ylabel("ListNet Loss") + ax[0].set_title("Training Loss") + ax[0].legend() + ax[0].grid(True) + + # 绘制验证 NDCG + if "val_ndcg" in evals_result: + ax[1].plot(evals_result["val_ndcg"], label="Val NDCG") + ax[1].set_xlabel("Epoch") + ax[1].set_ylabel("NDCG") + ax[1].set_title("Validation NDCG") + ax[1].legend() + ax[1].grid(True) + + plt.tight_layout() + plt.show() + except Exception as e: + print(f"[警告] 无法绘制训练曲线: {e}") diff --git a/tests/test_tabm_rank_model.py b/tests/test_tabm_rank_model.py new file mode 100644 index 0000000..8f039ef --- /dev/null +++ b/tests/test_tabm_rank_model.py @@ -0,0 +1,88 @@ +"""测试 TabMRankModel 基础功能""" + +import pytest +import numpy as np +import polars as pl +import torch + +from src.training.components.models import TabMRankModel + + +class TestTabMRankModel: + """TabMRankModel 测试类""" + + def test_model_initialization(self): + """测试模型初始化""" + model = TabMRankModel(params={"ensemble_size": 16}) + assert model.name == "tabm_rank" + assert model.params["ensemble_size"] == 16 + assert model.device in [torch.device("cpu"), torch.device("cuda")] + + def test_prepare_group_from_dates(self): + """测试从日期生成 group 数组""" + df = pl.DataFrame( + { + "trade_date": [ + "20240101", + "20240101", + "20240102", + "20240102", + "20240102", + ], + "value": [1, 2, 3, 4, 5], + } + ) + group = TabMRankModel.prepare_group_from_dates(df) + assert np.array_equal(group, np.array([2, 3])) + + def test_convert_to_quantile_labels(self): + """测试转换分位数标签""" + df = pl.DataFrame( + { + "trade_date": ["20240101"] * 5 + ["20240102"] * 5, + "return": [0.1, 0.05, 0.0, -0.05, -0.1] + [0.2, 0.1, 0.0, -0.1, -0.2], + } + ) + result = TabMRankModel.convert_to_quantile_labels(df, "return", n_quantiles=5) + assert "return_rank" in result.columns + assert result["return_rank"].dtype == pl.Int64 + + def test_save_load(self, tmp_path): + """测试模型保存和加载(跳过实际训练)""" + params = { + "ensemble_size": 8, + "n_blocks": 2, + "d_block": 64, + "epochs": 1, + } + model = TabMRankModel(params=params) + model.feature_names_ = ["feat1", "feat2"] + + # 模拟训练历史 + model.training_history_["train_loss"] = [0.5, 0.4] + model.training_history_["val_ndcg"] = [0.3, 0.35] + + # 测试保存前需要初始化模型 + # 这里只测试元数据保存 + save_path = tmp_path / "test_model" + import pickle + + with open(save_path.with_suffix(".meta"), "wb") as f: + pickle.dump( + { + "params": model.params, + "feature_names": model.feature_names_, + "training_history": model.training_history_, + "device": str(model.device), + }, + f, + ) + + # 测试加载元数据 + loaded_model = TabMRankModel.load(save_path) + assert loaded_model.params == params + assert loaded_model.feature_names_ == ["feat1", "feat2"] + + +if __name__ == "__main__": + pytest.main([__file__, "-v"])