feat(training): 添加数据质量检查工具并重构实验脚本
- 新增 check_data_quality 函数用于检测全空/全零/全NaN数据质量问题 - 重构 register_factors 函数,消除 FEATURE_COLS 和 PROCESSORS 冗余定义 - 修复实验脚本中特征列表不一致的问题,确保处理器覆盖所有特征 - 优化 LambdaRank 模型参数配置
This commit is contained in:
@@ -18,6 +18,7 @@ from src.training import (
|
||||
Trainer,
|
||||
Winsorizer,
|
||||
NullFiller,
|
||||
check_data_quality,
|
||||
)
|
||||
from src.training.config import TrainingConfig
|
||||
|
||||
@@ -25,13 +26,13 @@ from src.training.config import TrainingConfig
|
||||
# %% md
|
||||
# ## 2. 定义辅助函数
|
||||
# %%
|
||||
def create_factors_with_metadata(
|
||||
def register_factors(
|
||||
engine: FactorEngine,
|
||||
selected_factors: List[str],
|
||||
factor_definitions: dict,
|
||||
label_factor: dict,
|
||||
) -> List[str]:
|
||||
"""注册因子(SELECTED_FACTORS 从 metadata 查询,FACTOR_DEFINITIONS 用表达式注册)"""
|
||||
"""注册因子(selected_factors 从 metadata 查询,factor_definitions 用 DSL 表达式注册)"""
|
||||
print("=" * 80)
|
||||
print("注册因子")
|
||||
print("=" * 80)
|
||||
@@ -285,9 +286,6 @@ MODEL_PARAMS = {
|
||||
"random_state": 42,
|
||||
}
|
||||
|
||||
# 数据处理器配置(新 API:需要传入 feature_cols)
|
||||
# 注意:processor 现在需要显式指定要处理的特征列
|
||||
|
||||
|
||||
# 股票池筛选函数
|
||||
# 使用新的 StockPoolManager API:传入自定义筛选函数和所需列/因子
|
||||
@@ -355,7 +353,7 @@ engine = FactorEngine(metadata_path="data/factors.jsonl")
|
||||
|
||||
# 2. 使用 metadata 定义因子
|
||||
print("\n[2] 定义因子(从 metadata 注册)")
|
||||
feature_cols = create_factors_with_metadata(
|
||||
feature_cols = register_factors(
|
||||
engine, SELECTED_FACTORS, FACTOR_DEFINITIONS, LABEL_FACTOR
|
||||
)
|
||||
target_col = LABEL_NAME
|
||||
@@ -380,7 +378,7 @@ print(f"[配置] 目标变量: {target_col}")
|
||||
# 5. 创建模型
|
||||
model = LightGBMModel(params=MODEL_PARAMS)
|
||||
|
||||
# 6. 创建数据处理器(新 API:需要传入 feature_cols)
|
||||
# 6. 创建数据处理器(使用函数返回的完整特征列表)
|
||||
processors = [
|
||||
NullFiller(feature_cols=feature_cols, strategy="mean"),
|
||||
Winsorizer(feature_cols=feature_cols, lower=0.01, upper=0.99),
|
||||
@@ -482,8 +480,26 @@ else:
|
||||
test_data = filtered_data
|
||||
print(" 未配置划分器,全部作为训练集")
|
||||
# %%
|
||||
# 步骤 3: 训练集数据处理
|
||||
print("\n[步骤 3/6] 训练集数据处理")
|
||||
# 步骤 3: 数据质量检查(必须在预处理之前)
|
||||
print("\n[步骤 3/7] 数据质量检查")
|
||||
print("-" * 60)
|
||||
print(" [说明] 此检查在 fillna 等处理之前执行,用于发现数据问题")
|
||||
|
||||
print("\n 检查训练集...")
|
||||
check_data_quality(train_data, feature_cols, raise_on_error=True)
|
||||
|
||||
if "val_data" in locals() and val_data is not None:
|
||||
print("\n 检查验证集...")
|
||||
check_data_quality(val_data, feature_cols, raise_on_error=True)
|
||||
|
||||
print("\n 检查测试集...")
|
||||
check_data_quality(test_data, feature_cols, raise_on_error=True)
|
||||
|
||||
print(" [成功] 数据质量检查通过,未发现异常")
|
||||
|
||||
# %%
|
||||
# 步骤 4: 训练集数据处理
|
||||
print("\n[步骤 4/7] 训练集数据处理")
|
||||
print("-" * 60)
|
||||
fitted_processors = []
|
||||
if processors:
|
||||
@@ -510,7 +526,7 @@ for col in feature_cols[:5]: # 只显示前5个特征的缺失值
|
||||
print(f" {col}: {null_count} ({null_count / len(train_data) * 100:.2f}%)")
|
||||
# %%
|
||||
# 步骤 4: 训练模型
|
||||
print("\n[步骤 4/6] 训练模型")
|
||||
print("\n[步骤 5/7] 训练模型")
|
||||
print("-" * 60)
|
||||
print(f" 模型类型: LightGBM")
|
||||
print(f" 训练样本数: {len(train_data)}")
|
||||
@@ -532,7 +548,7 @@ model.fit(X_train, y_train)
|
||||
print(" 训练完成!")
|
||||
# %%
|
||||
# 步骤 5: 测试集数据处理
|
||||
print("\n[步骤 5/6] 测试集数据处理")
|
||||
print("\n[步骤 6/7] 测试集数据处理")
|
||||
print("-" * 60)
|
||||
if processors and test_data is not train_data:
|
||||
for i, processor in enumerate(fitted_processors, 1):
|
||||
@@ -548,7 +564,7 @@ else:
|
||||
print(" 跳过测试集处理")
|
||||
# %%
|
||||
# 步骤 6: 生成预测
|
||||
print("\n[步骤 6/6] 生成预测")
|
||||
print("\n[步骤 7/7] 生成预测")
|
||||
print("-" * 60)
|
||||
X_test = test_data.select(feature_cols)
|
||||
print(f" 测试样本数: {len(X_test)}")
|
||||
|
||||
Reference in New Issue
Block a user