feat(training): 新增 TabM 排序学习模型支持并优化训练流程

- 新增 TabMRankModel、TabMRankTask 及配套损失函数与配置
- 将 DataQualityAnalyzer 从 experiment 迁移至 training 模块
- 调整数据处理器移除过度的 NaN/null 硬填充逻辑
- 优化 RankTask 评估指标使用分位数标签替代原始收益率
- 更新实验脚本处理器顺序与模型超参数配置
This commit is contained in:
2026-04-04 22:39:58 +08:00
parent 9e7d4241c6
commit a66d5e9db3
16 changed files with 1663 additions and 344 deletions

View File

@@ -0,0 +1,126 @@
# TabM Top-K 优化指南
针对每日只买入预测分最高的 5-10 只股票的 Top-K 选股场景,对 TabM 排序学习进行了三大方向优化。
## 优化概述
### 1. 损失函数优化(核心)
**加权 ListNet (推荐首次尝试)**
- 原理:给予高分标签样本更高的损失权重
- 参数:`loss_type="weighted_listnet"`, `topk_weight=5.0`
- 效果:头部样本权重是尾部的 5 倍,强迫模型关注头部排序
**LambdaLoss (精细 Top-K 优化)**
- 原理:基于 DeltaNDCG 计算每对样本交换位置后的损失
- 参数:`loss_type="lambda"`, `lambda_sigma=1.0`, `ndcg_weight_power=1.5`
- 效果:精准优化 NDCG@K 指标,适合追求极致 Top-K 性能
### 2. 标签工程优化(增强)
**指数化增益变换**
- 公式:`Gain = 2^(rank/scale) - 1`
- 参数:`label_transform="exponential"`, `label_scale=20.0`
- 效果rank=0 → 0, rank=19 → ~0.93, rank=99 → ~30.5
- 用途:拉大高分样本与低分样本的差距,强化头部区分度
### 3. 推荐配置组合
```python
# 配置 A: 温和优化(推荐首次尝试)
MODEL_PARAMS = {
"loss_type": "weighted_listnet",
"topk_weight": 3.0,
# ... 其他参数
}
LABEL_TRANSFORM = None # 保持标准分位数
# 配置 B: 平衡优化(兼顾效果和稳定性)
MODEL_PARAMS = {
"loss_type": "weighted_listnet",
"topk_weight": 5.0,
"ndcg_k": 20, # 验证时关注 NDCG@20
}
LABEL_TRANSFORM = "exponential"
LABEL_SCALE = 20.0
# 配置 C: 激进优化(专注 Top-10
MODEL_PARAMS = {
"loss_type": "lambda",
"lambda_sigma": 1.0,
"ndcg_weight_power": 1.5,
"ndcg_k": 10,
}
N_QUANTILES = 50 # 提高分位数分辨率
LABEL_TRANSFORM = "exponential"
LABEL_SCALE = 25.0
```
## 使用示例
`tabm_rank_train.py` 中修改配置:
```python
# 分位数配置
N_QUANTILES = 30
# 标签工程配置
LABEL_TRANSFORM = "exponential" # 启用指数化增益
LABEL_SCALE = 20.0
# 模型参数配置
MODEL_PARAMS = {
# ... 基础参数 ...
# Top-K 优化参数
"loss_type": "weighted_listnet", # 或 "lambda"
"topk_weight": 5.0, # 仅 weighted_listnet 有效
"lambda_sigma": 1.0, # 仅 lambda 有效
"ndcg_weight_power": 1.0, # 仅 lambda 有效
"ndcg_k": 20, # 验证指标
}
```
## 参数说明
### 损失函数参数
| 参数 | 类型 | 默认值 | 说明 |
|------|------|--------|------|
| `loss_type` | str | "listnet" | 损失类型: "listnet"/"weighted_listnet"/"lambda" |
| `topk_weight` | float | 5.0 | 头部权重系数 (weighted_listnet),越大越关注头部 |
| `lambda_sigma` | float | 1.0 | Sigmoid 陡峭程度 (lambda) |
| `ndcg_weight_power` | float | 1.0 | DeltaNDCG 权重幂次 (lambda)>1 进一步放大头部 |
| `ndcg_k` | int/None | None | 验证时计算的 NDCG@kNone 表示全局 |
### 标签工程参数
| 参数 | 类型 | 默认值 | 说明 |
|------|------|--------|------|
| `n_quantiles` | int | 20 | 分位数数量,越大分辨率越高 |
| `label_transform` | str/None | None | 变换类型: None/"exponential" |
| `label_scale` | float | 20.0 | 指数变换缩放因子,控制增益幅度 |
## 实施建议
1. **渐进式优化**
- 第1轮仅启用 `loss_type="weighted_listnet"`, `topk_weight=3.0`
- 第2轮增加标签工程 `label_transform="exponential"`
- 第3轮尝试 `loss_type="lambda"` 精细优化
2. **监控指标**
- 关注 NDCG@KK 设为实际 Top-K 大小)
- 对比不同配置的回测收益率
- 观察训练损失是否稳定下降
3. **注意事项**
- LambdaLoss 训练更慢,每 epoch 需更多时间
- 指数化增益会改变标签分布,可能需要调整学习率
- 过高的 topk_weight 可能导致过拟合头部样本
## 参考论文
1. **ListNet**: "Learning to Rank: From Pairwise Approach to Listwise Approach" (Cao et al., 2007)
2. **LambdaRank**: "From RankNet to LambdaRank to LambdaMART: An Overview" (Burges, 2010)
3. **LambdaLoss**: "The LambdaLoss Framework for Ranking Metric Optimization" (Wang et al., 2018)
4. **深度学习选股**: "Deep Learning for Stock Selection" (Gu, Kelly, & Xiu, 2020)

View File

@@ -270,6 +270,7 @@ SELECTED_FACTORS = [
"bottom_cost_stability",
"pivot_reversion",
"chip_transition",
# "amivest_liq_20",
# "atr_price_impact",
# "hui_heubel_ratio",
@@ -323,11 +324,11 @@ def get_label_factor(label_name: str) -> dict:
# 辅助函数
# =============================================================================
def register_factors(
engine: FactorEngine,
selected_factors: List[str],
factor_definitions: dict,
label_factor: dict,
excluded_factors: Optional[List[str]] = None,
engine: FactorEngine,
selected_factors: List[str],
factor_definitions: dict,
label_factor: dict,
excluded_factors: Optional[List[str]] = None,
) -> List[str]:
"""注册因子。
@@ -408,11 +409,11 @@ def register_factors(
def prepare_data(
engine: FactorEngine,
feature_cols: List[str],
start_date: str,
end_date: str,
label_name: str,
engine: FactorEngine,
feature_cols: List[str],
start_date: str,
end_date: str,
label_name: str,
) -> pl.DataFrame:
"""准备数据。
@@ -450,45 +451,6 @@ def prepare_data(
return data
# =============================================================================
# 股票池筛选配置
# =============================================================================
def stock_pool_filter(df: pl.DataFrame) -> pl.Series:
"""股票池筛选函数(单日数据)。
筛选条件:
1. 排除创业板(代码以 300 开头)
2. 排除科创板(代码以 688 开头)
3. 排除北交所(代码以 8、9 或 4 开头)
4. 选取当日市值最小的500只股票
Args:
df: 单日数据框
Returns:
布尔Series表示哪些股票被选中
"""
# 代码筛选(排除创业板、科创板、北交所)
code_filter = (
~df["ts_code"].str.starts_with("30") # 排除创业板
& ~df["ts_code"].str.starts_with("68") # 排除科创板
& ~df["ts_code"].str.starts_with("8") # 排除北交所
& ~df["ts_code"].str.starts_with("9") # 排除北交所
& ~df["ts_code"].str.starts_with("4") # 排除北交所
)
# 在已筛选的股票中选取流通市值最小的500只
valid_df = df.filter(code_filter)
n = min(1000, len(valid_df))
small_cap_codes = valid_df.sort("circ_mv").head(n)["ts_code"]
# 返回布尔 Series是否在被选中的股票中
return df["ts_code"].is_in(small_cap_codes)
# 定义筛选所需的基础列
STOCK_FILTER_REQUIRED_COLUMNS = ["circ_mv"]
# =============================================================================
# 输出配置
# =============================================================================
@@ -502,10 +464,54 @@ MODEL_SAVE_DIR = "models" # 模型保存目录
# Top N 配置:每日推荐股票数量
TOP_N = 5 # 可调整为 10, 20 等
# 股票池大小配置
STOCK_POOL_SIZE = 1000 # 股票池选择市值最小的股票数量
# 训练数据跳过天数配置
TRAIN_SKIP_DAYS = 300 # 跳过训练数据前252天的数据避免训练初期数据不足的问题
# =============================================================================
# 股票池筛选配置
# =============================================================================
def stock_pool_filter(df: pl.DataFrame, n_stocks: int = STOCK_POOL_SIZE) -> pl.Series:
"""股票池筛选函数(单日数据)。
筛选条件:
1. 排除创业板(代码以 300 开头)
2. 排除科创板(代码以 688 开头)
3. 排除北交所(代码以 8、9 或 4 开头)
4. 选取当日市值最小的n_stocks只股票
Args:
df: 单日数据框
n_stocks: 选取的股票数量,默认为 STOCK_POOL_SIZE
Returns:
布尔Series表示哪些股票被选中
"""
# 代码筛选(排除创业板、科创板、北交所)
code_filter = (
~df["ts_code"].str.starts_with("30") # 排除创业板
& ~df["ts_code"].str.starts_with("68") # 排除科创板
& ~df["ts_code"].str.starts_with("8") # 排除北交所
& ~df["ts_code"].str.starts_with("9") # 排除北交所
& ~df["ts_code"].str.starts_with("4") # 排除北交所
)
# 在已筛选的股票中选取流通市值最小的n_stocks只
valid_df = df.filter(code_filter)
n = min(n_stocks, len(valid_df))
small_cap_codes = valid_df.sort("circ_mv").head(n)["ts_code"]
# 返回布尔 Series是否在被选中的股票中
return df["ts_code"].is_in(small_cap_codes)
# 定义筛选所需的基础列
STOCK_FILTER_REQUIRED_COLUMNS = ["circ_mv"]
def get_output_path(model_type: str, test_start: str, test_end: str) -> str:
"""生成输出文件路径。
@@ -532,7 +538,7 @@ def get_output_path(model_type: str, test_start: str, test_end: str) -> str:
def get_model_save_path(
model_type: str,
model_type: str,
) -> Optional[str]:
"""生成模型保存路径。
@@ -558,11 +564,11 @@ def get_model_save_path(
def save_model_with_factors(
model,
model_path: str,
selected_factors: list[str],
factor_definitions: dict,
fitted_processors: list | None = None,
model,
model_path: str,
selected_factors: list[str],
factor_definitions: dict,
fitted_processors: list | None = None,
) -> str:
"""保存模型及关联的因子信息和处理器。

View File

@@ -54,27 +54,10 @@ N_QUANTILES = 20
# 排除的因子列表
EXCLUDED_FACTORS = [
'active_market_cap',
'close_vwap_deviation',
'sharpe_ratio_20',
'upper_shadow_ratio',
'volume_ratio_5_20',
'GTJA_alpha090',
'GTJA_alpha084',
'GTJA_alpha066',
'GTJA_alpha150',
'GTJA_alpha148',
'GTJA_alpha106',
'GTJA_alpha109',
'GTJA_alpha108',
'GTJA_alpha176',
'GTJA_alpha169',
'GTJA_alpha156',
'chip_dispersion_70',
'winner_rate_cs_rank',
'atr_price_impact',
'low_vol_days_20',
'liquidity_shock_momentum',
# 'debt_to_equity',
# 'GTJA_alpha016',
# 'GTJA_alpha141',
]
# LambdaRank 模型参数配置
@@ -145,8 +128,8 @@ def main():
pipeline = DataPipeline(
factor_manager=factor_manager,
processor_configs=[
(NullFiller, {"strategy": "mean"}),
(Winsorizer, {"lower": 0.01, "upper": 0.99}),
(NullFiller, {"strategy": "mean"}),
(CrossSectionalStandardScaler, {}),
],
filters=[STFilter(data_router=engine.router)],

View File

@@ -52,36 +52,36 @@ TRAINING_TYPE = "regression"
# 排除的因子列表
EXCLUDED_FACTORS = [
'GTJA_alpha016',
'volatility_20',
'current_ratio',
'GTJA_alpha001',
'GTJA_alpha141',
'GTJA_alpha129',
'GTJA_alpha164',
'amivest_liq_20',
'GTJA_alpha012',
'debt_to_equity',
'turnover_deviation',
'GTJA_alpha073',
'GTJA_alpha043',
'GTJA_alpha032',
'GTJA_alpha028',
'GTJA_alpha090',
'GTJA_alpha108',
'GTJA_alpha105',
'GTJA_alpha091',
'GTJA_alpha119',
'GTJA_alpha104',
'GTJA_alpha163',
'GTJA_alpha157',
'cost_skewness',
'GTJA_alpha176',
'chip_transition',
'amount_skewness_20',
'GTJA_alpha148',
'mean_median_dev',
'downside_illiq_20',
# 'GTJA_alpha016',
# 'volatility_20',
# 'current_ratio',
# 'GTJA_alpha001',
# 'GTJA_alpha141',
# 'GTJA_alpha129',
# 'GTJA_alpha164',
# 'amivest_liq_20',
# 'GTJA_alpha012',
# 'debt_to_equity',
# 'turnover_deviation',
# 'GTJA_alpha073',
# 'GTJA_alpha043',
# 'GTJA_alpha032',
# 'GTJA_alpha028',
# 'GTJA_alpha090',
# 'GTJA_alpha108',
# 'GTJA_alpha105',
# 'GTJA_alpha091',
# 'GTJA_alpha119',
# 'GTJA_alpha104',
# 'GTJA_alpha163',
# 'GTJA_alpha157',
# 'cost_skewness',
# 'GTJA_alpha176',
# 'chip_transition',
# 'amount_skewness_20',
# 'GTJA_alpha148',
# 'mean_median_dev',
# 'downside_illiq_20',
]
# 模型参数配置
@@ -153,15 +153,15 @@ def main():
pipeline = DataPipeline(
factor_manager=factor_manager,
processor_configs=[
(NullFiller, {"strategy": "mean"}),
(Winsorizer, {"lower": 0.01, "upper": 0.99}),
(NullFiller, {"strategy": "mean"}),
(StandardScaler, {}),
# (CrossSectionalStandardScaler, {}),
],
label_processor_configs=[
# 对 label 进行缩尾处理(去除极端收益率)
(Winsorizer, {"lower": 0.05, "upper": 0.95}),
# (StandardScaler, {}),
(StandardScaler, {}),
],
filters=[STFilter(data_router=engine.router)],
stock_pool_filter_func=stock_pool_filter,

View File

@@ -0,0 +1,176 @@
"""TabM 排序学习训练流程(模块化版本)
使用新的模块化 Trainer 架构,基于 TabMRankModel 实现排序学习。
TabM 使用 ListNet 损失函数,支持集成学习。
"""
import os
from src.factors import FactorEngine
from src.training import (
FactorManager,
DataPipeline,
NullFiller,
Winsorizer,
CrossSectionalStandardScaler,
)
from src.training.tasks.tabm_rank_task import TabMRankTask
from src.training.core.trainer_v2 import Trainer
from src.training.components.filters import STFilter
from src.experiment.common import (
SELECTED_FACTORS,
FACTOR_DEFINITIONS,
LABEL_NAME,
LABEL_FACTOR,
TRAIN_START,
TRAIN_END,
VAL_START,
VAL_END,
TEST_START,
TEST_END,
stock_pool_filter,
STOCK_FILTER_REQUIRED_COLUMNS,
OUTPUT_DIR,
SAVE_PREDICTIONS,
SAVE_MODEL,
get_model_save_path,
save_model_with_factors,
TOP_N,
TRAIN_SKIP_DAYS,
)
# 训练类型标识
TRAINING_TYPE = "tabm_rank"
# %%
# Label 配置(从 common.py 统一导入)
# LABEL_NAME 和 LABEL_FACTOR 已在 common.py 中绑定,只需从 common 导入
# 分位数配置(提高分辨率以更好地区分头部)
N_QUANTILES = 50
# 【Top-K 优化】标签工程配置 - 默认启用平方增益
LABEL_TRANSFORM = "exponential" # 启用平方增益标签 (rank^2)
LABEL_SCALE = 20.0 # 保留参数(当前未使用,平方变换不需要缩放)
# 排除的因子列表
EXCLUDED_FACTORS = ["GTJA_alpha041", "GTJA_alpha127"]
# TabM Rank 模型参数配置Top-K 优化全部开启,使用 LambdaLoss
MODEL_PARAMS = {
# ==================== MLP 结构 ====================
"n_blocks": 4, # MLP 层数
"d_block": 256, # 每层神经元数
"dropout": 0.5, # Dropout 率
# ==================== 集成机制 ====================
"ensemble_size": 32, # 内置集成大小(模拟 32 个模型集成)
# ==================== 训练参数 ====================
"learning_rate": 1e-4, # 学习率
"weight_decay": 1e-5, # 权重衰减
"epochs": 500, # 训练轮数
# ==================== 早停 ====================
"early_stopping_round": 50, # 早停耐心值
# NDCG 评估 - 关注 Top-20
"ndcg_k": 20, # 验证时计算 NDCG@20
# 【Top-K 优化】损失函数配置 - 使用 LambdaLoss
"loss_type": "lambda", # 使用 LambdaLoss 精准优化 Top-K
"lambda_sigma": 1.0, # Sigmoid 陡峭程度
"ndcg_weight_power": 1.0, # DeltaNDCG 权重幂次,>1 进一步放大头部效应
}
# 日期范围配置
date_range = {
"train": (TRAIN_START, TRAIN_END),
"val": (VAL_START, VAL_END),
"test": (TEST_START, TEST_END),
}
# 输出配置
output_config = {
"output_dir": OUTPUT_DIR,
"output_filename": "tabm_rank_output.csv",
"save_predictions": SAVE_PREDICTIONS,
"save_model": SAVE_MODEL,
"model_save_path": get_model_save_path(TRAINING_TYPE),
"top_n": TOP_N,
}
def main():
"""主函数"""
print("\n" + "=" * 80)
print("TabM 排序学习训练(模块化版本)")
print("=" * 80)
# 1. 创建 FactorEngine
print("\n[1] 创建 FactorEngine")
engine = FactorEngine()
# 2. 创建 FactorManager
print("\n[2] 创建 FactorManager")
factor_manager = FactorManager(
selected_factors=SELECTED_FACTORS,
factor_definitions=FACTOR_DEFINITIONS,
label_factor=LABEL_FACTOR,
excluded_factors=EXCLUDED_FACTORS,
)
# 3. 创建 DataPipeline
print("\n[3] 创建 DataPipeline")
pipeline = DataPipeline(
factor_manager=factor_manager,
processor_configs=[
(Winsorizer, {"lower": 0.01, "upper": 0.99}),
(NullFiller, {"strategy": "mean"}),
(CrossSectionalStandardScaler, {}),
],
filters=[STFilter(data_router=engine.router)],
stock_pool_filter_func=stock_pool_filter,
stock_pool_required_columns=STOCK_FILTER_REQUIRED_COLUMNS,
train_skip_days=TRAIN_SKIP_DAYS,
)
# 4. 创建 TabMRankTask
print("\n[4] 创建 TabMRankTask")
task = TabMRankTask(
model_params=MODEL_PARAMS,
label_name=LABEL_NAME,
n_quantiles=N_QUANTILES,
label_transform=LABEL_TRANSFORM,
label_scale=LABEL_SCALE,
)
# 5. 创建 Trainer
print("\n[5] 创建 Trainer")
trainer = Trainer(
data_pipeline=pipeline,
task=task,
output_config=output_config,
verbose=True,
)
# 6. 执行训练
print("\n[6] 执行训练")
results = trainer.run(engine=engine, date_range=date_range)
# 7. 保存模型和因子信息(如果启用)
if SAVE_MODEL:
print("\n[7] 保存模型和因子信息")
save_model_with_factors(
model=task.get_model(),
model_path=output_config["model_save_path"],
selected_factors=SELECTED_FACTORS,
factor_definitions=FACTOR_DEFINITIONS,
fitted_processors=pipeline.get_fitted_processors(),
)
print("\n" + "=" * 80)
print("训练流程完成!")
print(f"结果保存路径: {os.path.join(OUTPUT_DIR, 'tabm_rank_output.csv')}")
print("=" * 80)
return results
if __name__ == "__main__":
main()

View File

@@ -39,9 +39,8 @@ from src.experiment.common import (
get_model_save_path,
save_model_with_factors,
TOP_N,
TRAIN_SKIP_DAYS,
TRAIN_SKIP_DAYS, STOCK_POOL_SIZE,
)
from src.experiment.data_quality_analyzer import DataQualityAnalyzer
# 训练类型标识
TRAINING_TYPE = "tabm_regression"
@@ -54,201 +53,7 @@ TRAINING_TYPE = "tabm_regression"
# 排除的因子列表(与 LightGBM 回归保持一致)
EXCLUDED_FACTORS = [
# "GTJA_alpha001",
# "GTJA_alpha002",
# "GTJA_alpha003",
# "GTJA_alpha004",
# "GTJA_alpha005",
# "GTJA_alpha006",
# "GTJA_alpha007",
# "GTJA_alpha008",
# "GTJA_alpha009",
# "GTJA_alpha010",
# "GTJA_alpha011",
# "GTJA_alpha012",
# "GTJA_alpha013",
# "GTJA_alpha014",
# "GTJA_alpha015",
# "GTJA_alpha016",
# "GTJA_alpha017",
# "GTJA_alpha018",
# "GTJA_alpha019",
# "GTJA_alpha020",
# "GTJA_alpha022",
# "GTJA_alpha023",
# "GTJA_alpha024",
# "GTJA_alpha025",
# "GTJA_alpha026",
# "GTJA_alpha027",
# "GTJA_alpha028",
# "GTJA_alpha029",
# "GTJA_alpha031",
# "GTJA_alpha032",
# "GTJA_alpha033",
# "GTJA_alpha034",
# "GTJA_alpha035",
# "GTJA_alpha036",
# "GTJA_alpha037",
# # "GTJA_alpha038",
# "GTJA_alpha039",
# "GTJA_alpha040",
# "GTJA_alpha041",
# "GTJA_alpha042",
# "GTJA_alpha043",
# "GTJA_alpha044",
# "GTJA_alpha045",
# "GTJA_alpha046",
# "GTJA_alpha047",
# "GTJA_alpha048",
# "GTJA_alpha049",
# "GTJA_alpha050",
# "GTJA_alpha051",
# "GTJA_alpha052",
# "GTJA_alpha053",
# "GTJA_alpha054",
# "GTJA_alpha056",
# "GTJA_alpha057",
# "GTJA_alpha058",
# "GTJA_alpha059",
# "GTJA_alpha060",
# "GTJA_alpha061",
# "GTJA_alpha062",
# "GTJA_alpha063",
# "GTJA_alpha064",
# "GTJA_alpha065",
# "GTJA_alpha066",
# "GTJA_alpha067",
# "GTJA_alpha068",
# "GTJA_alpha070",
# "GTJA_alpha071",
# "GTJA_alpha072",
# "GTJA_alpha073",
# "GTJA_alpha074",
# "GTJA_alpha076",
# "GTJA_alpha077",
# "GTJA_alpha078",
# "GTJA_alpha079",
# "GTJA_alpha080",
# "GTJA_alpha081",
# "GTJA_alpha082",
# "GTJA_alpha083",
# "GTJA_alpha084",
# "GTJA_alpha085",
# "GTJA_alpha086",
# "GTJA_alpha087",
# "GTJA_alpha088",
# "GTJA_alpha089",
# "GTJA_alpha090",
# "GTJA_alpha091",
# "GTJA_alpha092",
# "GTJA_alpha093",
# "GTJA_alpha094",
# "GTJA_alpha095",
# "GTJA_alpha096",
# "GTJA_alpha097",
# "GTJA_alpha098",
# "GTJA_alpha099",
# "GTJA_alpha100",
# "GTJA_alpha101",
# "GTJA_alpha102",
# "GTJA_alpha103",
# "GTJA_alpha104",
# "GTJA_alpha105",
# "GTJA_alpha106",
# "GTJA_alpha107",
# "GTJA_alpha108",
# "GTJA_alpha109",
# "GTJA_alpha110",
# "GTJA_alpha111",
# "GTJA_alpha112",
# # "GTJA_alpha113",
# "GTJA_alpha114",
# "GTJA_alpha115",
# "GTJA_alpha117",
# "GTJA_alpha118",
# "GTJA_alpha119",
# "GTJA_alpha120",
# # "GTJA_alpha121",
# "GTJA_alpha122",
# "GTJA_alpha123",
# "GTJA_alpha124",
# "GTJA_alpha125",
# "GTJA_alpha126",
# "GTJA_alpha127",
# "GTJA_alpha128",
# "GTJA_alpha129",
# "GTJA_alpha130",
# "GTJA_alpha131",
# "GTJA_alpha132",
# "GTJA_alpha133",
# "GTJA_alpha134",
# "GTJA_alpha135",
# "GTJA_alpha136",
# # "GTJA_alpha138",
# "GTJA_alpha139",
# # "GTJA_alpha140",
# "GTJA_alpha141",
# "GTJA_alpha142",
# "GTJA_alpha145",
# # "GTJA_alpha146",
# "GTJA_alpha148",
# "GTJA_alpha150",
# "GTJA_alpha151",
# "GTJA_alpha152",
# "GTJA_alpha153",
# "GTJA_alpha154",
# "GTJA_alpha155",
# "GTJA_alpha156",
# "GTJA_alpha157",
# "GTJA_alpha158",
# "GTJA_alpha159",
# "GTJA_alpha160",
# "GTJA_alpha161",
# "GTJA_alpha162",
# "GTJA_alpha163",
# "GTJA_alpha164",
# # "GTJA_alpha165",
# "GTJA_alpha166",
# "GTJA_alpha167",
# "GTJA_alpha168",
# "GTJA_alpha169",
# "GTJA_alpha170",
# "GTJA_alpha171",
# "GTJA_alpha173",
# "GTJA_alpha174",
# "GTJA_alpha175",
# "GTJA_alpha176",
# "GTJA_alpha177",
# "GTJA_alpha178",
# "GTJA_alpha179",
# "GTJA_alpha180",
# # "GTJA_alpha183",
# "GTJA_alpha184",
# "GTJA_alpha185",
# "GTJA_alpha187",
# "GTJA_alpha188",
# "GTJA_alpha189",
# "GTJA_alpha191",
# "chip_dispersion_90",
# "chip_dispersion_70",
# "cost_skewness",
# "dispersion_change_20",
# "price_to_avg_cost",
# "price_to_median_cost",
# "mean_median_dev",
# "trap_pressure",
# "bottom_profit",
# "history_position",
# "winner_rate_surge_5",
# "winner_rate_cs_rank",
# "winner_rate_dev_20",
# "winner_rate_volatility",
# "smart_money_accumulation",
# "winner_vol_corr_20",
# "cost_base_momentum",
# "bottom_cost_stability",
# "pivot_reversion",
# "chip_transition",
]
# TabM 模型参数配置(来自用户提供的示例代码)
@@ -256,11 +61,11 @@ MODEL_PARAMS = {
# ==================== MLP 结构 ====================
"n_blocks": 3, # MLP 层数
"d_block": 256, # 每层神经元数
"dropout": 0.3, # Dropout 率
"dropout": 0.5, # Dropout 率
# ==================== 集成机制 ====================
"ensemble_size": 32, # 内置集成大小(模拟 32 个模型集成)
# ==================== 训练参数 ====================
"batch_size": 2048, # 批次大小
"batch_size": STOCK_POOL_SIZE * 5, # 批次大小
"learning_rate": 1e-3, # 学习率
"weight_decay": 1e-5, # 权重衰减
"epochs": 100, # 训练轮数
@@ -312,13 +117,14 @@ def main():
pipeline = DataPipeline(
factor_manager=factor_manager,
processor_configs=[
(NullFiller, {"strategy": "mean"}),
(Winsorizer, {"lower": 0.01, "upper": 0.99}), # 先缩尾处理厚尾分布
(NullFiller, {"strategy": "mean"}),
(StandardScaler, {}), # TabM 需要标准化输入
],
label_processor_configs=[
# 对 label 进行缩尾处理(去除极端收益率)
(Winsorizer, {"lower": 0.05, "upper": 0.95}),
(Winsorizer, {"lower": 0.01, "upper": 0.99}),
(StandardScaler, {}),
],
filters=[STFilter(data_router=engine.router)],
stock_pool_filter_func=stock_pool_filter,

View File

@@ -43,6 +43,12 @@ from src.training.core import StockPoolManager, Trainer
# 工具函数
from src.training.utils import check_data_quality
# 数据质量分析器
from src.training.data_quality_analyzer import (
DataQualityAnalyzer,
analyze_data_quality,
)
# 配置
from src.training.config import TrainingConfig
@@ -85,6 +91,9 @@ __all__ = [
"Trainer",
# 工具函数
"check_data_quality",
# 数据质量分析器
"DataQualityAnalyzer",
"analyze_data_quality",
# 配置
"TrainingConfig",
# 新增:模块化 Trainer 组件(推荐使用)

View File

@@ -7,6 +7,11 @@ from src.training.components.models.lightgbm import LightGBMModel
from src.training.components.models.lightgbm_lambdarank import LightGBMLambdaRankModel
from src.training.components.models.tabpfn_model import TabPFNModel
from src.training.components.models.tabm_model import TabMModel
from src.training.components.models.tabm_rank_model import (
TabMRankModel,
EnsembleListNetLoss,
EnsembleLambdaLoss,
)
from src.training.components.models.cross_section_sampler import CrossSectionSampler
from src.training.components.models.ensemble_quant_loss import EnsembleQuantLoss
@@ -15,6 +20,9 @@ __all__ = [
"LightGBMLambdaRankModel",
"TabPFNModel",
"TabMModel",
"TabMRankModel",
"EnsembleListNetLoss",
"EnsembleLambdaLoss",
"CrossSectionSampler",
"EnsembleQuantLoss",
]

View File

@@ -0,0 +1,747 @@
"""TabM 排序模型实现 (TabM Rank)
基于 TabM (Tabular Multilayer Perceptron with Ensembles) 架构
引入 ListNet 列表级排序损失,实现类似 LambdaRank 的截面排序学习。
适用于股票未来收益率的截面排序预测。
"""
from typing import Dict, Any, List, Optional, Tuple
from pathlib import Path
import pickle
import numpy as np
import polars as pl
import scipy.stats as stats
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset, Sampler
from tabm import TabM
from src.training.components.base import BaseModel
from src.training.registry import register_model
class GroupSampler(Sampler):
"""排序学习专用的分组采样器
确保每个 Batch 包含的是同一个 Query (如同一天) 的所有样本。
这与 LightGBM 的 `group` 参数逻辑完全一致。
"""
def __init__(self, group_counts: np.ndarray, shuffle_groups: bool = True):
"""初始化分组采样器
Args:
group_counts: 各组的样本数量数组,如 [100, 150, 120]
shuffle_groups: 是否打乱组的训练顺序
"""
self.group_counts = group_counts
self.shuffle_groups = shuffle_groups
# 计算每组在原数组中的起始和结束边界
self.boundaries = np.insert(np.cumsum(group_counts), 0, 0)
self.num_groups = len(group_counts)
def __iter__(self):
"""迭代生成批次索引
Yields:
list: 同一组的所有样本索引
"""
group_indices = list(range(self.num_groups))
if self.shuffle_groups:
np.random.shuffle(group_indices)
for g_idx in group_indices:
start = self.boundaries[g_idx]
end = self.boundaries[g_idx + 1]
# 返回该组(天)内所有样本的索引,作为一个完整的 batch
yield list(range(start, end))
def __len__(self):
"""返回组数量"""
return self.num_groups
class EnsembleListNetLoss(nn.Module):
"""集成 ListNet 排序损失 (Listwise Ranking Loss)
基于交叉熵将截面上的相关度转化为概率分布。
支持 TabM 的 Ensemble 维度。
"""
def __init__(self, topk_weight: float = 1.0):
"""初始化 ListNet 损失
Args:
topk_weight: 头部样本权重系数,>1.0 时强化对高分标签的关注
"""
super().__init__()
self.topk_weight = topk_weight
def forward(self, preds: torch.Tensor, targets: torch.Tensor):
"""计算 ListNet 排序损失
Args:
preds: [BatchSize, EnsembleSize] - 当前组所有样本的预测 logits
targets: [BatchSize] - 当前组所有样本的真实排序标签 (0, 1, 2...)
Returns:
标量损失值
"""
# 如果该组内样本太少,无法排序,直接返回 0
if preds.size(0) <= 1:
return torch.tensor(0.0, requires_grad=True, device=preds.device)
# [BatchSize] -> [BatchSize, EnsembleSize] 广播以对齐每个集成成员
targets_expanded = targets.unsqueeze(1).expand_as(preds)
# 1. 计算真实标签的分数概率分布 (Softmax 使得高分权重成倍提升)
targets_prob = F.softmax(targets_expanded, dim=0)
# 2. 【Top-K 优化】如果启用加权,给予高分标签更高权重
if self.topk_weight > 1.0:
# 基于标签值计算权重(标签越高,权重越大)
# 归一化到 [1, topk_weight] 范围
max_target = targets.max()
min_target = targets.min()
if max_target > min_target:
# 线性插值:权重 = 1 + (target - min) / (max - min) * (topk_weight - 1)
sample_weights = 1.0 + (targets - min_target) / (
max_target - min_target
) * (self.topk_weight - 1.0)
sample_weights = sample_weights.unsqueeze(1).expand_as(preds)
# 归一化权重使其和为样本数
sample_weights = sample_weights * len(targets) / sample_weights.sum()
# 应用权重到目标概率
targets_prob = targets_prob * sample_weights
# 3. 计算预测值的对数概率分布 (log_softmax 数值上比 log(softmax) 更稳定)
preds_log_prob = F.log_softmax(preds, dim=0)
# 4. 计算交叉熵损失 (在 Batch/组 维度求和)
loss = -torch.sum(targets_prob * preds_log_prob, dim=0) # [EnsembleSize]
# 5. 对所有集成成员的 Loss 取平均
return loss.mean()
class EnsembleLambdaLoss(nn.Module):
"""集成 LambdaLoss (支持 TabM 集成维度)
基于 Pairwise 排序损失,引入 DeltaNDCG 权重。
参考: "The LambdaLoss Framework for Ranking Metric Optimization" (Google Research)
特点:
- 计算每对样本交换位置后对 NDCG 的影响
- 对头部样本(高 Gain 且排名靠前)给予更高权重
- 更适合 Top-K 选股场景
"""
def __init__(self, sigma: float = 1.0, ndcg_weight_power: float = 1.0):
"""初始化 LambdaLoss
Args:
sigma: Sigmoid 函数的陡峭程度,控制梯度大小
ndcg_weight_power: DeltaNDCG 权重幂次,>1 时进一步放大头部效应
"""
super().__init__()
self.sigma = sigma
self.ndcg_weight_power = ndcg_weight_power
def forward(self, preds: torch.Tensor, targets: torch.Tensor):
"""计算 LambdaLoss
Args:
preds: [BatchSize, EnsembleSize] - 预测分
targets: [BatchSize] - 相关性标签 (Gain)
Returns:
标量损失值
"""
if preds.size(0) <= 1:
return torch.tensor(0.0, requires_grad=True, device=preds.device)
# 1. 计算两两对之间的差值
preds_diff = preds.unsqueeze(1) - preds.unsqueeze(0) # [B, B, E]
# 【性能优化】: target_diff 不需要 E 维度!直接保持 [B, B]
target_diff = targets.unsqueeze(1) - targets.unsqueeze(0) # [B, B]
# 2. 掩码矩阵: [B, B, 1] 方便后续广播
mask = (target_diff > 0).float().unsqueeze(2) # [B, B, 1]
# 3. 计算 Delta NDCG
# DeltaNDCG = |Gain_i - Gain_j| * |1/log(rank_i+1) - 1/log(rank_j+1)|
with torch.no_grad():
# 【性能核爆优化】: 使用两次 argsort 完全消灭 for 循环
# 第1次 argsort: 获得从大到小的索引
# 第2次 argsort: 直接将索引反转为排名 (加上 1 就是真实名次)
ranks = preds.argsort(dim=0, descending=True).argsort(dim=0) + 1
ranks = ranks.float() # [B, E]
# 计算位置惩罚项 (log2 排名)
log_rank = torch.log2(ranks + 1)
inv_log_rank_diff = torch.abs(
1.0 / log_rank.unsqueeze(1) - 1.0 / log_rank.unsqueeze(0)
) # [B, B, E]
# DeltaNDCG = |Gain 差| * |位置惩罚差|
# target_diff 是 [B, B],通过 unsqueeze 广播到 [B, B, 1]
delta_ndcg = torch.abs(target_diff).unsqueeze(2) * inv_log_rank_diff
# 应用幂次调整权重分布
if self.ndcg_weight_power != 1.0:
delta_ndcg = torch.pow(delta_ndcg, self.ndcg_weight_power)
# 4. Pairwise Logistic Loss 并加权 Delta NDCG
# loss = delta_ndcg * log(1 + exp(-sigma * (preds_i - preds_j)))
pairwise_loss = F.binary_cross_entropy_with_logits(
self.sigma * preds_diff,
torch.ones_like(preds_diff),
reduction="none",
)
# 5. 应用掩码和权重
weighted_loss = pairwise_loss * mask * delta_ndcg
# 避免全 0 导致的除零错误
valid_pairs_count = mask.sum().clamp(min=1.0)
mean_loss = weighted_loss.sum() / valid_pairs_count
return mean_loss
@register_model("tabm_rank")
class TabMRankModel(BaseModel):
"""TabM 学习排序模型
基于 TabM 架构的排序学习模型,支持 ListNet 损失。
适用于股票截面排序任务,将未来收益率转换为分位数标签进行训练。
特点:
- 使用 ListNet 列表级排序损失
- 支持 group 参数进行分组训练
- 以 NDCG 作为验证指标
- 与 LightGBMLambdaRank 接口兼容
"""
name = "tabm_rank"
def __init__(self, params: Optional[Dict[str, Any]] = None):
"""初始化 TabM Rank 模型
Args:
params: 模型参数字典,包含:
- ensemble_size: 集成大小 (默认: 32)
- n_blocks: MLP层数 (默认: 3)
- d_block: 每层神经元数 (默认: 256)
- dropout: Dropout率 (默认: 0.1)
- batch_size: 批次大小 (默认: 2048仅预测时使用)
- learning_rate: 学习率 (默认: 1e-3)
- weight_decay: 权重衰减 (默认: 1e-5)
- epochs: 训练轮数 (默认: 50)
- early_stopping_round: 早停轮数 (默认: 10)
- max_grad_norm: 梯度裁剪阈值 (默认: 1.0)
- ndcg_k: NDCG@k 的 k 值None 表示全局 (默认: None)
- loss_type: 损失函数类型 (默认: "listnet")
- "listnet": 标准 ListNet 损失
- "weighted_listnet": 加权 ListNet通过 topk_weight 强化头部
- "lambda": LambdaLoss基于 DeltaNDCG 加权
- topk_weight: 头部样本权重系数,用于 weighted_listnet (默认: 5.0)
- lambda_sigma: LambdaLoss 的 sigma 参数 (默认: 1.0)
- ndcg_weight_power: DeltaNDCG 权重幂次 (默认: 1.0)
"""
self.params = params or {}
self.model = None
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.training_history_: Dict[str, List[float]] = {
"train_loss": [],
"val_ndcg": [],
}
self.feature_names_: Optional[List[str]] = None
# 根据配置选择损失函数
loss_type = self.params.get("loss_type", "listnet")
if loss_type == "lambda":
self.criterion = EnsembleLambdaLoss(
sigma=self.params.get("lambda_sigma", 1.0),
ndcg_weight_power=self.params.get("ndcg_weight_power", 1.0),
)
elif loss_type == "weighted_listnet":
self.criterion = EnsembleListNetLoss(
topk_weight=self.params.get("topk_weight", 5.0)
)
else: # "listnet"
self.criterion = EnsembleListNetLoss(topk_weight=1.0)
def _make_loader(
self,
X: np.ndarray,
y: Optional[np.ndarray] = None,
group: Optional[np.ndarray] = None,
shuffle_groups: bool = False,
) -> DataLoader:
"""创建 DataLoader (支持 Query/Group 截面打包)
Args:
X: 特征数组 [N, n_features]
y: 标签数组 [N] 或 None
group: 分组数组,表示每组样本数
shuffle_groups: 是否打乱组的顺序
Returns:
DataLoader 实例
"""
# 【性能核爆优化】: 显存管够时,直接把全量数据一把推进 GPU避免训练时 PCIe 搬运
X_tensor = torch.from_numpy(X).to(self.device)
if y is not None:
y_tensor = torch.from_numpy(y).to(self.device)
dataset = TensorDataset(X_tensor, y_tensor)
else:
dataset = TensorDataset(X_tensor)
if group is not None:
# 训练和验证时使用 GroupSampler每个 batch 就是一个 Query
sampler = GroupSampler(group, shuffle_groups=shuffle_groups)
return DataLoader(dataset, batch_sampler=sampler)
else:
# 预测时如果没有 group则退化为普通批次预测
batch_size = self.params.get("batch_size", 2048)
return DataLoader(dataset, batch_size=batch_size, shuffle=False)
def _validate_ndcg(self, val_loader: DataLoader, k: Optional[int] = None) -> float:
"""验证模型 (使用 NDCG 排序指标)
Args:
val_loader: 验证数据加载器
k: NDCG@k 的 k 值None 表示计算全局 NDCG
Returns:
平均 NDCG 分数
"""
from sklearn.metrics import ndcg_score
assert self.model is not None, "模型未训练,无法验证"
self.model.eval()
ndcg_list = []
with torch.no_grad():
for batch in val_loader:
if len(batch) != 2:
continue
bx, by = batch
bx = bx.to(self.device)
by = by.cpu().numpy()
if len(by) <= 1:
continue
outputs = self.model(bx) # [B, E, 1]
preds = outputs.mean(dim=1).squeeze(-1).cpu().numpy() # [B]
try:
# ndcg_score 需要形状为 (1, n_samples) 的二维数组
score = ndcg_score([by], [preds], k=k)
ndcg_list.append(score)
except ValueError:
pass
return float(np.mean(ndcg_list)) if len(ndcg_list) > 0 else 0.0
def fit(
self,
X: pl.DataFrame,
y: pl.Series,
group: Optional[np.ndarray] = None,
eval_set: Optional[Tuple] = None,
) -> "TabMRankModel":
"""训练排序模型
Args:
X: 训练特征DataFrame
y: 训练标签 (Polars Series),应为分位数标签 (0, 1, 2, ...)
group: 分组数组,表示每个 query 的样本数
eval_set: 验证集元组 (X_val, y_val, group_val),用于早停
Returns:
self (支持链式调用)
Raises:
ValueError: group 参数无效
"""
self.feature_names_ = list(X.columns)
X_np = X.to_numpy().astype(np.float32)
y_np = y.to_numpy().astype(np.float32)
# 检查和处理 group 参数
if group is None:
group = np.array([len(y_np)])
if group.sum() != len(y_np):
raise ValueError(
f"group 数组的和 ({group.sum()}) 必须等于样本数 ({len(y_np)})"
)
train_loader = self._make_loader(X_np, y_np, group=group, shuffle_groups=True)
val_loader = None
if eval_set is not None:
X_val, y_val, group_val = eval_set
X_val_np = (
X_val.to_numpy().astype(np.float32)
if isinstance(X_val, pl.DataFrame)
else X_val
)
y_val_np = (
y_val.to_numpy().astype(np.float32)
if isinstance(y_val, pl.Series)
else y_val
)
if group_val is None:
group_val = np.array([len(y_val_np)])
val_loader = self._make_loader(
X_val_np, y_val_np, group=group_val, shuffle_groups=False
)
ensemble_size = self.params.get("ensemble_size", 32)
n_features = X_np.shape[1]
# 初始化 TabM 模型
self.model = TabM.make(
n_num_features=n_features,
cat_cardinalities=[],
d_out=1,
n_blocks=self.params.get("n_blocks", 3),
d_block=self.params.get("d_block", 256),
dropout=self.params.get("dropout", 0.1),
k=ensemble_size,
).to(self.device)
optimizer = optim.AdamW(
self.model.parameters(),
lr=self.params.get("learning_rate", 1e-3),
weight_decay=self.params.get("weight_decay", 1e-5),
)
epochs = self.params.get("epochs", 50)
scheduler = optim.lr_scheduler.CosineAnnealingLR(
optimizer, T_max=epochs, eta_min=1e-6
)
early_stopping_patience = self.params.get(
"early_stopping_patience",
self.params.get("early_stopping_round", 10),
)
best_val_ndcg = -float("inf")
patience_counter = 0
best_model_state = None
ndcg_k = self.params.get("ndcg_k", None) # None 表示计算全局 NDCG
print(f"[TabMRank] 开始训练... 设备: {self.device}, 集成大小: {ensemble_size}")
for epoch in range(epochs):
# 训练阶段
self.model.train()
train_loss = 0.0
n_train_batches = 0
for batch in train_loader:
if len(batch) != 2:
continue
bx, by = batch[0], batch[1]
optimizer.zero_grad()
outputs = self.model(bx) # [B, E, 1]
outputs_squeezed = outputs.squeeze(-1) # [B, E]
# 计算 ListNet 排序损失
loss = self.criterion(outputs_squeezed, by)
loss.backward()
torch.nn.utils.clip_grad_norm_(
self.model.parameters(),
max_norm=self.params.get("max_grad_norm", 1.0),
)
optimizer.step()
train_loss += loss.item()
n_train_batches += 1
avg_train_loss = train_loss / max(n_train_batches, 1)
self.training_history_["train_loss"].append(avg_train_loss)
# 验证阶段 (基于 NDCG)
if val_loader is not None:
val_ndcg = self._validate_ndcg(val_loader, k=ndcg_k)
self.training_history_["val_ndcg"].append(val_ndcg)
if val_ndcg > best_val_ndcg:
best_val_ndcg = val_ndcg
patience_counter = 0
best_model_state = {
k: v.cpu().clone() for k, v in self.model.state_dict().items()
}
else:
patience_counter += 1
if (epoch + 1) % 5 == 0 or epoch == 0:
print(
f"[TabMRank] Epoch {epoch + 1}/{epochs} | "
f"Train Loss (ListNet): {avg_train_loss:.4f} | "
f"Val NDCG: {val_ndcg:.4f} (Best: {best_val_ndcg:.4f})"
)
if patience_counter >= early_stopping_patience:
print(f"[TabMRank] 触发早停,停止于 epoch {epoch + 1}")
break
else:
if (epoch + 1) % 5 == 0 or epoch == 0:
print(
f"[TabMRank] Epoch {epoch + 1}/{epochs} | "
f"Train Loss: {avg_train_loss:.4f}"
)
scheduler.step()
# 恢复最佳权重
if best_model_state is not None:
self.model.load_state_dict(best_model_state)
print(f"[TabMRank] 已恢复最佳模型权重 (Val NDCG: {best_val_ndcg:.4f})")
return self
def predict(
self, X: pl.DataFrame, group: Optional[np.ndarray] = None
) -> np.ndarray:
"""预测排序分数
Args:
X: 特征矩阵 (Polars DataFrame)
group: 分组数组,表示每个 query 的样本数。
如果提供,将使用 GroupSampler 确保预测顺序与分组一致。
Returns:
预测分数 (numpy ndarray)
Raises:
RuntimeError: 模型未训练时调用
ValueError: 预测数据缺失特征
"""
if self.model is None:
raise RuntimeError("模型未训练请先调用fit()")
# 特征对齐检查
if self.feature_names_:
missing_cols = [c for c in self.feature_names_ if c not in X.columns]
if missing_cols:
raise ValueError(f"预测数据缺失特征: {missing_cols}")
X = X.select(self.feature_names_)
X_np = X.to_numpy().astype(np.float32)
loader = self._make_loader(X_np, group=group, shuffle_groups=False)
self.model.eval()
all_preds = []
with torch.no_grad():
for batch in loader:
bx = batch[0].to(self.device)
outputs = self.model(bx) # [B, E, 1]
# 排序模型预测时直接输出集成成员的均值作为最终分数
preds = outputs.mean(dim=1).squeeze(-1) # [B]
all_preds.append(preds.cpu().numpy())
return np.concatenate(all_preds)
def get_evals_result(self) -> Optional[Dict[str, List[float]]]:
"""获取训练评估结果
Returns:
评估结果字典,包含 train_loss 和 val_ndcg
"""
return self.training_history_
def feature_importance(self) -> None:
"""获取特征重要性
TabM没有内置特征重要性计算返回None。
"""
return None
def save(self, path: str | Path) -> None:
"""保存模型
Args:
path: 保存路径
Raises:
RuntimeError: 模型未训练时调用
"""
if self.model is None:
raise RuntimeError("模型未训练,无法保存")
path = Path(path)
path.parent.mkdir(parents=True, exist_ok=True)
# 保存模型权重
model_path = path.with_suffix(".pt")
torch.save(self.model.state_dict(), model_path)
# 保存元数据
meta_path = path.with_suffix(".meta")
meta = {
"params": self.params,
"feature_names": self.feature_names_,
"training_history": self.training_history_,
"device": str(self.device),
}
with open(meta_path, "wb") as f:
pickle.dump(meta, f)
print(f"[TabMRank] 模型保存到: {path}")
@classmethod
def load(cls, path: str | Path) -> "TabMRankModel":
"""加载模型
Args:
path: 模型路径(不含扩展名)
Returns:
加载的 TabMRankModel 实例
"""
path = Path(path)
# 加载元数据
meta_path = path.with_suffix(".meta")
with open(meta_path, "rb") as f:
meta = pickle.load(f)
# 创建实例
instance = cls(meta["params"])
instance.feature_names_ = meta["feature_names"]
instance.training_history_ = meta["training_history"]
# 重建模型结构
if instance.feature_names_ is not None:
n_features = len(instance.feature_names_)
ensemble_size = instance.params.get("ensemble_size", 32)
instance.model = TabM.make(
n_num_features=n_features,
cat_cardinalities=[],
d_out=1,
n_blocks=instance.params.get("n_blocks", 3),
d_block=instance.params.get("d_block", 256),
dropout=instance.params.get("dropout", 0.1),
k=ensemble_size,
).to(instance.device)
# 加载权重
model_path = path.with_suffix(".pt")
instance.model.load_state_dict(
torch.load(model_path, map_location=instance.device)
)
print(f"[TabMRank] 模型从 {path} 加载完成")
return instance
@staticmethod
def prepare_group_from_dates(
df: pl.DataFrame,
date_col: str = "trade_date",
) -> np.ndarray:
"""从日期列生成 group 数组
Args:
df: 包含日期列的 DataFrame
date_col: 日期列名,默认 "trade_date"
Returns:
group 数组
"""
group_counts = df.group_by(date_col, maintain_order=True).agg(
pl.count().alias("count")
)
return group_counts["count"].to_numpy()
@staticmethod
def convert_to_quantile_labels(
df: pl.DataFrame,
label_col: str,
date_col: str = "trade_date",
n_quantiles: int = 20,
new_col_name: Optional[str] = None,
) -> pl.DataFrame:
"""将连续标签转换为分位数标签
Args:
df: 输入 DataFrame
label_col: 原始标签列名
date_col: 日期列名,默认 "trade_date"
n_quantiles: 分位数数量,默认 20
new_col_name: 新列名,默认为 {label_col}_rank
Returns:
添加了分位数标签列的 DataFrame
"""
if new_col_name is None:
new_col_name = f"{label_col}_rank"
return df.with_columns(
pl.col(label_col)
.qcut(n_quantiles)
.over(date_col)
.to_physical()
.cast(pl.Int64)
.alias(new_col_name)
)
def evaluate_ndcg(
self,
X: pl.DataFrame,
y: pl.Series,
group: np.ndarray,
k: Optional[int] = None,
) -> float:
"""评估 NDCG 指标
Args:
X: 特征矩阵
y: 真实标签
group: 分组数组
k: NDCG@k 的 k 值
Returns:
NDCG 分数
"""
from sklearn.metrics import ndcg_score
y_pred = self.predict(X)
y_true_list = []
y_score_list = []
start_idx = 0
for group_size in group:
end_idx = start_idx + group_size
y_true_list.append(y.to_numpy()[start_idx:end_idx])
y_score_list.append(y_pred[start_idx:end_idx])
start_idx = end_idx
ndcg_scores = []
for y_true, y_score in zip(y_true_list, y_score_list):
if len(y_true) > 1:
try:
score = ndcg_score([y_true], [y_score], k=k)
ndcg_scores.append(score)
except ValueError:
pass
return float(np.mean(ndcg_scores)) if ndcg_scores else 0.0

View File

@@ -289,7 +289,7 @@ class StandardScaler(BaseProcessor):
return self
def transform(self, X: pl.DataFrame) -> pl.DataFrame:
"""标准化(使用训练集学到的参数,增加 NaN 保护
"""标准化(使用训练集学到的参数)
Args:
X: 待转换数据
@@ -302,18 +302,11 @@ class StandardScaler(BaseProcessor):
if col in self.mean_ and col in self.std_:
# 避免除以0
std_val = self.std_[col] if self.std_[col] != 0 else 1.0
# 关键修复:添加 fill_nan(0) 保险,防止计算产生 NaN
expr = (
((pl.col(col) - self.mean_[col]) / std_val)
.fill_nan(0)
.fill_null(0)
.alias(col)
)
expr = ((pl.col(col) - self.mean_[col]) / std_val).alias(col)
expressions.append(expr)
elif col in self.feature_cols:
# 对于应该被处理但未学习到统计量的列
# 统一转换为float并同时处理 NaN 和 null
expr = pl.col(col).cast(pl.Float64).fill_nan(0).fill_null(0).alias(col)
# 对于应该被处理但未学习到统计量的列统一转换为float
expr = pl.col(col).cast(pl.Float64).alias(col)
expressions.append(expr)
else:
expressions.append(pl.col(col))
@@ -372,20 +365,14 @@ class CrossSectionalStandardScaler(BaseProcessor):
if col in self.feature_cols and X[col].dtype.is_numeric():
# 截面标准化:每天独立计算均值和标准差
# 避免除以0当std为0时设为1
# 关键修复:先 fill_nan 再 fill_null防止计算产生的 NaN
expr = (
(
(pl.col(col) - pl.col(col).mean().over(self.date_col))
/ (pl.col(col).std().over(self.date_col) + 1e-10)
)
.fill_nan(0)
.fill_null(0)
.alias(col)
)
(pl.col(col) - pl.col(col).mean().over(self.date_col))
/ (pl.col(col).std().over(self.date_col) + 1e-10)
).alias(col)
expressions.append(expr)
elif col in self.feature_cols:
# 对于应该被处理但类型不匹配的列转换为float并同时处理 NaN 和 null
expr = pl.col(col).cast(pl.Float64).fill_nan(0).fill_null(0).alias(col)
# 对于应该被处理但类型不匹配的列转换为float
expr = pl.col(col).cast(pl.Float64).alias(col)
expressions.append(expr)
else:
expressions.append(pl.col(col))
@@ -488,8 +475,8 @@ class Winsorizer(BaseProcessor):
expressions.append(expr)
elif col in self.feature_cols:
# 对于应该被处理但未学习到边界的列如全为NaN、布尔列等
# 统一转换为float并填充0
expr = pl.col(col).cast(pl.Float64).fill_null(0).alias(col)
# 统一转换为float
expr = pl.col(col).cast(pl.Float64).alias(col)
expressions.append(expr)
else:
expressions.append(pl.col(col))
@@ -522,11 +509,9 @@ class Winsorizer(BaseProcessor):
clip_exprs = []
for col in X.columns:
if col in target_cols:
# 先用当天分位数缩尾如果分位数是null该日全为NaN则填充0
clipped = (
pl.col(col)
.clip(pl.col(f"{col}_lower"), pl.col(f"{col}_upper"))
.fill_null(0)
.alias(col)
)
clip_exprs.append(clipped)

View File

@@ -13,6 +13,7 @@ from src.factors import FactorEngine
from src.training.pipeline import DataPipeline
from src.training.tasks.base import BaseTask
from src.training.result_analyzer import ResultAnalyzer
from src.training.data_quality_analyzer import DataQualityAnalyzer
class Trainer:
@@ -100,8 +101,6 @@ class Trainer:
print("\n[Step 1.5/7] 数据质量分析...")
try:
from src.experiment.data_quality_analyzer import DataQualityAnalyzer
# 获取特征列名(从训练集)
feature_cols = data["train"].get("feature_cols", [])
label_name = self.task.label_name

View File

@@ -1,6 +1,7 @@
"""数据质量分析模块
提供数据质量检查功能包括
- 数据集日期范围信息
- 缺失值统计
- 零值统计
- 按日期检查全空列
@@ -19,6 +20,7 @@ class DataQualityAnalyzer:
Attributes:
feature_cols: 特征列名列表
label_col: 标签列名
date_col: 日期列名
verbose: 是否打印详细信息
"""
@@ -26,6 +28,7 @@ class DataQualityAnalyzer:
self,
feature_cols: Optional[List[str]] = None,
label_col: Optional[str] = None,
date_col: str = "trade_date",
verbose: bool = True,
):
"""初始化数据质量分析器
@@ -33,10 +36,12 @@ class DataQualityAnalyzer:
Args:
feature_cols: 特征列名列表
label_col: 标签列名
date_col: 日期列名默认为 "trade_date"
verbose: 是否打印详细信息
"""
self.feature_cols = feature_cols or []
self.label_col = label_col
self.date_col = date_col
self.verbose = verbose
self.analysis_results: Dict[str, Any] = {}
@@ -74,6 +79,10 @@ class DataQualityAnalyzer:
self.analysis_results = {}
# 首先打印数据集概览(日期范围等基本信息)
if self.verbose:
self._print_dataset_overview(data, split_names)
for split_name in split_names:
if split_name not in data:
continue
@@ -96,6 +105,104 @@ class DataQualityAnalyzer:
return self.analysis_results
def _print_dataset_overview(
self,
data: Dict[str, Dict[str, Any]],
split_names: List[str],
) -> None:
"""打印数据集概览信息
包括每个数据集的起始日期终止日期样本数量等基本信息
Args:
data: 数据字典
split_names: 数据划分名称列表
"""
print("\n[数据集概览]")
print("-" * 40)
overview_data = []
for split_name in split_names:
if split_name not in data:
continue
split_data = data[split_name]
raw_df = split_data.get("raw_data")
if raw_df is None or len(raw_df) == 0:
overview_data.append(
{
"划分": split_name.upper(),
"起始日期": "-",
"终止日期": "-",
"样本数": 0,
"股票数": 0,
}
)
continue
# 获取日期范围
if self.date_col in raw_df.columns:
dates = raw_df[self.date_col]
start_date = dates.min()
end_date = dates.max()
unique_dates = dates.n_unique()
else:
start_date = "-"
end_date = "-"
unique_dates = 0
# 获取股票数量
if "ts_code" in raw_df.columns:
unique_stocks = raw_df["ts_code"].n_unique()
else:
unique_stocks = 0
overview_data.append(
{
"划分": split_name.upper(),
"起始日期": str(start_date),
"终止日期": str(end_date),
"交易日数": unique_dates,
"样本数": len(raw_df),
"股票数": unique_stocks,
}
)
# 打印表格
if overview_data:
# 计算列宽
headers = ["划分", "起始日期", "终止日期", "交易日数", "样本数", "股票数"]
col_widths = {}
for header in headers:
max_data_len = max(
len(
str(
row.get(
header.lower().replace("", ""), row.get(header, "")
)
)
)
for row in overview_data
)
col_widths[header] = max(len(header), max_data_len) + 2
# 打印表头
header_line = " ".join(h.ljust(col_widths[h]) for h in headers)
print(f" {header_line}")
print(f" {'-' * (sum(col_widths.values()) + 2 * (len(headers) - 1))}")
# 打印数据行
for row in overview_data:
line = " ".join(
str(row.get(h, row.get(h.lower().replace("", ""), ""))).ljust(
col_widths[h]
)
for h in headers
)
print(f" {line}")
def _analyze_split(
self,
df: pl.DataFrame,
@@ -120,6 +227,16 @@ class DataQualityAnalyzer:
"all_null_by_date": {},
}
# 获取日期范围
if self.date_col in df.columns:
results["start_date"] = str(df[self.date_col].min())
results["end_date"] = str(df[self.date_col].max())
results["unique_dates"] = df[self.date_col].n_unique()
# 获取股票数量
if "ts_code" in df.columns:
results["unique_stocks"] = df["ts_code"].n_unique()
# 1. 分析特征列的缺失值
null_stats = self._analyze_null_values(df, self.feature_cols)
results["null_analysis"] = null_stats
@@ -249,7 +366,7 @@ class DataQualityAnalyzer:
"issues": [],
}
if "trade_date" not in df.columns:
if self.date_col not in df.columns:
return results
# 过滤掉不在表中的列
@@ -267,7 +384,7 @@ class DataQualityAnalyzer:
]
agg_exprs.append(pl.len().alias("total_rows"))
agg_lf = lf.group_by("trade_date").agg(agg_exprs)
agg_lf = lf.group_by(self.date_col).agg(agg_exprs)
# 收集结果 (此时 agg_df 行数通常只有几百到几千行)
agg_df = agg_lf.collect()
@@ -279,13 +396,13 @@ class DataQualityAnalyzer:
# 找出 null 数量等于总行数的日期
bad_dates = agg_df.filter(
(pl.col(null_col) == pl.col("total_rows")) & (pl.col("total_rows") > 0)
).select(["trade_date", "total_rows"])
).select([self.date_col, "total_rows"])
if not bad_dates.is_empty():
for row in bad_dates.to_dicts():
issues.append(
{
"date": row["trade_date"],
"date": row[self.date_col],
"column": col,
"total_rows": row["total_rows"],
}
@@ -464,8 +581,20 @@ class DataQualityAnalyzer:
for split_name, results in self.analysis_results.items():
lines.append(f"\n[{split_name.upper()}]")
# 添加日期范围信息
if "start_date" in results and "end_date" in results:
lines.append(
f" 日期范围: {results['start_date']} ~ {results['end_date']}"
)
if "unique_dates" in results:
lines.append(f" 交易日数: {results['unique_dates']}")
lines.append(f" 总行数: {results['total_rows']:,}")
if "unique_stocks" in results:
lines.append(f" 股票数: {results['unique_stocks']}")
null_stats = results.get("null_analysis", {})
if null_stats.get("columns_with_null"):
lines.append(
@@ -493,6 +622,7 @@ def analyze_data_quality(
data: Dict[str, Dict[str, Any]],
feature_cols: Optional[List[str]] = None,
label_col: Optional[str] = None,
date_col: str = "trade_date",
verbose: bool = True,
) -> Dict[str, Any]:
"""便捷函数:执行数据质量分析
@@ -501,6 +631,7 @@ def analyze_data_quality(
data: 数据字典
feature_cols: 特征列名列表
label_col: 标签列名
date_col: 日期列名默认为 "trade_date"
verbose: 是否打印详细信息
Returns:
@@ -509,6 +640,7 @@ def analyze_data_quality(
analyzer = DataQualityAnalyzer(
feature_cols=feature_cols,
label_col=label_col,
date_col=date_col,
verbose=verbose,
)
return analyzer.analyze(data)

View File

@@ -7,10 +7,12 @@ from src.training.tasks.base import BaseTask
from src.training.tasks.regression_task import RegressionTask
from src.training.tasks.rank_task import RankTask
from src.training.tasks.tabm_regression_task import TabMRegressionTask
from src.training.tasks.tabm_rank_task import TabMRankTask
__all__ = [
"BaseTask",
"RegressionTask",
"RankTask",
"TabMRegressionTask",
"TabMRankTask",
]

View File

@@ -153,7 +153,6 @@ class RankTask(BaseTask):
if k_list is None:
k_list = [1, 5, 10, 20]
y_true = test_data["y_raw"]
y_pred = self.predict(test_data)
groups = test_data["groups"]
@@ -166,9 +165,13 @@ class RankTask(BaseTask):
y_true_groups = []
y_pred_groups = []
# 使用分位数标签 y (0-19) 作为真实相关性分数,而非原始收益率 y_raw
# 这样与模型学习目标一致,避免原始收益率中负值的影响
y_true_array = test_data["y"].to_numpy()
for group_size in groups:
end_idx = start_idx + group_size
y_true_groups.append(y_true.to_numpy()[start_idx:end_idx])
y_true_groups.append(y_true_array[start_idx:end_idx])
y_pred_groups.append(y_pred[start_idx:end_idx])
start_idx = end_idx

View File

@@ -0,0 +1,249 @@
"""TabM 排序学习任务实现
实现基于 TabM 的排序学习训练流程:
- Label 转换为分位数标签
- 生成 group 数组
- 使用 TabMRankModel基于 ListNet Loss
- 支持 NDCG@k 评估
"""
from typing import Any, Dict, List, Optional
import numpy as np
import polars as pl
from src.training.tasks.base import BaseTask
from src.training.components.models.tabm_rank_model import TabMRankModel
class TabMRankTask(BaseTask):
"""TabM 排序学习任务
使用 TabMRankModel 进行排序学习训练。
将连续收益率转换为分位数标签进行训练。
支持指数化增益标签以增强 Top-K 关注。
"""
def __init__(
self,
model_params: Dict[str, Any],
label_name: str = "future_return_5",
n_quantiles: int = 20,
label_transform: Optional[str] = None,
label_scale: float = 20.0,
):
"""初始化排序学习任务
Args:
model_params: TabM 参数字典
label_name: Label 列名
n_quantiles: 分位数数量
label_transform: 标签变换类型,可选:
- None: 标准分位数标签 (0, 1, ..., n_quantiles-1)
- "exponential": 指数化增益: 2^(rank/scale) - 1
label_scale: 指数变换的缩放因子,用于控制增益幅度
"""
super().__init__(model_params, label_name)
self.n_quantiles = n_quantiles
self.label_transform = label_transform
self.label_scale = label_scale
def prepare_labels(self, data: Dict[str, Dict]) -> Dict[str, Dict]:
"""准备标签(转换为分位数标签,可选指数化增益变换)
将连续收益率转换为分位数标签,并生成 group 数组。
支持指数化增益变换以增强头部样本的区分度。
Args:
data: 数据字典
Returns:
处理后的数据字典(添加了 y_rank 和 groups
"""
for split in ["train", "val", "test"]:
if split not in data:
continue
df = data[split]["raw_data"]
# 分位数转换
rank_col = f"{self.label_name}_rank"
# 1. 基础分位数标签 (0 到 n_quantiles-1)
df_ranked = df.with_columns(
pl.col(self.label_name)
.rank(method="min")
.over("trade_date")
.alias("_rank")
).with_columns(
((pl.col("_rank") - 1) / pl.len().over("trade_date") * self.n_quantiles)
.floor()
.cast(pl.Int64)
.clip(0, self.n_quantiles - 1)
.alias("_base_rank")
)
# 2. 【Top-K 优化】可选指数化增益变换
if self.label_transform == "exponential":
# 平方变换: rank^2
# 例如 rank=0 -> 0, rank=10 -> 100, rank=19 -> 361
# 效果:高分样本与低分样本的差距被平方级拉大
df_ranked = df_ranked.with_columns(
(pl.col("_base_rank").cast(pl.Float64) ** 2).alias(rank_col)
)
else:
# 标准分位数标签
df_ranked = df_ranked.with_columns(
pl.col("_base_rank").cast(pl.Float64).alias(rank_col)
)
# 清理临时列
df_ranked = df_ranked.drop(["_rank", "_base_rank"])
# 更新数据
data[split]["raw_data"] = df_ranked
data[split]["y"] = df_ranked[rank_col]
data[split]["y_raw"] = df_ranked[self.label_name] # 保留原始值
# 生成 group 数组
data[split]["groups"] = self._compute_group_array(df_ranked, "trade_date")
return data
def _compute_group_array(
self,
df: pl.DataFrame,
date_col: str = "trade_date",
) -> np.ndarray:
"""计算 group 数组
Args:
df: 数据框
date_col: 日期列名
Returns:
group 数组(每个日期的样本数)
"""
group_counts = df.group_by(date_col, maintain_order=True).agg(
pl.count().alias("count")
)
return group_counts["count"].to_numpy()
def fit(self, train_data: Dict, val_data: Dict) -> None:
"""训练排序模型
Args:
train_data: 训练数据
val_data: 验证数据
"""
self.model = TabMRankModel(params=self.model_params)
self.model.fit(
train_data["X"],
train_data["y"],
group=train_data["groups"],
eval_set=(val_data["X"], val_data["y"], val_data["groups"])
if val_data
else None,
)
def predict(self, test_data: Dict) -> np.ndarray:
"""生成预测
Args:
test_data: 测试数据
Returns:
预测结果
"""
# 传入 groups 参数,确保预测顺序与分组一致,与验证逻辑保持一致
return self.model.predict(test_data["X"], group=test_data.get("groups"))
def evaluate_ndcg(
self,
test_data: Dict,
k_list: List[int] = None,
) -> Dict[str, float]:
"""评估 NDCG@k
Args:
test_data: 测试数据
k_list: k 值列表,默认 [1, 5, 10, 20]
Returns:
NDCG 分数字典 {"ndcg@1": score, ...}
"""
if k_list is None:
k_list = [1, 5, 10, 20]
y_pred = self.predict(test_data)
groups = test_data["groups"]
from sklearn.metrics import ndcg_score
results = {}
# 按 group 拆分
start_idx = 0
y_true_groups = []
y_pred_groups = []
# 使用分位数标签 y (0-19) 作为真实相关性分数,而非原始收益率 y_raw
# 这样与模型学习目标一致,避免原始收益率中负值的影响
y_true_array = test_data["y"].to_numpy()
for group_size in groups:
end_idx = start_idx + group_size
y_true_groups.append(y_true_array[start_idx:end_idx])
y_pred_groups.append(y_pred[start_idx:end_idx])
start_idx = end_idx
# 计算每个 k 的 NDCG
for k in k_list:
ndcg_scores = []
for yt, yp in zip(y_true_groups, y_pred_groups):
if len(yt) > 1:
try:
score = ndcg_score([yt], [yp], k=k)
ndcg_scores.append(score)
except ValueError:
pass
results[f"ndcg@{k}"] = float(np.mean(ndcg_scores)) if ndcg_scores else 0.0
return results
def plot_training_metrics(self) -> None:
"""绘制训练指标曲线NDCG"""
if self.model and hasattr(self.model, "get_evals_result"):
try:
import matplotlib.pyplot as plt
evals_result = self.model.get_evals_result()
if not evals_result:
print("[警告] 没有训练指标数据可供绘制")
return
fig, ax = plt.subplots(1, 2, figsize=(12, 4))
# 绘制训练损失
if "train_loss" in evals_result:
ax[0].plot(evals_result["train_loss"], label="Train Loss")
ax[0].set_xlabel("Epoch")
ax[0].set_ylabel("ListNet Loss")
ax[0].set_title("Training Loss")
ax[0].legend()
ax[0].grid(True)
# 绘制验证 NDCG
if "val_ndcg" in evals_result:
ax[1].plot(evals_result["val_ndcg"], label="Val NDCG")
ax[1].set_xlabel("Epoch")
ax[1].set_ylabel("NDCG")
ax[1].set_title("Validation NDCG")
ax[1].legend()
ax[1].grid(True)
plt.tight_layout()
plt.show()
except Exception as e:
print(f"[警告] 无法绘制训练曲线: {e}")

View File

@@ -0,0 +1,88 @@
"""测试 TabMRankModel 基础功能"""
import pytest
import numpy as np
import polars as pl
import torch
from src.training.components.models import TabMRankModel
class TestTabMRankModel:
"""TabMRankModel 测试类"""
def test_model_initialization(self):
"""测试模型初始化"""
model = TabMRankModel(params={"ensemble_size": 16})
assert model.name == "tabm_rank"
assert model.params["ensemble_size"] == 16
assert model.device in [torch.device("cpu"), torch.device("cuda")]
def test_prepare_group_from_dates(self):
"""测试从日期生成 group 数组"""
df = pl.DataFrame(
{
"trade_date": [
"20240101",
"20240101",
"20240102",
"20240102",
"20240102",
],
"value": [1, 2, 3, 4, 5],
}
)
group = TabMRankModel.prepare_group_from_dates(df)
assert np.array_equal(group, np.array([2, 3]))
def test_convert_to_quantile_labels(self):
"""测试转换分位数标签"""
df = pl.DataFrame(
{
"trade_date": ["20240101"] * 5 + ["20240102"] * 5,
"return": [0.1, 0.05, 0.0, -0.05, -0.1] + [0.2, 0.1, 0.0, -0.1, -0.2],
}
)
result = TabMRankModel.convert_to_quantile_labels(df, "return", n_quantiles=5)
assert "return_rank" in result.columns
assert result["return_rank"].dtype == pl.Int64
def test_save_load(self, tmp_path):
"""测试模型保存和加载(跳过实际训练)"""
params = {
"ensemble_size": 8,
"n_blocks": 2,
"d_block": 64,
"epochs": 1,
}
model = TabMRankModel(params=params)
model.feature_names_ = ["feat1", "feat2"]
# 模拟训练历史
model.training_history_["train_loss"] = [0.5, 0.4]
model.training_history_["val_ndcg"] = [0.3, 0.35]
# 测试保存前需要初始化模型
# 这里只测试元数据保存
save_path = tmp_path / "test_model"
import pickle
with open(save_path.with_suffix(".meta"), "wb") as f:
pickle.dump(
{
"params": model.params,
"feature_names": model.feature_names_,
"training_history": model.training_history_,
"device": str(model.device),
},
f,
)
# 测试加载元数据
loaded_model = TabMRankModel.load(save_path)
assert loaded_model.params == params
assert loaded_model.feature_names_ == ["feat1", "feat2"]
if __name__ == "__main__":
pytest.main([__file__, "-v"])