Files
ProStock/src/experiment/regression.ipynb

1595 lines
194 KiB
Plaintext
Raw Normal View History

{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 1. 导入依赖"
]
},
{
"cell_type": "code",
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-08T06:10:19.495013Z",
"start_time": "2026-03-08T06:10:18.913682Z"
}
},
"source": [
"import os\n",
"from datetime import datetime\n",
"from typing import List\n",
"\n",
"import polars as pl\n",
"\n",
"from src.factors import FactorEngine\n",
"from src.training import (\n",
" DateSplitter,\n",
" LightGBMModel,\n",
" STFilter,\n",
" StandardScaler,\n",
" StockFilterConfig,\n",
" StockPoolManager,\n",
" Trainer,\n",
" Winsorizer,\n",
" NullFiller,\n",
")\n",
"from src.training.config import TrainingConfig"
],
"outputs": [],
"execution_count": 1
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 2. 定义辅助函数"
]
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-08T06:10:19.508163Z",
"start_time": "2026-03-08T06:10:19.502562Z"
}
},
"cell_type": "code",
"source": [
"def create_factors_with_strings(engine: FactorEngine, factor_definitions: dict, label_factor: dict) -> List[str]:\n",
" print(\"=\" * 80)\n",
" print(\"使用字符串表达式定义因子\")\n",
" print(\"=\" * 80)\n",
"\n",
" # 注册所有特征因子\n",
" print(\"\\n注册特征因子:\")\n",
" for name, expr in factor_definitions.items():\n",
" engine.add_factor(name, expr)\n",
" print(f\" - {name}: {expr}\")\n",
"\n",
" # 注册 label 因子\n",
" print(\"\\n注册 Label 因子:\")\n",
" for name, expr in label_factor.items():\n",
" engine.add_factor(name, expr)\n",
" print(f\" - {name}: {expr}\")\n",
"\n",
" # 从字典自动获取特征列\n",
" feature_cols = list(factor_definitions.keys())\n",
"\n",
" print(f\"\\n特征因子数: {len(feature_cols)}\")\n",
" print(f\"Label: {list(label_factor.keys())[0]}\")\n",
" print(f\"已注册因子总数: {len(engine.list_registered())}\")\n",
"\n",
" return feature_cols\n",
"\n",
"\n",
"def prepare_data(\n",
" engine: FactorEngine,\n",
" feature_cols: List[str],\n",
" start_date: str,\n",
" end_date: str,\n",
") -> pl.DataFrame:\n",
" print(\"\\n\" + \"=\" * 80)\n",
" print(\"准备数据\")\n",
" print(\"=\" * 80)\n",
"\n",
" # 计算因子(全市场数据)\n",
" print(f\"\\n计算因子: {start_date} - {end_date}\")\n",
" factor_names = feature_cols + [LABEL_NAME] # 包含 label\n",
"\n",
" data = engine.compute(\n",
" factor_names=factor_names,\n",
" start_date=start_date,\n",
" end_date=end_date,\n",
" )\n",
"\n",
" print(f\"数据形状: {data.shape}\")\n",
" print(f\"数据列: {data.columns}\")\n",
" print(f\"\\n前5行预览:\")\n",
" print(data.head())\n",
"\n",
" return data"
],
"outputs": [],
"execution_count": 2
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 3. 配置参数\n",
"\n",
"### 3.1 因子定义"
]
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-08T06:10:19.519954Z",
"start_time": "2026-03-08T06:10:19.515047Z"
}
},
"cell_type": "code",
"source": [
"# 特征因子定义字典:新增因子只需在此处添加一行\n",
"LABEL_NAME = 'future_return_5'\n",
"\n",
"FACTOR_DEFINITIONS = {\n",
" # 1. 价格动量因子\n",
" \"ma5\": \"ts_mean(close, 5)\",\n",
" \"ma10\": \"ts_mean(close, 10)\",\n",
" \"ma20\": \"ts_mean(close, 20)\",\n",
" \"ma_ratio\": \"ts_mean(close, 5) / ts_mean(close, 20) - 1\",\n",
" # 2. 波动率因子\n",
" \"volatility_5\": \"ts_std(close, 5)\",\n",
" \"volatility_20\": \"ts_std(close, 20)\",\n",
" \"vol_ratio\": \"ts_std(close, 5) / (ts_std(close, 20) + 1e-8)\",\n",
" # 3. 收益率动量因子\n",
" \"return_10\": \"(close / ts_delay(close, 10)) - 1\",\n",
" \"return_20\": \"(close / ts_delay(close, 20)) - 1\",\n",
" # 4. 收益率变化因子\n",
" \"return_diff\": \"(close / ts_delay(close, 5)) - 1 - ((close / ts_delay(close, 10)) - 1)\",\n",
" # 5. 成交量因子\n",
" \"vol_ma5\": \"ts_mean(vol, 5)\",\n",
" \"vol_ma20\": \"ts_mean(vol, 20)\",\n",
" \"vol_ratio\": \"ts_mean(vol, 5) / (ts_mean(vol, 20) + 1e-8)\",\n",
" # 6. 市值因子(截面排名)\n",
" \"market_cap_rank\": \"cs_rank(total_mv)\",\n",
" # 7. 价格位置因子\n",
" \"high_low_ratio\": \"(close - ts_min(low, 20)) / (ts_max(high, 20) - ts_min(low, 20) + 1e-8)\",\n",
" # 8. 技术指标因子3.1 完全可实现)\n",
" \"turnover_rate_mean_5\": \"ts_mean(turnover_rate, 5)\", # 5日均换手率\n",
" \"bbi_ratio_factor\": \"(ts_mean(close, 3) + ts_mean(close, 6) + ts_mean(close, 12) + ts_mean(close, 24)) / 4 / close\", # BBI比率\n",
" # 9. ARBR 因子3.1 完全可实现)\n",
" # \"AR\": \"ts_sum(high - open, 26) / ts_sum(open - low, 26) * 100\", # AR人气指标\n",
" # \"BR\": \"ts_sum(max_(0, high - ts_delay(close, 1)), 26) / ts_sum(max_(0, ts_delay(close, 1) - low), 26) * 100\", # BR意愿指标\n",
" # \"AR_BR\": \"AR - BR\", # ARBR差值\n",
" # 10. 成交量因子3.1 完全可实现)\n",
" \"volume_change_rate\": \"ts_mean(vol, 2) / ts_mean(vol, 10) - 1\", # 成交量变化率\n",
" \"turnover_deviation\": \"(turnover_rate - ts_mean(turnover_rate, 3)) / ts_std(turnover_rate, 3)\", # 换手率偏离度\n",
" \"vol_std_5\": \"ts_std(ts_delta(vol, 1), 5)\", # 成交量变化标准差\n",
" # 11. 收益率因子3.1 完全可实现)\n",
" # \"return_5\": \"close / ts_delay(close, 5) - 1\", # 5日收益率\n",
" \"std_return_5\": \"ts_std((close - ts_delay(close, 1)) / ts_delay(close, 1), 5)\", # 5日收益率标准差\n",
" \"std_return_90\": \"ts_std((close - ts_delay(close, 1)) / ts_delay(close, 1), 90)\", # 90日收益率标准差\n",
" # 12. 截面排序因子3.1 完全可实现)\n",
" \"cs_rank_volume_ratio\": \"cs_rank(volume_ratio)\", # 量比截面排名\n",
" \"cs_rank_turnover_rate\": \"cs_rank(turnover_rate)\", # 换手率截面排名\n",
" \"n_income_rank\": \"cs_rank(n_income)\", # 净利润截面排名\n",
" # 13. 财务数据因子(来自利润表 financial_income\n",
" \"operate_profit_rank\": \"cs_rank(operate_profit)\", # 营业利润截面排名\n",
" \"total_profit_rank\": \"cs_rank(total_profit)\", # 利润总额截面排名\n",
" \"ebit_rank\": \"cs_rank(ebit)\", # 息税前利润截面排名\n",
" \"ebitda_rank\": \"cs_rank(ebitda)\", # 息税折旧摊销前利润截面排名\n",
" # 14. 财务数据因子(来自资产负债表 financial_balance\n",
" \"total_liab_rank\": \"cs_rank(total_liab)\", # 总负债截面排名\n",
" \"money_cap_rank\": \"cs_rank(money_cap)\", # 货币资金截面排名\n",
" # 15. 财务数据因子(来自现金流量表 financial_cashflow\n",
" \"n_cashflow_act_rank\": \"cs_rank(n_cashflow_act)\", # 经营活动现金流净额截面排名\n",
" # 16. 财务估值因子\n",
" \"profit_to_market_cap\": \"n_income / (total_mv + 1e-8)\", # 净利润率(净利润/市值)\n",
" \"cashflow_to_market_cap\": \"n_cashflow_act / (total_mv + 1e-8)\", # 经营现金流/市值\n",
" \"operate_profit_to_market_cap\": \"operate_profit / (total_mv + 1e-8)\", # 营业利润/市值\n",
"}\n",
"\n",
"# Label 因子定义(不参与训练,用于计算目标)\n",
"LABEL_FACTOR = {\n",
" LABEL_NAME: \"(ts_delay(close, -5) / close) - 1\", # 未来5日收益率\n",
"}"
],
"outputs": [],
"execution_count": 3
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 3.2 训练参数配置"
]
},
{
"cell_type": "code",
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-08T06:10:19.527636Z",
"start_time": "2026-03-08T06:10:19.523669Z"
}
},
"source": [
"# 日期范围配置(正确的 train/val/test 三分法)\n",
"# Train: 用于训练模型参数\n",
"# Val: 用于验证/早停/调参(位于 train 之后test 之前)\n",
"# Test: 仅用于最终评估,完全独立于训练过程\n",
"TRAIN_START = \"20200101\"\n",
"TRAIN_END = \"20231231\"\n",
"VAL_START = \"20240101\"\n",
"VAL_END = \"20241231\"\n",
"TEST_START = \"20250101\"\n",
"TEST_END = \"20261231\"\n",
"\n",
"# 模型参数配置\n",
"MODEL_PARAMS = {\n",
" \"objective\": \"regression\",\n",
" \"metric\": \"mae\", # 改为 MAE对异常值更稳健\n",
" # 树结构控制(防过拟合核心)\n",
" \"num_leaves\": 20, # 从31降为20降低模型复杂度\n",
" \"max_depth\": 4, # 显式限制深度,防止过度拟合噪声\n",
" \"min_child_samples\": 50, # 叶子最小样本数,防止学习极端样本\n",
" \"min_child_weight\": 0.001,\n",
" # 学习参数\n",
" \"learning_rate\": 0.01, # 降低学习率,配合更多树\n",
" \"n_estimators\": 1000, # 增加树数量,配合早停\n",
" # 采样策略(关键防过拟合)\n",
" \"subsample\": 0.8, # 每棵树随机采样80%数据(行采样)\n",
" \"subsample_freq\": 5, # 每5轮迭代进行一次 subsample\n",
" \"colsample_bytree\": 0.8, # 每棵树随机选择80%特征(列采样)\n",
" # 正则化\n",
" \"reg_alpha\": 0.1, # L1正则增加稀疏性\n",
" \"reg_lambda\": 1.0, # L2正则平滑权重\n",
" # 数值稳定性\n",
" \"verbose\": -1,\n",
" \"random_state\": 42,\n",
"}\n",
"\n",
"# 数据处理器配置\n",
"PROCESSOR_CONFIGS = [\n",
" {\"name\": \"winsorizer\", \"params\": {\"lower\": 0.01, \"upper\": 0.99}},\n",
" {\"name\": \"cs_standard_scaler\", \"params\": {}},\n",
"]\n",
"\n",
"# 股票池筛选配置\n",
"STOCK_FILTER_CONFIG = {\n",
" \"exclude_cyb\": True, # 排除创业板\n",
" \"exclude_kcb\": True, # 排除科创板\n",
" \"exclude_bj\": True, # 排除北交所\n",
" \"exclude_st\": True, # 排除ST股票\n",
"}\n",
"\n",
"# 输出配置(相对于本文件所在目录)\n",
"OUTPUT_DIR = \"output\"\n",
"SAVE_PREDICTIONS = True\n",
"PERSIST_MODEL = False\n",
"\n",
"# Top N 配置:每日推荐股票数量\n",
"TOP_N = 2 # 可调整为 10, 20 等"
],
"outputs": [],
"execution_count": 4
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 4. 训练流程\n",
"\n",
"### 4.1 初始化组件"
]
},
{
"cell_type": "code",
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-08T06:10:27.660787Z",
"start_time": "2026-03-08T06:10:19.531220Z"
}
},
"source": [
"print(\"\\n\" + \"=\" * 80)\n",
"print(\"LightGBM 回归模型训练\")\n",
"print(\"=\" * 80)\n",
"\n",
"# 1. 创建 FactorEngine\n",
"print(\"\\n[1] 创建 FactorEngine\")\n",
"engine = FactorEngine()\n",
"\n",
"# 2. 使用字符串表达式定义因子\n",
"print(\"\\n[2] 定义因子(字符串表达式)\")\n",
"feature_cols = create_factors_with_strings(engine, FACTOR_DEFINITIONS, LABEL_FACTOR)\n",
"target_col = LABEL_NAME\n",
"\n",
"# 3. 准备数据(使用模块级别的日期配置)\n",
"print(\"\\n[3] 准备数据\")\n",
"\n",
"data = prepare_data(\n",
" engine=engine,\n",
" feature_cols=feature_cols,\n",
" start_date=TRAIN_START,\n",
" end_date=TEST_END,\n",
")\n",
"\n",
"# 4. 打印配置信息\n",
"print(f\"\\n[配置] 训练期: {TRAIN_START} - {TRAIN_END}\")\n",
"print(f\"[配置] 验证期: {VAL_START} - {VAL_END}\")\n",
"print(f\"[配置] 测试期: {TEST_START} - {TEST_END}\")\n",
"print(f\"[配置] 特征数: {len(feature_cols)}\")\n",
"print(f\"[配置] 目标变量: {target_col}\")\n",
"\n",
"# 5. 创建模型\n",
"model = LightGBMModel(params=MODEL_PARAMS)\n",
"\n",
"# 6. 创建数据处理器\n",
"processors = [\n",
" NullFiller(strategy=\"mean\"),\n",
" Winsorizer(**PROCESSOR_CONFIGS[0][\"params\"]),\n",
" StandardScaler(exclude_cols=[\"ts_code\", \"trade_date\", target_col]),\n",
"]\n",
"\n",
"# 7. 创建数据划分器(正确的 train/val/test 三分法)\n",
"# Train: 训练模型参数 | Val: 验证/早停 | Test: 最终评估\n",
"splitter = DateSplitter(\n",
" train_start=TRAIN_START,\n",
" train_end=TRAIN_END,\n",
" val_start=VAL_START,\n",
" val_end=VAL_END,\n",
" test_start=TEST_START,\n",
" test_end=TEST_END,\n",
")\n",
"\n",
"# 8. 创建股票池管理器\n",
"pool_manager = StockPoolManager(\n",
" filter_config=StockFilterConfig(**STOCK_FILTER_CONFIG),\n",
" selector_config=None, # 暂时不启用市值选择\n",
" data_router=engine.router,\n",
")\n",
"\n",
"# 9. 创建 ST 股票过滤器\n",
"st_filter = STFilter(\n",
" data_router=engine.router,\n",
")\n",
"\n",
"# 10. 创建训练器\n",
"trainer = Trainer(\n",
" model=model,\n",
" pool_manager=pool_manager,\n",
" processors=processors,\n",
" filters=[st_filter],\n",
" splitter=splitter,\n",
" target_col=target_col,\n",
" feature_cols=feature_cols,\n",
" persist_model=PERSIST_MODEL,\n",
")"
],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"================================================================================\n",
"LightGBM 回归模型训练\n",
"================================================================================\n",
"\n",
"[1] 创建 FactorEngine\n",
"\n",
"[2] 定义因子(字符串表达式)\n",
"================================================================================\n",
"使用字符串表达式定义因子\n",
"================================================================================\n",
"\n",
"注册特征因子:\n",
" - ma5: ts_mean(close, 5)\n",
" - ma10: ts_mean(close, 10)\n",
" - ma20: ts_mean(close, 20)\n",
" - ma_ratio: ts_mean(close, 5) / ts_mean(close, 20) - 1\n",
" - volatility_5: ts_std(close, 5)\n",
" - volatility_20: ts_std(close, 20)\n",
" - vol_ratio: ts_mean(vol, 5) / (ts_mean(vol, 20) + 1e-8)\n",
" - return_10: (close / ts_delay(close, 10)) - 1\n",
" - return_20: (close / ts_delay(close, 20)) - 1\n",
" - return_diff: (close / ts_delay(close, 5)) - 1 - ((close / ts_delay(close, 10)) - 1)\n",
" - vol_ma5: ts_mean(vol, 5)\n",
" - vol_ma20: ts_mean(vol, 20)\n",
" - market_cap_rank: cs_rank(total_mv)\n",
" - high_low_ratio: (close - ts_min(low, 20)) / (ts_max(high, 20) - ts_min(low, 20) + 1e-8)\n",
" - turnover_rate_mean_5: ts_mean(turnover_rate, 5)\n",
" - bbi_ratio_factor: (ts_mean(close, 3) + ts_mean(close, 6) + ts_mean(close, 12) + ts_mean(close, 24)) / 4 / close\n",
" - volume_change_rate: ts_mean(vol, 2) / ts_mean(vol, 10) - 1\n",
" - turnover_deviation: (turnover_rate - ts_mean(turnover_rate, 3)) / ts_std(turnover_rate, 3)\n",
" - vol_std_5: ts_std(ts_delta(vol, 1), 5)\n",
" - std_return_5: ts_std((close - ts_delay(close, 1)) / ts_delay(close, 1), 5)\n",
" - std_return_90: ts_std((close - ts_delay(close, 1)) / ts_delay(close, 1), 90)\n",
" - cs_rank_volume_ratio: cs_rank(volume_ratio)\n",
" - cs_rank_turnover_rate: cs_rank(turnover_rate)\n",
" - n_income_rank: cs_rank(n_income)\n",
" - operate_profit_rank: cs_rank(operate_profit)\n",
" - total_profit_rank: cs_rank(total_profit)\n",
" - ebit_rank: cs_rank(ebit)\n",
" - ebitda_rank: cs_rank(ebitda)\n",
" - total_liab_rank: cs_rank(total_liab)\n",
" - money_cap_rank: cs_rank(money_cap)\n",
" - n_cashflow_act_rank: cs_rank(n_cashflow_act)\n",
" - profit_to_market_cap: n_income / (total_mv + 1e-8)\n",
" - cashflow_to_market_cap: n_cashflow_act / (total_mv + 1e-8)\n",
" - operate_profit_to_market_cap: operate_profit / (total_mv + 1e-8)\n",
"\n",
"注册 Label 因子:\n",
" - future_return_5: (ts_delay(close, -5) / close) - 1\n",
"\n",
"特征因子数: 34\n",
"Label: future_return_5\n",
"已注册因子总数: 35\n",
"\n",
"[3] 准备数据\n",
"\n",
"================================================================================\n",
"准备数据\n",
"================================================================================\n",
"\n",
"计算因子: 20200101 - 20261231\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"D:\\PyProject\\ProStock\\src\\data\\financial_loader.py:123: UserWarning: Sortedness of columns cannot be checked when 'by' groups provided\n",
" merged = df_price.join_asof(\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"数据形状: (7255513, 53)\n",
"数据列: ['ts_code', 'trade_date', 'vol', 'low', 'high', 'volume_ratio', 'turnover_rate', 'close', 'total_mv', 'f_ann_date', 'total_profit', 'ebit', 'operate_profit', 'ebitda', 'n_income', 'money_cap', 'total_liab', 'n_cashflow_act', 'ma5', 'ma10', 'ma20', 'ma_ratio', 'volatility_5', 'volatility_20', 'vol_ratio', 'return_10', 'return_20', 'return_diff', 'vol_ma5', 'vol_ma20', 'market_cap_rank', 'high_low_ratio', 'turnover_rate_mean_5', 'bbi_ratio_factor', 'volume_change_rate', 'turnover_deviation', 'vol_std_5', 'std_return_5', 'std_return_90', 'cs_rank_volume_ratio', 'cs_rank_turnover_rate', 'n_income_rank', 'operate_profit_rank', 'total_profit_rank', 'ebit_rank', 'ebitda_rank', 'total_liab_rank', 'money_cap_rank', 'n_cashflow_act_rank', 'profit_to_market_cap', 'cashflow_to_market_cap', 'operate_profit_to_market_cap', 'future_return_5']\n",
"\n",
"前5行预览:\n",
"shape: (5, 53)\n",
"┌───────────┬────────────┬───────────┬─────────┬───┬───────────┬───────────┬───────────┬───────────┐\n",
"│ ts_code ┆ trade_date ┆ vol ┆ low ┆ … ┆ profit_to ┆ cashflow_ ┆ operate_p ┆ future_re │\n",
"│ --- ┆ --- ┆ --- ┆ --- ┆ ┆ _market_c ┆ to_market ┆ rofit_to_ ┆ turn_5 │\n",
"│ str ┆ str ┆ f64 ┆ f64 ┆ ┆ ap ┆ _cap ┆ market_ca ┆ --- │\n",
"│ ┆ ┆ ┆ ┆ ┆ --- ┆ --- ┆ p ┆ f64 │\n",
"│ ┆ ┆ ┆ ┆ ┆ f64 ┆ f64 ┆ --- ┆ │\n",
"│ ┆ ┆ ┆ ┆ ┆ ┆ ┆ f64 ┆ │\n",
"╞═══════════╪════════════╪═══════════╪═════════╪═══╪═══════════╪═══════════╪═══════════╪═══════════╡\n",
"│ 000001.SZ ┆ 20200102 ┆ 1.5302e6 ┆ 1806.75 ┆ … ┆ 721.52104 ┆ 2580.5045 ┆ 938.15146 ┆ -0.004746 │\n",
"│ ┆ ┆ ┆ ┆ ┆ 1 ┆ 33 ┆ 4 ┆ │\n",
"│ 000001.SZ ┆ 20200103 ┆ 1.1162e6 ┆ 1847.15 ┆ … ┆ 708.50174 ┆ 2533.9412 ┆ 921.22323 ┆ -0.02852 │\n",
"│ ┆ ┆ ┆ ┆ ┆ 4 ┆ 96 ┆ 7 ┆ │\n",
"│ 000001.SZ ┆ 20200106 ┆ 862083.5 ┆ 1846.05 ┆ … ┆ 713.06736 ┆ 2550.2701 ┆ 927.15964 ┆ -0.004685 │\n",
"│ ┆ ┆ ┆ ┆ ┆ 8 ┆ 5 ┆ 9 ┆ │\n",
"│ 000001.SZ ┆ 20200107 ┆ 728607.56 ┆ 1850.42 ┆ … ┆ 709.74110 ┆ 2538.3738 ┆ 922.83470 ┆ -0.022743 │\n",
"│ ┆ ┆ ┆ ┆ ┆ 6 ┆ 46 ┆ 6 ┆ │\n",
"│ 000001.SZ ┆ 20200108 ┆ 847824.12 ┆ 1815.49 ┆ … ┆ 730.61584 ┆ 2613.0319 ┆ 949.97690 ┆ -0.008401 │\n",
"│ ┆ ┆ ┆ ┆ ┆ 4 ┆ 01 ┆ 3 ┆ │\n",
"└───────────┴────────────┴───────────┴─────────┴───┴───────────┴───────────┴───────────┴───────────┘\n",
"\n",
"[配置] 训练期: 20200101 - 20231231\n",
"[配置] 验证期: 20240101 - 20241231\n",
"[配置] 测试期: 20250101 - 20261231\n",
"[配置] 特征数: 34\n",
"[配置] 目标变量: future_return_5\n"
]
}
],
"execution_count": 5
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 4.2 执行训练"
]
},
{
"cell_type": "code",
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-08T06:10:37.221864Z",
"start_time": "2026-03-08T06:10:27.670279Z"
}
},
"source": [
"print(\"\\n\" + \"=\" * 80)\n",
"print(\"开始训练\")\n",
"print(\"=\" * 80)\n",
"\n",
"# 步骤 1: 股票池筛选\n",
"print(\"\\n[步骤 1/6] 股票池筛选\")\n",
"print(\"-\" * 60)\n",
"if pool_manager:\n",
" print(\" 执行每日独立筛选股票池...\")\n",
" filtered_data = pool_manager.filter_and_select_daily(data)\n",
" print(f\" 筛选前数据规模: {data.shape}\")\n",
" print(f\" 筛选后数据规模: {filtered_data.shape}\")\n",
" print(f\" 筛选前股票数: {data['ts_code'].n_unique()}\")\n",
" print(f\" 筛选后股票数: {filtered_data['ts_code'].n_unique()}\")\n",
" print(f\" 删除记录数: {len(data) - len(filtered_data)}\")\n",
"else:\n",
" filtered_data = data\n",
" print(\" 未配置股票池管理器,跳过筛选\")"
],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"================================================================================\n",
"开始训练\n",
"================================================================================\n",
"\n",
"[步骤 1/6] 股票池筛选\n",
"------------------------------------------------------------\n",
" 执行每日独立筛选股票池...\n",
" 筛选前数据规模: (7255513, 53)\n",
" 筛选后数据规模: (4654331, 53)\n",
" 筛选前股票数: 5694\n",
" 筛选后股票数: 3364\n",
" 删除记录数: 2601182\n"
]
}
],
"execution_count": 6
},
{
"cell_type": "code",
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-08T06:10:39.192697Z",
"start_time": "2026-03-08T06:10:37.226314Z"
}
},
"source": [
"# 步骤 2: 划分训练/验证/测试集(正确的三分法)\n",
"print(\"\\n[步骤 2/6] 划分训练集、验证集和测试集\")\n",
"print(\"-\" * 60)\n",
"if splitter:\n",
" # 正确的三分法train用于训练val用于验证/早停test仅用于最终评估\n",
" train_data, val_data, test_data = splitter.split(filtered_data)\n",
" print(f\" 训练集数据规模: {train_data.shape}\")\n",
" print(f\" 验证集数据规模: {val_data.shape}\")\n",
" print(f\" 测试集数据规模: {test_data.shape}\")\n",
" print(f\" 训练集股票数: {train_data['ts_code'].n_unique()}\")\n",
" print(f\" 验证集股票数: {val_data['ts_code'].n_unique()}\")\n",
" print(f\" 测试集股票数: {test_data['ts_code'].n_unique()}\")\n",
" print(\n",
" f\" 训练集日期范围: {train_data['trade_date'].min()} - {train_data['trade_date'].max()}\"\n",
" )\n",
" print(\n",
" f\" 验证集日期范围: {val_data['trade_date'].min()} - {val_data['trade_date'].max()}\"\n",
" )\n",
" print(\n",
" f\" 测试集日期范围: {test_data['trade_date'].min()} - {test_data['trade_date'].max()}\"\n",
" )\n",
"\n",
" print(\"\\n 训练集前5行预览:\")\n",
" print(train_data.head())\n",
" print(\"\\n 验证集前5行预览:\")\n",
" print(val_data.head())\n",
" print(\"\\n 测试集前5行预览:\")\n",
" print(test_data.head())\n",
"else:\n",
" train_data = filtered_data\n",
" test_data = filtered_data\n",
" print(\" 未配置划分器,全部作为训练集\")"
],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"[步骤 2/6] 划分训练集、验证集和测试集\n",
"------------------------------------------------------------\n",
" 训练集数据规模: (2991506, 53)\n",
" 验证集数据规模: (769485, 53)\n",
" 测试集数据规模: (893340, 53)\n",
" 训练集股票数: 3297\n",
" 验证集股票数: 3220\n",
" 测试集股票数: 3220\n",
" 训练集日期范围: 20200102 - 20231229\n",
" 验证集日期范围: 20240102 - 20241231\n",
" 测试集日期范围: 20250102 - 20260306\n",
"\n",
" 训练集前5行预览:\n",
"shape: (5, 53)\n",
"┌───────────┬────────────┬───────────┬─────────┬───┬───────────┬───────────┬───────────┬───────────┐\n",
"│ ts_code ┆ trade_date ┆ vol ┆ low ┆ … ┆ profit_to ┆ cashflow_ ┆ operate_p ┆ future_re │\n",
"│ --- ┆ --- ┆ --- ┆ --- ┆ ┆ _market_c ┆ to_market ┆ rofit_to_ ┆ turn_5 │\n",
"│ str ┆ str ┆ f64 ┆ f64 ┆ ┆ ap ┆ _cap ┆ market_ca ┆ --- │\n",
"│ ┆ ┆ ┆ ┆ ┆ --- ┆ --- ┆ p ┆ f64 │\n",
"│ ┆ ┆ ┆ ┆ ┆ f64 ┆ f64 ┆ --- ┆ │\n",
"│ ┆ ┆ ┆ ┆ ┆ ┆ ┆ f64 ┆ │\n",
"╞═══════════╪════════════╪═══════════╪═════════╪═══╪═══════════╪═══════════╪═══════════╪═══════════╡\n",
"│ 000001.SZ ┆ 20200102 ┆ 1.5302e6 ┆ 1806.75 ┆ … ┆ 721.52104 ┆ 2580.5045 ┆ 938.15146 ┆ -0.004746 │\n",
"│ ┆ ┆ ┆ ┆ ┆ 1 ┆ 33 ┆ 4 ┆ │\n",
"│ 000002.SZ ┆ 20200102 ┆ 1012130.4 ┆ 4824.87 ┆ … ┆ 776.91820 ┆ 47.131053 ┆ 1140.2493 ┆ -0.011057 │\n",
"│ ┆ ┆ ┆ ┆ ┆ 1 ┆ ┆ 95 ┆ │\n",
"│ 000004.SZ ┆ 20200102 ┆ 17853.2 ┆ 90.1 ┆ … ┆ -69.58089 ┆ -52.61755 ┆ -24.82135 ┆ -0.000441 │\n",
"│ ┆ ┆ ┆ ┆ ┆ 5 ┆ 4 ┆ 9 ┆ │\n",
"│ 000005.SZ ┆ 20200102 ┆ 104134.12 ┆ 28.82 ┆ … ┆ 142.55925 ┆ 385.57490 ┆ 208.12520 ┆ 0.022337 │\n",
"│ ┆ ┆ ┆ ┆ ┆ 6 ┆ 4 ┆ 2 ┆ │\n",
"│ 000006.SZ ┆ 20200102 ┆ 124751.76 ┆ 190.24 ┆ … ┆ 633.27582 ┆ 650.95370 ┆ 819.10495 ┆ 0.012964 │\n",
"│ ┆ ┆ ┆ ┆ ┆ ┆ 3 ┆ 5 ┆ │\n",
"└───────────┴────────────┴───────────┴─────────┴───┴───────────┴───────────┴───────────┴───────────┘\n",
"\n",
" 验证集前5行预览:\n",
"shape: (5, 53)\n",
"┌───────────┬────────────┬───────────┬─────────┬───┬───────────┬───────────┬───────────┬───────────┐\n",
"│ ts_code ┆ trade_date ┆ vol ┆ low ┆ … ┆ profit_to ┆ cashflow_ ┆ operate_p ┆ future_re │\n",
"│ --- ┆ --- ┆ --- ┆ --- ┆ ┆ _market_c ┆ to_market ┆ rofit_to_ ┆ turn_5 │\n",
"│ str ┆ str ┆ f64 ┆ f64 ┆ ┆ ap ┆ _cap ┆ market_ca ┆ --- │\n",
"│ ┆ ┆ ┆ ┆ ┆ --- ┆ --- ┆ p ┆ f64 │\n",
"│ ┆ ┆ ┆ ┆ ┆ f64 ┆ f64 ┆ --- ┆ │\n",
"│ ┆ ┆ ┆ ┆ ┆ ┆ ┆ f64 ┆ │\n",
"╞═══════════╪════════════╪═══════════╪═════════╪═══╪═══════════╪═══════════╪═══════════╪═══════════╡\n",
"│ 000001.SZ ┆ 20240102 ┆ 1.1584e6 ┆ 1074.93 ┆ … ┆ 2217.6093 ┆ 6486.3743 ┆ 2744.2180 ┆ -0.003256 │\n",
"│ ┆ ┆ ┆ ┆ ┆ 09 ┆ 45 ┆ 84 ┆ │\n",
"│ 000002.SZ ┆ 20240102 ┆ 811106.29 ┆ 1844.3 ┆ … ┆ 1736.4093 ┆ 19.432701 ┆ 2329.7434 ┆ -0.026601 │\n",
"│ ┆ ┆ ┆ ┆ ┆ 99 ┆ ┆ 1 ┆ │\n",
"│ 000004.SZ ┆ 20240102 ┆ 28867.0 ┆ 65.23 ┆ … ┆ -168.7552 ┆ -184.4013 ┆ -192.7135 ┆ -0.014789 │\n",
"│ ┆ ┆ ┆ ┆ ┆ 72 ┆ 85 ┆ 84 ┆ │\n",
"│ 000005.SZ ┆ 20240102 ┆ 63028.0 ┆ 10.01 ┆ … ┆ -96.94997 ┆ -295.0388 ┆ -46.06373 ┆ -0.05395 │\n",
"│ ┆ ┆ ┆ ┆ ┆ 7 ┆ 72 ┆ 6 ┆ │\n",
"│ 000006.SZ ┆ 20240102 ┆ 261947.19 ┆ 176.84 ┆ … ┆ -6.971845 ┆ -51.5536 ┆ -5.32671 ┆ -0.013454 │\n",
"└───────────┴────────────┴───────────┴─────────┴───┴───────────┴───────────┴───────────┴───────────┘\n",
"\n",
" 测试集前5行预览:\n",
"shape: (5, 53)\n",
"┌───────────┬────────────┬───────────┬─────────┬───┬───────────┬───────────┬───────────┬───────────┐\n",
"│ ts_code ┆ trade_date ┆ vol ┆ low ┆ … ┆ profit_to ┆ cashflow_ ┆ operate_p ┆ future_re │\n",
"│ --- ┆ --- ┆ --- ┆ --- ┆ ┆ _market_c ┆ to_market ┆ rofit_to_ ┆ turn_5 │\n",
"│ str ┆ str ┆ f64 ┆ f64 ┆ ┆ ap ┆ _cap ┆ market_ca ┆ --- │\n",
"│ ┆ ┆ ┆ ┆ ┆ --- ┆ --- ┆ p ┆ f64 │\n",
"│ ┆ ┆ ┆ ┆ ┆ f64 ┆ f64 ┆ --- ┆ │\n",
"│ ┆ ┆ ┆ ┆ ┆ ┆ ┆ f64 ┆ │\n",
"╞═══════════╪════════════╪═══════════╪═════════╪═══╪═══════════╪═══════════╪═══════════╪═══════════╡\n",
"│ 000001.SZ ┆ 20250102 ┆ 1.8196e6 ┆ 1455.46 ┆ … ┆ 1791.1304 ┆ 6183.5904 ┆ 2158.1117 ┆ -0.002622 │\n",
"│ ┆ ┆ ┆ ┆ ┆ 08 ┆ 38 ┆ 45 ┆ │\n",
"│ 000002.SZ ┆ 20250102 ┆ 1.1827e6 ┆ 1284.65 ┆ … ┆ -1933.116 ┆ -1110.658 ┆ -1729.069 ┆ -0.022509 │\n",
"│ ┆ ┆ ┆ ┆ ┆ 105 ┆ 303 ┆ 737 ┆ │\n",
"│ 000004.SZ ┆ 20250102 ┆ 119760.37 ┆ 54.17 ┆ … ┆ -199.1144 ┆ -126.8907 ┆ -197.3308 ┆ -0.064897 │\n",
"│ ┆ ┆ ┆ ┆ ┆ 31 ┆ 63 ┆ 47 ┆ │\n",
"│ 000006.SZ ┆ 20250102 ┆ 307195.1 ┆ 285.33 ┆ … ┆ -646.1294 ┆ 74.343232 ┆ -637.5489 ┆ -0.048278 │\n",
"│ ┆ ┆ ┆ ┆ ┆ 33 ┆ ┆ 17 ┆ │\n",
"│ 000007.SZ ┆ 20250102 ┆ 68219.01 ┆ 57.49 ┆ … ┆ 6.740918 ┆ 783.72753 ┆ 22.556002 ┆ 0.015649 │\n",
"│ ┆ ┆ ┆ ┆ ┆ ┆ 9 ┆ ┆ │\n",
"└───────────┴────────────┴───────────┴─────────┴───┴───────────┴───────────┴───────────┴───────────┘\n"
]
}
],
"execution_count": 7
},
{
"cell_type": "code",
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-08T06:10:40.545732Z",
"start_time": "2026-03-08T06:10:39.198908Z"
}
},
"source": [
"# 步骤 3: 训练集数据处理\n",
"print(\"\\n[步骤 3/6] 训练集数据处理\")\n",
"print(\"-\" * 60)\n",
"fitted_processors = []\n",
"if processors:\n",
" for i, processor in enumerate(processors, 1):\n",
" print(\n",
" f\" [{i}/{len(processors)}] 应用处理器: {processor.__class__.__name__}\"\n",
" )\n",
" train_data_before = len(train_data)\n",
" train_data = processor.fit_transform(train_data)\n",
" train_data_after = len(train_data)\n",
" fitted_processors.append(processor)\n",
" print(f\" 处理前记录数: {train_data_before}\")\n",
" print(f\" 处理后记录数: {train_data_after}\")\n",
" if train_data_before != train_data_after:\n",
" print(f\" 删除记录数: {train_data_before - train_data_after}\")\n",
"\n",
"print(\"\\n 训练集处理后前5行预览:\")\n",
"print(train_data.head())\n",
"print(f\"\\n 训练集特征统计:\")\n",
"print(f\" 特征数: {len(feature_cols)}\")\n",
"print(f\" 样本数: {len(train_data)}\")\n",
"print(f\" 缺失值统计:\")\n",
"for col in feature_cols[:5]: # 只显示前5个特征的缺失值\n",
" null_count = train_data[col].null_count()\n",
" if null_count > 0:\n",
" print(\n",
" f\" {col}: {null_count} ({null_count / len(train_data) * 100:.2f}%)\"\n",
" )"
],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"[步骤 3/6] 训练集数据处理\n",
"------------------------------------------------------------\n",
" [1/3] 应用处理器: NullFiller\n",
" 处理前记录数: 2991506\n",
" 处理后记录数: 2991506\n",
" [2/3] 应用处理器: Winsorizer\n",
" 处理前记录数: 2991506\n",
" 处理后记录数: 2991506\n",
" [3/3] 应用处理器: StandardScaler\n",
" 处理前记录数: 2991506\n",
" 处理后记录数: 2991506\n",
"\n",
" 训练集处理后前5行预览:\n",
"shape: (5, 53)\n",
"┌───────────┬───────────┬───────────┬──────────┬───┬───────────┬───────────┬───────────┬───────────┐\n",
"│ ts_code ┆ trade_dat ┆ vol ┆ low ┆ … ┆ profit_to ┆ cashflow_ ┆ operate_p ┆ future_re │\n",
"│ --- ┆ e ┆ --- ┆ --- ┆ ┆ _market_c ┆ to_market ┆ rofit_to_ ┆ turn_5 │\n",
"│ str ┆ --- ┆ f64 ┆ f64 ┆ ┆ ap ┆ _cap ┆ market_ca ┆ --- │\n",
"│ ┆ str ┆ ┆ ┆ ┆ --- ┆ --- ┆ p ┆ f64 │\n",
"│ ┆ ┆ ┆ ┆ ┆ f64 ┆ f64 ┆ --- ┆ │\n",
"│ ┆ ┆ ┆ ┆ ┆ ┆ ┆ f64 ┆ │\n",
"╞═══════════╪═══════════╪═══════════╪══════════╪═══╪═══════════╪═══════════╪═══════════╪═══════════╡\n",
"│ 000001.SZ ┆ 20200102 ┆ 4.749919 ┆ 7.140527 ┆ … ┆ 1.223327 ┆ 2.58711 ┆ 1.420902 ┆ -0.004746 │\n",
"│ 000002.SZ ┆ 20200102 ┆ 2.92576 ┆ 7.140527 ┆ … ┆ 1.348024 ┆ -0.147052 ┆ 1.826328 ┆ -0.011057 │\n",
"│ 000004.SZ ┆ 20200102 ┆ -0.574944 ┆ 0.102278 ┆ … ┆ -0.557417 ┆ -0.254707 ┆ -0.510901 ┆ -0.000441 │\n",
"│ 000005.SZ ┆ 20200102 ┆ -0.271162 ┆ -0.29962 ┆ … ┆ -0.079896 ┆ 0.218216 ┆ -0.043591 ┆ 0.022337 │\n",
"│ 000006.SZ ┆ 20200102 ┆ -0.19857 ┆ 0.759033 ┆ … ┆ 1.02469 ┆ 0.504628 ┆ 1.182085 ┆ 0.012964 │\n",
"└───────────┴───────────┴───────────┴──────────┴───┴───────────┴───────────┴───────────┴───────────┘\n",
"\n",
" 训练集特征统计:\n",
" 特征数: 34\n",
" 样本数: 2991506\n",
" 缺失值统计:\n",
" ma5: 11541 (0.39%)\n",
" ma10: 25950 (0.87%)\n",
" ma20: 54850 (1.83%)\n",
" ma_ratio: 54850 (1.83%)\n",
" volatility_5: 11541 (0.39%)\n"
]
}
],
"execution_count": 8
},
{
"cell_type": "code",
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-08T06:10:58.361151Z",
"start_time": "2026-03-08T06:10:40.549989Z"
}
},
"source": [
"# 步骤 4: 训练模型\n",
"print(\"\\n[步骤 4/6] 训练模型\")\n",
"print(\"-\" * 60)\n",
"print(f\" 模型类型: LightGBM\")\n",
"print(f\" 训练样本数: {len(train_data)}\")\n",
"print(f\" 特征数: {len(feature_cols)}\")\n",
"print(f\" 目标变量: {target_col}\")\n",
"\n",
"X_train = train_data.select(feature_cols)\n",
"y_train = train_data.select(target_col).to_series()\n",
"\n",
"print(f\"\\n 目标变量统计:\")\n",
"print(f\" 均值: {y_train.mean():.6f}\")\n",
"print(f\" 标准差: {y_train.std():.6f}\")\n",
"print(f\" 最小值: {y_train.min():.6f}\")\n",
"print(f\" 最大值: {y_train.max():.6f}\")\n",
"print(f\" 缺失值: {y_train.null_count()}\")\n",
"\n",
"print(\"\\n 开始训练...\")\n",
"model.fit(X_train, y_train)\n",
"print(\" 训练完成!\")"
],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"[步骤 4/6] 训练模型\n",
"------------------------------------------------------------\n",
" 模型类型: LightGBM\n",
" 训练样本数: 2991506\n",
" 特征数: 34\n",
" 目标变量: future_return_5\n",
"\n",
" 目标变量统计:\n",
" 均值: 0.001610\n",
" 标准差: 0.059623\n",
" 最小值: -0.155098\n",
" 最大值: 0.212842\n",
" 缺失值: 0\n",
"\n",
" 开始训练...\n",
" 训练完成!\n"
]
}
],
"execution_count": 9
},
{
"cell_type": "code",
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-08T06:10:58.523203Z",
"start_time": "2026-03-08T06:10:58.364995Z"
}
},
"source": [
"# 步骤 5: 测试集数据处理\n",
"print(\"\\n[步骤 5/6] 测试集数据处理\")\n",
"print(\"-\" * 60)\n",
"if processors and test_data is not train_data:\n",
" for i, processor in enumerate(fitted_processors, 1):\n",
" print(\n",
" f\" [{i}/{len(fitted_processors)}] 应用处理器: {processor.__class__.__name__}\"\n",
" )\n",
" test_data_before = len(test_data)\n",
" test_data = processor.transform(test_data)\n",
" test_data_after = len(test_data)\n",
" print(f\" 处理前记录数: {test_data_before}\")\n",
" print(f\" 处理后记录数: {test_data_after}\")\n",
"else:\n",
" print(\" 跳过测试集处理\")"
],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"[步骤 5/6] 测试集数据处理\n",
"------------------------------------------------------------\n",
" [1/3] 应用处理器: NullFiller\n",
" 处理前记录数: 893340\n",
" 处理后记录数: 893340\n",
" [2/3] 应用处理器: Winsorizer\n",
" 处理前记录数: 893340\n",
" 处理后记录数: 893340\n",
" [3/3] 应用处理器: StandardScaler\n",
" 处理前记录数: 893340\n",
" 处理后记录数: 893340\n"
]
}
],
"execution_count": 10
},
{
"cell_type": "code",
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-08T06:10:59.948592Z",
"start_time": "2026-03-08T06:10:58.533879Z"
}
},
"source": [
"# 步骤 6: 生成预测\n",
"print(\"\\n[步骤 6/6] 生成预测\")\n",
"print(\"-\" * 60)\n",
"X_test = test_data.select(feature_cols)\n",
"print(f\" 测试样本数: {len(X_test)}\")\n",
"print(\" 预测中...\")\n",
"predictions = model.predict(X_test)\n",
"print(f\" 预测完成!\")\n",
"\n",
"print(f\"\\n 预测结果统计:\")\n",
"print(f\" 均值: {predictions.mean():.6f}\")\n",
"print(f\" 标准差: {predictions.std():.6f}\")\n",
"print(f\" 最小值: {predictions.min():.6f}\")\n",
"print(f\" 最大值: {predictions.max():.6f}\")\n",
"\n",
"# 保存结果到 trainer\n",
"trainer.results = test_data.with_columns([pl.Series(\"prediction\", predictions)])"
],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"[步骤 6/6] 生成预测\n",
"------------------------------------------------------------\n",
" 测试样本数: 893340\n",
" 预测中...\n",
" 预测完成!\n",
"\n",
" 预测结果统计:\n",
" 均值: -0.000990\n",
" 标准差: 0.008152\n",
" 最小值: -0.145569\n",
" 最大值: 0.102702\n"
]
}
],
"execution_count": 11
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 4.3 训练指标曲线"
]
},
{
"cell_type": "code",
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-08T06:11:05.252510Z",
"start_time": "2026-03-08T06:10:59.952123Z"
}
},
"source": [
"print(\"\\n\" + \"=\" * 80)\n",
"print(\"训练指标曲线\")\n",
"print(\"=\" * 80)\n",
"\n",
"# 重新训练以收集指标(因为之前的训练没有保存评估结果)\n",
"print(\"\\n重新训练模型以收集训练指标...\")\n",
"\n",
"import lightgbm as lgb\n",
"\n",
"# 准备数据(使用 val 做验证test 不参与训练过程)\n",
"X_train_np = X_train.to_numpy()\n",
"y_train_np = y_train.to_numpy()\n",
"X_val_np = val_data.select(feature_cols).to_numpy()\n",
"y_val_np = val_data.select(target_col).to_series().to_numpy()\n",
"\n",
"# 创建数据集\n",
"train_dataset = lgb.Dataset(X_train_np, label=y_train_np)\n",
"val_dataset = lgb.Dataset(X_val_np, label=y_val_np, reference=train_dataset)\n",
"\n",
"# 用于存储评估结果\n",
"evals_result = {}\n",
"\n",
"# 使用与原模型相同的参数重新训练\n",
"# 正确的三分法train用于训练val用于验证test不参与训练过程\n",
"# 添加早停如果验证指标连续100轮没有改善则停止训练\n",
"booster_with_eval = lgb.train(\n",
" MODEL_PARAMS,\n",
" train_dataset,\n",
" num_boost_round=MODEL_PARAMS.get(\"n_estimators\", 100),\n",
" valid_sets=[train_dataset, val_dataset],\n",
" valid_names=[\"train\", \"val\"],\n",
" callbacks=[\n",
" lgb.record_evaluation(evals_result),\n",
" lgb.early_stopping(stopping_rounds=100, verbose=True),\n",
" ],\n",
")\n",
"\n",
"print(\"训练完成,指标已收集\")\n",
"\n",
"# 获取指标名称\n",
"metric_name = list(evals_result[\"train\"].keys())[0]\n",
"print(f\"\\n评估指标: {metric_name}\")\n",
"\n",
"# 提取训练和验证指标\n",
"train_metric = evals_result[\"train\"][metric_name]\n",
"val_metric = evals_result[\"val\"][metric_name]\n",
"\n",
"# 显示早停信息\n",
"actual_rounds = len(train_metric)\n",
"expected_rounds = MODEL_PARAMS.get(\"n_estimators\", 100)\n",
"print(f\"\\n[早停信息]\")\n",
"print(f\" 配置的最大轮数: {expected_rounds}\")\n",
"print(f\" 实际训练轮数: {actual_rounds}\")\n",
"if actual_rounds < expected_rounds:\n",
" print(f\" 早停状态: 已触发连续100轮验证指标未改善\")\n",
"else:\n",
" print(f\" 早停状态: 未触发(达到最大轮数)\")\n",
"\n",
"print(f\"\\n最终指标:\")\n",
"print(f\" 训练 {metric_name}: {train_metric[-1]:.6f}\")\n",
"print(f\" 验证 {metric_name}: {val_metric[-1]:.6f}\")"
],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"================================================================================\n",
"训练指标曲线\n",
"================================================================================\n",
"\n",
"重新训练模型以收集训练指标...\n",
"Training until validation scores don't improve for 100 rounds\n",
"Early stopping, best iteration is:\n",
"[147]\ttrain's l1: 0.0424037\tval's l1: 0.0535696\n",
"训练完成,指标已收集\n",
"\n",
"评估指标: l1\n",
"\n",
"[早停信息]\n",
" 配置的最大轮数: 1000\n",
" 实际训练轮数: 247\n",
" 早停状态: 已触发连续100轮验证指标未改善\n",
"\n",
"最终指标:\n",
" 训练 l1: 0.042166\n",
" 验证 l1: 0.053583\n"
]
}
],
"execution_count": 12
},
{
"cell_type": "code",
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-08T06:11:05.495565Z",
"start_time": "2026-03-08T06:11:05.256254Z"
}
},
"source": [
"# 绘制训练指标曲线\n",
"import matplotlib.pyplot as plt\n",
"\n",
"fig, ax = plt.subplots(figsize=(12, 6))\n",
"\n",
"# 绘制训练集和验证集的指标曲线注意val用于验证test不参与训练\n",
"iterations = range(1, len(train_metric) + 1)\n",
"ax.plot(iterations, train_metric, label=f\"Train {metric_name}\", linewidth=2, color=\"blue\")\n",
"ax.plot(iterations, val_metric, label=f\"Validation {metric_name}\", linewidth=2, color=\"red\")\n",
"\n",
"ax.set_xlabel(\"Iteration\", fontsize=12)\n",
"ax.set_ylabel(metric_name.upper(), fontsize=12)\n",
"ax.set_title(f\"Training and Validation {metric_name.upper()} Curve\", fontsize=14, fontweight=\"bold\")\n",
"ax.legend(fontsize=10)\n",
"ax.grid(True, alpha=0.3)\n",
"\n",
"# 标记最佳验证指标点(用于早停决策)\n",
"best_iter = val_metric.index(min(val_metric))\n",
"best_metric = min(val_metric)\n",
"ax.axvline(x=best_iter + 1, color=\"green\", linestyle=\"--\", alpha=0.7, label=f\"Best Iteration ({best_iter + 1})\")\n",
"ax.scatter([best_iter + 1], [best_metric], color=\"green\", s=100, zorder=5)\n",
"ax.annotate(\n",
" f\"Best: {best_metric:.6f}\\nIter: {best_iter + 1}\",\n",
" xy=(best_iter + 1, best_metric),\n",
" xytext=(best_iter + 1 + len(iterations) * 0.1, best_metric),\n",
" fontsize=9,\n",
" arrowprops=dict(arrowstyle=\"->\", color=\"green\", alpha=0.7),\n",
")\n",
"\n",
"plt.tight_layout()\n",
"plt.show()\n",
"\n",
"print(f\"\\n[指标分析]\")\n",
"print(f\" 最佳验证 {metric_name}: {best_metric:.6f}\")\n",
"print(f\" 最佳迭代轮数: {best_iter + 1}\")\n",
"print(f\" 早停建议: 如果验证指标连续10轮不下降建议在第 {best_iter + 1} 轮停止训练\")\n",
"print(f\"\\n[重要提醒] 验证集仅用于早停/调参,测试集完全独立于训练过程!\")"
],
"outputs": [
{
"data": {
"text/plain": [
"<Figure size 1200x600 with 1 Axes>"
],
"image/png": "iVBORw0KGgoAAAANSUhEUgAABKUAAAJOCAYAAABm7rQwAAAAOnRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjEwLjgsIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcvwVt1zgAAAAlwSFlzAAAPYQAAD2EBqD+naQAAemRJREFUeJzt3QecE2X+x/Hf7mYLCyx96VWQLggIBxZQEBAUEUTFQpHDjiJWFEFExYqgougpqP+T4p5iRTkE9VSwICDCiV1BOoj07fm/fs/exElIstlsdjbZ/bx9PWYymSST2QyTfPN7nolzu91uAQAAAAAAABwU7+STAQAAAAAAAIpQCgAAAAAAAI4jlAIAAAAAAIDjCKUAAAAAAADgOEIpAAAAAAAAOI5QCgAAAAAAAI4jlAIAAAAAAIDjCKUAAAAAAADgOEIpAAAAAAAAOI5QCgCAKPTCCy9IXFycp0VCkyZNPI939913R+QxyyrdPta20u0WjezvD32/FPe9UxqvuSTe5wAAIHYQSgEA4Ce0CbV9+OGHbD/IAw884PW++OKLLwJuldGjR3uWS0pKkt27d5fJLVhWAid7WKft119/LfQ+y5YtkxtuuEF69OghqampRb5/IG63W5YsWSKXXXaZHH/88ZKWliaJiYlSu3Zt6d27tzz44IOyffv2sB8fAACnuRx/RgAAUKiTTjpJHn744YhuqTvvvFP2799vpvXLMiJHQwLdvvn5+eb6//3f/0nXrl2PWe7o0aPy6quveq4PHDhQatWqFfXvnZISS+taFLNnz5Y33ngjoo+5ZcsWufjii+WTTz455rZdu3bJihUrTPv222+9KucAAIhmhFIAAPgJbdS+ffvk/vvv91w/88wzpW/fvl7b67jjjgu4/Q4cOGAqGcLRtm1b0yJp7NixEX08/KV+/frm/bF06VJzfeHChTJjxgxTxWK3ePFiOXjwoOf6qFGjIr4ZS+K9U1JiaV2LQiuiGjRoIF26dJG8vDx56623ivV4O3fulJ49e8ovv/zimde0aVMZNGiQqZLSf6s+++wzv4FVpOnrycrKMhVgAAAUF933AACwhTY333yzp/mGOFpdZL/9/PPPl0aNGnl15Xv++eelU6dOUqFCBTnttNPM/fSL5Pjx4+XUU0+Vhg0bSsWKFSU5OdkEGeecc47fL6zBuj716tXLM19DjR9++EGGDx8uNWvWlJSUFPP8/qo0Ao0ppettf66ff/5ZnnrqKTnhhBPM46Wnp8vf//5388XX15EjR2TixIlmO+iyGjDMmTPHvOZwujnqcmPGjDGvoW7dumY76Zff5s2bm25v33zzzTH30W1gPY9uG+2+dMUVV3ju37p1a/nHP/7h9/n08c4++2wTHmrr37+/rFmzRsKh62fZs2ePvPvuu8csoxVUFt2uWimltFpo8ODBpktW9erVTZhVtWpVU2113333yeHDhyPWbS6c16xhmlaD6XtCQxDtdlipUiVp06aNXHfddV5d0nRan9e+PZR9naz3X2HrqpVljz32mJx88slSrVo187z6/AMGDJBXXnnlmOWL816OpPnz55vKJt1uQ4YMKfbj6b8f9kDq6quvlu+//15mzpxp9r+HHnpI/vOf/8h3331nwtFA+0awbWX/G/reb/Pmzebvr9te35u6TQvbv7t16+a53fff0q+//louv/xyE+rrv5X6XjrxxBPNjwBFea8DAMoANwAA8OuXX35x66HSalOmTAl6+6mnnup1vUOHDma5t956y2u+vzZ16lSvx543b57X7XY9e/b0zD/hhBPclStXPubx4uLi3O+//77X/Ro3buz3tXzwwQde9z3llFP8ruNpp53m9XjZ2dnHvGarnXPOOV7X9TlCcdNNNwXdTklJSe5ly5Z53WfkyJGe25s1a+auW7eu3/s+//zzXvf78ssv3ZUqVTpmuZSUFHfv3r0913W7hSIzM9NdtWpVz/3OP/98r9u3b9/uTkhI8Nx+4403em6rUaNG0Nfdvn1798GDB70ez367vl9Cee+E+5qHDh0adP3S0tLc69ev97tf+GvW+y/Yuur2atu2bdDH0fXKyckp9ns5GF1X+3319RWF72ss6v23bdtm9mfr/h07dnTn5eWFdF/7vqH/btj5biv7etnv16JFC3edOnW8ll28eLHXvn/FFVd4PfaPP/7otfzKlSs9tz311FNul8sV8G/apk0b87cHAJQPdN8DACBCPv74Y2ncuLEMHTrUVPfoOC/K5XJJx44dTVceHT9Iq1O0GuDTTz+VDz74wCwzbdo0UyGk1VNFsX79elNBcuONN5qqEq0I0u41mllo9Y0OflxU2gVI76eVYa+//rqnOkkrMbSL0N/+9jdzfdasWeY1W7Qa5dxzzzVVEG+++aaEQ6vItJtS+/btTcWQVlHs3btX3nnnHTNWTnZ2tlx//fXy3//+1+/9tTJGK2K0kkTv+/TTT5vtorSaRKszlG4fnT506JC5rtUcOl6PVpPpmE/Lly8v8rprVdZFF11kKsWUVsD9+eefpuLJqp7Rv42/rnva1ev000837x/9e+r6aWXMokWLzHtF/wZanXLrrbdKuIrzmvU1aNdVrTqzKpa0S5lWAmkVjXZVve2228wg3Pp30/fe6tWrzfpb7GNHhTKm2SWXXCIbN270XNfKRK3M0kHEV61aZebpemt1zeTJk4v1Xo5m+m9EQQZZYOTIkRIf71xnB63EVFrx1aFDB/ntt9+kSpUqphLO2v//9a9/yZNPPunprrpgwQLP/Vu1aiXdu3c30ytXrjSVddbYa7r9tVJPu7S++OKLpsJQ9+0RI0bIv//9b8deIwCg9BBKAQAQITrGi3aDskIIi37p0qbdbdauXWvOtqZf3rQL0ueff266wOXm5ppBirWLTFFosKBhgnZ9URrIaJce9eWXX4b1Os477zzzZV8fW7sNaZcnK0zRx7S+yD/33HOe+2iwoV/yNQiyAhf9kllUU6dONV9YNdDQEEpDHe0ydNZZZ5nrSi+1a5R2hfRHx3PScExpt0J9DUq7NumX38qVK5vtbu8KqOOJaTCoNPjRbkX6Bbmo9Iu6FUrpuDvaxUy7Evp23dO/l4Z4lnXr1pnxzPRLu4Y8GkRpANS5c2cToCgdr6o4oVRxXrP+rXNycszfWEMKDaE0SNPAZ968eWYZff/qMhq6avdW7ZpnD6V0Xqh0e+jjWXT99MxySgMo7QprBVMajk6aNMlvUBPqezmabd261eu6hjxO039T9GyCdvoe1YBYQ84//vjDvD+1W6hvKGXvxvnII494AintFqj/dll/twsvvNBzcgANHjVwt+8jAICyiVAKAIAIufbaa48JpJSO1aJVHxo4BPP7778X+Tm1AsEKpFTLli090+GOm6NVRtb4Plr1omNVaVWM/TH1i6iGPJZhw4Z5Ainri2g4oZR+GdUxfzSYKWxb+Qul6tWr5wmkfLeHtf4aSmnoZad/H4uGKjrWlxW2FIV+qdZqHquSS4MoDaU2bNhgghZ/X9T1S/rtt99uwhWtBAv2moujOK/55ZdfNqFOsKBOQzi9XcfyKi4rcLJXB1kSEhLk0ksv9SyjgYi+FzXEC+e9jOC0Mk7/bfNX1aj7vfWe0SBKQykNk6z3v/6t7EG7VodadBwqvT0Q/feSUAoAyj4GOgcAIEICVTDoANaFBVLWl/qi0gol3y5kFnuXn0g9plXloBVMdnXq1Al6PRTbtm0z26qwQCrYtgq27sHWXyto7LQ6K1z2AEW/hGs3vJdeeskzT7u+abc5y+OPP266tgULpMJ9f9iF+5q1+k+7U4VSOVbcdbRo0BRs3XyvBwqYQnkvRzvfLr2bNm0K63F8/z0I9W+lFXTaBdkfqzus0pMraNWndlO1aIWjPaT0/bsGoxWlAICyj1AKAIAI0coBX1rBoWMsWTSM0IoX/UKsXxJ1jKnisMZwsfg7g1lJPKaOKWNnjZ9l2bFjR5GfV8dg0i+1lkcffdQEKbqd7GMLRWJ7+Fa0+a6/VU0TDq0MsSpAdN21G5v9i7pWk9SoUcNz3d7FTSu9tJudBgZ631tuuUUiJdzXnJGR4Ql
},
"metadata": {},
"output_type": "display_data",
"jetTransient": {
"display_id": null
}
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"[指标分析]\n",
" 最佳验证 l1: 0.053570\n",
" 最佳迭代轮数: 147\n",
" 早停建议: 如果验证指标连续10轮不下降建议在第 147 轮停止训练\n",
"\n",
"[重要提醒] 验证集仅用于早停/调参,测试集完全独立于训练过程!\n"
]
}
],
"execution_count": 13
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 4.4 查看结果"
]
},
{
"cell_type": "code",
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-08T06:11:05.540372Z",
"start_time": "2026-03-08T06:11:05.504804Z"
}
},
"source": [
"print(\"\\n\" + \"=\" * 80)\n",
"print(\"训练结果\")\n",
"print(\"=\" * 80)\n",
"\n",
"results = trainer.results\n",
"\n",
"print(f\"\\n结果数据形状: {results.shape}\")\n",
"print(f\"结果列: {results.columns}\")\n",
"print(f\"\\n结果前10行预览:\")\n",
"print(results.head(10))\n",
"print(f\"\\n结果后5行预览:\")\n",
"print(results.tail())\n",
"\n",
"print(f\"\\n每日预测样本数统计:\")\n",
"daily_counts = results.group_by(\"trade_date\").agg(pl.len()).sort(\"trade_date\")\n",
"print(f\" 最小: {daily_counts['len'].min()}\")\n",
"print(f\" 最大: {daily_counts['len'].max()}\")\n",
"print(f\" 平均: {daily_counts['len'].mean():.2f}\")\n",
"\n",
"# 展示某一天的前10个预测结果\n",
"sample_date = results[\"trade_date\"][0]\n",
"sample_data = results.filter(results[\"trade_date\"] == sample_date).head(10)\n",
"print(f\"\\n示例日期 {sample_date} 的前10条预测:\")\n",
"print(sample_data.select([\"ts_code\", \"trade_date\", target_col, \"prediction\"]))"
],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"================================================================================\n",
"训练结果\n",
"================================================================================\n",
"\n",
"结果数据形状: (893340, 54)\n",
"结果列: ['ts_code', 'trade_date', 'vol', 'low', 'high', 'volume_ratio', 'turnover_rate', 'close', 'total_mv', 'f_ann_date', 'total_profit', 'ebit', 'operate_profit', 'ebitda', 'n_income', 'money_cap', 'total_liab', 'n_cashflow_act', 'ma5', 'ma10', 'ma20', 'ma_ratio', 'volatility_5', 'volatility_20', 'vol_ratio', 'return_10', 'return_20', 'return_diff', 'vol_ma5', 'vol_ma20', 'market_cap_rank', 'high_low_ratio', 'turnover_rate_mean_5', 'bbi_ratio_factor', 'volume_change_rate', 'turnover_deviation', 'vol_std_5', 'std_return_5', 'std_return_90', 'cs_rank_volume_ratio', 'cs_rank_turnover_rate', 'n_income_rank', 'operate_profit_rank', 'total_profit_rank', 'ebit_rank', 'ebitda_rank', 'total_liab_rank', 'money_cap_rank', 'n_cashflow_act_rank', 'profit_to_market_cap', 'cashflow_to_market_cap', 'operate_profit_to_market_cap', 'future_return_5', 'prediction']\n",
"\n",
"结果前10行预览:\n",
"shape: (10, 54)\n",
"┌───────────┬───────────┬───────────┬───────────┬───┬───────────┬───────────┬───────────┬──────────┐\n",
"│ ts_code ┆ trade_dat ┆ vol ┆ low ┆ … ┆ cashflow_ ┆ operate_p ┆ future_re ┆ predicti │\n",
"│ --- ┆ e ┆ --- ┆ --- ┆ ┆ to_market ┆ rofit_to_ ┆ turn_5 ┆ on │\n",
"│ str ┆ --- ┆ f64 ┆ f64 ┆ ┆ _cap ┆ market_ca ┆ --- ┆ --- │\n",
"│ ┆ str ┆ ┆ ┆ ┆ --- ┆ p ┆ f64 ┆ f64 │\n",
"│ ┆ ┆ ┆ ┆ ┆ f64 ┆ --- ┆ ┆ │\n",
"│ ┆ ┆ ┆ ┆ ┆ ┆ f64 ┆ ┆ │\n",
"╞═══════════╪═══════════╪═══════════╪═══════════╪═══╪═══════════╪═══════════╪═══════════╪══════════╡\n",
"│ 000001.SZ ┆ 20250102 ┆ 5.587779 ┆ 7.140527 ┆ … ┆ 4.53796 ┆ 3.597925 ┆ -0.002622 ┆ -0.00503 │\n",
"│ ┆ ┆ ┆ ┆ ┆ ┆ ┆ ┆ 1 │\n",
"│ 000002.SZ ┆ 20250102 ┆ 3.526191 ┆ 7.140527 ┆ … ┆ -1.396605 ┆ -3.929764 ┆ -0.022509 ┆ 0.002998 │\n",
"│ 000004.SZ ┆ 20250102 ┆ -0.216144 ┆ -0.133365 ┆ … ┆ -0.334867 ┆ -0.856969 ┆ -0.064897 ┆ 0.000258 │\n",
"│ 000006.SZ ┆ 20250102 ┆ 0.443786 ┆ 1.382669 ┆ … ┆ -0.117683 ┆ -1.740083 ┆ -0.048278 ┆ 0.008562 │\n",
"│ 000007.SZ ┆ 20250102 ┆ -0.397614 ┆ -0.111591 ┆ … ┆ 0.647925 ┆ -0.415858 ┆ 0.015649 ┆ -0.00104 │\n",
"│ ┆ ┆ ┆ ┆ ┆ ┆ ┆ ┆ 9 │\n",
"│ 000008.SZ ┆ 20250102 ┆ 3.219998 ┆ -0.075651 ┆ … ┆ -0.389118 ┆ -1.046469 ┆ -0.066939 ┆ -0.00301 │\n",
"│ ┆ ┆ ┆ ┆ ┆ ┆ ┆ ┆ 4 │\n",
"│ 000009.SZ ┆ 20250102 ┆ 0.204049 ┆ 0.03066 ┆ … ┆ 0.246551 ┆ 0.57786 ┆ -0.036045 ┆ 0.014509 │\n",
"│ 000010.SZ ┆ 20250102 ┆ 0.584268 ┆ -0.299947 ┆ … ┆ -0.787198 ┆ -1.017736 ┆ 0.092123 ┆ -0.00703 │\n",
"│ ┆ ┆ ┆ ┆ ┆ ┆ ┆ ┆ 3 │\n",
"│ 000011.SZ ┆ 20250102 ┆ -0.488782 ┆ -0.241316 ┆ … ┆ -1.745853 ┆ -0.461702 ┆ -0.022094 ┆ 0.002963 │\n",
"│ 000012.SZ ┆ 20250102 ┆ -0.060654 ┆ 0.51434 ┆ … ┆ 0.727138 ┆ 0.565626 ┆ -0.029188 ┆ 0.012271 │\n",
"└───────────┴───────────┴───────────┴───────────┴───┴───────────┴───────────┴───────────┴──────────┘\n",
"\n",
"结果后5行预览:\n",
"shape: (5, 54)\n",
"┌───────────┬───────────┬───────────┬───────────┬───┬───────────┬───────────┬───────────┬──────────┐\n",
"│ ts_code ┆ trade_dat ┆ vol ┆ low ┆ … ┆ cashflow_ ┆ operate_p ┆ future_re ┆ predicti │\n",
"│ --- ┆ e ┆ --- ┆ --- ┆ ┆ to_market ┆ rofit_to_ ┆ turn_5 ┆ on │\n",
"│ str ┆ --- ┆ f64 ┆ f64 ┆ ┆ _cap ┆ market_ca ┆ --- ┆ --- │\n",
"│ ┆ str ┆ ┆ ┆ ┆ --- ┆ p ┆ f64 ┆ f64 │\n",
"│ ┆ ┆ ┆ ┆ ┆ f64 ┆ --- ┆ ┆ │\n",
"│ ┆ ┆ ┆ ┆ ┆ ┆ f64 ┆ ┆ │\n",
"╞═══════════╪═══════════╪═══════════╪═══════════╪═══╪═══════════╪═══════════╪═══════════╪══════════╡\n",
"│ 605588.SH ┆ 20260306 ┆ -0.594236 ┆ -0.130348 ┆ … ┆ 0.166521 ┆ -0.648847 ┆ null ┆ -0.00224 │\n",
"│ ┆ ┆ ┆ ┆ ┆ ┆ ┆ ┆ 6 │\n",
"│ 605589.SH ┆ 20260306 ┆ 0.108199 ┆ -0.253383 ┆ … ┆ -0.32234 ┆ 0.153492 ┆ null ┆ -0.00560 │\n",
"│ ┆ ┆ ┆ ┆ ┆ ┆ ┆ ┆ 4 │\n",
"│ 605598.SH ┆ 20260306 ┆ -0.54733 ┆ -0.01433 ┆ … ┆ -0.191519 ┆ -0.305813 ┆ null ┆ 0.003137 │\n",
"│ 605599.SH ┆ 20260306 ┆ -0.409848 ┆ -0.309195 ┆ … ┆ 0.812476 ┆ 0.480653 ┆ null ┆ -0.00217 │\n",
"│ ┆ ┆ ┆ ┆ ┆ ┆ ┆ ┆ 3 │\n",
"│ 689009.SH ┆ 20260306 ┆ -0.419186 ┆ -0.168059 ┆ … ┆ 1.301768 ┆ 0.905556 ┆ null ┆ 0.001416 │\n",
"└───────────┴───────────┴───────────┴───────────┴───┴───────────┴───────────┴───────────┴──────────┘\n",
"\n",
"每日预测样本数统计:\n",
" 最小: 1040\n",
" 最大: 3192\n",
" 平均: 3167.87\n",
"\n",
"示例日期 20250102 的前10条预测:\n",
"shape: (10, 4)\n",
"┌───────────┬────────────┬─────────────────┬────────────┐\n",
"│ ts_code ┆ trade_date ┆ future_return_5 ┆ prediction │\n",
"│ --- ┆ --- ┆ --- ┆ --- │\n",
"│ str ┆ str ┆ f64 ┆ f64 │\n",
"╞═══════════╪════════════╪═════════════════╪════════════╡\n",
"│ 000001.SZ ┆ 20250102 ┆ -0.002622 ┆ -0.005031 │\n",
"│ 000002.SZ ┆ 20250102 ┆ -0.022509 ┆ 0.002998 │\n",
"│ 000004.SZ ┆ 20250102 ┆ -0.064897 ┆ 0.000258 │\n",
"│ 000006.SZ ┆ 20250102 ┆ -0.048278 ┆ 0.008562 │\n",
"│ 000007.SZ ┆ 20250102 ┆ 0.015649 ┆ -0.001049 │\n",
"│ 000008.SZ ┆ 20250102 ┆ -0.066939 ┆ -0.003014 │\n",
"│ 000009.SZ ┆ 20250102 ┆ -0.036045 ┆ 0.014509 │\n",
"│ 000010.SZ ┆ 20250102 ┆ 0.092123 ┆ -0.007033 │\n",
"│ 000011.SZ ┆ 20250102 ┆ -0.022094 ┆ 0.002963 │\n",
"│ 000012.SZ ┆ 20250102 ┆ -0.029188 ┆ 0.012271 │\n",
"└───────────┴────────────┴─────────────────┴────────────┘\n"
]
}
],
"execution_count": 14
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 4.4 保存结果"
]
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-08T06:11:05.964933Z",
"start_time": "2026-03-08T06:11:05.551387Z"
}
},
"cell_type": "code",
"source": [
"print(\"\\n\" + \"=\" * 80)\n",
"print(\"保存预测结果\")\n",
"print(\"=\" * 80)\n",
"\n",
"# 确保输出目录存在\n",
"os.makedirs(OUTPUT_DIR, exist_ok=True)\n",
"\n",
"# 生成时间戳\n",
"start_dt = datetime.strptime(TEST_START, \"%Y%m%d\")\n",
"end_dt = datetime.strptime(TEST_END, \"%Y%m%d\")\n",
"date_str = f\"{start_dt.strftime('%Y%m%d')}_{end_dt.strftime('%Y%m%d')}\"\n",
"\n",
"# 保存每日 Top N\n",
"print(f\"\\n[1/1] 保存每日 Top {TOP_N} 股票...\")\n",
"topn_output_path = os.path.join(OUTPUT_DIR, f\"regression_output.csv\")\n",
"\n",
"# 按日期分组,取每日 top N\n",
"topn_by_date = []\n",
"unique_dates = results[\"trade_date\"].unique().sort()\n",
"for date in unique_dates:\n",
" day_data = results.filter(results[\"trade_date\"] == date)\n",
" # 按 prediction 降序排序,取前 N\n",
" topn = day_data.sort(\"prediction\", descending=True).head(TOP_N)\n",
" topn_by_date.append(topn)\n",
"\n",
"# 合并所有日期的 top N\n",
"topn_results = pl.concat(topn_by_date)\n",
"\n",
"# 格式化日期并调整列顺序:日期、分数、股票\n",
"topn_to_save = topn_results.select(\n",
" [\n",
" pl.col(\"trade_date\").str.slice(0, 4)\n",
" + \"-\"\n",
" + pl.col(\"trade_date\").str.slice(4, 2)\n",
" + \"-\"\n",
" + pl.col(\"trade_date\").str.slice(6, 2).alias(\"date\"),\n",
" pl.col(\"prediction\").alias(\"score\"),\n",
" pl.col(\"ts_code\"),\n",
" ]\n",
")\n",
"topn_to_save.write_csv(topn_output_path, include_header=True)\n",
"print(f\" 保存路径: {topn_output_path}\")\n",
"print(f\" 保存行数: {len(topn_to_save)}{len(unique_dates)}个交易日 × 每日top{TOP_N}\")\n",
"print(f\"\\n 预览前15行:\")\n",
"print(topn_to_save.head(15))"
],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"================================================================================\n",
"保存预测结果\n",
"================================================================================\n",
"\n",
"[1/1] 保存每日 Top 2 股票...\n",
" 保存路径: output\\regression_output.csv\n",
" 保存行数: 564282个交易日 × 每日top2\n",
"\n",
" 预览前15行:\n",
"shape: (15, 3)\n",
"┌────────────┬──────────┬───────────┐\n",
"│ trade_date ┆ score ┆ ts_code │\n",
"│ --- ┆ --- ┆ --- │\n",
"│ str ┆ f64 ┆ str │\n",
"╞════════════╪══════════╪═══════════╡\n",
"│ 2025-01-02 ┆ 0.081732 ┆ 603007.SH │\n",
"│ 2025-01-02 ┆ 0.070111 ┆ 603559.SH │\n",
"│ 2025-01-03 ┆ 0.088554 ┆ 603007.SH │\n",
"│ 2025-01-03 ┆ 0.08042 ┆ 603559.SH │\n",
"│ 2025-01-06 ┆ 0.087152 ┆ 603007.SH │\n",
"│ … ┆ … ┆ … │\n",
"│ 2025-01-09 ┆ 0.043328 ┆ 605118.SH │\n",
"│ 2025-01-09 ┆ 0.039754 ┆ 603848.SH │\n",
"│ 2025-01-10 ┆ 0.037324 ┆ 002309.SZ │\n",
"│ 2025-01-10 ┆ 0.027656 ┆ 603848.SH │\n",
"│ 2025-01-13 ┆ 0.050825 ┆ 002309.SZ │\n",
"└────────────┴──────────┴───────────┘\n"
]
}
],
"execution_count": 15
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 4.5 特征重要性"
]
},
{
"cell_type": "code",
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-08T06:11:05.972569Z",
"start_time": "2026-03-08T06:11:05.968651Z"
}
},
"source": [
"importance = model.feature_importance()\n",
"if importance is not None:\n",
" print(\"\\n特征重要性:\")\n",
" print(importance.sort_values(ascending=False))\n",
"\n",
"print(\"\\n\" + \"=\" * 80)\n",
"print(\"训练完成!\")\n",
"print(\"=\" * 80)"
],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"特征重要性:\n",
"ebitda_rank 12044.385847\n",
"bbi_ratio_factor 1996.570270\n",
"std_return_90 1900.096352\n",
"return_10 1707.237531\n",
"cs_rank_turnover_rate 1700.674361\n",
"ma_ratio 1449.246176\n",
"std_return_5 1112.916802\n",
"high_low_ratio 1002.711145\n",
"vol_ma20 800.361497\n",
"ma20 743.156761\n",
"cs_rank_volume_ratio 620.018650\n",
"vol_ratio 614.280431\n",
"turnover_rate_mean_5 574.150844\n",
"return_20 530.409461\n",
"volume_change_rate 516.283890\n",
"return_diff 427.252813\n",
"market_cap_rank 367.685596\n",
"volatility_20 297.221084\n",
"vol_std_5 245.201921\n",
"volatility_5 210.447650\n",
"ma10 175.131429\n",
"ma5 137.251949\n",
"vol_ma5 105.623339\n",
"ebit_rank 102.653185\n",
"profit_to_market_cap 65.704505\n",
"operate_profit_rank 64.448026\n",
"operate_profit_to_market_cap 59.952241\n",
"cashflow_to_market_cap 40.900113\n",
"n_income_rank 34.135645\n",
"n_cashflow_act_rank 27.246159\n",
"money_cap_rank 22.246325\n",
"total_liab_rank 21.635597\n",
"total_profit_rank 17.579179\n",
"turnover_deviation 0.000000\n",
"dtype: float64\n",
"\n",
"================================================================================\n",
"训练完成!\n",
"================================================================================\n"
]
}
],
"execution_count": 16
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 5. 可视化分析\n",
"\n",
"使用训练好的模型直接绘图。\n",
"- **特征重要性图**:辅助特征选择\n",
"- **决策树图**:理解决策逻辑"
]
},
{
"cell_type": "code",
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-08T06:11:05.983915Z",
"start_time": "2026-03-08T06:11:05.981218Z"
}
},
"source": [
"# 导入可视化库\n",
"import matplotlib.pyplot as plt\n",
"import lightgbm as lgb\n",
"import pandas as pd\n",
"\n",
"# 从封装的model中取出底层Booster\n",
"booster = model.model\n",
"print(f\"模型类型: {type(booster)}\")\n",
"print(f\"特征数量: {len(feature_cols)}\")"
],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"模型类型: <class 'lightgbm.basic.Booster'>\n",
"特征数量: 34\n"
]
}
],
"execution_count": 17
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 5.1 绘制特征重要性(辅助特征选择)\n",
"\n",
"**解读**\n",
"- 重要性高的特征对模型贡献大\n",
"- 重要性为0的特征可以考虑删除\n",
"- 可以帮助理解哪些因子最有效"
]
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-08T06:11:06.124862Z",
"start_time": "2026-03-08T06:11:05.993617Z"
}
},
"cell_type": "code",
"source": [
"print(\"绘制特征重要性...\")\n",
"\n",
"fig, ax = plt.subplots(figsize=(10, 8))\n",
"lgb.plot_importance(\n",
" booster, \n",
" max_num_features=20,\n",
" importance_type='gain',\n",
" title='Feature Importance (Gain)',\n",
" ax=ax\n",
")\n",
"ax.set_xlabel('Importance (Gain)')\n",
"plt.tight_layout()\n",
"plt.show()\n",
"\n",
"# 打印重要性排名\n",
"importance_gain = pd.Series(\n",
" booster.feature_importance(importance_type='gain'),\n",
" index=feature_cols\n",
").sort_values(ascending=False)\n",
"\n",
"print(\"\\n[特征重要性排名 - Gain]\")\n",
"print(importance_gain)\n",
"\n",
"# 识别低重要性特征\n",
"zero_importance = importance_gain[importance_gain == 0].index.tolist()\n",
"if zero_importance:\n",
" print(f\"\\n[低重要性特征] 以下{len(zero_importance)}个特征重要性为0可考虑删除:\")\n",
" for feat in zero_importance:\n",
" print(f\" - {feat}\")\n",
"else:\n",
" print(\"\\n所有特征都有一定重要性\")"
],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"绘制特征重要性...\n"
]
},
{
"data": {
"text/plain": [
"<Figure size 1000x800 with 1 Axes>"
],
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA9kAAAMWCAYAAADlCkWLAAAAOnRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjEwLjgsIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcvwVt1zgAAAAlwSFlzAAAPYQAAD2EBqD+naQAA6KFJREFUeJzs3QmcjfX7//HLvmbNnuyyZCuRKCJLJEpI1laVJUnJN7JEiFCIdtFeClmTSkq2LJGlFJF9yf61DPN/vD/f/31+Z86cMwY3M2Zez8fjZOY+97nPPefcM533fV2fz50iOjo62gAAAAAAwAVLeeGbAAAAAAAAhGwAAAAAAHxEJRsAAAAAAJ8QsgEAAAAA8AkhGwAAAAAAnxCyAQAAAADwCSEbAAAAAACfELIBAAAAAPAJIRsAAAAAAJ8QsgEAAC6xrVu3Wvr06e2nn366JM/XoUMHK1y48Hk99tlnn7WqVav6vk8AkFQRsgEAScqECRMsRYoUYW8KCxfDwoULrV+/fnbgwAFLrK/HsmXL7HL12muvuZ8jKRkwYIALrtWrV49134IFC6xFixZWoEABS5s2rWXNmtWtq8fs2rXrku9rt27dbNWqVTZt2rRL/twAcDlKndA7AADAxaBAUqRIkRjLrr322osWsvv37++qhdmyZbsoz5GcKWRfeeWV7vVNCvbs2WPvvfeeu4V6/vnn7YUXXrCiRYu6n1f/Hj9+3H755Rd7+eWX3WP+/PPPc37ON998086cOXNe+5s3b15r0qSJDR8+3O68887z2gYAJCeEbABAknT77bdb5cqV7XJ29OhRy5QpkyVXx44ds4wZM1pS8/7771vq1KmtcePGMZZ/8sknLmCrij1p0iRXxQ42cuRIdzsfadKkuaB91j41b97c/vrrLxf8AQCR0S4OAEiWZs2aZTfffLMLsVdccYU1atTIfvvttxjr/Prrr4FqosbPqqL3wAMP2L59+wLrqE386aefdl+rcu61pm/evNnd9HW4Vmct12ODt6Nla9eutfvuu8+yZ89uNWrUiBHMrr/+esuQIYPlyJHD7r33Xjeu93zoZ8qcObNt2bLF7rjjDve1WpPHjh3r7l+9erXVrl3bvTaFChWyDz/8MGwL+g8//GAdO3a0nDlzWpYsWaxdu3b277//hq1Ely1b1tKlS2f58+e3Tp06xWqtr1Wrlus0UMX2lltuceH6P//5jxtHrPdl/vz5gddW68r+/futR48eVq5cOfczaB90ckWtzcG+//5797hPP/3UBg0aZFdddZV7P+vUqWMbN26Mtb+LFy+2hg0buvdAr0H58uXtlVdeibHO+vXr7Z577nHvhbalEzrxbaeeMmWKa//WPodWsVWxf/vtt2MFbFHbePAxI1OnTnXHrl5Xvb7FihVzQf306dNxjsn2jk1Vp9944w33OD3+hhtusKVLl8Z67ttuuy3wfACAuFHJBgAkSQcPHrS9e/fGWKYAI6oStm/f3urXr29Dhw51FdNx48a5ULtixYpAGJk7d66r3N1///0uYCvsKZDo30WLFrmQcvfdd9vvv/9uH330kasyes+RK1cu1xZ8rlQtLFGihL344osWHR3tlikY9unTx1UTH3roIbfd0aNHuzCq/T2fFnWFMAVSbeOll16yDz74wDp37uxC5XPPPWetW7d2P9v48eNdeK5WrVqs9nutr+dW8NuwYYN7Df/+++9AqBXdp1Z6hbTHHnsssJ6CnCb9Cq6w6uSF9kknENq0aWN58uRxgbpLly4ukGq/RMtF740Cq14z7ZvGK7/++utWs2ZNd7JCwTPYkCFDLGXKlC6Y6/jQz62fU6Hao/dcJx7y5ctnTzzxhHvf161bZ9OnT3ffi95/jaXWiQmN89drpgDftGlTmzx5st11110RX/dTp065n12vRTAdQ7rp/Q0N33HRCQ+t3717d/fvt99+68L6oUOHbNiwYWd9vE6gHD582J0s0Xum10Tvu17b4PdGAV9BXO/Zk08+Ge/9A4BkKRoAgCTk3XffVTINe5PDhw9HZ8uWLfrhhx+O8bidO3dGZ82aNcbyY8eOxdr+Rx995Lb1ww8/BJYNGzbMLdu0aVOMdfW9lmufQml53759A9/ray1r1apVjPU2b94cnSpVquhBgwbFWL569ero1KlTx1oe6fVYunRpYFn79u3dshdffDGw7N9//43OkCFDdIoUKaI//vjjwPL169fH2ldvm9dff330yZMnA8tfeuklt3zq1Knu+927d0enTZs2ul69etGnT58OrDdmzBi33jvvvBNYVrNmTbds/PjxsX6GsmXLuvtDHT9+PMZ2vdc8Xbp00QMGDAgs++6779y2S5cuHX3ixInA8ldeecUt12spUVFR0UWKFIkuVKiQez2CnTlzJvB1nTp1osuVK+eeP/j+m266KbpEiRLRcdm4caN7ztGjR8dYrtdMy0eNGhXreffs2RPjdurUqTiP0Y4dO0ZnzJgxxv7pPdfPFfw66fly5swZvX///lj78dVXX8Xart5HvYYAgLjRLg4ASJLU+qyqZPBN9K9alVu1auUq3d4tVapUroX3u+++C2xDrdkeTT6l9W688Ub3/fLlyy/Kfj/66KMxvv/iiy/chFWqYgfvryqsqngH7++5UtXUo4r0Nddc46qyei6Pluk+VTZDPfLIIzGqnarOaqzxzJkz3ffffPONnTx50s1OrQqy5+GHH3at3TNmzIixPbUrq2sgvrS+t11V5lUJVzVX+xzu/dG2g9uwNVxAvJ9NXQGbNm1y+xvaHeBV5tWirmqxXiNVgL33Q8+tzog//vjDtm3bFnGfvaEGakUPpsqzhFaxVXFXV0TwbeXKlWGPUW9/9HOpO0Mt7WfTsmXLGPsS+poE03qh3SEAgNhoFwcAJElVqlQJO/GZQpBozHE4Cn8eBSq1On/88ce2e/fuWOHnYghtydb+qvCtQO3nhFYaR6zAFkwtwRqv7AXK4OXhxlqH7pMCotqsNd5X1DouCr3BFHQ1zt273+Ndsiq+dPJBY6U15lvhOHgcssaJh7r66qtjfO+FS+9n82btjmsWeo3h1vuh9n3dwtGxop8lLt5QAI/mBZAjR47Eek29E0Rff/11rBZwta737t3bBX8vqJ/LMXq21yR0n0OPDQBAbIRsAECy4l3GSOOyVQ0OpUqsR9VKXZ5LE5tVrFjRBR49vkGDBvG6HFKkQBI6KVWw4Mqkt7/ajiZqU7U91LmM3w0WbltxLQ8NhRdD6M9+Nhq3rqCryeg02ZcmIVNlW5XocO+PHz+bt12N61blOpzixYtHfLwX/kNDbKlSpdy/a9asiXU8epOO/fPPPzHuU0eGxp/rxJAuWacx0zp5oip+z54943WMnstron325hwAAERGyAYAJCsKIpI7d+5AeAlHgWLevHmukq2JpEIr4fEJ015VMHQm7dAK7tn2V4FHFe6SJUtaYqLX4tZbbw18ryrsjh073MzcopnJRZOdBV/2SS3kqjzH9frH5/X9/PPP3fNrNu5ger3PJwx6x4aCbqR9834OdRDEd/9DK8c6maCfP5iq/eoM0ERuo0aNitel2zTBnNrPNaRAE9h5QrftF223QoUKF2XbAJCUMCYbAJCsqPqoyp+qoJrpOZQ3I7hX4Qut6CkAhfICUWiY1vMo7OlSV8HU3hxfmulZ+6KwH7ov+j74cmKXmmZaD34NNWt4VFSUmyFcFELV/v3qq6/G2HeFYrUy69JT8aHXN/S1Fb0uoa/JZ599FueY6Lhcd9117mSG3uPQ5/OeRydnNOO5ZjHXCYVQZ5tRXuFcwxiWLVsW6z7NxK4xzxqzHu7YDP1Zwx2jOoFxLsdXfOn9Ujv9TTfd5Pu2ASCpoZINAEhWFHwVBtu2betClS4XpbHJuma0JuLSpZnGjBnj1vMub6XAozG2GhMbrkqo61eLLjGl7SlINW7c2IVDTS6mS0fpX4UrBW5dqulcqqsDBw60Xr16ubHOukyUxu9qP7788ks3+ZhalxOCAp2uNa22elWrFe50GbQ777zT3a/XVfutEwRqsddybz1dj1mX6YoPvb56z/Q6qBVbQVdj6nWpLbV
},
"metadata": {},
"output_type": "display_data",
"jetTransient": {
"display_id": null
}
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"[特征重要性排名 - Gain]\n",
"ebitda_rank 12044.385847\n",
"bbi_ratio_factor 1996.570270\n",
"std_return_90 1900.096352\n",
"return_10 1707.237531\n",
"cs_rank_turnover_rate 1700.674361\n",
"ma_ratio 1449.246176\n",
"std_return_5 1112.916802\n",
"high_low_ratio 1002.711145\n",
"vol_ma20 800.361497\n",
"ma20 743.156761\n",
"cs_rank_volume_ratio 620.018650\n",
"vol_ratio 614.280431\n",
"turnover_rate_mean_5 574.150844\n",
"return_20 530.409461\n",
"volume_change_rate 516.283890\n",
"return_diff 427.252813\n",
"market_cap_rank 367.685596\n",
"volatility_20 297.221084\n",
"vol_std_5 245.201921\n",
"volatility_5 210.447650\n",
"ma10 175.131429\n",
"ma5 137.251949\n",
"vol_ma5 105.623339\n",
"ebit_rank 102.653185\n",
"profit_to_market_cap 65.704505\n",
"operate_profit_rank 64.448026\n",
"operate_profit_to_market_cap 59.952241\n",
"cashflow_to_market_cap 40.900113\n",
"n_income_rank 34.135645\n",
"n_cashflow_act_rank 27.246159\n",
"money_cap_rank 22.246325\n",
"total_liab_rank 21.635597\n",
"total_profit_rank 17.579179\n",
"turnover_deviation 0.000000\n",
"dtype: float64\n",
"\n",
"[低重要性特征] 以下1个特征重要性为0可考虑删除:\n",
" - turnover_deviation\n"
]
}
],
"execution_count": 18
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.0"
}
},
"nbformat": 4,
"nbformat_minor": 4
}