feat(factors): 新增筹码集中度相关因子并优化训练框架

- 添加 19 个筹码分布和胜率相关因子(包括chip_dispersion、winner_rate等系列)
- LightGBM模型添加早停和训练指标记录功能
- 统一Label配置到common.py模块
- 新增list_factors.py因子列表脚本
This commit is contained in:
2026-03-29 01:34:58 +08:00
parent d4e0e2a0b6
commit c3d1b157e9
9 changed files with 373 additions and 246 deletions

View File

@@ -250,67 +250,58 @@ SELECTED_FACTORS = [
"GTJA_alpha188", "GTJA_alpha188",
"GTJA_alpha189", "GTJA_alpha189",
"GTJA_alpha191", "GTJA_alpha191",
"chip_dispersion_90",
"chip_dispersion_70",
"cost_skewness",
"dispersion_change_20",
"price_to_avg_cost",
"price_to_median_cost",
"mean_median_dev",
"trap_pressure",
"bottom_profit",
"history_position",
"winner_rate_surge_5",
"winner_rate_cs_rank",
"winner_rate_dev_20",
"winner_rate_volatility",
"smart_money_accumulation",
"winner_vol_corr_20",
"cost_base_momentum",
"bottom_cost_stability",
"pivot_reversion",
"chip_transition",
] ]
# 因子定义字典完整因子库用于存放尚未注册到metadata的因子 # 因子定义字典完整因子库用于存放尚未注册到metadata的因子
FACTOR_DEFINITIONS = {"cs_rank_circ_mv": "cs_rank(circ_mv)"} FACTOR_DEFINITIONS = {"cs_rank_circ_mv": "cs_rank(circ_mv)"}
# 需要排除的因子列表(这些因子不会被计算和使用)
# 用于临时屏蔽效果不好的因子,无需从 SELECTED_FACTORS 中删除 # =============================================================================
# EXCLUDED_FACTORS: List[str] = [ # Label 配置(统一绑定 label_name 和 label_dsl
# # "GTJA_alpha005", # =============================================================================
# # "GTJA_alpha028", # Label 名称
# # "GTJA_alpha023", LABEL_NAME = "future_return_5"
# # "GTJA_alpha002",
# # "GTJA_alpha010", # Label DSL 公式
# # "GTJA_alpha011", LABEL_DSL = "(ts_delay(close, -5) / ts_delay(open, -1)) - 1"
# # "GTJA_alpha044",
# # "GTJA_alpha036", # Label 配置字典(绑定 name 和 dsl
# # "GTJA_alpha027", LABEL_FACTOR = {LABEL_NAME: LABEL_DSL}
# # "GTJA_alpha109",
# # "GTJA_alpha104",
# # "GTJA_alpha103",
# # "GTJA_alpha085",
# # "GTJA_alpha111",
# # "GTJA_alpha092",
# # "GTJA_alpha067",
# # "GTJA_alpha060",
# # "GTJA_alpha062",
# # "GTJA_alpha063",
# # "GTJA_alpha079",
# # "GTJA_alpha073",
# # "GTJA_alpha087",
# # "GTJA_alpha117",
# # "GTJA_alpha113",
# # "GTJA_alpha138",
# # "GTJA_alpha121",
# # "GTJA_alpha124",
# # "GTJA_alpha133",
# # "GTJA_alpha131",
# # "GTJA_alpha118",
# # "GTJA_alpha164",
# # "GTJA_alpha162",
# # "GTJA_alpha157",
# # "GTJA_alpha171",
# # "GTJA_alpha177",
# # "GTJA_alpha180",
# # "GTJA_alpha188",
# # "GTJA_alpha191",
# ]
def get_label_factor(label_name: str) -> dict: def get_label_factor(label_name: str) -> dict:
"""获取Label因子定义字典。 """获取Label因子定义字典。
警告: 此函数已废弃,请直接使用 LABEL_FACTOR 常量。
label_name 参数将被忽略,始终返回预定义的 LABEL_FACTOR。
Args: Args:
label_name: label因子名称 label_name: label因子名称(已废弃,仅保留参数保持向后兼容)
Returns: Returns:
Label因子定义字典 Label因子定义字典
""" """
return { return LABEL_FACTOR
label_name: "(ts_delay(close, -5) / ts_delay(open, -1)) - 1",
}
# ============================================================================= # =============================================================================

View File

@@ -21,7 +21,8 @@ from src.training.components.filters import STFilter
from src.experiment.common import ( from src.experiment.common import (
SELECTED_FACTORS, SELECTED_FACTORS,
FACTOR_DEFINITIONS, FACTOR_DEFINITIONS,
get_label_factor, LABEL_NAME,
LABEL_FACTOR,
TRAIN_START, TRAIN_START,
TRAIN_END, TRAIN_END,
VAL_START, VAL_START,
@@ -44,171 +45,39 @@ TRAINING_TYPE = "rank"
# %% md # %% md
# ## 2. 训练特定配置 # ## 2. 训练特定配置
# %% # %%
# Label 配置 # Label 配置(从 common.py 统一导入)
LABEL_NAME = "future_return_5" # LABEL_NAME 和 LABEL_FACTOR 已在 common.py 中绑定,只需从 common 导入
LABEL_FACTOR = get_label_factor(LABEL_NAME)
# 分位数配置 # 分位数配置
N_QUANTILES = 20 N_QUANTILES = 20
# 排除的因子列表 # 排除的因子列表
EXCLUDED_FACTORS = [ EXCLUDED_FACTORS = [
"volatility_5",
"volume_ratio_5_20",
"capital_retention_20",
"volatility_squeeze_5_60",
"drawdown_from_high_60",
"ma_ratio_5_20",
"bias_10",
"high_low_ratio",
"bbi_ratio",
"volatility_20",
"std_return_20",
"sharpe_ratio_20",
"ma_5",
"max_ret_20",
"CP",
"net_profit_yoy",
"debt_to_equity",
"EP_rank",
"turnover_rank",
"return_5_rank",
"ebit_rank",
"BP",
"EP",
"amihud_illiq_20",
"profit_margin",
"return_5",
"return_20",
"kaufman_ER_20",
"GTJA_alpha043",
"GTJA_alpha042",
"GTJA_alpha041",
"GTJA_alpha040",
"GTJA_alpha039",
"GTJA_alpha037",
"GTJA_alpha036",
"GTJA_alpha035",
"GTJA_alpha033",
"GTJA_alpha032",
"GTJA_alpha031",
"GTJA_alpha028",
"GTJA_alpha026",
"GTJA_alpha027",
"GTJA_alpha023",
"GTJA_alpha024",
"GTJA_alpha009",
"GTJA_alpha011",
"GTJA_alpha022",
"GTJA_alpha020",
"GTJA_alpha018",
"GTJA_alpha019",
"GTJA_alpha014",
"GTJA_alpha013",
"GTJA_alpha010", "GTJA_alpha010",
"GTJA_alpha001",
"GTJA_alpha003",
"GTJA_alpha002",
"GTJA_alpha004",
"GTJA_alpha005", "GTJA_alpha005",
"GTJA_alpha006", "GTJA_alpha002",
"GTJA_alpha008", "GTJA_alpha027",
"turnover_deviation", "GTJA_alpha051",
"turnover_cv_20",
"roa",
"GTJA_alpha073",
"GTJA_alpha078",
"GTJA_alpha077",
"GTJA_alpha076",
"GTJA_alpha067",
"GTJA_alpha085",
"GTJA_alpha084",
"GTJA_alpha087",
"GTJA_alpha088",
"GTJA_alpha090",
"GTJA_alpha083",
"GTJA_alpha079",
"GTJA_alpha080",
"GTJA_alpha094",
"GTJA_alpha092",
"GTJA_alpha089",
"GTJA_alpha095",
"GTJA_alpha064",
"GTJA_alpha065",
"GTJA_alpha066",
"GTJA_alpha063",
"GTJA_alpha060",
"GTJA_alpha058",
"GTJA_alpha057",
"GTJA_alpha056",
"GTJA_alpha046",
"GTJA_alpha044", "GTJA_alpha044",
"GTJA_alpha049", "GTJA_alpha041",
"GTJA_alpha050",
"GTJA_alpha110",
"GTJA_alpha107",
"GTJA_alpha104",
"GTJA_alpha106",
"GTJA_alpha103",
"GTJA_alpha100",
"GTJA_alpha101",
"GTJA_alpha102",
"GTJA_alpha098",
"GTJA_alpha097",
"GTJA_alpha096",
"GTJA_alpha099",
"GTJA_alpha117",
"GTJA_alpha118",
"GTJA_alpha114",
"GTJA_alpha111",
"GTJA_alpha129",
"GTJA_alpha130",
"GTJA_alpha132",
"GTJA_alpha131", "GTJA_alpha131",
"GTJA_alpha134", "GTJA_alpha103",
"GTJA_alpha135", "GTJA_alpha087",
"GTJA_alpha136", "GTJA_alpha093",
"GTJA_alpha112", "GTJA_alpha092",
"GTJA_alpha120", "GTJA_alpha073",
"GTJA_alpha119",
"GTJA_alpha122",
"GTJA_alpha124",
"GTJA_alpha126",
"GTJA_alpha127", "GTJA_alpha127",
"GTJA_alpha128", "GTJA_alpha117",
"GTJA_alpha115", "GTJA_alpha124",
"GTJA_alpha153",
"GTJA_alpha152",
"GTJA_alpha151",
"GTJA_alpha150",
"GTJA_alpha148",
"GTJA_alpha142",
"GTJA_alpha141",
"GTJA_alpha139",
"GTJA_alpha133",
"GTJA_alpha161",
"GTJA_alpha164",
"GTJA_alpha162", "GTJA_alpha162",
"GTJA_alpha157",
"GTJA_alpha156",
"GTJA_alpha160",
"GTJA_alpha155",
"GTJA_alpha170",
"GTJA_alpha169",
"GTJA_alpha168",
"GTJA_alpha166",
"GTJA_alpha163",
"GTJA_alpha176",
"GTJA_alpha175",
"GTJA_alpha174",
"GTJA_alpha178",
"GTJA_alpha177", "GTJA_alpha177",
"GTJA_alpha185",
"GTJA_alpha180",
"GTJA_alpha187",
"GTJA_alpha188", "GTJA_alpha188",
"GTJA_alpha189", "smart_money_accumulation",
"GTJA_alpha191", "GTJA_alpha014",
"GTJA_alpha056",
"GTJA_alpha085",
"GTJA_alpha154",
"GTJA_alpha141",
] ]
# LambdaRank 模型参数配置 # LambdaRank 模型参数配置

View File

@@ -15,13 +15,15 @@ from src.training import (
NullFiller, NullFiller,
Winsorizer, Winsorizer,
StandardScaler, StandardScaler,
CrossSectionalStandardScaler,
) )
from src.training.core.trainer_v2 import Trainer from src.training.core.trainer_v2 import Trainer
from src.training.components.filters import STFilter from src.training.components.filters import STFilter
from src.experiment.common import ( from src.experiment.common import (
SELECTED_FACTORS, SELECTED_FACTORS,
FACTOR_DEFINITIONS, FACTOR_DEFINITIONS,
get_label_factor, LABEL_NAME,
LABEL_FACTOR,
TRAIN_START, TRAIN_START,
TRAIN_END, TRAIN_END,
VAL_START, VAL_START,
@@ -44,58 +46,93 @@ TRAINING_TYPE = "regression"
# %% md # %% md
# ## 2. 训练特定配置 # ## 2. 训练特定配置
# %% # %%
# Label 配置 # Label 配置(从 common.py 统一导入)
LABEL_NAME = "future_return_5" # LABEL_NAME 和 LABEL_FACTOR 已在 common.py 中绑定,只需从 common 导入
LABEL_FACTOR = get_label_factor(LABEL_NAME)
# 排除的因子列表 # 排除的因子列表
EXCLUDED_FACTORS = [ EXCLUDED_FACTORS = [
"GTJA_alpha062",
"GTJA_alpha060",
"GTJA_alpha058",
"GTJA_alpha056",
"GTJA_alpha053",
"GTJA_alpha040",
"GTJA_alpha043",
"GTJA_alpha027",
"CP",
"max_ret_20",
"debt_to_equity",
"close_vwap_deviation",
"EP",
"BP",
"EP_rank",
"GTJA_alpha044",
"GTJA_alpha036",
"GTJA_alpha010", "GTJA_alpha010",
"GTJA_alpha005", "GTJA_alpha005",
"GTJA_alpha036", "GTJA_alpha001",
"GTJA_alpha027", "GTJA_alpha002",
"GTJA_alpha044", "GTJA_alpha007",
"GTJA_alpha016",
"GTJA_alpha073", "GTJA_alpha073",
"GTJA_alpha104",
"GTJA_alpha103",
"GTJA_alpha105",
"GTJA_alpha092",
"GTJA_alpha087",
"GTJA_alpha085",
"GTJA_alpha062",
"GTJA_alpha124",
"GTJA_alpha133", "GTJA_alpha133",
"GTJA_alpha131", "GTJA_alpha131",
"GTJA_alpha117", "GTJA_alpha117",
"GTJA_alpha124",
"GTJA_alpha120",
"GTJA_alpha119",
"GTJA_alpha103",
"GTJA_alpha099",
"GTJA_alpha105",
"GTJA_alpha104",
"GTJA_alpha090",
"GTJA_alpha085",
"GTJA_alpha083",
"GTJA_alpha084",
"GTJA_alpha087",
"GTJA_alpha092",
"GTJA_alpha074",
"GTJA_alpha089",
"GTJA_alpha173",
"GTJA_alpha157", "GTJA_alpha157",
"GTJA_alpha139",
"GTJA_alpha162", "GTJA_alpha162",
"GTJA_alpha163",
"GTJA_alpha177", "GTJA_alpha177",
"GTJA_alpha180", "price_to_avg_cost",
"cost_skewness",
"GTJA_alpha191", "GTJA_alpha191",
"GTJA_alpha180",
"history_position",
"bottom_profit",
"smart_money_accumulation",
] ]
# 模型参数配置 # 模型参数配置
MODEL_PARAMS = { MODEL_PARAMS = {
# 基础设置 # ==================== 基础设置 ====================
"objective": "regression_l1", "objective": "huber", # 【修改】相比纯 L1(MAE)huber 对异常值鲁棒且在极小误差处平滑,更适合收益率预测
"metric": "mae", "metric": "mae",
# 树结构约束 # ==================== 树结构约束 ====================
"max_depth": 5, "max_depth": 5, # 【修改】适当加深,允许捕捉一定的高阶交叉
"num_leaves": 24, "num_leaves": 31, # 【修改】限制为 312的5次方-1确保树是不对称生长的防止过拟合
"min_data_in_leaf": 100, "min_data_in_leaf": 512, # 【大幅增加】从256加到1000。训练集有97万条极大地限制叶子节点样本量能有效抵抗股市噪音
# 学习参数 # ==================== 学习参数 ====================
"learning_rate": 0.01, "learning_rate": 0.02, # 【修改】稍微调大一点,帮助模型跳出初始的局部最优(避免十几轮就早停)
"n_estimators": 1500, "n_estimators": 2000,
# 随机采样 # ==================== 随机采样与降维 ====================
"subsample": 0.8, "subsample": 0.85,
"subsample_freq": 1, "subsample_freq": 1,
"colsample_bytree": 0.8, "colsample_bytree": 0.4, # 【大幅降低】从0.8降到0.4。强制打压 GTJA_alpha127 的霸权,逼迫模型去学习其他因子的信息
# 正则化 "extra_trees": True, # 【新增且极度推荐】极度随机树模式。在分裂点选择时增加随机性,是量化比赛中防过拟合的神器
"reg_alpha": 0.5, # ==================== 正则化 ====================
"reg_lambda": 1.0, "reg_alpha": 1.0, # 【修改】L1正则增加强行把一些无用特征的权重压到0
# 杂项 "reg_lambda": 5.0, # 【修改】L2正则大幅增加从1到5惩罚过大的叶子节点输出权重
"max_bin": 127, # 【新增】默认255降低到127相当于对连续特征做了一次粗颗粒度的分箱也是极好的正则化手段
# ==================== 杂项 ====================
"verbose": -1, "verbose": -1,
"random_state": 42, "random_state": 42,
"n_jobs": -1,
} }
# 日期范围配置 # 日期范围配置
@@ -143,6 +180,7 @@ def main():
(NullFiller, {"strategy": "mean"}), (NullFiller, {"strategy": "mean"}),
(Winsorizer, {"lower": 0.01, "upper": 0.99}), (Winsorizer, {"lower": 0.01, "upper": 0.99}),
(StandardScaler, {}), (StandardScaler, {}),
# (CrossSectionalStandardScaler, {}),
], ],
filters=[STFilter(data_router=engine.router)], filters=[STFilter(data_router=engine.router)],
stock_pool_filter_func=stock_pool_filter, stock_pool_filter_func=stock_pool_filter,

View File

@@ -0,0 +1,81 @@
"""列出所有已入库的因子。
以 Python 列表格式输出所有已注册因子的名称,方便复制使用。
保持 factors.jsonl 中的原始顺序(按 factor_id
使用方法:
uv run python -m src.scripts.list_factors
"""
import json
import re
from pathlib import Path
from src.config.settings import get_settings
def extract_factor_id_number(factor_id: str) -> int:
"""从 factor_id 中提取数字部分用于排序。
Args:
factor_id: 如 "F_001"
Returns:
数字部分,如 1
"""
match = re.match(r"F_(\d+)", factor_id)
if match:
return int(match.group(1))
return 0
def list_factors():
"""读取 factors.jsonl 并按 factor_id 顺序打印因子名称列表。"""
settings = get_settings()
factors_path = settings.data_path_resolved / "factors.jsonl"
if not factors_path.exists():
print(f"[错误] 因子文件不存在: {factors_path}")
return
# 读取所有因子并按 factor_id 排序
factors = []
try:
with open(factors_path, "r", encoding="utf-8") as f:
for line in f:
line = line.strip()
if not line:
continue
try:
data = json.loads(line)
factor_id = data.get("factor_id", "")
name = data.get("name")
if name and factor_id:
factors.append((factor_id, name))
except json.JSONDecodeError:
continue
except Exception as e:
print(f"[错误] 读取因子文件失败: {e}")
return
if not factors:
print("[信息] 没有找到任何因子")
return
# 按 factor_id 数字排序(保持入库顺序)
factors.sort(key=lambda x: extract_factor_id_number(x[0]))
# 以 Python 列表格式输出
print("[")
for i, (factor_id, name) in enumerate(factors):
if i == len(factors) - 1:
print(f' "{name}"')
else:
print(f' "{name}",')
print("]")
print(f"\n[统计] 共计 {len(factors)} 个因子")
if __name__ == "__main__":
list_factors()

View File

@@ -26,24 +26,126 @@ from typing import Any, Dict, List, Optional
from src.factors.metadata import FactorManager from src.factors.metadata import FactorManager
from src.factors.metadata.exceptions import DuplicateFactorError, ValidationError from src.factors.metadata.exceptions import DuplicateFactorError, ValidationError
from src.config.settings import get_settings
# ============================================================================ # ============================================================================
# 用户配置区域 - 在这里添加要注册的因子 # 用户配置区域 - 在这里添加要注册的因子
# ============================================================================ # ============================================================================
FACTORS: List[Dict[str, Any]] = [ FACTORS: List[Dict[str, Any]] =[
# 示例因子,请根据实际需要修改或添加 # ==================== 第一类:筹码集中度与离散度因子 ====================
{ {
"name": "turnover_volatility_ratio", "name": "chip_dispersion_90",
"desc": "5日价格动量收盘价相对于5日前收盘价的涨跌幅进行截面排名", "desc": "90%筹码离散度衡量市场90%持仓筹码的宽度,值越小表示筹码越高度集中(单峰密集),往往是洗盘结束的前兆",
"dsl": "cs_rank(close / ts_delay(close, 5) - 1)", "dsl": "(cost_95pct - cost_5pct) / (cost_95pct + cost_5pct)",
"category": "momentum", },
{
"name": "chip_dispersion_70",
"desc": "70%核心筹码离散度剔除极端的底部死筹和高位套牢盘反映中间70%主流资金的成本集中度",
"dsl": "(cost_85pct - cost_15pct) / (cost_85pct + cost_15pct)",
},
{
"name": "cost_skewness",
"desc": "筹码偏度反映筹码分布的不对称性。大于1说明上方套牢盘拖尾严重小于1说明下方获利盘雄厚",
"dsl": "(cost_95pct - cost_50pct) / (cost_50pct - cost_5pct)",
},
{
"name": "dispersion_change_20",
"desc": "筹码集中度近期变化率过去20天筹码宽度的变化比例持续下降说明主力正在暗中吸筹",
"dsl": "ts_pct_change((cost_95pct - cost_5pct) / cost_50pct, 20)",
},
# ==================== 第二类:筹码相对位置与压力/支撑因子 ====================
{
"name": "price_to_avg_cost",
"desc": "整体浮盈比例:当前价格相对加权平均成本的溢价率。高溢价有均值回归压力,负溢价代表超跌",
"dsl": "(close - weight_avg) / weight_avg",
},
{
"name": "price_to_median_cost",
"desc": "中位数成本偏离度价格相对于50%分位点(绝对半数人持仓价)的偏离,向上突破通常是右侧买点",
"dsl": "(close - cost_50pct) / cost_50pct",
},
{
"name": "mean_median_dev",
"desc": "均值中位数背离:均值显著大于中位数说明高位筹码堆积,上涨阻力大",
"dsl": "(weight_avg - cost_50pct) / cost_50pct",
},
{
"name": "trap_pressure",
"desc": "高位套牢盘压力指数当前价格距离上方95%高位套牢成本的距离。距离越大,反弹的真空期阻力越小",
"dsl": "(cost_95pct - close) / close",
},
{
"name": "bottom_profit",
"desc": "底部支撑底仓利润率当前价格距离底部5%筹码的利润空间。暴跌时大于0说明底仓极度稳定",
"dsl": "(close - cost_5pct) / cost_5pct",
},
{
"name": "history_position",
"desc": "历史区间分位点:当前价格在个股上市以来历史最高点和最低点之间的相对位置",
"dsl": "(close - his_low) / (his_high - his_low)",
},
# ==================== 第三类:胜率相关的动量与反转因子 ====================
{
"name": "winner_rate_surge_5",
"desc": "获利盘短期爆发力胜率在过去5天内的变化值急剧上升是极强的动量做多信号",
"dsl": "ts_delta(winner_rate, 5)",
},
{
"name": "winner_rate_cs_rank",
"desc": "获利盘高位反转信号:全市场胜率截面排名,极端高胜率往往面临多头踩踏的获利了结压力(反转因子)",
"dsl": "cs_rank(winner_rate)",
},
{
"name": "winner_rate_dev_20",
"desc": "获利盘均线偏离当前胜率相对过去20天平均胜率的偏离程度捕捉筹码情绪的边际超买/超卖",
"dsl": "winner_rate - ts_mean(winner_rate, 20)",
},
{
"name": "winner_rate_volatility",
"desc": "获利盘波动率过去20天胜率的波动率。波动率低且胜率高说明单边上涨极度稳健",
"dsl": "ts_std(winner_rate, 20)",
},
{
"name": "smart_money_accumulation",
"desc": "潜在主力吸筹隐蔽指标胜率的60日时序分位数减去价格的时序分位数。值越大说明价平而获利盘增底部吸筹明显",
"dsl": "ts_rank(winner_rate, 60) - ts_rank(close, 60)",
},
# ==================== 第四类:量价与筹码交乘因子 ====================
{
"name": "winner_vol_corr_20",
"desc": "放量突破筹码密集区胜率与成交量的20日时序相关性正相关说明增量资金在主动解套上方筹码",
"dsl": "ts_corr(winner_rate, vol, 20)",
},
{
"name": "cost_base_momentum",
"desc": "成本重心上移换手率过去20天加权平均成本的变化幅度快速上移说明高位换手极其充分",
"dsl": "ts_pct_change(weight_avg, 20)",
},
{
"name": "bottom_cost_stability",
"desc": "底部坚如磐石因子底部5%成本的60天波动率相对于中位数的比值波动越小说明死筹越稳固",
"dsl": "ts_std(cost_5pct, 60) / cost_50pct",
},
{
"name": "pivot_reversion",
"desc": "盈亏分界线乖离修复价格偏离50%分位点除以近20日价格标准差用于寻找超跌后的均值回归买点",
"dsl": "(close - cost_50pct) / ts_std(close, 20)",
},
{
"name": "chip_transition",
"desc": "强弱筹码切换度上方厚度与下方厚度差值的20日变化量。由正变负说明筹码彻底完成了自上而下的转移洗盘结束",
"dsl": "ts_delta((cost_85pct - cost_50pct) - (cost_50pct - cost_15pct), 20)",
}, },
] ]
# 因子存储路径(默认使用实验目录) # 因子存储路径(使用项目根路径下的 data 目录)
OUTPUT_PATH = Path(__file__).parent.parent / "experiment" / "data" / "factors.jsonl" settings = get_settings()
OUTPUT_PATH = settings.data_path_resolved / "factors.jsonl"
# ============================================================================ # ============================================================================

View File

@@ -48,6 +48,7 @@ class LightGBMModel(BaseModel):
self.params = dict(params) if params is not None else {} self.params = dict(params) if params is not None else {}
self.model = None self.model = None
self.feature_names_: Optional[list] = None self.feature_names_: Optional[list] = None
self.evals_result_: Optional[dict] = None
def fit( def fit(
self, self,
@@ -90,14 +91,23 @@ class LightGBMModel(BaseModel):
y_val_np = y_val.to_numpy() y_val_np = y_val.to_numpy()
valid_sets = lgb.Dataset(X_val_np, label=y_val_np, reference=train_data) valid_sets = lgb.Dataset(X_val_np, label=y_val_np, reference=train_data)
# 从 params 中提取 num_boost_round默认 100 # 从 params 中提取训练控制参数
num_boost_round = self.params.pop("n_estimators", 100) params_copy = dict(self.params)
num_boost_round = params_copy.pop("n_estimators", 100)
early_stopping_round = params_copy.pop("early_stopping_round", 50)
self.evals_result_ = {}
callbacks = [
lgb.early_stopping(stopping_rounds=early_stopping_round),
lgb.record_evaluation(self.evals_result_),
]
self.model = lgb.train( self.model = lgb.train(
self.params, params_copy,
train_data, train_data,
num_boost_round=num_boost_round, num_boost_round=num_boost_round,
valid_sets=[valid_sets] if valid_sets else None, valid_sets=[valid_sets] if valid_sets else None,
callbacks=callbacks,
) )
return self return self
@@ -121,6 +131,34 @@ class LightGBMModel(BaseModel):
result = self.model.predict(X_np) result = self.model.predict(X_np)
return np.asarray(result) return np.asarray(result)
def get_evals_result(self) -> Optional[dict]:
"""获取训练评估结果
Returns:
评估结果字典,如果模型尚未训练返回 None
"""
return self.evals_result_
def get_best_iteration(self) -> Optional[int]:
"""获取最佳迭代轮数(考虑早停)
Returns:
最佳迭代轮数,如果模型未训练返回 None
"""
if self.model is None:
return None
return self.model.best_iteration
def get_best_score(self) -> Optional[dict]:
"""获取最佳评分
Returns:
最佳评分字典,如果模型未训练返回 None
"""
if self.model is None:
return None
return self.model.best_score
def feature_importance(self) -> Optional[pd.Series]: def feature_importance(self) -> Optional[pd.Series]:
"""返回特征重要性 """返回特征重要性

View File

@@ -84,7 +84,7 @@ class ResultAnalyzer:
print("\n" + "-" * 80) print("\n" + "-" * 80)
print(f"[警告] 贡献为0的特征{len(zero_importance_features)} 个):") print(f"[警告] 贡献为0的特征{len(zero_importance_features)} 个):")
for i, feature in enumerate(zero_importance_features, 1): for i, feature in enumerate(zero_importance_features, 1):
print(f" {i}. {feature}") print(f"'{feature}',")
# 统计摘要 # 统计摘要
print("\n" + "=" * 80) print("\n" + "=" * 80)

View File

@@ -189,10 +189,14 @@ class RankTask(BaseTask):
def plot_training_metrics(self) -> None: def plot_training_metrics(self) -> None:
"""绘制训练指标曲线NDCG""" """绘制训练指标曲线NDCG"""
if self.model and hasattr(self.model, "model") and self.model.model: if self.model and hasattr(self.model, "get_evals_result"):
try: try:
import lightgbm as lgb import lightgbm as lgb
lgb.plot_metric(self.model.model) evals_result = self.model.get_evals_result()
if evals_result:
lgb.plot_metric(evals_result)
else:
print("[警告] 没有训练指标数据可供绘制")
except Exception as e: except Exception as e:
print(f"[警告] 无法绘制训练曲线: {e}") print(f"[警告] 无法绘制训练曲线: {e}")

View File

@@ -77,10 +77,14 @@ class RegressionTask(BaseTask):
def plot_training_metrics(self) -> None: def plot_training_metrics(self) -> None:
"""绘制训练指标曲线""" """绘制训练指标曲线"""
if self.model and hasattr(self.model, "model") and self.model.model: if self.model and hasattr(self.model, "get_evals_result"):
try: try:
import lightgbm as lgb import lightgbm as lgb
lgb.plot_metric(self.model.model) evals_result = self.model.get_evals_result()
if evals_result:
lgb.plot_metric(evals_result)
else:
print("[警告] 没有训练指标数据可供绘制")
except Exception as e: except Exception as e:
print(f"[警告] 无法绘制训练曲线: {e}") print(f"[警告] 无法绘制训练曲线: {e}")