feat(training): 新增 TabM 排序学习模型支持并优化训练流程
- 新增 TabMRankModel、TabMRankTask 及配套损失函数与配置 - 将 DataQualityAnalyzer 从 experiment 迁移至 training 模块 - 调整数据处理器移除过度的 NaN/null 硬填充逻辑 - 优化 RankTask 评估指标使用分位数标签替代原始收益率 - 更新实验脚本处理器顺序与模型超参数配置
This commit is contained in:
126
src/experiment/README_TABM_TOPK.md
Normal file
126
src/experiment/README_TABM_TOPK.md
Normal file
@@ -0,0 +1,126 @@
|
||||
# TabM Top-K 优化指南
|
||||
|
||||
针对每日只买入预测分最高的 5-10 只股票的 Top-K 选股场景,对 TabM 排序学习进行了三大方向优化。
|
||||
|
||||
## 优化概述
|
||||
|
||||
### 1. 损失函数优化(核心)
|
||||
|
||||
**加权 ListNet (推荐首次尝试)**
|
||||
- 原理:给予高分标签样本更高的损失权重
|
||||
- 参数:`loss_type="weighted_listnet"`, `topk_weight=5.0`
|
||||
- 效果:头部样本权重是尾部的 5 倍,强迫模型关注头部排序
|
||||
|
||||
**LambdaLoss (精细 Top-K 优化)**
|
||||
- 原理:基于 DeltaNDCG 计算每对样本交换位置后的损失
|
||||
- 参数:`loss_type="lambda"`, `lambda_sigma=1.0`, `ndcg_weight_power=1.5`
|
||||
- 效果:精准优化 NDCG@K 指标,适合追求极致 Top-K 性能
|
||||
|
||||
### 2. 标签工程优化(增强)
|
||||
|
||||
**指数化增益变换**
|
||||
- 公式:`Gain = 2^(rank/scale) - 1`
|
||||
- 参数:`label_transform="exponential"`, `label_scale=20.0`
|
||||
- 效果:rank=0 → 0, rank=19 → ~0.93, rank=99 → ~30.5
|
||||
- 用途:拉大高分样本与低分样本的差距,强化头部区分度
|
||||
|
||||
### 3. 推荐配置组合
|
||||
|
||||
```python
|
||||
# 配置 A: 温和优化(推荐首次尝试)
|
||||
MODEL_PARAMS = {
|
||||
"loss_type": "weighted_listnet",
|
||||
"topk_weight": 3.0,
|
||||
# ... 其他参数
|
||||
}
|
||||
LABEL_TRANSFORM = None # 保持标准分位数
|
||||
|
||||
# 配置 B: 平衡优化(兼顾效果和稳定性)
|
||||
MODEL_PARAMS = {
|
||||
"loss_type": "weighted_listnet",
|
||||
"topk_weight": 5.0,
|
||||
"ndcg_k": 20, # 验证时关注 NDCG@20
|
||||
}
|
||||
LABEL_TRANSFORM = "exponential"
|
||||
LABEL_SCALE = 20.0
|
||||
|
||||
# 配置 C: 激进优化(专注 Top-10)
|
||||
MODEL_PARAMS = {
|
||||
"loss_type": "lambda",
|
||||
"lambda_sigma": 1.0,
|
||||
"ndcg_weight_power": 1.5,
|
||||
"ndcg_k": 10,
|
||||
}
|
||||
N_QUANTILES = 50 # 提高分位数分辨率
|
||||
LABEL_TRANSFORM = "exponential"
|
||||
LABEL_SCALE = 25.0
|
||||
```
|
||||
|
||||
## 使用示例
|
||||
|
||||
在 `tabm_rank_train.py` 中修改配置:
|
||||
|
||||
```python
|
||||
# 分位数配置
|
||||
N_QUANTILES = 30
|
||||
|
||||
# 标签工程配置
|
||||
LABEL_TRANSFORM = "exponential" # 启用指数化增益
|
||||
LABEL_SCALE = 20.0
|
||||
|
||||
# 模型参数配置
|
||||
MODEL_PARAMS = {
|
||||
# ... 基础参数 ...
|
||||
|
||||
# Top-K 优化参数
|
||||
"loss_type": "weighted_listnet", # 或 "lambda"
|
||||
"topk_weight": 5.0, # 仅 weighted_listnet 有效
|
||||
"lambda_sigma": 1.0, # 仅 lambda 有效
|
||||
"ndcg_weight_power": 1.0, # 仅 lambda 有效
|
||||
"ndcg_k": 20, # 验证指标
|
||||
}
|
||||
```
|
||||
|
||||
## 参数说明
|
||||
|
||||
### 损失函数参数
|
||||
|
||||
| 参数 | 类型 | 默认值 | 说明 |
|
||||
|------|------|--------|------|
|
||||
| `loss_type` | str | "listnet" | 损失类型: "listnet"/"weighted_listnet"/"lambda" |
|
||||
| `topk_weight` | float | 5.0 | 头部权重系数 (weighted_listnet),越大越关注头部 |
|
||||
| `lambda_sigma` | float | 1.0 | Sigmoid 陡峭程度 (lambda) |
|
||||
| `ndcg_weight_power` | float | 1.0 | DeltaNDCG 权重幂次 (lambda),>1 进一步放大头部 |
|
||||
| `ndcg_k` | int/None | None | 验证时计算的 NDCG@k,None 表示全局 |
|
||||
|
||||
### 标签工程参数
|
||||
|
||||
| 参数 | 类型 | 默认值 | 说明 |
|
||||
|------|------|--------|------|
|
||||
| `n_quantiles` | int | 20 | 分位数数量,越大分辨率越高 |
|
||||
| `label_transform` | str/None | None | 变换类型: None/"exponential" |
|
||||
| `label_scale` | float | 20.0 | 指数变换缩放因子,控制增益幅度 |
|
||||
|
||||
## 实施建议
|
||||
|
||||
1. **渐进式优化**:
|
||||
- 第1轮:仅启用 `loss_type="weighted_listnet"`, `topk_weight=3.0`
|
||||
- 第2轮:增加标签工程 `label_transform="exponential"`
|
||||
- 第3轮:尝试 `loss_type="lambda"` 精细优化
|
||||
|
||||
2. **监控指标**:
|
||||
- 关注 NDCG@K(K 设为实际 Top-K 大小)
|
||||
- 对比不同配置的回测收益率
|
||||
- 观察训练损失是否稳定下降
|
||||
|
||||
3. **注意事项**:
|
||||
- LambdaLoss 训练更慢,每 epoch 需更多时间
|
||||
- 指数化增益会改变标签分布,可能需要调整学习率
|
||||
- 过高的 topk_weight 可能导致过拟合头部样本
|
||||
|
||||
## 参考论文
|
||||
|
||||
1. **ListNet**: "Learning to Rank: From Pairwise Approach to Listwise Approach" (Cao et al., 2007)
|
||||
2. **LambdaRank**: "From RankNet to LambdaRank to LambdaMART: An Overview" (Burges, 2010)
|
||||
3. **LambdaLoss**: "The LambdaLoss Framework for Ranking Metric Optimization" (Wang et al., 2018)
|
||||
4. **深度学习选股**: "Deep Learning for Stock Selection" (Gu, Kelly, & Xiu, 2020)
|
||||
@@ -270,6 +270,7 @@ SELECTED_FACTORS = [
|
||||
"bottom_cost_stability",
|
||||
"pivot_reversion",
|
||||
"chip_transition",
|
||||
|
||||
# "amivest_liq_20",
|
||||
# "atr_price_impact",
|
||||
# "hui_heubel_ratio",
|
||||
@@ -323,11 +324,11 @@ def get_label_factor(label_name: str) -> dict:
|
||||
# 辅助函数
|
||||
# =============================================================================
|
||||
def register_factors(
|
||||
engine: FactorEngine,
|
||||
selected_factors: List[str],
|
||||
factor_definitions: dict,
|
||||
label_factor: dict,
|
||||
excluded_factors: Optional[List[str]] = None,
|
||||
engine: FactorEngine,
|
||||
selected_factors: List[str],
|
||||
factor_definitions: dict,
|
||||
label_factor: dict,
|
||||
excluded_factors: Optional[List[str]] = None,
|
||||
) -> List[str]:
|
||||
"""注册因子。
|
||||
|
||||
@@ -408,11 +409,11 @@ def register_factors(
|
||||
|
||||
|
||||
def prepare_data(
|
||||
engine: FactorEngine,
|
||||
feature_cols: List[str],
|
||||
start_date: str,
|
||||
end_date: str,
|
||||
label_name: str,
|
||||
engine: FactorEngine,
|
||||
feature_cols: List[str],
|
||||
start_date: str,
|
||||
end_date: str,
|
||||
label_name: str,
|
||||
) -> pl.DataFrame:
|
||||
"""准备数据。
|
||||
|
||||
@@ -450,45 +451,6 @@ def prepare_data(
|
||||
return data
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# 股票池筛选配置
|
||||
# =============================================================================
|
||||
def stock_pool_filter(df: pl.DataFrame) -> pl.Series:
|
||||
"""股票池筛选函数(单日数据)。
|
||||
|
||||
筛选条件:
|
||||
1. 排除创业板(代码以 300 开头)
|
||||
2. 排除科创板(代码以 688 开头)
|
||||
3. 排除北交所(代码以 8、9 或 4 开头)
|
||||
4. 选取当日市值最小的500只股票
|
||||
|
||||
Args:
|
||||
df: 单日数据框
|
||||
|
||||
Returns:
|
||||
布尔Series,表示哪些股票被选中
|
||||
"""
|
||||
# 代码筛选(排除创业板、科创板、北交所)
|
||||
code_filter = (
|
||||
~df["ts_code"].str.starts_with("30") # 排除创业板
|
||||
& ~df["ts_code"].str.starts_with("68") # 排除科创板
|
||||
& ~df["ts_code"].str.starts_with("8") # 排除北交所
|
||||
& ~df["ts_code"].str.starts_with("9") # 排除北交所
|
||||
& ~df["ts_code"].str.starts_with("4") # 排除北交所
|
||||
)
|
||||
|
||||
# 在已筛选的股票中,选取流通市值最小的500只
|
||||
valid_df = df.filter(code_filter)
|
||||
n = min(1000, len(valid_df))
|
||||
small_cap_codes = valid_df.sort("circ_mv").head(n)["ts_code"]
|
||||
|
||||
# 返回布尔 Series:是否在被选中的股票中
|
||||
return df["ts_code"].is_in(small_cap_codes)
|
||||
|
||||
|
||||
# 定义筛选所需的基础列
|
||||
STOCK_FILTER_REQUIRED_COLUMNS = ["circ_mv"]
|
||||
|
||||
# =============================================================================
|
||||
# 输出配置
|
||||
# =============================================================================
|
||||
@@ -502,10 +464,54 @@ MODEL_SAVE_DIR = "models" # 模型保存目录
|
||||
# Top N 配置:每日推荐股票数量
|
||||
TOP_N = 5 # 可调整为 10, 20 等
|
||||
|
||||
# 股票池大小配置
|
||||
STOCK_POOL_SIZE = 1000 # 股票池选择市值最小的股票数量
|
||||
|
||||
# 训练数据跳过天数配置
|
||||
TRAIN_SKIP_DAYS = 300 # 跳过训练数据前252天的数据,避免训练初期数据不足的问题
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# 股票池筛选配置
|
||||
# =============================================================================
|
||||
def stock_pool_filter(df: pl.DataFrame, n_stocks: int = STOCK_POOL_SIZE) -> pl.Series:
|
||||
"""股票池筛选函数(单日数据)。
|
||||
|
||||
筛选条件:
|
||||
1. 排除创业板(代码以 300 开头)
|
||||
2. 排除科创板(代码以 688 开头)
|
||||
3. 排除北交所(代码以 8、9 或 4 开头)
|
||||
4. 选取当日市值最小的n_stocks只股票
|
||||
|
||||
Args:
|
||||
df: 单日数据框
|
||||
n_stocks: 选取的股票数量,默认为 STOCK_POOL_SIZE
|
||||
|
||||
Returns:
|
||||
布尔Series,表示哪些股票被选中
|
||||
"""
|
||||
# 代码筛选(排除创业板、科创板、北交所)
|
||||
code_filter = (
|
||||
~df["ts_code"].str.starts_with("30") # 排除创业板
|
||||
& ~df["ts_code"].str.starts_with("68") # 排除科创板
|
||||
& ~df["ts_code"].str.starts_with("8") # 排除北交所
|
||||
& ~df["ts_code"].str.starts_with("9") # 排除北交所
|
||||
& ~df["ts_code"].str.starts_with("4") # 排除北交所
|
||||
)
|
||||
|
||||
# 在已筛选的股票中,选取流通市值最小的n_stocks只
|
||||
valid_df = df.filter(code_filter)
|
||||
n = min(n_stocks, len(valid_df))
|
||||
small_cap_codes = valid_df.sort("circ_mv").head(n)["ts_code"]
|
||||
|
||||
# 返回布尔 Series:是否在被选中的股票中
|
||||
return df["ts_code"].is_in(small_cap_codes)
|
||||
|
||||
|
||||
# 定义筛选所需的基础列
|
||||
STOCK_FILTER_REQUIRED_COLUMNS = ["circ_mv"]
|
||||
|
||||
|
||||
def get_output_path(model_type: str, test_start: str, test_end: str) -> str:
|
||||
"""生成输出文件路径。
|
||||
|
||||
@@ -532,7 +538,7 @@ def get_output_path(model_type: str, test_start: str, test_end: str) -> str:
|
||||
|
||||
|
||||
def get_model_save_path(
|
||||
model_type: str,
|
||||
model_type: str,
|
||||
) -> Optional[str]:
|
||||
"""生成模型保存路径。
|
||||
|
||||
@@ -558,11 +564,11 @@ def get_model_save_path(
|
||||
|
||||
|
||||
def save_model_with_factors(
|
||||
model,
|
||||
model_path: str,
|
||||
selected_factors: list[str],
|
||||
factor_definitions: dict,
|
||||
fitted_processors: list | None = None,
|
||||
model,
|
||||
model_path: str,
|
||||
selected_factors: list[str],
|
||||
factor_definitions: dict,
|
||||
fitted_processors: list | None = None,
|
||||
) -> str:
|
||||
"""保存模型及关联的因子信息和处理器。
|
||||
|
||||
|
||||
@@ -1,514 +0,0 @@
|
||||
"""数据质量分析模块
|
||||
|
||||
提供数据质量检查功能,包括:
|
||||
- 缺失值统计
|
||||
- 零值统计
|
||||
- 按日期检查全空列
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, List, Optional
|
||||
import polars as pl
|
||||
import numpy as np
|
||||
|
||||
|
||||
class DataQualityAnalyzer:
|
||||
"""数据质量分析器
|
||||
|
||||
用于分析训练数据的质量问题,帮助识别数据异常。
|
||||
|
||||
Attributes:
|
||||
feature_cols: 特征列名列表
|
||||
label_col: 标签列名
|
||||
verbose: 是否打印详细信息
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
feature_cols: Optional[List[str]] = None,
|
||||
label_col: Optional[str] = None,
|
||||
verbose: bool = True,
|
||||
):
|
||||
"""初始化数据质量分析器
|
||||
|
||||
Args:
|
||||
feature_cols: 特征列名列表
|
||||
label_col: 标签列名
|
||||
verbose: 是否打印详细信息
|
||||
"""
|
||||
self.feature_cols = feature_cols or []
|
||||
self.label_col = label_col
|
||||
self.verbose = verbose
|
||||
self.analysis_results: Dict[str, Any] = {}
|
||||
|
||||
def set_columns(self, feature_cols: List[str], label_col: str) -> None:
|
||||
"""设置要分析的列
|
||||
|
||||
Args:
|
||||
feature_cols: 特征列名列表
|
||||
label_col: 标签列名
|
||||
"""
|
||||
self.feature_cols = feature_cols
|
||||
self.label_col = label_col
|
||||
|
||||
def analyze(
|
||||
self,
|
||||
data: Dict[str, Dict[str, Any]],
|
||||
split_names: Optional[List[str]] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""执行完整的数据质量分析
|
||||
|
||||
Args:
|
||||
data: 数据字典,格式为 {"train": {...}, "val": {...}, "test": {...}}
|
||||
split_names: 要分析的数据划分名称列表,默认为 ["train", "val", "test"]
|
||||
|
||||
Returns:
|
||||
分析结果字典
|
||||
"""
|
||||
if not split_names:
|
||||
split_names = ["train", "val", "test"]
|
||||
|
||||
if self.verbose:
|
||||
print("\n" + "=" * 80)
|
||||
print("数据质量分析报告")
|
||||
print("=" * 80)
|
||||
|
||||
self.analysis_results = {}
|
||||
|
||||
for split_name in split_names:
|
||||
if split_name not in data:
|
||||
continue
|
||||
|
||||
split_data = data[split_name]
|
||||
raw_df = split_data.get("raw_data")
|
||||
|
||||
if raw_df is None:
|
||||
continue
|
||||
|
||||
if self.verbose:
|
||||
print(f"\n[{split_name.upper()} 数据集]")
|
||||
print("-" * 40)
|
||||
|
||||
split_results = self._analyze_split(raw_df, split_name)
|
||||
self.analysis_results[split_name] = split_results
|
||||
|
||||
if self.verbose:
|
||||
print("\n" + "=" * 80)
|
||||
|
||||
return self.analysis_results
|
||||
|
||||
def _analyze_split(
|
||||
self,
|
||||
df: pl.DataFrame,
|
||||
split_name: str,
|
||||
) -> Dict[str, Any]:
|
||||
"""分析单个数据集划分
|
||||
|
||||
Args:
|
||||
df: 数据框
|
||||
split_name: 划分名称
|
||||
|
||||
Returns:
|
||||
分析结果字典
|
||||
"""
|
||||
results = {
|
||||
"total_rows": len(df),
|
||||
"total_cols": len(df.columns),
|
||||
"feature_cols": self.feature_cols,
|
||||
"label_col": self.label_col,
|
||||
"null_analysis": {},
|
||||
"zero_analysis": {},
|
||||
"all_null_by_date": {},
|
||||
}
|
||||
|
||||
# 1. 分析特征列的缺失值
|
||||
null_stats = self._analyze_null_values(df, self.feature_cols)
|
||||
results["null_analysis"] = null_stats
|
||||
|
||||
if self.verbose:
|
||||
self._print_null_analysis(null_stats)
|
||||
|
||||
# 2. 分析特征列的零值
|
||||
zero_stats = self._analyze_zero_values(df, self.feature_cols)
|
||||
results["zero_analysis"] = zero_stats
|
||||
|
||||
if self.verbose:
|
||||
self._print_zero_analysis(zero_stats)
|
||||
|
||||
# 3. 检查是否存在某天某列全为空的情况
|
||||
all_null_by_date = self._check_all_null_by_date(df, self.feature_cols)
|
||||
results["all_null_by_date"] = all_null_by_date
|
||||
|
||||
if self.verbose:
|
||||
self._print_all_null_by_date(all_null_by_date)
|
||||
|
||||
# 4. 分析标签列
|
||||
if self.label_col and self.label_col in df.columns:
|
||||
label_stats = self._analyze_label(df, self.label_col)
|
||||
results["label_analysis"] = label_stats
|
||||
|
||||
if self.verbose:
|
||||
self._print_label_analysis(label_stats)
|
||||
|
||||
return results
|
||||
|
||||
def _analyze_null_values(
|
||||
self,
|
||||
df: pl.DataFrame,
|
||||
cols: List[str],
|
||||
) -> Dict[str, Any]:
|
||||
"""分析缺失值
|
||||
|
||||
Args:
|
||||
df: 数据框
|
||||
cols: 要分析的列名列表
|
||||
|
||||
Returns:
|
||||
缺失值统计字典
|
||||
"""
|
||||
stats = {
|
||||
"total_cells": len(df) * len(cols),
|
||||
"null_counts": {},
|
||||
"null_percentages": {},
|
||||
"columns_with_null": [],
|
||||
"total_null_cells": 0,
|
||||
}
|
||||
|
||||
for col in cols:
|
||||
if col not in df.columns:
|
||||
continue
|
||||
|
||||
null_count = df[col].null_count()
|
||||
if null_count > 0:
|
||||
null_pct = null_count / len(df) * 100
|
||||
stats["null_counts"][col] = null_count
|
||||
stats["null_percentages"][col] = null_pct
|
||||
stats["columns_with_null"].append(col)
|
||||
stats["total_null_cells"] += null_count
|
||||
|
||||
return stats
|
||||
|
||||
def _analyze_zero_values(
|
||||
self,
|
||||
df: pl.DataFrame,
|
||||
cols: List[str],
|
||||
) -> Dict[str, Any]:
|
||||
"""分析零值
|
||||
|
||||
Args:
|
||||
df: 数据框
|
||||
cols: 要分析的列名列表
|
||||
|
||||
Returns:
|
||||
零值统计字典
|
||||
"""
|
||||
stats = {
|
||||
"total_cells": len(df) * len(cols),
|
||||
"zero_counts": {},
|
||||
"zero_percentages": {},
|
||||
"columns_with_zero": [],
|
||||
"total_zero_cells": 0,
|
||||
}
|
||||
|
||||
for col in cols:
|
||||
if col not in df.columns:
|
||||
continue
|
||||
|
||||
# 计算零值数量(排除空值)
|
||||
non_null_series = df[col].drop_nulls()
|
||||
if len(non_null_series) == 0:
|
||||
continue
|
||||
|
||||
zero_count = (non_null_series == 0).sum()
|
||||
if zero_count > 0:
|
||||
zero_pct = zero_count / len(df) * 100
|
||||
stats["zero_counts"][col] = int(zero_count)
|
||||
stats["zero_percentages"][col] = zero_pct
|
||||
stats["columns_with_zero"].append(col)
|
||||
stats["total_zero_cells"] += int(zero_count)
|
||||
|
||||
return stats
|
||||
|
||||
def _check_all_null_by_date(
|
||||
self,
|
||||
df: pl.DataFrame,
|
||||
cols: List[str],
|
||||
) -> Dict[str, Any]:
|
||||
"""检查是否存在某天某列全为空的情况
|
||||
|
||||
使用 polars lazy frame 进行内存安全的高效计算。
|
||||
|
||||
Args:
|
||||
df: 数据框
|
||||
cols: 要分析的列名列表
|
||||
|
||||
Returns:
|
||||
全空检查结果字典
|
||||
"""
|
||||
results = {
|
||||
"issues_found": False,
|
||||
"issues": [],
|
||||
}
|
||||
|
||||
if "trade_date" not in df.columns:
|
||||
return results
|
||||
|
||||
# 过滤掉不在表中的列
|
||||
valid_cols = [c for c in cols if c in df.columns]
|
||||
if not valid_cols:
|
||||
return results
|
||||
|
||||
# 使用 lazy frame 进行查询优化
|
||||
lf = df.lazy()
|
||||
|
||||
# 核心步骤:只计算 null_count 和总行数 (聚合后数据量极小)
|
||||
# 为每个列创建单独的 null_count 聚合表达式
|
||||
agg_exprs = [
|
||||
pl.col(col).null_count().alias(f"{col}_nulls") for col in valid_cols
|
||||
]
|
||||
agg_exprs.append(pl.len().alias("total_rows"))
|
||||
|
||||
agg_lf = lf.group_by("trade_date").agg(agg_exprs)
|
||||
|
||||
# 收集结果 (此时 agg_df 行数通常只有几百到几千行)
|
||||
agg_df = agg_lf.collect()
|
||||
|
||||
# 在这个已经"脱水"的小表上进行逻辑检查
|
||||
issues = []
|
||||
for col in valid_cols:
|
||||
null_col = f"{col}_nulls"
|
||||
# 找出 null 数量等于总行数的日期
|
||||
bad_dates = agg_df.filter(
|
||||
(pl.col(null_col) == pl.col("total_rows")) & (pl.col("total_rows") > 0)
|
||||
).select(["trade_date", "total_rows"])
|
||||
|
||||
if not bad_dates.is_empty():
|
||||
for row in bad_dates.to_dicts():
|
||||
issues.append(
|
||||
{
|
||||
"date": row["trade_date"],
|
||||
"column": col,
|
||||
"total_rows": row["total_rows"],
|
||||
}
|
||||
)
|
||||
|
||||
if issues:
|
||||
results["issues_found"] = True
|
||||
results["issues"] = issues
|
||||
|
||||
return results
|
||||
|
||||
def _analyze_label(
|
||||
self,
|
||||
df: pl.DataFrame,
|
||||
label_col: str,
|
||||
) -> Dict[str, Any]:
|
||||
"""分析标签列
|
||||
|
||||
Args:
|
||||
df: 数据框
|
||||
label_col: 标签列名
|
||||
|
||||
Returns:
|
||||
标签分析字典
|
||||
"""
|
||||
stats = {
|
||||
"total_count": len(df),
|
||||
"null_count": 0,
|
||||
"null_percentage": 0.0,
|
||||
"zero_count": 0,
|
||||
"zero_percentage": 0.0,
|
||||
"min": None,
|
||||
"max": None,
|
||||
"mean": None,
|
||||
"std": None,
|
||||
}
|
||||
|
||||
if label_col not in df.columns:
|
||||
return stats
|
||||
|
||||
series = df[label_col]
|
||||
|
||||
# 缺失值统计
|
||||
null_count = series.null_count()
|
||||
stats["null_count"] = null_count
|
||||
stats["null_percentage"] = null_count / len(df) * 100 if len(df) > 0 else 0
|
||||
|
||||
# 零值统计
|
||||
non_null_series = series.drop_nulls()
|
||||
if len(non_null_series) > 0:
|
||||
zero_count = (non_null_series == 0).sum()
|
||||
stats["zero_count"] = int(zero_count)
|
||||
stats["zero_percentage"] = zero_count / len(df) * 100
|
||||
|
||||
# 基本统计量
|
||||
stats["min"] = float(non_null_series.min())
|
||||
stats["max"] = float(non_null_series.max())
|
||||
stats["mean"] = float(non_null_series.mean())
|
||||
stats["std"] = float(non_null_series.std())
|
||||
|
||||
return stats
|
||||
|
||||
def _print_null_analysis(self, stats: Dict[str, Any]) -> None:
|
||||
"""打印缺失值分析结果
|
||||
|
||||
Args:
|
||||
stats: 缺失值统计字典
|
||||
"""
|
||||
total_cells = stats["total_cells"]
|
||||
total_null = stats["total_null_cells"]
|
||||
null_cols = stats["columns_with_null"]
|
||||
|
||||
print(f" 缺失值统计:")
|
||||
print(f" 总单元格数: {total_cells:,}")
|
||||
print(
|
||||
f" 缺失单元格数: {total_null:,} ({total_null / total_cells * 100:.2f}%)"
|
||||
)
|
||||
print(f" 有缺失值的列数: {len(null_cols)}/{len(self.feature_cols)}")
|
||||
|
||||
if null_cols:
|
||||
print(f" 缺失值最多的5个特征:")
|
||||
sorted_cols = sorted(
|
||||
stats["null_counts"].items(),
|
||||
key=lambda x: x[1],
|
||||
reverse=True,
|
||||
)[:5]
|
||||
for col, count in sorted_cols:
|
||||
pct = stats["null_percentages"][col]
|
||||
print(f" {col}: {count:,} ({pct:.2f}%)")
|
||||
|
||||
def _print_zero_analysis(self, stats: Dict[str, Any]) -> None:
|
||||
"""打印零值分析结果
|
||||
|
||||
Args:
|
||||
stats: 零值统计字典
|
||||
"""
|
||||
total_cells = stats["total_cells"]
|
||||
total_zero = stats["total_zero_cells"]
|
||||
zero_cols = stats["columns_with_zero"]
|
||||
|
||||
print(f" 零值统计:")
|
||||
print(f" 总单元格数: {total_cells:,}")
|
||||
print(
|
||||
f" 零值单元格数: {total_zero:,} ({total_zero / total_cells * 100:.2f}%)"
|
||||
)
|
||||
print(f" 有零值的列数: {len(zero_cols)}/{len(self.feature_cols)}")
|
||||
|
||||
if zero_cols:
|
||||
print(f" 零值最多的5个特征:")
|
||||
sorted_cols = sorted(
|
||||
stats["zero_counts"].items(),
|
||||
key=lambda x: x[1],
|
||||
reverse=True,
|
||||
)[:5]
|
||||
for col, count in sorted_cols:
|
||||
pct = stats["zero_percentages"][col]
|
||||
print(f" {col}: {count:,} ({pct:.2f}%)")
|
||||
|
||||
def _print_all_null_by_date(self, results: Dict[str, Any]) -> None:
|
||||
"""打印按日期全空检查结果
|
||||
|
||||
Args:
|
||||
results: 全空检查结果字典
|
||||
"""
|
||||
issues = results["issues"]
|
||||
|
||||
print(f" 按日期全空检查:")
|
||||
if results["issues_found"]:
|
||||
print(f" [警告] 发现 {len(issues)} 个问题:")
|
||||
# 按日期分组显示
|
||||
by_date = {}
|
||||
for issue in issues:
|
||||
date = issue["date"]
|
||||
if date not in by_date:
|
||||
by_date[date] = []
|
||||
by_date[date].append(issue["column"])
|
||||
|
||||
for date in sorted(by_date.keys())[:5]: # 只显示前5个日期
|
||||
cols = by_date[date]
|
||||
print(f" 日期 {date}: {len(cols)} 列全为空")
|
||||
if len(cols) <= 3:
|
||||
print(f" 列名: {', '.join(cols)}")
|
||||
|
||||
if len(by_date) > 5:
|
||||
print(f" ... 还有 {len(by_date) - 5} 个日期存在问题")
|
||||
else:
|
||||
print(f" [正常] 未发现某天某列全为空的情况")
|
||||
|
||||
def _print_label_analysis(self, stats: Dict[str, Any]) -> None:
|
||||
"""打印标签分析结果
|
||||
|
||||
Args:
|
||||
stats: 标签分析字典
|
||||
"""
|
||||
print(f" 标签列统计 ({self.label_col}):")
|
||||
print(f" 总数: {stats['total_count']:,}")
|
||||
print(f" 缺失值: {stats['null_count']:,} ({stats['null_percentage']:.2f}%)")
|
||||
print(f" 零值: {stats['zero_count']:,} ({stats['zero_percentage']:.2f}%)")
|
||||
|
||||
if stats["mean"] is not None:
|
||||
print(f" 最小值: {stats['min']:.6f}")
|
||||
print(f" 最大值: {stats['max']:.6f}")
|
||||
print(f" 均值: {stats['mean']:.6f}")
|
||||
print(f" 标准差: {stats['std']:.6f}")
|
||||
|
||||
def get_summary(self) -> str:
|
||||
"""获取分析结果摘要
|
||||
|
||||
Returns:
|
||||
摘要字符串
|
||||
"""
|
||||
if not self.analysis_results:
|
||||
return "尚未执行分析"
|
||||
|
||||
lines = ["数据质量分析摘要", "=" * 40]
|
||||
|
||||
for split_name, results in self.analysis_results.items():
|
||||
lines.append(f"\n[{split_name.upper()}]")
|
||||
lines.append(f" 总行数: {results['total_rows']:,}")
|
||||
|
||||
null_stats = results.get("null_analysis", {})
|
||||
if null_stats.get("columns_with_null"):
|
||||
lines.append(
|
||||
f" 缺失值: {null_stats['total_null_cells']:,} 个单元格, "
|
||||
f"{len(null_stats['columns_with_null'])} 列受影响"
|
||||
)
|
||||
|
||||
zero_stats = results.get("zero_analysis", {})
|
||||
if zero_stats.get("columns_with_zero"):
|
||||
lines.append(
|
||||
f" 零值: {zero_stats['total_zero_cells']:,} 个单元格, "
|
||||
f"{len(zero_stats['columns_with_zero'])} 列受影响"
|
||||
)
|
||||
|
||||
all_null = results.get("all_null_by_date", {})
|
||||
if all_null.get("issues_found"):
|
||||
lines.append(
|
||||
f" [警告] 发现 {len(all_null['issues'])} 个日期列全空问题"
|
||||
)
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def analyze_data_quality(
|
||||
data: Dict[str, Dict[str, Any]],
|
||||
feature_cols: Optional[List[str]] = None,
|
||||
label_col: Optional[str] = None,
|
||||
verbose: bool = True,
|
||||
) -> Dict[str, Any]:
|
||||
"""便捷函数:执行数据质量分析
|
||||
|
||||
Args:
|
||||
data: 数据字典
|
||||
feature_cols: 特征列名列表
|
||||
label_col: 标签列名
|
||||
verbose: 是否打印详细信息
|
||||
|
||||
Returns:
|
||||
分析结果字典
|
||||
"""
|
||||
analyzer = DataQualityAnalyzer(
|
||||
feature_cols=feature_cols,
|
||||
label_col=label_col,
|
||||
verbose=verbose,
|
||||
)
|
||||
return analyzer.analyze(data)
|
||||
@@ -54,27 +54,10 @@ N_QUANTILES = 20
|
||||
|
||||
# 排除的因子列表
|
||||
EXCLUDED_FACTORS = [
|
||||
'active_market_cap',
|
||||
'close_vwap_deviation',
|
||||
'sharpe_ratio_20',
|
||||
'upper_shadow_ratio',
|
||||
'volume_ratio_5_20',
|
||||
'GTJA_alpha090',
|
||||
'GTJA_alpha084',
|
||||
'GTJA_alpha066',
|
||||
'GTJA_alpha150',
|
||||
'GTJA_alpha148',
|
||||
'GTJA_alpha106',
|
||||
'GTJA_alpha109',
|
||||
'GTJA_alpha108',
|
||||
'GTJA_alpha176',
|
||||
'GTJA_alpha169',
|
||||
'GTJA_alpha156',
|
||||
'chip_dispersion_70',
|
||||
'winner_rate_cs_rank',
|
||||
'atr_price_impact',
|
||||
'low_vol_days_20',
|
||||
'liquidity_shock_momentum',
|
||||
# 'debt_to_equity',
|
||||
# 'GTJA_alpha016',
|
||||
# 'GTJA_alpha141',
|
||||
|
||||
]
|
||||
|
||||
# LambdaRank 模型参数配置
|
||||
@@ -145,8 +128,8 @@ def main():
|
||||
pipeline = DataPipeline(
|
||||
factor_manager=factor_manager,
|
||||
processor_configs=[
|
||||
(NullFiller, {"strategy": "mean"}),
|
||||
(Winsorizer, {"lower": 0.01, "upper": 0.99}),
|
||||
(NullFiller, {"strategy": "mean"}),
|
||||
(CrossSectionalStandardScaler, {}),
|
||||
],
|
||||
filters=[STFilter(data_router=engine.router)],
|
||||
|
||||
@@ -52,36 +52,36 @@ TRAINING_TYPE = "regression"
|
||||
|
||||
# 排除的因子列表
|
||||
EXCLUDED_FACTORS = [
|
||||
'GTJA_alpha016',
|
||||
'volatility_20',
|
||||
'current_ratio',
|
||||
'GTJA_alpha001',
|
||||
'GTJA_alpha141',
|
||||
'GTJA_alpha129',
|
||||
'GTJA_alpha164',
|
||||
'amivest_liq_20',
|
||||
'GTJA_alpha012',
|
||||
'debt_to_equity',
|
||||
'turnover_deviation',
|
||||
'GTJA_alpha073',
|
||||
'GTJA_alpha043',
|
||||
'GTJA_alpha032',
|
||||
'GTJA_alpha028',
|
||||
'GTJA_alpha090',
|
||||
'GTJA_alpha108',
|
||||
'GTJA_alpha105',
|
||||
'GTJA_alpha091',
|
||||
'GTJA_alpha119',
|
||||
'GTJA_alpha104',
|
||||
'GTJA_alpha163',
|
||||
'GTJA_alpha157',
|
||||
'cost_skewness',
|
||||
'GTJA_alpha176',
|
||||
'chip_transition',
|
||||
'amount_skewness_20',
|
||||
'GTJA_alpha148',
|
||||
'mean_median_dev',
|
||||
'downside_illiq_20',
|
||||
# 'GTJA_alpha016',
|
||||
# 'volatility_20',
|
||||
# 'current_ratio',
|
||||
# 'GTJA_alpha001',
|
||||
# 'GTJA_alpha141',
|
||||
# 'GTJA_alpha129',
|
||||
# 'GTJA_alpha164',
|
||||
# 'amivest_liq_20',
|
||||
# 'GTJA_alpha012',
|
||||
# 'debt_to_equity',
|
||||
# 'turnover_deviation',
|
||||
# 'GTJA_alpha073',
|
||||
# 'GTJA_alpha043',
|
||||
# 'GTJA_alpha032',
|
||||
# 'GTJA_alpha028',
|
||||
# 'GTJA_alpha090',
|
||||
# 'GTJA_alpha108',
|
||||
# 'GTJA_alpha105',
|
||||
# 'GTJA_alpha091',
|
||||
# 'GTJA_alpha119',
|
||||
# 'GTJA_alpha104',
|
||||
# 'GTJA_alpha163',
|
||||
# 'GTJA_alpha157',
|
||||
# 'cost_skewness',
|
||||
# 'GTJA_alpha176',
|
||||
# 'chip_transition',
|
||||
# 'amount_skewness_20',
|
||||
# 'GTJA_alpha148',
|
||||
# 'mean_median_dev',
|
||||
# 'downside_illiq_20',
|
||||
]
|
||||
|
||||
# 模型参数配置
|
||||
@@ -153,15 +153,15 @@ def main():
|
||||
pipeline = DataPipeline(
|
||||
factor_manager=factor_manager,
|
||||
processor_configs=[
|
||||
(NullFiller, {"strategy": "mean"}),
|
||||
(Winsorizer, {"lower": 0.01, "upper": 0.99}),
|
||||
(NullFiller, {"strategy": "mean"}),
|
||||
(StandardScaler, {}),
|
||||
# (CrossSectionalStandardScaler, {}),
|
||||
],
|
||||
label_processor_configs=[
|
||||
# 对 label 进行缩尾处理(去除极端收益率)
|
||||
(Winsorizer, {"lower": 0.05, "upper": 0.95}),
|
||||
# (StandardScaler, {}),
|
||||
(StandardScaler, {}),
|
||||
],
|
||||
filters=[STFilter(data_router=engine.router)],
|
||||
stock_pool_filter_func=stock_pool_filter,
|
||||
|
||||
176
src/experiment/tabm_rank_train.py
Normal file
176
src/experiment/tabm_rank_train.py
Normal file
@@ -0,0 +1,176 @@
|
||||
"""TabM 排序学习训练流程(模块化版本)
|
||||
|
||||
使用新的模块化 Trainer 架构,基于 TabMRankModel 实现排序学习。
|
||||
TabM 使用 ListNet 损失函数,支持集成学习。
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
from src.factors import FactorEngine
|
||||
from src.training import (
|
||||
FactorManager,
|
||||
DataPipeline,
|
||||
NullFiller,
|
||||
Winsorizer,
|
||||
CrossSectionalStandardScaler,
|
||||
)
|
||||
from src.training.tasks.tabm_rank_task import TabMRankTask
|
||||
from src.training.core.trainer_v2 import Trainer
|
||||
from src.training.components.filters import STFilter
|
||||
from src.experiment.common import (
|
||||
SELECTED_FACTORS,
|
||||
FACTOR_DEFINITIONS,
|
||||
LABEL_NAME,
|
||||
LABEL_FACTOR,
|
||||
TRAIN_START,
|
||||
TRAIN_END,
|
||||
VAL_START,
|
||||
VAL_END,
|
||||
TEST_START,
|
||||
TEST_END,
|
||||
stock_pool_filter,
|
||||
STOCK_FILTER_REQUIRED_COLUMNS,
|
||||
OUTPUT_DIR,
|
||||
SAVE_PREDICTIONS,
|
||||
SAVE_MODEL,
|
||||
get_model_save_path,
|
||||
save_model_with_factors,
|
||||
TOP_N,
|
||||
TRAIN_SKIP_DAYS,
|
||||
)
|
||||
|
||||
# 训练类型标识
|
||||
TRAINING_TYPE = "tabm_rank"
|
||||
|
||||
# %%
|
||||
# Label 配置(从 common.py 统一导入)
|
||||
# LABEL_NAME 和 LABEL_FACTOR 已在 common.py 中绑定,只需从 common 导入
|
||||
|
||||
# 分位数配置(提高分辨率以更好地区分头部)
|
||||
N_QUANTILES = 50
|
||||
|
||||
# 【Top-K 优化】标签工程配置 - 默认启用平方增益
|
||||
LABEL_TRANSFORM = "exponential" # 启用平方增益标签 (rank^2)
|
||||
LABEL_SCALE = 20.0 # 保留参数(当前未使用,平方变换不需要缩放)
|
||||
|
||||
# 排除的因子列表
|
||||
EXCLUDED_FACTORS = ["GTJA_alpha041", "GTJA_alpha127"]
|
||||
|
||||
# TabM Rank 模型参数配置(Top-K 优化全部开启,使用 LambdaLoss)
|
||||
MODEL_PARAMS = {
|
||||
# ==================== MLP 结构 ====================
|
||||
"n_blocks": 4, # MLP 层数
|
||||
"d_block": 256, # 每层神经元数
|
||||
"dropout": 0.5, # Dropout 率
|
||||
# ==================== 集成机制 ====================
|
||||
"ensemble_size": 32, # 内置集成大小(模拟 32 个模型集成)
|
||||
# ==================== 训练参数 ====================
|
||||
"learning_rate": 1e-4, # 学习率
|
||||
"weight_decay": 1e-5, # 权重衰减
|
||||
"epochs": 500, # 训练轮数
|
||||
# ==================== 早停 ====================
|
||||
"early_stopping_round": 50, # 早停耐心值
|
||||
# NDCG 评估 - 关注 Top-20
|
||||
"ndcg_k": 20, # 验证时计算 NDCG@20
|
||||
# 【Top-K 优化】损失函数配置 - 使用 LambdaLoss
|
||||
"loss_type": "lambda", # 使用 LambdaLoss 精准优化 Top-K
|
||||
"lambda_sigma": 1.0, # Sigmoid 陡峭程度
|
||||
"ndcg_weight_power": 1.0, # DeltaNDCG 权重幂次,>1 进一步放大头部效应
|
||||
}
|
||||
|
||||
# 日期范围配置
|
||||
date_range = {
|
||||
"train": (TRAIN_START, TRAIN_END),
|
||||
"val": (VAL_START, VAL_END),
|
||||
"test": (TEST_START, TEST_END),
|
||||
}
|
||||
|
||||
# 输出配置
|
||||
output_config = {
|
||||
"output_dir": OUTPUT_DIR,
|
||||
"output_filename": "tabm_rank_output.csv",
|
||||
"save_predictions": SAVE_PREDICTIONS,
|
||||
"save_model": SAVE_MODEL,
|
||||
"model_save_path": get_model_save_path(TRAINING_TYPE),
|
||||
"top_n": TOP_N,
|
||||
}
|
||||
|
||||
|
||||
def main():
|
||||
"""主函数"""
|
||||
print("\n" + "=" * 80)
|
||||
print("TabM 排序学习训练(模块化版本)")
|
||||
print("=" * 80)
|
||||
|
||||
# 1. 创建 FactorEngine
|
||||
print("\n[1] 创建 FactorEngine")
|
||||
engine = FactorEngine()
|
||||
|
||||
# 2. 创建 FactorManager
|
||||
print("\n[2] 创建 FactorManager")
|
||||
factor_manager = FactorManager(
|
||||
selected_factors=SELECTED_FACTORS,
|
||||
factor_definitions=FACTOR_DEFINITIONS,
|
||||
label_factor=LABEL_FACTOR,
|
||||
excluded_factors=EXCLUDED_FACTORS,
|
||||
)
|
||||
|
||||
# 3. 创建 DataPipeline
|
||||
print("\n[3] 创建 DataPipeline")
|
||||
pipeline = DataPipeline(
|
||||
factor_manager=factor_manager,
|
||||
processor_configs=[
|
||||
(Winsorizer, {"lower": 0.01, "upper": 0.99}),
|
||||
(NullFiller, {"strategy": "mean"}),
|
||||
(CrossSectionalStandardScaler, {}),
|
||||
],
|
||||
filters=[STFilter(data_router=engine.router)],
|
||||
stock_pool_filter_func=stock_pool_filter,
|
||||
stock_pool_required_columns=STOCK_FILTER_REQUIRED_COLUMNS,
|
||||
train_skip_days=TRAIN_SKIP_DAYS,
|
||||
)
|
||||
|
||||
# 4. 创建 TabMRankTask
|
||||
print("\n[4] 创建 TabMRankTask")
|
||||
task = TabMRankTask(
|
||||
model_params=MODEL_PARAMS,
|
||||
label_name=LABEL_NAME,
|
||||
n_quantiles=N_QUANTILES,
|
||||
label_transform=LABEL_TRANSFORM,
|
||||
label_scale=LABEL_SCALE,
|
||||
)
|
||||
|
||||
# 5. 创建 Trainer
|
||||
print("\n[5] 创建 Trainer")
|
||||
trainer = Trainer(
|
||||
data_pipeline=pipeline,
|
||||
task=task,
|
||||
output_config=output_config,
|
||||
verbose=True,
|
||||
)
|
||||
|
||||
# 6. 执行训练
|
||||
print("\n[6] 执行训练")
|
||||
results = trainer.run(engine=engine, date_range=date_range)
|
||||
|
||||
# 7. 保存模型和因子信息(如果启用)
|
||||
if SAVE_MODEL:
|
||||
print("\n[7] 保存模型和因子信息")
|
||||
save_model_with_factors(
|
||||
model=task.get_model(),
|
||||
model_path=output_config["model_save_path"],
|
||||
selected_factors=SELECTED_FACTORS,
|
||||
factor_definitions=FACTOR_DEFINITIONS,
|
||||
fitted_processors=pipeline.get_fitted_processors(),
|
||||
)
|
||||
|
||||
print("\n" + "=" * 80)
|
||||
print("训练流程完成!")
|
||||
print(f"结果保存路径: {os.path.join(OUTPUT_DIR, 'tabm_rank_output.csv')}")
|
||||
print("=" * 80)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -39,9 +39,8 @@ from src.experiment.common import (
|
||||
get_model_save_path,
|
||||
save_model_with_factors,
|
||||
TOP_N,
|
||||
TRAIN_SKIP_DAYS,
|
||||
TRAIN_SKIP_DAYS, STOCK_POOL_SIZE,
|
||||
)
|
||||
from src.experiment.data_quality_analyzer import DataQualityAnalyzer
|
||||
|
||||
# 训练类型标识
|
||||
TRAINING_TYPE = "tabm_regression"
|
||||
@@ -54,201 +53,7 @@ TRAINING_TYPE = "tabm_regression"
|
||||
|
||||
# 排除的因子列表(与 LightGBM 回归保持一致)
|
||||
EXCLUDED_FACTORS = [
|
||||
# "GTJA_alpha001",
|
||||
# "GTJA_alpha002",
|
||||
# "GTJA_alpha003",
|
||||
# "GTJA_alpha004",
|
||||
# "GTJA_alpha005",
|
||||
# "GTJA_alpha006",
|
||||
# "GTJA_alpha007",
|
||||
# "GTJA_alpha008",
|
||||
# "GTJA_alpha009",
|
||||
# "GTJA_alpha010",
|
||||
# "GTJA_alpha011",
|
||||
# "GTJA_alpha012",
|
||||
# "GTJA_alpha013",
|
||||
# "GTJA_alpha014",
|
||||
# "GTJA_alpha015",
|
||||
# "GTJA_alpha016",
|
||||
# "GTJA_alpha017",
|
||||
# "GTJA_alpha018",
|
||||
# "GTJA_alpha019",
|
||||
# "GTJA_alpha020",
|
||||
# "GTJA_alpha022",
|
||||
# "GTJA_alpha023",
|
||||
# "GTJA_alpha024",
|
||||
# "GTJA_alpha025",
|
||||
# "GTJA_alpha026",
|
||||
# "GTJA_alpha027",
|
||||
# "GTJA_alpha028",
|
||||
# "GTJA_alpha029",
|
||||
# "GTJA_alpha031",
|
||||
# "GTJA_alpha032",
|
||||
# "GTJA_alpha033",
|
||||
# "GTJA_alpha034",
|
||||
# "GTJA_alpha035",
|
||||
# "GTJA_alpha036",
|
||||
# "GTJA_alpha037",
|
||||
# # "GTJA_alpha038",
|
||||
# "GTJA_alpha039",
|
||||
# "GTJA_alpha040",
|
||||
# "GTJA_alpha041",
|
||||
# "GTJA_alpha042",
|
||||
# "GTJA_alpha043",
|
||||
# "GTJA_alpha044",
|
||||
# "GTJA_alpha045",
|
||||
# "GTJA_alpha046",
|
||||
# "GTJA_alpha047",
|
||||
# "GTJA_alpha048",
|
||||
# "GTJA_alpha049",
|
||||
# "GTJA_alpha050",
|
||||
# "GTJA_alpha051",
|
||||
# "GTJA_alpha052",
|
||||
# "GTJA_alpha053",
|
||||
# "GTJA_alpha054",
|
||||
# "GTJA_alpha056",
|
||||
# "GTJA_alpha057",
|
||||
# "GTJA_alpha058",
|
||||
# "GTJA_alpha059",
|
||||
# "GTJA_alpha060",
|
||||
# "GTJA_alpha061",
|
||||
# "GTJA_alpha062",
|
||||
# "GTJA_alpha063",
|
||||
# "GTJA_alpha064",
|
||||
# "GTJA_alpha065",
|
||||
# "GTJA_alpha066",
|
||||
# "GTJA_alpha067",
|
||||
# "GTJA_alpha068",
|
||||
# "GTJA_alpha070",
|
||||
# "GTJA_alpha071",
|
||||
# "GTJA_alpha072",
|
||||
# "GTJA_alpha073",
|
||||
# "GTJA_alpha074",
|
||||
# "GTJA_alpha076",
|
||||
# "GTJA_alpha077",
|
||||
# "GTJA_alpha078",
|
||||
# "GTJA_alpha079",
|
||||
# "GTJA_alpha080",
|
||||
# "GTJA_alpha081",
|
||||
# "GTJA_alpha082",
|
||||
# "GTJA_alpha083",
|
||||
# "GTJA_alpha084",
|
||||
# "GTJA_alpha085",
|
||||
# "GTJA_alpha086",
|
||||
# "GTJA_alpha087",
|
||||
# "GTJA_alpha088",
|
||||
# "GTJA_alpha089",
|
||||
# "GTJA_alpha090",
|
||||
# "GTJA_alpha091",
|
||||
# "GTJA_alpha092",
|
||||
# "GTJA_alpha093",
|
||||
# "GTJA_alpha094",
|
||||
# "GTJA_alpha095",
|
||||
# "GTJA_alpha096",
|
||||
# "GTJA_alpha097",
|
||||
# "GTJA_alpha098",
|
||||
# "GTJA_alpha099",
|
||||
# "GTJA_alpha100",
|
||||
# "GTJA_alpha101",
|
||||
# "GTJA_alpha102",
|
||||
# "GTJA_alpha103",
|
||||
# "GTJA_alpha104",
|
||||
# "GTJA_alpha105",
|
||||
# "GTJA_alpha106",
|
||||
# "GTJA_alpha107",
|
||||
# "GTJA_alpha108",
|
||||
# "GTJA_alpha109",
|
||||
# "GTJA_alpha110",
|
||||
# "GTJA_alpha111",
|
||||
# "GTJA_alpha112",
|
||||
# # "GTJA_alpha113",
|
||||
# "GTJA_alpha114",
|
||||
# "GTJA_alpha115",
|
||||
# "GTJA_alpha117",
|
||||
# "GTJA_alpha118",
|
||||
# "GTJA_alpha119",
|
||||
# "GTJA_alpha120",
|
||||
# # "GTJA_alpha121",
|
||||
# "GTJA_alpha122",
|
||||
# "GTJA_alpha123",
|
||||
# "GTJA_alpha124",
|
||||
# "GTJA_alpha125",
|
||||
# "GTJA_alpha126",
|
||||
# "GTJA_alpha127",
|
||||
# "GTJA_alpha128",
|
||||
# "GTJA_alpha129",
|
||||
# "GTJA_alpha130",
|
||||
# "GTJA_alpha131",
|
||||
# "GTJA_alpha132",
|
||||
# "GTJA_alpha133",
|
||||
# "GTJA_alpha134",
|
||||
# "GTJA_alpha135",
|
||||
# "GTJA_alpha136",
|
||||
# # "GTJA_alpha138",
|
||||
# "GTJA_alpha139",
|
||||
# # "GTJA_alpha140",
|
||||
# "GTJA_alpha141",
|
||||
# "GTJA_alpha142",
|
||||
# "GTJA_alpha145",
|
||||
# # "GTJA_alpha146",
|
||||
# "GTJA_alpha148",
|
||||
# "GTJA_alpha150",
|
||||
# "GTJA_alpha151",
|
||||
# "GTJA_alpha152",
|
||||
# "GTJA_alpha153",
|
||||
# "GTJA_alpha154",
|
||||
# "GTJA_alpha155",
|
||||
# "GTJA_alpha156",
|
||||
# "GTJA_alpha157",
|
||||
# "GTJA_alpha158",
|
||||
# "GTJA_alpha159",
|
||||
# "GTJA_alpha160",
|
||||
# "GTJA_alpha161",
|
||||
# "GTJA_alpha162",
|
||||
# "GTJA_alpha163",
|
||||
# "GTJA_alpha164",
|
||||
# # "GTJA_alpha165",
|
||||
# "GTJA_alpha166",
|
||||
# "GTJA_alpha167",
|
||||
# "GTJA_alpha168",
|
||||
# "GTJA_alpha169",
|
||||
# "GTJA_alpha170",
|
||||
# "GTJA_alpha171",
|
||||
# "GTJA_alpha173",
|
||||
# "GTJA_alpha174",
|
||||
# "GTJA_alpha175",
|
||||
# "GTJA_alpha176",
|
||||
# "GTJA_alpha177",
|
||||
# "GTJA_alpha178",
|
||||
# "GTJA_alpha179",
|
||||
# "GTJA_alpha180",
|
||||
# # "GTJA_alpha183",
|
||||
# "GTJA_alpha184",
|
||||
# "GTJA_alpha185",
|
||||
# "GTJA_alpha187",
|
||||
# "GTJA_alpha188",
|
||||
# "GTJA_alpha189",
|
||||
# "GTJA_alpha191",
|
||||
# "chip_dispersion_90",
|
||||
# "chip_dispersion_70",
|
||||
# "cost_skewness",
|
||||
# "dispersion_change_20",
|
||||
# "price_to_avg_cost",
|
||||
# "price_to_median_cost",
|
||||
# "mean_median_dev",
|
||||
# "trap_pressure",
|
||||
# "bottom_profit",
|
||||
# "history_position",
|
||||
# "winner_rate_surge_5",
|
||||
# "winner_rate_cs_rank",
|
||||
# "winner_rate_dev_20",
|
||||
# "winner_rate_volatility",
|
||||
# "smart_money_accumulation",
|
||||
# "winner_vol_corr_20",
|
||||
# "cost_base_momentum",
|
||||
# "bottom_cost_stability",
|
||||
# "pivot_reversion",
|
||||
# "chip_transition",
|
||||
|
||||
]
|
||||
|
||||
# TabM 模型参数配置(来自用户提供的示例代码)
|
||||
@@ -256,11 +61,11 @@ MODEL_PARAMS = {
|
||||
# ==================== MLP 结构 ====================
|
||||
"n_blocks": 3, # MLP 层数
|
||||
"d_block": 256, # 每层神经元数
|
||||
"dropout": 0.3, # Dropout 率
|
||||
"dropout": 0.5, # Dropout 率
|
||||
# ==================== 集成机制 ====================
|
||||
"ensemble_size": 32, # 内置集成大小(模拟 32 个模型集成)
|
||||
# ==================== 训练参数 ====================
|
||||
"batch_size": 2048, # 批次大小
|
||||
"batch_size": STOCK_POOL_SIZE * 5, # 批次大小
|
||||
"learning_rate": 1e-3, # 学习率
|
||||
"weight_decay": 1e-5, # 权重衰减
|
||||
"epochs": 100, # 训练轮数
|
||||
@@ -312,13 +117,14 @@ def main():
|
||||
pipeline = DataPipeline(
|
||||
factor_manager=factor_manager,
|
||||
processor_configs=[
|
||||
(NullFiller, {"strategy": "mean"}),
|
||||
(Winsorizer, {"lower": 0.01, "upper": 0.99}), # 先缩尾处理厚尾分布
|
||||
(NullFiller, {"strategy": "mean"}),
|
||||
(StandardScaler, {}), # TabM 需要标准化输入
|
||||
],
|
||||
label_processor_configs=[
|
||||
# 对 label 进行缩尾处理(去除极端收益率)
|
||||
(Winsorizer, {"lower": 0.05, "upper": 0.95}),
|
||||
(Winsorizer, {"lower": 0.01, "upper": 0.99}),
|
||||
(StandardScaler, {}),
|
||||
],
|
||||
filters=[STFilter(data_router=engine.router)],
|
||||
stock_pool_filter_func=stock_pool_filter,
|
||||
|
||||
Reference in New Issue
Block a user