feat(training): 添加数据质量检查工具并重构实验脚本
- 新增 check_data_quality 函数用于检测全空/全零/全NaN数据质量问题 - 重构 register_factors 函数,消除 FEATURE_COLS 和 PROCESSORS 冗余定义 - 修复实验脚本中特征列表不一致的问题,确保处理器覆盖所有特征 - 优化 LambdaRank 模型参数配置
This commit is contained in:
File diff suppressed because one or more lines are too long
@@ -31,6 +31,7 @@ from src.training import (
|
|||||||
Winsorizer,
|
Winsorizer,
|
||||||
NullFiller,
|
NullFiller,
|
||||||
StandardScaler,
|
StandardScaler,
|
||||||
|
check_data_quality,
|
||||||
)
|
)
|
||||||
from src.training.components.models import LightGBMLambdaRankModel
|
from src.training.components.models import LightGBMLambdaRankModel
|
||||||
from src.training.config import TrainingConfig
|
from src.training.config import TrainingConfig
|
||||||
@@ -39,13 +40,13 @@ from src.training.config import TrainingConfig
|
|||||||
# %% md
|
# %% md
|
||||||
# ## 2. 辅助函数
|
# ## 2. 辅助函数
|
||||||
# %%
|
# %%
|
||||||
def create_factors_with_metadata(
|
def register_factors(
|
||||||
engine: FactorEngine,
|
engine: FactorEngine,
|
||||||
selected_factors: List[str],
|
selected_factors: List[str],
|
||||||
factor_definitions: dict,
|
factor_definitions: dict,
|
||||||
label_factor: dict,
|
label_factor: dict,
|
||||||
) -> List[str]:
|
) -> List[str]:
|
||||||
"""注册因子(SELECTED_FACTORS 从 metadata 查询,FACTOR_DEFINITIONS 用表达式注册)"""
|
"""注册因子(selected_factors 从 metadata 查询,factor_definitions 用 DSL 表达式注册)"""
|
||||||
print("=" * 80)
|
print("=" * 80)
|
||||||
print("注册因子")
|
print("注册因子")
|
||||||
print("=" * 80)
|
print("=" * 80)
|
||||||
@@ -326,14 +327,18 @@ VAL_END = "20241231"
|
|||||||
TEST_START = "20250101"
|
TEST_START = "20250101"
|
||||||
TEST_END = "20251231"
|
TEST_END = "20251231"
|
||||||
|
|
||||||
|
|
||||||
|
# 分位数配置
|
||||||
|
N_QUANTILES = 20 # 将 label 分为 20 组
|
||||||
|
|
||||||
# LambdaRank 模型参数配置
|
# LambdaRank 模型参数配置
|
||||||
MODEL_PARAMS = {
|
MODEL_PARAMS = {
|
||||||
"objective": "lambdarank",
|
"objective": "lambdarank",
|
||||||
"metric": "ndcg",
|
"metric": "ndcg",
|
||||||
"ndcg_at": 2, # 评估 NDCG@k
|
"ndcg_at": 10, # 评估 NDCG@k
|
||||||
"learning_rate": 0.01,
|
"learning_rate": 0.01,
|
||||||
"num_leaves": 31,
|
"num_leaves": 31,
|
||||||
"max_depth": 6,
|
"max_depth": 4,
|
||||||
"min_data_in_leaf": 20,
|
"min_data_in_leaf": 20,
|
||||||
"n_estimators": 2000,
|
"n_estimators": 2000,
|
||||||
"early_stopping_round": 300,
|
"early_stopping_round": 300,
|
||||||
@@ -343,21 +348,10 @@ MODEL_PARAMS = {
|
|||||||
"reg_lambda": 1.0,
|
"reg_lambda": 1.0,
|
||||||
"verbose": -1,
|
"verbose": -1,
|
||||||
"random_state": 42,
|
"random_state": 42,
|
||||||
|
"lambdarank_truncation_level": 10,
|
||||||
|
"label_gain": [i for i in range(1, N_QUANTILES + 1)],
|
||||||
}
|
}
|
||||||
|
|
||||||
# 分位数配置
|
|
||||||
N_QUANTILES = 20 # 将 label 分为 20 组
|
|
||||||
|
|
||||||
# 特征列(用于数据处理器)
|
|
||||||
FEATURE_COLS = SELECTED_FACTORS
|
|
||||||
|
|
||||||
# 数据处理器配置
|
|
||||||
PROCESSORS = [
|
|
||||||
NullFiller(feature_cols=FEATURE_COLS, strategy="mean"),
|
|
||||||
Winsorizer(feature_cols=FEATURE_COLS, lower=0.01, upper=0.99),
|
|
||||||
StandardScaler(feature_cols=FEATURE_COLS),
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
# 股票池筛选函数
|
# 股票池筛选函数
|
||||||
def stock_pool_filter(df: pl.DataFrame) -> pl.Series:
|
def stock_pool_filter(df: pl.DataFrame) -> pl.Series:
|
||||||
@@ -406,7 +400,7 @@ engine = FactorEngine()
|
|||||||
|
|
||||||
# 2. 使用 metadata 定义因子
|
# 2. 使用 metadata 定义因子
|
||||||
print("\n[2] 定义因子(从 metadata 注册)")
|
print("\n[2] 定义因子(从 metadata 注册)")
|
||||||
feature_cols = create_factors_with_metadata(
|
feature_cols = register_factors(
|
||||||
engine, SELECTED_FACTORS, FACTOR_DEFINITIONS, LABEL_FACTOR
|
engine, SELECTED_FACTORS, FACTOR_DEFINITIONS, LABEL_FACTOR
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -435,10 +429,14 @@ print(f"[配置] 特征数: {len(feature_cols)}")
|
|||||||
print(f"[配置] 目标变量: {target_col}({N_QUANTILES}分位数)")
|
print(f"[配置] 目标变量: {target_col}({N_QUANTILES}分位数)")
|
||||||
|
|
||||||
# 6. 创建排序学习模型
|
# 6. 创建排序学习模型
|
||||||
model = LightGBMLambdaRankModel(params=MODEL_PARAMS)
|
model: LightGBMLambdaRankModel = LightGBMLambdaRankModel(params=MODEL_PARAMS)
|
||||||
|
|
||||||
# 7. 创建数据处理器
|
# 7. 创建数据处理器(使用函数返回的完整特征列表)
|
||||||
processors = PROCESSORS
|
processors = [
|
||||||
|
NullFiller(feature_cols=feature_cols, strategy="mean"),
|
||||||
|
Winsorizer(feature_cols=feature_cols, lower=0.01, upper=0.99),
|
||||||
|
StandardScaler(feature_cols=feature_cols),
|
||||||
|
]
|
||||||
|
|
||||||
# 8. 创建数据划分器
|
# 8. 创建数据划分器
|
||||||
splitter = DateSplitter(
|
splitter = DateSplitter(
|
||||||
@@ -522,7 +520,25 @@ if splitter:
|
|||||||
else:
|
else:
|
||||||
raise ValueError("必须配置数据划分器")
|
raise ValueError("必须配置数据划分器")
|
||||||
# %% md
|
# %% md
|
||||||
# ### 4.3 数据预处理
|
# ### 4.3 数据质量检查
|
||||||
|
# %%
|
||||||
|
print("\n" + "=" * 80)
|
||||||
|
print("数据质量检查(必须在预处理之前)")
|
||||||
|
print("=" * 80)
|
||||||
|
|
||||||
|
print("\n检查训练集...")
|
||||||
|
check_data_quality(train_data, feature_cols, raise_on_error=True)
|
||||||
|
|
||||||
|
print("\n检查验证集...")
|
||||||
|
check_data_quality(val_data, feature_cols, raise_on_error=True)
|
||||||
|
|
||||||
|
print("\n检查测试集...")
|
||||||
|
check_data_quality(test_data, feature_cols, raise_on_error=True)
|
||||||
|
|
||||||
|
print("[成功] 数据质量检查通过,未发现异常")
|
||||||
|
|
||||||
|
# %% md
|
||||||
|
# ### 4.4 数据预处理
|
||||||
# %%
|
# %%
|
||||||
print("\n" + "=" * 80)
|
print("\n" + "=" * 80)
|
||||||
print("数据预处理")
|
print("数据预处理")
|
||||||
@@ -584,112 +600,51 @@ print("\n" + "=" * 80)
|
|||||||
print("训练指标曲线")
|
print("训练指标曲线")
|
||||||
print("=" * 80)
|
print("=" * 80)
|
||||||
|
|
||||||
# 重新训练以收集指标(因为之前的训练没有保存评估结果)
|
# 从模型获取训练评估结果
|
||||||
print("\n重新训练模型以收集训练指标...")
|
evals_result = model.get_evals_result()
|
||||||
|
|
||||||
import lightgbm as lgb
|
if evals_result is None or not evals_result:
|
||||||
|
print("[警告] 没有可用的训练指标,请确保训练时使用了 eval_set 参数")
|
||||||
# 准备数据(使用 val 做验证,test 不参与训练过程)
|
|
||||||
X_train_np = X_train.to_numpy()
|
|
||||||
y_train_np = y_train.to_numpy()
|
|
||||||
X_val_np = val_data.select(feature_cols).to_numpy()
|
|
||||||
y_val_np = val_data.select(target_col).to_series().to_numpy()
|
|
||||||
|
|
||||||
# 创建数据集
|
|
||||||
train_dataset = lgb.Dataset(X_train_np, label=y_train_np, group=train_group)
|
|
||||||
val_dataset = lgb.Dataset(
|
|
||||||
X_val_np, label=y_val_np, group=val_group, reference=train_dataset
|
|
||||||
)
|
|
||||||
|
|
||||||
# 用于存储评估结果
|
|
||||||
evals_result = {}
|
|
||||||
|
|
||||||
# 使用与原模型相同的参数重新训练
|
|
||||||
# 正确的三分法:train用于训练,val用于验证,test不参与训练过程
|
|
||||||
booster_with_eval = lgb.train(
|
|
||||||
MODEL_PARAMS,
|
|
||||||
train_dataset,
|
|
||||||
num_boost_round=MODEL_PARAMS.get("n_estimators", 1000),
|
|
||||||
valid_sets=[train_dataset, val_dataset],
|
|
||||||
valid_names=["train", "val"],
|
|
||||||
callbacks=[
|
|
||||||
lgb.record_evaluation(evals_result),
|
|
||||||
lgb.early_stopping(stopping_rounds=50, verbose=True),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
print("训练完成,指标已收集")
|
|
||||||
|
|
||||||
# 获取评估的 NDCG 指标
|
|
||||||
ndcg_metrics = [k for k in evals_result["train"].keys() if "ndcg" in k]
|
|
||||||
print(f"\n评估的 NDCG 指标: {ndcg_metrics}")
|
|
||||||
|
|
||||||
# 显示早停信息
|
|
||||||
actual_rounds = len(list(evals_result["train"].values())[0])
|
|
||||||
expected_rounds = MODEL_PARAMS.get("n_estimators", 1000)
|
|
||||||
print(f"\n[早停信息]")
|
|
||||||
print(f" 配置的最大轮数: {expected_rounds}")
|
|
||||||
print(f" 实际训练轮数: {actual_rounds}")
|
|
||||||
if actual_rounds < expected_rounds:
|
|
||||||
print(f" 早停状态: 已触发(连续50轮验证指标未改善)")
|
|
||||||
else:
|
else:
|
||||||
print(f" 早停状态: 未触发(达到最大轮数)")
|
print("[成功] 已从模型获取训练评估结果")
|
||||||
|
|
||||||
# 显示各 NDCG 指标的最终值
|
# 获取评估的 NDCG 指标
|
||||||
print(f"\n最终 NDCG 指标:")
|
ndcg_metrics = [k for k in evals_result["train"].keys() if "ndcg" in k]
|
||||||
for metric in ndcg_metrics:
|
print(f"\n评估的 NDCG 指标: {ndcg_metrics}")
|
||||||
train_ndcg = evals_result["train"][metric][-1]
|
|
||||||
val_ndcg = evals_result["val"][metric][-1]
|
|
||||||
print(f" {metric}: 训练集={train_ndcg:.4f}, 验证集={val_ndcg:.4f}")
|
|
||||||
# %%
|
|
||||||
# 绘制 NDCG 训练指标曲线
|
|
||||||
import matplotlib.pyplot as plt
|
|
||||||
|
|
||||||
fig, axes = plt.subplots(2, 2, figsize=(14, 10))
|
# 显示早停信息
|
||||||
axes = axes.flatten()
|
actual_rounds = len(list(evals_result["train"].values())[0])
|
||||||
|
expected_rounds = MODEL_PARAMS.get("n_estimators", 1000)
|
||||||
|
print(f"\n[早停信息]")
|
||||||
|
print(f" 配置的最大轮数: {expected_rounds}")
|
||||||
|
print(f" 实际训练轮数: {actual_rounds}")
|
||||||
|
|
||||||
for idx, metric in enumerate(ndcg_metrics[:4]): # 最多显示4个NDCG指标
|
best_iter = model.get_best_iteration()
|
||||||
ax = axes[idx]
|
if best_iter is not None and best_iter < actual_rounds:
|
||||||
train_metric = evals_result["train"][metric]
|
print(f" 早停状态: 已触发(最佳迭代: {best_iter})")
|
||||||
val_metric = evals_result["val"][metric]
|
else:
|
||||||
iterations = range(1, len(train_metric) + 1)
|
print(f" 早停状态: 未触发(达到最大轮数)")
|
||||||
|
|
||||||
ax.plot(
|
# 显示各 NDCG 指标的最终值
|
||||||
iterations, train_metric, label=f"Train {metric}", linewidth=2, color="blue"
|
print(f"\n最终 NDCG 指标:")
|
||||||
)
|
for metric in ndcg_metrics:
|
||||||
ax.plot(iterations, val_metric, label=f"Val {metric}", linewidth=2, color="red")
|
train_ndcg = evals_result["train"][metric][-1]
|
||||||
ax.set_xlabel("Iteration", fontsize=10)
|
val_ndcg = evals_result["val"][metric][-1]
|
||||||
ax.set_ylabel(metric.upper(), fontsize=10)
|
print(f" {metric}: 训练集={train_ndcg:.4f}, 验证集={val_ndcg:.4f}")
|
||||||
ax.set_title(
|
|
||||||
f"Training and Validation {metric.upper()}", fontsize=12, fontweight="bold"
|
|
||||||
)
|
|
||||||
ax.legend(fontsize=9)
|
|
||||||
ax.grid(True, alpha=0.3)
|
|
||||||
|
|
||||||
# 标记最佳验证指标点
|
# 使用封装好的方法绘制所有指标
|
||||||
best_iter = val_metric.index(max(val_metric))
|
print("\n[绘图] 使用 LightGBM 原生接口绘制训练曲线...")
|
||||||
best_metric = max(val_metric)
|
fig = model.plot_all_metrics(metrics=ndcg_metrics[:4], figsize=(14, 10))
|
||||||
ax.axvline(x=best_iter + 1, color="green", linestyle="--", alpha=0.7)
|
plt.show()
|
||||||
ax.scatter([best_iter + 1], [best_metric], color="green", s=80, zorder=5)
|
|
||||||
ax.annotate(
|
|
||||||
f"Best: {best_metric:.4f}",
|
|
||||||
xy=(best_iter + 1, best_metric),
|
|
||||||
xytext=(best_iter + 1 + len(iterations) * 0.05, best_metric),
|
|
||||||
fontsize=8,
|
|
||||||
arrowprops=dict(arrowstyle="->", color="green", alpha=0.7),
|
|
||||||
)
|
|
||||||
|
|
||||||
plt.tight_layout()
|
print(f"\n[指标分析]")
|
||||||
plt.show()
|
print(f" 各NDCG指标在验证集上的最佳值:")
|
||||||
|
for metric in ndcg_metrics:
|
||||||
print(f"\n[指标分析]")
|
val_metric_list = evals_result["val"][metric]
|
||||||
print(f" 各NDCG指标在验证集上的最佳值:")
|
best_iter_metric = val_metric_list.index(max(val_metric_list))
|
||||||
for metric in ndcg_metrics:
|
best_val = max(val_metric_list)
|
||||||
val_metric_list = evals_result["val"][metric]
|
print(f" {metric}: {best_val:.4f} (迭代 {best_iter_metric + 1})")
|
||||||
best_iter = val_metric_list.index(max(val_metric_list))
|
print(f"\n[重要提醒] 验证集仅用于早停/调参,测试集完全独立于训练过程!")
|
||||||
best_val = max(val_metric_list)
|
|
||||||
print(f" {metric}: {best_val:.4f} (迭代 {best_iter + 1})")
|
|
||||||
print(f"\n[重要提醒] 验证集仅用于早停/调参,测试集完全独立于训练过程!")
|
|
||||||
# %% md
|
# %% md
|
||||||
# ### 4.6 模型评估
|
# ### 4.6 模型评估
|
||||||
# %%
|
# %%
|
||||||
|
|||||||
@@ -18,6 +18,7 @@ from src.training import (
|
|||||||
Trainer,
|
Trainer,
|
||||||
Winsorizer,
|
Winsorizer,
|
||||||
NullFiller,
|
NullFiller,
|
||||||
|
check_data_quality,
|
||||||
)
|
)
|
||||||
from src.training.config import TrainingConfig
|
from src.training.config import TrainingConfig
|
||||||
|
|
||||||
@@ -25,13 +26,13 @@ from src.training.config import TrainingConfig
|
|||||||
# %% md
|
# %% md
|
||||||
# ## 2. 定义辅助函数
|
# ## 2. 定义辅助函数
|
||||||
# %%
|
# %%
|
||||||
def create_factors_with_metadata(
|
def register_factors(
|
||||||
engine: FactorEngine,
|
engine: FactorEngine,
|
||||||
selected_factors: List[str],
|
selected_factors: List[str],
|
||||||
factor_definitions: dict,
|
factor_definitions: dict,
|
||||||
label_factor: dict,
|
label_factor: dict,
|
||||||
) -> List[str]:
|
) -> List[str]:
|
||||||
"""注册因子(SELECTED_FACTORS 从 metadata 查询,FACTOR_DEFINITIONS 用表达式注册)"""
|
"""注册因子(selected_factors 从 metadata 查询,factor_definitions 用 DSL 表达式注册)"""
|
||||||
print("=" * 80)
|
print("=" * 80)
|
||||||
print("注册因子")
|
print("注册因子")
|
||||||
print("=" * 80)
|
print("=" * 80)
|
||||||
@@ -285,9 +286,6 @@ MODEL_PARAMS = {
|
|||||||
"random_state": 42,
|
"random_state": 42,
|
||||||
}
|
}
|
||||||
|
|
||||||
# 数据处理器配置(新 API:需要传入 feature_cols)
|
|
||||||
# 注意:processor 现在需要显式指定要处理的特征列
|
|
||||||
|
|
||||||
|
|
||||||
# 股票池筛选函数
|
# 股票池筛选函数
|
||||||
# 使用新的 StockPoolManager API:传入自定义筛选函数和所需列/因子
|
# 使用新的 StockPoolManager API:传入自定义筛选函数和所需列/因子
|
||||||
@@ -355,7 +353,7 @@ engine = FactorEngine(metadata_path="data/factors.jsonl")
|
|||||||
|
|
||||||
# 2. 使用 metadata 定义因子
|
# 2. 使用 metadata 定义因子
|
||||||
print("\n[2] 定义因子(从 metadata 注册)")
|
print("\n[2] 定义因子(从 metadata 注册)")
|
||||||
feature_cols = create_factors_with_metadata(
|
feature_cols = register_factors(
|
||||||
engine, SELECTED_FACTORS, FACTOR_DEFINITIONS, LABEL_FACTOR
|
engine, SELECTED_FACTORS, FACTOR_DEFINITIONS, LABEL_FACTOR
|
||||||
)
|
)
|
||||||
target_col = LABEL_NAME
|
target_col = LABEL_NAME
|
||||||
@@ -380,7 +378,7 @@ print(f"[配置] 目标变量: {target_col}")
|
|||||||
# 5. 创建模型
|
# 5. 创建模型
|
||||||
model = LightGBMModel(params=MODEL_PARAMS)
|
model = LightGBMModel(params=MODEL_PARAMS)
|
||||||
|
|
||||||
# 6. 创建数据处理器(新 API:需要传入 feature_cols)
|
# 6. 创建数据处理器(使用函数返回的完整特征列表)
|
||||||
processors = [
|
processors = [
|
||||||
NullFiller(feature_cols=feature_cols, strategy="mean"),
|
NullFiller(feature_cols=feature_cols, strategy="mean"),
|
||||||
Winsorizer(feature_cols=feature_cols, lower=0.01, upper=0.99),
|
Winsorizer(feature_cols=feature_cols, lower=0.01, upper=0.99),
|
||||||
@@ -482,8 +480,26 @@ else:
|
|||||||
test_data = filtered_data
|
test_data = filtered_data
|
||||||
print(" 未配置划分器,全部作为训练集")
|
print(" 未配置划分器,全部作为训练集")
|
||||||
# %%
|
# %%
|
||||||
# 步骤 3: 训练集数据处理
|
# 步骤 3: 数据质量检查(必须在预处理之前)
|
||||||
print("\n[步骤 3/6] 训练集数据处理")
|
print("\n[步骤 3/7] 数据质量检查")
|
||||||
|
print("-" * 60)
|
||||||
|
print(" [说明] 此检查在 fillna 等处理之前执行,用于发现数据问题")
|
||||||
|
|
||||||
|
print("\n 检查训练集...")
|
||||||
|
check_data_quality(train_data, feature_cols, raise_on_error=True)
|
||||||
|
|
||||||
|
if "val_data" in locals() and val_data is not None:
|
||||||
|
print("\n 检查验证集...")
|
||||||
|
check_data_quality(val_data, feature_cols, raise_on_error=True)
|
||||||
|
|
||||||
|
print("\n 检查测试集...")
|
||||||
|
check_data_quality(test_data, feature_cols, raise_on_error=True)
|
||||||
|
|
||||||
|
print(" [成功] 数据质量检查通过,未发现异常")
|
||||||
|
|
||||||
|
# %%
|
||||||
|
# 步骤 4: 训练集数据处理
|
||||||
|
print("\n[步骤 4/7] 训练集数据处理")
|
||||||
print("-" * 60)
|
print("-" * 60)
|
||||||
fitted_processors = []
|
fitted_processors = []
|
||||||
if processors:
|
if processors:
|
||||||
@@ -510,7 +526,7 @@ for col in feature_cols[:5]: # 只显示前5个特征的缺失值
|
|||||||
print(f" {col}: {null_count} ({null_count / len(train_data) * 100:.2f}%)")
|
print(f" {col}: {null_count} ({null_count / len(train_data) * 100:.2f}%)")
|
||||||
# %%
|
# %%
|
||||||
# 步骤 4: 训练模型
|
# 步骤 4: 训练模型
|
||||||
print("\n[步骤 4/6] 训练模型")
|
print("\n[步骤 5/7] 训练模型")
|
||||||
print("-" * 60)
|
print("-" * 60)
|
||||||
print(f" 模型类型: LightGBM")
|
print(f" 模型类型: LightGBM")
|
||||||
print(f" 训练样本数: {len(train_data)}")
|
print(f" 训练样本数: {len(train_data)}")
|
||||||
@@ -532,7 +548,7 @@ model.fit(X_train, y_train)
|
|||||||
print(" 训练完成!")
|
print(" 训练完成!")
|
||||||
# %%
|
# %%
|
||||||
# 步骤 5: 测试集数据处理
|
# 步骤 5: 测试集数据处理
|
||||||
print("\n[步骤 5/6] 测试集数据处理")
|
print("\n[步骤 6/7] 测试集数据处理")
|
||||||
print("-" * 60)
|
print("-" * 60)
|
||||||
if processors and test_data is not train_data:
|
if processors and test_data is not train_data:
|
||||||
for i, processor in enumerate(fitted_processors, 1):
|
for i, processor in enumerate(fitted_processors, 1):
|
||||||
@@ -548,7 +564,7 @@ else:
|
|||||||
print(" 跳过测试集处理")
|
print(" 跳过测试集处理")
|
||||||
# %%
|
# %%
|
||||||
# 步骤 6: 生成预测
|
# 步骤 6: 生成预测
|
||||||
print("\n[步骤 6/6] 生成预测")
|
print("\n[步骤 7/7] 生成预测")
|
||||||
print("-" * 60)
|
print("-" * 60)
|
||||||
X_test = test_data.select(feature_cols)
|
X_test = test_data.select(feature_cols)
|
||||||
print(f" 测试样本数: {len(X_test)}")
|
print(f" 测试样本数: {len(X_test)}")
|
||||||
|
|||||||
@@ -37,6 +37,9 @@ from src.training.components.filters import BaseFilter, STFilter
|
|||||||
# 训练核心
|
# 训练核心
|
||||||
from src.training.core import StockPoolManager, Trainer
|
from src.training.core import StockPoolManager, Trainer
|
||||||
|
|
||||||
|
# 工具函数
|
||||||
|
from src.training.utils import check_data_quality
|
||||||
|
|
||||||
# 配置
|
# 配置
|
||||||
from src.training.config import TrainingConfig
|
from src.training.config import TrainingConfig
|
||||||
|
|
||||||
@@ -67,6 +70,8 @@ __all__ = [
|
|||||||
# 训练核心
|
# 训练核心
|
||||||
"StockPoolManager",
|
"StockPoolManager",
|
||||||
"Trainer",
|
"Trainer",
|
||||||
|
# 工具函数
|
||||||
|
"check_data_quality",
|
||||||
# 配置
|
# 配置
|
||||||
"TrainingConfig",
|
"TrainingConfig",
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -98,6 +98,7 @@ class LightGBMLambdaRankModel(BaseModel):
|
|||||||
|
|
||||||
self.model = None
|
self.model = None
|
||||||
self.feature_names_: Optional[list] = None
|
self.feature_names_: Optional[list] = None
|
||||||
|
self.evals_result_: Optional[dict] = None # 存储训练评估结果
|
||||||
|
|
||||||
def fit(
|
def fit(
|
||||||
self,
|
self,
|
||||||
@@ -155,8 +156,9 @@ class LightGBMLambdaRankModel(BaseModel):
|
|||||||
# 创建训练数据集
|
# 创建训练数据集
|
||||||
train_data = lgb.Dataset(X_np, label=y_np, group=group)
|
train_data = lgb.Dataset(X_np, label=y_np, group=group)
|
||||||
|
|
||||||
# 准备验证集
|
# 准备验证集和验证集名称
|
||||||
valid_sets = [train_data]
|
valid_sets = [train_data]
|
||||||
|
valid_names = ["train"]
|
||||||
if eval_set is not None:
|
if eval_set is not None:
|
||||||
X_val, y_val, group_val = eval_set
|
X_val, y_val, group_val = eval_set
|
||||||
X_val_np = X_val.to_numpy() if isinstance(X_val, pl.DataFrame) else X_val
|
X_val_np = X_val.to_numpy() if isinstance(X_val, pl.DataFrame) else X_val
|
||||||
@@ -169,15 +171,23 @@ class LightGBMLambdaRankModel(BaseModel):
|
|||||||
|
|
||||||
val_data = lgb.Dataset(X_val_np, label=y_val_np, group=group_val)
|
val_data = lgb.Dataset(X_val_np, label=y_val_np, group=group_val)
|
||||||
valid_sets.append(val_data)
|
valid_sets.append(val_data)
|
||||||
|
valid_names.append("val")
|
||||||
|
|
||||||
|
# 初始化评估结果存储
|
||||||
|
self.evals_result_ = {}
|
||||||
|
|
||||||
# 训练
|
# 训练
|
||||||
callbacks = [lgb.early_stopping(stopping_rounds=self.early_stopping_rounds)]
|
callbacks = [
|
||||||
|
lgb.early_stopping(stopping_rounds=self.early_stopping_rounds),
|
||||||
|
lgb.record_evaluation(self.evals_result_),
|
||||||
|
]
|
||||||
|
|
||||||
self.model = lgb.train(
|
self.model = lgb.train(
|
||||||
self.params,
|
self.params,
|
||||||
train_data,
|
train_data,
|
||||||
num_boost_round=self.n_estimators,
|
num_boost_round=self.n_estimators,
|
||||||
valid_sets=valid_sets,
|
valid_sets=valid_sets,
|
||||||
|
valid_names=valid_names,
|
||||||
callbacks=callbacks,
|
callbacks=callbacks,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -201,6 +211,185 @@ class LightGBMLambdaRankModel(BaseModel):
|
|||||||
X_np = X.to_numpy()
|
X_np = X.to_numpy()
|
||||||
return self.model.predict(X_np)
|
return self.model.predict(X_np)
|
||||||
|
|
||||||
|
def get_evals_result(self) -> Optional[dict]:
|
||||||
|
"""获取训练评估结果
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
评估结果字典,包含训练集和验证集的指标历史
|
||||||
|
格式: {'train': {'metric_name': [...]}, 'val': {'metric_name': [...]}}
|
||||||
|
如果模型尚未训练,返回 None
|
||||||
|
"""
|
||||||
|
return self.evals_result_
|
||||||
|
|
||||||
|
def plot_metric(
|
||||||
|
self,
|
||||||
|
metric: Optional[str] = None,
|
||||||
|
figsize: tuple = (10, 6),
|
||||||
|
title: Optional[str] = None,
|
||||||
|
ax=None,
|
||||||
|
):
|
||||||
|
"""绘制训练指标曲线
|
||||||
|
|
||||||
|
使用 LightGBM 原生的 plot_metric 接口绘制训练曲线。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
metric: 要绘制的指标名称,如 'ndcg@5'、'ndcg@10' 等。
|
||||||
|
如果为 None,则自动选择第一个可用的 NDCG 指标。
|
||||||
|
figsize: 图大小,默认 (10, 6)
|
||||||
|
title: 图表标题,如果为 None 则自动生成
|
||||||
|
ax: matplotlib Axes 对象,如果为 None 则创建新图
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
matplotlib Axes 对象
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
RuntimeError: 模型尚未训练
|
||||||
|
ValueError: 指定的指标不存在
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> model.plot_metric('ndcg@20') # 绘制 ndcg@20 曲线
|
||||||
|
>>> model.plot_metric() # 自动选择指标
|
||||||
|
"""
|
||||||
|
if self.model is None:
|
||||||
|
raise RuntimeError("模型尚未训练,请先调用 fit()")
|
||||||
|
|
||||||
|
if self.evals_result_ is None or not self.evals_result_:
|
||||||
|
raise RuntimeError("没有可用的评估结果")
|
||||||
|
|
||||||
|
import lightgbm as lgb
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
|
# 如果没有指定指标,自动选择第一个 NDCG 指标
|
||||||
|
if metric is None:
|
||||||
|
available_metrics = list(self.evals_result_.get("train", {}).keys())
|
||||||
|
ndcg_metrics = [m for m in available_metrics if "ndcg" in m.lower()]
|
||||||
|
if ndcg_metrics:
|
||||||
|
metric = ndcg_metrics[0]
|
||||||
|
elif available_metrics:
|
||||||
|
metric = available_metrics[0]
|
||||||
|
else:
|
||||||
|
raise ValueError("没有可用的评估指标")
|
||||||
|
|
||||||
|
# 检查指标是否存在
|
||||||
|
if metric not in self.evals_result_.get("train", {}):
|
||||||
|
available = list(self.evals_result_.get("train", {}).keys())
|
||||||
|
raise ValueError(f"指标 '{metric}' 不存在。可用的指标: {available}")
|
||||||
|
|
||||||
|
# 创建图表
|
||||||
|
if ax is None:
|
||||||
|
fig, ax = plt.subplots(figsize=figsize)
|
||||||
|
|
||||||
|
# 使用 LightGBM 原生接口绘制
|
||||||
|
lgb.plot_metric(self.evals_result_, metric=metric, ax=ax)
|
||||||
|
|
||||||
|
# 设置标题
|
||||||
|
if title is None:
|
||||||
|
title = f"Training Metric ({metric.upper()}) over Iterations"
|
||||||
|
ax.set_title(title, fontsize=12, fontweight="bold")
|
||||||
|
|
||||||
|
return ax
|
||||||
|
|
||||||
|
def plot_all_metrics(
|
||||||
|
self,
|
||||||
|
metrics: Optional[list] = None,
|
||||||
|
figsize: tuple = (14, 10),
|
||||||
|
max_cols: int = 2,
|
||||||
|
):
|
||||||
|
"""绘制所有训练指标曲线
|
||||||
|
|
||||||
|
在一个图表中绘制多个指标的训练曲线。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
metrics: 要绘制的指标列表,如果为 None 则绘制所有 NDCG 指标
|
||||||
|
figsize: 图大小,默认 (14, 10)
|
||||||
|
max_cols: 每行最多显示的子图数,默认 2
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
matplotlib Figure 对象
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
RuntimeError: 模型尚未训练
|
||||||
|
"""
|
||||||
|
if self.model is None:
|
||||||
|
raise RuntimeError("模型尚未训练,请先调用 fit()")
|
||||||
|
|
||||||
|
if self.evals_result_ is None or not self.evals_result_:
|
||||||
|
raise RuntimeError("没有可用的评估结果")
|
||||||
|
|
||||||
|
import lightgbm as lgb
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
|
available_metrics = list(self.evals_result_.get("train", {}).keys())
|
||||||
|
|
||||||
|
# 如果没有指定指标,使用所有 NDCG 指标(最多 4 个)
|
||||||
|
if metrics is None:
|
||||||
|
ndcg_metrics = [m for m in available_metrics if "ndcg" in m.lower()]
|
||||||
|
metrics = ndcg_metrics[:4] if ndcg_metrics else available_metrics[:4]
|
||||||
|
|
||||||
|
if not metrics:
|
||||||
|
raise ValueError("没有可用的评估指标")
|
||||||
|
|
||||||
|
# 计算子图布局
|
||||||
|
n_metrics = len(metrics)
|
||||||
|
n_cols = min(max_cols, n_metrics)
|
||||||
|
n_rows = (n_metrics + n_cols - 1) // n_cols
|
||||||
|
|
||||||
|
fig, axes = plt.subplots(n_rows, n_cols, figsize=figsize)
|
||||||
|
if n_metrics == 1:
|
||||||
|
axes = [axes]
|
||||||
|
else:
|
||||||
|
axes = (
|
||||||
|
axes.flatten()
|
||||||
|
if n_rows > 1
|
||||||
|
else [axes]
|
||||||
|
if n_cols == 1
|
||||||
|
else axes.flatten()
|
||||||
|
)
|
||||||
|
|
||||||
|
for idx, metric in enumerate(metrics):
|
||||||
|
if idx < len(axes):
|
||||||
|
ax = axes[idx]
|
||||||
|
if metric in available_metrics:
|
||||||
|
self.plot_metric(metric=metric, ax=ax)
|
||||||
|
ax.set_title(f"{metric.upper()}", fontsize=11, fontweight="bold")
|
||||||
|
else:
|
||||||
|
ax.text(
|
||||||
|
0.5,
|
||||||
|
0.5,
|
||||||
|
f"Metric '{metric}' not found",
|
||||||
|
ha="center",
|
||||||
|
va="center",
|
||||||
|
transform=ax.transAxes,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 隐藏多余的子图
|
||||||
|
for idx in range(n_metrics, len(axes)):
|
||||||
|
axes[idx].axis("off")
|
||||||
|
|
||||||
|
plt.tight_layout()
|
||||||
|
return fig
|
||||||
|
|
||||||
|
def get_best_iteration(self) -> Optional[int]:
|
||||||
|
"""获取最佳迭代轮数(考虑早停)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
最佳迭代轮数,如果模型未训练返回 None
|
||||||
|
"""
|
||||||
|
if self.model is None:
|
||||||
|
return None
|
||||||
|
return self.model.best_iteration
|
||||||
|
|
||||||
|
def get_best_score(self) -> Optional[dict]:
|
||||||
|
"""获取最佳评分
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
最佳评分字典,格式: {'valid_0': {'metric': value}, 'valid_1': {...}}
|
||||||
|
如果模型未训练返回 None
|
||||||
|
"""
|
||||||
|
if self.model is None:
|
||||||
|
return None
|
||||||
|
return self.model.best_score
|
||||||
|
|
||||||
def feature_importance(self) -> Optional[pd.Series]:
|
def feature_importance(self) -> Optional[pd.Series]:
|
||||||
"""返回特征重要性
|
"""返回特征重要性
|
||||||
|
|
||||||
|
|||||||
171
src/training/utils.py
Normal file
171
src/training/utils.py
Normal file
@@ -0,0 +1,171 @@
|
|||||||
|
"""训练模块工具函数
|
||||||
|
|
||||||
|
提供数据质量检查、验证等通用工具函数。
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Dict, List, Optional, Union
|
||||||
|
|
||||||
|
import polars as pl
|
||||||
|
|
||||||
|
|
||||||
|
def check_data_quality(
|
||||||
|
df: pl.DataFrame,
|
||||||
|
feature_cols: List[str],
|
||||||
|
date_col: str = "trade_date",
|
||||||
|
check_all_null: bool = True,
|
||||||
|
check_all_zero: bool = True,
|
||||||
|
check_all_nan: bool = True,
|
||||||
|
raise_on_error: bool = True,
|
||||||
|
) -> Dict[str, List[Dict[str, Union[str, int]]]]:
|
||||||
|
"""检查数据质量,检测某天某个因子是否全部为空、0或NaN
|
||||||
|
|
||||||
|
此检查必须在 fillna、标准化等处理之前执行,否则错误会被掩盖。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
df: 待检查的数据
|
||||||
|
feature_cols: 特征列名列表
|
||||||
|
date_col: 日期列名,默认 "trade_date"
|
||||||
|
check_all_null: 是否检查全空,默认 True
|
||||||
|
check_all_zero: 是否检查全零,默认 True
|
||||||
|
check_all_nan: 是否检查全NaN,默认 True
|
||||||
|
raise_on_error: 发现问题时是否抛出异常,默认 True
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
检查结果字典,格式:
|
||||||
|
{
|
||||||
|
"all_null": [{"date": "20240101", "factor": "factor_name", "count": 500}],
|
||||||
|
"all_zero": [{"date": "20240101", "factor": "factor_name", "count": 500}],
|
||||||
|
"all_nan": [{"date": "20240101", "factor": "factor_name", "count": 500}],
|
||||||
|
}
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: 当发现质量问题且 raise_on_error=True 时
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> import polars as pl
|
||||||
|
>>> df = pl.DataFrame({
|
||||||
|
... "trade_date": ["20240101", "20240101", "20240102"],
|
||||||
|
... "ts_code": ["000001.SZ", "000002.SZ", "000001.SZ"],
|
||||||
|
... "factor1": [1.0, 2.0, None],
|
||||||
|
... "factor2": [0.0, 0.0, 1.0],
|
||||||
|
... })
|
||||||
|
>>> result = check_data_quality(df, ["factor1", "factor2"])
|
||||||
|
"""
|
||||||
|
issues = {
|
||||||
|
"all_null": [],
|
||||||
|
"all_zero": [],
|
||||||
|
"all_nan": [],
|
||||||
|
}
|
||||||
|
|
||||||
|
# 获取实际存在的特征列
|
||||||
|
existing_cols = [col for col in feature_cols if col in df.columns]
|
||||||
|
if not existing_cols:
|
||||||
|
return issues
|
||||||
|
|
||||||
|
# 按日期分组检查
|
||||||
|
for date in df[date_col].unique():
|
||||||
|
day_data = df.filter(pl.col(date_col) == date)
|
||||||
|
day_str = str(date)
|
||||||
|
|
||||||
|
for col in existing_cols:
|
||||||
|
if not day_data[col].dtype.is_numeric():
|
||||||
|
continue
|
||||||
|
|
||||||
|
col_data = day_data[col]
|
||||||
|
non_null_count = col_data.count()
|
||||||
|
|
||||||
|
if non_null_count == 0:
|
||||||
|
# 该日期该因子完全没有有效数据
|
||||||
|
if check_all_null:
|
||||||
|
issues["all_null"].append(
|
||||||
|
{
|
||||||
|
"date": day_str,
|
||||||
|
"factor": col,
|
||||||
|
"count": len(day_data),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 检查是否全为零
|
||||||
|
if check_all_zero:
|
||||||
|
abs_sum = col_data.abs().sum()
|
||||||
|
if abs_sum == 0:
|
||||||
|
issues["all_zero"].append(
|
||||||
|
{
|
||||||
|
"date": day_str,
|
||||||
|
"factor": col,
|
||||||
|
"count": non_null_count,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# 检查是否全为NaN(在Polars中表现为null)
|
||||||
|
if check_all_nan:
|
||||||
|
null_count = col_data.null_count()
|
||||||
|
if null_count == non_null_count:
|
||||||
|
issues["all_nan"].append(
|
||||||
|
{
|
||||||
|
"date": day_str,
|
||||||
|
"factor": col,
|
||||||
|
"count": non_null_count,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# 生成报告
|
||||||
|
total_issues = sum(len(v) for v in issues.values())
|
||||||
|
|
||||||
|
if total_issues > 0:
|
||||||
|
report_lines = ["\n" + "=" * 80, "数据质量检查报告", "=" * 80]
|
||||||
|
|
||||||
|
if issues["all_null"]:
|
||||||
|
report_lines.append(f"\n[严重] 发现 {len(issues['all_null'])} 个全空因子:")
|
||||||
|
report_lines.append(
|
||||||
|
" (某天的某个因子所有值都是 null,可能是数据缺失或计算错误)"
|
||||||
|
)
|
||||||
|
for issue in issues["all_null"][:10]: # 最多显示10条
|
||||||
|
msg = f" - 日期 {issue['date']}: {issue['factor']} (样本数: {issue['count']})"
|
||||||
|
report_lines.append(msg)
|
||||||
|
if len(issues["all_null"]) > 10:
|
||||||
|
report_lines.append(f" ... 还有 {len(issues['all_null']) - 10} 个")
|
||||||
|
|
||||||
|
if issues["all_zero"]:
|
||||||
|
report_lines.append(f"\n[警告] 发现 {len(issues['all_zero'])} 个全零因子:")
|
||||||
|
report_lines.append(
|
||||||
|
" (某天的某个因子所有值都是 0,可能是计算错误或数据源问题)"
|
||||||
|
)
|
||||||
|
for issue in issues["all_zero"][:10]:
|
||||||
|
msg = f" - 日期 {issue['date']}: {issue['factor']} (样本数: {issue['count']})"
|
||||||
|
report_lines.append(msg)
|
||||||
|
if len(issues["all_zero"]) > 10:
|
||||||
|
report_lines.append(f" ... 还有 {len(issues['all_zero']) - 10} 个")
|
||||||
|
|
||||||
|
if issues["all_nan"]:
|
||||||
|
report_lines.append(f"\n[警告] 发现 {len(issues['all_nan'])} 个全NaN因子:")
|
||||||
|
report_lines.append(" (某天的某个因子所有值都是 NaN,可能是数值计算错误)")
|
||||||
|
for issue in issues["all_nan"][:10]:
|
||||||
|
msg = f" - 日期 {issue['date']}: {issue['factor']} (样本数: {issue['count']})"
|
||||||
|
report_lines.append(msg)
|
||||||
|
if len(issues["all_nan"]) > 10:
|
||||||
|
report_lines.append(f" ... 还有 {len(issues['all_nan']) - 10} 个")
|
||||||
|
|
||||||
|
report_lines.extend(
|
||||||
|
[
|
||||||
|
"\n" + "-" * 80,
|
||||||
|
"建议处理方式:",
|
||||||
|
" 1. 检查因子定义和数据源,确认计算逻辑是否正确",
|
||||||
|
" 2. 如果是预期内的缺失(如新股无历史数据),考虑调整因子计算窗口",
|
||||||
|
" 3. 如果是数据同步问题,重新同步相关数据",
|
||||||
|
" 4. 可以使用 filter 排除问题日期或因子",
|
||||||
|
"=" * 80,
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
report = "\n".join(report_lines)
|
||||||
|
print(report)
|
||||||
|
|
||||||
|
if raise_on_error:
|
||||||
|
raise ValueError(
|
||||||
|
f"数据质量检查失败: 发现 {total_issues} 个问题,"
|
||||||
|
f"详见上方报告。如需忽略,请设置 raise_on_error=False"
|
||||||
|
)
|
||||||
|
|
||||||
|
return issues
|
||||||
Reference in New Issue
Block a user