refactor: 优化回归实验配置和模型参数

- 将因子定义、模型参数、日期配置提取为模块级常量
- 优化 LightGBM 参数(降低过拟合风险)
- LightGBMModel 支持 params 字典参数传入
- 修复 StockFilter 创业板排除逻辑(支持 301xxx)
- 添加 experiment/output 到 .gitignore
This commit is contained in:
2026-03-05 00:38:20 +08:00
parent 3b42093100
commit 5a1f278df8
5 changed files with 183 additions and 1350 deletions

1
.gitignore vendored
View File

@@ -82,3 +82,4 @@ src/training/output/*
# AI Agent 工作目录 # AI Agent 工作目录
/.sisyphus/ /.sisyphus/
/src/experiment/output/

File diff suppressed because it is too large Load Diff

View File

@@ -5,6 +5,7 @@ Label: return_5 = (close / ts_delay(close, 5)) - 1
""" """
import os import os
from datetime import datetime
from typing import List, Tuple from typing import List, Tuple
import polars as pl import polars as pl
@@ -22,6 +23,95 @@ from src.training import (
) )
from src.training.config import TrainingConfig from src.training.config import TrainingConfig
# =============================================================================
# 因子定义(集中在此,方便修改)
# =============================================================================
# 特征因子定义字典:新增因子只需在此处添加一行
FACTOR_DEFINITIONS = {
# 1. 价格动量因子
"ma5": "ts_mean(close, 5)",
"ma10": "ts_mean(close, 10)",
"ma20": "ts_mean(close, 20)",
"ma_ratio": "ts_mean(close, 5) / ts_mean(close, 20) - 1",
# 2. 波动率因子
"volatility_5": "ts_std(close, 5)",
"volatility_20": "ts_std(close, 20)",
"vol_ratio": "ts_std(close, 5) / (ts_std(close, 20) + 1e-8)",
# 3. 收益率动量因子
"return_10": "(close / ts_delay(close, 10)) - 1",
"return_20": "(close / ts_delay(close, 20)) - 1",
# 4. 收益率变化因子
"return_diff": "(close / ts_delay(close, 5)) - 1 - ((close / ts_delay(close, 10)) - 1)",
# 5. 成交量因子
"vol_ma5": "ts_mean(vol, 5)",
"vol_ma20": "ts_mean(vol, 20)",
"vol_ratio": "ts_mean(vol, 5) / (ts_mean(vol, 20) + 1e-8)",
# 6. 市值因子(截面排名)
"market_cap_rank": "cs_rank(total_mv)",
# 7. 价格位置因子
"high_low_ratio": "(close - ts_min(low, 20)) / (ts_max(high, 20) - ts_min(low, 20) + 1e-8)",
"n_income": "n_income",
}
# Label 因子定义(不参与训练,用于计算目标)
LABEL_FACTOR = {
"return_5": "(close / ts_delay(close, 5)) - 1",
}
# =============================================================================
# 训练参数配置(集中在此,方便修改)
# =============================================================================
# 日期范围配置
TRAIN_START = "20200101"
TRAIN_END = "20241231"
TEST_START = "20250101"
TEST_END = "20251231"
# 模型参数配置
MODEL_PARAMS = {
"objective": "regression",
"metric": "mae", # 改为 MAE对异常值更稳健
# 树结构控制(防过拟合核心)
"num_leaves": 20, # 从31降为20降低模型复杂度
"max_depth": 4, # 显式限制深度,防止过度拟合噪声
"min_child_samples": 50, # 叶子最小样本数,防止学习极端样本
"min_child_weight": 0.001,
# 学习参数
"learning_rate": 0.01, # 降低学习率,配合更多树
"n_estimators": 1000, # 增加树数量,配合早停
# 采样策略(关键防过拟合)
"subsample": 0.8, # 每棵树随机采样80%数据(行采样)
"subsample_freq": 5, # 每5轮迭代进行一次 subsample
"colsample_bytree": 0.8, # 每棵树随机选择80%特征(列采样)
# 正则化
"reg_alpha": 0.1, # L1正则增加稀疏性
"reg_lambda": 1.0, # L2正则平滑权重
# 数值稳定性
"verbose": -1,
"random_state": 42,
}
# 数据处理器配置
PROCESSOR_CONFIGS = [
{"name": "winsorizer", "params": {"lower": 0.01, "upper": 0.99}},
{"name": "cs_standard_scaler", "params": {}},
]
# 股票池筛选配置
STOCK_FILTER_CONFIG = {
"exclude_cyb": True, # 排除创业板
"exclude_kcb": True, # 排除科创板
"exclude_bj": True, # 排除北交所
"exclude_st": True, # 排除ST股票
}
# 输出配置(相对于本文件所在目录)
OUTPUT_DIR = "output"
SAVE_PREDICTIONS = True
PERSIST_MODEL = False
def create_factors_with_strings(engine: FactorEngine) -> List[str]: def create_factors_with_strings(engine: FactorEngine) -> List[str]:
"""使用字符串表达式定义因子 """使用字符串表达式定义因子
@@ -36,57 +126,24 @@ def create_factors_with_strings(engine: FactorEngine) -> List[str]:
print("使用字符串表达式定义因子") print("使用字符串表达式定义因子")
print("=" * 80) print("=" * 80)
# 定义所有因子(使用字典,方便维护和扩展) # 使用模块级别的因子定义
# 新增因子只需在此处添加一行即可
factor_definitions = {
# 1. 价格动量因子
"ma5": "ts_mean(close, 5)",
"ma10": "ts_mean(close, 10)",
"ma20": "ts_mean(close, 20)",
"ma_ratio": "ts_mean(close, 5) / ts_mean(close, 20) - 1",
# 2. 波动率因子
"volatility_5": "ts_std(close, 5)",
"volatility_20": "ts_std(close, 20)",
"vol_ratio": "ts_std(close, 5) / (ts_std(close, 20) + 1e-8)",
# 3. 收益率动量因子return_5 是 label需要单独注册
"return_10": "(close / ts_delay(close, 10)) - 1",
"return_20": "(close / ts_delay(close, 20)) - 1",
# 4. 收益率变化因子(使用完整表达式,不引用其他因子)
"return_diff": "(close / ts_delay(close, 5)) - 1 - ((close / ts_delay(close, 10)) - 1)",
# 5. 成交量因子
"vol_ma5": "ts_mean(vol, 5)",
"vol_ma20": "ts_mean(vol, 20)",
"vol_ratio": "ts_mean(vol, 5) / (ts_mean(vol, 20) + 1e-8)",
# 6. 市值因子(截面排名)
"market_cap_rank": "cs_rank(total_mv)",
# 7. 价格位置因子
"high_low_ratio": "(close - ts_min(low, 20)) / (ts_max(high, 20) - ts_min(low, 20) + 1e-8)",
"n_income": "n_income"
}
# Label 因子(单独定义,不参与训练)
label_factor = {
"return_5": "(close / ts_delay(close, 5)) - 1",
}
# 注册所有特征因子 # 注册所有特征因子
print("\n注册特征因子:") print("\n注册特征因子:")
for name, expr in factor_definitions.items(): for name, expr in FACTOR_DEFINITIONS.items():
engine.add_factor(name, expr) engine.add_factor(name, expr)
print(f" - {name}: {expr}") print(f" - {name}: {expr}")
# 注册 label 因子 # 注册 label 因子
print("\n注册 Label 因子:") print("\n注册 Label 因子:")
for name, expr in label_factor.items(): for name, expr in LABEL_FACTOR.items():
engine.add_factor(name, expr) engine.add_factor(name, expr)
print(f" - {name}: {expr}") print(f" - {name}: {expr}")
# 从字典自动获取特征列keys() 方法) # 从字典自动获取特征列
feature_cols = list(factor_definitions.keys()) feature_cols = list(FACTOR_DEFINITIONS.keys())
print(f"\n特征因子数: {len(feature_cols)}") print(f"\n特征因子数: {len(feature_cols)}")
print(f"Label: {list(label_factor.keys())[0]}") print(f"Label: {list(LABEL_FACTOR.keys())[0]}")
print(f"已注册因子总数: {len(engine.list_registered())}") print(f"已注册因子总数: {len(engine.list_registered())}")
return feature_cols return feature_cols
@@ -146,82 +203,42 @@ def train_regression_model():
feature_cols = create_factors_with_strings(engine) feature_cols = create_factors_with_strings(engine)
target_col = "return_5" target_col = "return_5"
# 3. 准备数据 # 3. 准备数据(使用模块级别的日期配置)
print("\n[3] 准备数据") print("\n[3] 准备数据")
train_start, train_end = "20200101", "20241231"
test_start, test_end = "20250101", "20251231"
data = prepare_data( data = prepare_data(
engine=engine, engine=engine,
feature_cols=feature_cols, feature_cols=feature_cols,
start_date=train_start, start_date=TRAIN_START,
end_date=test_end, end_date=TEST_END,
) )
# 4. 创建配置 # 4. 打印配置信息(使用模块级别的配置常量)
config = TrainingConfig( print(f"\n[配置] 训练期: {TRAIN_START} - {TRAIN_END}")
feature_cols=feature_cols, print(f"[配置] 测试期: {TEST_START} - {TEST_END}")
target_col=target_col,
date_col="trade_date",
code_col="ts_code",
train_start=train_start,
train_end=train_end,
test_start=test_start,
test_end=test_end,
model_type="lightgbm",
model_params={
"objective": "regression",
"metric": "rmse",
"num_leaves": 31,
"learning_rate": 0.05,
"n_estimators": 100,
},
processors=[
{"name": "winsorizer", "params": {"lower": 0.01, "upper": 0.99}},
{"name": "cs_standard_scaler", "params": {}},
],
persist_model=False,
model_save_path=None,
output_dir="output/regression",
save_predictions=True,
)
print(f"\n[配置] 训练期: {train_start} - {train_end}")
print(f"[配置] 测试期: {test_start} - {test_end}")
print(f"[配置] 特征数: {len(feature_cols)}") print(f"[配置] 特征数: {len(feature_cols)}")
print(f"[配置] 目标变量: {target_col}") print(f"[配置] 目标变量: {target_col}")
# 5. 创建模型 # 5. 创建模型(使用模块级别的模型参数)
model = LightGBMModel( model = LightGBMModel(params=MODEL_PARAMS)
objective="regression",
metric="rmse",
num_leaves=31,
learning_rate=0.05,
n_estimators=100,
)
# 6. 创建数据处理器 # 6. 创建数据处理器(从 PROCESSOR_CONFIGS 解析)
processors = [ processors = [
Winsorizer(lower=0.01, upper=0.99), Winsorizer(**PROCESSOR_CONFIGS[0]["params"]), # type: ignore[arg-type]
StandardScaler(exclude_cols=["ts_code", "trade_date", target_col]), StandardScaler(exclude_cols=["ts_code", "trade_date", target_col]), # type: ignore[call-arg]
] ]
# 7. 创建数据划分器 # 7. 创建数据划分器(使用模块级别的日期配置)
splitter = DateSplitter( splitter = DateSplitter(
train_start=train_start, train_start=TRAIN_START,
train_end=train_end, train_end=TRAIN_END,
test_start=test_start, test_start=TEST_START,
test_end=test_end, test_end=TEST_END,
) )
# 8. 创建股票池管理器(可选 # 8. 创建股票池管理器(使用模块级别的筛选配置
pool_manager = StockPoolManager( pool_manager = StockPoolManager(
filter_config=StockFilterConfig( filter_config=StockFilterConfig(**STOCK_FILTER_CONFIG),
exclude_cyb=True,
exclude_kcb=True,
exclude_bj=True,
exclude_st=True,
),
selector_config=None, # 暂时不启用市值选择 selector_config=None, # 暂时不启用市值选择
data_router=engine.router, # 从 FactorEngine 获取数据路由器 data_router=engine.router, # 从 FactorEngine 获取数据路由器
) )
@@ -240,7 +257,7 @@ def train_regression_model():
splitter=splitter, splitter=splitter,
target_col=target_col, target_col=target_col,
feature_cols=feature_cols, feature_cols=feature_cols,
persist_model=False, persist_model=PERSIST_MODEL,
) )
# 10. 手动执行训练流程(增加详细打印) # 10. 手动执行训练流程(增加详细打印)
@@ -401,22 +418,24 @@ def train_regression_model():
print(f"\n示例日期 {sample_date} 的前10条预测:") print(f"\n示例日期 {sample_date} 的前10条预测:")
print(sample_data.select(["ts_code", "trade_date", target_col, "prediction"])) print(sample_data.select(["ts_code", "trade_date", target_col, "prediction"]))
# 12. 保存结果(每日 top5 # 12. 保存结果
output_dir = "D:\\PyProject\\ProStock\\src\\training\\output" print("\n" + "=" * 80)
os.makedirs(output_dir, exist_ok=True) print("保存预测结果")
print("=" * 80)
# 生成文件名top_5_{开始日期}_{结束日期}.csv # 确保输出目录存在
from datetime import datetime os.makedirs(OUTPUT_DIR, exist_ok=True)
start_dt = datetime.strptime(test_start, "%Y%m%d") # 生成时间戳
end_dt = datetime.strptime(test_end, "%Y%m%d") start_dt = datetime.strptime(TEST_START, "%Y%m%d")
filename = ( end_dt = datetime.strptime(TEST_END, "%Y%m%d")
f"top_5_{start_dt.strftime('%Y-%m-%d')}_{end_dt.strftime('%Y-%m-%d')}.csv" date_str = f"{start_dt.strftime('%Y%m%d')}_{end_dt.strftime('%Y%m%d')}"
)
output_path = os.path.join(output_dir, filename) # 12.1 保存每日 Top5
print("\n[1/1] 保存每日 Top5 股票...")
top5_output_path = os.path.join(OUTPUT_DIR, f"top5_{date_str}.csv")
# 按日期分组,取每日 top5 # 按日期分组,取每日 top5
print("\n选取每日 Top 5 股票...")
top5_by_date = [] top5_by_date = []
unique_dates = results["trade_date"].unique().sort() unique_dates = results["trade_date"].unique().sort()
for date in unique_dates: for date in unique_dates:
@@ -425,29 +444,26 @@ def train_regression_model():
top5 = day_data.sort("prediction", descending=True).head(5) top5 = day_data.sort("prediction", descending=True).head(5)
top5_by_date.append(top5) top5_by_date.append(top5)
print(f" 处理完成: 共 {len(unique_dates)} 个交易日,每交易日取 top5")
# 合并所有日期的 top5 # 合并所有日期的 top5
top5_results = pl.concat(top5_by_date) top5_results = pl.concat(top5_by_date)
# 格式化日期并调整列顺序:日期、分数、股票 # 格式化日期并调整列顺序:日期、分数、股票
results_to_save = top5_results.select( top5_to_save = top5_results.select(
[ [
pl.col("trade_date").str.slice(0, 4) pl.col("trade_date").str.slice(0, 4)
+ "-" + "-"
+ pl.col("trade_date").str.slice(4, 2) + pl.col("trade_date").str.slice(4, 2)
+ "-" + "-"
+ pl.col("trade_date").str.slice(6, 2), + pl.col("trade_date").str.slice(6, 2).alias("date"),
pl.col("prediction").alias("score"), pl.col("prediction").alias("score"),
pl.col("ts_code"), pl.col("ts_code"),
] ]
).rename({"trade_date": "date"}) )
results_to_save.write_csv(output_path, include_header=True) top5_to_save.write_csv(top5_output_path, include_header=True)
print(f"\n预测结果已保存: {output_path}") print(f" 保存路径: {top5_output_path}")
print(f"保存列: {results_to_save.columns}") print(f" 保存行数: {len(top5_to_save)}{len(unique_dates)}个交易日 × 每日top5")
print(f"总行数: {len(results_to_save)}(每日 top5") print(f"\n 预览前15行:")
print(f"\n保存数据预览:") print(top5_to_save.head(15))
print(results_to_save.head(15))
# 13. 特征重要性 # 13. 特征重要性
importance = model.feature_importance() importance = model.feature_importance()

View File

@@ -3,7 +3,7 @@
提供 LightGBM 回归模型的实现,支持特征重要性和原生模型保存。 提供 LightGBM 回归模型的实现,支持特征重要性和原生模型保存。
""" """
from typing import Optional from typing import Any, Optional
import numpy as np import numpy as np
import pandas as pd import pandas as pd
@@ -31,6 +31,7 @@ class LightGBMModel(BaseModel):
def __init__( def __init__(
self, self,
params: Optional[dict] = None,
objective: str = "regression", objective: str = "regression",
metric: str = "rmse", metric: str = "rmse",
num_leaves: int = 31, num_leaves: int = 31,
@@ -40,23 +41,54 @@ class LightGBMModel(BaseModel):
): ):
"""初始化 LightGBM 模型 """初始化 LightGBM 模型
支持两种方式传入参数:
1. 通过 params 字典传入所有参数(推荐方式)
2. 通过独立参数传入(向后兼容)
Args: Args:
params: LightGBM 参数字典,如果提供则直接使用此字典
objective: 目标函数,默认 "regression" objective: 目标函数,默认 "regression"
metric: 评估指标,默认 "rmse" metric: 评估指标,默认 "rmse"
num_leaves: 叶子节点数,默认 31 num_leaves: 叶子节点数,默认 31
learning_rate: 学习率,默认 0.05 learning_rate: 学习率,默认 0.05
n_estimators: 迭代次数,默认 100 n_estimators: 迭代次数,默认 100
**kwargs: 其他 LightGBM 参数 **kwargs: 其他 LightGBM 参数
Examples:
>>> # 方式1通过 params 字典(推荐)
>>> model = LightGBMModel(params={
... "objective": "regression",
... "metric": "rmse",
... "num_leaves": 31,
... "learning_rate": 0.05,
... "n_estimators": 100,
... })
>>>
>>> # 方式2通过独立参数向后兼容
>>> model = LightGBMModel(
... objective="regression",
... num_leaves=31,
... learning_rate=0.05,
... )
""" """
self.params = { if params is not None:
"objective": objective, # 方式1直接使用 params 字典
"metric": metric, self.params = dict(params) # 复制一份,避免修改原始字典
"num_leaves": num_leaves, self.params.setdefault("verbose", -1) # 默认抑制训练输出
"learning_rate": learning_rate, # n_estimators 可能存在于 params 中
"verbose": -1, # 抑制训练输出 self.n_estimators = self.params.pop("n_estimators", n_estimators)
**kwargs, else:
} # 方式2通过独立参数构建 params
self.n_estimators = n_estimators self.params = {
"objective": objective,
"metric": metric,
"num_leaves": num_leaves,
"learning_rate": learning_rate,
"verbose": -1, # 抑制训练输出
**kwargs,
}
self.n_estimators = n_estimators
self.model = None self.model = None
self.feature_names_: Optional[list] = None self.feature_names_: Optional[list] = None

View File

@@ -15,7 +15,7 @@ class StockFilterConfig:
基于股票代码进行过滤,不依赖外部数据。 基于股票代码进行过滤,不依赖外部数据。
Attributes: Attributes:
exclude_cyb: 是否排除创业板300xxx exclude_cyb: 是否排除创业板300xxx, 301xxx
exclude_kcb: 是否排除科创板688xxx exclude_kcb: 是否排除科创板688xxx
exclude_bj: 是否排除北交所(.BJ 后缀) exclude_bj: 是否排除北交所(.BJ 后缀)
exclude_st: 是否排除ST股票需要外部数据支持 exclude_st: 是否排除ST股票需要外部数据支持
@@ -41,8 +41,8 @@ class StockFilterConfig:
""" """
result = [] result = []
for code in codes: for code in codes:
# 排除创业板300xxx # 排除创业板300xxx, 301xxx
if self.exclude_cyb and code.startswith("300"): if self.exclude_cyb and code.startswith(("300", "301")):
continue continue
# 排除科创板688xxx # 排除科创板688xxx
if self.exclude_kcb and code.startswith("688"): if self.exclude_kcb and code.startswith("688"):