Files
ProStock/src/experiment/learn_to_rank.py

340 lines
7.6 KiB
Python
Raw Normal View History

# %% md
# # LightGBM LambdaRank 排序学习训练流程(模块化版本)
#
# 使用新的模块化 Trainer 架构,代码更简洁、可维护性更高。
# %% md
# ## 1. 导入依赖
# %%
import os
from datetime import datetime
from typing import List, Tuple, Optional
import numpy as np
import polars as pl
import pandas as pd
import matplotlib.pyplot as plt
from src.factors import FactorEngine
from src.training import (
FactorManager,
DataPipeline,
RankTask,
NullFiller,
Winsorizer,
CrossSectionalStandardScaler,
)
from src.training.trainer_v2 import Trainer
from src.training.components.filters import STFilter
from src.experiment.common import (
SELECTED_FACTORS,
FACTOR_DEFINITIONS,
get_label_factor,
TRAIN_START,
TRAIN_END,
VAL_START,
VAL_END,
TEST_START,
TEST_END,
stock_pool_filter,
STOCK_FILTER_REQUIRED_COLUMNS,
OUTPUT_DIR,
SAVE_PREDICTIONS,
SAVE_MODEL,
get_model_save_path,
save_model_with_factors,
TOP_N,
)
# 训练类型标识
TRAINING_TYPE = "rank"
# %% md
# ## 2. 训练特定配置
# %%
# Label 配置
LABEL_NAME = "future_return_5"
LABEL_FACTOR = get_label_factor(LABEL_NAME)
# 分位数配置
N_QUANTILES = 20
# 排除的因子列表
EXCLUDED_FACTORS = [
"volatility_5",
"volume_ratio_5_20",
"capital_retention_20",
"volatility_squeeze_5_60",
"drawdown_from_high_60",
"ma_ratio_5_20",
"bias_10",
"high_low_ratio",
"bbi_ratio",
"volatility_20",
"std_return_20",
"sharpe_ratio_20",
"ma_5",
"max_ret_20",
"CP",
"net_profit_yoy",
"debt_to_equity",
"EP_rank",
"turnover_rank",
"return_5_rank",
"ebit_rank",
"BP",
"EP",
"amihud_illiq_20",
"profit_margin",
"return_5",
"return_20",
"kaufman_ER_20",
"GTJA_alpha043",
"GTJA_alpha042",
"GTJA_alpha041",
"GTJA_alpha040",
"GTJA_alpha039",
"GTJA_alpha037",
"GTJA_alpha036",
"GTJA_alpha035",
"GTJA_alpha033",
"GTJA_alpha032",
"GTJA_alpha031",
"GTJA_alpha028",
"GTJA_alpha026",
"GTJA_alpha027",
"GTJA_alpha023",
"GTJA_alpha024",
"GTJA_alpha009",
"GTJA_alpha011",
"GTJA_alpha022",
"GTJA_alpha020",
"GTJA_alpha018",
"GTJA_alpha019",
"GTJA_alpha014",
"GTJA_alpha013",
"GTJA_alpha010",
"GTJA_alpha001",
"GTJA_alpha003",
"GTJA_alpha002",
"GTJA_alpha004",
"GTJA_alpha005",
"GTJA_alpha006",
"GTJA_alpha008",
"turnover_deviation",
"turnover_cv_20",
"roa",
"GTJA_alpha073",
"GTJA_alpha078",
"GTJA_alpha077",
"GTJA_alpha076",
"GTJA_alpha067",
"GTJA_alpha085",
"GTJA_alpha084",
"GTJA_alpha087",
"GTJA_alpha088",
"GTJA_alpha090",
"GTJA_alpha083",
"GTJA_alpha079",
"GTJA_alpha080",
"GTJA_alpha094",
"GTJA_alpha092",
"GTJA_alpha089",
"GTJA_alpha095",
"GTJA_alpha064",
"GTJA_alpha065",
"GTJA_alpha066",
"GTJA_alpha063",
"GTJA_alpha060",
"GTJA_alpha058",
"GTJA_alpha057",
"GTJA_alpha056",
"GTJA_alpha046",
"GTJA_alpha044",
"GTJA_alpha049",
"GTJA_alpha050",
"GTJA_alpha110",
"GTJA_alpha107",
"GTJA_alpha104",
"GTJA_alpha106",
"GTJA_alpha103",
"GTJA_alpha100",
"GTJA_alpha101",
"GTJA_alpha102",
"GTJA_alpha098",
"GTJA_alpha097",
"GTJA_alpha096",
"GTJA_alpha099",
"GTJA_alpha117",
"GTJA_alpha118",
"GTJA_alpha114",
"GTJA_alpha111",
"GTJA_alpha129",
"GTJA_alpha130",
"GTJA_alpha132",
"GTJA_alpha131",
"GTJA_alpha134",
"GTJA_alpha135",
"GTJA_alpha136",
"GTJA_alpha112",
"GTJA_alpha120",
"GTJA_alpha119",
"GTJA_alpha122",
"GTJA_alpha124",
"GTJA_alpha126",
"GTJA_alpha127",
"GTJA_alpha128",
"GTJA_alpha115",
"GTJA_alpha153",
"GTJA_alpha152",
"GTJA_alpha151",
"GTJA_alpha150",
"GTJA_alpha148",
"GTJA_alpha142",
"GTJA_alpha141",
"GTJA_alpha139",
"GTJA_alpha133",
"GTJA_alpha161",
"GTJA_alpha164",
"GTJA_alpha162",
"GTJA_alpha157",
"GTJA_alpha156",
"GTJA_alpha160",
"GTJA_alpha155",
"GTJA_alpha170",
"GTJA_alpha169",
"GTJA_alpha168",
"GTJA_alpha166",
"GTJA_alpha163",
"GTJA_alpha176",
"GTJA_alpha175",
"GTJA_alpha174",
"GTJA_alpha178",
"GTJA_alpha177",
"GTJA_alpha185",
"GTJA_alpha180",
"GTJA_alpha187",
"GTJA_alpha188",
"GTJA_alpha189",
"GTJA_alpha191",
]
# LambdaRank 模型参数配置
MODEL_PARAMS = {
"objective": "lambdarank",
"metric": "ndcg",
"ndcg_at": 25,
"learning_rate": 0.1,
"n_estimators": 1000,
"early_stopping_round": 50,
# 防止过拟合的核心约束
"max_depth": 4,
"num_leaves": 32,
"min_data_in_leaf": 256,
# 随机采样(增加鲁棒性)
"subsample": 0.4,
"subsample_freq": 1,
"colsample_bytree": 0.4,
# 正则化惩罚
"reg_alpha": 10.0,
"reg_lambda": 50.0,
# Lambdarank 专属配置
"lambdarank_truncation_level": 50,
"label_gain": [i * i for i in range(1, N_QUANTILES + 1)],
"verbose": -1,
"random_state": 42,
}
# 日期范围配置
date_range = {
"train": (TRAIN_START, TRAIN_END),
"val": (VAL_START, VAL_END),
"test": (TEST_START, TEST_END),
}
# 输出配置
output_config = {
"output_dir": OUTPUT_DIR,
"output_filename": "rank_output.csv",
"save_predictions": SAVE_PREDICTIONS,
"save_model": SAVE_MODEL,
"model_save_path": get_model_save_path(TRAINING_TYPE),
"top_n": TOP_N,
}
def main():
"""主函数"""
print("\n" + "=" * 80)
print("LightGBM LambdaRank 排序学习训练(模块化版本)")
print("=" * 80)
# 1. 创建 FactorEngine
print("\n[1] 创建 FactorEngine")
engine = FactorEngine()
# 2. 创建 FactorManager
print("\n[2] 创建 FactorManager")
factor_manager = FactorManager(
selected_factors=SELECTED_FACTORS,
factor_definitions=FACTOR_DEFINITIONS,
label_factor=LABEL_FACTOR,
excluded_factors=EXCLUDED_FACTORS,
)
# 3. 创建 DataPipeline
print("\n[3] 创建 DataPipeline")
pipeline = DataPipeline(
factor_manager=factor_manager,
processor_configs=[
(NullFiller, {"strategy": "mean"}),
(Winsorizer, {"lower": 0.01, "upper": 0.99}),
(CrossSectionalStandardScaler, {}),
],
filters=[STFilter(data_router=engine.router)],
stock_pool_filter_func=stock_pool_filter,
stock_pool_required_columns=STOCK_FILTER_REQUIRED_COLUMNS,
)
# 4. 创建 RankTask
print("\n[4] 创建 RankTask")
task = RankTask(
model_params=MODEL_PARAMS,
label_name=LABEL_NAME,
n_quantiles=N_QUANTILES,
)
# 5. 创建 Trainer
print("\n[5] 创建 Trainer")
trainer = Trainer(
data_pipeline=pipeline,
task=task,
output_config=output_config,
verbose=True,
)
# 6. 执行训练
print("\n[6] 执行训练")
results = trainer.run(engine=engine, date_range=date_range)
# 7. 保存模型和因子信息(如果启用)
if SAVE_MODEL:
print("\n[7] 保存模型和因子信息")
save_model_with_factors(
model=task.get_model(),
model_path=output_config["model_save_path"],
selected_factors=SELECTED_FACTORS,
factor_definitions=FACTOR_DEFINITIONS,
fitted_processors=pipeline.get_fitted_processors(),
)
print("\n" + "=" * 80)
print("训练流程完成!")
print(f"结果保存路径: {os.path.join(OUTPUT_DIR, 'rank_output.csv')}")
print("=" * 80)
return results
if __name__ == "__main__":
main()