2026-03-15 05:46:19 +08:00
|
|
|
# %% md
|
2026-03-24 23:35:31 +08:00
|
|
|
# # LightGBM LambdaRank 排序学习训练流程(模块化版本)
|
|
|
|
|
#
|
|
|
|
|
# 使用新的模块化 Trainer 架构,代码更简洁、可维护性更高。
|
2026-03-15 05:46:19 +08:00
|
|
|
# %% md
|
2026-03-11 22:54:52 +08:00
|
|
|
# ## 1. 导入依赖
|
2026-03-15 05:46:19 +08:00
|
|
|
# %%
|
2026-03-11 22:54:52 +08:00
|
|
|
import os
|
|
|
|
|
|
|
|
|
|
from src.factors import FactorEngine
|
|
|
|
|
from src.training import (
|
2026-03-24 23:35:31 +08:00
|
|
|
FactorManager,
|
|
|
|
|
DataPipeline,
|
|
|
|
|
RankTask,
|
2026-03-11 22:54:52 +08:00
|
|
|
NullFiller,
|
2026-03-24 23:35:31 +08:00
|
|
|
Winsorizer,
|
2026-03-16 22:50:47 +08:00
|
|
|
CrossSectionalStandardScaler,
|
2026-03-11 22:54:52 +08:00
|
|
|
)
|
2026-03-26 00:15:30 +08:00
|
|
|
from src.training.core.trainer_v2 import Trainer
|
2026-03-24 23:35:31 +08:00
|
|
|
from src.training.components.filters import STFilter
|
2026-03-15 05:46:19 +08:00
|
|
|
from src.experiment.common import (
|
|
|
|
|
SELECTED_FACTORS,
|
|
|
|
|
FACTOR_DEFINITIONS,
|
|
|
|
|
get_label_factor,
|
|
|
|
|
TRAIN_START,
|
|
|
|
|
TRAIN_END,
|
|
|
|
|
VAL_START,
|
|
|
|
|
VAL_END,
|
|
|
|
|
TEST_START,
|
|
|
|
|
TEST_END,
|
|
|
|
|
stock_pool_filter,
|
|
|
|
|
STOCK_FILTER_REQUIRED_COLUMNS,
|
|
|
|
|
OUTPUT_DIR,
|
|
|
|
|
SAVE_PREDICTIONS,
|
2026-03-16 22:50:47 +08:00
|
|
|
SAVE_MODEL,
|
|
|
|
|
get_model_save_path,
|
|
|
|
|
save_model_with_factors,
|
2026-03-15 05:46:19 +08:00
|
|
|
TOP_N,
|
|
|
|
|
)
|
2026-03-11 22:54:52 +08:00
|
|
|
|
2026-03-16 22:50:47 +08:00
|
|
|
# 训练类型标识
|
|
|
|
|
TRAINING_TYPE = "rank"
|
|
|
|
|
|
2026-03-15 05:46:19 +08:00
|
|
|
# %% md
|
2026-03-24 23:35:31 +08:00
|
|
|
# ## 2. 训练特定配置
|
2026-03-15 05:46:19 +08:00
|
|
|
# %%
|
2026-03-24 23:35:31 +08:00
|
|
|
# Label 配置
|
2026-03-15 05:46:19 +08:00
|
|
|
LABEL_NAME = "future_return_5"
|
|
|
|
|
LABEL_FACTOR = get_label_factor(LABEL_NAME)
|
|
|
|
|
|
|
|
|
|
# 分位数配置
|
2026-03-24 23:35:31 +08:00
|
|
|
N_QUANTILES = 20
|
|
|
|
|
|
|
|
|
|
# 排除的因子列表
|
|
|
|
|
EXCLUDED_FACTORS = [
|
|
|
|
|
"volatility_5",
|
|
|
|
|
"volume_ratio_5_20",
|
|
|
|
|
"capital_retention_20",
|
|
|
|
|
"volatility_squeeze_5_60",
|
|
|
|
|
"drawdown_from_high_60",
|
|
|
|
|
"ma_ratio_5_20",
|
|
|
|
|
"bias_10",
|
|
|
|
|
"high_low_ratio",
|
|
|
|
|
"bbi_ratio",
|
|
|
|
|
"volatility_20",
|
|
|
|
|
"std_return_20",
|
|
|
|
|
"sharpe_ratio_20",
|
|
|
|
|
"ma_5",
|
|
|
|
|
"max_ret_20",
|
|
|
|
|
"CP",
|
|
|
|
|
"net_profit_yoy",
|
|
|
|
|
"debt_to_equity",
|
|
|
|
|
"EP_rank",
|
|
|
|
|
"turnover_rank",
|
|
|
|
|
"return_5_rank",
|
|
|
|
|
"ebit_rank",
|
|
|
|
|
"BP",
|
|
|
|
|
"EP",
|
|
|
|
|
"amihud_illiq_20",
|
|
|
|
|
"profit_margin",
|
|
|
|
|
"return_5",
|
|
|
|
|
"return_20",
|
|
|
|
|
"kaufman_ER_20",
|
|
|
|
|
"GTJA_alpha043",
|
|
|
|
|
"GTJA_alpha042",
|
|
|
|
|
"GTJA_alpha041",
|
|
|
|
|
"GTJA_alpha040",
|
|
|
|
|
"GTJA_alpha039",
|
|
|
|
|
"GTJA_alpha037",
|
|
|
|
|
"GTJA_alpha036",
|
|
|
|
|
"GTJA_alpha035",
|
|
|
|
|
"GTJA_alpha033",
|
|
|
|
|
"GTJA_alpha032",
|
|
|
|
|
"GTJA_alpha031",
|
|
|
|
|
"GTJA_alpha028",
|
|
|
|
|
"GTJA_alpha026",
|
|
|
|
|
"GTJA_alpha027",
|
|
|
|
|
"GTJA_alpha023",
|
|
|
|
|
"GTJA_alpha024",
|
|
|
|
|
"GTJA_alpha009",
|
|
|
|
|
"GTJA_alpha011",
|
|
|
|
|
"GTJA_alpha022",
|
|
|
|
|
"GTJA_alpha020",
|
|
|
|
|
"GTJA_alpha018",
|
|
|
|
|
"GTJA_alpha019",
|
|
|
|
|
"GTJA_alpha014",
|
|
|
|
|
"GTJA_alpha013",
|
|
|
|
|
"GTJA_alpha010",
|
|
|
|
|
"GTJA_alpha001",
|
|
|
|
|
"GTJA_alpha003",
|
|
|
|
|
"GTJA_alpha002",
|
|
|
|
|
"GTJA_alpha004",
|
|
|
|
|
"GTJA_alpha005",
|
|
|
|
|
"GTJA_alpha006",
|
|
|
|
|
"GTJA_alpha008",
|
|
|
|
|
"turnover_deviation",
|
|
|
|
|
"turnover_cv_20",
|
|
|
|
|
"roa",
|
|
|
|
|
"GTJA_alpha073",
|
|
|
|
|
"GTJA_alpha078",
|
|
|
|
|
"GTJA_alpha077",
|
|
|
|
|
"GTJA_alpha076",
|
|
|
|
|
"GTJA_alpha067",
|
|
|
|
|
"GTJA_alpha085",
|
|
|
|
|
"GTJA_alpha084",
|
|
|
|
|
"GTJA_alpha087",
|
|
|
|
|
"GTJA_alpha088",
|
|
|
|
|
"GTJA_alpha090",
|
|
|
|
|
"GTJA_alpha083",
|
|
|
|
|
"GTJA_alpha079",
|
|
|
|
|
"GTJA_alpha080",
|
|
|
|
|
"GTJA_alpha094",
|
|
|
|
|
"GTJA_alpha092",
|
|
|
|
|
"GTJA_alpha089",
|
|
|
|
|
"GTJA_alpha095",
|
|
|
|
|
"GTJA_alpha064",
|
|
|
|
|
"GTJA_alpha065",
|
|
|
|
|
"GTJA_alpha066",
|
|
|
|
|
"GTJA_alpha063",
|
|
|
|
|
"GTJA_alpha060",
|
|
|
|
|
"GTJA_alpha058",
|
|
|
|
|
"GTJA_alpha057",
|
|
|
|
|
"GTJA_alpha056",
|
|
|
|
|
"GTJA_alpha046",
|
|
|
|
|
"GTJA_alpha044",
|
|
|
|
|
"GTJA_alpha049",
|
|
|
|
|
"GTJA_alpha050",
|
|
|
|
|
"GTJA_alpha110",
|
|
|
|
|
"GTJA_alpha107",
|
|
|
|
|
"GTJA_alpha104",
|
|
|
|
|
"GTJA_alpha106",
|
|
|
|
|
"GTJA_alpha103",
|
|
|
|
|
"GTJA_alpha100",
|
|
|
|
|
"GTJA_alpha101",
|
|
|
|
|
"GTJA_alpha102",
|
|
|
|
|
"GTJA_alpha098",
|
|
|
|
|
"GTJA_alpha097",
|
|
|
|
|
"GTJA_alpha096",
|
|
|
|
|
"GTJA_alpha099",
|
|
|
|
|
"GTJA_alpha117",
|
|
|
|
|
"GTJA_alpha118",
|
|
|
|
|
"GTJA_alpha114",
|
|
|
|
|
"GTJA_alpha111",
|
|
|
|
|
"GTJA_alpha129",
|
|
|
|
|
"GTJA_alpha130",
|
|
|
|
|
"GTJA_alpha132",
|
|
|
|
|
"GTJA_alpha131",
|
|
|
|
|
"GTJA_alpha134",
|
|
|
|
|
"GTJA_alpha135",
|
|
|
|
|
"GTJA_alpha136",
|
|
|
|
|
"GTJA_alpha112",
|
|
|
|
|
"GTJA_alpha120",
|
|
|
|
|
"GTJA_alpha119",
|
|
|
|
|
"GTJA_alpha122",
|
|
|
|
|
"GTJA_alpha124",
|
|
|
|
|
"GTJA_alpha126",
|
|
|
|
|
"GTJA_alpha127",
|
|
|
|
|
"GTJA_alpha128",
|
|
|
|
|
"GTJA_alpha115",
|
|
|
|
|
"GTJA_alpha153",
|
|
|
|
|
"GTJA_alpha152",
|
|
|
|
|
"GTJA_alpha151",
|
|
|
|
|
"GTJA_alpha150",
|
|
|
|
|
"GTJA_alpha148",
|
|
|
|
|
"GTJA_alpha142",
|
|
|
|
|
"GTJA_alpha141",
|
|
|
|
|
"GTJA_alpha139",
|
|
|
|
|
"GTJA_alpha133",
|
|
|
|
|
"GTJA_alpha161",
|
|
|
|
|
"GTJA_alpha164",
|
|
|
|
|
"GTJA_alpha162",
|
|
|
|
|
"GTJA_alpha157",
|
|
|
|
|
"GTJA_alpha156",
|
|
|
|
|
"GTJA_alpha160",
|
|
|
|
|
"GTJA_alpha155",
|
|
|
|
|
"GTJA_alpha170",
|
|
|
|
|
"GTJA_alpha169",
|
|
|
|
|
"GTJA_alpha168",
|
|
|
|
|
"GTJA_alpha166",
|
|
|
|
|
"GTJA_alpha163",
|
|
|
|
|
"GTJA_alpha176",
|
|
|
|
|
"GTJA_alpha175",
|
|
|
|
|
"GTJA_alpha174",
|
|
|
|
|
"GTJA_alpha178",
|
|
|
|
|
"GTJA_alpha177",
|
|
|
|
|
"GTJA_alpha185",
|
|
|
|
|
"GTJA_alpha180",
|
|
|
|
|
"GTJA_alpha187",
|
|
|
|
|
"GTJA_alpha188",
|
|
|
|
|
"GTJA_alpha189",
|
|
|
|
|
"GTJA_alpha191",
|
|
|
|
|
]
|
2026-03-13 22:24:12 +08:00
|
|
|
|
2026-03-11 22:54:52 +08:00
|
|
|
# LambdaRank 模型参数配置
|
|
|
|
|
MODEL_PARAMS = {
|
|
|
|
|
"objective": "lambdarank",
|
|
|
|
|
"metric": "ndcg",
|
2026-03-24 23:35:31 +08:00
|
|
|
"ndcg_at": 25,
|
|
|
|
|
"learning_rate": 0.1,
|
2026-03-16 22:50:47 +08:00
|
|
|
"n_estimators": 1000,
|
2026-03-24 23:35:31 +08:00
|
|
|
"early_stopping_round": 50,
|
|
|
|
|
# 防止过拟合的核心约束
|
|
|
|
|
"max_depth": 4,
|
|
|
|
|
"num_leaves": 32,
|
|
|
|
|
"min_data_in_leaf": 256,
|
|
|
|
|
# 随机采样(增加鲁棒性)
|
|
|
|
|
"subsample": 0.4,
|
|
|
|
|
"subsample_freq": 1,
|
|
|
|
|
"colsample_bytree": 0.4,
|
|
|
|
|
# 正则化惩罚
|
|
|
|
|
"reg_alpha": 10.0,
|
|
|
|
|
"reg_lambda": 50.0,
|
|
|
|
|
# Lambdarank 专属配置
|
2026-03-22 02:43:23 +08:00
|
|
|
"lambdarank_truncation_level": 50,
|
2026-03-24 23:35:31 +08:00
|
|
|
"label_gain": [i * i for i in range(1, N_QUANTILES + 1)],
|
2026-03-22 02:43:23 +08:00
|
|
|
"verbose": -1,
|
|
|
|
|
"random_state": 42,
|
2026-03-11 22:54:52 +08:00
|
|
|
}
|
|
|
|
|
|
2026-03-24 23:35:31 +08:00
|
|
|
# 日期范围配置
|
|
|
|
|
date_range = {
|
|
|
|
|
"train": (TRAIN_START, TRAIN_END),
|
|
|
|
|
"val": (VAL_START, VAL_END),
|
|
|
|
|
"test": (TEST_START, TEST_END),
|
|
|
|
|
}
|
2026-03-11 22:54:52 +08:00
|
|
|
|
2026-03-24 23:35:31 +08:00
|
|
|
# 输出配置
|
|
|
|
|
output_config = {
|
|
|
|
|
"output_dir": OUTPUT_DIR,
|
|
|
|
|
"output_filename": "rank_output.csv",
|
|
|
|
|
"save_predictions": SAVE_PREDICTIONS,
|
|
|
|
|
"save_model": SAVE_MODEL,
|
|
|
|
|
"model_save_path": get_model_save_path(TRAINING_TYPE),
|
|
|
|
|
"top_n": TOP_N,
|
|
|
|
|
}
|
2026-03-11 22:54:52 +08:00
|
|
|
|
|
|
|
|
|
2026-03-24 23:35:31 +08:00
|
|
|
def main():
|
|
|
|
|
"""主函数"""
|
|
|
|
|
print("\n" + "=" * 80)
|
|
|
|
|
print("LightGBM LambdaRank 排序学习训练(模块化版本)")
|
2026-03-22 02:43:23 +08:00
|
|
|
print("=" * 80)
|
|
|
|
|
|
2026-03-24 23:35:31 +08:00
|
|
|
# 1. 创建 FactorEngine
|
|
|
|
|
print("\n[1] 创建 FactorEngine")
|
|
|
|
|
engine = FactorEngine()
|
|
|
|
|
|
|
|
|
|
# 2. 创建 FactorManager
|
|
|
|
|
print("\n[2] 创建 FactorManager")
|
|
|
|
|
factor_manager = FactorManager(
|
|
|
|
|
selected_factors=SELECTED_FACTORS,
|
|
|
|
|
factor_definitions=FACTOR_DEFINITIONS,
|
|
|
|
|
label_factor=LABEL_FACTOR,
|
|
|
|
|
excluded_factors=EXCLUDED_FACTORS,
|
|
|
|
|
)
|
2026-03-11 22:54:52 +08:00
|
|
|
|
2026-03-24 23:35:31 +08:00
|
|
|
# 3. 创建 DataPipeline
|
|
|
|
|
print("\n[3] 创建 DataPipeline")
|
|
|
|
|
pipeline = DataPipeline(
|
|
|
|
|
factor_manager=factor_manager,
|
|
|
|
|
processor_configs=[
|
|
|
|
|
(NullFiller, {"strategy": "mean"}),
|
|
|
|
|
(Winsorizer, {"lower": 0.01, "upper": 0.99}),
|
|
|
|
|
(CrossSectionalStandardScaler, {}),
|
|
|
|
|
],
|
|
|
|
|
filters=[STFilter(data_router=engine.router)],
|
|
|
|
|
stock_pool_filter_func=stock_pool_filter,
|
|
|
|
|
stock_pool_required_columns=STOCK_FILTER_REQUIRED_COLUMNS,
|
|
|
|
|
)
|
2026-03-11 22:54:52 +08:00
|
|
|
|
2026-03-24 23:35:31 +08:00
|
|
|
# 4. 创建 RankTask
|
|
|
|
|
print("\n[4] 创建 RankTask")
|
|
|
|
|
task = RankTask(
|
|
|
|
|
model_params=MODEL_PARAMS,
|
|
|
|
|
label_name=LABEL_NAME,
|
|
|
|
|
n_quantiles=N_QUANTILES,
|
|
|
|
|
)
|
2026-03-11 22:54:52 +08:00
|
|
|
|
2026-03-24 23:35:31 +08:00
|
|
|
# 5. 创建 Trainer
|
|
|
|
|
print("\n[5] 创建 Trainer")
|
|
|
|
|
trainer = Trainer(
|
|
|
|
|
data_pipeline=pipeline,
|
|
|
|
|
task=task,
|
|
|
|
|
output_config=output_config,
|
|
|
|
|
verbose=True,
|
|
|
|
|
)
|
2026-03-11 22:54:52 +08:00
|
|
|
|
2026-03-24 23:35:31 +08:00
|
|
|
# 6. 执行训练
|
|
|
|
|
print("\n[6] 执行训练")
|
|
|
|
|
results = trainer.run(engine=engine, date_range=date_range)
|
2026-03-11 22:54:52 +08:00
|
|
|
|
2026-03-24 23:35:31 +08:00
|
|
|
# 7. 保存模型和因子信息(如果启用)
|
|
|
|
|
if SAVE_MODEL:
|
|
|
|
|
print("\n[7] 保存模型和因子信息")
|
2026-03-16 22:50:47 +08:00
|
|
|
save_model_with_factors(
|
2026-03-24 23:35:31 +08:00
|
|
|
model=task.get_model(),
|
|
|
|
|
model_path=output_config["model_save_path"],
|
2026-03-16 22:50:47 +08:00
|
|
|
selected_factors=SELECTED_FACTORS,
|
|
|
|
|
factor_definitions=FACTOR_DEFINITIONS,
|
2026-03-24 23:35:31 +08:00
|
|
|
fitted_processors=pipeline.get_fitted_processors(),
|
2026-03-16 22:50:47 +08:00
|
|
|
)
|
|
|
|
|
|
2026-03-24 23:35:31 +08:00
|
|
|
print("\n" + "=" * 80)
|
|
|
|
|
print("训练流程完成!")
|
|
|
|
|
print(f"结果保存路径: {os.path.join(OUTPUT_DIR, 'rank_output.csv')}")
|
|
|
|
|
print("=" * 80)
|
|
|
|
|
|
|
|
|
|
return results
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
|
main()
|