Files
ProStock/src/experiment/learn_to_rank.ipynb

1371 lines
120 KiB
Plaintext
Raw Normal View History

{
"cells": [
{
"metadata": {},
"cell_type": "markdown",
"source": [
"# Learn-to-Rank 排序学习训练流程\n",
"#\n",
"本 Notebook 实现基于 LightGBM LambdaRank 的排序学习训练,用于股票排序任务。\n",
"#\n",
"## 核心特点\n",
"#\n",
"1. **Label 转换**: 将 `future_return_5` 按每日进行 20 分位数划分qcut\n",
"2. **排序学习**: 使用 LambdaRank 目标函数,学习每日股票排序\n",
"3. **NDCG 评估**: 使用 NDCG@1/5/10/20 评估排序质量\n",
"4. **策略回测**: 基于排序分数构建 Top-k 选股策略"
]
},
{
"metadata": {},
"cell_type": "markdown",
"source": "## 1. 导入依赖"
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-14T15:09:08.941539Z",
"start_time": "2026-03-14T15:09:08.938469Z"
}
},
"cell_type": "code",
"source": [
"import os\n",
"from datetime import datetime\n",
"from typing import List, Tuple, Optional\n",
"\n",
"import numpy as np\n",
"import polars as pl\n",
"import pandas as pd\n",
"import matplotlib.pyplot as plt\n",
"from sklearn.metrics import ndcg_score\n",
"\n",
"from src.factors import FactorEngine\n",
"from src.training import (\n",
" DateSplitter,\n",
" STFilter,\n",
" StockPoolManager,\n",
" Trainer,\n",
" Winsorizer,\n",
" NullFiller,\n",
" StandardScaler,\n",
" check_data_quality,\n",
")\n",
"from src.training.components.models import LightGBMLambdaRankModel\n",
"from src.training.config import TrainingConfig\n",
"\n",
"# 从 common 模块导入共用配置和函数\n",
"from src.experiment.common import (\n",
" SELECTED_FACTORS,\n",
" FACTOR_DEFINITIONS,\n",
" get_label_factor,\n",
" register_factors,\n",
" prepare_data,\n",
" TRAIN_START,\n",
" TRAIN_END,\n",
" VAL_START,\n",
" VAL_END,\n",
" TEST_START,\n",
" TEST_END,\n",
" stock_pool_filter,\n",
" STOCK_FILTER_REQUIRED_COLUMNS,\n",
" OUTPUT_DIR,\n",
" SAVE_PREDICTIONS,\n",
" PERSIST_MODEL,\n",
" TOP_N,\n",
")\n",
"\n"
],
"outputs": [],
"execution_count": 13
},
{
"metadata": {},
"cell_type": "markdown",
"source": "## 2. 本地辅助函数"
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-14T15:09:08.987446Z",
"start_time": "2026-03-14T15:09:08.981368Z"
}
},
"cell_type": "code",
"source": [
"# 注意register_factors 和 prepare_data 已从 common 模块导入\n",
"\n",
"\n",
"def prepare_ranking_data(\n",
" df: pl.DataFrame,\n",
" label_col: str = \"future_return_5\",\n",
" date_col: str = \"trade_date\",\n",
" n_quantiles: int = 20,\n",
") -> Tuple[pl.DataFrame, str]:\n",
" \"\"\"准备排序学习数据\n",
"\n",
" 将连续 label 转换为分位数标签,用于排序学习任务。\n",
"\n",
" Args:\n",
" df: 原始数据\n",
" label_col: 原始标签列名\n",
" date_col: 日期列名\n",
" n_quantiles: 分位数数量\n",
"\n",
" Returns:\n",
" (处理后的 DataFrame, 新的标签列名)\n",
" \"\"\"\n",
" print(\"\\n\" + \"=\" * 80)\n",
" print(f\"准备排序学习数据(将 {label_col} 转换为 {n_quantiles} 分位数标签)\")\n",
" print(\"=\" * 80)\n",
"\n",
" # 新的标签列名\n",
" rank_col = f\"{label_col}_rank\"\n",
"\n",
" # 按日期分组进行分位数划分\n",
" # 使用 rank 生成 0, 1, 2, ..., n_quantiles-1 的标签\n",
" # 方法: 计算每天内的排名,然后映射到 n_quantiles 个分位数组\n",
" df_ranked = (\n",
" df.with_columns(\n",
" # 计算每天内的排名 (1-based)\n",
" pl.col(label_col).rank(method=\"min\").over(date_col).alias(\"_rank\")\n",
" )\n",
" .with_columns(\n",
" # 将排名转换为分位数标签 (0 to n_quantiles-1)\n",
" ((pl.col(\"_rank\") - 1) / pl.len().over(date_col) * n_quantiles)\n",
" .floor()\n",
" .cast(pl.Int64)\n",
" .clip(0, n_quantiles - 1)\n",
" .alias(rank_col)\n",
" )\n",
" .drop(\"_rank\")\n",
" )\n",
"\n",
" # 检查转换结果\n",
" print(f\"\\n原始 {label_col} 统计:\")\n",
" print(df_ranked[label_col].describe())\n",
"\n",
" print(f\"\\n转换后 {rank_col} 统计:\")\n",
" print(df_ranked[rank_col].describe())\n",
"\n",
" # 检查每日样本分布\n",
" print(f\"\\n每日样本数统计:\")\n",
" daily_counts = df_ranked.group_by(date_col).agg(pl.count().alias(\"count\"))\n",
" print(daily_counts[\"count\"].describe())\n",
"\n",
" # 检查分位数分布(应该是均匀的)\n",
" print(f\"\\n分位数标签分布:\")\n",
" rank_dist = df_ranked[rank_col].value_counts().sort(rank_col)\n",
" print(rank_dist)\n",
"\n",
" return df_ranked, rank_col\n",
"\n",
"\n",
"def compute_group_array(df: pl.DataFrame, date_col: str = \"trade_date\") -> np.ndarray:\n",
" \"\"\"计算 group 数组用于 LambdaRank\n",
"\n",
" 每个日期作为一个 querygroup 数组表示每个 query 的样本数。\n",
"\n",
" Args:\n",
" df: 数据框\n",
" date_col: 日期列名\n",
"\n",
" Returns:\n",
" group 数组\n",
" \"\"\"\n",
" group_counts = df.group_by(date_col, maintain_order=True).agg(\n",
" pl.count().alias(\"count\")\n",
" )\n",
" return group_counts[\"count\"].to_numpy()\n",
"\n",
"\n",
"def evaluate_ndcg_at_k(\n",
" y_true: np.ndarray,\n",
" y_pred: np.ndarray,\n",
" group: np.ndarray,\n",
" k_list: List[int] = [1, 5, 10, 20],\n",
") -> dict:\n",
" \"\"\"计算 NDCG@k 指标\n",
"\n",
" Args:\n",
" y_true: 真实标签\n",
" y_pred: 预测分数\n",
" group: 分组数组\n",
" k_list: 要计算的 k 值列表\n",
"\n",
" Returns:\n",
" NDCG 指标字典\n",
" \"\"\"\n",
" results = {}\n",
"\n",
" # 按 group 拆分数据\n",
" start_idx = 0\n",
" y_true_groups = []\n",
" y_pred_groups = []\n",
"\n",
" for group_size in group:\n",
" end_idx = start_idx + group_size\n",
" y_true_groups.append(y_true[start_idx:end_idx])\n",
" y_pred_groups.append(y_pred[start_idx:end_idx])\n",
" start_idx = end_idx\n",
"\n",
" # 计算每个 k 值的平均 NDCG\n",
" for k in k_list:\n",
" ndcg_scores = []\n",
" for yt, yp in zip(y_true_groups, y_pred_groups):\n",
" if len(yt) > 1:\n",
" try:\n",
" score = ndcg_score([yt], [yp], k=k)\n",
" ndcg_scores.append(score)\n",
" except ValueError:\n",
" # 标签都相同,无法计算\n",
" pass\n",
"\n",
" results[f\"ndcg@{k}\"] = np.mean(ndcg_scores) if ndcg_scores else 0.0\n",
"\n",
" return results\n",
"\n"
],
"outputs": [],
"execution_count": 14
},
{
"metadata": {},
"cell_type": "markdown",
"source": [
"## 3. 配置参数\n",
"#\n",
"### 3.1 因子与日期配置"
]
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-14T15:09:08.996882Z",
"start_time": "2026-03-14T15:09:08.993154Z"
}
},
"cell_type": "code",
"source": [
"# 注意SELECTED_FACTORS, FACTOR_DEFINITIONS, 日期配置等已从 common 模块导入\n",
"# 本脚本特有的配置:\n",
"\n",
"# Label 名称(排序学习使用原始收益率,会后续转换为分位数标签)\n",
"LABEL_NAME = \"future_return_5\"\n",
"\n",
"# 获取 Label 因子定义\n",
"LABEL_FACTOR = get_label_factor(LABEL_NAME)\n",
"\n",
"# 分位数配置\n",
"N_QUANTILES = 20 # 将 label 分为 20 组\n",
"\n",
"\n",
"# 分位数配置\n",
"N_QUANTILES = 20 # 将 label 分为 20 组\n",
"\n",
"# LambdaRank 模型参数配置\n",
"MODEL_PARAMS = {\n",
" \"objective\": \"lambdarank\",\n",
" \"metric\": \"ndcg\",\n",
" \"ndcg_at\": 5, # 评估 NDCG@k\n",
" \"learning_rate\": 0.05,\n",
" \"num_leaves\": 31,\n",
" \"max_depth\": 4,\n",
" \"min_data_in_leaf\": 20,\n",
" \"n_estimators\": 2000,\n",
" \"early_stopping_round\": 100,\n",
" \"subsample\": 0.8,\n",
" \"colsample_bytree\": 0.8,\n",
" \"reg_alpha\": 0.1,\n",
" \"reg_lambda\": 1.0,\n",
" \"verbose\": -1,\n",
" \"random_state\": 42,\n",
" \"lambdarank_truncation_level\": 5,\n",
" \"label_gain\": [i for i in range(1, N_QUANTILES + 1)],\n",
"}\n",
"\n",
"# 注意stock_pool_filter, STOCK_FILTER_REQUIRED_COLUMNS, OUTPUT_DIR 等配置\n",
"# 已从 common 模块导入"
],
"outputs": [],
"execution_count": 15
},
{
"metadata": {},
"cell_type": "markdown",
"source": "## 4. 训练流程"
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-14T15:09:21.575611Z",
"start_time": "2026-03-14T15:09:09.003108Z"
}
},
"cell_type": "code",
"source": [
"print(\"\\n\" + \"=\" * 80)\n",
"print(\"LightGBM LambdaRank 排序学习训练\")\n",
"print(\"=\" * 80)\n",
"\n",
"# 1. 创建 FactorEngine启用 metadata 功能)\n",
"print(\"\\n[1] 创建 FactorEngine\")\n",
"engine = FactorEngine()\n",
"\n",
"# 2. 使用 metadata 定义因子\n",
"print(\"\\n[2] 定义因子(从 metadata 注册)\")\n",
"feature_cols = register_factors(\n",
" engine, SELECTED_FACTORS, FACTOR_DEFINITIONS, LABEL_FACTOR\n",
")\n",
"\n",
"# 3. 准备数据\n",
"print(\"\\n[3] 准备数据\")\n",
"data = prepare_data(\n",
" engine=engine,\n",
" feature_cols=feature_cols,\n",
" start_date=TRAIN_START,\n",
" end_date=TEST_END,\n",
" label_name=LABEL_NAME,\n",
")\n",
"\n",
"# 4. 转换为排序学习格式(分位数标签)\n",
"print(\"\\n[4] 转换为排序学习格式\")\n",
"data, target_col = prepare_ranking_data(\n",
" df=data,\n",
" label_col=LABEL_NAME,\n",
" n_quantiles=N_QUANTILES,\n",
")\n",
"\n",
"# 5. 打印配置信息\n",
"print(f\"\\n[配置] 训练期: {TRAIN_START} - {TRAIN_END}\")\n",
"print(f\"[配置] 验证期: {VAL_START} - {VAL_END}\")\n",
"print(f\"[配置] 测试期: {TEST_START} - {TEST_END}\")\n",
"print(f\"[配置] 特征数: {len(feature_cols)}\")\n",
"print(f\"[配置] 目标变量: {target_col}{N_QUANTILES}分位数)\")\n",
"\n",
"# 6. 创建排序学习模型\n",
"model = LightGBMLambdaRankModel(params=MODEL_PARAMS)\n",
"\n",
"# 7. 创建数据处理器(使用函数返回的完整特征列表)\n",
"processors = [\n",
" NullFiller(feature_cols=feature_cols, strategy=\"mean\"),\n",
" Winsorizer(feature_cols=feature_cols, lower=0.01, upper=0.99),\n",
" StandardScaler(feature_cols=feature_cols),\n",
"]\n",
"\n",
"# 8. 创建数据划分器\n",
"splitter = DateSplitter(\n",
" train_start=TRAIN_START,\n",
" train_end=TRAIN_END,\n",
" val_start=VAL_START,\n",
" val_end=VAL_END,\n",
" test_start=TEST_START,\n",
" test_end=TEST_END,\n",
")\n",
"\n",
"# 9. 创建股票池管理器\n",
"pool_manager = StockPoolManager(\n",
" filter_func=stock_pool_filter,\n",
" required_columns=STOCK_FILTER_REQUIRED_COLUMNS,\n",
" data_router=engine.router,\n",
")\n",
"\n",
"# 10. 创建 ST 过滤器\n",
"st_filter = STFilter(data_router=engine.router)\n",
"\n",
"# 11. 创建训练器\n",
"trainer = Trainer(\n",
" model=model,\n",
" pool_manager=pool_manager,\n",
" processors=processors,\n",
" filters=[st_filter],\n",
" splitter=splitter,\n",
" target_col=target_col,\n",
" feature_cols=feature_cols,\n",
" persist_model=PERSIST_MODEL,\n",
")"
],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"================================================================================\n",
"LightGBM LambdaRank 排序学习训练\n",
"================================================================================\n",
"\n",
"[1] 创建 FactorEngine\n",
"\n",
"[2] 定义因子(从 metadata 注册)\n",
"================================================================================\n",
"注册因子\n",
"================================================================================\n",
"\n",
"注册特征因子(从 metadata:\n",
" - ma_5\n",
" - ma_20\n",
" - ma_ratio_5_20\n",
" - bias_10\n",
" - high_low_ratio\n",
" - bbi_ratio\n",
" - return_5\n",
" - return_20\n",
" - kaufman_ER_20\n",
" - mom_acceleration_10_20\n",
" - drawdown_from_high_60\n",
" - up_days_ratio_20\n",
" - volatility_5\n",
" - volatility_20\n",
" - volatility_ratio\n",
" - std_return_20\n",
" - sharpe_ratio_20\n",
" - min_ret_20\n",
" - volatility_squeeze_5_60\n",
" - overnight_intraday_diff\n",
" - upper_shadow_ratio\n",
" - capital_retention_20\n",
" - max_ret_20\n",
" - volume_ratio_5_20\n",
" - turnover_rate_mean_5\n",
" - turnover_deviation\n",
" - amihud_illiq_20\n",
" - turnover_cv_20\n",
" - pv_corr_20\n",
" - close_vwap_deviation\n",
" - roe\n",
" - roa\n",
" - profit_margin\n",
" - debt_to_equity\n",
" - current_ratio\n",
" - net_profit_yoy\n",
" - revenue_yoy\n",
" - healthy_expansion_velocity\n",
" - EP\n",
" - BP\n",
" - CP\n",
" - market_cap_rank\n",
" - turnover_rank\n",
" - return_5_rank\n",
" - EP_rank\n",
" - pe_expansion_trend\n",
" - value_price_divergence\n",
" - active_market_cap\n",
" - ebit_rank\n",
"\n",
"注册特征因子(表达式):\n",
"\n",
"注册 Label 因子(表达式):\n",
" - future_return_5: (ts_delay(close, -5) / ts_delay(open, -1)) - 1\n",
"\n",
"特征因子数: 49\n",
" - 来自 metadata: 49\n",
" - 来自表达式: 0\n",
"Label: future_return_5\n",
"已注册因子总数: 63\n",
"\n",
"[3] 准备数据\n",
"\n",
"================================================================================\n",
"准备数据\n",
"================================================================================\n",
"\n",
"计算因子: 20200101 - 20261231\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"D:\\PyProject\\ProStock\\src\\data\\financial_loader.py:148: UserWarning: Sortedness of columns cannot be checked when 'by' groups provided\n",
" merged = df_price.join_asof(\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"数据形状: (7255513, 70)\n",
"数据列: ['ts_code', 'trade_date', 'amount', 'low', 'turnover_rate', 'vol', 'open', 'close', 'high', 'total_assets', 'total_mv', 'f_ann_date', 'total_hldr_eqy_exc_min_int', 'total_liab', 'total_cur_liab', 'total_cur_assets', 'ebit', 'n_income', 'revenue', 'n_cashflow_act', 'ma_5', 'ma_20', 'ma_ratio_5_20', 'bias_10', 'high_low_ratio', 'bbi_ratio', 'return_5', 'return_20', 'kaufman_ER_20', 'mom_acceleration_10_20', 'drawdown_from_high_60', 'up_days_ratio_20', 'volatility_5', 'volatility_20', 'volatility_ratio', 'std_return_20', 'sharpe_ratio_20', 'min_ret_20', 'volatility_squeeze_5_60', 'overnight_intraday_diff', 'upper_shadow_ratio', 'capital_retention_20', 'max_ret_20', 'volume_ratio_5_20', 'turnover_rate_mean_5', 'turnover_deviation', 'amihud_illiq_20', 'turnover_cv_20', 'pv_corr_20', 'close_vwap_deviation', 'roe', 'roa', 'profit_margin', 'debt_to_equity', 'current_ratio', 'net_profit_yoy', 'revenue_yoy', 'healthy_expansion_velocity', 'EP', 'BP', 'CP', 'market_cap_rank', 'turnover_rank', 'return_5_rank', 'EP_rank', 'pe_expansion_trend', 'value_price_divergence', 'active_market_cap', 'ebit_rank', 'future_return_5']\n",
"\n",
"前5行预览:\n",
"shape: (5, 70)\n",
"┌───────────┬────────────┬──────────┬─────────┬───┬────────────┬───────────┬───────────┬───────────┐\n",
"│ ts_code ┆ trade_date ┆ amount ┆ low ┆ … ┆ value_pric ┆ active_ma ┆ ebit_rank ┆ future_re │\n",
"│ --- ┆ --- ┆ --- ┆ --- ┆ ┆ e_divergen ┆ rket_cap ┆ --- ┆ turn_5 │\n",
"│ str ┆ str ┆ f64 ┆ f64 ┆ ┆ ce ┆ --- ┆ f64 ┆ --- │\n",
"│ ┆ ┆ ┆ ┆ ┆ --- ┆ f64 ┆ ┆ f64 │\n",
"│ ┆ ┆ ┆ ┆ ┆ f64 ┆ ┆ ┆ │\n",
"╞═══════════╪════════════╪══════════╪═════════╪═══╪════════════╪═══════════╪═══════════╪═══════════╡\n",
"│ 000001.SZ ┆ 20200102 ┆ 2.5712e6 ┆ 1806.75 ┆ … ┆ null ┆ null ┆ null ┆ -0.008857 │\n",
"│ 000001.SZ ┆ 20200103 ┆ 1.9145e6 ┆ 1847.15 ┆ … ┆ null ┆ null ┆ null ┆ -0.01881 │\n",
"│ 000001.SZ ┆ 20200106 ┆ 1.4779e6 ┆ 1846.05 ┆ … ┆ null ┆ null ┆ null ┆ -0.008171 │\n",
"│ 000001.SZ ┆ 20200107 ┆ 1.2470e6 ┆ 1850.42 ┆ … ┆ null ┆ null ┆ null ┆ -0.014117 │\n",
"│ 000001.SZ ┆ 20200108 ┆ 1.4236e6 ┆ 1815.49 ┆ … ┆ null ┆ null ┆ null ┆ -0.017252 │\n",
"└───────────┴────────────┴──────────┴─────────┴───┴────────────┴───────────┴───────────┴───────────┘\n",
"\n",
"[4] 转换为排序学习格式\n",
"\n",
"================================================================================\n",
"准备排序学习数据(将 future_return_5 转换为 20 分位数标签)\n",
"================================================================================\n",
"\n",
"原始 future_return_5 统计:\n",
"shape: (9, 2)\n",
"┌────────────┬────────────┐\n",
"│ statistic ┆ value │\n",
"│ --- ┆ --- │\n",
"│ str ┆ f64 │\n",
"╞════════════╪════════════╡\n",
"│ count ┆ 7.227054e6 │\n",
"│ null_count ┆ 28459.0 │\n",
"│ mean ┆ 0.003978 │\n",
"│ std ┆ 0.073204 │\n",
"│ min ┆ -0.969459 │\n",
"│ 25% ┆ -0.032998 │\n",
"│ 50% ┆ -0.001278 │\n",
"│ 75% ┆ 0.032666 │\n",
"│ max ┆ 10.361925 │\n",
"└────────────┴────────────┘\n",
"\n",
"转换后 future_return_5_rank 统计:\n",
"shape: (9, 2)\n",
"┌────────────┬────────────┐\n",
"│ statistic ┆ value │\n",
"│ --- ┆ --- │\n",
"│ str ┆ f64 │\n",
"╞════════════╪════════════╡\n",
"│ count ┆ 7.227054e6 │\n",
"│ null_count ┆ 28459.0 │\n",
"│ mean ┆ 9.493551 │\n",
"│ std ┆ 5.765628 │\n",
"│ min ┆ 0.0 │\n",
"│ 25% ┆ 4.0 │\n",
"│ 50% ┆ 9.0 │\n",
"│ 75% ┆ 14.0 │\n",
"│ max ┆ 19.0 │\n",
"└────────────┴────────────┘\n",
"\n",
"每日样本数统计:\n",
"shape: (9, 2)\n",
"┌────────────┬─────────────┐\n",
"│ statistic ┆ value │\n",
"│ --- ┆ --- │\n",
"│ str ┆ f64 │\n",
"╞════════════╪═════════════╡\n",
"│ count ┆ 1494.0 │\n",
"│ null_count ┆ 0.0 │\n",
"│ mean ┆ 4856.434404 │\n",
"│ std ┆ 564.521537 │\n",
"│ min ┆ 2885.0 │\n",
"│ 25% ┆ 4382.0 │\n",
"│ 50% ┆ 5069.0 │\n",
"│ 75% ┆ 5347.0 │\n",
"│ max ┆ 5476.0 │\n",
"└────────────┴─────────────┘\n",
"\n",
"分位数标签分布:\n",
"shape: (21, 2)\n",
"┌──────────────────────┬────────┐\n",
"│ future_return_5_rank ┆ count │\n",
"│ --- ┆ --- │\n",
"│ i64 ┆ u32 │\n",
"╞══════════════════════╪════════╡\n",
"│ null ┆ 28459 │\n",
"│ 0 ┆ 362270 │\n",
"│ 1 ┆ 361546 │\n",
"│ 2 ┆ 361599 │\n",
"│ 3 ┆ 361755 │\n",
"│ … ┆ … │\n",
"│ 15 ┆ 361289 │\n",
"│ 16 ┆ 361218 │\n",
"│ 17 ┆ 361227 │\n",
"│ 18 ┆ 361252 │\n",
"│ 19 ┆ 359483 │\n",
"└──────────────────────┴────────┘\n",
"\n",
"[配置] 训练期: 20200101 - 20231231\n",
"[配置] 验证期: 20240101 - 20241231\n",
"[配置] 测试期: 20250101 - 20261231\n",
"[配置] 特征数: 49\n",
"[配置] 目标变量: future_return_5_rank20分位数\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"C:\\Users\\liaozhaorun\\AppData\\Local\\Temp\\ipykernel_28476\\562285170.py:58: DeprecationWarning: `pl.count()` is deprecated. Please use `pl.len()` instead.\n",
"(Deprecated in version 0.20.5)\n",
" daily_counts = df_ranked.group_by(date_col).agg(pl.count().alias(\"count\"))\n"
]
}
],
"execution_count": 16
},
{
"metadata": {},
"cell_type": "markdown",
"source": "### 4.1 股票池筛选"
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-14T15:09:45.297379Z",
"start_time": "2026-03-14T15:09:21.587998Z"
}
},
"cell_type": "code",
"source": [
"print(\"\\n\" + \"=\" * 80)\n",
"print(\"股票池筛选\")\n",
"print(\"=\" * 80)\n",
"\n",
"# 先执行 ST 过滤(在股票池筛选之前,与 Trainer.train() 保持一致)\n",
"if st_filter:\n",
" print(\"\\n[过滤] 应用 ST 过滤器...\")\n",
" data = st_filter.filter(data)\n",
" print(f\" ST 过滤后数据规模: {data.shape}\")\n",
"\n",
"if pool_manager:\n",
" print(\"\\n执行每日独立筛选股票池...\")\n",
" filtered_data = pool_manager.filter_and_select_daily(data)\n",
" print(f\" 筛选前数据规模: {data.shape}\")\n",
" print(f\" 筛选后数据规模: {filtered_data.shape}\")\n",
" print(f\" 筛选前股票数: {data['ts_code'].n_unique()}\")\n",
" print(f\" 筛选后股票数: {filtered_data['ts_code'].n_unique()}\")\n",
" print(f\" 删除记录数: {len(data) - len(filtered_data)}\")\n",
"else:\n",
" filtered_data = data\n",
" print(\" 未配置股票池管理器,跳过筛选\")"
],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"================================================================================\n",
"股票池筛选\n",
"================================================================================\n",
"\n",
"[过滤] 应用 ST 过滤器...\n",
" ST 过滤后数据规模: (7027678, 71)\n",
"\n",
"执行每日独立筛选股票池...\n",
" 筛选前数据规模: (7027678, 71)\n",
" 筛选后数据规模: (747000, 71)\n",
" 筛选前股票数: 5694\n",
" 筛选后股票数: 1439\n",
" 删除记录数: 6280678\n"
]
}
],
"execution_count": 17
},
{
"metadata": {},
"cell_type": "markdown",
"source": "### 4.2 数据划分"
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-14T15:09:45.376100Z",
"start_time": "2026-03-14T15:09:45.307622Z"
}
},
"cell_type": "code",
"source": [
"print(\"\\n\" + \"=\" * 80)\n",
"print(\"数据划分\")\n",
"print(\"=\" * 80)\n",
"\n",
"if splitter:\n",
" train_data, val_data, test_data = splitter.split(filtered_data)\n",
" print(f\"\\n训练集数据规模: {train_data.shape}\")\n",
" print(f\"验证集数据规模: {val_data.shape}\")\n",
" print(f\"测试集数据规模: {test_data.shape}\")\n",
"\n",
" # 计算各集的 group 数组\n",
" train_group = compute_group_array(train_data)\n",
" val_group = compute_group_array(val_data)\n",
" test_group = compute_group_array(test_data)\n",
"\n",
" print(f\"\\n训练集 group 数量: {len(train_group)}\")\n",
" print(f\"验证集 group 数量: {len(val_group)}\")\n",
" print(f\"测试集 group 数量: {len(test_group)}\")\n",
" print(f\"训练集日均样本数: {np.mean(train_group):.1f}\")\n",
" print(f\"验证集日均样本数: {np.mean(val_group):.1f}\")\n",
" print(f\"测试集日均样本数: {np.mean(test_group):.1f}\")\n",
"else:\n",
" raise ValueError(\"必须配置数据划分器\")"
],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"================================================================================\n",
"数据划分\n",
"================================================================================\n",
"\n",
"训练集数据规模: (485000, 71)\n",
"验证集数据规模: (121000, 71)\n",
"测试集数据规模: (141000, 71)\n",
"\n",
"训练集 group 数量: 970\n",
"验证集 group 数量: 242\n",
"测试集 group 数量: 282\n",
"训练集日均样本数: 500.0\n",
"验证集日均样本数: 500.0\n",
"测试集日均样本数: 500.0\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"C:\\Users\\liaozhaorun\\AppData\\Local\\Temp\\ipykernel_28476\\562285170.py:82: DeprecationWarning: `pl.count()` is deprecated. Please use `pl.len()` instead.\n",
"(Deprecated in version 0.20.5)\n",
" pl.count().alias(\"count\")\n"
]
}
],
"execution_count": 18
},
{
"metadata": {},
"cell_type": "markdown",
"source": "### 4.3 数据质量检查"
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-14T15:09:47.637521Z",
"start_time": "2026-03-14T15:09:45.382468Z"
}
},
"cell_type": "code",
"source": [
"print(\"\\n\" + \"=\" * 80)\n",
"print(\"数据质量检查(必须在预处理之前)\")\n",
"print(\"=\" * 80)\n",
"\n",
"print(\"\\n检查训练集...\")\n",
"check_data_quality(train_data, feature_cols, raise_on_error=False)\n",
"\n",
"print(\"\\n检查验证集...\")\n",
"check_data_quality(val_data, feature_cols, raise_on_error=True)\n",
"\n",
"print(\"\\n检查测试集...\")\n",
"check_data_quality(test_data, feature_cols, raise_on_error=True)\n",
"\n",
"print(\"[成功] 数据质量检查通过,未发现异常\")\n"
],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"================================================================================\n",
"数据质量检查(必须在预处理之前)\n",
"================================================================================\n",
"\n",
"检查训练集...\n",
"\n",
"================================================================================\n",
"数据质量检查报告\n",
"================================================================================\n",
"\n",
"[严重] 发现 1638 个全空因子:\n",
" (某天的某个因子所有值都是 null可能是数据缺失或计算错误)\n",
" - 日期 20200824: roa (样本数: 500)\n",
" - 日期 20200824: net_profit_yoy (样本数: 500)\n",
" - 日期 20200824: revenue_yoy (样本数: 500)\n",
" - 日期 20200824: healthy_expansion_velocity (样本数: 500)\n",
" - 日期 20200824: value_price_divergence (样本数: 500)\n",
" - 日期 20200115: ma_20 (样本数: 500)\n",
" - 日期 20200115: ma_ratio_5_20 (样本数: 500)\n",
" - 日期 20200115: high_low_ratio (样本数: 500)\n",
" - 日期 20200115: bbi_ratio (样本数: 500)\n",
" - 日期 20200115: return_20 (样本数: 500)\n",
" ... 还有 1628 个\n",
"\n",
"--------------------------------------------------------------------------------\n",
"建议处理方式:\n",
" 1. 检查因子定义和数据源,确认计算逻辑是否正确\n",
" 2. 如果是预期内的缺失(如新股无历史数据),考虑调整因子计算窗口\n",
" 3. 如果是数据同步问题,重新同步相关数据\n",
" 4. 可以使用 filter 排除问题日期或因子\n",
"================================================================================\n",
"\n",
"检查验证集...\n",
"\n",
"检查测试集...\n",
"[成功] 数据质量检查通过,未发现异常\n"
]
}
],
"execution_count": 19
},
{
"metadata": {},
"cell_type": "markdown",
"source": "### 4.4 数据预处理"
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-14T15:09:47.934178Z",
"start_time": "2026-03-14T15:09:47.643592Z"
}
},
"cell_type": "code",
"source": [
"print(\"\\n\" + \"=\" * 80)\n",
"print(\"数据预处理\")\n",
"print(\"=\" * 80)\n",
"\n",
"fitted_processors = []\n",
"if processors:\n",
" print(\"\\n训练集处理...\")\n",
" for i, processor in enumerate(processors, 1):\n",
" print(f\" [{i}/{len(processors)}] {processor.__class__.__name__}\")\n",
" train_data = processor.fit_transform(train_data)\n",
" fitted_processors.append(processor)\n",
"\n",
" print(\"\\n验证集处理...\")\n",
" for processor in fitted_processors:\n",
" val_data = processor.transform(val_data)\n",
"\n",
" print(\"\\n测试集处理...\")\n",
" for processor in fitted_processors:\n",
" test_data = processor.transform(test_data)\n",
"\n",
"print(f\"\\n处理后训练集形状: {train_data.shape}\")\n",
"print(f\"处理后验证集形状: {val_data.shape}\")\n",
"print(f\"处理后测试集形状: {test_data.shape}\")"
],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"================================================================================\n",
"数据预处理\n",
"================================================================================\n",
"\n",
"训练集处理...\n",
" [1/3] NullFiller\n",
" [2/3] Winsorizer\n",
" [3/3] StandardScaler\n",
"\n",
"验证集处理...\n",
"\n",
"测试集处理...\n",
"\n",
"处理后训练集形状: (485000, 71)\n",
"处理后验证集形状: (121000, 71)\n",
"处理后测试集形状: (141000, 71)\n"
]
}
],
"execution_count": 20
},
{
"metadata": {},
"cell_type": "markdown",
"source": "### 4.4 训练 LambdaRank 模型"
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-14T15:09:48.844310Z",
"start_time": "2026-03-14T15:09:47.939155Z"
}
},
"cell_type": "code",
"source": [
"print(\"\\n\" + \"=\" * 80)\n",
"print(\"训练 LambdaRank 模型\")\n",
"print(\"=\" * 80)\n",
"\n",
"# 准备数据\n",
"X_train = train_data.select(feature_cols)\n",
"y_train = train_data.select(target_col).to_series()\n",
"\n",
"X_val = val_data.select(feature_cols)\n",
"y_val = val_data.select(target_col).to_series()\n",
"\n",
"print(f\"\\n训练样本数: {len(X_train)}\")\n",
"print(f\"验证样本数: {len(X_val)}\")\n",
"print(f\"特征数: {len(feature_cols)}\")\n",
"print(f\"目标变量: {target_col}\")\n",
"\n",
"print(\"\\n目标变量统计训练集:\")\n",
"print(y_train.describe())\n",
"\n",
"print(\"\\n开始训练...\")\n",
"model.fit(\n",
" X=X_train,\n",
" y=y_train,\n",
" group=train_group,\n",
" eval_set=(X_val, y_val, val_group),\n",
")\n",
"print(\"训练完成!\")"
],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"================================================================================\n",
"训练 LambdaRank 模型\n",
"================================================================================\n",
"\n",
"训练样本数: 485000\n",
"验证样本数: 121000\n",
"特征数: 49\n",
"目标变量: future_return_5_rank\n",
"\n",
"目标变量统计(训练集):\n",
"shape: (9, 2)\n",
"┌────────────┬──────────┐\n",
"│ statistic ┆ value │\n",
"│ --- ┆ --- │\n",
"│ str ┆ f64 │\n",
"╞════════════╪══════════╡\n",
"│ count ┆ 484665.0 │\n",
"│ null_count ┆ 335.0 │\n",
"│ mean ┆ 9.988943 │\n",
"│ std ┆ 5.224762 │\n",
"│ min ┆ 0.0 │\n",
"│ 25% ┆ 6.0 │\n",
"│ 50% ┆ 10.0 │\n",
"│ 75% ┆ 14.0 │\n",
"│ max ┆ 19.0 │\n",
"└────────────┴──────────┘\n",
"\n",
"开始训练...\n",
"Training until validation scores don't improve for 100 rounds\n",
"Early stopping, best iteration is:\n",
"[5]\ttrain's ndcg@5: 0.610972\tval's ndcg@5: 0.558883\n",
"训练完成!\n"
]
}
],
"execution_count": 21
},
{
"metadata": {},
"cell_type": "markdown",
"source": "### 4.5 训练指标曲线"
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-14T15:09:48.923721Z",
"start_time": "2026-03-14T15:09:48.850317Z"
}
},
"cell_type": "code",
"source": [
"print(\"\\n\" + \"=\" * 80)\n",
"print(\"训练指标曲线\")\n",
"print(\"=\" * 80)\n",
"\n",
"# 从模型获取训练评估结果\n",
"evals_result = model.get_evals_result()\n",
"\n",
"if evals_result is None or not evals_result:\n",
" print(\"[警告] 没有可用的训练指标,请确保训练时使用了 eval_set 参数\")\n",
"else:\n",
" print(\"[成功] 已从模型获取训练评估结果\")\n",
"\n",
" # 获取评估的 NDCG 指标\n",
" ndcg_metrics = [k for k in evals_result[\"train\"].keys() if \"ndcg\" in k]\n",
" print(f\"\\n评估的 NDCG 指标: {ndcg_metrics}\")\n",
"\n",
" # 显示早停信息\n",
" actual_rounds = len(list(evals_result[\"train\"].values())[0])\n",
" expected_rounds = MODEL_PARAMS.get(\"n_estimators\", 1000)\n",
" print(f\"\\n[早停信息]\")\n",
" print(f\" 配置的最大轮数: {expected_rounds}\")\n",
" print(f\" 实际训练轮数: {actual_rounds}\")\n",
"\n",
" best_iter = model.get_best_iteration()\n",
" if best_iter is not None and best_iter < actual_rounds:\n",
" print(f\" 早停状态: 已触发(最佳迭代: {best_iter}\")\n",
" else:\n",
" print(f\" 早停状态: 未触发(达到最大轮数)\")\n",
"\n",
" # 显示各 NDCG 指标的最终值\n",
" print(f\"\\n最终 NDCG 指标:\")\n",
" for metric in ndcg_metrics:\n",
" train_ndcg = evals_result[\"train\"][metric][-1]\n",
" val_ndcg = evals_result[\"val\"][metric][-1]\n",
" print(f\" {metric}: 训练集={train_ndcg:.4f}, 验证集={val_ndcg:.4f}\")\n",
"\n",
" # 使用封装好的方法绘制所有指标\n",
" print(\"\\n[绘图] 使用 LightGBM 原生接口绘制训练曲线...\")\n",
" fig = model.plot_all_metrics(metrics=ndcg_metrics[:4], figsize=(14, 10))\n",
" plt.show()\n",
"\n",
" print(f\"\\n[指标分析]\")\n",
" print(f\" 各NDCG指标在验证集上的最佳值:\")\n",
" for metric in ndcg_metrics:\n",
" val_metric_list = evals_result[\"val\"][metric]\n",
" best_iter_metric = val_metric_list.index(max(val_metric_list))\n",
" best_val = max(val_metric_list)\n",
" print(f\" {metric}: {best_val:.4f} (迭代 {best_iter_metric + 1})\")\n",
" print(f\"\\n[重要提醒] 验证集仅用于早停/调参,测试集完全独立于训练过程!\")"
],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"================================================================================\n",
"训练指标曲线\n",
"================================================================================\n",
"[成功] 已从模型获取训练评估结果\n",
"\n",
"评估的 NDCG 指标: ['ndcg@5']\n",
"\n",
"[早停信息]\n",
" 配置的最大轮数: 2000\n",
" 实际训练轮数: 105\n",
" 早停状态: 已触发(最佳迭代: 5\n",
"\n",
"最终 NDCG 指标:\n",
" ndcg@5: 训练集=0.7120, 验证集=0.5334\n",
"\n",
"[绘图] 使用 LightGBM 原生接口绘制训练曲线...\n"
]
},
{
"data": {
"text/plain": [
"<Figure size 1400x1000 with 1 Axes>"
],
"image/png": "iVBORw0KGgoAAAANSUhEUgAABW0AAAPdCAYAAADxjUr8AAAAOnRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjEwLjgsIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcvwVt1zgAAAAlwSFlzAAAPYQAAD2EBqD+naQAA0G9JREFUeJzs3QeYXXWZP/B3eksySWbSe4EQShJ6lY4ogoKNZgFXcC279lXsWP+uLrKWFRu2lcWyrA2kF6XXJLR00vskM5Nk+sz9P+dMMiQkQMqUm5nPx+c8p9xzz/3dc+cm8s077y8nk8lkAgAAAACArJDb0wMAAAAAAOBFQlsAAAAAgCwitAUAAAAAyCJCWwAAAACALCK0BQAAAADIIkJbAAAAAIAsIrQFAAAAAMgiQlsAAAAAgCwitAUAAAAAyCJCWwAA9juXXXZZ5OTkpMuBBx4YbW1tHY9dd911HY/94he/SI+NHz++41heXl6UlZWlx97whjfEb37zm2htbd3l61RVVcUXvvCFmDFjRvTv3z/69esXU6dOjQ984APx7LPP7nBuQ0ND/OAHP4hTTjklKioqorCwMEaNGhXHHXdcfPGLX4xFixZ1yntfvHhxx3t56TJw4MBOeQ0AAHpWfg+/PgAA7JP58+fH7373u7jooot26/wk4K2rq4slS5akyy233BLXX399/N///V8MGDCg47ynnnoqDXVXrVq1w/PnzJmTLkkoe+2116bHVq5cGeecc07MmjVrh3OT48nyyCOPREFBQXzuc5/baTxJmJtc529/+1ssX748HcO0adPiXe96V1x66aWRm6vOAgCgr/H/AAEA2O99/etfj0wms1vnJudt2bIl7rrrrjj88MPTY3fffXe8973v7Thn8+bN8cY3vrEjsH3nO98ZCxcujMbGxpg7d258/vOfT6t1t13vzW9+c0dge+qpp6YhbVJ5W1NTE/fdd1986EMf6jh/e9/97nfTyt3vfe97sWDBgvQ5a9eujTvvvDMNbZNrJdW+L+eFF15IX3/bUl1dvYd3DgCAbKTSFgCA/VrS7uDpp5+OP//5z/GmN71pt55TWloap59+etx2220xefLkqK2tjd///vfpdQ477LD46U9/mla9JpL2Br/61a86npu0Y/jyl7/c0VLhL3/5SxrSJsaMGZNW7paUlKT7RUVFcfLJJ6fLS/3Xf/1XfPjDH47Kysq03cJb3/rWmDhxYjQ1NcVjjz2WBrl//etf49xzz02D36SyFwCAvkGlLQAA+7W3v/3t6fprX/vaHj93yJAhaQuCbZIWBYlbb72141gSrL5cWLz9cxJJ+LotsH0ly5Yti4997GNx6KGHxsyZM9Oq3ksuuSSGDx8eP//5z+Nf//VfY9OmTfHNb34zHn744bQid1eOOeaYtO3CyJEj4/LLL48VK1bswbsHACBbCW0BANivJQFn0gc2qU69/fbb9/j5SeXs9pN8JZJet9scdNBBr/j87c+dMmVKx/Y///M/7zBJWBLIbvOf//mf6bGkj+6wYcPiLW95SzzzzDNpW4Z/+7d/S1skJJLtpCI4qfzdlXXr1kVLS0vaxiGZdO3YY49NjwEAsH8T2gIAsF8bOHBgWuG6t9W2ycRk2yRB6vbrPbW7z0sqeZM+uElrhqRFQxL8vu51r4uNGzemk5Ul622SSuCkj27SwiGR9Mb9xje+kYa8yYRqzz33XJxwwgnpY0ml7Q9+8IO9GjsAANlDaAsAwH4vaTWQ9Kn9+9//Hvfff/8ePXfevHkd2+PHj0/XY8eO7TiWBKavZNy4cbu81nXXXZdODrb949tPIJa0NkjMnz+/ozI3CaA/+tGP7nDutgrdpAp3W0uHT3/603HIIYekrRiSicy+/e1vd5yfVBwDALB/E9oCALDfS4LMK664It2+8cYbd/t5SSuBG264oWP/9a9/fbpOql63ebl+stsmItv+3CSobW5u3q3X3tYTNwl2t/fS/aeeeirtW1tRUbFTZfCuKnz3tkoYAIDsIbQFAKBX+OQnPxmFhYUdYeorqa+vj3vuuScNXJMJvxIXXnhhOjFY4r3vfW+MHj063X7wwQfTSb6S6tgkkE2qaT//+c+nS+K8886Lo446Kt1Ozrngggvi6aefTs9NJhxraGjY6fWTit4nnnhih566P/rRj6K6urojJE7C2Ztvvjm+9a1vxRlnnBFFRUXp8aR9QvJek9doamqKOXPmxMc//vGOa5944on7eCcBAOhpOZmX/lM+AABkucsuuyx++ctfptvPP/98x2Rh73vf++LHP/5xx3k///nP03OTkHT7CcNeKglFb7rppnRCs+0rXM8555xYvXr1Lp/z4Q9/OK699tp0Owlnzz777HQsLyeZcGzbtZIWCMk4k8B1xIgRaW/bbeNLKnCTcWzra1teXp4GxwcffHC6/5GPfCSdyGxXkvvwyCOP7PA+AADY/6i0BQCg1/jUpz4V+fn5r3hO0j4g6QWb9JpNQtmkPUIyGdhLg87DDz88rWZNKlunTZuWTgCW9M2dMmVK2n82qcbdZsyYMWkv2aQq9thjj02vlYSvSUuDZJKwz3zmM3H33Xd3nP+v//qvaSXu+eefn4azv//979NQdtskY0mQm/Syfcc73hGPP/54R2CbSELoD37wg2lP2yTQTaqLk9D3E5/4RDz00EMCWwCAXkClLQAA9IBrrrkmbWuQBLRJ2Jy0WUhaMtTW1sbSpUvj3nvvTQPY7cNhAAD6BqEtAAD0kG9+85vx2c9+dpd9eJMK2p/85Cfxrne9q0fGBgBAzxHaAgBAD3ruuefS3rh33nlnrFy5MgYNGhSvfe1r46qrruro1QsAQN8itAUAAAAAyCImIgMAAAAAyCJCWwAAAACALJIfWeAHP/hBfOtb34rVq1fH9OnT43vf+14cc8wxuzz31FNPjfvuu2+n4+ecc07cfPPN6fZll10Wv/zlL3d4/Oyzz45bb711t8bT1taW9hPr379/5OTk7NV7AgAAAADYXiaTiU2bNsXIkSMjNzc3e0Pb3/72t/Gxj30srrvuujj22GPTSRiSgHXu3LkxdOjQnc6/6aaboqmpqWO/qqoqDXrf9ra37XDe6173uvj5z3/esV9UVLTbY0oC2zFjxuz1ewIAAAAAeDnLli2L0aNHZ29oe80118QVV1wRl19+ebqfhLdJxez1118fn/70p3c6f/DgwTvs33jjjVFaWrpTaJuEtMOHD9+rMSUVtokXXnhhp9cDeofm5ua4/fbb09m5CwoKeno4QCfzHYfez/ccejffcej9+ur3vLa2Ni0W3ZY/ZmVom1TMPvHEE3HVVVd1HEvKgs8888x46KGHdusaP/vZz+Kiiy6KsrKyHY7fe++9aaXuoEGD4vTTT4+vfvWrUVFRsctrNDY2pss2SYlyori4OEpKSvby3QHZLD8/P/0Hn+Q73pf+coC+wnccej/fc+jdfMeh9+ur3/Pm5uZ0/WotWXs0tF2/fn20trbGsGHDdjie7M+ZM+dVn//oo4/GM888kwa3L22N8OY3vzkmTJgQCxcujM985jPx+te/Pg2C8/LydrrON77xjbj66qt3On7PPfekPzxA73XHHXf09BCALuQ7Dr2f7zn0br7j0Pv1te95XV3dbp3X4+0R9kUS1h522GE7TVqWVN5ukzw+bdq0mDRpUlp9e8YZZ+x0naTSN+mr+9Iy5dNOO+1lq3OB/f9ftpK/GM4666w+9S960Ff4jkPv53sOvZvvOPR+ffV7Xltbm/2hbWVlZVr5umbNmh2OJ/uv1o92y5YtaT/bL3/5y6/6OhMnTkxfa8GCBbsMbZP+t7uaqCz5gelLPzTQF/meQ+/mOw69n+859G6+49D79bXvecFuvtceDW0LCwvjyCOPjLvuuivOP//89FhbW1u6/6EPfegVn/v73/8+7UP7jne841VfZ/ny5VFVVRUjRozotLEDAAAAQG+U5HPJXFRdXWmb9LVtaGhI26f2plA2bxftWfdUj7dHSNoSvPvd746jjjoqbXNw7bXXplW0l19+efr4u971rhg1alTad/alrRGSoPel7Qs2b96c9qd9y1v
},
"metadata": {},
"output_type": "display_data",
"jetTransient": {
"display_id": null
}
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"[指标分析]\n",
" 各NDCG指标在验证集上的最佳值:\n",
" ndcg@5: 0.5589 (迭代 5)\n",
"\n",
"[重要提醒] 验证集仅用于早停/调参,测试集完全独立于训练过程!\n"
]
}
],
"execution_count": 22
},
{
"metadata": {},
"cell_type": "markdown",
"source": "### 4.6 模型评估"
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-14T15:09:49.191132Z",
"start_time": "2026-03-14T15:09:48.936532Z"
}
},
"cell_type": "code",
"source": [
"print(\"\\n\" + \"=\" * 80)\n",
"print(\"模型评估\")\n",
"print(\"=\" * 80)\n",
"\n",
"# 准备测试集\n",
"X_test = test_data.select(feature_cols)\n",
"y_test = test_data.select(target_col).to_series()\n",
"\n",
"# 预测\n",
"print(\"\\n生成预测...\")\n",
"predictions = model.predict(X_test)\n",
"\n",
"# 添加预测列\n",
"test_data = test_data.with_columns([pl.Series(\"prediction\", predictions)])\n",
"\n",
"# 计算 NDCG 指标\n",
"print(\"\\n计算 NDCG 指标...\")\n",
"ndcg_results = evaluate_ndcg_at_k(\n",
" y_true=y_test.to_numpy(),\n",
" y_pred=predictions,\n",
" group=test_group,\n",
" k_list=[1, 5, 10, 20],\n",
")\n",
"\n",
"print(\"\\nNDCG 评估结果:\")\n",
"print(\"-\" * 40)\n",
"for metric, value in ndcg_results.items():\n",
" print(f\" {metric}: {value:.4f}\")\n",
"\n",
"# 特征重要性\n",
"print(\"\\n特征重要性Top 20:\")\n",
"print(\"-\" * 40)\n",
"importance = model.feature_importance()\n",
"if importance is not None:\n",
" top_features = importance.sort_values(ascending=False).head(20)\n",
" for i, (feature, score) in enumerate(top_features.items(), 1):\n",
" print(f\" {i:2d}. {feature:30s} {score:10.2f}\")"
],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"================================================================================\n",
"模型评估\n",
"================================================================================\n",
"\n",
"生成预测...\n",
"\n",
"计算 NDCG 指标...\n",
"\n",
"NDCG 评估结果:\n",
"----------------------------------------\n",
" ndcg@1: 0.5390\n",
" ndcg@5: 0.5280\n",
" ndcg@10: 0.5261\n",
" ndcg@20: 0.5281\n",
"\n",
"特征重要性Top 20:\n",
"----------------------------------------\n",
" 1. max_ret_20 178.55\n",
" 2. ma_ratio_5_20 175.30\n",
" 3. ma_5 161.37\n",
" 4. market_cap_rank 144.83\n",
" 5. CP 130.89\n",
" 6. roa 109.85\n",
" 7. healthy_expansion_velocity 108.79\n",
" 8. roe 107.79\n",
" 9. ebit_rank 103.48\n",
" 10. close_vwap_deviation 96.07\n",
" 11. std_return_20 94.57\n",
" 12. revenue_yoy 93.39\n",
" 13. ma_20 90.67\n",
" 14. turnover_rank 78.58\n",
" 15. pv_corr_20 75.06\n",
" 16. amihud_illiq_20 71.67\n",
" 17. EP_rank 60.11\n",
" 18. return_20 46.65\n",
" 19. min_ret_20 46.34\n",
" 20. volume_ratio_5_20 45.85\n"
]
}
],
"execution_count": 23
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-14T15:09:49.453399Z",
"start_time": "2026-03-14T15:09:49.213145Z"
}
},
"cell_type": "code",
"source": [
"# 确保输出目录存在\n",
"os.makedirs(OUTPUT_DIR, exist_ok=True)\n",
"\n",
"# 生成时间戳\n",
"start_dt = datetime.strptime(TEST_START, \"%Y%m%d\")\n",
"end_dt = datetime.strptime(TEST_END, \"%Y%m%d\")\n",
"date_str = f\"{start_dt.strftime('%Y%m%d')}_{end_dt.strftime('%Y%m%d')}\"\n",
"\n",
"# 保存每日 Top N\n",
"print(f\"\\n[1/1] 保存每日 Top {TOP_N} 股票...\")\n",
"topn_output_path = os.path.join(OUTPUT_DIR, \"rank_output.csv\")\n",
"\n",
"# 按日期分组,取每日 top N\n",
"topn_by_date = []\n",
"unique_dates = test_data[\"trade_date\"].unique().sort()\n",
"for date in unique_dates:\n",
" day_data = test_data.filter(test_data[\"trade_date\"] == date)\n",
" # 按 prediction 降序排序,取前 N\n",
" topn = day_data.sort(\"prediction\", descending=True).head(TOP_N)\n",
" topn_by_date.append(topn)\n",
"\n",
"# 合并所有日期的 top N\n",
"topn_results = pl.concat(topn_by_date)\n",
"\n",
"# 格式化日期并调整列顺序:日期、分数、股票\n",
"topn_to_save = topn_results.select(\n",
" [\n",
" pl.col(\"trade_date\").str.slice(0, 4)\n",
" + \"-\"\n",
" + pl.col(\"trade_date\").str.slice(4, 2)\n",
" + \"-\"\n",
" + pl.col(\"trade_date\").str.slice(6, 2).alias(\"date\"),\n",
" pl.col(\"prediction\").alias(\"score\"),\n",
" pl.col(\"ts_code\"),\n",
" ]\n",
")\n",
"topn_to_save.write_csv(topn_output_path, include_header=True)\n",
"print(f\" 保存路径: {topn_output_path}\")\n",
"print(\n",
" f\" 保存行数: {len(topn_to_save)}{len(unique_dates)}个交易日 x 每日top{TOP_N}\"\n",
")\n",
"print(f\"\\n 预览前15行:\")\n",
"print(topn_to_save.head(15))\n",
"\n",
"print(\"\\n训练流程完成\")"
],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"[1/1] 保存每日 Top 5 股票...\n",
" 保存路径: output\\rank_output.csv\n",
" 保存行数: 1410282个交易日 x 每日top5\n",
"\n",
" 预览前15行:\n",
"shape: (15, 3)\n",
"┌────────────┬──────────┬───────────┐\n",
"│ trade_date ┆ score ┆ ts_code │\n",
"│ --- ┆ --- ┆ --- │\n",
"│ str ┆ f64 ┆ str │\n",
"╞════════════╪══════════╪═══════════╡\n",
"│ 2025-01-02 ┆ 0.092165 ┆ 002816.SZ │\n",
"│ 2025-01-02 ┆ 0.067764 ┆ 002634.SZ │\n",
"│ 2025-01-02 ┆ 0.066779 ┆ 002836.SZ │\n",
"│ 2025-01-02 ┆ 0.054118 ┆ 000004.SZ │\n",
"│ 2025-01-02 ┆ 0.046321 ┆ 000691.SZ │\n",
"│ … ┆ … ┆ … │\n",
"│ 2025-01-06 ┆ 0.092165 ┆ 002816.SZ │\n",
"│ 2025-01-06 ┆ 0.066779 ┆ 002836.SZ │\n",
"│ 2025-01-06 ┆ 0.05733 ┆ 002634.SZ │\n",
"│ 2025-01-06 ┆ 0.054118 ┆ 000004.SZ │\n",
"│ 2025-01-06 ┆ 0.052639 ┆ 600857.SH │\n",
"└────────────┴──────────┴───────────┘\n",
"\n",
"训练流程完成!\n"
]
}
],
"execution_count": 24
},
{
"metadata": {},
"cell_type": "markdown",
"source": [
"## 5. 总结\n",
"#\n",
"本 Notebook 实现了完整的 Learn-to-Rank 训练流程:\n",
"#\n",
"### 核心步骤\n",
"#\n",
"1. **数据准备**: 计算 49 个特征因子,将 `future_return_5` 转换为 20 分位数标签\n",
"2. **模型训练**: 使用 LightGBM LambdaRank 学习每日股票排序\n",
"3. **模型评估**: 使用 NDCG@1/5/10/20 评估排序质量\n",
"4. **策略分析**: 基于排序分数构建 Top-k 选股策略\n",
"#\n",
"### 关键参数\n",
"#\n",
"- **Objective**: lambdarank\n",
"- **Metric**: ndcg\n",
"- **Learning Rate**: 0.05\n",
"- **Num Leaves**: 31\n",
"- **N Quantiles**: 20\n",
"#\n",
"### 输出结果\n",
"#\n",
"- rank_output.csv: 每日Top-N推荐股票格式date, score, ts_code\n",
"- 特征重要性排名\n",
"- Top-k 策略统计和图表\n",
"- NDCG训练指标曲线\n",
"#\n",
"### 后续优化方向\n",
"#\n",
"1. **特征工程**: 尝试更多因子组合\n",
"2. **超参数调优**: 使用网格搜索优化 LambdaRank 参数\n",
"3. **模型集成**: 结合多个排序模型的预测\n",
"4. **更复杂的分组**: 考虑按行业分组排序\n",
"#\n"
]
},
{
"metadata": {},
"cell_type": "markdown",
"source": [
"## 5. 总结\n",
"#\n",
"本 Notebook 实现了完整的 Learn-to-Rank 训练流程:\n",
"#\n",
"### 核心步骤\n",
"#\n",
"1. **数据准备**: 计算 49 个特征因子,将 `future_return_5` 转换为 20 分位数标签\n",
"2. **模型训练**: 使用 LightGBM LambdaRank 学习每日股票排序\n",
"3. **模型评估**: 使用 NDCG@1/5/10/20 评估排序质量\n",
"4. **策略分析**: 基于排序分数构建 Top-k 选股策略\n",
"#\n",
"### 关键参数\n",
"#\n",
"- **Objective**: lambdarank\n",
"- **Metric**: ndcg\n",
"- **Learning Rate**: 0.05\n",
"- **Num Leaves**: 31\n",
"- **N Quantiles**: 20\n",
"#\n",
"### 输出结果\n",
"#\n",
"- rank_output.csv: 每日Top-N推荐股票格式date, score, ts_code\n",
"- 特征重要性排名\n",
"- Top-k 策略统计和图表\n",
"- NDCG训练指标曲线\n",
"#\n",
"### 后续优化方向\n",
"#\n",
"1. **特征工程**: 尝试更多因子组合\n",
"2. **超参数调优**: 使用网格搜索优化 LambdaRank 参数\n",
"3. **模型集成**: 结合多个排序模型的预测\n",
"4. **更复杂的分组**: 考虑按行业分组排序\n"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.0"
}
},
"nbformat": 4,
"nbformat_minor": 4
}