- 合并 add_factor_by_name 到 add_factor,支持三种调用方式 - FactorManager 构造函数改为可选参数,使用默认路径 - FactorEngine 默认启用 metadata,无需手动配置路径
1034 lines
44 KiB
Plaintext
1034 lines
44 KiB
Plaintext
{
|
||
"cells": [
|
||
{
|
||
"metadata": {},
|
||
"cell_type": "markdown",
|
||
"source": [
|
||
"# Learn-to-Rank 排序学习训练流程\n",
|
||
"#\n",
|
||
"本 Notebook 实现基于 LightGBM LambdaRank 的排序学习训练,用于股票排序任务。\n",
|
||
"#\n",
|
||
"## 核心特点\n",
|
||
"#\n",
|
||
"1. **Label 转换**: 将 `future_return_5` 按每日进行 20 分位数划分(qcut)\n",
|
||
"2. **排序学习**: 使用 LambdaRank 目标函数,学习每日股票排序\n",
|
||
"3. **NDCG 评估**: 使用 NDCG@1/5/10/20 评估排序质量\n",
|
||
"4. **策略回测**: 基于排序分数构建 Top-k 选股策略"
|
||
]
|
||
},
|
||
{
|
||
"metadata": {},
|
||
"cell_type": "markdown",
|
||
"source": "## 1. 导入依赖"
|
||
},
|
||
{
|
||
"metadata": {
|
||
"ExecuteTime": {
|
||
"end_time": "2026-03-11T16:02:49.975545Z",
|
||
"start_time": "2026-03-11T16:02:48.487347Z"
|
||
}
|
||
},
|
||
"cell_type": "code",
|
||
"source": [
|
||
"import os\n",
|
||
"from datetime import datetime\n",
|
||
"from typing import List, Tuple, Optional\n",
|
||
"\n",
|
||
"import numpy as np\n",
|
||
"import polars as pl\n",
|
||
"import pandas as pd\n",
|
||
"import matplotlib.pyplot as plt\n",
|
||
"from sklearn.metrics import ndcg_score\n",
|
||
"\n",
|
||
"from src.factors import FactorEngine\n",
|
||
"from src.training import (\n",
|
||
" DateSplitter,\n",
|
||
" STFilter,\n",
|
||
" StockPoolManager,\n",
|
||
" Trainer,\n",
|
||
" Winsorizer,\n",
|
||
" NullFiller,\n",
|
||
" StandardScaler,\n",
|
||
")\n",
|
||
"from src.training.components.models import LightGBMLambdaRankModel\n",
|
||
"from src.training.config import TrainingConfig\n",
|
||
"\n"
|
||
],
|
||
"outputs": [],
|
||
"execution_count": 1
|
||
},
|
||
{
|
||
"metadata": {},
|
||
"cell_type": "markdown",
|
||
"source": "## 2. 辅助函数"
|
||
},
|
||
{
|
||
"metadata": {
|
||
"ExecuteTime": {
|
||
"end_time": "2026-03-11T16:02:49.989220Z",
|
||
"start_time": "2026-03-11T16:02:49.981542Z"
|
||
}
|
||
},
|
||
"cell_type": "code",
|
||
"source": [
|
||
"def create_factors_with_metadata(\n",
|
||
" engine: FactorEngine,\n",
|
||
" selected_factors: List[str],\n",
|
||
" factor_definitions: dict,\n",
|
||
" label_factor: dict,\n",
|
||
") -> List[str]:\n",
|
||
" \"\"\"注册因子(SELECTED_FACTORS 从 metadata 查询,FACTOR_DEFINITIONS 用表达式注册)\"\"\"\n",
|
||
" print(\"=\" * 80)\n",
|
||
" print(\"注册因子\")\n",
|
||
" print(\"=\" * 80)\n",
|
||
"\n",
|
||
" # 注册 SELECTED_FACTORS 中的因子(已在 metadata 中)\n",
|
||
" print(\"\\n注册特征因子(从 metadata):\")\n",
|
||
" for name in selected_factors:\n",
|
||
" engine.add_factor(name)\n",
|
||
" print(f\" - {name}\")\n",
|
||
"\n",
|
||
" # 注册 FACTOR_DEFINITIONS 中的因子(通过表达式,尚未在 metadata 中)\n",
|
||
" print(\"\\n注册特征因子(表达式):\")\n",
|
||
" for name, expr in factor_definitions.items():\n",
|
||
" engine.add_factor(name, expr)\n",
|
||
" print(f\" - {name}: {expr}\")\n",
|
||
"\n",
|
||
" # 注册 label 因子(通过表达式)\n",
|
||
" print(\"\\n注册 Label 因子(表达式):\")\n",
|
||
" for name, expr in label_factor.items():\n",
|
||
" engine.add_factor(name, expr)\n",
|
||
" print(f\" - {name}: {expr}\")\n",
|
||
"\n",
|
||
" # 特征列 = SELECTED_FACTORS + FACTOR_DEFINITIONS 的 keys\n",
|
||
" feature_cols = selected_factors + list(factor_definitions.keys())\n",
|
||
"\n",
|
||
" print(f\"\\n特征因子数: {len(feature_cols)}\")\n",
|
||
" print(f\" - 来自 metadata: {len(selected_factors)}\")\n",
|
||
" print(f\" - 来自表达式: {len(factor_definitions)}\")\n",
|
||
" print(f\"Label: {list(label_factor.keys())[0]}\")\n",
|
||
" print(f\"已注册因子总数: {len(engine.list_registered())}\")\n",
|
||
"\n",
|
||
" return feature_cols\n",
|
||
"\n",
|
||
"\n",
|
||
"def prepare_data(\n",
|
||
" engine: FactorEngine,\n",
|
||
" feature_cols: List[str],\n",
|
||
" start_date: str,\n",
|
||
" end_date: str,\n",
|
||
") -> pl.DataFrame:\n",
|
||
" \"\"\"准备数据\"\"\"\n",
|
||
" print(\"\\n\" + \"=\" * 80)\n",
|
||
" print(\"准备数据\")\n",
|
||
" print(\"=\" * 80)\n",
|
||
"\n",
|
||
" # 计算因子(全市场数据)\n",
|
||
" print(f\"\\n计算因子: {start_date} - {end_date}\")\n",
|
||
" factor_names = feature_cols + [LABEL_NAME] # 包含 label\n",
|
||
"\n",
|
||
" data = engine.compute(\n",
|
||
" factor_names=factor_names,\n",
|
||
" start_date=start_date,\n",
|
||
" end_date=end_date,\n",
|
||
" )\n",
|
||
"\n",
|
||
" print(f\"数据形状: {data.shape}\")\n",
|
||
" print(f\"数据列: {data.columns}\")\n",
|
||
" print(f\"\\n前5行预览:\")\n",
|
||
" print(data.head())\n",
|
||
"\n",
|
||
" return data\n",
|
||
"\n",
|
||
"\n",
|
||
"def prepare_ranking_data(\n",
|
||
" df: pl.DataFrame,\n",
|
||
" label_col: str = \"future_return_5\",\n",
|
||
" date_col: str = \"trade_date\",\n",
|
||
" n_quantiles: int = 20,\n",
|
||
") -> Tuple[pl.DataFrame, str]:\n",
|
||
" \"\"\"准备排序学习数据\n",
|
||
"\n",
|
||
" 将连续 label 转换为分位数标签,用于排序学习任务。\n",
|
||
"\n",
|
||
" Args:\n",
|
||
" df: 原始数据\n",
|
||
" label_col: 原始标签列名\n",
|
||
" date_col: 日期列名\n",
|
||
" n_quantiles: 分位数数量\n",
|
||
"\n",
|
||
" Returns:\n",
|
||
" (处理后的 DataFrame, 新的标签列名)\n",
|
||
" \"\"\"\n",
|
||
" print(\"\\n\" + \"=\" * 80)\n",
|
||
" print(f\"准备排序学习数据(将 {label_col} 转换为 {n_quantiles} 分位数标签)\")\n",
|
||
" print(\"=\" * 80)\n",
|
||
"\n",
|
||
" # 新的标签列名\n",
|
||
" rank_col = f\"{label_col}_rank\"\n",
|
||
"\n",
|
||
" # 按日期分组进行分位数划分\n",
|
||
" # 使用 rank 生成 0, 1, 2, ..., n_quantiles-1 的标签\n",
|
||
" # 方法: 计算每天内的排名,然后映射到 n_quantiles 个分位数组\n",
|
||
" df_ranked = (\n",
|
||
" df.with_columns(\n",
|
||
" # 计算每天内的排名 (1-based)\n",
|
||
" pl.col(label_col).rank(method=\"min\").over(date_col).alias(\"_rank\")\n",
|
||
" )\n",
|
||
" .with_columns(\n",
|
||
" # 将排名转换为分位数标签 (0 to n_quantiles-1)\n",
|
||
" ((pl.col(\"_rank\") - 1) / pl.len().over(date_col) * n_quantiles)\n",
|
||
" .floor()\n",
|
||
" .cast(pl.Int64)\n",
|
||
" .clip(0, n_quantiles - 1)\n",
|
||
" .alias(rank_col)\n",
|
||
" )\n",
|
||
" .drop(\"_rank\")\n",
|
||
" )\n",
|
||
"\n",
|
||
" # 检查转换结果\n",
|
||
" print(f\"\\n原始 {label_col} 统计:\")\n",
|
||
" print(df_ranked[label_col].describe())\n",
|
||
"\n",
|
||
" print(f\"\\n转换后 {rank_col} 统计:\")\n",
|
||
" print(df_ranked[rank_col].describe())\n",
|
||
"\n",
|
||
" # 检查每日样本分布\n",
|
||
" print(f\"\\n每日样本数统计:\")\n",
|
||
" daily_counts = df_ranked.group_by(date_col).agg(pl.count().alias(\"count\"))\n",
|
||
" print(daily_counts[\"count\"].describe())\n",
|
||
"\n",
|
||
" # 检查分位数分布(应该是均匀的)\n",
|
||
" print(f\"\\n分位数标签分布:\")\n",
|
||
" rank_dist = df_ranked[rank_col].value_counts().sort(rank_col)\n",
|
||
" print(rank_dist)\n",
|
||
"\n",
|
||
" return df_ranked, rank_col\n",
|
||
"\n",
|
||
"\n",
|
||
"def compute_group_array(df: pl.DataFrame, date_col: str = \"trade_date\") -> np.ndarray:\n",
|
||
" \"\"\"计算 group 数组用于 LambdaRank\n",
|
||
"\n",
|
||
" 每个日期作为一个 query,group 数组表示每个 query 的样本数。\n",
|
||
"\n",
|
||
" Args:\n",
|
||
" df: 数据框\n",
|
||
" date_col: 日期列名\n",
|
||
"\n",
|
||
" Returns:\n",
|
||
" group 数组\n",
|
||
" \"\"\"\n",
|
||
" group_counts = df.group_by(date_col, maintain_order=True).agg(\n",
|
||
" pl.count().alias(\"count\")\n",
|
||
" )\n",
|
||
" return group_counts[\"count\"].to_numpy()\n",
|
||
"\n",
|
||
"\n",
|
||
"def evaluate_ndcg_at_k(\n",
|
||
" y_true: np.ndarray,\n",
|
||
" y_pred: np.ndarray,\n",
|
||
" group: np.ndarray,\n",
|
||
" k_list: List[int] = [1, 5, 10, 20],\n",
|
||
") -> dict:\n",
|
||
" \"\"\"计算 NDCG@k 指标\n",
|
||
"\n",
|
||
" Args:\n",
|
||
" y_true: 真实标签\n",
|
||
" y_pred: 预测分数\n",
|
||
" group: 分组数组\n",
|
||
" k_list: 要计算的 k 值列表\n",
|
||
"\n",
|
||
" Returns:\n",
|
||
" NDCG 指标字典\n",
|
||
" \"\"\"\n",
|
||
" results = {}\n",
|
||
"\n",
|
||
" # 按 group 拆分数据\n",
|
||
" start_idx = 0\n",
|
||
" y_true_groups = []\n",
|
||
" y_pred_groups = []\n",
|
||
"\n",
|
||
" for group_size in group:\n",
|
||
" end_idx = start_idx + group_size\n",
|
||
" y_true_groups.append(y_true[start_idx:end_idx])\n",
|
||
" y_pred_groups.append(y_pred[start_idx:end_idx])\n",
|
||
" start_idx = end_idx\n",
|
||
"\n",
|
||
" # 计算每个 k 值的平均 NDCG\n",
|
||
" for k in k_list:\n",
|
||
" ndcg_scores = []\n",
|
||
" for yt, yp in zip(y_true_groups, y_pred_groups):\n",
|
||
" if len(yt) > 1:\n",
|
||
" try:\n",
|
||
" score = ndcg_score([yt], [yp], k=k)\n",
|
||
" ndcg_scores.append(score)\n",
|
||
" except ValueError:\n",
|
||
" # 标签都相同,无法计算\n",
|
||
" pass\n",
|
||
"\n",
|
||
" results[f\"ndcg@{k}\"] = np.mean(ndcg_scores) if ndcg_scores else 0.0\n",
|
||
"\n",
|
||
" return results\n",
|
||
"\n"
|
||
],
|
||
"outputs": [],
|
||
"execution_count": 2
|
||
},
|
||
{
|
||
"metadata": {},
|
||
"cell_type": "markdown",
|
||
"source": [
|
||
"## 3. 配置参数\n",
|
||
"#\n",
|
||
"### 3.1 因子定义"
|
||
]
|
||
},
|
||
{
|
||
"metadata": {
|
||
"ExecuteTime": {
|
||
"end_time": "2026-03-11T16:02:50.000875Z",
|
||
"start_time": "2026-03-11T16:02:49.994082Z"
|
||
}
|
||
},
|
||
"cell_type": "code",
|
||
"source": [
|
||
"# 特征因子定义字典(复用 regression.ipynb 的因子定义)\n",
|
||
"LABEL_NAME = \"future_return_5_rank\"\n",
|
||
"\n",
|
||
"# 当前选择的因子列表(从 FACTOR_DEFINITIONS 中选择要使用的因子)\n",
|
||
"SELECTED_FACTORS = [\n",
|
||
" # ================= 1. 价格、趋势与路径依赖 =================\n",
|
||
" \"ma_5\",\n",
|
||
" \"ma_20\",\n",
|
||
" \"ma_ratio_5_20\",\n",
|
||
" \"bias_10\",\n",
|
||
" \"high_low_ratio\",\n",
|
||
" \"bbi_ratio\",\n",
|
||
" \"return_5\",\n",
|
||
" \"return_20\",\n",
|
||
" \"kaufman_ER_20\",\n",
|
||
" \"mom_acceleration_10_20\",\n",
|
||
" \"drawdown_from_high_60\",\n",
|
||
" \"up_days_ratio_20\",\n",
|
||
" # ================= 2. 波动率、风险调整与高阶矩 =================\n",
|
||
" \"volatility_5\",\n",
|
||
" \"volatility_20\",\n",
|
||
" \"volatility_ratio\",\n",
|
||
" \"std_return_20\",\n",
|
||
" \"sharpe_ratio_20\",\n",
|
||
" \"min_ret_20\",\n",
|
||
" \"volatility_squeeze_5_60\",\n",
|
||
" # ================= 3. 日内微观结构与异象 =================\n",
|
||
" \"overnight_intraday_diff\",\n",
|
||
" \"upper_shadow_ratio\",\n",
|
||
" \"capital_retention_20\",\n",
|
||
" \"max_ret_20\",\n",
|
||
" # ================= 4. 量能、流动性与量价背离 =================\n",
|
||
" \"volume_ratio_5_20\",\n",
|
||
" \"turnover_rate_mean_5\",\n",
|
||
" \"turnover_deviation\",\n",
|
||
" \"amihud_illiq_20\",\n",
|
||
" \"turnover_cv_20\",\n",
|
||
" \"pv_corr_20\",\n",
|
||
" \"close_vwap_deviation\",\n",
|
||
" # ================= 5. 基本面财务特征 =================\n",
|
||
" \"roe\",\n",
|
||
" \"roa\",\n",
|
||
" \"profit_margin\",\n",
|
||
" \"debt_to_equity\",\n",
|
||
" \"current_ratio\",\n",
|
||
" \"net_profit_yoy\",\n",
|
||
" \"revenue_yoy\",\n",
|
||
" \"healthy_expansion_velocity\",\n",
|
||
" # ================= 6. 基本面估值与截面动量共振 =================\n",
|
||
" \"EP\",\n",
|
||
" \"BP\",\n",
|
||
" \"CP\",\n",
|
||
" \"market_cap_rank\",\n",
|
||
" \"turnover_rank\",\n",
|
||
" \"return_5_rank\",\n",
|
||
" \"EP_rank\",\n",
|
||
" \"pe_expansion_trend\",\n",
|
||
" \"value_price_divergence\",\n",
|
||
" \"active_market_cap\",\n",
|
||
" \"ebit_rank\",\n",
|
||
"]\n",
|
||
"\n",
|
||
"# 因子定义字典(完整因子库)\n",
|
||
"FACTOR_DEFINITIONS = {\"turnover_volatility_ratio\": \"log(ts_std(turnover_rate, 20))\"}\n",
|
||
"\n",
|
||
"# Label 因子定义(不参与训练,用于计算目标)\n",
|
||
"LABEL_FACTOR = {\n",
|
||
" LABEL_NAME: \"(ts_delay(close, -5) / ts_delay(open, -1)) - 1\",\n",
|
||
"}"
|
||
],
|
||
"outputs": [],
|
||
"execution_count": 3
|
||
},
|
||
{
|
||
"metadata": {},
|
||
"cell_type": "markdown",
|
||
"source": "### 3.2 训练参数配置"
|
||
},
|
||
{
|
||
"metadata": {
|
||
"ExecuteTime": {
|
||
"end_time": "2026-03-11T16:02:50.009081Z",
|
||
"start_time": "2026-03-11T16:02:50.005330Z"
|
||
}
|
||
},
|
||
"cell_type": "code",
|
||
"source": [
|
||
"# 日期范围配置(正确的 train/val/test 三分法)\n",
|
||
"TRAIN_START = \"20200101\"\n",
|
||
"TRAIN_END = \"20231231\"\n",
|
||
"VAL_START = \"20240101\"\n",
|
||
"VAL_END = \"20241231\"\n",
|
||
"TEST_START = \"20250101\"\n",
|
||
"TEST_END = \"20251231\"\n",
|
||
"\n",
|
||
"# LambdaRank 模型参数配置\n",
|
||
"MODEL_PARAMS = {\n",
|
||
" \"objective\": \"lambdarank\",\n",
|
||
" \"metric\": \"ndcg\",\n",
|
||
" \"ndcg_at\": [1, 5, 10, 20], # 评估 NDCG@k\n",
|
||
" \"learning_rate\": 0.05,\n",
|
||
" \"num_leaves\": 31,\n",
|
||
" \"max_depth\": 6,\n",
|
||
" \"min_data_in_leaf\": 20,\n",
|
||
" \"n_estimators\": 1000,\n",
|
||
" \"early_stopping_rounds\": 50,\n",
|
||
" \"subsample\": 0.8,\n",
|
||
" \"colsample_bytree\": 0.8,\n",
|
||
" \"reg_alpha\": 0.1,\n",
|
||
" \"reg_lambda\": 1.0,\n",
|
||
" \"verbose\": -1,\n",
|
||
" \"random_state\": 42,\n",
|
||
"}\n",
|
||
"\n",
|
||
"# 分位数配置\n",
|
||
"N_QUANTILES = 20 # 将 label 分为 20 组\n",
|
||
"\n",
|
||
"# 特征列(用于数据处理器)\n",
|
||
"FEATURE_COLS = SELECTED_FACTORS\n",
|
||
"\n",
|
||
"# 数据处理器配置\n",
|
||
"PROCESSORS = [\n",
|
||
" NullFiller(feature_cols=FEATURE_COLS, strategy=\"mean\"),\n",
|
||
" Winsorizer(feature_cols=FEATURE_COLS, lower=0.01, upper=0.99),\n",
|
||
" StandardScaler(feature_cols=FEATURE_COLS),\n",
|
||
"]\n",
|
||
"\n",
|
||
"\n",
|
||
"# 股票池筛选函数\n",
|
||
"def stock_pool_filter(df: pl.DataFrame) -> pl.Series:\n",
|
||
" \"\"\"股票池筛选函数(单日数据)\n",
|
||
"\n",
|
||
" 筛选条件:\n",
|
||
" 1. 排除创业板(代码以 300 开头)\n",
|
||
" 2. 排除科创板(代码以 688 开头)\n",
|
||
" 3. 排除北交所(代码以 8、9 或 4 开头)\n",
|
||
" 4. 选取当日市值最小的500只股票\n",
|
||
" \"\"\"\n",
|
||
" code_filter = (\n",
|
||
" ~df[\"ts_code\"].str.starts_with(\"30\")\n",
|
||
" & ~df[\"ts_code\"].str.starts_with(\"68\")\n",
|
||
" & ~df[\"ts_code\"].str.starts_with(\"8\")\n",
|
||
" & ~df[\"ts_code\"].str.starts_with(\"9\")\n",
|
||
" & ~df[\"ts_code\"].str.starts_with(\"4\")\n",
|
||
" )\n",
|
||
"\n",
|
||
" valid_df = df.filter(code_filter)\n",
|
||
" n = min(1000, len(valid_df))\n",
|
||
" small_cap_codes = valid_df.sort(\"total_mv\").head(n)[\"ts_code\"]\n",
|
||
"\n",
|
||
" return df[\"ts_code\"].is_in(small_cap_codes)\n",
|
||
"\n",
|
||
"\n",
|
||
"STOCK_FILTER_REQUIRED_COLUMNS = [\"total_mv\"]\n",
|
||
"\n",
|
||
"# 输出配置\n",
|
||
"OUTPUT_DIR = \"output\"\n",
|
||
"SAVE_PREDICTIONS = True\n",
|
||
"PERSIST_MODEL = False\n",
|
||
"\n",
|
||
"# Top N 配置:每日推荐股票数量\n",
|
||
"TOP_N = 5 # 可调整为 10, 20 等"
|
||
],
|
||
"outputs": [],
|
||
"execution_count": 4
|
||
},
|
||
{
|
||
"metadata": {},
|
||
"cell_type": "markdown",
|
||
"source": "## 4. 训练流程"
|
||
},
|
||
{
|
||
"metadata": {
|
||
"ExecuteTime": {
|
||
"end_time": "2026-03-11T16:02:50.330018Z",
|
||
"start_time": "2026-03-11T16:02:50.012964Z"
|
||
}
|
||
},
|
||
"cell_type": "code",
|
||
"source": [
|
||
"print(\"\\n\" + \"=\" * 80)\n",
|
||
"print(\"LightGBM LambdaRank 排序学习训练\")\n",
|
||
"print(\"=\" * 80)\n",
|
||
"\n",
|
||
"# 1. 创建 FactorEngine(启用 metadata 功能)\n",
|
||
"print(\"\\n[1] 创建 FactorEngine\")\n",
|
||
"engine = FactorEngine(metadata_path=\"data/factors.jsonl\")\n",
|
||
"\n",
|
||
"# 2. 使用 metadata 定义因子\n",
|
||
"print(\"\\n[2] 定义因子(从 metadata 注册)\")\n",
|
||
"feature_cols = create_factors_with_metadata(\n",
|
||
" engine, SELECTED_FACTORS, FACTOR_DEFINITIONS, LABEL_FACTOR\n",
|
||
")\n",
|
||
"\n",
|
||
"# 3. 准备数据\n",
|
||
"print(\"\\n[3] 准备数据\")\n",
|
||
"data = prepare_data(\n",
|
||
" engine=engine,\n",
|
||
" feature_cols=feature_cols,\n",
|
||
" start_date=TRAIN_START,\n",
|
||
" end_date=TEST_END,\n",
|
||
")\n",
|
||
"\n",
|
||
"# 4. 转换为排序学习格式(分位数标签)\n",
|
||
"print(\"\\n[4] 转换为排序学习格式\")\n",
|
||
"data, target_col = prepare_ranking_data(\n",
|
||
" df=data,\n",
|
||
" label_col=LABEL_NAME,\n",
|
||
" n_quantiles=N_QUANTILES,\n",
|
||
")\n",
|
||
"\n",
|
||
"# 5. 打印配置信息\n",
|
||
"print(f\"\\n[配置] 训练期: {TRAIN_START} - {TRAIN_END}\")\n",
|
||
"print(f\"[配置] 验证期: {VAL_START} - {VAL_END}\")\n",
|
||
"print(f\"[配置] 测试期: {TEST_START} - {TEST_END}\")\n",
|
||
"print(f\"[配置] 特征数: {len(feature_cols)}\")\n",
|
||
"print(f\"[配置] 目标变量: {target_col}({N_QUANTILES}分位数)\")\n",
|
||
"\n",
|
||
"# 6. 创建排序学习模型\n",
|
||
"model = LightGBMLambdaRankModel(params=MODEL_PARAMS)\n",
|
||
"\n",
|
||
"# 7. 创建数据处理器\n",
|
||
"processors = PROCESSORS\n",
|
||
"\n",
|
||
"# 8. 创建数据划分器\n",
|
||
"splitter = DateSplitter(\n",
|
||
" train_start=TRAIN_START,\n",
|
||
" train_end=TRAIN_END,\n",
|
||
" val_start=VAL_START,\n",
|
||
" val_end=VAL_END,\n",
|
||
" test_start=TEST_START,\n",
|
||
" test_end=TEST_END,\n",
|
||
")\n",
|
||
"\n",
|
||
"# 9. 创建股票池管理器\n",
|
||
"pool_manager = StockPoolManager(\n",
|
||
" filter_func=stock_pool_filter,\n",
|
||
" required_columns=STOCK_FILTER_REQUIRED_COLUMNS,\n",
|
||
" data_router=engine.router,\n",
|
||
")\n",
|
||
"\n",
|
||
"# 10. 创建 ST 过滤器\n",
|
||
"st_filter = STFilter(data_router=engine.router)\n",
|
||
"\n",
|
||
"# 11. 创建训练器\n",
|
||
"trainer = Trainer(\n",
|
||
" model=model,\n",
|
||
" pool_manager=pool_manager,\n",
|
||
" processors=processors,\n",
|
||
" filters=[st_filter],\n",
|
||
" splitter=splitter,\n",
|
||
" target_col=target_col,\n",
|
||
" feature_cols=feature_cols,\n",
|
||
" persist_model=PERSIST_MODEL,\n",
|
||
")"
|
||
],
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"\n",
|
||
"================================================================================\n",
|
||
"LightGBM LambdaRank 排序学习训练\n",
|
||
"================================================================================\n",
|
||
"\n",
|
||
"[1] 创建 FactorEngine\n",
|
||
"\n",
|
||
"[2] 定义因子(从 metadata 注册)\n",
|
||
"================================================================================\n",
|
||
"注册因子\n",
|
||
"================================================================================\n",
|
||
"\n",
|
||
"注册特征因子(从 metadata):\n"
|
||
]
|
||
},
|
||
{
|
||
"ename": "QueryError",
|
||
"evalue": "查询执行失败: Binder Error: Referenced column \"name\" not found in FROM clause!\nCandidate bindings: \"json\"\n\nLINE 4: WHERE name = 'ma_5'\n ^\nSQL: \n SELECT *\n FROM read_json_auto('D:\\PyProject\\ProStock\\src\\experiment\\data\\factors.jsonl')\n WHERE name = 'ma_5'\n ",
|
||
"output_type": "error",
|
||
"traceback": [
|
||
"\u001B[31m---------------------------------------------------------------------------\u001B[39m",
|
||
"\u001B[31mBinderException\u001B[39m Traceback (most recent call last)",
|
||
"\u001B[36mFile \u001B[39m\u001B[32mD:\\PyProject\\ProStock\\src\\factors\\metadata\\manager.py:296\u001B[39m, in \u001B[36mFactorManager._execute_query\u001B[39m\u001B[34m(self, sql)\u001B[39m\n\u001B[32m 295\u001B[39m conn = \u001B[38;5;28mself\u001B[39m._get_connection()\n\u001B[32m--> \u001B[39m\u001B[32m296\u001B[39m result = \u001B[43mconn\u001B[49m\u001B[43m.\u001B[49m\u001B[43mexecute\u001B[49m\u001B[43m(\u001B[49m\u001B[43msql\u001B[49m\u001B[43m)\u001B[49m.pl()\n\u001B[32m 297\u001B[39m \u001B[38;5;28;01mreturn\u001B[39;00m result\n",
|
||
"\u001B[31mBinderException\u001B[39m: Binder Error: Referenced column \"name\" not found in FROM clause!\nCandidate bindings: \"json\"\n\nLINE 4: WHERE name = 'ma_5'\n ^",
|
||
"\nDuring handling of the above exception, another exception occurred:\n",
|
||
"\u001B[31mQueryError\u001B[39m Traceback (most recent call last)",
|
||
"\u001B[36mCell\u001B[39m\u001B[36m \u001B[39m\u001B[32mIn[5]\u001B[39m\u001B[32m, line 11\u001B[39m\n\u001B[32m 9\u001B[39m \u001B[38;5;66;03m# 2. 使用 metadata 定义因子\u001B[39;00m\n\u001B[32m 10\u001B[39m \u001B[38;5;28mprint\u001B[39m(\u001B[33m\"\u001B[39m\u001B[38;5;130;01m\\n\u001B[39;00m\u001B[33m[2] 定义因子(从 metadata 注册)\u001B[39m\u001B[33m\"\u001B[39m)\n\u001B[32m---> \u001B[39m\u001B[32m11\u001B[39m feature_cols = \u001B[43mcreate_factors_with_metadata\u001B[49m\u001B[43m(\u001B[49m\n\u001B[32m 12\u001B[39m \u001B[43m \u001B[49m\u001B[43mengine\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mSELECTED_FACTORS\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mFACTOR_DEFINITIONS\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mLABEL_FACTOR\u001B[49m\n\u001B[32m 13\u001B[39m \u001B[43m)\u001B[49m\n\u001B[32m 15\u001B[39m \u001B[38;5;66;03m# 3. 准备数据\u001B[39;00m\n\u001B[32m 16\u001B[39m \u001B[38;5;28mprint\u001B[39m(\u001B[33m\"\u001B[39m\u001B[38;5;130;01m\\n\u001B[39;00m\u001B[33m[3] 准备数据\u001B[39m\u001B[33m\"\u001B[39m)\n",
|
||
"\u001B[36mCell\u001B[39m\u001B[36m \u001B[39m\u001B[32mIn[2]\u001B[39m\u001B[32m, line 15\u001B[39m, in \u001B[36mcreate_factors_with_metadata\u001B[39m\u001B[34m(engine, selected_factors, factor_definitions, label_factor)\u001B[39m\n\u001B[32m 13\u001B[39m \u001B[38;5;28mprint\u001B[39m(\u001B[33m\"\u001B[39m\u001B[38;5;130;01m\\n\u001B[39;00m\u001B[33m注册特征因子(从 metadata):\u001B[39m\u001B[33m\"\u001B[39m)\n\u001B[32m 14\u001B[39m \u001B[38;5;28;01mfor\u001B[39;00m name \u001B[38;5;129;01min\u001B[39;00m selected_factors:\n\u001B[32m---> \u001B[39m\u001B[32m15\u001B[39m \u001B[43mengine\u001B[49m\u001B[43m.\u001B[49m\u001B[43madd_factor\u001B[49m\u001B[43m(\u001B[49m\u001B[43mname\u001B[49m\u001B[43m)\u001B[49m\n\u001B[32m 16\u001B[39m \u001B[38;5;28mprint\u001B[39m(\u001B[33mf\u001B[39m\u001B[33m\"\u001B[39m\u001B[33m - \u001B[39m\u001B[38;5;132;01m{\u001B[39;00mname\u001B[38;5;132;01m}\u001B[39;00m\u001B[33m\"\u001B[39m)\n\u001B[32m 18\u001B[39m \u001B[38;5;66;03m# 注册 FACTOR_DEFINITIONS 中的因子(通过表达式,尚未在 metadata 中)\u001B[39;00m\n",
|
||
"\u001B[36mFile \u001B[39m\u001B[32mD:\\PyProject\\ProStock\\src\\factors\\engine\\factor_engine.py:225\u001B[39m, in \u001B[36mFactorEngine.add_factor\u001B[39m\u001B[34m(self, name, expression, data_specs)\u001B[39m\n\u001B[32m 182\u001B[39m \u001B[38;5;250m\u001B[39m\u001B[33;03m\"\"\"注册因子(支持多种调用方式)。\u001B[39;00m\n\u001B[32m 183\u001B[39m \n\u001B[32m 184\u001B[39m \u001B[33;03m这是 register 方法的增强版,支持以下调用方式:\u001B[39;00m\n\u001B[32m (...)\u001B[39m\u001B[32m 221\u001B[39m \u001B[33;03m ... .add_factor(\"golden_cross\", \"ma5 > ma10\"))\u001B[39;00m\n\u001B[32m 222\u001B[39m \u001B[33;03m\"\"\"\u001B[39;00m\n\u001B[32m 223\u001B[39m \u001B[38;5;28;01mif\u001B[39;00m expression \u001B[38;5;129;01mis\u001B[39;00m \u001B[38;5;28;01mNone\u001B[39;00m:\n\u001B[32m 224\u001B[39m \u001B[38;5;66;03m# 从 metadata 查询表达式\u001B[39;00m\n\u001B[32m--> \u001B[39m\u001B[32m225\u001B[39m \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[38;5;28;43mself\u001B[39;49m\u001B[43m.\u001B[49m\u001B[43m_add_factor_from_metadata\u001B[49m\u001B[43m(\u001B[49m\u001B[43mname\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mname\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mdata_specs\u001B[49m\u001B[43m)\u001B[49m\n\u001B[32m 227\u001B[39m \u001B[38;5;28;01mif\u001B[39;00m \u001B[38;5;28misinstance\u001B[39m(expression, \u001B[38;5;28mstr\u001B[39m):\n\u001B[32m 228\u001B[39m \u001B[38;5;66;03m# Fail-Fast:立即解析,失败立即报错\u001B[39;00m\n\u001B[32m 229\u001B[39m node = \u001B[38;5;28mself\u001B[39m._parser.parse(expression)\n",
|
||
"\u001B[36mFile \u001B[39m\u001B[32mD:\\PyProject\\ProStock\\src\\factors\\engine\\factor_engine.py:159\u001B[39m, in \u001B[36mFactorEngine._add_factor_from_metadata\u001B[39m\u001B[34m(self, name, factor_name_in_metadata, data_specs)\u001B[39m\n\u001B[32m 153\u001B[39m \u001B[38;5;28;01mraise\u001B[39;00m \u001B[38;5;167;01mRuntimeError\u001B[39;00m(\n\u001B[32m 154\u001B[39m \u001B[33m\"\u001B[39m\u001B[33m引擎未配置 metadata 路径。请在初始化时传入 metadata_path 参数,\u001B[39m\u001B[33m\"\u001B[39m\n\u001B[32m 155\u001B[39m + \u001B[33m\"\u001B[39m\u001B[33m例如:FactorEngine(metadata_path=\u001B[39m\u001B[33m'\u001B[39m\u001B[33mdata/factors.jsonl\u001B[39m\u001B[33m'\u001B[39m\u001B[33m)\u001B[39m\u001B[33m\"\u001B[39m\n\u001B[32m 156\u001B[39m )\n\u001B[32m 158\u001B[39m \u001B[38;5;66;03m# 从 metadata 查询因子\u001B[39;00m\n\u001B[32m--> \u001B[39m\u001B[32m159\u001B[39m df = \u001B[38;5;28;43mself\u001B[39;49m\u001B[43m.\u001B[49m\u001B[43m_metadata\u001B[49m\u001B[43m.\u001B[49m\u001B[43mget_factors_by_name\u001B[49m\u001B[43m(\u001B[49m\u001B[43mfactor_name_in_metadata\u001B[49m\u001B[43m)\u001B[49m\n\u001B[32m 161\u001B[39m \u001B[38;5;28;01mif\u001B[39;00m \u001B[38;5;28mlen\u001B[39m(df) == \u001B[32m0\u001B[39m:\n\u001B[32m 162\u001B[39m \u001B[38;5;28;01mraise\u001B[39;00m \u001B[38;5;167;01mValueError\u001B[39;00m(\n\u001B[32m 163\u001B[39m \u001B[33mf\u001B[39m\u001B[33m\"\u001B[39m\u001B[33m在 metadata 中未找到因子 \u001B[39m\u001B[33m'\u001B[39m\u001B[38;5;132;01m{\u001B[39;00mfactor_name_in_metadata\u001B[38;5;132;01m}\u001B[39;00m\u001B[33m'\u001B[39m\u001B[33m。\u001B[39m\u001B[33m\"\u001B[39m\n\u001B[32m 164\u001B[39m + \u001B[33m\"\u001B[39m\u001B[33m请确认因子名称正确,或先使用 FactorManager 添加该因子。\u001B[39m\u001B[33m\"\u001B[39m\n\u001B[32m 165\u001B[39m )\n",
|
||
"\u001B[36mFile \u001B[39m\u001B[32mD:\\PyProject\\ProStock\\src\\factors\\metadata\\manager.py:177\u001B[39m, in \u001B[36mFactorManager.get_factors_by_name\u001B[39m\u001B[34m(self, name)\u001B[39m\n\u001B[32m 154\u001B[39m \u001B[38;5;250m\u001B[39m\u001B[33;03m\"\"\"根据名称查询因子。\u001B[39;00m\n\u001B[32m 155\u001B[39m \n\u001B[32m 156\u001B[39m \u001B[33;03m使用DuckDB执行SQL查询,返回Polars DataFrame。\u001B[39;00m\n\u001B[32m (...)\u001B[39m\u001B[32m 170\u001B[39m \u001B[33;03m ... print(df[\"dsl\"][0])\u001B[39;00m\n\u001B[32m 171\u001B[39m \u001B[33;03m\"\"\"\u001B[39;00m\n\u001B[32m 172\u001B[39m sql = \u001B[33mf\u001B[39m\u001B[33m\"\"\"\u001B[39m\n\u001B[32m 173\u001B[39m \u001B[33m SELECT *\u001B[39m\n\u001B[32m 174\u001B[39m \u001B[33m FROM read_json_auto(\u001B[39m\u001B[33m'\u001B[39m\u001B[38;5;132;01m{\u001B[39;00m\u001B[38;5;28mself\u001B[39m.filepath\u001B[38;5;132;01m}\u001B[39;00m\u001B[33m'\u001B[39m\u001B[33m)\u001B[39m\n\u001B[32m 175\u001B[39m \u001B[33m WHERE name = \u001B[39m\u001B[33m'\u001B[39m\u001B[38;5;132;01m{\u001B[39;00mname\u001B[38;5;132;01m}\u001B[39;00m\u001B[33m'\u001B[39m\n\u001B[32m 176\u001B[39m \u001B[33m\u001B[39m\u001B[33m\"\"\"\u001B[39m\n\u001B[32m--> \u001B[39m\u001B[32m177\u001B[39m \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[38;5;28;43mself\u001B[39;49m\u001B[43m.\u001B[49m\u001B[43m_execute_query\u001B[49m\u001B[43m(\u001B[49m\u001B[43msql\u001B[49m\u001B[43m)\u001B[49m\n",
|
||
"\u001B[36mFile \u001B[39m\u001B[32mD:\\PyProject\\ProStock\\src\\factors\\metadata\\manager.py:299\u001B[39m, in \u001B[36mFactorManager._execute_query\u001B[39m\u001B[34m(self, sql)\u001B[39m\n\u001B[32m 297\u001B[39m \u001B[38;5;28;01mreturn\u001B[39;00m result\n\u001B[32m 298\u001B[39m \u001B[38;5;28;01mexcept\u001B[39;00m \u001B[38;5;167;01mException\u001B[39;00m \u001B[38;5;28;01mas\u001B[39;00m e:\n\u001B[32m--> \u001B[39m\u001B[32m299\u001B[39m \u001B[38;5;28;01mraise\u001B[39;00m QueryError(sql, e)\n",
|
||
"\u001B[31mQueryError\u001B[39m: 查询执行失败: Binder Error: Referenced column \"name\" not found in FROM clause!\nCandidate bindings: \"json\"\n\nLINE 4: WHERE name = 'ma_5'\n ^\nSQL: \n SELECT *\n FROM read_json_auto('D:\\PyProject\\ProStock\\src\\experiment\\data\\factors.jsonl')\n WHERE name = 'ma_5'\n "
|
||
]
|
||
}
|
||
],
|
||
"execution_count": 5
|
||
},
|
||
{
|
||
"metadata": {},
|
||
"cell_type": "markdown",
|
||
"source": "### 4.1 股票池筛选"
|
||
},
|
||
{
|
||
"metadata": {},
|
||
"cell_type": "code",
|
||
"outputs": [],
|
||
"execution_count": null,
|
||
"source": [
|
||
"print(\"\\n\" + \"=\" * 80)\n",
|
||
"print(\"股票池筛选\")\n",
|
||
"print(\"=\" * 80)\n",
|
||
"\n",
|
||
"# 先执行 ST 过滤(在股票池筛选之前,与 Trainer.train() 保持一致)\n",
|
||
"if st_filter:\n",
|
||
" print(\"\\n[过滤] 应用 ST 过滤器...\")\n",
|
||
" data = st_filter.filter(data)\n",
|
||
" print(f\" ST 过滤后数据规模: {data.shape}\")\n",
|
||
"\n",
|
||
"if pool_manager:\n",
|
||
" print(\"\\n执行每日独立筛选股票池...\")\n",
|
||
" filtered_data = pool_manager.filter_and_select_daily(data)\n",
|
||
" print(f\" 筛选前数据规模: {data.shape}\")\n",
|
||
" print(f\" 筛选后数据规模: {filtered_data.shape}\")\n",
|
||
" print(f\" 筛选前股票数: {data['ts_code'].n_unique()}\")\n",
|
||
" print(f\" 筛选后股票数: {filtered_data['ts_code'].n_unique()}\")\n",
|
||
" print(f\" 删除记录数: {len(data) - len(filtered_data)}\")\n",
|
||
"else:\n",
|
||
" filtered_data = data\n",
|
||
" print(\" 未配置股票池管理器,跳过筛选\")"
|
||
]
|
||
},
|
||
{
|
||
"metadata": {},
|
||
"cell_type": "markdown",
|
||
"source": "### 4.2 数据划分"
|
||
},
|
||
{
|
||
"metadata": {},
|
||
"cell_type": "code",
|
||
"outputs": [],
|
||
"execution_count": null,
|
||
"source": [
|
||
"print(\"\\n\" + \"=\" * 80)\n",
|
||
"print(\"数据划分\")\n",
|
||
"print(\"=\" * 80)\n",
|
||
"\n",
|
||
"if splitter:\n",
|
||
" train_data, val_data, test_data = splitter.split(filtered_data)\n",
|
||
" print(f\"\\n训练集数据规模: {train_data.shape}\")\n",
|
||
" print(f\"验证集数据规模: {val_data.shape}\")\n",
|
||
" print(f\"测试集数据规模: {test_data.shape}\")\n",
|
||
"\n",
|
||
" # 计算各集的 group 数组\n",
|
||
" train_group = compute_group_array(train_data)\n",
|
||
" val_group = compute_group_array(val_data)\n",
|
||
" test_group = compute_group_array(test_data)\n",
|
||
"\n",
|
||
" print(f\"\\n训练集 group 数量: {len(train_group)}\")\n",
|
||
" print(f\"验证集 group 数量: {len(val_group)}\")\n",
|
||
" print(f\"测试集 group 数量: {len(test_group)}\")\n",
|
||
" print(f\"训练集日均样本数: {np.mean(train_group):.1f}\")\n",
|
||
" print(f\"验证集日均样本数: {np.mean(val_group):.1f}\")\n",
|
||
" print(f\"测试集日均样本数: {np.mean(test_group):.1f}\")\n",
|
||
"else:\n",
|
||
" raise ValueError(\"必须配置数据划分器\")"
|
||
]
|
||
},
|
||
{
|
||
"metadata": {},
|
||
"cell_type": "markdown",
|
||
"source": "### 4.3 数据预处理"
|
||
},
|
||
{
|
||
"metadata": {},
|
||
"cell_type": "code",
|
||
"outputs": [],
|
||
"execution_count": null,
|
||
"source": [
|
||
"print(\"\\n\" + \"=\" * 80)\n",
|
||
"print(\"数据预处理\")\n",
|
||
"print(\"=\" * 80)\n",
|
||
"\n",
|
||
"fitted_processors = []\n",
|
||
"if processors:\n",
|
||
" print(\"\\n训练集处理...\")\n",
|
||
" for i, processor in enumerate(processors, 1):\n",
|
||
" print(f\" [{i}/{len(processors)}] {processor.__class__.__name__}\")\n",
|
||
" train_data = processor.fit_transform(train_data)\n",
|
||
" fitted_processors.append(processor)\n",
|
||
"\n",
|
||
" print(\"\\n验证集处理...\")\n",
|
||
" for processor in fitted_processors:\n",
|
||
" val_data = processor.transform(val_data)\n",
|
||
"\n",
|
||
" print(\"\\n测试集处理...\")\n",
|
||
" for processor in fitted_processors:\n",
|
||
" test_data = processor.transform(test_data)\n",
|
||
"\n",
|
||
"print(f\"\\n处理后训练集形状: {train_data.shape}\")\n",
|
||
"print(f\"处理后验证集形状: {val_data.shape}\")\n",
|
||
"print(f\"处理后测试集形状: {test_data.shape}\")"
|
||
]
|
||
},
|
||
{
|
||
"metadata": {},
|
||
"cell_type": "markdown",
|
||
"source": "### 4.4 训练 LambdaRank 模型"
|
||
},
|
||
{
|
||
"metadata": {},
|
||
"cell_type": "code",
|
||
"outputs": [],
|
||
"execution_count": null,
|
||
"source": [
|
||
"print(\"\\n\" + \"=\" * 80)\n",
|
||
"print(\"训练 LambdaRank 模型\")\n",
|
||
"print(\"=\" * 80)\n",
|
||
"\n",
|
||
"# 准备数据\n",
|
||
"X_train = train_data.select(feature_cols)\n",
|
||
"y_train = train_data.select(target_col).to_series()\n",
|
||
"\n",
|
||
"X_val = val_data.select(feature_cols)\n",
|
||
"y_val = val_data.select(target_col).to_series()\n",
|
||
"\n",
|
||
"print(f\"\\n训练样本数: {len(X_train)}\")\n",
|
||
"print(f\"验证样本数: {len(X_val)}\")\n",
|
||
"print(f\"特征数: {len(feature_cols)}\")\n",
|
||
"print(f\"目标变量: {target_col}\")\n",
|
||
"\n",
|
||
"print(\"\\n目标变量统计(训练集):\")\n",
|
||
"print(y_train.describe())\n",
|
||
"\n",
|
||
"print(\"\\n开始训练...\")\n",
|
||
"model.fit(\n",
|
||
" X=X_train,\n",
|
||
" y=y_train,\n",
|
||
" group=train_group,\n",
|
||
" eval_set=(X_val, y_val, val_group),\n",
|
||
")\n",
|
||
"print(\"训练完成!\")"
|
||
]
|
||
},
|
||
{
|
||
"metadata": {},
|
||
"cell_type": "markdown",
|
||
"source": "### 4.5 训练指标曲线"
|
||
},
|
||
{
|
||
"metadata": {},
|
||
"cell_type": "code",
|
||
"outputs": [],
|
||
"execution_count": null,
|
||
"source": [
|
||
"print(\"\\n\" + \"=\" * 80)\n",
|
||
"print(\"训练指标曲线\")\n",
|
||
"print(\"=\" * 80)\n",
|
||
"\n",
|
||
"# 重新训练以收集指标(因为之前的训练没有保存评估结果)\n",
|
||
"print(\"\\n重新训练模型以收集训练指标...\")\n",
|
||
"\n",
|
||
"import lightgbm as lgb\n",
|
||
"\n",
|
||
"# 准备数据(使用 val 做验证,test 不参与训练过程)\n",
|
||
"X_train_np = X_train.to_numpy()\n",
|
||
"y_train_np = y_train.to_numpy()\n",
|
||
"X_val_np = val_data.select(feature_cols).to_numpy()\n",
|
||
"y_val_np = val_data.select(target_col).to_series().to_numpy()\n",
|
||
"\n",
|
||
"# 创建数据集\n",
|
||
"train_dataset = lgb.Dataset(X_train_np, label=y_train_np, group=train_group)\n",
|
||
"val_dataset = lgb.Dataset(\n",
|
||
" X_val_np, label=y_val_np, group=val_group, reference=train_dataset\n",
|
||
")\n",
|
||
"\n",
|
||
"# 用于存储评估结果\n",
|
||
"evals_result = {}\n",
|
||
"\n",
|
||
"# 使用与原模型相同的参数重新训练\n",
|
||
"# 正确的三分法:train用于训练,val用于验证,test不参与训练过程\n",
|
||
"booster_with_eval = lgb.train(\n",
|
||
" MODEL_PARAMS,\n",
|
||
" train_dataset,\n",
|
||
" num_boost_round=MODEL_PARAMS.get(\"n_estimators\", 1000),\n",
|
||
" valid_sets=[train_dataset, val_dataset],\n",
|
||
" valid_names=[\"train\", \"val\"],\n",
|
||
" callbacks=[\n",
|
||
" lgb.record_evaluation(evals_result),\n",
|
||
" lgb.early_stopping(stopping_rounds=50, verbose=True),\n",
|
||
" ],\n",
|
||
")\n",
|
||
"\n",
|
||
"print(\"训练完成,指标已收集\")\n",
|
||
"\n",
|
||
"# 获取评估的 NDCG 指标\n",
|
||
"ndcg_metrics = [k for k in evals_result[\"train\"].keys() if \"ndcg\" in k]\n",
|
||
"print(f\"\\n评估的 NDCG 指标: {ndcg_metrics}\")\n",
|
||
"\n",
|
||
"# 显示早停信息\n",
|
||
"actual_rounds = len(list(evals_result[\"train\"].values())[0])\n",
|
||
"expected_rounds = MODEL_PARAMS.get(\"n_estimators\", 1000)\n",
|
||
"print(f\"\\n[早停信息]\")\n",
|
||
"print(f\" 配置的最大轮数: {expected_rounds}\")\n",
|
||
"print(f\" 实际训练轮数: {actual_rounds}\")\n",
|
||
"if actual_rounds < expected_rounds:\n",
|
||
" print(f\" 早停状态: 已触发(连续50轮验证指标未改善)\")\n",
|
||
"else:\n",
|
||
" print(f\" 早停状态: 未触发(达到最大轮数)\")\n",
|
||
"\n",
|
||
"# 显示各 NDCG 指标的最终值\n",
|
||
"print(f\"\\n最终 NDCG 指标:\")\n",
|
||
"for metric in ndcg_metrics:\n",
|
||
" train_ndcg = evals_result[\"train\"][metric][-1]\n",
|
||
" val_ndcg = evals_result[\"val\"][metric][-1]\n",
|
||
" print(f\" {metric}: 训练集={train_ndcg:.4f}, 验证集={val_ndcg:.4f}\")"
|
||
]
|
||
},
|
||
{
|
||
"metadata": {},
|
||
"cell_type": "code",
|
||
"outputs": [],
|
||
"execution_count": null,
|
||
"source": [
|
||
"# 绘制 NDCG 训练指标曲线\n",
|
||
"import matplotlib.pyplot as plt\n",
|
||
"\n",
|
||
"fig, axes = plt.subplots(2, 2, figsize=(14, 10))\n",
|
||
"axes = axes.flatten()\n",
|
||
"\n",
|
||
"for idx, metric in enumerate(ndcg_metrics[:4]): # 最多显示4个NDCG指标\n",
|
||
" ax = axes[idx]\n",
|
||
" train_metric = evals_result[\"train\"][metric]\n",
|
||
" val_metric = evals_result[\"val\"][metric]\n",
|
||
" iterations = range(1, len(train_metric) + 1)\n",
|
||
"\n",
|
||
" ax.plot(\n",
|
||
" iterations, train_metric, label=f\"Train {metric}\", linewidth=2, color=\"blue\"\n",
|
||
" )\n",
|
||
" ax.plot(iterations, val_metric, label=f\"Val {metric}\", linewidth=2, color=\"red\")\n",
|
||
" ax.set_xlabel(\"Iteration\", fontsize=10)\n",
|
||
" ax.set_ylabel(metric.upper(), fontsize=10)\n",
|
||
" ax.set_title(\n",
|
||
" f\"Training and Validation {metric.upper()}\", fontsize=12, fontweight=\"bold\"\n",
|
||
" )\n",
|
||
" ax.legend(fontsize=9)\n",
|
||
" ax.grid(True, alpha=0.3)\n",
|
||
"\n",
|
||
" # 标记最佳验证指标点\n",
|
||
" best_iter = val_metric.index(max(val_metric))\n",
|
||
" best_metric = max(val_metric)\n",
|
||
" ax.axvline(x=best_iter + 1, color=\"green\", linestyle=\"--\", alpha=0.7)\n",
|
||
" ax.scatter([best_iter + 1], [best_metric], color=\"green\", s=80, zorder=5)\n",
|
||
" ax.annotate(\n",
|
||
" f\"Best: {best_metric:.4f}\",\n",
|
||
" xy=(best_iter + 1, best_metric),\n",
|
||
" xytext=(best_iter + 1 + len(iterations) * 0.05, best_metric),\n",
|
||
" fontsize=8,\n",
|
||
" arrowprops=dict(arrowstyle=\"->\", color=\"green\", alpha=0.7),\n",
|
||
" )\n",
|
||
"\n",
|
||
"plt.tight_layout()\n",
|
||
"plt.show()\n",
|
||
"\n",
|
||
"print(f\"\\n[指标分析]\")\n",
|
||
"print(f\" 各NDCG指标在验证集上的最佳值:\")\n",
|
||
"for metric in ndcg_metrics:\n",
|
||
" val_metric_list = evals_result[\"val\"][metric]\n",
|
||
" best_iter = val_metric_list.index(max(val_metric_list))\n",
|
||
" best_val = max(val_metric_list)\n",
|
||
" print(f\" {metric}: {best_val:.4f} (迭代 {best_iter + 1})\")\n",
|
||
"print(f\"\\n[重要提醒] 验证集仅用于早停/调参,测试集完全独立于训练过程!\")"
|
||
]
|
||
},
|
||
{
|
||
"metadata": {},
|
||
"cell_type": "markdown",
|
||
"source": "### 4.6 模型评估"
|
||
},
|
||
{
|
||
"metadata": {},
|
||
"cell_type": "code",
|
||
"outputs": [],
|
||
"execution_count": null,
|
||
"source": [
|
||
"print(\"\\n\" + \"=\" * 80)\n",
|
||
"print(\"模型评估\")\n",
|
||
"print(\"=\" * 80)\n",
|
||
"\n",
|
||
"# 准备测试集\n",
|
||
"X_test = test_data.select(feature_cols)\n",
|
||
"y_test = test_data.select(target_col).to_series()\n",
|
||
"\n",
|
||
"# 预测\n",
|
||
"print(\"\\n生成预测...\")\n",
|
||
"predictions = model.predict(X_test)\n",
|
||
"\n",
|
||
"# 添加预测列\n",
|
||
"test_data = test_data.with_columns([pl.Series(\"prediction\", predictions)])\n",
|
||
"\n",
|
||
"# 计算 NDCG 指标\n",
|
||
"print(\"\\n计算 NDCG 指标...\")\n",
|
||
"ndcg_results = evaluate_ndcg_at_k(\n",
|
||
" y_true=y_test.to_numpy(),\n",
|
||
" y_pred=predictions,\n",
|
||
" group=test_group,\n",
|
||
" k_list=[1, 5, 10, 20],\n",
|
||
")\n",
|
||
"\n",
|
||
"print(\"\\nNDCG 评估结果:\")\n",
|
||
"print(\"-\" * 40)\n",
|
||
"for metric, value in ndcg_results.items():\n",
|
||
" print(f\" {metric}: {value:.4f}\")\n",
|
||
"\n",
|
||
"# 特征重要性\n",
|
||
"print(\"\\n特征重要性(Top 20):\")\n",
|
||
"print(\"-\" * 40)\n",
|
||
"importance = model.feature_importance()\n",
|
||
"if importance is not None:\n",
|
||
" top_features = importance.sort_values(ascending=False).head(20)\n",
|
||
" for i, (feature, score) in enumerate(top_features.items(), 1):\n",
|
||
" print(f\" {i:2d}. {feature:30s} {score:10.2f}\")"
|
||
]
|
||
},
|
||
{
|
||
"metadata": {},
|
||
"cell_type": "code",
|
||
"outputs": [],
|
||
"execution_count": null,
|
||
"source": [
|
||
"# 确保输出目录存在\n",
|
||
"os.makedirs(OUTPUT_DIR, exist_ok=True)\n",
|
||
"\n",
|
||
"# 生成时间戳\n",
|
||
"start_dt = datetime.strptime(TEST_START, \"%Y%m%d\")\n",
|
||
"end_dt = datetime.strptime(TEST_END, \"%Y%m%d\")\n",
|
||
"date_str = f\"{start_dt.strftime('%Y%m%d')}_{end_dt.strftime('%Y%m%d')}\"\n",
|
||
"\n",
|
||
"# 保存每日 Top N\n",
|
||
"print(f\"\\n[1/1] 保存每日 Top {TOP_N} 股票...\")\n",
|
||
"topn_output_path = os.path.join(OUTPUT_DIR, \"rank_output.csv\")\n",
|
||
"\n",
|
||
"# 按日期分组,取每日 top N\n",
|
||
"topn_by_date = []\n",
|
||
"unique_dates = test_data[\"trade_date\"].unique().sort()\n",
|
||
"for date in unique_dates:\n",
|
||
" day_data = test_data.filter(test_data[\"trade_date\"] == date)\n",
|
||
" # 按 prediction 降序排序,取前 N\n",
|
||
" topn = day_data.sort(\"prediction\", descending=True).head(TOP_N)\n",
|
||
" topn_by_date.append(topn)\n",
|
||
"\n",
|
||
"# 合并所有日期的 top N\n",
|
||
"topn_results = pl.concat(topn_by_date)\n",
|
||
"\n",
|
||
"# 格式化日期并调整列顺序:日期、分数、股票\n",
|
||
"topn_to_save = topn_results.select(\n",
|
||
" [\n",
|
||
" pl.col(\"trade_date\").str.slice(0, 4)\n",
|
||
" + \"-\"\n",
|
||
" + pl.col(\"trade_date\").str.slice(4, 2)\n",
|
||
" + \"-\"\n",
|
||
" + pl.col(\"trade_date\").str.slice(6, 2).alias(\"date\"),\n",
|
||
" pl.col(\"prediction\").alias(\"score\"),\n",
|
||
" pl.col(\"ts_code\"),\n",
|
||
" ]\n",
|
||
")\n",
|
||
"topn_to_save.write_csv(topn_output_path, include_header=True)\n",
|
||
"print(f\" 保存路径: {topn_output_path}\")\n",
|
||
"print(\n",
|
||
" f\" 保存行数: {len(topn_to_save)}({len(unique_dates)}个交易日 x 每日top{TOP_N})\"\n",
|
||
")\n",
|
||
"print(f\"\\n 预览(前15行):\")\n",
|
||
"print(topn_to_save.head(15))\n",
|
||
"\n",
|
||
"print(\"\\n训练流程完成!\")"
|
||
]
|
||
},
|
||
{
|
||
"metadata": {},
|
||
"cell_type": "markdown",
|
||
"source": [
|
||
"## 5. 总结\n",
|
||
"#\n",
|
||
"本 Notebook 实现了完整的 Learn-to-Rank 训练流程:\n",
|
||
"#\n",
|
||
"### 核心步骤\n",
|
||
"#\n",
|
||
"1. **数据准备**: 计算 49 个特征因子,将 `future_return_5` 转换为 20 分位数标签\n",
|
||
"2. **模型训练**: 使用 LightGBM LambdaRank 学习每日股票排序\n",
|
||
"3. **模型评估**: 使用 NDCG@1/5/10/20 评估排序质量\n",
|
||
"4. **策略分析**: 基于排序分数构建 Top-k 选股策略\n",
|
||
"#\n",
|
||
"### 关键参数\n",
|
||
"#\n",
|
||
"- **Objective**: lambdarank\n",
|
||
"- **Metric**: ndcg\n",
|
||
"- **Learning Rate**: 0.05\n",
|
||
"- **Num Leaves**: 31\n",
|
||
"- **N Quantiles**: 20\n",
|
||
"#\n",
|
||
"### 输出结果\n",
|
||
"#\n",
|
||
"- rank_output.csv: 每日Top-N推荐股票(格式:date, score, ts_code)\n",
|
||
"- 特征重要性排名\n",
|
||
"- Top-k 策略统计和图表\n",
|
||
"- NDCG训练指标曲线\n",
|
||
"#\n",
|
||
"### 后续优化方向\n",
|
||
"#\n",
|
||
"1. **特征工程**: 尝试更多因子组合\n",
|
||
"2. **超参数调优**: 使用网格搜索优化 LambdaRank 参数\n",
|
||
"3. **模型集成**: 结合多个排序模型的预测\n",
|
||
"4. **更复杂的分组**: 考虑按行业分组排序\n"
|
||
]
|
||
}
|
||
],
|
||
"metadata": {
|
||
"kernelspec": {
|
||
"display_name": "Python 3",
|
||
"language": "python",
|
||
"name": "python3"
|
||
},
|
||
"language_info": {
|
||
"codemirror_mode": {
|
||
"name": "ipython",
|
||
"version": 3
|
||
},
|
||
"file_extension": ".py",
|
||
"mimetype": "text/x-python",
|
||
"name": "python",
|
||
"nbconvert_exporter": "python",
|
||
"pygments_lexer": "ipython3",
|
||
"version": "3.10.0"
|
||
}
|
||
},
|
||
"nbformat": 4,
|
||
"nbformat_minor": 4
|
||
}
|