2026-03-10 22:23:44 +08:00
{
"cells": [
{
"metadata": {},
2026-03-11 22:54:52 +08:00
"cell_type": "markdown",
2026-03-10 22:23:44 +08:00
"source": [
"# Learn-to-Rank 排序学习训练流程\n",
2026-03-11 22:54:52 +08:00
"#\n",
2026-03-10 22:23:44 +08:00
"本 Notebook 实现基于 LightGBM LambdaRank 的排序学习训练,用于股票排序任务。\n",
2026-03-11 22:54:52 +08:00
"#\n",
2026-03-10 22:23:44 +08:00
"## 核心特点\n",
2026-03-11 22:54:52 +08:00
"#\n",
2026-03-10 22:23:44 +08:00
"1. **Label 转换**: 将 `future_return_5` 按每日进行 20 分位数划分( qcut) \n",
"2. **排序学习**: 使用 LambdaRank 目标函数,学习每日股票排序\n",
"3. **NDCG 评估**: 使用 NDCG@1/5/10/20 评估排序质量\n",
"4. **策略回测**: 基于排序分数构建 Top-k 选股策略"
]
},
{
"metadata": {},
2026-03-11 22:54:52 +08:00
"cell_type": "markdown",
"source": "## 1. 导入依赖"
2026-03-10 22:23:44 +08:00
},
{
2026-03-12 22:34:25 +08:00
"metadata": {
"ExecuteTime": {
2026-03-14 02:12:20 +08:00
"end_time": "2026-03-13T18:10:00.178020Z",
"start_time": "2026-03-13T18:09:58.791253Z"
2026-03-12 22:34:25 +08:00
}
},
2026-03-10 22:23:44 +08:00
"cell_type": "code",
"source": [
"import os\n",
"from datetime import datetime\n",
"from typing import List, Tuple, Optional\n",
"\n",
"import numpy as np\n",
"import polars as pl\n",
"import pandas as pd\n",
"import matplotlib.pyplot as plt\n",
"from sklearn.metrics import ndcg_score\n",
"\n",
"from src.factors import FactorEngine\n",
"from src.training import (\n",
" DateSplitter,\n",
" STFilter,\n",
" StockPoolManager,\n",
" Trainer,\n",
" Winsorizer,\n",
" NullFiller,\n",
" StandardScaler,\n",
2026-03-13 22:24:12 +08:00
" check_data_quality,\n",
2026-03-10 22:23:44 +08:00
")\n",
"from src.training.components.models import LightGBMLambdaRankModel\n",
2026-03-11 22:54:52 +08:00
"from src.training.config import TrainingConfig\n",
"\n"
2026-03-12 22:34:25 +08:00
],
"outputs": [],
"execution_count": 1
2026-03-10 22:23:44 +08:00
},
{
"metadata": {},
2026-03-11 22:54:52 +08:00
"cell_type": "markdown",
"source": "## 2. 辅助函数"
2026-03-10 22:23:44 +08:00
},
{
2026-03-12 22:34:25 +08:00
"metadata": {
"ExecuteTime": {
2026-03-14 02:12:20 +08:00
"end_time": "2026-03-13T18:10:00.209449Z",
"start_time": "2026-03-13T18:10:00.197996Z"
2026-03-12 22:34:25 +08:00
}
},
2026-03-10 22:23:44 +08:00
"cell_type": "code",
"source": [
2026-03-14 00:19:03 +08:00
"def register_factors(\n",
2026-03-12 22:34:25 +08:00
" engine: FactorEngine,\n",
" selected_factors: List[str],\n",
" factor_definitions: dict,\n",
" label_factor: dict,\n",
2026-03-11 22:54:52 +08:00
") -> List[str]:\n",
2026-03-14 00:19:03 +08:00
" \"\"\"注册因子( selected_factors 从 metadata 查询, factor_definitions 用 DSL 表达式注册)\"\"\"\n",
2026-03-10 22:23:44 +08:00
" print(\"=\" * 80)\n",
2026-03-12 22:34:25 +08:00
" print(\"注册因子\")\n",
2026-03-10 22:23:44 +08:00
" print(\"=\" * 80)\n",
"\n",
2026-03-12 22:34:25 +08:00
" # 注册 SELECTED_FACTORS 中的因子(已在 metadata 中)\n",
2026-03-11 22:54:52 +08:00
" print(\"\\n注册特征因子( 从 metadata) :\")\n",
2026-03-12 22:34:25 +08:00
" for name in selected_factors:\n",
" engine.add_factor(name)\n",
2026-03-11 22:54:52 +08:00
" print(f\" - {name}\")\n",
2026-03-10 22:23:44 +08:00
"\n",
2026-03-12 22:34:25 +08:00
" # 注册 FACTOR_DEFINITIONS 中的因子(通过表达式,尚未在 metadata 中)\n",
" print(\"\\n注册特征因子( 表达式) :\")\n",
" for name, expr in factor_definitions.items():\n",
" engine.add_factor(name, expr)\n",
" print(f\" - {name}: {expr}\")\n",
"\n",
" # 注册 label 因子(通过表达式)\n",
2026-03-11 22:54:52 +08:00
" print(\"\\n注册 Label 因子(表达式):\")\n",
2026-03-10 22:23:44 +08:00
" for name, expr in label_factor.items():\n",
" engine.add_factor(name, expr)\n",
" print(f\" - {name}: {expr}\")\n",
"\n",
2026-03-12 22:34:25 +08:00
" # 特征列 = SELECTED_FACTORS + FACTOR_DEFINITIONS 的 keys\n",
" feature_cols = selected_factors + list(factor_definitions.keys())\n",
2026-03-10 22:23:44 +08:00
"\n",
" print(f\"\\n特征因子数: {len(feature_cols)}\")\n",
2026-03-12 22:34:25 +08:00
" print(f\" - 来自 metadata: {len(selected_factors)}\")\n",
" print(f\" - 来自表达式: {len(factor_definitions)}\")\n",
2026-03-10 22:23:44 +08:00
" print(f\"Label: {list(label_factor.keys())[0]}\")\n",
" print(f\"已注册因子总数: {len(engine.list_registered())}\")\n",
"\n",
" return feature_cols\n",
"\n",
"\n",
"def prepare_data(\n",
" engine: FactorEngine,\n",
" feature_cols: List[str],\n",
" start_date: str,\n",
" end_date: str,\n",
") -> pl.DataFrame:\n",
" \"\"\"准备数据\"\"\"\n",
" print(\"\\n\" + \"=\" * 80)\n",
" print(\"准备数据\")\n",
" print(\"=\" * 80)\n",
"\n",
" # 计算因子(全市场数据)\n",
" print(f\"\\n计算因子: {start_date} - {end_date}\")\n",
" factor_names = feature_cols + [LABEL_NAME] # 包含 label\n",
"\n",
" data = engine.compute(\n",
" factor_names=factor_names,\n",
" start_date=start_date,\n",
" end_date=end_date,\n",
" )\n",
"\n",
" print(f\"数据形状: {data.shape}\")\n",
" print(f\"数据列: {data.columns}\")\n",
" print(f\"\\n前5行预览:\")\n",
" print(data.head())\n",
"\n",
" return data\n",
"\n",
"\n",
"def prepare_ranking_data(\n",
" df: pl.DataFrame,\n",
" label_col: str = \"future_return_5\",\n",
" date_col: str = \"trade_date\",\n",
" n_quantiles: int = 20,\n",
") -> Tuple[pl.DataFrame, str]:\n",
" \"\"\"准备排序学习数据\n",
2026-03-11 22:54:52 +08:00
"\n",
2026-03-10 22:23:44 +08:00
" 将连续 label 转换为分位数标签,用于排序学习任务。\n",
2026-03-11 22:54:52 +08:00
"\n",
2026-03-10 22:23:44 +08:00
" Args:\n",
" df: 原始数据\n",
" label_col: 原始标签列名\n",
" date_col: 日期列名\n",
" n_quantiles: 分位数数量\n",
2026-03-11 22:54:52 +08:00
"\n",
2026-03-10 22:23:44 +08:00
" Returns:\n",
" (处理后的 DataFrame, 新的标签列名)\n",
" \"\"\"\n",
" print(\"\\n\" + \"=\" * 80)\n",
" print(f\"准备排序学习数据(将 {label_col} 转换为 {n_quantiles} 分位数标签)\")\n",
" print(\"=\" * 80)\n",
2026-03-11 22:54:52 +08:00
"\n",
2026-03-10 22:23:44 +08:00
" # 新的标签列名\n",
" rank_col = f\"{label_col}_rank\"\n",
2026-03-11 22:54:52 +08:00
"\n",
2026-03-10 22:23:44 +08:00
" # 按日期分组进行分位数划分\n",
" # 使用 rank 生成 0, 1, 2, ..., n_quantiles-1 的标签\n",
" # 方法: 计算每天内的排名,然后映射到 n_quantiles 个分位数组\n",
2026-03-11 22:54:52 +08:00
" df_ranked = (\n",
" df.with_columns(\n",
" # 计算每天内的排名 (1-based)\n",
" pl.col(label_col).rank(method=\"min\").over(date_col).alias(\"_rank\")\n",
" )\n",
" .with_columns(\n",
" # 将排名转换为分位数标签 (0 to n_quantiles-1)\n",
" ((pl.col(\"_rank\") - 1) / pl.len().over(date_col) * n_quantiles)\n",
" .floor()\n",
" .cast(pl.Int64)\n",
" .clip(0, n_quantiles - 1)\n",
" .alias(rank_col)\n",
" )\n",
" .drop(\"_rank\")\n",
" )\n",
"\n",
2026-03-10 22:23:44 +08:00
" # 检查转换结果\n",
" print(f\"\\n原始 {label_col} 统计:\")\n",
" print(df_ranked[label_col].describe())\n",
2026-03-11 22:54:52 +08:00
"\n",
2026-03-10 22:23:44 +08:00
" print(f\"\\n转换后 {rank_col} 统计:\")\n",
" print(df_ranked[rank_col].describe())\n",
2026-03-11 22:54:52 +08:00
"\n",
2026-03-10 22:23:44 +08:00
" # 检查每日样本分布\n",
" print(f\"\\n每日样本数统计:\")\n",
" daily_counts = df_ranked.group_by(date_col).agg(pl.count().alias(\"count\"))\n",
" print(daily_counts[\"count\"].describe())\n",
2026-03-11 22:54:52 +08:00
"\n",
2026-03-10 22:23:44 +08:00
" # 检查分位数分布(应该是均匀的)\n",
" print(f\"\\n分位数标签分布:\")\n",
" rank_dist = df_ranked[rank_col].value_counts().sort(rank_col)\n",
" print(rank_dist)\n",
2026-03-11 22:54:52 +08:00
"\n",
2026-03-10 22:23:44 +08:00
" return df_ranked, rank_col\n",
"\n",
"\n",
"def compute_group_array(df: pl.DataFrame, date_col: str = \"trade_date\") -> np.ndarray:\n",
" \"\"\"计算 group 数组用于 LambdaRank\n",
2026-03-11 22:54:52 +08:00
"\n",
2026-03-10 22:23:44 +08:00
" 每个日期作为一个 query, group 数组表示每个 query 的样本数。\n",
2026-03-11 22:54:52 +08:00
"\n",
2026-03-10 22:23:44 +08:00
" Args:\n",
" df: 数据框\n",
" date_col: 日期列名\n",
2026-03-11 22:54:52 +08:00
"\n",
2026-03-10 22:23:44 +08:00
" Returns:\n",
" group 数组\n",
" \"\"\"\n",
" group_counts = df.group_by(date_col, maintain_order=True).agg(\n",
" pl.count().alias(\"count\")\n",
" )\n",
" return group_counts[\"count\"].to_numpy()\n",
"\n",
"\n",
"def evaluate_ndcg_at_k(\n",
" y_true: np.ndarray,\n",
" y_pred: np.ndarray,\n",
" group: np.ndarray,\n",
" k_list: List[int] = [1, 5, 10, 20],\n",
") -> dict:\n",
" \"\"\"计算 NDCG@k 指标\n",
2026-03-11 22:54:52 +08:00
"\n",
2026-03-10 22:23:44 +08:00
" Args:\n",
" y_true: 真实标签\n",
" y_pred: 预测分数\n",
" group: 分组数组\n",
" k_list: 要计算的 k 值列表\n",
2026-03-11 22:54:52 +08:00
"\n",
2026-03-10 22:23:44 +08:00
" Returns:\n",
" NDCG 指标字典\n",
" \"\"\"\n",
" results = {}\n",
2026-03-11 22:54:52 +08:00
"\n",
2026-03-10 22:23:44 +08:00
" # 按 group 拆分数据\n",
" start_idx = 0\n",
" y_true_groups = []\n",
" y_pred_groups = []\n",
2026-03-11 22:54:52 +08:00
"\n",
2026-03-10 22:23:44 +08:00
" for group_size in group:\n",
" end_idx = start_idx + group_size\n",
" y_true_groups.append(y_true[start_idx:end_idx])\n",
" y_pred_groups.append(y_pred[start_idx:end_idx])\n",
" start_idx = end_idx\n",
2026-03-11 22:54:52 +08:00
"\n",
2026-03-10 22:23:44 +08:00
" # 计算每个 k 值的平均 NDCG\n",
" for k in k_list:\n",
" ndcg_scores = []\n",
" for yt, yp in zip(y_true_groups, y_pred_groups):\n",
" if len(yt) > 1:\n",
" try:\n",
" score = ndcg_score([yt], [yp], k=k)\n",
" ndcg_scores.append(score)\n",
" except ValueError:\n",
" # 标签都相同,无法计算\n",
" pass\n",
2026-03-11 22:54:52 +08:00
"\n",
2026-03-10 22:23:44 +08:00
" results[f\"ndcg@{k}\"] = np.mean(ndcg_scores) if ndcg_scores else 0.0\n",
2026-03-11 22:54:52 +08:00
"\n",
2026-03-10 22:23:44 +08:00
" return results\n",
2026-03-11 00:12:05 +08:00
"\n"
2026-03-12 22:34:25 +08:00
],
"outputs": [],
"execution_count": 2
2026-03-10 22:23:44 +08:00
},
{
"metadata": {},
2026-03-11 22:54:52 +08:00
"cell_type": "markdown",
2026-03-10 22:23:44 +08:00
"source": [
"## 3. 配置参数\n",
2026-03-11 22:54:52 +08:00
"#\n",
2026-03-10 22:23:44 +08:00
"### 3.1 因子定义"
]
},
{
2026-03-12 22:34:25 +08:00
"metadata": {
"ExecuteTime": {
2026-03-14 02:12:20 +08:00
"end_time": "2026-03-13T18:10:00.241454Z",
"start_time": "2026-03-13T18:10:00.237620Z"
2026-03-12 22:34:25 +08:00
}
},
2026-03-10 22:23:44 +08:00
"cell_type": "code",
"source": [
"# 特征因子定义字典(复用 regression.ipynb 的因子定义)\n",
"LABEL_NAME = \"future_return_5_rank\"\n",
"\n",
2026-03-12 22:34:25 +08:00
"# 当前选择的因子列表(从 FACTOR_DEFINITIONS 中选择要使用的因子)\n",
"SELECTED_FACTORS = [\n",
" # ================= 1. 价格、趋势与路径依赖 =================\n",
" \"ma_5\",\n",
" \"ma_20\",\n",
" \"ma_ratio_5_20\",\n",
" \"bias_10\",\n",
" \"high_low_ratio\",\n",
" \"bbi_ratio\",\n",
" \"return_5\",\n",
" \"return_20\",\n",
" \"kaufman_ER_20\",\n",
" \"mom_acceleration_10_20\",\n",
" \"drawdown_from_high_60\",\n",
" \"up_days_ratio_20\",\n",
2026-03-10 22:23:44 +08:00
" # ================= 2. 波动率、风险调整与高阶矩 =================\n",
2026-03-12 22:34:25 +08:00
" \"volatility_5\",\n",
" \"volatility_20\",\n",
" \"volatility_ratio\",\n",
" \"std_return_20\",\n",
" \"sharpe_ratio_20\",\n",
" \"min_ret_20\",\n",
" \"volatility_squeeze_5_60\",\n",
2026-03-10 22:23:44 +08:00
" # ================= 3. 日内微观结构与异象 =================\n",
2026-03-12 22:34:25 +08:00
" \"overnight_intraday_diff\",\n",
" \"upper_shadow_ratio\",\n",
" \"capital_retention_20\",\n",
" \"max_ret_20\",\n",
2026-03-10 22:23:44 +08:00
" # ================= 4. 量能、流动性与量价背离 =================\n",
2026-03-12 22:34:25 +08:00
" \"volume_ratio_5_20\",\n",
" \"turnover_rate_mean_5\",\n",
" \"turnover_deviation\",\n",
" \"amihud_illiq_20\",\n",
" \"turnover_cv_20\",\n",
" \"pv_corr_20\",\n",
" \"close_vwap_deviation\",\n",
2026-03-10 22:23:44 +08:00
" # ================= 5. 基本面财务特征 =================\n",
2026-03-12 22:34:25 +08:00
" \"roe\",\n",
" \"roa\",\n",
" \"profit_margin\",\n",
" \"debt_to_equity\",\n",
" \"current_ratio\",\n",
" \"net_profit_yoy\",\n",
" \"revenue_yoy\",\n",
" \"healthy_expansion_velocity\",\n",
2026-03-10 22:23:44 +08:00
" # ================= 6. 基本面估值与截面动量共振 =================\n",
2026-03-12 22:34:25 +08:00
" \"EP\",\n",
" \"BP\",\n",
" \"CP\",\n",
" \"market_cap_rank\",\n",
" \"turnover_rank\",\n",
" \"return_5_rank\",\n",
" \"EP_rank\",\n",
" \"pe_expansion_trend\",\n",
2026-03-14 02:12:20 +08:00
" \"value_price_divergence\",\n",
2026-03-12 22:34:25 +08:00
" \"active_market_cap\",\n",
2026-03-14 02:12:20 +08:00
" \"ebit_rank\",\n",
2026-03-12 22:34:25 +08:00
"]\n",
"\n",
"# 因子定义字典(完整因子库)\n",
2026-03-13 22:24:12 +08:00
"FACTOR_DEFINITIONS = {\n",
2026-03-14 00:19:03 +08:00
" # \"turnover_rate_volatility\": \"ts_std(log(turnover_rate), 20)\"\n",
2026-03-13 22:24:12 +08:00
"}\n",
2026-03-10 22:23:44 +08:00
"\n",
"# Label 因子定义(不参与训练,用于计算目标)\n",
"LABEL_FACTOR = {\n",
" LABEL_NAME: \"(ts_delay(close, -5) / ts_delay(open, -1)) - 1\",\n",
"}"
2026-03-12 22:34:25 +08:00
],
"outputs": [],
"execution_count": 3
2026-03-10 22:23:44 +08:00
},
{
"metadata": {},
2026-03-11 22:54:52 +08:00
"cell_type": "markdown",
"source": "### 3.2 训练参数配置"
2026-03-10 22:23:44 +08:00
},
{
2026-03-12 22:34:25 +08:00
"metadata": {
"ExecuteTime": {
2026-03-14 02:12:20 +08:00
"end_time": "2026-03-13T18:10:00.257506Z",
"start_time": "2026-03-13T18:10:00.253934Z"
2026-03-12 22:34:25 +08:00
}
},
2026-03-10 22:23:44 +08:00
"cell_type": "code",
"source": [
"# 日期范围配置(正确的 train/val/test 三分法)\n",
"TRAIN_START = \"20200101\"\n",
"TRAIN_END = \"20231231\"\n",
"VAL_START = \"20240101\"\n",
"VAL_END = \"20241231\"\n",
"TEST_START = \"20250101\"\n",
2026-03-11 00:12:05 +08:00
"TEST_END = \"20251231\"\n",
2026-03-10 22:23:44 +08:00
"\n",
2026-03-13 22:24:12 +08:00
"\n",
"# 分位数配置\n",
"N_QUANTILES = 20 # 将 label 分为 20 组\n",
"\n",
2026-03-10 22:23:44 +08:00
"# LambdaRank 模型参数配置\n",
"MODEL_PARAMS = {\n",
" \"objective\": \"lambdarank\",\n",
" \"metric\": \"ndcg\",\n",
2026-03-13 22:24:12 +08:00
" \"ndcg_at\": 10, # 评估 NDCG@k\n",
" \"learning_rate\": 0.01,\n",
2026-03-10 22:23:44 +08:00
" \"num_leaves\": 31,\n",
2026-03-13 22:24:12 +08:00
" \"max_depth\": 4,\n",
2026-03-10 22:23:44 +08:00
" \"min_data_in_leaf\": 20,\n",
2026-03-13 22:24:12 +08:00
" \"n_estimators\": 2000,\n",
" \"early_stopping_round\": 300,\n",
2026-03-10 22:23:44 +08:00
" \"subsample\": 0.8,\n",
" \"colsample_bytree\": 0.8,\n",
" \"reg_alpha\": 0.1,\n",
" \"reg_lambda\": 1.0,\n",
" \"verbose\": -1,\n",
" \"random_state\": 42,\n",
2026-03-13 22:24:12 +08:00
" \"lambdarank_truncation_level\": 10,\n",
" \"label_gain\": [i for i in range(1, N_QUANTILES + 1)],\n",
2026-03-10 22:23:44 +08:00
"}\n",
"\n",
"\n",
"# 股票池筛选函数\n",
"def stock_pool_filter(df: pl.DataFrame) -> pl.Series:\n",
" \"\"\"股票池筛选函数(单日数据)\n",
2026-03-11 22:54:52 +08:00
"\n",
2026-03-10 22:23:44 +08:00
" 筛选条件:\n",
" 1. 排除创业板(代码以 300 开头)\n",
" 2. 排除科创板(代码以 688 开头)\n",
" 3. 排除北交所(代码以 8、9 或 4 开头)\n",
" 4. 选取当日市值最小的500只股票\n",
" \"\"\"\n",
" code_filter = (\n",
2026-03-11 22:54:52 +08:00
" ~df[\"ts_code\"].str.starts_with(\"30\")\n",
" & ~df[\"ts_code\"].str.starts_with(\"68\")\n",
" & ~df[\"ts_code\"].str.starts_with(\"8\")\n",
" & ~df[\"ts_code\"].str.starts_with(\"9\")\n",
" & ~df[\"ts_code\"].str.starts_with(\"4\")\n",
2026-03-10 22:23:44 +08:00
" )\n",
2026-03-11 22:54:52 +08:00
"\n",
2026-03-10 22:23:44 +08:00
" valid_df = df.filter(code_filter)\n",
" n = min(1000, len(valid_df))\n",
" small_cap_codes = valid_df.sort(\"total_mv\").head(n)[\"ts_code\"]\n",
2026-03-11 22:54:52 +08:00
"\n",
2026-03-10 22:23:44 +08:00
" return df[\"ts_code\"].is_in(small_cap_codes)\n",
"\n",
2026-03-11 22:54:52 +08:00
"\n",
2026-03-10 22:23:44 +08:00
"STOCK_FILTER_REQUIRED_COLUMNS = [\"total_mv\"]\n",
"\n",
"# 输出配置\n",
"OUTPUT_DIR = \"output\"\n",
"SAVE_PREDICTIONS = True\n",
"PERSIST_MODEL = False\n",
"\n",
"# Top N 配置:每日推荐股票数量\n",
"TOP_N = 5 # 可调整为 10, 20 等"
2026-03-12 22:34:25 +08:00
],
"outputs": [],
"execution_count": 4
2026-03-10 22:23:44 +08:00
},
{
"metadata": {},
2026-03-11 22:54:52 +08:00
"cell_type": "markdown",
"source": "## 4. 训练流程"
2026-03-10 22:23:44 +08:00
},
{
2026-03-12 22:34:25 +08:00
"metadata": {
"ExecuteTime": {
2026-03-14 02:12:20 +08:00
"end_time": "2026-03-13T18:10:11.844363Z",
"start_time": "2026-03-13T18:10:00.265162Z"
2026-03-12 22:34:25 +08:00
}
},
2026-03-10 22:23:44 +08:00
"cell_type": "code",
"source": [
"print(\"\\n\" + \"=\" * 80)\n",
"print(\"LightGBM LambdaRank 排序学习训练\")\n",
"print(\"=\" * 80)\n",
"\n",
2026-03-11 22:54:52 +08:00
"# 1. 创建 FactorEngine( 启用 metadata 功能)\n",
2026-03-10 22:23:44 +08:00
"print(\"\\n[1] 创建 FactorEngine\")\n",
2026-03-13 22:24:12 +08:00
"engine = FactorEngine()\n",
2026-03-10 22:23:44 +08:00
"\n",
2026-03-11 22:54:52 +08:00
"# 2. 使用 metadata 定义因子\n",
"print(\"\\n[2] 定义因子(从 metadata 注册)\")\n",
2026-03-14 00:19:03 +08:00
"feature_cols = register_factors(\n",
2026-03-12 22:34:25 +08:00
" engine, SELECTED_FACTORS, FACTOR_DEFINITIONS, LABEL_FACTOR\n",
")\n",
2026-03-10 22:23:44 +08:00
"\n",
"# 3. 准备数据\n",
"print(\"\\n[3] 准备数据\")\n",
"data = prepare_data(\n",
" engine=engine,\n",
" feature_cols=feature_cols,\n",
" start_date=TRAIN_START,\n",
" end_date=TEST_END,\n",
")\n",
"\n",
"# 4. 转换为排序学习格式(分位数标签)\n",
"print(\"\\n[4] 转换为排序学习格式\")\n",
"data, target_col = prepare_ranking_data(\n",
" df=data,\n",
" label_col=LABEL_NAME,\n",
" n_quantiles=N_QUANTILES,\n",
")\n",
"\n",
"# 5. 打印配置信息\n",
"print(f\"\\n[配置] 训练期: {TRAIN_START} - {TRAIN_END}\")\n",
"print(f\"[配置] 验证期: {VAL_START} - {VAL_END}\")\n",
"print(f\"[配置] 测试期: {TEST_START} - {TEST_END}\")\n",
"print(f\"[配置] 特征数: {len(feature_cols)}\")\n",
"print(f\"[配置] 目标变量: {target_col}( {N_QUANTILES}分位数)\")\n",
"\n",
"# 6. 创建排序学习模型\n",
2026-03-14 00:19:03 +08:00
"model = LightGBMLambdaRankModel(params=MODEL_PARAMS)\n",
2026-03-10 22:23:44 +08:00
"\n",
2026-03-14 00:19:03 +08:00
"# 7. 创建数据处理器(使用函数返回的完整特征列表)\n",
"processors = [\n",
" NullFiller(feature_cols=feature_cols, strategy=\"mean\"),\n",
" Winsorizer(feature_cols=feature_cols, lower=0.01, upper=0.99),\n",
" StandardScaler(feature_cols=feature_cols),\n",
"]\n",
2026-03-10 22:23:44 +08:00
"\n",
"# 8. 创建数据划分器\n",
"splitter = DateSplitter(\n",
" train_start=TRAIN_START,\n",
" train_end=TRAIN_END,\n",
" val_start=VAL_START,\n",
" val_end=VAL_END,\n",
" test_start=TEST_START,\n",
" test_end=TEST_END,\n",
")\n",
"\n",
"# 9. 创建股票池管理器\n",
"pool_manager = StockPoolManager(\n",
" filter_func=stock_pool_filter,\n",
" required_columns=STOCK_FILTER_REQUIRED_COLUMNS,\n",
" data_router=engine.router,\n",
")\n",
"\n",
"# 10. 创建 ST 过滤器\n",
"st_filter = STFilter(data_router=engine.router)\n",
"\n",
"# 11. 创建训练器\n",
"trainer = Trainer(\n",
" model=model,\n",
" pool_manager=pool_manager,\n",
" processors=processors,\n",
" filters=[st_filter],\n",
" splitter=splitter,\n",
" target_col=target_col,\n",
" feature_cols=feature_cols,\n",
" persist_model=PERSIST_MODEL,\n",
")"
2026-03-12 22:34:25 +08:00
],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"================================================================================\n",
"LightGBM LambdaRank 排序学习训练\n",
"================================================================================\n",
"\n",
"[1] 创建 FactorEngine\n",
"\n",
"[2] 定义因子(从 metadata 注册)\n",
"================================================================================\n",
"注册因子\n",
"================================================================================\n",
"\n",
2026-03-13 22:24:12 +08:00
"注册特征因子(从 metadata) :\n",
" - ma_5\n",
" - ma_20\n",
" - ma_ratio_5_20\n",
" - bias_10\n",
" - high_low_ratio\n",
" - bbi_ratio\n",
" - return_5\n",
" - return_20\n",
" - kaufman_ER_20\n",
" - mom_acceleration_10_20\n",
" - drawdown_from_high_60\n",
" - up_days_ratio_20\n",
" - volatility_5\n",
" - volatility_20\n",
" - volatility_ratio\n",
" - std_return_20\n",
" - sharpe_ratio_20\n",
" - min_ret_20\n",
" - volatility_squeeze_5_60\n",
" - overnight_intraday_diff\n",
" - upper_shadow_ratio\n",
" - capital_retention_20\n",
" - max_ret_20\n",
" - volume_ratio_5_20\n",
" - turnover_rate_mean_5\n",
" - turnover_deviation\n",
" - amihud_illiq_20\n",
" - turnover_cv_20\n",
" - pv_corr_20\n",
" - close_vwap_deviation\n",
" - roe\n",
" - roa\n",
" - profit_margin\n",
" - debt_to_equity\n",
" - current_ratio\n",
" - net_profit_yoy\n",
" - revenue_yoy\n",
" - healthy_expansion_velocity\n",
" - EP\n",
" - BP\n",
" - CP\n",
" - market_cap_rank\n",
" - turnover_rank\n",
" - return_5_rank\n",
" - EP_rank\n",
" - pe_expansion_trend\n",
2026-03-14 02:12:20 +08:00
" - value_price_divergence\n",
2026-03-13 22:24:12 +08:00
" - active_market_cap\n",
2026-03-14 02:12:20 +08:00
" - ebit_rank\n",
2026-03-13 22:24:12 +08:00
"\n",
"注册特征因子(表达式):\n",
2026-03-14 02:12:20 +08:00
" - turnover_rate_volatility: ts_std(log(turnover_rate), 20)\n",
2026-03-13 22:24:12 +08:00
"\n",
"注册 Label 因子(表达式):\n",
" - future_return_5_rank: (ts_delay(close, -5) / ts_delay(open, -1)) - 1\n",
"\n",
2026-03-14 02:12:20 +08:00
"特征因子数: 50\n",
" - 来自 metadata: 49\n",
" - 来自表达式: 1\n",
2026-03-13 22:24:12 +08:00
"Label: future_return_5_rank\n",
2026-03-14 02:12:20 +08:00
"已注册因子总数: 64\n",
2026-03-13 22:24:12 +08:00
"\n",
"[3] 准备数据\n",
"\n",
"================================================================================\n",
"准备数据\n",
"================================================================================\n",
"\n",
2026-03-14 00:19:03 +08:00
"计算因子: 20200101 - 20251231\n"
2026-03-12 22:34:25 +08:00
]
},
{
2026-03-13 22:24:12 +08:00
"name": "stderr",
"output_type": "stream",
"text": [
"D:\\PyProject\\ProStock\\src\\data\\financial_loader.py:148: UserWarning: Sortedness of columns cannot be checked when 'by' groups provided\n",
" merged = df_price.join_asof(\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
2026-03-14 02:12:20 +08:00
"数据形状: (7044952, 71)\n",
"数据列: ['ts_code', 'trade_date', 'open', 'high', 'turnover_rate', 'low', 'vol', 'amount', 'close', 'total_assets', 'total_mv', 'f_ann_date', 'ebit', 'revenue', 'n_income', 'total_liab', 'total_hldr_eqy_exc_min_int', 'total_cur_assets', 'total_cur_liab', 'n_cashflow_act', 'ma_5', 'ma_20', 'ma_ratio_5_20', 'bias_10', 'high_low_ratio', 'bbi_ratio', 'return_5', 'return_20', 'kaufman_ER_20', 'mom_acceleration_10_20', 'drawdown_from_high_60', 'up_days_ratio_20', 'volatility_5', 'volatility_20', 'volatility_ratio', 'std_return_20', 'sharpe_ratio_20', 'min_ret_20', 'volatility_squeeze_5_60', 'overnight_intraday_diff', 'upper_shadow_ratio', 'capital_retention_20', 'max_ret_20', 'volume_ratio_5_20', 'turnover_rate_mean_5', 'turnover_deviation', 'amihud_illiq_20', 'turnover_cv_20', 'pv_corr_20', 'close_vwap_deviation', 'roe', 'roa', 'profit_margin', 'debt_to_equity', 'current_ratio', 'net_profit_yoy', 'revenue_yoy', 'healthy_expansion_velocity', 'EP', 'BP', 'CP', 'market_cap_rank', 'turnover_rank', 'return_5_rank', 'EP_rank', 'pe_expansion_trend', 'value_price_divergence', 'active_market_cap', 'ebit_rank', 'turnover_rate_volatility', 'future_return_5_rank']\n",
2026-03-13 22:24:12 +08:00
"\n",
"前5行预览:\n",
2026-03-14 02:12:20 +08:00
"shape: (5, 71)\n",
"┌───────────┬────────────┬─────────┬─────────┬───┬────────────┬───────────┬────────────┬───────────┐\n",
"│ ts_code ┆ trade_date ┆ open ┆ high ┆ … ┆ active_mar ┆ ebit_rank ┆ turnover_r ┆ future_re │\n",
"│ --- ┆ --- ┆ --- ┆ --- ┆ ┆ ket_cap ┆ --- ┆ ate_volati ┆ turn_5_ra │\n",
"│ str ┆ str ┆ f64 ┆ f64 ┆ ┆ --- ┆ f64 ┆ lity ┆ nk │\n",
"│ ┆ ┆ ┆ ┆ ┆ f64 ┆ ┆ --- ┆ --- │\n",
"│ ┆ ┆ ┆ ┆ ┆ ┆ ┆ f64 ┆ f64 │\n",
"╞═══════════╪════════════╪═════════╪═════════╪═══╪════════════╪═══════════╪════════════╪═══════════╡\n",
"│ 000001.SZ ┆ 20200102 ┆ 1817.67 ┆ 1850.42 ┆ … ┆ null ┆ null ┆ null ┆ -0.008857 │\n",
"│ 000001.SZ ┆ 20200103 ┆ 1849.33 ┆ 1889.72 ┆ … ┆ null ┆ null ┆ null ┆ -0.01881 │\n",
"│ 000001.SZ ┆ 20200106 ┆ 1856.97 ┆ 1893.0 ┆ … ┆ null ┆ null ┆ null ┆ -0.008171 │\n",
"│ 000001.SZ ┆ 20200107 ┆ 1870.07 ┆ 1886.45 ┆ … ┆ null ┆ null ┆ null ┆ -0.014117 │\n",
"│ 000001.SZ ┆ 20200108 ┆ 1855.88 ┆ 1861.34 ┆ … ┆ null ┆ null ┆ null ┆ -0.017252 │\n",
"└───────────┴────────────┴─────────┴─────────┴───┴────────────┴───────────┴────────────┴───────────┘\n",
2026-03-13 22:24:12 +08:00
"\n",
"[4] 转换为排序学习格式\n",
"\n",
"================================================================================\n",
"准备排序学习数据(将 future_return_5_rank 转换为 20 分位数标签)\n",
"================================================================================\n",
"\n",
"原始 future_return_5_rank 统计:\n",
"shape: (9, 2)\n",
"┌────────────┬───────────┐\n",
"│ statistic ┆ value │\n",
"│ --- ┆ --- │\n",
"│ str ┆ f64 │\n",
"╞════════════╪═══════════╡\n",
"│ count ┆ 7.01659e6 │\n",
"│ null_count ┆ 28362.0 │\n",
"│ mean ┆ 0.003779 │\n",
"│ std ┆ 0.073221 │\n",
"│ min ┆ -0.969459 │\n",
"│ 25% ┆ -0.033163 │\n",
"│ 50% ┆ -0.001483 │\n",
"│ 75% ┆ 0.032547 │\n",
"│ max ┆ 10.361925 │\n",
"└────────────┴───────────┘\n",
"\n",
"转换后 future_return_5_rank_rank 统计:\n",
"shape: (9, 2)\n",
"┌────────────┬───────────┐\n",
"│ statistic ┆ value │\n",
"│ --- ┆ --- │\n",
"│ str ┆ f64 │\n",
"╞════════════╪═══════════╡\n",
"│ count ┆ 7.01659e6 │\n",
"│ null_count ┆ 28362.0 │\n",
"│ mean ┆ 9.495412 │\n",
"│ std ┆ 5.765668 │\n",
"│ min ┆ 0.0 │\n",
"│ 25% ┆ 4.0 │\n",
"│ 50% ┆ 9.0 │\n",
"│ 75% ┆ 14.0 │\n",
"│ max ┆ 19.0 │\n",
"└────────────┴───────────┘\n",
"\n",
"每日样本数统计:\n",
"shape: (9, 2)\n",
"┌────────────┬─────────────┐\n",
"│ statistic ┆ value │\n",
"│ --- ┆ --- │\n",
"│ str ┆ f64 │\n",
"╞════════════╪═════════════╡\n",
"│ count ┆ 1455.0 │\n",
"│ null_count ┆ 0.0 │\n",
"│ mean ┆ 4841.891409 │\n",
"│ std ┆ 560.948186 │\n",
"│ min ┆ 3740.0 │\n",
"│ 25% ┆ 4369.0 │\n",
"│ 50% ┆ 5060.0 │\n",
"│ 75% ┆ 5344.0 │\n",
"│ max ┆ 5458.0 │\n",
"└────────────┴─────────────┘\n",
"\n",
"分位数标签分布:\n",
"shape: (21, 2)\n",
"┌───────────────────────────┬────────┐\n",
"│ future_return_5_rank_rank ┆ count │\n",
"│ --- ┆ --- │\n",
"│ i64 ┆ u32 │\n",
"╞═══════════════════════════╪════════╡\n",
"│ null ┆ 28362 │\n",
"│ 0 ┆ 351599 │\n",
"│ 1 ┆ 350894 │\n",
"│ 2 ┆ 350944 │\n",
"│ 3 ┆ 351077 │\n",
"│ … ┆ … │\n",
"│ 15 ┆ 350910 │\n",
"│ 16 ┆ 350835 │\n",
"│ 17 ┆ 350848 │\n",
"│ 18 ┆ 350871 │\n",
"│ 19 ┆ 349137 │\n",
"└───────────────────────────┴────────┘\n",
"\n",
"[配置] 训练期: 20200101 - 20231231\n",
"[配置] 验证期: 20240101 - 20241231\n",
"[配置] 测试期: 20250101 - 20251231\n",
2026-03-14 02:12:20 +08:00
"[配置] 特征数: 50\n",
2026-03-13 22:24:12 +08:00
"[配置] 目标变量: future_return_5_rank_rank( 20分位数) \n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
2026-03-14 02:12:20 +08:00
"C:\\Users\\liaozhaorun\\AppData\\Local\\Temp\\ipykernel_26384\\2929956284.py:125: DeprecationWarning: `pl.count()` is deprecated. Please use `pl.len()` instead.\n",
2026-03-13 22:24:12 +08:00
"(Deprecated in version 0.20.5)\n",
" daily_counts = df_ranked.group_by(date_col).agg(pl.count().alias(\"count\"))\n"
2026-03-12 22:34:25 +08:00
]
}
],
"execution_count": 5
2026-03-10 22:23:44 +08:00
},
{
"metadata": {},
2026-03-11 22:54:52 +08:00
"cell_type": "markdown",
"source": "### 4.1 股票池筛选"
2026-03-10 22:23:44 +08:00
},
{
2026-03-13 22:24:12 +08:00
"metadata": {
"ExecuteTime": {
2026-03-14 02:12:20 +08:00
"end_time": "2026-03-13T18:10:34.925345Z",
"start_time": "2026-03-13T18:10:11.853318Z"
2026-03-13 22:24:12 +08:00
}
},
2026-03-10 22:23:44 +08:00
"cell_type": "code",
"source": [
"print(\"\\n\" + \"=\" * 80)\n",
"print(\"股票池筛选\")\n",
"print(\"=\" * 80)\n",
"\n",
2026-03-11 00:12:05 +08:00
"# 先执行 ST 过滤(在股票池筛选之前,与 Trainer.train() 保持一致)\n",
"if st_filter:\n",
" print(\"\\n[过滤] 应用 ST 过滤器...\")\n",
" data = st_filter.filter(data)\n",
" print(f\" ST 过滤后数据规模: {data.shape}\")\n",
"\n",
2026-03-10 22:23:44 +08:00
"if pool_manager:\n",
" print(\"\\n执行每日独立筛选股票池...\")\n",
" filtered_data = pool_manager.filter_and_select_daily(data)\n",
" print(f\" 筛选前数据规模: {data.shape}\")\n",
" print(f\" 筛选后数据规模: {filtered_data.shape}\")\n",
" print(f\" 筛选前股票数: {data['ts_code'].n_unique()}\")\n",
" print(f\" 筛选后股票数: {filtered_data['ts_code'].n_unique()}\")\n",
" print(f\" 删除记录数: {len(data) - len(filtered_data)}\")\n",
"else:\n",
" filtered_data = data\n",
" print(\" 未配置股票池管理器,跳过筛选\")"
2026-03-13 22:24:12 +08:00
],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"================================================================================\n",
"股票池筛选\n",
"================================================================================\n",
"\n",
"[过滤] 应用 ST 过滤器...\n",
2026-03-14 02:12:20 +08:00
" ST 过滤后数据规模: (6823808, 72)\n",
2026-03-13 22:24:12 +08:00
"\n",
"执行每日独立筛选股票池...\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
2026-03-14 02:12:20 +08:00
"C:\\Users\\liaozhaorun\\AppData\\Local\\Temp\\ipykernel_26384\\4061767669.py:57: DeprecationWarning: `is_in` with a collection of the same datatype is ambiguous and deprecated.\n",
2026-03-13 22:24:12 +08:00
"Please use `implode` to return to previous behavior.\n",
"\n",
"See https://github.com/pola-rs/polars/issues/22149 for more information.\n",
" return df[\"ts_code\"].is_in(small_cap_codes)\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
2026-03-14 02:12:20 +08:00
" 筛选前数据规模: (6823808, 72)\n",
" 筛选后数据规模: (1455000, 72)\n",
2026-03-13 22:24:12 +08:00
" 筛选前股票数: 5678\n",
" 筛选后股票数: 1934\n",
" 删除记录数: 5368808\n"
]
}
],
"execution_count": 6
2026-03-10 22:23:44 +08:00
},
{
"metadata": {},
2026-03-11 22:54:52 +08:00
"cell_type": "markdown",
"source": "### 4.2 数据划分"
2026-03-10 22:23:44 +08:00
},
{
2026-03-13 22:24:12 +08:00
"metadata": {
"ExecuteTime": {
2026-03-14 02:12:20 +08:00
"end_time": "2026-03-13T18:10:34.997214Z",
"start_time": "2026-03-13T18:10:34.932346Z"
2026-03-13 22:24:12 +08:00
}
},
2026-03-10 22:23:44 +08:00
"cell_type": "code",
"source": [
"print(\"\\n\" + \"=\" * 80)\n",
"print(\"数据划分\")\n",
"print(\"=\" * 80)\n",
"\n",
"if splitter:\n",
" train_data, val_data, test_data = splitter.split(filtered_data)\n",
" print(f\"\\n训练集数据规模: {train_data.shape}\")\n",
" print(f\"验证集数据规模: {val_data.shape}\")\n",
" print(f\"测试集数据规模: {test_data.shape}\")\n",
2026-03-11 22:54:52 +08:00
"\n",
2026-03-10 22:23:44 +08:00
" # 计算各集的 group 数组\n",
" train_group = compute_group_array(train_data)\n",
" val_group = compute_group_array(val_data)\n",
" test_group = compute_group_array(test_data)\n",
2026-03-11 22:54:52 +08:00
"\n",
2026-03-10 22:23:44 +08:00
" print(f\"\\n训练集 group 数量: {len(train_group)}\")\n",
" print(f\"验证集 group 数量: {len(val_group)}\")\n",
" print(f\"测试集 group 数量: {len(test_group)}\")\n",
" print(f\"训练集日均样本数: {np.mean(train_group):.1f}\")\n",
" print(f\"验证集日均样本数: {np.mean(val_group):.1f}\")\n",
" print(f\"测试集日均样本数: {np.mean(test_group):.1f}\")\n",
"else:\n",
" raise ValueError(\"必须配置数据划分器\")"
2026-03-13 22:24:12 +08:00
],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"================================================================================\n",
"数据划分\n",
"================================================================================\n",
"\n",
2026-03-14 02:12:20 +08:00
"训练集数据规模: (970000, 72)\n",
"验证集数据规模: (242000, 72)\n",
"测试集数据规模: (243000, 72)\n",
2026-03-13 22:24:12 +08:00
"\n",
"训练集 group 数量: 970\n",
"验证集 group 数量: 242\n",
"测试集 group 数量: 243\n",
"训练集日均样本数: 1000.0\n",
"验证集日均样本数: 1000.0\n",
"测试集日均样本数: 1000.0\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
2026-03-14 02:12:20 +08:00
"C:\\Users\\liaozhaorun\\AppData\\Local\\Temp\\ipykernel_26384\\2929956284.py:149: DeprecationWarning: `pl.count()` is deprecated. Please use `pl.len()` instead.\n",
2026-03-13 22:24:12 +08:00
"(Deprecated in version 0.20.5)\n",
" pl.count().alias(\"count\")\n"
]
}
],
"execution_count": 7
2026-03-10 22:23:44 +08:00
},
{
"metadata": {},
2026-03-11 22:54:52 +08:00
"cell_type": "markdown",
2026-03-13 22:24:12 +08:00
"source": "### 4.3 数据质量检查"
},
{
"metadata": {
"ExecuteTime": {
2026-03-14 02:12:20 +08:00
"end_time": "2026-03-13T18:10:37.417614Z",
"start_time": "2026-03-13T18:10:35.002702Z"
2026-03-13 22:24:12 +08:00
}
},
"cell_type": "code",
"source": [
"print(\"\\n\" + \"=\" * 80)\n",
"print(\"数据质量检查(必须在预处理之前)\")\n",
"print(\"=\" * 80)\n",
"\n",
"print(\"\\n检查训练集...\")\n",
2026-03-14 00:19:03 +08:00
"check_data_quality(train_data, feature_cols, raise_on_error=False)\n",
2026-03-13 22:24:12 +08:00
"\n",
"print(\"\\n检查验证集...\")\n",
"check_data_quality(val_data, feature_cols, raise_on_error=True)\n",
"\n",
"print(\"\\n检查测试集...\")\n",
"check_data_quality(test_data, feature_cols, raise_on_error=True)\n",
"\n",
"print(\"[成功] 数据质量检查通过,未发现异常\")\n"
],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"================================================================================\n",
"数据质量检查(必须在预处理之前)\n",
"================================================================================\n",
"\n",
"检查训练集...\n",
"\n",
"================================================================================\n",
"数据质量检查报告\n",
"================================================================================\n",
"\n",
2026-03-14 02:12:20 +08:00
"[严重] 发现 1657 个全空因子:\n",
2026-03-13 22:24:12 +08:00
" (某天的某个因子所有值都是 null, 可能是数据缺失或计算错误)\n",
2026-03-14 02:12:20 +08:00
" - 日期 20200217: drawdown_from_high_60 (样本数: 1000)\n",
" - 日期 20200217: volatility_squeeze_5_60 (样本数: 1000)\n",
" - 日期 20200217: net_profit_yoy (样本数: 1000)\n",
" - 日期 20200217: revenue_yoy (样本数: 1000)\n",
" - 日期 20200217: healthy_expansion_velocity (样本数: 1000)\n",
" - 日期 20200217: pe_expansion_trend (样本数: 1000)\n",
" - 日期 20200217: value_price_divergence (样本数: 1000)\n",
" - 日期 20200213: drawdown_from_high_60 (样本数: 1000)\n",
" - 日期 20200213: volatility_squeeze_5_60 (样本数: 1000)\n",
" - 日期 20200213: net_profit_yoy (样本数: 1000)\n",
" ... 还有 1647 个\n",
2026-03-13 22:24:12 +08:00
"\n",
"--------------------------------------------------------------------------------\n",
"建议处理方式:\n",
" 1. 检查因子定义和数据源,确认计算逻辑是否正确\n",
" 2. 如果是预期内的缺失(如新股无历史数据),考虑调整因子计算窗口\n",
" 3. 如果是数据同步问题,重新同步相关数据\n",
" 4. 可以使用 filter 排除问题日期或因子\n",
2026-03-14 00:19:03 +08:00
"================================================================================\n",
"\n",
"检查验证集...\n",
"\n",
"检查测试集...\n",
"[成功] 数据质量检查通过,未发现异常\n"
2026-03-13 22:24:12 +08:00
]
}
],
"execution_count": 8
2026-03-10 22:23:44 +08:00
},
{
2026-03-11 22:54:52 +08:00
"metadata": {},
2026-03-13 22:24:12 +08:00
"cell_type": "markdown",
"source": "### 4.4 数据预处理"
},
{
"metadata": {
"ExecuteTime": {
2026-03-14 02:12:20 +08:00
"end_time": "2026-03-13T18:10:37.997346Z",
"start_time": "2026-03-13T18:10:37.438005Z"
2026-03-13 22:24:12 +08:00
}
},
2026-03-10 22:23:44 +08:00
"cell_type": "code",
"source": [
"print(\"\\n\" + \"=\" * 80)\n",
"print(\"数据预处理\")\n",
"print(\"=\" * 80)\n",
"\n",
"fitted_processors = []\n",
"if processors:\n",
" print(\"\\n训练集处理...\")\n",
" for i, processor in enumerate(processors, 1):\n",
" print(f\" [{i}/{len(processors)}] {processor.__class__.__name__}\")\n",
" train_data = processor.fit_transform(train_data)\n",
" fitted_processors.append(processor)\n",
2026-03-11 22:54:52 +08:00
"\n",
2026-03-10 22:23:44 +08:00
" print(\"\\n验证集处理...\")\n",
" for processor in fitted_processors:\n",
" val_data = processor.transform(val_data)\n",
2026-03-11 22:54:52 +08:00
"\n",
2026-03-10 22:23:44 +08:00
" print(\"\\n测试集处理...\")\n",
" for processor in fitted_processors:\n",
" test_data = processor.transform(test_data)\n",
"\n",
"print(f\"\\n处理后训练集形状: {train_data.shape}\")\n",
"print(f\"处理后验证集形状: {val_data.shape}\")\n",
"print(f\"处理后测试集形状: {test_data.shape}\")"
2026-03-13 22:24:12 +08:00
],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"================================================================================\n",
"数据预处理\n",
"================================================================================\n",
"\n",
"训练集处理...\n",
" [1/3] NullFiller\n",
" [2/3] Winsorizer\n",
" [3/3] StandardScaler\n",
"\n",
"验证集处理...\n",
"\n",
"测试集处理...\n",
"\n",
2026-03-14 02:12:20 +08:00
"处理后训练集形状: (970000, 72)\n",
"处理后验证集形状: (242000, 72)\n",
"处理后测试集形状: (243000, 72)\n"
2026-03-13 22:24:12 +08:00
]
}
],
2026-03-14 00:19:03 +08:00
"execution_count": 9
2026-03-10 22:23:44 +08:00
},
{
"metadata": {},
2026-03-11 22:54:52 +08:00
"cell_type": "markdown",
"source": "### 4.4 训练 LambdaRank 模型"
2026-03-10 22:23:44 +08:00
},
{
2026-03-13 22:24:12 +08:00
"metadata": {
"ExecuteTime": {
2026-03-14 02:12:20 +08:00
"end_time": "2026-03-13T18:10:39.279997Z",
"start_time": "2026-03-13T18:10:38.028508Z"
2026-03-13 22:24:12 +08:00
}
},
2026-03-10 22:23:44 +08:00
"cell_type": "code",
"source": [
"print(\"\\n\" + \"=\" * 80)\n",
"print(\"训练 LambdaRank 模型\")\n",
"print(\"=\" * 80)\n",
"\n",
"# 准备数据\n",
"X_train = train_data.select(feature_cols)\n",
"y_train = train_data.select(target_col).to_series()\n",
"\n",
"X_val = val_data.select(feature_cols)\n",
"y_val = val_data.select(target_col).to_series()\n",
"\n",
"print(f\"\\n训练样本数: {len(X_train)}\")\n",
"print(f\"验证样本数: {len(X_val)}\")\n",
"print(f\"特征数: {len(feature_cols)}\")\n",
"print(f\"目标变量: {target_col}\")\n",
"\n",
"print(\"\\n目标变量统计( 训练集) :\")\n",
"print(y_train.describe())\n",
"\n",
"print(\"\\n开始训练...\")\n",
"model.fit(\n",
" X=X_train,\n",
" y=y_train,\n",
" group=train_group,\n",
" eval_set=(X_val, y_val, val_group),\n",
")\n",
"print(\"训练完成!\")"
2026-03-13 22:24:12 +08:00
],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"================================================================================\n",
"训练 LambdaRank 模型\n",
"================================================================================\n",
"\n",
"训练样本数: 970000\n",
"验证样本数: 242000\n",
2026-03-14 02:12:20 +08:00
"特征数: 50\n",
2026-03-13 22:24:12 +08:00
"目标变量: future_return_5_rank_rank\n",
"\n",
"目标变量统计(训练集):\n",
"shape: (9, 2)\n",
"┌────────────┬──────────┐\n",
"│ statistic ┆ value │\n",
"│ --- ┆ --- │\n",
"│ str ┆ f64 │\n",
"╞════════════╪══════════╡\n",
"│ count ┆ 969665.0 │\n",
"│ null_count ┆ 335.0 │\n",
"│ mean ┆ 9.810091 │\n",
"│ std ┆ 5.346526 │\n",
"│ min ┆ 0.0 │\n",
"│ 25% ┆ 6.0 │\n",
"│ 50% ┆ 10.0 │\n",
"│ 75% ┆ 14.0 │\n",
"│ max ┆ 19.0 │\n",
"└────────────┴──────────┘\n",
"\n",
"开始训练...\n",
"Training until validation scores don't improve for 50 rounds\n",
"Early stopping, best iteration is:\n",
2026-03-14 02:12:20 +08:00
"[5]\ttrain's ndcg@10: 0.59368\tval's ndcg@10: 0.546167\n",
2026-03-13 22:24:12 +08:00
"训练完成!\n"
]
}
],
2026-03-14 00:19:03 +08:00
"execution_count": 10
2026-03-10 22:23:44 +08:00
},
{
"metadata": {},
2026-03-11 22:54:52 +08:00
"cell_type": "markdown",
"source": "### 4.5 训练指标曲线"
2026-03-10 22:23:44 +08:00
},
{
2026-03-13 22:24:12 +08:00
"metadata": {
"ExecuteTime": {
2026-03-14 02:12:20 +08:00
"end_time": "2026-03-13T18:10:39.394848Z",
"start_time": "2026-03-13T18:10:39.285414Z"
2026-03-13 22:24:12 +08:00
}
},
2026-03-10 22:23:44 +08:00
"cell_type": "code",
"source": [
"print(\"\\n\" + \"=\" * 80)\n",
"print(\"训练指标曲线\")\n",
"print(\"=\" * 80)\n",
"\n",
2026-03-13 22:24:12 +08:00
"# 从模型获取训练评估结果\n",
"evals_result = model.get_evals_result()\n",
2026-03-10 22:23:44 +08:00
"\n",
2026-03-13 22:24:12 +08:00
"if evals_result is None or not evals_result:\n",
" print(\"[警告] 没有可用的训练指标,请确保训练时使用了 eval_set 参数\")\n",
2026-03-10 22:23:44 +08:00
"else:\n",
2026-03-13 22:24:12 +08:00
" print(\"[成功] 已从模型获取训练评估结果\")\n",
"\n",
" # 获取评估的 NDCG 指标\n",
" ndcg_metrics = [k for k in evals_result[\"train\"].keys() if \"ndcg\" in k]\n",
" print(f\"\\n评估的 NDCG 指标: {ndcg_metrics}\")\n",
"\n",
" # 显示早停信息\n",
" actual_rounds = len(list(evals_result[\"train\"].values())[0])\n",
" expected_rounds = MODEL_PARAMS.get(\"n_estimators\", 1000)\n",
" print(f\"\\n[早停信息]\")\n",
" print(f\" 配置的最大轮数: {expected_rounds}\")\n",
" print(f\" 实际训练轮数: {actual_rounds}\")\n",
"\n",
" best_iter = model.get_best_iteration()\n",
" if best_iter is not None and best_iter < actual_rounds:\n",
" print(f\" 早停状态: 已触发(最佳迭代: {best_iter}) \")\n",
" else:\n",
" print(f\" 早停状态: 未触发(达到最大轮数)\")\n",
"\n",
" # 显示各 NDCG 指标的最终值\n",
" print(f\"\\n最终 NDCG 指标:\")\n",
" for metric in ndcg_metrics:\n",
" train_ndcg = evals_result[\"train\"][metric][-1]\n",
" val_ndcg = evals_result[\"val\"][metric][-1]\n",
" print(f\" {metric}: 训练集={train_ndcg:.4f}, 验证集={val_ndcg:.4f}\")\n",
"\n",
" # 使用封装好的方法绘制所有指标\n",
" print(\"\\n[绘图] 使用 LightGBM 原生接口绘制训练曲线...\")\n",
" fig = model.plot_all_metrics(metrics=ndcg_metrics[:4], figsize=(14, 10))\n",
" plt.show()\n",
"\n",
" print(f\"\\n[指标分析]\")\n",
" print(f\" 各NDCG指标在验证集上的最佳值:\")\n",
" for metric in ndcg_metrics:\n",
" val_metric_list = evals_result[\"val\"][metric]\n",
" best_iter_metric = val_metric_list.index(max(val_metric_list))\n",
" best_val = max(val_metric_list)\n",
" print(f\" {metric}: {best_val:.4f} (迭代 {best_iter_metric + 1})\")\n",
" print(f\"\\n[重要提醒] 验证集仅用于早停/调参,测试集完全独立于训练过程!\")"
],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"================================================================================\n",
"训练指标曲线\n",
"================================================================================\n",
"[成功] 已从模型获取训练评估结果\n",
"\n",
"评估的 NDCG 指标: ['ndcg@10']\n",
"\n",
"[早停信息]\n",
" 配置的最大轮数: 2000\n",
2026-03-14 02:12:20 +08:00
" 实际训练轮数: 55\n",
" 早停状态: 已触发(最佳迭代: 5) \n",
2026-03-13 22:24:12 +08:00
"\n",
"最终 NDCG 指标:\n",
2026-03-14 02:12:20 +08:00
" ndcg@10: 训练集=0.6331, 验证集=0.5393\n",
2026-03-13 22:24:12 +08:00
"\n",
"[绘图] 使用 LightGBM 原生接口绘制训练曲线...\n"
]
},
{
"data": {
"text/plain": [
"<Figure size 1400x1000 with 1 Axes>"
],
2026-03-14 02:12:20 +08:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAABW0AAAPdCAYAAADxjUr8AAAAOnRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjEwLjgsIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcvwVt1zgAAAAlwSFlzAAAPYQAAD2EBqD+naQAAz/dJREFUeJzs3QecHWW5P/Bn+242vfdGTQghpNCrShFFxIaAUlREBRsWLFeKBf9evMi9NgRFREVRLCBVeqSH0CEEQirpfVO3nv9nZrPLppKym3N29/u9dz5TzpzZ98zJIPnx7PPmZTKZTAAAAAAAkBPysz0AAAAAAADeIrQFAAAAAMghQlsAAAAAgBwitAUAAAAAyCFCWwAAAACAHCK0BQAAAADIIUJbAAAAAIAcIrQFAAAAAMghQlsAAAAAgBwitAUAIKedc845kZeXly5777131NXVNb52zTXXNL52ww03pMeGDh3aeKygoCDKy8vTY+95z3vij3/8Y9TW1m7x5yxdujQuueSSGDNmTHTq1Ck6duwYI0aMiM997nPx8ssvb3Tu+vXr4+c//3kcffTR0aNHjyguLo4BAwbEIYccEpdeemlMnz69WT57VVVVfPWrX43DDjssSktLGz/X3XffvcXzn3jiiTj++OOjc+fO6ec+/PDD484772yWsQAAsPsU7safBQAAu+T111+Pv/zlL/HRj350u85PAt61a9fGrFmz0iUJMK+//vr4xz/+kQabDZ599tk01J0/f/5G73/11VfTJQllr7766vTYvHnz4qSTTornn39+o3OT48ny5JNPRlFRUfzXf/3XZuNJwtzkOnfddVe8+eab6RhGjx4dZ511Vpx55pmRn79xTUUy9v/5n//Zrs86ceLEOO6449Kgt8Fjjz0W733ve9Ow+vTTT9+u6wAAkH0qbQEAaFWuuOKKyGQy23Vuct6aNWvi/vvvjwMPPDA99sADD8SnPvWpxnNWr14d73vf+xoD249//OPxxhtvRGVlZUydOjW+853vpFWrDdf7wAc+0BjYHnPMMWlIm1Terly5Mh5++OG48MILG89v6v/+7//Syt2f/vSnMW3atPQ9ixYtivvuuy8NbZNrJdW+TSXh7wUXXJCGrp/5zGe2+VmT15PAtkuXLjFp0qT0ZwwePDgd8+c///lYt27ddt0zAACyT6UtAACtRtLu4MUXX4zbbrstTjnllO16T4cOHeId73hH3HPPPbHnnntGRUVF/PWvf02vs//++8evf/3rtOo1kbQ3uPHGGxvfm7Rj+O53v9vYUuFf//pXGtImBg0alFbulpWVpfslJSVx1FFHpcumfvGLX8QXv/jF6NmzZ9pu4UMf+lAMHz48DVmTgDUJcm+//fa0KjYJfpPK3kQS/v7sZz9Lt1977bWtfsZnnnkmpkyZkm4nVcjjx49vDHK/9a1vpWFw0lLh1FNP3c47DQBANqm0BQCg1fjIRz6Srn/wgx/s8Ht79eqVtiBokLQoSDTtD5sEq1sLi5u+J5GErw2B7bbMmTMnLrroohg1alQ899xzaVXvGWecEX379o3f/va38YUvfCFWrVoVP/rRj9KetElF7o5KQtsG++677xa3m54DAEBuE9oCANBqJAFn0gc2qU7997//vcPvTypnG8ycOTNdJ71utxRybknTc/fZZ5/G7aSitWGSsGRJAtkG//u//5seS/ro9unTJz74wQ/GSy+9lLZl+PrXv562SEgk20lFcFL5u6MWL17cuN20V2/T7YafAwBA7hPaAgDQanTt2jWtcN3ZattkYrIGSZDadL2jtvd9SSVv0gc3ac2QtGhIgt8TTzwxli9fnk5WlqwbJJXASR/dpIVDc2ja+3dnPycAALuf0BYAgFYlaTWQ9KmdOHFiPPLIIzv03qZ9YYcOHZquk8m6GiSB6bYMGTJki9e65ppr0oC06esNZsyYEQcddFC6/frrrzdW5iYB9Je//OWNzm2o0E2qcHe09UODZEK0BknbhS2dAwBAbhPaAgDQqiTh43nnnZdu//nPf96hFgI33XRT4/673/3udJ1UvTbYWj/ZhonImp6bBLXV1dXb9bMbeuI2rXzd0v6zzz4bRUVF0aNHj9gRY8eO3WLw/Oqrr27xHAAAcpvQFgCAVudrX/taFBcXN4ap27Ju3bp48MEH08C1ofL0tNNOSycGS3zqU5+KgQMHptuPPfZYnHvuuWl1bBLIJtW03/nOd9IlcfLJJ8f48ePT7eScU089NV588cX03GTCsfXr12/285OK3smTJ2/UU/dXv/pVrFixojEkTto23HHHHXHllVfGO9/5zigpKWl8/5IlS9Jl7dq1jceS9gnJsYaq2iSQHTFiRGOQ/fTTT8e0adPSYDmRhMBNA2cAAHJbXmbT/7wPAAA55Jxzzonf/e536faUKVMaJws7//zz49prr20877e//W16bhKSNp0wbFNJKPr3v/99o0m6kgrXk046KRYsWLDF93zxi1+Mq6++Ot1OwtkTTjghHcvWJBOONVwraYGQjDOpeu3Xr1/a27ZhfEkFbjKOhr62Xbp0SYPjkSNHblcv2qOPPjoeeuihdPvhhx+O448/PqqqqjY6J3n/H//4xzj99NO3eh0AAHKLSlsAAFqliy++OAoLC7d5ThJYlpWVpb1mk1A2aY+QTAbWNLBNHHjggWnFbDIx2OjRo6O8vDztm7vPPvuk/WeTatwGgwYNikmTJqVVsQcffHB6rSR8TapZDzvssPjWt74VDzzwQOP5X/jCF9JK3Pe///1pOPvXv/41DWWTn/HDH/4wDXKTXrYf+9jH0grZpoHtjkgC3CS4Pe6446JTp07p+JPx3H777QJbAIBWRqUtAAC0sKuuuiq+8pWvpAFtEjYnbRaSlgxJm4PZs2en1bJJ+Ns0HAYAoP0S2gIAwG7wox/9KL797W9vsQ9v0p/3uuuui7POOst3AQCA0BYAAHaXV155Je2Ne99998W8efOiW7duaR/ab37zm429egEAQKUtAAAAAEAOMREZAAAAAEAOEdoCAAAAAOSQwmwPIBfV1dWlPcY6deoUeXl52R4OAAAAANAGZDKZWLVqVfTv3z/y87deTyu03YIksB00aFBLfj8AAAAAQDs1Z86cGDhw4FZfF9puQVJhm5gxY0Z079695b4doNlUV1fHv//973QG7qKiIncWWgHPLbROnl1ofTy30Pp4btuuioqKtFi0IX/cGqHtFjS0REhuXufOnVvmGwKa/X/QOnTokD6zQltoHTy30Dp5dqH18dxC6+O5bfveriWricgAAAAAAHKI0BYAAAAAIIcIbQEAAAAAcoietgAAAABAo7q6uqiqqnJHdkIyz05BQUHsKqEtAAAAAJBKwtoZM2akwS07p2vXrtG3b9+3nWxsW4S2AAAAAEBkMpmYP39+Wik6aNCgyM/XWXVH79/atWtj0aJF6X6/fv1iZwltAQAAAICoqalJQ8f+/ftHhw4d3JGdUFZWlq6T4LZ379473SpBXA4AAAAARG1tbXoXiouL3Y1d0BB4V1dX7/Q1hLYAAAAAQKNd6cVKNMv9E9oCAAAAAOQQoS0AAAAAQA4R2gIAAAAARMTQoUPj6quvzvq9KMz2AAAAAAAAdtYxxxwTY8aMaZawddKkSVFeXp71L0NoCwAAAAC0WZlMJmpra6Ow8O2j0F69ekUu0B4BAAAAANhi2Lm2qiYrSyaT2a5v5JxzzomHH344/vd//zfy8vLS5YYbbkjXd911V4wbNy5KSkrikUceiTfeeCNOOeWU6NOnT3Ts2DEmTJgQ99133zbbIyTX+fWvfx2nnnpqdOjQIfbaa6+47bbbWvxPi0pbAAAAAGAz66prY+Ql92Tlzrzy3ROiQ/HbR5dJWPvaa6/FqFGj4rvf/W567OWXX07X3/jGN+LHP/5xDB8+PLp16xZz5syJk046KX7wgx+kQe6NN94YJ598ckydOjUGDx681Z9x+eWXx3//93/HlVdeGT/96U/jzDPPjFmzZkX37t2jpai0BQAAAABapS5dukRxcXFaBdu3b990KSgoSF9LQtzjjjsu9thjjzRgPeCAA+L8889
2026-03-13 22:24:12 +08:00
},
"metadata": {},
"output_type": "display_data",
"jetTransient": {
"display_id": null
}
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"[指标分析]\n",
" 各NDCG指标在验证集上的最佳值:\n",
2026-03-14 02:12:20 +08:00
" ndcg@10: 0.5462 (迭代 5)\n",
2026-03-13 22:24:12 +08:00
"\n",
"[重要提醒] 验证集仅用于早停/调参,测试集完全独立于训练过程!\n"
]
}
],
2026-03-14 00:19:03 +08:00
"execution_count": 11
2026-03-10 22:23:44 +08:00
},
{
"metadata": {},
2026-03-11 22:54:52 +08:00
"cell_type": "markdown",
"source": "### 4.6 模型评估"
2026-03-10 22:23:44 +08:00
},
{
2026-03-13 22:24:12 +08:00
"metadata": {
"ExecuteTime": {
2026-03-14 02:12:20 +08:00
"end_time": "2026-03-13T18:10:39.655028Z",
"start_time": "2026-03-13T18:10:39.403487Z"
2026-03-13 22:24:12 +08:00
}
},
2026-03-10 22:23:44 +08:00
"cell_type": "code",
"source": [
"print(\"\\n\" + \"=\" * 80)\n",
"print(\"模型评估\")\n",
"print(\"=\" * 80)\n",
"\n",
"# 准备测试集\n",
"X_test = test_data.select(feature_cols)\n",
"y_test = test_data.select(target_col).to_series()\n",
"\n",
"# 预测\n",
"print(\"\\n生成预测...\")\n",
"predictions = model.predict(X_test)\n",
"\n",
"# 添加预测列\n",
"test_data = test_data.with_columns([pl.Series(\"prediction\", predictions)])\n",
"\n",
"# 计算 NDCG 指标\n",
"print(\"\\n计算 NDCG 指标...\")\n",
"ndcg_results = evaluate_ndcg_at_k(\n",
" y_true=y_test.to_numpy(),\n",
" y_pred=predictions,\n",
" group=test_group,\n",
" k_list=[1, 5, 10, 20],\n",
")\n",
"\n",
"print(\"\\nNDCG 评估结果:\")\n",
"print(\"-\" * 40)\n",
"for metric, value in ndcg_results.items():\n",
" print(f\" {metric}: {value:.4f}\")\n",
"\n",
"# 特征重要性\n",
"print(\"\\n特征重要性( Top 20) :\")\n",
"print(\"-\" * 40)\n",
"importance = model.feature_importance()\n",
"if importance is not None:\n",
" top_features = importance.sort_values(ascending=False).head(20)\n",
" for i, (feature, score) in enumerate(top_features.items(), 1):\n",
" print(f\" {i:2d}. {feature:30s} {score:10.2f}\")"
2026-03-13 22:24:12 +08:00
],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"================================================================================\n",
"模型评估\n",
"================================================================================\n",
"\n",
"生成预测...\n",
"\n",
"计算 NDCG 指标...\n",
"\n",
"NDCG 评估结果:\n",
"----------------------------------------\n",
2026-03-14 02:12:20 +08:00
" ndcg@1: 0.4710\n",
" ndcg@5: 0.4927\n",
" ndcg@10: 0.5003\n",
" ndcg@20: 0.5075\n",
2026-03-13 22:24:12 +08:00
"\n",
"特征重要性( Top 20) :\n",
"----------------------------------------\n",
2026-03-14 02:12:20 +08:00
" 1. roe 345.77\n",
" 2. ma_20 244.73\n",
" 3. profit_margin 223.10\n",
" 4. revenue_yoy 216.85\n",
" 5. active_market_cap 214.10\n",
" 6. std_return_20 149.44\n",
" 7. close_vwap_deviation 136.42\n",
" 8. max_ret_20 110.61\n",
" 9. debt_to_equity 105.78\n",
" 10. bbi_ratio 104.63\n",
" 11. market_cap_rank 104.10\n",
" 12. up_days_ratio_20 99.44\n",
" 13. turnover_rank 80.60\n",
" 14. min_ret_20 72.72\n",
" 15. ebit_rank 64.03\n",
" 16. drawdown_from_high_60 61.28\n",
" 17. turnover_rate_mean_5 57.16\n",
" 18. healthy_expansion_velocity 54.10\n",
" 19. BP 51.36\n",
" 20. EP 51.29\n"
2026-03-13 22:24:12 +08:00
]
}
],
2026-03-14 00:19:03 +08:00
"execution_count": 12
2026-03-10 22:23:44 +08:00
},
{
2026-03-13 22:24:12 +08:00
"metadata": {
"ExecuteTime": {
2026-03-14 02:12:20 +08:00
"end_time": "2026-03-13T18:10:39.872615Z",
"start_time": "2026-03-13T18:10:39.663834Z"
2026-03-13 22:24:12 +08:00
}
},
2026-03-10 22:23:44 +08:00
"cell_type": "code",
"source": [
"# 确保输出目录存在\n",
"os.makedirs(OUTPUT_DIR, exist_ok=True)\n",
"\n",
"# 生成时间戳\n",
"start_dt = datetime.strptime(TEST_START, \"%Y%m%d\")\n",
"end_dt = datetime.strptime(TEST_END, \"%Y%m%d\")\n",
"date_str = f\"{start_dt.strftime('%Y%m%d')}_{end_dt.strftime('%Y%m%d')}\"\n",
"\n",
"# 保存每日 Top N\n",
"print(f\"\\n[1/1] 保存每日 Top {TOP_N} 股票...\")\n",
"topn_output_path = os.path.join(OUTPUT_DIR, \"rank_output.csv\")\n",
"\n",
"# 按日期分组,取每日 top N\n",
"topn_by_date = []\n",
"unique_dates = test_data[\"trade_date\"].unique().sort()\n",
"for date in unique_dates:\n",
" day_data = test_data.filter(test_data[\"trade_date\"] == date)\n",
" # 按 prediction 降序排序,取前 N\n",
" topn = day_data.sort(\"prediction\", descending=True).head(TOP_N)\n",
" topn_by_date.append(topn)\n",
"\n",
"# 合并所有日期的 top N\n",
"topn_results = pl.concat(topn_by_date)\n",
"\n",
"# 格式化日期并调整列顺序:日期、分数、股票\n",
"topn_to_save = topn_results.select(\n",
" [\n",
" pl.col(\"trade_date\").str.slice(0, 4)\n",
" + \"-\"\n",
" + pl.col(\"trade_date\").str.slice(4, 2)\n",
" + \"-\"\n",
" + pl.col(\"trade_date\").str.slice(6, 2).alias(\"date\"),\n",
" pl.col(\"prediction\").alias(\"score\"),\n",
" pl.col(\"ts_code\"),\n",
" ]\n",
")\n",
"topn_to_save.write_csv(topn_output_path, include_header=True)\n",
"print(f\" 保存路径: {topn_output_path}\")\n",
2026-03-11 22:54:52 +08:00
"print(\n",
" f\" 保存行数: {len(topn_to_save)}( {len(unique_dates)}个交易日 x 每日top{TOP_N}) \"\n",
")\n",
2026-03-10 22:23:44 +08:00
"print(f\"\\n 预览( 前15行) :\")\n",
"print(topn_to_save.head(15))\n",
"\n",
"print(\"\\n训练流程完成! \")"
2026-03-13 22:24:12 +08:00
],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"[1/1] 保存每日 Top 5 股票...\n",
" 保存路径: output\\rank_output.csv\n",
" 保存行数: 1215( 243个交易日 x 每日top5) \n",
"\n",
" 预览( 前15行) :\n",
"shape: (15, 3)\n",
"┌────────────┬──────────┬───────────┐\n",
"│ trade_date ┆ score ┆ ts_code │\n",
"│ --- ┆ --- ┆ --- │\n",
"│ str ┆ f64 ┆ str │\n",
"╞════════════╪══════════╪═══════════╡\n",
2026-03-14 02:12:20 +08:00
"│ 2025-01-02 ┆ 0.022198 ┆ 002480.SZ │\n",
"│ 2025-01-02 ┆ 0.018886 ┆ 000826.SZ │\n",
"│ 2025-01-02 ┆ 0.018845 ┆ 600683.SH │\n",
"│ 2025-01-02 ┆ 0.016044 ┆ 000632.SZ │\n",
"│ 2025-01-02 ┆ 0.014577 ┆ 002072.SZ │\n",
2026-03-13 22:24:12 +08:00
"│ … ┆ … ┆ … │\n",
2026-03-14 02:12:20 +08:00
"│ 2025-01-06 ┆ 0.03183 ┆ 600683.SH │\n",
"│ 2025-01-06 ┆ 0.023015 ┆ 000701.SZ │\n",
"│ 2025-01-06 ┆ 0.018886 ┆ 000826.SZ │\n",
"│ 2025-01-06 ┆ 0.016044 ┆ 000632.SZ │\n",
"│ 2025-01-06 ┆ 0.014577 ┆ 002072.SZ │\n",
2026-03-13 22:24:12 +08:00
"└────────────┴──────────┴───────────┘\n",
"\n",
"训练流程完成!\n"
]
}
],
2026-03-14 00:19:03 +08:00
"execution_count": 13
2026-03-10 22:23:44 +08:00
},
{
"metadata": {},
2026-03-11 22:54:52 +08:00
"cell_type": "markdown",
2026-03-10 22:23:44 +08:00
"source": [
"## 5. 总结\n",
2026-03-11 22:54:52 +08:00
"#\n",
2026-03-10 22:23:44 +08:00
"本 Notebook 实现了完整的 Learn-to-Rank 训练流程:\n",
2026-03-11 22:54:52 +08:00
"#\n",
2026-03-10 22:23:44 +08:00
"### 核心步骤\n",
2026-03-11 22:54:52 +08:00
"#\n",
2026-03-10 22:23:44 +08:00
"1. **数据准备**: 计算 49 个特征因子,将 `future_return_5` 转换为 20 分位数标签\n",
"2. **模型训练**: 使用 LightGBM LambdaRank 学习每日股票排序\n",
"3. **模型评估**: 使用 NDCG@1/5/10/20 评估排序质量\n",
"4. **策略分析**: 基于排序分数构建 Top-k 选股策略\n",
2026-03-11 22:54:52 +08:00
"#\n",
2026-03-10 22:23:44 +08:00
"### 关键参数\n",
2026-03-11 22:54:52 +08:00
"#\n",
2026-03-10 22:23:44 +08:00
"- **Objective**: lambdarank\n",
"- **Metric**: ndcg\n",
"- **Learning Rate**: 0.05\n",
"- **Num Leaves**: 31\n",
"- **N Quantiles**: 20\n",
2026-03-11 22:54:52 +08:00
"#\n",
2026-03-10 22:23:44 +08:00
"### 输出结果\n",
2026-03-11 22:54:52 +08:00
"#\n",
2026-03-10 22:23:44 +08:00
"- rank_output.csv: 每日Top-N推荐股票( 格式: date, score, ts_code) \n",
"- 特征重要性排名\n",
"- Top-k 策略统计和图表\n",
"- NDCG训练指标曲线\n",
2026-03-11 22:54:52 +08:00
"#\n",
2026-03-10 22:23:44 +08:00
"### 后续优化方向\n",
2026-03-11 22:54:52 +08:00
"#\n",
2026-03-10 22:23:44 +08:00
"1. **特征工程**: 尝试更多因子组合\n",
"2. **超参数调优**: 使用网格搜索优化 LambdaRank 参数\n",
"3. **模型集成**: 结合多个排序模型的预测\n",
2026-03-11 22:54:52 +08:00
"4. **更复杂的分组**: 考虑按行业分组排序\n"
2026-03-10 22:23:44 +08:00
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.0"
}
},
"nbformat": 4,
"nbformat_minor": 4
}