Files
ProStock/src/experiment/regression.ipynb

1605 lines
196 KiB
Plaintext
Raw Normal View History

{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 1. 导入依赖"
]
},
{
"cell_type": "code",
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-08T12:42:00.925979Z",
"start_time": "2026-03-08T12:42:00.366875Z"
}
},
"source": [
"import os\n",
"from datetime import datetime\n",
"from typing import List\n",
"\n",
"import polars as pl\n",
"\n",
"from src.factors import FactorEngine\n",
"from src.training import (\n",
" DateSplitter,\n",
" LightGBMModel,\n",
" STFilter,\n",
" StandardScaler,\n",
" StockFilterConfig,\n",
" StockPoolManager,\n",
" Trainer,\n",
" Winsorizer,\n",
" NullFiller,\n",
")\n",
"from src.training.config import TrainingConfig"
],
"outputs": [],
"execution_count": 1
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 2. 定义辅助函数"
]
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-08T12:42:00.940517Z",
"start_time": "2026-03-08T12:42:00.935125Z"
}
},
"cell_type": "code",
"source": [
"def create_factors_with_strings(engine: FactorEngine, factor_definitions: dict, label_factor: dict) -> List[str]:\n",
" print(\"=\" * 80)\n",
" print(\"使用字符串表达式定义因子\")\n",
" print(\"=\" * 80)\n",
"\n",
" # 注册所有特征因子\n",
" print(\"\\n注册特征因子:\")\n",
" for name, expr in factor_definitions.items():\n",
" engine.add_factor(name, expr)\n",
" print(f\" - {name}: {expr}\")\n",
"\n",
" # 注册 label 因子\n",
" print(\"\\n注册 Label 因子:\")\n",
" for name, expr in label_factor.items():\n",
" engine.add_factor(name, expr)\n",
" print(f\" - {name}: {expr}\")\n",
"\n",
" # 从字典自动获取特征列\n",
" feature_cols = list(factor_definitions.keys())\n",
"\n",
" print(f\"\\n特征因子数: {len(feature_cols)}\")\n",
" print(f\"Label: {list(label_factor.keys())[0]}\")\n",
" print(f\"已注册因子总数: {len(engine.list_registered())}\")\n",
"\n",
" return feature_cols\n",
"\n",
"\n",
"def prepare_data(\n",
" engine: FactorEngine,\n",
" feature_cols: List[str],\n",
" start_date: str,\n",
" end_date: str,\n",
") -> pl.DataFrame:\n",
" print(\"\\n\" + \"=\" * 80)\n",
" print(\"准备数据\")\n",
" print(\"=\" * 80)\n",
"\n",
" # 计算因子(全市场数据)\n",
" print(f\"\\n计算因子: {start_date} - {end_date}\")\n",
" factor_names = feature_cols + [LABEL_NAME] # 包含 label\n",
"\n",
" data = engine.compute(\n",
" factor_names=factor_names,\n",
" start_date=start_date,\n",
" end_date=end_date,\n",
" )\n",
"\n",
" print(f\"数据形状: {data.shape}\")\n",
" print(f\"数据列: {data.columns}\")\n",
" print(f\"\\n前5行预览:\")\n",
" print(data.head())\n",
"\n",
" return data"
],
"outputs": [],
"execution_count": 2
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 3. 配置参数\n",
"\n",
"### 3.1 因子定义"
]
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-08T12:42:00.953087Z",
"start_time": "2026-03-08T12:42:00.947665Z"
}
},
"cell_type": "code",
"source": [
"# 特征因子定义字典:新增因子只需在此处添加一行\n",
"LABEL_NAME = 'future_return_5'\n",
"\n",
"FACTOR_DEFINITIONS = {\n",
" # 1. 价格动量因子\n",
" \"ma5\": \"ts_mean(close, 5)\",\n",
" \"ma10\": \"ts_mean(close, 10)\",\n",
" \"ma20\": \"ts_mean(close, 20)\",\n",
" \"ma_ratio\": \"ts_mean(close, 5) / ts_mean(close, 20) - 1\",\n",
" # 2. 波动率因子\n",
" \"volatility_5\": \"ts_std(close, 5)\",\n",
" \"volatility_20\": \"ts_std(close, 20)\",\n",
" \"vol_ratio\": \"ts_std(close, 5) / (ts_std(close, 20) + 1e-8)\",\n",
" # 3. 收益率动量因子\n",
" \"return_10\": \"(close / ts_delay(close, 10)) - 1\",\n",
" \"return_20\": \"(close / ts_delay(close, 20)) - 1\",\n",
" # 4. 收益率变化因子\n",
" \"return_diff\": \"(close / ts_delay(close, 5)) - 1 - ((close / ts_delay(close, 10)) - 1)\",\n",
" # 5. 成交量因子\n",
" \"vol_ma5\": \"ts_mean(vol, 5)\",\n",
" \"vol_ma20\": \"ts_mean(vol, 20)\",\n",
" \"vol_ratio\": \"ts_mean(vol, 5) / (ts_mean(vol, 20) + 1e-8)\",\n",
" # 6. 市值因子(截面排名)\n",
" \"market_cap_rank\": \"cs_rank(total_mv)\",\n",
" # 7. 价格位置因子\n",
" \"high_low_ratio\": \"(close - ts_min(low, 20)) / (ts_max(high, 20) - ts_min(low, 20) + 1e-8)\",\n",
" # 8. 技术指标因子3.1 完全可实现)\n",
" \"turnover_rate_mean_5\": \"ts_mean(turnover_rate, 5)\", # 5日均换手率\n",
" \"bbi_ratio_factor\": \"(ts_mean(close, 3) + ts_mean(close, 6) + ts_mean(close, 12) + ts_mean(close, 24)) / 4 / close\", # BBI比率\n",
" # 9. ARBR 因子3.1 完全可实现)\n",
" # \"AR\": \"ts_sum(high - open, 26) / ts_sum(open - low, 26) * 100\", # AR人气指标\n",
" # \"BR\": \"ts_sum(max_(0, high - ts_delay(close, 1)), 26) / ts_sum(max_(0, ts_delay(close, 1) - low), 26) * 100\", # BR意愿指标\n",
" # \"AR_BR\": \"AR - BR\", # ARBR差值\n",
" # 10. 成交量因子3.1 完全可实现)\n",
" \"volume_change_rate\": \"ts_mean(vol, 2) / ts_mean(vol, 10) - 1\", # 成交量变化率\n",
" \"turnover_deviation\": \"(turnover_rate - ts_mean(turnover_rate, 3)) / ts_std(turnover_rate, 3)\", # 换手率偏离度\n",
" \"vol_std_5\": \"ts_std(ts_delta(vol, 1), 5)\", # 成交量变化标准差\n",
" # 11. 收益率因子3.1 完全可实现)\n",
" # \"return_5\": \"close / ts_delay(close, 5) - 1\", # 5日收益率\n",
" \"std_return_5\": \"ts_std((close - ts_delay(close, 1)) / ts_delay(close, 1), 5)\", # 5日收益率标准差\n",
" \"std_return_90\": \"ts_std((close - ts_delay(close, 1)) / ts_delay(close, 1), 90)\", # 90日收益率标准差\n",
" # 12. 截面排序因子3.1 完全可实现)\n",
" \"cs_rank_volume_ratio\": \"cs_rank(volume_ratio)\", # 量比截面排名\n",
" \"cs_rank_turnover_rate\": \"cs_rank(turnover_rate)\", # 换手率截面排名\n",
" \"n_income_rank\": \"cs_rank(n_income)\", # 净利润截面排名\n",
" # 13. 财务数据因子(来自利润表 financial_income\n",
" \"operate_profit_rank\": \"cs_rank(operate_profit)\", # 营业利润截面排名\n",
" \"total_profit_rank\": \"cs_rank(total_profit)\", # 利润总额截面排名\n",
" \"ebit_rank\": \"cs_rank(ebit)\", # 息税前利润截面排名\n",
" \"ebitda_rank\": \"cs_rank(ebitda)\", # 息税折旧摊销前利润截面排名\n",
" # 14. 财务数据因子(来自资产负债表 financial_balance\n",
" \"total_liab_rank\": \"cs_rank(total_liab)\", # 总负债截面排名\n",
" \"money_cap_rank\": \"cs_rank(money_cap)\", # 货币资金截面排名\n",
" # 15. 财务数据因子(来自现金流量表 financial_cashflow\n",
" \"n_cashflow_act_rank\": \"cs_rank(n_cashflow_act)\", # 经营活动现金流净额截面排名\n",
" # 16. 财务估值因子\n",
" \"profit_to_market_cap\": \"n_income / (total_mv + 1e-8)\", # 净利润率(净利润/市值)\n",
" \"cashflow_to_market_cap\": \"n_cashflow_act / (total_mv + 1e-8)\", # 经营现金流/市值\n",
" \"operate_profit_to_market_cap\": \"operate_profit / (total_mv + 1e-8)\", # 营业利润/市值\n",
"}\n",
"\n",
"# Label 因子定义(不参与训练,用于计算目标)\n",
"LABEL_FACTOR = {\n",
" LABEL_NAME: \"(ts_delay(close, -5) / close) - 1\", # 未来5日收益率\n",
"}"
],
"outputs": [],
"execution_count": 3
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 3.2 训练参数配置"
]
},
{
"cell_type": "code",
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-08T12:42:00.961926Z",
"start_time": "2026-03-08T12:42:00.959032Z"
}
},
"source": [
"# 日期范围配置(正确的 train/val/test 三分法)\n",
"# Train: 用于训练模型参数\n",
"# Val: 用于验证/早停/调参(位于 train 之后test 之前)\n",
"# Test: 仅用于最终评估,完全独立于训练过程\n",
"TRAIN_START = \"20200101\"\n",
"TRAIN_END = \"20231231\"\n",
"VAL_START = \"20240101\"\n",
"VAL_END = \"20241231\"\n",
"TEST_START = \"20250101\"\n",
"TEST_END = \"20251231\"\n",
"\n",
"# 模型参数配置\n",
"MODEL_PARAMS = {\n",
" \"objective\": \"regression\",\n",
" \"metric\": \"mae\", # 改为 MAE对异常值更稳健\n",
" # 树结构控制(防过拟合核心)\n",
" \"num_leaves\": 20, # 从31降为20降低模型复杂度\n",
" \"max_depth\": 4, # 显式限制深度,防止过度拟合噪声\n",
" \"min_child_samples\": 50, # 叶子最小样本数,防止学习极端样本\n",
" \"min_child_weight\": 0.001,\n",
" # 学习参数\n",
" \"learning_rate\": 0.01, # 降低学习率,配合更多树\n",
" \"n_estimators\": 1000, # 增加树数量,配合早停\n",
" # 采样策略(关键防过拟合)\n",
" \"subsample\": 0.8, # 每棵树随机采样80%数据(行采样)\n",
" \"subsample_freq\": 5, # 每5轮迭代进行一次 subsample\n",
" \"colsample_bytree\": 0.8, # 每棵树随机选择80%特征(列采样)\n",
" # 正则化\n",
" \"reg_alpha\": 0.1, # L1正则增加稀疏性\n",
" \"reg_lambda\": 1.0, # L2正则平滑权重\n",
" # 数值稳定性\n",
" \"verbose\": -1,\n",
" \"random_state\": 42,\n",
"}\n",
"\n",
"# 数据处理器配置\n",
"PROCESSOR_CONFIGS = [\n",
" {\"name\": \"winsorizer\", \"params\": {\"lower\": 0.01, \"upper\": 0.99}},\n",
" {\"name\": \"cs_standard_scaler\", \"params\": {}},\n",
"]\n",
"\n",
"# 股票池筛选配置\n",
"STOCK_FILTER_CONFIG = {\n",
" \"exclude_cyb\": True, # 排除创业板\n",
" \"exclude_kcb\": True, # 排除科创板\n",
" \"exclude_bj\": True, # 排除北交所\n",
" \"exclude_st\": True, # 排除ST股票\n",
"}\n",
"\n",
"# 输出配置(相对于本文件所在目录)\n",
"OUTPUT_DIR = \"output\"\n",
"SAVE_PREDICTIONS = True\n",
"PERSIST_MODEL = False\n",
"\n",
"# Top N 配置:每日推荐股票数量\n",
"TOP_N = 5 # 可调整为 10, 20 等"
],
"outputs": [],
"execution_count": 4
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 4. 训练流程\n",
"\n",
"### 4.1 初始化组件"
]
},
{
"cell_type": "code",
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-08T12:42:08.872125Z",
"start_time": "2026-03-08T12:42:00.967162Z"
}
},
"source": [
"print(\"\\n\" + \"=\" * 80)\n",
"print(\"LightGBM 回归模型训练\")\n",
"print(\"=\" * 80)\n",
"\n",
"# 1. 创建 FactorEngine\n",
"print(\"\\n[1] 创建 FactorEngine\")\n",
"engine = FactorEngine()\n",
"\n",
"# 2. 使用字符串表达式定义因子\n",
"print(\"\\n[2] 定义因子(字符串表达式)\")\n",
"feature_cols = create_factors_with_strings(engine, FACTOR_DEFINITIONS, LABEL_FACTOR)\n",
"target_col = LABEL_NAME\n",
"\n",
"# 3. 准备数据(使用模块级别的日期配置)\n",
"print(\"\\n[3] 准备数据\")\n",
"\n",
"data = prepare_data(\n",
" engine=engine,\n",
" feature_cols=feature_cols,\n",
" start_date=TRAIN_START,\n",
" end_date=TEST_END,\n",
")\n",
"\n",
"# 4. 打印配置信息\n",
"print(f\"\\n[配置] 训练期: {TRAIN_START} - {TRAIN_END}\")\n",
"print(f\"[配置] 验证期: {VAL_START} - {VAL_END}\")\n",
"print(f\"[配置] 测试期: {TEST_START} - {TEST_END}\")\n",
"print(f\"[配置] 特征数: {len(feature_cols)}\")\n",
"print(f\"[配置] 目标变量: {target_col}\")\n",
"\n",
"# 5. 创建模型\n",
"model = LightGBMModel(params=MODEL_PARAMS)\n",
"\n",
"# 6. 创建数据处理器\n",
"processors = [\n",
" NullFiller(strategy=\"mean\"),\n",
" Winsorizer(**PROCESSOR_CONFIGS[0][\"params\"]),\n",
" StandardScaler(exclude_cols=[\"ts_code\", \"trade_date\", target_col]),\n",
"]\n",
"\n",
"# 7. 创建数据划分器(正确的 train/val/test 三分法)\n",
"# Train: 训练模型参数 | Val: 验证/早停 | Test: 最终评估\n",
"splitter = DateSplitter(\n",
" train_start=TRAIN_START,\n",
" train_end=TRAIN_END,\n",
" val_start=VAL_START,\n",
" val_end=VAL_END,\n",
" test_start=TEST_START,\n",
" test_end=TEST_END,\n",
")\n",
"\n",
"# 8. 创建股票池管理器\n",
"pool_manager = StockPoolManager(\n",
" filter_config=StockFilterConfig(**STOCK_FILTER_CONFIG),\n",
" selector_config=None, # 暂时不启用市值选择\n",
" data_router=engine.router,\n",
")\n",
"\n",
"# 9. 创建 ST 股票过滤器\n",
"st_filter = STFilter(\n",
" data_router=engine.router,\n",
")\n",
"\n",
"# 10. 创建训练器\n",
"trainer = Trainer(\n",
" model=model,\n",
" pool_manager=pool_manager,\n",
" processors=processors,\n",
" filters=[st_filter],\n",
" splitter=splitter,\n",
" target_col=target_col,\n",
" feature_cols=feature_cols,\n",
" persist_model=PERSIST_MODEL,\n",
")"
],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"================================================================================\n",
"LightGBM 回归模型训练\n",
"================================================================================\n",
"\n",
"[1] 创建 FactorEngine\n",
"\n",
"[2] 定义因子(字符串表达式)\n",
"================================================================================\n",
"使用字符串表达式定义因子\n",
"================================================================================\n",
"\n",
"注册特征因子:\n",
" - ma5: ts_mean(close, 5)\n",
" - ma10: ts_mean(close, 10)\n",
" - ma20: ts_mean(close, 20)\n",
" - ma_ratio: ts_mean(close, 5) / ts_mean(close, 20) - 1\n",
" - volatility_5: ts_std(close, 5)\n",
" - volatility_20: ts_std(close, 20)\n",
" - vol_ratio: ts_mean(vol, 5) / (ts_mean(vol, 20) + 1e-8)\n",
" - return_10: (close / ts_delay(close, 10)) - 1\n",
" - return_20: (close / ts_delay(close, 20)) - 1\n",
" - return_diff: (close / ts_delay(close, 5)) - 1 - ((close / ts_delay(close, 10)) - 1)\n",
" - vol_ma5: ts_mean(vol, 5)\n",
" - vol_ma20: ts_mean(vol, 20)\n",
" - market_cap_rank: cs_rank(total_mv)\n",
" - high_low_ratio: (close - ts_min(low, 20)) / (ts_max(high, 20) - ts_min(low, 20) + 1e-8)\n",
" - turnover_rate_mean_5: ts_mean(turnover_rate, 5)\n",
" - bbi_ratio_factor: (ts_mean(close, 3) + ts_mean(close, 6) + ts_mean(close, 12) + ts_mean(close, 24)) / 4 / close\n",
" - volume_change_rate: ts_mean(vol, 2) / ts_mean(vol, 10) - 1\n",
" - turnover_deviation: (turnover_rate - ts_mean(turnover_rate, 3)) / ts_std(turnover_rate, 3)\n",
" - vol_std_5: ts_std(ts_delta(vol, 1), 5)\n",
" - std_return_5: ts_std((close - ts_delay(close, 1)) / ts_delay(close, 1), 5)\n",
" - std_return_90: ts_std((close - ts_delay(close, 1)) / ts_delay(close, 1), 90)\n",
" - cs_rank_volume_ratio: cs_rank(volume_ratio)\n",
" - cs_rank_turnover_rate: cs_rank(turnover_rate)\n",
" - n_income_rank: cs_rank(n_income)\n",
" - operate_profit_rank: cs_rank(operate_profit)\n",
" - total_profit_rank: cs_rank(total_profit)\n",
" - ebit_rank: cs_rank(ebit)\n",
" - ebitda_rank: cs_rank(ebitda)\n",
" - total_liab_rank: cs_rank(total_liab)\n",
" - money_cap_rank: cs_rank(money_cap)\n",
" - n_cashflow_act_rank: cs_rank(n_cashflow_act)\n",
" - profit_to_market_cap: n_income / (total_mv + 1e-8)\n",
" - cashflow_to_market_cap: n_cashflow_act / (total_mv + 1e-8)\n",
" - operate_profit_to_market_cap: operate_profit / (total_mv + 1e-8)\n",
"\n",
"注册 Label 因子:\n",
" - future_return_5: (ts_delay(close, -5) / close) - 1\n",
"\n",
"特征因子数: 34\n",
"Label: future_return_5\n",
"已注册因子总数: 35\n",
"\n",
"[3] 准备数据\n",
"\n",
"================================================================================\n",
"准备数据\n",
"================================================================================\n",
"\n",
"计算因子: 20200101 - 20251231\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"D:\\PyProject\\ProStock\\src\\data\\financial_loader.py:148: UserWarning: Sortedness of columns cannot be checked when 'by' groups provided\n",
" merged = df_price.join_asof(\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"数据形状: (7044952, 53)\n",
"数据列: ['ts_code', 'trade_date', 'turnover_rate', 'volume_ratio', 'high', 'vol', 'close', 'low', 'total_mv', 'f_ann_date', 'total_profit', 'operate_profit', 'ebit', 'n_income', 'ebitda', 'money_cap', 'total_liab', 'n_cashflow_act', 'ma5', 'ma10', 'ma20', 'ma_ratio', 'volatility_5', 'volatility_20', 'vol_ratio', 'return_10', 'return_20', 'return_diff', 'vol_ma5', 'vol_ma20', 'market_cap_rank', 'high_low_ratio', 'turnover_rate_mean_5', 'bbi_ratio_factor', 'volume_change_rate', 'turnover_deviation', 'vol_std_5', 'std_return_5', 'std_return_90', 'cs_rank_volume_ratio', 'cs_rank_turnover_rate', 'n_income_rank', 'operate_profit_rank', 'total_profit_rank', 'ebit_rank', 'ebitda_rank', 'total_liab_rank', 'money_cap_rank', 'n_cashflow_act_rank', 'profit_to_market_cap', 'cashflow_to_market_cap', 'operate_profit_to_market_cap', 'future_return_5']\n",
"\n",
"前5行预览:\n",
"shape: (5, 53)\n",
"┌───────────┬───────────┬───────────┬───────────┬───┬───────────┬───────────┬───────────┬──────────┐\n",
"│ ts_code ┆ trade_dat ┆ turnover_ ┆ volume_ra ┆ … ┆ profit_to ┆ cashflow_ ┆ operate_p ┆ future_r │\n",
"│ --- ┆ e ┆ rate ┆ tio ┆ ┆ _market_c ┆ to_market ┆ rofit_to_ ┆ eturn_5 │\n",
"│ str ┆ --- ┆ --- ┆ --- ┆ ┆ ap ┆ _cap ┆ market_ca ┆ --- │\n",
"│ ┆ str ┆ f64 ┆ f64 ┆ ┆ --- ┆ --- ┆ p ┆ f64 │\n",
"│ ┆ ┆ ┆ ┆ ┆ f64 ┆ f64 ┆ --- ┆ │\n",
"│ ┆ ┆ ┆ ┆ ┆ ┆ ┆ f64 ┆ │\n",
"╞═══════════╪═══════════╪═══════════╪═══════════╪═══╪═══════════╪═══════════╪═══════════╪══════════╡\n",
"│ 000001.SZ ┆ 20200102 ┆ 0.7885 ┆ 2.18 ┆ … ┆ 721.52104 ┆ 2580.5045 ┆ 938.15146 ┆ -0.00474 │\n",
"│ ┆ ┆ ┆ ┆ ┆ 1 ┆ 33 ┆ 4 ┆ 6 │\n",
"│ 000001.SZ ┆ 20200103 ┆ 0.5752 ┆ 1.21 ┆ … ┆ 708.50174 ┆ 2533.9412 ┆ 921.22323 ┆ -0.02852 │\n",
"│ ┆ ┆ ┆ ┆ ┆ 4 ┆ 96 ┆ 7 ┆ │\n",
"│ 000001.SZ ┆ 20200106 ┆ 0.4442 ┆ 0.8 ┆ … ┆ 713.06736 ┆ 2550.2701 ┆ 927.15964 ┆ -0.00468 │\n",
"│ ┆ ┆ ┆ ┆ ┆ 8 ┆ 5 ┆ 9 ┆ 5 │\n",
"│ 000001.SZ ┆ 20200107 ┆ 0.3755 ┆ 0.7 ┆ … ┆ 709.74110 ┆ 2538.3738 ┆ 922.83470 ┆ -0.02274 │\n",
"│ ┆ ┆ ┆ ┆ ┆ 6 ┆ 46 ┆ 6 ┆ 3 │\n",
"│ 000001.SZ ┆ 20200108 ┆ 0.4369 ┆ 0.86 ┆ … ┆ 730.61584 ┆ 2613.0319 ┆ 949.97690 ┆ -0.00840 │\n",
"│ ┆ ┆ ┆ ┆ ┆ 4 ┆ 01 ┆ 3 ┆ 1 │\n",
"└───────────┴───────────┴───────────┴───────────┴───┴───────────┴───────────┴───────────┴──────────┘\n",
"\n",
"[配置] 训练期: 20200101 - 20231231\n",
"[配置] 验证期: 20240101 - 20241231\n",
"[配置] 测试期: 20250101 - 20251231\n",
"[配置] 特征数: 34\n",
"[配置] 目标变量: future_return_5\n"
]
}
],
"execution_count": 5
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 4.2 执行训练"
]
},
{
"cell_type": "code",
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-08T12:42:17.812047Z",
"start_time": "2026-03-08T12:42:08.881623Z"
}
},
"source": [
"print(\"\\n\" + \"=\" * 80)\n",
"print(\"开始训练\")\n",
"print(\"=\" * 80)\n",
"\n",
"# 步骤 1: 股票池筛选\n",
"print(\"\\n[步骤 1/6] 股票池筛选\")\n",
"print(\"-\" * 60)\n",
"if pool_manager:\n",
" print(\" 执行每日独立筛选股票池...\")\n",
" filtered_data = pool_manager.filter_and_select_daily(data)\n",
" print(f\" 筛选前数据规模: {data.shape}\")\n",
" print(f\" 筛选后数据规模: {filtered_data.shape}\")\n",
" print(f\" 筛选前股票数: {data['ts_code'].n_unique()}\")\n",
" print(f\" 筛选后股票数: {filtered_data['ts_code'].n_unique()}\")\n",
" print(f\" 删除记录数: {len(data) - len(filtered_data)}\")\n",
"else:\n",
" filtered_data = data\n",
" print(\" 未配置股票池管理器,跳过筛选\")"
],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"================================================================================\n",
"开始训练\n",
"================================================================================\n",
"\n",
"[步骤 1/6] 股票池筛选\n",
"------------------------------------------------------------\n",
" 执行每日独立筛选股票池...\n",
" 筛选前数据规模: (7044952, 53)\n",
" 筛选后数据规模: (4532198, 53)\n",
" 筛选前股票数: 5678\n",
" 筛选后股票数: 3359\n",
" 删除记录数: 2512754\n"
]
}
],
"execution_count": 6
},
{
"cell_type": "code",
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-08T12:42:19.603242Z",
"start_time": "2026-03-08T12:42:17.822762Z"
}
},
"source": [
"# 步骤 2: 划分训练/验证/测试集(正确的三分法)\n",
"print(\"\\n[步骤 2/6] 划分训练集、验证集和测试集\")\n",
"print(\"-\" * 60)\n",
"if splitter:\n",
" # 正确的三分法train用于训练val用于验证/早停test仅用于最终评估\n",
" train_data, val_data, test_data = splitter.split(filtered_data)\n",
" print(f\" 训练集数据规模: {train_data.shape}\")\n",
" print(f\" 验证集数据规模: {val_data.shape}\")\n",
" print(f\" 测试集数据规模: {test_data.shape}\")\n",
" print(f\" 训练集股票数: {train_data['ts_code'].n_unique()}\")\n",
" print(f\" 验证集股票数: {val_data['ts_code'].n_unique()}\")\n",
" print(f\" 测试集股票数: {test_data['ts_code'].n_unique()}\")\n",
" print(\n",
" f\" 训练集日期范围: {train_data['trade_date'].min()} - {train_data['trade_date'].max()}\"\n",
" )\n",
" print(\n",
" f\" 验证集日期范围: {val_data['trade_date'].min()} - {val_data['trade_date'].max()}\"\n",
" )\n",
" print(\n",
" f\" 测试集日期范围: {test_data['trade_date'].min()} - {test_data['trade_date'].max()}\"\n",
" )\n",
"\n",
" print(\"\\n 训练集前5行预览:\")\n",
" print(train_data.head())\n",
" print(\"\\n 验证集前5行预览:\")\n",
" print(val_data.head())\n",
" print(\"\\n 测试集前5行预览:\")\n",
" print(test_data.head())\n",
"else:\n",
" train_data = filtered_data\n",
" test_data = filtered_data\n",
" print(\" 未配置划分器,全部作为训练集\")"
],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"[步骤 2/6] 划分训练集、验证集和测试集\n",
"------------------------------------------------------------\n",
" 训练集数据规模: (2991506, 53)\n",
" 验证集数据规模: (769485, 53)\n",
" 测试集数据规模: (771207, 53)\n",
" 训练集股票数: 3297\n",
" 验证集股票数: 3220\n",
" 测试集股票数: 3215\n",
" 训练集日期范围: 20200102 - 20231229\n",
" 验证集日期范围: 20240102 - 20241231\n",
" 测试集日期范围: 20250102 - 20251231\n",
"\n",
" 训练集前5行预览:\n",
"shape: (5, 53)\n",
"┌───────────┬───────────┬───────────┬───────────┬───┬───────────┬───────────┬───────────┬──────────┐\n",
"│ ts_code ┆ trade_dat ┆ turnover_ ┆ volume_ra ┆ … ┆ profit_to ┆ cashflow_ ┆ operate_p ┆ future_r │\n",
"│ --- ┆ e ┆ rate ┆ tio ┆ ┆ _market_c ┆ to_market ┆ rofit_to_ ┆ eturn_5 │\n",
"│ str ┆ --- ┆ --- ┆ --- ┆ ┆ ap ┆ _cap ┆ market_ca ┆ --- │\n",
"│ ┆ str ┆ f64 ┆ f64 ┆ ┆ --- ┆ --- ┆ p ┆ f64 │\n",
"│ ┆ ┆ ┆ ┆ ┆ f64 ┆ f64 ┆ --- ┆ │\n",
"│ ┆ ┆ ┆ ┆ ┆ ┆ ┆ f64 ┆ │\n",
"╞═══════════╪═══════════╪═══════════╪═══════════╪═══╪═══════════╪═══════════╪═══════════╪══════════╡\n",
"│ 000001.SZ ┆ 20200102 ┆ 0.7885 ┆ 2.18 ┆ … ┆ 721.52104 ┆ 2580.5045 ┆ 938.15146 ┆ -0.00474 │\n",
"│ ┆ ┆ ┆ ┆ ┆ 1 ┆ 33 ┆ 4 ┆ 6 │\n",
"│ 000002.SZ ┆ 20200102 ┆ 1.0418 ┆ 1.31 ┆ … ┆ 776.91820 ┆ 47.131053 ┆ 1140.2493 ┆ -0.01105 │\n",
"│ ┆ ┆ ┆ ┆ ┆ 1 ┆ ┆ 95 ┆ 7 │\n",
"│ 000004.SZ ┆ 20200102 ┆ 2.1613 ┆ 0.92 ┆ … ┆ -69.58089 ┆ -52.61755 ┆ -24.82135 ┆ -0.00044 │\n",
"│ ┆ ┆ ┆ ┆ ┆ 5 ┆ 4 ┆ 9 ┆ 1 │\n",
"│ 000005.SZ ┆ 20200102 ┆ 0.9843 ┆ 1.35 ┆ … ┆ 142.55925 ┆ 385.57490 ┆ 208.12520 ┆ 0.022337 │\n",
"│ ┆ ┆ ┆ ┆ ┆ 6 ┆ 4 ┆ 2 ┆ │\n",
"│ 000006.SZ ┆ 20200102 ┆ 0.9252 ┆ 1.62 ┆ … ┆ 633.27582 ┆ 650.95370 ┆ 819.10495 ┆ 0.012964 │\n",
"│ ┆ ┆ ┆ ┆ ┆ ┆ 3 ┆ 5 ┆ │\n",
"└───────────┴───────────┴───────────┴───────────┴───┴───────────┴───────────┴───────────┴──────────┘\n",
"\n",
" 验证集前5行预览:\n",
"shape: (5, 53)\n",
"┌───────────┬───────────┬───────────┬───────────┬───┬───────────┬───────────┬───────────┬──────────┐\n",
"│ ts_code ┆ trade_dat ┆ turnover_ ┆ volume_ra ┆ … ┆ profit_to ┆ cashflow_ ┆ operate_p ┆ future_r │\n",
"│ --- ┆ e ┆ rate ┆ tio ┆ ┆ _market_c ┆ to_market ┆ rofit_to_ ┆ eturn_5 │\n",
"│ str ┆ --- ┆ --- ┆ --- ┆ ┆ ap ┆ _cap ┆ market_ca ┆ --- │\n",
"│ ┆ str ┆ f64 ┆ f64 ┆ ┆ --- ┆ --- ┆ p ┆ f64 │\n",
"│ ┆ ┆ ┆ ┆ ┆ f64 ┆ f64 ┆ --- ┆ │\n",
"│ ┆ ┆ ┆ ┆ ┆ ┆ ┆ f64 ┆ │\n",
"╞═══════════╪═══════════╪═══════════╪═══════════╪═══╪═══════════╪═══════════╪═══════════╪══════════╡\n",
"│ 000001.SZ ┆ 20240102 ┆ 0.5969 ┆ 1.41 ┆ … ┆ 2217.6093 ┆ 6486.3743 ┆ 2744.2180 ┆ -0.00325 │\n",
"│ ┆ ┆ ┆ ┆ ┆ 09 ┆ 45 ┆ 84 ┆ 6 │\n",
"│ 000002.SZ ┆ 20240102 ┆ 0.8348 ┆ 1.71 ┆ … ┆ 1736.4093 ┆ 19.432701 ┆ 2329.7434 ┆ -0.02660 │\n",
"│ ┆ ┆ ┆ ┆ ┆ 99 ┆ ┆ 1 ┆ 1 │\n",
"│ 000004.SZ ┆ 20240102 ┆ 2.2858 ┆ 0.78 ┆ … ┆ -168.7552 ┆ -184.4013 ┆ -192.7135 ┆ -0.01478 │\n",
"│ ┆ ┆ ┆ ┆ ┆ 72 ┆ 85 ┆ 84 ┆ 9 │\n",
"│ 000005.SZ ┆ 20240102 ┆ 0.5958 ┆ 0.43 ┆ … ┆ -96.94997 ┆ -295.0388 ┆ -46.06373 ┆ -0.05395 │\n",
"│ ┆ ┆ ┆ ┆ ┆ 7 ┆ 72 ┆ 6 ┆ │\n",
"│ 000006.SZ ┆ 20240102 ┆ 1.9404 ┆ 0.97 ┆ … ┆ -6.971845 ┆ -51.5536 ┆ -5.32671 ┆ -0.01345 │\n",
"│ ┆ ┆ ┆ ┆ ┆ ┆ ┆ ┆ 4 │\n",
"└───────────┴───────────┴───────────┴───────────┴───┴───────────┴───────────┴───────────┴──────────┘\n",
"\n",
" 测试集前5行预览:\n",
"shape: (5, 53)\n",
"┌───────────┬───────────┬───────────┬───────────┬───┬───────────┬───────────┬───────────┬──────────┐\n",
"│ ts_code ┆ trade_dat ┆ turnover_ ┆ volume_ra ┆ … ┆ profit_to ┆ cashflow_ ┆ operate_p ┆ future_r │\n",
"│ --- ┆ e ┆ rate ┆ tio ┆ ┆ _market_c ┆ to_market ┆ rofit_to_ ┆ eturn_5 │\n",
"│ str ┆ --- ┆ --- ┆ --- ┆ ┆ ap ┆ _cap ┆ market_ca ┆ --- │\n",
"│ ┆ str ┆ f64 ┆ f64 ┆ ┆ --- ┆ --- ┆ p ┆ f64 │\n",
"│ ┆ ┆ ┆ ┆ ┆ f64 ┆ f64 ┆ --- ┆ │\n",
"│ ┆ ┆ ┆ ┆ ┆ ┆ ┆ f64 ┆ │\n",
"╞═══════════╪═══════════╪═══════════╪═══════════╪═══╪═══════════╪═══════════╪═══════════╪══════════╡\n",
"│ 000001.SZ ┆ 20250102 ┆ 0.9377 ┆ 1.38 ┆ … ┆ 1791.1304 ┆ 6183.5904 ┆ 2158.1117 ┆ -0.00262 │\n",
"│ ┆ ┆ ┆ ┆ ┆ 08 ┆ 38 ┆ 45 ┆ 2 │\n",
"│ 000002.SZ ┆ 20250102 ┆ 1.2171 ┆ 1.06 ┆ … ┆ -1933.116 ┆ -1110.658 ┆ -1729.069 ┆ -0.02250 │\n",
"│ ┆ ┆ ┆ ┆ ┆ 105 ┆ 303 ┆ 737 ┆ 9 │\n",
"│ 000004.SZ ┆ 20250102 ┆ 9.4831 ┆ 0.8 ┆ … ┆ -199.1144 ┆ -126.8907 ┆ -197.3308 ┆ -0.06489 │\n",
"│ ┆ ┆ ┆ ┆ ┆ 31 ┆ 63 ┆ 47 ┆ 7 │\n",
"│ 000006.SZ ┆ 20250102 ┆ 2.2755 ┆ 0.79 ┆ … ┆ -646.1294 ┆ -325.4842 ┆ -637.5489 ┆ -0.04827 │\n",
"│ ┆ ┆ ┆ ┆ ┆ 33 ┆ 66 ┆ 17 ┆ 8 │\n",
"│ 000007.SZ ┆ 20250102 ┆ 1.9691 ┆ 1.05 ┆ … ┆ 6.740918 ┆ 108.91759 ┆ 22.556002 ┆ 0.015649 │\n",
"│ ┆ ┆ ┆ ┆ ┆ ┆ 8 ┆ ┆ │\n",
"└───────────┴───────────┴───────────┴───────────┴───┴───────────┴───────────┴───────────┴──────────┘\n"
]
}
],
"execution_count": 7
},
{
"cell_type": "code",
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-08T12:42:20.933369Z",
"start_time": "2026-03-08T12:42:19.615723Z"
}
},
"source": [
"# 步骤 3: 训练集数据处理\n",
"print(\"\\n[步骤 3/6] 训练集数据处理\")\n",
"print(\"-\" * 60)\n",
"fitted_processors = []\n",
"if processors:\n",
" for i, processor in enumerate(processors, 1):\n",
" print(\n",
" f\" [{i}/{len(processors)}] 应用处理器: {processor.__class__.__name__}\"\n",
" )\n",
" train_data_before = len(train_data)\n",
" train_data = processor.fit_transform(train_data)\n",
" train_data_after = len(train_data)\n",
" fitted_processors.append(processor)\n",
" print(f\" 处理前记录数: {train_data_before}\")\n",
" print(f\" 处理后记录数: {train_data_after}\")\n",
" if train_data_before != train_data_after:\n",
" print(f\" 删除记录数: {train_data_before - train_data_after}\")\n",
"\n",
"print(\"\\n 训练集处理后前5行预览:\")\n",
"print(train_data.head())\n",
"print(f\"\\n 训练集特征统计:\")\n",
"print(f\" 特征数: {len(feature_cols)}\")\n",
"print(f\" 样本数: {len(train_data)}\")\n",
"print(f\" 缺失值统计:\")\n",
"for col in feature_cols[:5]: # 只显示前5个特征的缺失值\n",
" null_count = train_data[col].null_count()\n",
" if null_count > 0:\n",
" print(\n",
" f\" {col}: {null_count} ({null_count / len(train_data) * 100:.2f}%)\"\n",
" )"
],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"[步骤 3/6] 训练集数据处理\n",
"------------------------------------------------------------\n",
" [1/3] 应用处理器: NullFiller\n",
" 处理前记录数: 2991506\n",
" 处理后记录数: 2991506\n",
" [2/3] 应用处理器: Winsorizer\n",
" 处理前记录数: 2991506\n",
" 处理后记录数: 2991506\n",
" [3/3] 应用处理器: StandardScaler\n",
" 处理前记录数: 2991506\n",
" 处理后记录数: 2991506\n",
"\n",
" 训练集处理后前5行预览:\n",
"shape: (5, 53)\n",
"┌───────────┬───────────┬───────────┬───────────┬───┬───────────┬───────────┬───────────┬──────────┐\n",
"│ ts_code ┆ trade_dat ┆ turnover_ ┆ volume_ra ┆ … ┆ profit_to ┆ cashflow_ ┆ operate_p ┆ future_r │\n",
"│ --- ┆ e ┆ rate ┆ tio ┆ ┆ _market_c ┆ to_market ┆ rofit_to_ ┆ eturn_5 │\n",
"│ str ┆ --- ┆ --- ┆ --- ┆ ┆ ap ┆ _cap ┆ market_ca ┆ --- │\n",
"│ ┆ str ┆ f64 ┆ f64 ┆ ┆ --- ┆ --- ┆ p ┆ f64 │\n",
"│ ┆ ┆ ┆ ┆ ┆ f64 ┆ f64 ┆ --- ┆ │\n",
"│ ┆ ┆ ┆ ┆ ┆ ┆ ┆ f64 ┆ │\n",
"╞═══════════╪═══════════╪═══════════╪═══════════╪═══╪═══════════╪═══════════╪═══════════╪══════════╡\n",
"│ 000001.SZ ┆ 20200102 ┆ -0.492311 ┆ 2.080178 ┆ … ┆ 1.441327 ┆ 2.715295 ┆ 1.645238 ┆ -0.00474 │\n",
"│ ┆ ┆ ┆ ┆ ┆ ┆ ┆ ┆ 6 │\n",
"│ 000002.SZ ┆ 20200102 ┆ -0.40826 ┆ 0.478477 ┆ … ┆ 1.5879 ┆ -0.110121 ┆ 2.111027 ┆ -0.01105 │\n",
"│ ┆ ┆ ┆ ┆ ┆ ┆ ┆ ┆ 7 │\n",
"│ 000004.SZ ┆ 20200102 ┆ -0.036785 ┆ -0.239526 ┆ … ┆ -0.651808 ┆ -0.221369 ┆ -0.574193 ┆ -0.00044 │\n",
"│ ┆ ┆ ┆ ┆ ┆ ┆ ┆ ┆ 1 │\n",
"│ 000005.SZ ┆ 20200102 ┆ -0.42734 ┆ 0.552119 ┆ … ┆ -0.090517 ┆ 0.267338 ┆ -0.037305 ┆ 0.022337 │\n",
"│ 000006.SZ ┆ 20200102 ┆ -0.446951 ┆ 1.049198 ┆ … ┆ 1.207844 ┆ 0.563309 ┆ 1.370863 ┆ 0.012964 │\n",
"└───────────┴───────────┴───────────┴───────────┴───┴───────────┴───────────┴───────────┴──────────┘\n",
"\n",
" 训练集特征统计:\n",
" 特征数: 34\n",
" 样本数: 2991506\n",
" 缺失值统计:\n",
" ma5: 11541 (0.39%)\n",
" ma10: 25950 (0.87%)\n",
" ma20: 54850 (1.83%)\n",
" ma_ratio: 54850 (1.83%)\n",
" volatility_5: 11541 (0.39%)\n"
]
}
],
"execution_count": 8
},
{
"cell_type": "code",
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-08T12:42:38.165178Z",
"start_time": "2026-03-08T12:42:20.939484Z"
}
},
"source": [
"# 步骤 4: 训练模型\n",
"print(\"\\n[步骤 4/6] 训练模型\")\n",
"print(\"-\" * 60)\n",
"print(f\" 模型类型: LightGBM\")\n",
"print(f\" 训练样本数: {len(train_data)}\")\n",
"print(f\" 特征数: {len(feature_cols)}\")\n",
"print(f\" 目标变量: {target_col}\")\n",
"\n",
"X_train = train_data.select(feature_cols)\n",
"y_train = train_data.select(target_col).to_series()\n",
"\n",
"print(f\"\\n 目标变量统计:\")\n",
"print(f\" 均值: {y_train.mean():.6f}\")\n",
"print(f\" 标准差: {y_train.std():.6f}\")\n",
"print(f\" 最小值: {y_train.min():.6f}\")\n",
"print(f\" 最大值: {y_train.max():.6f}\")\n",
"print(f\" 缺失值: {y_train.null_count()}\")\n",
"\n",
"print(\"\\n 开始训练...\")\n",
"model.fit(X_train, y_train)\n",
"print(\" 训练完成!\")"
],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"[步骤 4/6] 训练模型\n",
"------------------------------------------------------------\n",
" 模型类型: LightGBM\n",
" 训练样本数: 2991506\n",
" 特征数: 34\n",
" 目标变量: future_return_5\n",
"\n",
" 目标变量统计:\n",
" 均值: 0.001610\n",
" 标准差: 0.059623\n",
" 最小值: -0.155098\n",
" 最大值: 0.212842\n",
" 缺失值: 0\n",
"\n",
" 开始训练...\n",
" 训练完成!\n"
]
}
],
"execution_count": 9
},
{
"cell_type": "code",
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-08T12:42:38.355756Z",
"start_time": "2026-03-08T12:42:38.169740Z"
}
},
"source": [
"# 步骤 5: 测试集数据处理\n",
"print(\"\\n[步骤 5/6] 测试集数据处理\")\n",
"print(\"-\" * 60)\n",
"if processors and test_data is not train_data:\n",
" for i, processor in enumerate(fitted_processors, 1):\n",
" print(\n",
" f\" [{i}/{len(fitted_processors)}] 应用处理器: {processor.__class__.__name__}\"\n",
" )\n",
" test_data_before = len(test_data)\n",
" test_data = processor.transform(test_data)\n",
" test_data_after = len(test_data)\n",
" print(f\" 处理前记录数: {test_data_before}\")\n",
" print(f\" 处理后记录数: {test_data_after}\")\n",
"else:\n",
" print(\" 跳过测试集处理\")"
],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"[步骤 5/6] 测试集数据处理\n",
"------------------------------------------------------------\n",
" [1/3] 应用处理器: NullFiller\n",
" 处理前记录数: 771207\n",
" 处理后记录数: 771207\n",
" [2/3] 应用处理器: Winsorizer\n",
" 处理前记录数: 771207\n",
" 处理后记录数: 771207\n",
" [3/3] 应用处理器: StandardScaler\n",
" 处理前记录数: 771207\n",
" 处理后记录数: 771207\n"
]
}
],
"execution_count": 10
},
{
"cell_type": "code",
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-08T12:42:39.677665Z",
"start_time": "2026-03-08T12:42:38.364128Z"
}
},
"source": [
"# 步骤 6: 生成预测\n",
"print(\"\\n[步骤 6/6] 生成预测\")\n",
"print(\"-\" * 60)\n",
"X_test = test_data.select(feature_cols)\n",
"print(f\" 测试样本数: {len(X_test)}\")\n",
"print(\" 预测中...\")\n",
"predictions = model.predict(X_test)\n",
"print(f\" 预测完成!\")\n",
"\n",
"print(f\"\\n 预测结果统计:\")\n",
"print(f\" 均值: {predictions.mean():.6f}\")\n",
"print(f\" 标准差: {predictions.std():.6f}\")\n",
"print(f\" 最小值: {predictions.min():.6f}\")\n",
"print(f\" 最大值: {predictions.max():.6f}\")\n",
"\n",
"# 保存结果到 trainer\n",
"trainer.results = test_data.with_columns([pl.Series(\"prediction\", predictions)])"
],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"[步骤 6/6] 生成预测\n",
"------------------------------------------------------------\n",
" 测试样本数: 771207\n",
" 预测中...\n",
" 预测完成!\n",
"\n",
" 预测结果统计:\n",
" 均值: -0.000501\n",
" 标准差: 0.008088\n",
" 最小值: -0.154524\n",
" 最大值: 0.096327\n"
]
}
],
"execution_count": 11
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 4.3 训练指标曲线"
]
},
{
"cell_type": "code",
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-08T12:42:42.772470Z",
"start_time": "2026-03-08T12:42:39.683522Z"
}
},
"source": [
"print(\"\\n\" + \"=\" * 80)\n",
"print(\"训练指标曲线\")\n",
"print(\"=\" * 80)\n",
"\n",
"# 重新训练以收集指标(因为之前的训练没有保存评估结果)\n",
"print(\"\\n重新训练模型以收集训练指标...\")\n",
"\n",
"import lightgbm as lgb\n",
"\n",
"# 准备数据(使用 val 做验证test 不参与训练过程)\n",
"X_train_np = X_train.to_numpy()\n",
"y_train_np = y_train.to_numpy()\n",
"X_val_np = val_data.select(feature_cols).to_numpy()\n",
"y_val_np = val_data.select(target_col).to_series().to_numpy()\n",
"\n",
"# 创建数据集\n",
"train_dataset = lgb.Dataset(X_train_np, label=y_train_np)\n",
"val_dataset = lgb.Dataset(X_val_np, label=y_val_np, reference=train_dataset)\n",
"\n",
"# 用于存储评估结果\n",
"evals_result = {}\n",
"\n",
"# 使用与原模型相同的参数重新训练\n",
"# 正确的三分法train用于训练val用于验证test不参与训练过程\n",
"# 添加早停如果验证指标连续100轮没有改善则停止训练\n",
"booster_with_eval = lgb.train(\n",
" MODEL_PARAMS,\n",
" train_dataset,\n",
" num_boost_round=MODEL_PARAMS.get(\"n_estimators\", 100),\n",
" valid_sets=[train_dataset, val_dataset],\n",
" valid_names=[\"train\", \"val\"],\n",
" callbacks=[\n",
" lgb.record_evaluation(evals_result),\n",
" lgb.early_stopping(stopping_rounds=100, verbose=True),\n",
" ],\n",
")\n",
"\n",
"print(\"训练完成,指标已收集\")\n",
"\n",
"# 获取指标名称\n",
"metric_name = list(evals_result[\"train\"].keys())[0]\n",
"print(f\"\\n评估指标: {metric_name}\")\n",
"\n",
"# 提取训练和验证指标\n",
"train_metric = evals_result[\"train\"][metric_name]\n",
"val_metric = evals_result[\"val\"][metric_name]\n",
"\n",
"# 显示早停信息\n",
"actual_rounds = len(train_metric)\n",
"expected_rounds = MODEL_PARAMS.get(\"n_estimators\", 100)\n",
"print(f\"\\n[早停信息]\")\n",
"print(f\" 配置的最大轮数: {expected_rounds}\")\n",
"print(f\" 实际训练轮数: {actual_rounds}\")\n",
"if actual_rounds < expected_rounds:\n",
" print(f\" 早停状态: 已触发连续100轮验证指标未改善\")\n",
"else:\n",
" print(f\" 早停状态: 未触发(达到最大轮数)\")\n",
"\n",
"print(f\"\\n最终指标:\")\n",
"print(f\" 训练 {metric_name}: {train_metric[-1]:.6f}\")\n",
"print(f\" 验证 {metric_name}: {val_metric[-1]:.6f}\")"
],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"================================================================================\n",
"训练指标曲线\n",
"================================================================================\n",
"\n",
"重新训练模型以收集训练指标...\n",
"Training until validation scores don't improve for 100 rounds\n",
"Early stopping, best iteration is:\n",
"[31]\ttrain's l1: 0.0428788\tval's l1: 0.0539432\n",
"训练完成,指标已收集\n",
"\n",
"评估指标: l1\n",
"\n",
"[早停信息]\n",
" 配置的最大轮数: 1000\n",
" 实际训练轮数: 131\n",
" 早停状态: 已触发连续100轮验证指标未改善\n",
"\n",
"最终指标:\n",
" 训练 l1: 0.042513\n",
" 验证 l1: 0.053961\n"
]
}
],
"execution_count": 12
},
{
"cell_type": "code",
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-08T12:42:43.013683Z",
"start_time": "2026-03-08T12:42:42.777555Z"
}
},
"source": [
"# 绘制训练指标曲线\n",
"import matplotlib.pyplot as plt\n",
"\n",
"fig, ax = plt.subplots(figsize=(12, 6))\n",
"\n",
"# 绘制训练集和验证集的指标曲线注意val用于验证test不参与训练\n",
"iterations = range(1, len(train_metric) + 1)\n",
"ax.plot(iterations, train_metric, label=f\"Train {metric_name}\", linewidth=2, color=\"blue\")\n",
"ax.plot(iterations, val_metric, label=f\"Validation {metric_name}\", linewidth=2, color=\"red\")\n",
"\n",
"ax.set_xlabel(\"Iteration\", fontsize=12)\n",
"ax.set_ylabel(metric_name.upper(), fontsize=12)\n",
"ax.set_title(f\"Training and Validation {metric_name.upper()} Curve\", fontsize=14, fontweight=\"bold\")\n",
"ax.legend(fontsize=10)\n",
"ax.grid(True, alpha=0.3)\n",
"\n",
"# 标记最佳验证指标点(用于早停决策)\n",
"best_iter = val_metric.index(min(val_metric))\n",
"best_metric = min(val_metric)\n",
"ax.axvline(x=best_iter + 1, color=\"green\", linestyle=\"--\", alpha=0.7, label=f\"Best Iteration ({best_iter + 1})\")\n",
"ax.scatter([best_iter + 1], [best_metric], color=\"green\", s=100, zorder=5)\n",
"ax.annotate(\n",
" f\"Best: {best_metric:.6f}\\nIter: {best_iter + 1}\",\n",
" xy=(best_iter + 1, best_metric),\n",
" xytext=(best_iter + 1 + len(iterations) * 0.1, best_metric),\n",
" fontsize=9,\n",
" arrowprops=dict(arrowstyle=\"->\", color=\"green\", alpha=0.7),\n",
")\n",
"\n",
"plt.tight_layout()\n",
"plt.show()\n",
"\n",
"print(f\"\\n[指标分析]\")\n",
"print(f\" 最佳验证 {metric_name}: {best_metric:.6f}\")\n",
"print(f\" 最佳迭代轮数: {best_iter + 1}\")\n",
"print(f\" 早停建议: 如果验证指标连续10轮不下降建议在第 {best_iter + 1} 轮停止训练\")\n",
"print(f\"\\n[重要提醒] 验证集仅用于早停/调参,测试集完全独立于训练过程!\")"
],
"outputs": [
{
"data": {
"text/plain": [
"<Figure size 1200x600 with 1 Axes>"
],
"image/png": "iVBORw0KGgoAAAANSUhEUgAABKUAAAJOCAYAAABm7rQwAAAAOnRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjEwLjgsIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcvwVt1zgAAAAlwSFlzAAAPYQAAD2EBqD+naQAAdsRJREFUeJzt3Ql8E2X+x/Ff2vQAuY9y36DIIQgIggcqCCirIoiKKIcsnlziCSLIoqJ/BUFBkV1FdxcEu4onogjqqqCAoIgr3ohyFVBueuf/+j0wMUmTNm3TyTT9vH2NmcxMkknydEi++T3PuDwej0cAAAAAAAAAG8XZ+WAAAAAAAACAIpQCAAAAAACA7QilAAAAAAAAYDtCKQAAAAAAANiOUAoAAAAAAAC2I5QCAAAAAACA7QilAAAAAAAAYDtCKQAAAAAAANiOUAoAAAAAAAC2I5QCAMCBnn/+eXG5XN4pEho3buy9v/vvvz8i9xmr9PWxXit93ZzIt31oeylu24nGcy6Jdg4AAEoPQikAAIKENuFOH3zwAa8f5OGHH/ZrF2vXrg35qgwfPty7XWJiouzZsycmX8FYCZx8wzqdtm7dWuBtVqxYIWPHjpVu3bpJ+fLlC337UDwejyxbtkyuu+46Ofnkk6VSpUqSkJAgtWrVkh49esgjjzwiO3fuLPL9AwBgN7ftjwgAAAp0xhlnyKOPPhrRV+ree++VAwcOmHn9sozI0ZBAX9/c3Fxz/V//+pd07tw5z3bHjh2Tl19+2Xu9b9++UrNmTce3nZJSmva1MObOnSuvvfZaRO/z119/lWuuuUY+/vjjPOvS0tJk1apVZvrmm2/8KucAAHAyQikAAIKENuqPP/6Qhx56yHv9wgsvlF69evm9Xs2aNQv5+h08eNBUMhRF69atzRRJI0eOjOj94U/16tUz7eOdd94x1xcvXiwzZ840VSy+li5dKocOHfJeHzZsWMRfxpJoOyWlNO1rYWhFVP369aVTp06Sk5Mjb7zxRrHub/fu3dK9e3f5+eefvcuaNGkil156qamS0mPVp59+GjSwijR9PhkZGaYCDACA4qL7HgAAPqHNHXfc4Z0CQxytLvJdf8UVV0jDhg39uvI9++yz0qFDBylXrpyce+655nb6RXLcuHFyzjnnSIMGDeSkk06SpKQkE2RccsklQb+w5tf16bzzzvMu11Dj+++/l0GDBkmNGjUkOTnZPH6wKo1QY0rpfvs+1k8//SRPPfWUnHbaaeb+UlJS5K9//av54hvo6NGjMmHCBPM66LYaMMybN88856J0c9TtRowYYZ5DnTp1zOukX36bN29uur199dVXeW6jr4H1OPraaPelG264wXv7U089Vf7+978HfTy9v7/85S8mPNSpT58+smHDBikK3T/L3r175e23386zjVZQWfR11UoppdVC/fr1M12yqlWrZsKsKlWqmGqrBx98UI4cORKxbnNFec4apmk1mLYJDUG022GFChWkVatWMmrUKL8uaTqvj+v7eijffbLaX0H7qpVljz/+uJx11llStWpV87j6+BdffLG89NJLebYvTluOpEWLFpnKJn3d+vfvX+z70+OHbyB18803y3fffSezZs0yf3//93//J//973/l22+/NeFoqL+N/F4r3/cw8Hbbtm0z77++9to29TUt6O+7S5cu3vWBx9Ivv/xSrr/+ehPq67FS29Lpp59ufgQoTFsHAMQADwAACOrnn3/26D+V1jRlypR8159zzjl+19u1a2e2e+ONN/yWB5umTp3qd98LFizwW++re/fu3uWnnXaap2LFinnuz+Vyed577z2/2zVq1Cjoc3n//ff9bnv22WcH3cdzzz3X7/4yMzPzPGdruuSSS/yu62OE4/bbb8/3dUpMTPSsWLHC7zZDhw71rm/atKmnTp06QW/77LPP+t1u3bp1ngoVKuTZLjk52dOjRw/vdX3dwpGenu6pUqWK93ZXXHGF3/qdO3d64uPjvetvu+0277rq1avn+7zbtm3rOXTokN/9+a7X9hJO2ynqcx4wYEC++1epUiXPpk2bgv5dBJus9pffvurr1bp163zvR/crKyur2G05P7qvvrfV51cYgc+xsLffsWOH+Xu2bt++fXtPTk5OWLf1/dvQ44avwNfKd798b9eiRQtP7dq1/bZdunSp39/+DTfc4HffP/zwg9/2q1ev9q576qmnPG63O+R72qpVK/PeAwDKBrrvAQAQIR999JE0atRIBgwYYKp7dJwX5Xa7pX379qYrj44fpNUpWg3wySefyPvvv2+2mTZtmqkQ0uqpwti0aZOpILnttttMVYlWBGn3Gs0stPpGBz8uLO0CpLfTyrBXX33VW52klRjaRejMM88012fPnm2es0WrUS677DJTBfH6669LUWgVmXZTatu2rakY0iqKffv2yVtvvWXGysnMzJQxY8bI//73v6C318oYrYjRShK97dNPP21eF6XVJFqdofT10fnDhw+b61rNoeP1aDWZjvm0cuXKQu+7VmVdffXVplJMaQXc/v37TcWTVT2j702wrnva1ev888837UffT90/rYxZsmSJaSv6Hmh1yl133SVFVZznrM9Bu65q1ZlVsaRdyrQSSKtotKvq3XffbQbh1vdN29769evN/lt8x44KZ0yzwYMHy9dff+29rpWJWpmlg4ivWbPGLNP91uqayZMnF6stO5keI45nkMcNHTpU4uLs6+yglZhKK77atWsnv/zyi1SuXNlUwll////5z39kzpw53u6qL774ovf2LVu2lK5du5r51atXm8o6a+w1ff21Uk+7tL7wwgumwlD/tocMGSLvvvuubc8RABA9hFIAAESIjvGi3aCsEMKiX7p00u42GzduNGdb0y9v2gXps88+M13gsrOzzSDF2kWmMDRY0DBBu74oDWS0S49at25dkZ7H5Zdfbr7s631rtyHt8mSFKXqf1hf5f/zjH97baLChX/I1CLICF/2SWVhTp041X1g10NAQSkMd7TJ00UUXmetKL7VrlHaFDEbHc9JwTGm3Qn0OSrs26ZffihUrmtfdtyugjiemwaDS4Ee7FekX5MLSL+pWKKXj7mgXM+1KGNh1T98vDfEsX3zxhRnPTL+0a8ijQZQGQB07djQBitLxqooTShXnOet7nZWVZd5jDSk0hNIgTQOfBQsWmG20/eo2Grpq91btmucbSumycOnrofdn0f3TM8spDaC0K6wVTGk4OmnSpKBBTbht2cm2b9/ud11DHrvpMUXPJuhL26gGxBpy/v7776Z9arfQwFDKtxvnY4895g2ktFugHrus9+2qq67ynhxAg0cN3H3/RgAAsYlQCgCACLn11lvzBFJKx2rRqg8NHPLz22+/FfoxtQLBCqTUKaec4p0v6rg5WmVkje+jVS86VpVWxfjep34R1ZDHMnDgQG8gZX0RLUoopV9GdcwfDWYKeq2ChVJ169b1BlKBr4e1/xpKaejlS98fi4YqOtaXFbYUhn6p1moeq5JLgygNpTZv3myClmBf1PVL+j333GPCFa0Ey+85F0dxnvPChQtNqJNfUKchnK7XsbyKywqcfKuDLPHx8XLttdd6t9FARNuihnhFacvIn1bG6bEtWFWj/t1bbUaDKA2lNEyy2r++V75Bu1aHWnQcKl0fih4vCaUAIPYx0DkAABESqoJBB7AuKJCyvtQXllYoBXYhs/h2+YnUfVpVDlrB5Kt27dr5Xg/Hjh07zGtVUCCV32uV377nt/9aQeNLq7OKyjdA0S/h2g3vn//8p3eZdn3TbnOWJ554wnRtyy+QKmr78FXU56zVf9qdKpzKseLuo0WDpvz2LfB6qIApnLbsdIFderds2VKk+wk8HoT7XmkFnXZBDsbqDqv05Apa9andVC1a4egbUga+r/nRilIAQOwjlAIAIEK0ciCQVnDoGEsWDSO04kW/EOuXRB1jqjisMVwswc5gVhL3qWPK+LLGz7Ls2rWr0I+rYzDpl1rLjBkzTJCir5Pv2EKReD0CK9oC99+qpikKrQyxKkB037Ubm+8Xda0mqV69uve6bxc3rfTSbnYaGOht77zzTomUoj7n1NR
},
"metadata": {},
"output_type": "display_data",
"jetTransient": {
"display_id": null
}
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"[指标分析]\n",
" 最佳验证 l1: 0.053943\n",
" 最佳迭代轮数: 31\n",
" 早停建议: 如果验证指标连续10轮不下降建议在第 31 轮停止训练\n",
"\n",
"[重要提醒] 验证集仅用于早停/调参,测试集完全独立于训练过程!\n"
]
}
],
"execution_count": 13
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 4.4 查看结果"
]
},
{
"cell_type": "code",
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-08T12:42:43.064939Z",
"start_time": "2026-03-08T12:42:43.030137Z"
}
},
"source": [
"print(\"\\n\" + \"=\" * 80)\n",
"print(\"训练结果\")\n",
"print(\"=\" * 80)\n",
"\n",
"results = trainer.results\n",
"\n",
"print(f\"\\n结果数据形状: {results.shape}\")\n",
"print(f\"结果列: {results.columns}\")\n",
"print(f\"\\n结果前10行预览:\")\n",
"print(results.head(10))\n",
"print(f\"\\n结果后5行预览:\")\n",
"print(results.tail())\n",
"\n",
"print(f\"\\n每日预测样本数统计:\")\n",
"daily_counts = results.group_by(\"trade_date\").agg(pl.len()).sort(\"trade_date\")\n",
"print(f\" 最小: {daily_counts['len'].min()}\")\n",
"print(f\" 最大: {daily_counts['len'].max()}\")\n",
"print(f\" 平均: {daily_counts['len'].mean():.2f}\")\n",
"\n",
"# 展示某一天的前10个预测结果\n",
"sample_date = results[\"trade_date\"][0]\n",
"sample_data = results.filter(results[\"trade_date\"] == sample_date).head(10)\n",
"print(f\"\\n示例日期 {sample_date} 的前10条预测:\")\n",
"print(sample_data.select([\"ts_code\", \"trade_date\", target_col, \"prediction\"]))"
],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"================================================================================\n",
"训练结果\n",
"================================================================================\n",
"\n",
"结果数据形状: (771207, 54)\n",
"结果列: ['ts_code', 'trade_date', 'turnover_rate', 'volume_ratio', 'high', 'vol', 'close', 'low', 'total_mv', 'f_ann_date', 'total_profit', 'operate_profit', 'ebit', 'n_income', 'ebitda', 'money_cap', 'total_liab', 'n_cashflow_act', 'ma5', 'ma10', 'ma20', 'ma_ratio', 'volatility_5', 'volatility_20', 'vol_ratio', 'return_10', 'return_20', 'return_diff', 'vol_ma5', 'vol_ma20', 'market_cap_rank', 'high_low_ratio', 'turnover_rate_mean_5', 'bbi_ratio_factor', 'volume_change_rate', 'turnover_deviation', 'vol_std_5', 'std_return_5', 'std_return_90', 'cs_rank_volume_ratio', 'cs_rank_turnover_rate', 'n_income_rank', 'operate_profit_rank', 'total_profit_rank', 'ebit_rank', 'ebitda_rank', 'total_liab_rank', 'money_cap_rank', 'n_cashflow_act_rank', 'profit_to_market_cap', 'cashflow_to_market_cap', 'operate_profit_to_market_cap', 'future_return_5', 'prediction']\n",
"\n",
"结果前10行预览:\n",
"shape: (10, 54)\n",
"┌───────────┬───────────┬───────────┬───────────┬───┬───────────┬───────────┬───────────┬──────────┐\n",
"│ ts_code ┆ trade_dat ┆ turnover_ ┆ volume_ra ┆ … ┆ cashflow_ ┆ operate_p ┆ future_re ┆ predicti │\n",
"│ --- ┆ e ┆ rate ┆ tio ┆ ┆ to_market ┆ rofit_to_ ┆ turn_5 ┆ on │\n",
"│ str ┆ --- ┆ --- ┆ --- ┆ ┆ _cap ┆ market_ca ┆ --- ┆ --- │\n",
"│ ┆ str ┆ f64 ┆ f64 ┆ ┆ --- ┆ p ┆ f64 ┆ f64 │\n",
"│ ┆ ┆ ┆ ┆ ┆ f64 ┆ --- ┆ ┆ │\n",
"│ ┆ ┆ ┆ ┆ ┆ ┆ f64 ┆ ┆ │\n",
"╞═══════════╪═══════════╪═══════════╪═══════════╪═══╪═══════════╪═══════════╪═══════════╪══════════╡\n",
"│ 000001.SZ ┆ 20250102 ┆ -0.442803 ┆ 0.60735 ┆ … ┆ 4.430923 ┆ 3.905767 ┆ -0.002622 ┆ -0.01380 │\n",
"│ ┆ ┆ ┆ ┆ ┆ ┆ ┆ ┆ 2 │\n",
"│ 000002.SZ ┆ 20250102 ┆ -0.350092 ┆ 0.018219 ┆ … ┆ -1.401378 ┆ -3.606928 ┆ -0.022509 ┆ -0.01000 │\n",
"│ ┆ ┆ ┆ ┆ ┆ ┆ ┆ ┆ 2 │\n",
"│ 000004.SZ ┆ 20250102 ┆ 2.392751 ┆ -0.46045 ┆ … ┆ -0.304204 ┆ -0.971788 ┆ -0.064897 ┆ -0.00486 │\n",
"│ ┆ ┆ ┆ ┆ ┆ ┆ ┆ ┆ 4 │\n",
"│ 000006.SZ ┆ 20250102 ┆ 0.001109 ┆ -0.478861 ┆ … ┆ -0.525691 ┆ -1.986389 ┆ -0.048278 ┆ -0.00030 │\n",
"│ ┆ ┆ ┆ ┆ ┆ ┆ ┆ ┆ 9 │\n",
"│ 000007.SZ ┆ 20250102 ┆ -0.100561 ┆ -0.000192 ┆ … ┆ -0.041212 ┆ -0.464999 ┆ 0.015649 ┆ 0.000112 │\n",
"│ 000008.SZ ┆ 20250102 ┆ 0.59965 ┆ -0.810247 ┆ … ┆ -0.360266 ┆ -1.189503 ┆ -0.066939 ┆ -0.01029 │\n",
"│ ┆ ┆ ┆ ┆ ┆ ┆ ┆ ┆ 8 │\n",
"│ 000009.SZ ┆ 20250102 ┆ -0.443002 ┆ 1.178071 ┆ … ┆ 0.296618 ┆ 0.676674 ┆ -0.036045 ┆ -0.00890 │\n",
"│ ┆ ┆ ┆ ┆ ┆ ┆ ┆ ┆ 9 │\n",
"│ 000010.SZ ┆ 20250102 ┆ 1.435875 ┆ 0.478477 ┆ … ┆ -0.214012 ┆ -1.156491 ┆ 0.092123 ┆ -0.00856 │\n",
"│ ┆ ┆ ┆ ┆ ┆ ┆ ┆ ┆ 6 │\n",
"│ 000011.SZ ┆ 20250102 ┆ -0.487201 ┆ -0.073833 ┆ … ┆ -1.762282 ┆ -0.517669 ┆ -0.022094 ┆ -0.00446 │\n",
"│ ┆ ┆ ┆ ┆ ┆ ┆ ┆ ┆ 2 │\n",
"│ 000012.SZ ┆ 20250102 ┆ -0.47635 ┆ 0.496888 ┆ … ┆ 0.793245 ┆ 0.662619 ┆ -0.029188 ┆ -0.01157 │\n",
"│ ┆ ┆ ┆ ┆ ┆ ┆ ┆ ┆ 2 │\n",
"└───────────┴───────────┴───────────┴───────────┴───┴───────────┴───────────┴───────────┴──────────┘\n",
"\n",
"结果后5行预览:\n",
"shape: (5, 54)\n",
"┌───────────┬───────────┬───────────┬───────────┬───┬───────────┬───────────┬───────────┬──────────┐\n",
"│ ts_code ┆ trade_dat ┆ turnover_ ┆ volume_ra ┆ … ┆ cashflow_ ┆ operate_p ┆ future_re ┆ predicti │\n",
"│ --- ┆ e ┆ rate ┆ tio ┆ ┆ to_market ┆ rofit_to_ ┆ turn_5 ┆ on │\n",
"│ str ┆ --- ┆ --- ┆ --- ┆ ┆ _cap ┆ market_ca ┆ --- ┆ --- │\n",
"│ ┆ str ┆ f64 ┆ f64 ┆ ┆ --- ┆ p ┆ f64 ┆ f64 │\n",
"│ ┆ ┆ ┆ ┆ ┆ f64 ┆ --- ┆ ┆ │\n",
"│ ┆ ┆ ┆ ┆ ┆ ┆ f64 ┆ ┆ │\n",
"╞═══════════╪═══════════╪═══════════╪═══════════╪═══╪═══════════╪═══════════╪═══════════╪══════════╡\n",
"│ 605588.SH ┆ 20251231 ┆ -0.287278 ┆ -0.478861 ┆ … ┆ 0.245664 ┆ -0.750859 ┆ null ┆ -0.00085 │\n",
"│ ┆ ┆ ┆ ┆ ┆ ┆ ┆ ┆ 3 │\n",
"│ 605589.SH ┆ 20251231 ┆ -0.211755 ┆ -0.828658 ┆ … ┆ -0.316701 ┆ 0.328847 ┆ null ┆ -0.00538 │\n",
"│ ┆ ┆ ┆ ┆ ┆ ┆ ┆ ┆ 4 │\n",
"│ 605598.SH ┆ 20251231 ┆ 0.754611 ┆ -0.552502 ┆ … ┆ -0.157841 ┆ -0.386306 ┆ null ┆ -0.00914 │\n",
"│ ┆ ┆ ┆ ┆ ┆ ┆ ┆ ┆ 6 │\n",
"│ 605599.SH ┆ 20251231 ┆ -0.565809 ┆ 0.404836 ┆ … ┆ 1.377021 ┆ 1.078554 ┆ null ┆ -0.00074 │\n",
"│ 689009.SH ┆ 20251231 ┆ -0.486869 ┆ -0.570913 ┆ … ┆ 1.185366 ┆ 0.848817 ┆ null ┆ -0.00457 │\n",
"│ ┆ ┆ ┆ ┆ ┆ ┆ ┆ ┆ 1 │\n",
"└───────────┴───────────┴───────────┴───────────┴───┴───────────┴───────────┴───────────┴──────────┘\n",
"\n",
"每日预测样本数统计:\n",
" 最小: 3147\n",
" 最大: 3186\n",
" 平均: 3173.69\n",
"\n",
"示例日期 20250102 的前10条预测:\n",
"shape: (10, 4)\n",
"┌───────────┬────────────┬─────────────────┬────────────┐\n",
"│ ts_code ┆ trade_date ┆ future_return_5 ┆ prediction │\n",
"│ --- ┆ --- ┆ --- ┆ --- │\n",
"│ str ┆ str ┆ f64 ┆ f64 │\n",
"╞═══════════╪════════════╪═════════════════╪════════════╡\n",
"│ 000001.SZ ┆ 20250102 ┆ -0.002622 ┆ -0.013802 │\n",
"│ 000002.SZ ┆ 20250102 ┆ -0.022509 ┆ -0.010002 │\n",
"│ 000004.SZ ┆ 20250102 ┆ -0.064897 ┆ -0.004864 │\n",
"│ 000006.SZ ┆ 20250102 ┆ -0.048278 ┆ -0.000309 │\n",
"│ 000007.SZ ┆ 20250102 ┆ 0.015649 ┆ 0.000112 │\n",
"│ 000008.SZ ┆ 20250102 ┆ -0.066939 ┆ -0.010298 │\n",
"│ 000009.SZ ┆ 20250102 ┆ -0.036045 ┆ -0.008909 │\n",
"│ 000010.SZ ┆ 20250102 ┆ 0.092123 ┆ -0.008566 │\n",
"│ 000011.SZ ┆ 20250102 ┆ -0.022094 ┆ -0.004462 │\n",
"│ 000012.SZ ┆ 20250102 ┆ -0.029188 ┆ -0.011572 │\n",
"└───────────┴────────────┴─────────────────┴────────────┘\n"
]
}
],
"execution_count": 14
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 4.4 保存结果"
]
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-08T12:42:43.409508Z",
"start_time": "2026-03-08T12:42:43.073624Z"
}
},
"cell_type": "code",
"source": [
"print(\"\\n\" + \"=\" * 80)\n",
"print(\"保存预测结果\")\n",
"print(\"=\" * 80)\n",
"\n",
"# 确保输出目录存在\n",
"os.makedirs(OUTPUT_DIR, exist_ok=True)\n",
"\n",
"# 生成时间戳\n",
"start_dt = datetime.strptime(TEST_START, \"%Y%m%d\")\n",
"end_dt = datetime.strptime(TEST_END, \"%Y%m%d\")\n",
"date_str = f\"{start_dt.strftime('%Y%m%d')}_{end_dt.strftime('%Y%m%d')}\"\n",
"\n",
"# 保存每日 Top N\n",
"print(f\"\\n[1/1] 保存每日 Top {TOP_N} 股票...\")\n",
"topn_output_path = os.path.join(OUTPUT_DIR, f\"regression_output.csv\")\n",
"\n",
"# 按日期分组,取每日 top N\n",
"topn_by_date = []\n",
"unique_dates = results[\"trade_date\"].unique().sort()\n",
"for date in unique_dates:\n",
" day_data = results.filter(results[\"trade_date\"] == date)\n",
" # 按 prediction 降序排序,取前 N\n",
" topn = day_data.sort(\"prediction\", descending=True).head(TOP_N)\n",
" topn_by_date.append(topn)\n",
"\n",
"# 合并所有日期的 top N\n",
"topn_results = pl.concat(topn_by_date)\n",
"\n",
"# 格式化日期并调整列顺序:日期、分数、股票\n",
"topn_to_save = topn_results.select(\n",
" [\n",
" pl.col(\"trade_date\").str.slice(0, 4)\n",
" + \"-\"\n",
" + pl.col(\"trade_date\").str.slice(4, 2)\n",
" + \"-\"\n",
" + pl.col(\"trade_date\").str.slice(6, 2).alias(\"date\"),\n",
" pl.col(\"prediction\").alias(\"score\"),\n",
" pl.col(\"ts_code\"),\n",
" ]\n",
")\n",
"topn_to_save.write_csv(topn_output_path, include_header=True)\n",
"print(f\" 保存路径: {topn_output_path}\")\n",
"print(f\" 保存行数: {len(topn_to_save)}{len(unique_dates)}个交易日 × 每日top{TOP_N}\")\n",
"print(f\"\\n 预览前15行:\")\n",
"print(topn_to_save.head(15))"
],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"================================================================================\n",
"保存预测结果\n",
"================================================================================\n",
"\n",
"[1/1] 保存每日 Top 5 股票...\n",
" 保存路径: output\\regression_output.csv\n",
" 保存行数: 1215243个交易日 × 每日top5\n",
"\n",
" 预览前15行:\n",
"shape: (15, 3)\n",
"┌────────────┬──────────┬───────────┐\n",
"│ trade_date ┆ score ┆ ts_code │\n",
"│ --- ┆ --- ┆ --- │\n",
"│ str ┆ f64 ┆ str │\n",
"╞════════════╪══════════╪═══════════╡\n",
"│ 2025-01-02 ┆ 0.086703 ┆ 603007.SH │\n",
"│ 2025-01-02 ┆ 0.073642 ┆ 603559.SH │\n",
"│ 2025-01-02 ┆ 0.047455 ┆ 603959.SH │\n",
"│ 2025-01-02 ┆ 0.019551 ┆ 600530.SH │\n",
"│ 2025-01-02 ┆ 0.014435 ┆ 600608.SH │\n",
"│ … ┆ … ┆ … │\n",
"│ 2025-01-06 ┆ 0.087354 ┆ 603007.SH │\n",
"│ 2025-01-06 ┆ 0.053317 ┆ 603959.SH │\n",
"│ 2025-01-06 ┆ 0.033367 ┆ 000573.SZ │\n",
"│ 2025-01-06 ┆ 0.019381 ┆ 603848.SH │\n",
"│ 2025-01-06 ┆ 0.017292 ┆ 603226.SH │\n",
"└────────────┴──────────┴───────────┘\n"
]
}
],
"execution_count": 15
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 4.5 特征重要性"
]
},
{
"cell_type": "code",
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-08T12:42:43.418164Z",
"start_time": "2026-03-08T12:42:43.414098Z"
}
},
"source": [
"importance = model.feature_importance()\n",
"if importance is not None:\n",
" print(\"\\n特征重要性:\")\n",
" print(importance.sort_values(ascending=False))\n",
"\n",
"print(\"\\n\" + \"=\" * 80)\n",
"print(\"训练完成!\")\n",
"print(\"=\" * 80)"
],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"特征重要性:\n",
"ebitda_rank 14994.629439\n",
"bbi_ratio_factor 1948.020045\n",
"cs_rank_turnover_rate 1686.329153\n",
"std_return_90 1566.943976\n",
"return_10 1430.825613\n",
"std_return_5 1242.399894\n",
"ma_ratio 979.636333\n",
"high_low_ratio 932.793097\n",
"volume_change_rate 652.408669\n",
"vol_ratio 651.873589\n",
"cs_rank_volume_ratio 649.430382\n",
"vol_ma20 649.122400\n",
"turnover_rate_mean_5 615.645898\n",
"ma20 535.948738\n",
"return_20 452.979679\n",
"market_cap_rank 407.870511\n",
"return_diff 304.914024\n",
"ebit_rank 294.325396\n",
"profit_to_market_cap 287.904060\n",
"vol_std_5 243.944097\n",
"operate_profit_to_market_cap 232.366538\n",
"ma10 201.935446\n",
"volatility_20 169.372244\n",
"volatility_5 161.736089\n",
"ma5 111.549055\n",
"n_income_rank 79.050783\n",
"vol_ma5 73.284933\n",
"operate_profit_rank 57.328938\n",
"total_profit_rank 35.760864\n",
"n_cashflow_act_rank 34.917162\n",
"money_cap_rank 32.511166\n",
"total_liab_rank 24.127689\n",
"cashflow_to_market_cap 20.326702\n",
"turnover_deviation 0.000000\n",
"dtype: float64\n",
"\n",
"================================================================================\n",
"训练完成!\n",
"================================================================================\n"
]
}
],
"execution_count": 16
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 5. 可视化分析\n",
"\n",
"使用训练好的模型直接绘图。\n",
"- **特征重要性图**:辅助特征选择\n",
"- **决策树图**:理解决策逻辑"
]
},
{
"cell_type": "code",
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-08T12:42:43.429903Z",
"start_time": "2026-03-08T12:42:43.423547Z"
}
},
"source": [
"# 导入可视化库\n",
"import matplotlib.pyplot as plt\n",
"import lightgbm as lgb\n",
"import pandas as pd\n",
"\n",
"# 从封装的model中取出底层Booster\n",
"booster = model.model\n",
"print(f\"模型类型: {type(booster)}\")\n",
"print(f\"特征数量: {len(feature_cols)}\")"
],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"模型类型: <class 'lightgbm.basic.Booster'>\n",
"特征数量: 34\n"
]
}
],
"execution_count": 17
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 5.1 绘制特征重要性(辅助特征选择)\n",
"\n",
"**解读**\n",
"- 重要性高的特征对模型贡献大\n",
"- 重要性为0的特征可以考虑删除\n",
"- 可以帮助理解哪些因子最有效"
]
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2026-03-08T12:42:43.605795Z",
"start_time": "2026-03-08T12:42:43.439482Z"
}
},
"cell_type": "code",
"source": [
"print(\"绘制特征重要性...\")\n",
"\n",
"fig, ax = plt.subplots(figsize=(10, 8))\n",
"lgb.plot_importance(\n",
" booster, \n",
" max_num_features=20,\n",
" importance_type='gain',\n",
" title='Feature Importance (Gain)',\n",
" ax=ax\n",
")\n",
"ax.set_xlabel('Importance (Gain)')\n",
"plt.tight_layout()\n",
"plt.show()\n",
"\n",
"# 打印重要性排名\n",
"importance_gain = pd.Series(\n",
" booster.feature_importance(importance_type='gain'),\n",
" index=feature_cols\n",
").sort_values(ascending=False)\n",
"\n",
"print(\"\\n[特征重要性排名 - Gain]\")\n",
"print(importance_gain)\n",
"\n",
"# 识别低重要性特征\n",
"zero_importance = importance_gain[importance_gain == 0].index.tolist()\n",
"if zero_importance:\n",
" print(f\"\\n[低重要性特征] 以下{len(zero_importance)}个特征重要性为0可考虑删除:\")\n",
" for feat in zero_importance:\n",
" print(f\" - {feat}\")\n",
"else:\n",
" print(\"\\n所有特征都有一定重要性\")"
],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"绘制特征重要性...\n"
]
},
{
"data": {
"text/plain": [
"<Figure size 1000x800 with 1 Axes>"
],
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA9kAAAMWCAYAAADlCkWLAAAAOnRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjEwLjgsIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcvwVt1zgAAAAlwSFlzAAAPYQAAD2EBqD+naQAA8BVJREFUeJzs3Qm4TeX7//HbPCZDpoQyRSIyZcg8RaKEyhiVypCppFSIEEXfTM2iuVQUkhTxlSlDKEqIMpYpZD7/6/N8/2v/9tln7+Mcds5yzvt1XTvn7L322sOzdPms+36elSomJibGAAAAAADAeUt9/rsAAAAAAACEbAAAAAAAoohKNgAAAAAAUULIBgAAAAAgSgjZAAAAAABECSEbAAAAAIAoIWQDAAAAABAlhGwAAAAAAKKEkA0AAAAAQJQQsgEAAC6w7du3W8aMGe2///3vBXm9Tp062ZVXXnlOz3300UetSpUqUX9PAJBcEbIBAMnK5MmTLVWqVGFvCgv/hsWLF9ugQYPswIED5tfvY8WKFXaxmjBhgvscycmQIUNccK1evXqcxxYuXGitW7e2AgUKWPr06e3SSy912+o5u3fvvuDvtVevXrZmzRqbMWPGBX9tALgYpU3qNwAAwL9BgeSqq66Kdd+11177r4XswYMHu2ph9uzZ/5XXSMkUsi+77DL3/SYHe/futTfffNPdQj355JP29NNPW5EiRdzn1Z/Hjh2z77//3p577jn3nF9//TXRr/nKK6/YmTNnzun95suXz5o3b26jR4+2W2655Zz2AQApCSEbAJAs3XTTTVaxYkW7mB05csSyZMliKdXRo0ctc+bMlty89dZbljZtWmvWrFms+99//30XsFXFnjp1qqtiBxszZoy7nYt06dKd13vWe2rVqpVt3rzZBX8AQGS0iwMAUqTZs2fbjTfe6ELsJZdcYk2bNrX169fH2uaHH34IVBM1f1YVvc6dO9tff/0V2EZt4g8//LD7WZVzrzV969at7qafw7U66349N3g/uu/HH3+0u+66y3LkyGE1atSIFcwqVKhgmTJlspw5c9odd9zh5vWeC32mrFmz2rZt2+zmm292P6s1efz48e7xtWvXWt26dd13U7hwYXvnnXfCtqB/++231rVrV8uVK5dly5bNOnToYPv37w9biS5durRlyJDBLr/8cuvWrVuc1vratWu7TgNVbGvWrOnC9WOPPebmEWtcFixYEPhuta3s27fP+vXrZ2XKlHGfQe9BJ1fU2hxs/vz57nkffPCBDRs2zK644go3nvXq1bNNmzbFeb9Lly61Jk2auDHQd1C2bFl74YUXYm2zYcMGu/32291YaF86oZPQdupPP/3UtX/rPYdWsVWxf+211+IEbFHbePAxI9OnT3fHrr5Xfb9FixZ1Qf306dPxzsn2jk1Vp19++WX3PD2/UqVKtnz58jivXb9+/cDrAQDiRyUbAJAsHTx40P78889Y9ynAiKqEHTt2tEaNGtnIkSNdxXTixIku1K5atSoQRubOnesqd3fffbcL2Ap7CiT6c8mSJS6k3Hbbbfbzzz/bu+++66qM3mvkzp3btQUnlqqFxYsXt2eeecZiYmLcfQqGTzzxhKsm3nPPPW6/L774ogujer/n0qKuEKZAqn08++yz9vbbb1v37t1dqHz88cetbdu27rNNmjTJheeqVavGab/X9nptBb+NGze67/C3334LhFrRY2qlV0h74IEHAtspyGnRr+AKq05e6D3pBEK7du0sb968LlD36NHDBVK9L9H9orFRYNV3pvem+covvfSS1apVy52sUPAMNmLECEudOrUL5jo+9Ln1ORWqPRpznXjInz+/PfTQQ27cf/rpJ/v888/d76Lx11xqnZjQPH99ZwrwLVq0sGnTptmtt94a8Xs/efKk++z6LoLpGNJN4xsavuOjEx7avk+fPu7Pr7/+2oX1Q4cO2ahRo876fJ1A+fvvv93JEo2ZvhONu77b4LFRwFcQ15j17t07we8PAFKkGAAAkpE33nhDyTTsTf7++++Y7Nmzx9x7772xnrdr166YSy+9NNb9R48ejbP/d9991+3r22+/Ddw3atQod9+WLVtibavfdb/eUyjd/9RTTwV+18+6784774y13datW2PSpEkTM2zYsFj3r127NiZt2rRx7o/0fSxfvjxwX8eOHd19zzzzTOC+/fv3x2TKlCkmVapUMe+9917g/g0bNsR5r94+K1SoEHPixInA/c8++6y7f/r06e73PXv2xKRPnz6mYcOGMadPnw5sN27cOLfd66+/HrivVq1a7r5JkybF+QylS5d2j4c6duxYrP1633mGDBlihgwZErjvm2++cfsuVapUzPHjxwP3v/DCC+5+fZdy6tSpmKuuuiqmcOHC7vsIdubMmcDP9erViylTpox7/eDHq1WrFlO8ePGY+GzatMm95osvvhjrfn1nun/s2LFxXnfv3r2xbidPnoz3GO3atWtM5syZY70/jbk+V/D3pNfLlStXzL59++K8j88++yzOfjWO+g4BAPGjXRwAkCyp9VlVyeCb6E+1Kt95552u0u3d0qRJ41p4v/nmm8A+1Jrt0eJT2u6GG25wv69cufJfed/3339/rN8//vhjt2CVqtjB71cVVlW8g99vYqlq6lFF+uqrr3ZVWb2WR/fpMVU2Q913332xqp2qzmqu8axZs9zvX331lZ04ccKtTq0Ksufee+91rd0zZ86MtT+1K6trIKG0vbdfVeZVCVc1V+853Pho38Ft2JouIN5nU1fAli1b3PsN7Q7wKvNqUVe1WN+RKsDeeOi11Rnxyy+/2B9//BHxPXtTDdSKHkyVZwmtYqvirq6I4Nvq1avDHqPe+9HnUneGWtrPpk2bNrHeS+h3EkzbhXaHAADiol0cAJAsVa5cOezCZwpBojnH4Sj8eRSo1Or83nvv2Z49e+KEn39DaEu23q8K3wrU0VzQSvOIFdiCqSVY85W9QBl8f7i51qHvSQFRbdaa7ytqHReF3mAKuprn7j3u8S5ZlVA6+aC50przrXAcPA9Z88RDFSpUKNbvXrj0Ppu3and8q9BrDrfGQ+37uoWjY0WfJT7eVACP1gWQw4cPx/lOvRNEX375ZZwWcLWuDxw40AV/L6gn5hg923cS+p5Djw0AQFyEbABAiuJdxkjzslUNDqVKrEfVSl2eSwublStXzgUePb9x48YJuhxSpEASuihVsODKpPd+tR8t1KZqe6jEzN8NFm5f8d0fGgr/DaGf/Ww0b11BV4vRabEvLUKmyrYq0eHGJxqfzduv5nWrch1OsWLFIj7fC/+hIbZkyZLuz3Xr1sU5Hr1Fx37//fdYj6kjQ/PPdWJIl6zTnGmdPFEVv3///gk6RhPzneg9e2sOAAAiI2QDAFIUBRHJkydPILyEo0Axb948V8nWQlKhlfCEhGmvKhi6knZoBfds71eBRxXuEiVKmJ/ou6hTp07gd1Vhd+7c6VbmFq1MLlrsLPiyT2ohV+U5vu8/Id/vRx995F5fq3EH0/d9LmHQOzYUdCO9N+9zqIMgoe8/tHKskwn6/MFU7VdngBZyGzt2bIIu3aYF5tR+rikFWsDOE7rvaNF+r7vuun9l3wCQnDAnGwCQoqj6qMqfqqBa6TmUtyK4V+ELregpAIXyAlFomNbrKOzpUlfB1N6cUFrpWe9FYT/0vej34MuJXWhaaT34O9Sq4adOnXIrhItCqNq///Of/8R67wrFamXWpacSQt9v6Hcr+l5Cv5MPP/ww3jnR8bn++uvdyQyNcejrea+jkzNa8VyrmOuEQqizrSivcK5pDCtWrIjzmFZi15xnzVkPd2yGftZwx6hOYCTm+EoojZfa6atVqxb1fQNAckMlGwCQoij4Kgy2b9/ehSpdLkpzk3XNaC3EpUszjRs3zm3nXd5KgUdzbDUnNlyVUNevFl1iSvtTkGrWrJkLh1pcTJeO0p8KVwrculRTYqqrQ4cOtQEDBri5zrpMlObv6n188sknbvExtS4nBQU6XWtabfWqVivc6TJot9xyi3tc36vet04QqMVe93vb6XrMukxXQuj71Zjpe1ArtoKu5tT
},
"metadata": {},
"output_type": "display_data",
"jetTransient": {
"display_id": null
}
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"[特征重要性排名 - Gain]\n",
"ebitda_rank 14994.629439\n",
"bbi_ratio_factor 1948.020045\n",
"cs_rank_turnover_rate 1686.329153\n",
"std_return_90 1566.943976\n",
"return_10 1430.825613\n",
"std_return_5 1242.399894\n",
"ma_ratio 979.636333\n",
"high_low_ratio 932.793097\n",
"volume_change_rate 652.408669\n",
"vol_ratio 651.873589\n",
"cs_rank_volume_ratio 649.430382\n",
"vol_ma20 649.122400\n",
"turnover_rate_mean_5 615.645898\n",
"ma20 535.948738\n",
"return_20 452.979679\n",
"market_cap_rank 407.870511\n",
"return_diff 304.914024\n",
"ebit_rank 294.325396\n",
"profit_to_market_cap 287.904060\n",
"vol_std_5 243.944097\n",
"operate_profit_to_market_cap 232.366538\n",
"ma10 201.935446\n",
"volatility_20 169.372244\n",
"volatility_5 161.736089\n",
"ma5 111.549055\n",
"n_income_rank 79.050783\n",
"vol_ma5 73.284933\n",
"operate_profit_rank 57.328938\n",
"total_profit_rank 35.760864\n",
"n_cashflow_act_rank 34.917162\n",
"money_cap_rank 32.511166\n",
"total_liab_rank 24.127689\n",
"cashflow_to_market_cap 20.326702\n",
"turnover_deviation 0.000000\n",
"dtype: float64\n",
"\n",
"[低重要性特征] 以下1个特征重要性为0可考虑删除:\n",
" - turnover_deviation\n"
]
}
],
"execution_count": 18
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.0"
}
},
"nbformat": 4,
"nbformat_minor": 4
}