Files
ProStock/src/experiment/regression.ipynb
liaozhaorun 0e9ea5d533 refactor(experiment): 提取共用配置到 common 模块
- 将因子定义、日期配置、股票池筛选等提取到 common.py
- 重构 learn_to_rank 和 regression 脚本,统一使用公共配置
- 简化代码结构,消除重复定义
2026-03-15 05:46:19 +08:00

817 lines
28 KiB
Plaintext
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 1. 导入依赖"
]
},
{
"metadata": {},
"cell_type": "code",
"outputs": [],
"execution_count": null,
"source": [
"import os\n",
"from datetime import datetime\n",
"\n",
"import polars as pl\n",
"\n",
"from src.factors import FactorEngine\n",
"from src.training import (\n",
" DateSplitter,\n",
" LightGBMModel,\n",
" STFilter,\n",
" StandardScaler,\n",
" StockPoolManager,\n",
" Trainer,\n",
" Winsorizer,\n",
" NullFiller,\n",
" check_data_quality,\n",
")\n",
"from src.training.config import TrainingConfig\n",
"\n",
"# 从 common 模块导入共用配置和函数\n",
"from src.experiment.common import (\n",
" SELECTED_FACTORS,\n",
" FACTOR_DEFINITIONS,\n",
" get_label_factor,\n",
" register_factors,\n",
" prepare_data,\n",
" TRAIN_START,\n",
" TRAIN_END,\n",
" VAL_START,\n",
" VAL_END,\n",
" TEST_START,\n",
" TEST_END,\n",
" stock_pool_filter,\n",
" STOCK_FILTER_REQUIRED_COLUMNS,\n",
" OUTPUT_DIR,\n",
" SAVE_PREDICTIONS,\n",
" PERSIST_MODEL,\n",
" TOP_N,\n",
")\n",
"\n"
]
},
{
"metadata": {},
"cell_type": "markdown",
"source": [
"## 2. 配置参数\n",
"#\n",
"### 2.1 标签定义"
]
},
{
"metadata": {},
"cell_type": "code",
"outputs": [],
"execution_count": null,
"source": [
"# Label 名称(回归任务使用连续收益率)\n",
"LABEL_NAME = \"future_return_5\"\n",
"\n",
"# 获取 Label 因子定义\n",
"LABEL_FACTOR = get_label_factor(LABEL_NAME)\n",
"\n",
"# 模型参数配置\n",
"MODEL_PARAMS = {\n",
" \"objective\": \"regression\",\n",
" \"metric\": \"mae\", # 改为 MAE对异常值更稳健\n",
" # 树结构控制(防过拟合核心)\n",
" # \"num_leaves\": 20, # 从31降为20降低模型复杂度\n",
" # \"max_depth\": 16, # 显式限制深度,防止过度拟合噪声\n",
" # \"min_child_samples\": 50, # 叶子最小样本数,防止学习极端样本\n",
" # \"min_child_weight\": 0.001,\n",
" # 学习参数\n",
" \"learning_rate\": 0.01, # 降低学习率,配合更多树\n",
" \"n_estimators\": 1000, # 增加树数量,配合早停\n",
" # 采样策略(关键防过拟合)\n",
" \"subsample\": 0.8, # 每棵树随机采样80%数据(行采样)\n",
" \"subsample_freq\": 5, # 每5轮迭代进行一次 subsample\n",
" \"colsample_bytree\": 0.8, # 每棵树随机选择80%特征(列采样)\n",
" # 正则化\n",
" \"reg_alpha\": 0.1, # L1正则增加稀疏性\n",
" \"reg_lambda\": 1.0, # L2正则平滑权重\n",
" # 数值稳定性\n",
" \"verbose\": -1,\n",
" \"random_state\": 42,\n",
"}"
]
},
{
"metadata": {},
"cell_type": "markdown",
"source": [
"## 4. 训练流程\n",
"#\n",
"### 4.1 初始化组件"
]
},
{
"metadata": {},
"cell_type": "code",
"outputs": [],
"execution_count": null,
"source": [
"print(\"\\n\" + \"=\" * 80)\n",
"print(\"LightGBM 回归模型训练\")\n",
"print(\"=\" * 80)\n",
"\n",
"# 1. 创建 FactorEngine启用 metadata 功能)\n",
"print(\"\\n[1] 创建 FactorEngine\")\n",
"engine = FactorEngine(metadata_path=\"data/factors.jsonl\")\n",
"\n",
"# 2. 使用 metadata 定义因子\n",
"print(\"\\n[2] 定义因子(从 metadata 注册)\")\n",
"feature_cols = register_factors(\n",
" engine, SELECTED_FACTORS, FACTOR_DEFINITIONS, LABEL_FACTOR\n",
")\n",
"target_col = LABEL_NAME\n",
"\n",
"# 3. 准备数据(使用模块级别的日期配置)\n",
"print(\"\\n[3] 准备数据\")\n",
"\n",
"data = prepare_data(\n",
" engine=engine,\n",
" feature_cols=feature_cols,\n",
" start_date=TRAIN_START,\n",
" end_date=TEST_END,\n",
" label_name=LABEL_NAME,\n",
")\n",
"\n",
"# 4. 打印配置信息\n",
"print(f\"\\n[配置] 训练期: {TRAIN_START} - {TRAIN_END}\")\n",
"print(f\"[配置] 验证期: {VAL_START} - {VAL_END}\")\n",
"print(f\"[配置] 测试期: {TEST_START} - {TEST_END}\")\n",
"print(f\"[配置] 特征数: {len(feature_cols)}\")\n",
"print(f\"[配置] 目标变量: {target_col}\")\n",
"\n",
"# 5. 创建模型\n",
"model = LightGBMModel(params=MODEL_PARAMS)\n",
"\n",
"# 6. 创建数据处理器(使用函数返回的完整特征列表)\n",
"processors = [\n",
" NullFiller(feature_cols=feature_cols, strategy=\"mean\"),\n",
" Winsorizer(feature_cols=feature_cols, lower=0.01, upper=0.99),\n",
" StandardScaler(feature_cols=feature_cols),\n",
"]\n",
"\n",
"# 7. 创建数据划分器(正确的 train/val/test 三分法)\n",
"# Train: 训练模型参数 | Val: 验证/早停 | Test: 最终评估\n",
"splitter = DateSplitter(\n",
" train_start=TRAIN_START,\n",
" train_end=TRAIN_END,\n",
" val_start=VAL_START,\n",
" val_end=VAL_END,\n",
" test_start=TEST_START,\n",
" test_end=TEST_END,\n",
")\n",
"\n",
"# 8. 创建股票池管理器\n",
"# 使用新的 API传入自定义筛选函数和所需列\n",
"pool_manager = StockPoolManager(\n",
" filter_func=stock_pool_filter,\n",
" required_columns=STOCK_FILTER_REQUIRED_COLUMNS, # 筛选所需的额外列\n",
" # required_factors=STOCK_FILTER_REQUIRED_FACTORS, # 可选:筛选所需的因子\n",
" data_router=engine.router,\n",
")\n",
"print(\"[股票池筛选] 使用自定义函数进行股票池筛选\")\n",
"print(f\"[股票池筛选] 所需基础列: {STOCK_FILTER_REQUIRED_COLUMNS}\")\n",
"print(\"[股票池筛选] 筛选逻辑: 排除创业板/科创板/北交所后每日选市值最小的500只\")\n",
"# print(f\"[股票池筛选] 所需因子: {list(STOCK_FILTER_REQUIRED_FACTORS.keys())}\")\n",
"\n",
"# 9. 创建 ST 股票过滤器\n",
"st_filter = STFilter(\n",
" data_router=engine.router,\n",
")\n",
"\n",
"# 10. 创建训练器\n",
"trainer = Trainer(\n",
" model=model,\n",
" pool_manager=pool_manager,\n",
" processors=processors,\n",
" filters=[st_filter], # 使用STFilter过滤ST股票\n",
" splitter=splitter,\n",
" target_col=target_col,\n",
" feature_cols=feature_cols,\n",
" persist_model=PERSIST_MODEL,\n",
")"
]
},
{
"metadata": {},
"cell_type": "markdown",
"source": "### 4.2 执行训练"
},
{
"metadata": {},
"cell_type": "code",
"outputs": [],
"execution_count": null,
"source": [
"print(\"\\n\" + \"=\" * 80)\n",
"print(\"开始训练\")\n",
"print(\"=\" * 80)\n",
"\n",
"# 步骤 1: 股票池筛选\n",
"print(\"\\n[步骤 1/6] 股票池筛选\")\n",
"print(\"-\" * 60)\n",
"if pool_manager:\n",
" print(\" 执行每日独立筛选股票池...\")\n",
" filtered_data = pool_manager.filter_and_select_daily(data)\n",
" print(f\" 筛选前数据规模: {data.shape}\")\n",
" print(f\" 筛选后数据规模: {filtered_data.shape}\")\n",
" print(f\" 筛选前股票数: {data['ts_code'].n_unique()}\")\n",
" print(f\" 筛选后股票数: {filtered_data['ts_code'].n_unique()}\")\n",
" print(f\" 删除记录数: {len(data) - len(filtered_data)}\")\n",
"else:\n",
" filtered_data = data\n",
" print(\" 未配置股票池管理器,跳过筛选\")"
]
},
{
"metadata": {},
"cell_type": "code",
"source": [
"# 步骤 2: 划分训练/验证/测试集(正确的三分法)\n",
"print(\"\\n[步骤 2/6] 划分训练集、验证集和测试集\")\n",
"print(\"-\" * 60)\n",
"if splitter:\n",
" # 正确的三分法train用于训练val用于验证/早停test仅用于最终评估\n",
" train_data, val_data, test_data = splitter.split(filtered_data)\n",
" print(f\" 训练集数据规模: {train_data.shape}\")\n",
" print(f\" 验证集数据规模: {val_data.shape}\")\n",
" print(f\" 测试集数据规模: {test_data.shape}\")\n",
" print(f\" 训练集股票数: {train_data['ts_code'].n_unique()}\")\n",
" print(f\" 验证集股票数: {val_data['ts_code'].n_unique()}\")\n",
" print(f\" 测试集股票数: {test_data['ts_code'].n_unique()}\")\n",
" print(\n",
" f\" 训练集日期范围: {train_data['trade_date'].min()} - {train_data['trade_date'].max()}\"\n",
" )\n",
" print(\n",
" f\" 验证集日期范围: {val_data['trade_date'].min()} - {val_data['trade_date'].max()}\"\n",
" )\n",
" print(\n",
" f\" 测试集日期范围: {test_data['trade_date'].min()} - {test_data['trade_date'].max()}\"\n",
" )\n",
"\n",
" print(\"\\n 训练集前5行预览:\")\n",
" print(train_data.head())\n",
" print(\"\\n 验证集前5行预览:\")\n",
" print(val_data.head())\n",
" print(\"\\n 测试集前5行预览:\")\n",
" print(test_data.head())\n",
"else:\n",
" train_data = filtered_data\n",
" test_data = filtered_data\n",
" print(\" 未配置划分器,全部作为训练集\")"
],
"outputs": [],
"execution_count": null
},
{
"metadata": {},
"cell_type": "code",
"outputs": [],
"execution_count": null,
"source": [
"# 步骤 3: 数据质量检查(必须在预处理之前)\n",
"print(\"\\n[步骤 3/7] 数据质量检查\")\n",
"print(\"-\" * 60)\n",
"print(\" [说明] 此检查在 fillna 等处理之前执行,用于发现数据问题\")\n",
"\n",
"print(\"\\n 检查训练集...\")\n",
"check_data_quality(train_data, feature_cols, raise_on_error=True)\n",
"\n",
"if \"val_data\" in locals() and val_data is not None:\n",
" print(\"\\n 检查验证集...\")\n",
" check_data_quality(val_data, feature_cols, raise_on_error=True)\n",
"\n",
"print(\"\\n 检查测试集...\")\n",
"check_data_quality(test_data, feature_cols, raise_on_error=True)\n",
"\n",
"print(\" [成功] 数据质量检查通过,未发现异常\")\n"
]
},
{
"metadata": {},
"cell_type": "code",
"source": [
"# 步骤 4: 训练集数据处理\n",
"print(\"\\n[步骤 4/7] 训练集数据处理\")\n",
"print(\"-\" * 60)\n",
"fitted_processors = []\n",
"if processors:\n",
" for i, processor in enumerate(processors, 1):\n",
" print(f\" [{i}/{len(processors)}] 应用处理器: {processor.__class__.__name__}\")\n",
" train_data_before = len(train_data)\n",
" train_data = processor.fit_transform(train_data)\n",
" train_data_after = len(train_data)\n",
" fitted_processors.append(processor)\n",
" print(f\" 处理前记录数: {train_data_before}\")\n",
" print(f\" 处理后记录数: {train_data_after}\")\n",
" if train_data_before != train_data_after:\n",
" print(f\" 删除记录数: {train_data_before - train_data_after}\")\n",
"\n",
"print(\"\\n 训练集处理后前5行预览:\")\n",
"print(train_data.head())\n",
"print(f\"\\n 训练集特征统计:\")\n",
"print(f\" 特征数: {len(feature_cols)}\")\n",
"print(f\" 样本数: {len(train_data)}\")\n",
"print(f\" 缺失值统计:\")\n",
"for col in feature_cols[:5]: # 只显示前5个特征的缺失值\n",
" null_count = train_data[col].null_count()\n",
" if null_count > 0:\n",
" print(f\" {col}: {null_count} ({null_count / len(train_data) * 100:.2f}%)\")"
],
"outputs": [],
"execution_count": null
},
{
"metadata": {},
"cell_type": "code",
"outputs": [],
"execution_count": null,
"source": [
"# 步骤 4: 训练模型\n",
"print(\"\\n[步骤 5/7] 训练模型\")\n",
"print(\"-\" * 60)\n",
"print(f\" 模型类型: LightGBM\")\n",
"print(f\" 训练样本数: {len(train_data)}\")\n",
"print(f\" 特征数: {len(feature_cols)}\")\n",
"print(f\" 目标变量: {target_col}\")\n",
"\n",
"X_train = train_data.select(feature_cols)\n",
"y_train = train_data.select(target_col).to_series()\n",
"\n",
"print(f\"\\n 目标变量统计:\")\n",
"print(f\" 均值: {y_train.mean():.6f}\")\n",
"print(f\" 标准差: {y_train.std():.6f}\")\n",
"print(f\" 最小值: {y_train.min():.6f}\")\n",
"print(f\" 最大值: {y_train.max():.6f}\")\n",
"print(f\" 缺失值: {y_train.null_count()}\")\n",
"\n",
"print(\"\\n 开始训练...\")\n",
"model.fit(X_train, y_train)\n",
"print(\" 训练完成!\")"
]
},
{
"metadata": {},
"cell_type": "code",
"outputs": [],
"execution_count": null,
"source": [
"# 步骤 5: 测试集数据处理\n",
"print(\"\\n[步骤 6/7] 测试集数据处理\")\n",
"print(\"-\" * 60)\n",
"if processors and test_data is not train_data:\n",
" for i, processor in enumerate(fitted_processors, 1):\n",
" print(\n",
" f\" [{i}/{len(fitted_processors)}] 应用处理器: {processor.__class__.__name__}\"\n",
" )\n",
" test_data_before = len(test_data)\n",
" test_data = processor.transform(test_data)\n",
" test_data_after = len(test_data)\n",
" print(f\" 处理前记录数: {test_data_before}\")\n",
" print(f\" 处理后记录数: {test_data_after}\")\n",
"else:\n",
" print(\" 跳过测试集处理\")"
]
},
{
"metadata": {},
"cell_type": "code",
"outputs": [],
"execution_count": null,
"source": [
"# 步骤 6: 生成预测\n",
"print(\"\\n[步骤 7/7] 生成预测\")\n",
"print(\"-\" * 60)\n",
"X_test = test_data.select(feature_cols)\n",
"print(f\" 测试样本数: {len(X_test)}\")\n",
"print(\" 预测中...\")\n",
"predictions = model.predict(X_test)\n",
"print(f\" 预测完成!\")\n",
"\n",
"print(f\"\\n 预测结果统计:\")\n",
"print(f\" 均值: {predictions.mean():.6f}\")\n",
"print(f\" 标准差: {predictions.std():.6f}\")\n",
"print(f\" 最小值: {predictions.min():.6f}\")\n",
"print(f\" 最大值: {predictions.max():.6f}\")\n",
"\n",
"# 保存结果到 trainer\n",
"trainer.results = test_data.with_columns([pl.Series(\"prediction\", predictions)])"
]
},
{
"metadata": {},
"cell_type": "markdown",
"source": "### 4.3 训练指标曲线"
},
{
"metadata": {},
"cell_type": "code",
"outputs": [],
"execution_count": null,
"source": [
"print(\"\\n\" + \"=\" * 80)\n",
"print(\"训练指标曲线\")\n",
"print(\"=\" * 80)\n",
"\n",
"# 重新训练以收集指标(因为之前的训练没有保存评估结果)\n",
"print(\"\\n重新训练模型以收集训练指标...\")\n",
"\n",
"import lightgbm as lgb\n",
"\n",
"# 准备数据(使用 val 做验证test 不参与训练过程)\n",
"X_train_np = X_train.to_numpy()\n",
"y_train_np = y_train.to_numpy()\n",
"X_val_np = val_data.select(feature_cols).to_numpy()\n",
"y_val_np = val_data.select(target_col).to_series().to_numpy()\n",
"\n",
"# 创建数据集\n",
"train_dataset = lgb.Dataset(X_train_np, label=y_train_np)\n",
"val_dataset = lgb.Dataset(X_val_np, label=y_val_np, reference=train_dataset)\n",
"\n",
"# 用于存储评估结果\n",
"evals_result = {}\n",
"\n",
"# 使用与原模型相同的参数重新训练\n",
"# 正确的三分法train用于训练val用于验证test不参与训练过程\n",
"# 添加早停如果验证指标连续100轮没有改善则停止训练\n",
"booster_with_eval = lgb.train(\n",
" MODEL_PARAMS,\n",
" train_dataset,\n",
" num_boost_round=MODEL_PARAMS.get(\"n_estimators\", 100),\n",
" valid_sets=[train_dataset, val_dataset],\n",
" valid_names=[\"train\", \"val\"],\n",
" callbacks=[\n",
" lgb.record_evaluation(evals_result),\n",
" lgb.early_stopping(stopping_rounds=100, verbose=True),\n",
" ],\n",
")\n",
"\n",
"print(\"训练完成,指标已收集\")\n",
"\n",
"# 获取指标名称\n",
"metric_name = list(evals_result[\"train\"].keys())[0]\n",
"print(f\"\\n评估指标: {metric_name}\")\n",
"\n",
"# 提取训练和验证指标\n",
"train_metric = evals_result[\"train\"][metric_name]\n",
"val_metric = evals_result[\"val\"][metric_name]\n",
"\n",
"# 显示早停信息\n",
"actual_rounds = len(train_metric)\n",
"expected_rounds = MODEL_PARAMS.get(\"n_estimators\", 100)\n",
"print(f\"\\n[早停信息]\")\n",
"print(f\" 配置的最大轮数: {expected_rounds}\")\n",
"print(f\" 实际训练轮数: {actual_rounds}\")\n",
"if actual_rounds < expected_rounds:\n",
" print(f\" 早停状态: 已触发连续100轮验证指标未改善\")\n",
"else:\n",
" print(f\" 早停状态: 未触发(达到最大轮数)\")\n",
"\n",
"print(f\"\\n最终指标:\")\n",
"print(f\" 训练 {metric_name}: {train_metric[-1]:.6f}\")\n",
"print(f\" 验证 {metric_name}: {val_metric[-1]:.6f}\")"
]
},
{
"metadata": {},
"cell_type": "code",
"outputs": [],
"execution_count": null,
"source": [
"# 绘制训练指标曲线\n",
"import matplotlib.pyplot as plt\n",
"\n",
"fig, ax = plt.subplots(figsize=(12, 6))\n",
"\n",
"# 绘制训练集和验证集的指标曲线注意val用于验证test不参与训练\n",
"iterations = range(1, len(train_metric) + 1)\n",
"ax.plot(\n",
" iterations, train_metric, label=f\"Train {metric_name}\", linewidth=2, color=\"blue\"\n",
")\n",
"ax.plot(\n",
" iterations, val_metric, label=f\"Validation {metric_name}\", linewidth=2, color=\"red\"\n",
")\n",
"\n",
"ax.set_xlabel(\"Iteration\", fontsize=12)\n",
"ax.set_ylabel(metric_name.upper(), fontsize=12)\n",
"ax.set_title(\n",
" f\"Training and Validation {metric_name.upper()} Curve\",\n",
" fontsize=14,\n",
" fontweight=\"bold\",\n",
")\n",
"ax.legend(fontsize=10)\n",
"ax.grid(True, alpha=0.3)\n",
"\n",
"# 标记最佳验证指标点(用于早停决策)\n",
"best_iter = val_metric.index(min(val_metric))\n",
"best_metric = min(val_metric)\n",
"ax.axvline(\n",
" x=best_iter + 1,\n",
" color=\"green\",\n",
" linestyle=\"--\",\n",
" alpha=0.7,\n",
" label=f\"Best Iteration ({best_iter + 1})\",\n",
")\n",
"ax.scatter([best_iter + 1], [best_metric], color=\"green\", s=100, zorder=5)\n",
"ax.annotate(\n",
" f\"Best: {best_metric:.6f}\\nIter: {best_iter + 1}\",\n",
" xy=(best_iter + 1, best_metric),\n",
" xytext=(best_iter + 1 + len(iterations) * 0.1, best_metric),\n",
" fontsize=9,\n",
" arrowprops=dict(arrowstyle=\"->\", color=\"green\", alpha=0.7),\n",
")\n",
"\n",
"plt.tight_layout()\n",
"plt.show()\n",
"\n",
"print(f\"\\n[指标分析]\")\n",
"print(f\" 最佳验证 {metric_name}: {best_metric:.6f}\")\n",
"print(f\" 最佳迭代轮数: {best_iter + 1}\")\n",
"print(f\" 早停建议: 如果验证指标连续10轮不下降建议在第 {best_iter + 1} 轮停止训练\")\n",
"print(f\"\\n[重要提醒] 验证集仅用于早停/调参,测试集完全独立于训练过程!\")"
]
},
{
"metadata": {},
"cell_type": "markdown",
"source": "### 4.4 查看结果"
},
{
"metadata": {},
"cell_type": "code",
"source": [
"print(\"\\n\" + \"=\" * 80)\n",
"print(\"训练结果\")\n",
"print(\"=\" * 80)\n",
"\n",
"results = trainer.results\n",
"\n",
"print(f\"\\n结果数据形状: {results.shape}\")\n",
"print(f\"结果列: {results.columns}\")\n",
"print(f\"\\n结果前10行预览:\")\n",
"print(results.head(10))\n",
"print(f\"\\n结果后5行预览:\")\n",
"print(results.tail())\n",
"\n",
"print(f\"\\n每日预测样本数统计:\")\n",
"daily_counts = results.group_by(\"trade_date\").agg(pl.len()).sort(\"trade_date\")\n",
"print(f\" 最小: {daily_counts['len'].min()}\")\n",
"print(f\" 最大: {daily_counts['len'].max()}\")\n",
"print(f\" 平均: {daily_counts['len'].mean():.2f}\")\n",
"\n",
"# 展示某一天的前10个预测结果\n",
"sample_date = results[\"trade_date\"][0]\n",
"sample_data = results.filter(results[\"trade_date\"] == sample_date).head(10)\n",
"print(f\"\\n示例日期 {sample_date} 的前10条预测:\")\n",
"print(sample_data.select([\"ts_code\", \"trade_date\", target_col, \"prediction\"]))"
],
"outputs": [],
"execution_count": null
},
{
"metadata": {},
"cell_type": "markdown",
"source": "### 4.4 保存结果"
},
{
"metadata": {},
"cell_type": "code",
"outputs": [],
"execution_count": null,
"source": [
"print(\"\\n\" + \"=\" * 80)\n",
"print(\"保存预测结果\")\n",
"print(\"=\" * 80)\n",
"\n",
"# 确保输出目录存在\n",
"os.makedirs(OUTPUT_DIR, exist_ok=True)\n",
"\n",
"# 生成时间戳\n",
"start_dt = datetime.strptime(TEST_START, \"%Y%m%d\")\n",
"end_dt = datetime.strptime(TEST_END, \"%Y%m%d\")\n",
"date_str = f\"{start_dt.strftime('%Y%m%d')}_{end_dt.strftime('%Y%m%d')}\"\n",
"\n",
"# 保存每日 Top N\n",
"print(f\"\\n[1/1] 保存每日 Top {TOP_N} 股票...\")\n",
"topn_output_path = os.path.join(OUTPUT_DIR, f\"regression_output.csv\")\n",
"\n",
"# 按日期分组,取每日 top N\n",
"topn_by_date = []\n",
"unique_dates = results[\"trade_date\"].unique().sort()\n",
"for date in unique_dates:\n",
" day_data = results.filter(results[\"trade_date\"] == date)\n",
" # 按 prediction 降序排序,取前 N\n",
" topn = day_data.sort(\"prediction\", descending=True).head(TOP_N)\n",
" topn_by_date.append(topn)\n",
"\n",
"# 合并所有日期的 top N\n",
"topn_results = pl.concat(topn_by_date)\n",
"\n",
"# 格式化日期并调整列顺序:日期、分数、股票\n",
"topn_to_save = topn_results.select(\n",
" [\n",
" pl.col(\"trade_date\").str.slice(0, 4)\n",
" + \"-\"\n",
" + pl.col(\"trade_date\").str.slice(4, 2)\n",
" + \"-\"\n",
" + pl.col(\"trade_date\").str.slice(6, 2).alias(\"date\"),\n",
" pl.col(\"prediction\").alias(\"score\"),\n",
" pl.col(\"ts_code\"),\n",
" ]\n",
")\n",
"topn_to_save.write_csv(topn_output_path, include_header=True)\n",
"print(f\" 保存路径: {topn_output_path}\")\n",
"print(\n",
" f\" 保存行数: {len(topn_to_save)}{len(unique_dates)}个交易日 × 每日top{TOP_N}\"\n",
")\n",
"print(f\"\\n 预览前15行:\")\n",
"print(topn_to_save.head(15))"
]
},
{
"metadata": {},
"cell_type": "markdown",
"source": "### 4.5 特征重要性"
},
{
"metadata": {},
"cell_type": "code",
"outputs": [],
"execution_count": null,
"source": [
"importance = model.feature_importance()\n",
"if importance is not None:\n",
" print(\"\\n特征重要性:\")\n",
" print(importance.sort_values(ascending=False))\n",
"\n",
"print(\"\\n\" + \"=\" * 80)\n",
"print(\"训练完成!\")\n",
"print(\"=\" * 80)"
]
},
{
"metadata": {},
"cell_type": "markdown",
"source": [
"## 5. 可视化分析\n",
"#\n",
"使用训练好的模型直接绘图。\n",
"- **特征重要性图**:辅助特征选择\n",
"- **决策树图**:理解决策逻辑"
]
},
{
"metadata": {},
"cell_type": "code",
"outputs": [],
"execution_count": null,
"source": [
"# 导入可视化库\n",
"import matplotlib.pyplot as plt\n",
"import lightgbm as lgb\n",
"import pandas as pd\n",
"\n",
"# 从封装的model中取出底层Booster\n",
"booster = model.model\n",
"print(f\"模型类型: {type(booster)}\")\n",
"print(f\"特征数量: {len(feature_cols)}\")"
]
},
{
"metadata": {},
"cell_type": "markdown",
"source": [
"### 5.1 绘制特征重要性(辅助特征选择)\n",
"#\n",
"**解读**\n",
"- 重要性高的特征对模型贡献大\n",
"- 重要性为0的特征可以考虑删除\n",
"- 可以帮助理解哪些因子最有效"
]
},
{
"metadata": {},
"cell_type": "code",
"outputs": [],
"execution_count": null,
"source": [
"print(\"绘制特征重要性...\")\n",
"\n",
"fig, ax = plt.subplots(figsize=(10, 8))\n",
"lgb.plot_importance(\n",
" booster,\n",
" max_num_features=20,\n",
" importance_type=\"gain\",\n",
" title=\"Feature Importance (Gain)\",\n",
" ax=ax,\n",
")\n",
"ax.set_xlabel(\"Importance (Gain)\")\n",
"plt.tight_layout()\n",
"plt.show()\n",
"\n",
"# 打印重要性排名\n",
"importance_gain = pd.Series(\n",
" booster.feature_importance(importance_type=\"gain\"), index=feature_cols\n",
").sort_values(ascending=False)\n",
"\n",
"print(\"\\n[特征重要性排名 - Gain]\")\n",
"print(importance_gain)\n",
"\n",
"# 识别低重要性特征\n",
"zero_importance = importance_gain[importance_gain == 0].index.tolist()\n",
"if zero_importance:\n",
" print(f\"\\n[低重要性特征] 以下{len(zero_importance)}个特征重要性为0可考虑删除:\")\n",
" for feat in zero_importance:\n",
" print(f\" - {feat}\")\n",
"else:\n",
" print(\"\\n所有特征都有一定重要性\")\n"
]
},
{
"metadata": {},
"cell_type": "code",
"outputs": [],
"execution_count": null,
"source": [
"# 导入可视化库\n",
"import matplotlib.pyplot as plt\n",
"import lightgbm as lgb\n",
"import pandas as pd\n",
"\n",
"# 从封装的model中取出底层Booster\n",
"booster = model.model\n",
"print(f\"模型类型: {type(booster)}\")\n",
"print(f\"特征数量: {len(feature_cols)}\")"
]
},
{
"metadata": {},
"cell_type": "code",
"outputs": [],
"execution_count": null,
"source": [
"print(\"绘制特征重要性...\")\n",
"\n",
"fig, ax = plt.subplots(figsize=(10, 8))\n",
"lgb.plot_importance(\n",
" booster,\n",
" max_num_features=20,\n",
" importance_type=\"gain\",\n",
" title=\"Feature Importance (Gain)\",\n",
" ax=ax,\n",
")\n",
"ax.set_xlabel(\"Importance (Gain)\")\n",
"plt.tight_layout()\n",
"plt.show()\n",
"\n",
"# 打印重要性排名\n",
"importance_gain = pd.Series(\n",
" booster.feature_importance(importance_type=\"gain\"), index=feature_cols\n",
").sort_values(ascending=False)\n",
"\n",
"print(\"\\n[特征重要性排名 - Gain]\")\n",
"print(importance_gain)\n",
"\n",
"# 识别低重要性特征\n",
"zero_importance = importance_gain[importance_gain == 0].index.tolist()\n",
"if zero_importance:\n",
" print(f\"\\n[低重要性特征] 以下{len(zero_importance)}个特征重要性为0可考虑删除:\")\n",
" for feat in zero_importance:\n",
" print(f\" - {feat}\")\n",
"else:\n",
" print(\"\\n所有特征都有一定重要性\")\n"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.0"
}
},
"nbformat": 4,
"nbformat_minor": 4
}