refactor: 优化回归实验配置和模型参数
- 将因子定义、模型参数、日期配置提取为模块级常量 - 优化 LightGBM 参数(降低过拟合风险) - LightGBMModel 支持 params 字典参数传入 - 修复 StockFilter 创业板排除逻辑(支持 301xxx) - 添加 experiment/output 到 .gitignore
This commit is contained in:
@@ -3,7 +3,7 @@
|
||||
提供 LightGBM 回归模型的实现,支持特征重要性和原生模型保存。
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
@@ -31,6 +31,7 @@ class LightGBMModel(BaseModel):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
params: Optional[dict] = None,
|
||||
objective: str = "regression",
|
||||
metric: str = "rmse",
|
||||
num_leaves: int = 31,
|
||||
@@ -40,23 +41,54 @@ class LightGBMModel(BaseModel):
|
||||
):
|
||||
"""初始化 LightGBM 模型
|
||||
|
||||
支持两种方式传入参数:
|
||||
1. 通过 params 字典传入所有参数(推荐方式)
|
||||
2. 通过独立参数传入(向后兼容)
|
||||
|
||||
Args:
|
||||
params: LightGBM 参数字典,如果提供则直接使用此字典
|
||||
objective: 目标函数,默认 "regression"
|
||||
metric: 评估指标,默认 "rmse"
|
||||
num_leaves: 叶子节点数,默认 31
|
||||
learning_rate: 学习率,默认 0.05
|
||||
n_estimators: 迭代次数,默认 100
|
||||
**kwargs: 其他 LightGBM 参数
|
||||
|
||||
Examples:
|
||||
>>> # 方式1:通过 params 字典(推荐)
|
||||
>>> model = LightGBMModel(params={
|
||||
... "objective": "regression",
|
||||
... "metric": "rmse",
|
||||
... "num_leaves": 31,
|
||||
... "learning_rate": 0.05,
|
||||
... "n_estimators": 100,
|
||||
... })
|
||||
>>>
|
||||
>>> # 方式2:通过独立参数(向后兼容)
|
||||
>>> model = LightGBMModel(
|
||||
... objective="regression",
|
||||
... num_leaves=31,
|
||||
... learning_rate=0.05,
|
||||
... )
|
||||
"""
|
||||
self.params = {
|
||||
"objective": objective,
|
||||
"metric": metric,
|
||||
"num_leaves": num_leaves,
|
||||
"learning_rate": learning_rate,
|
||||
"verbose": -1, # 抑制训练输出
|
||||
**kwargs,
|
||||
}
|
||||
self.n_estimators = n_estimators
|
||||
if params is not None:
|
||||
# 方式1:直接使用 params 字典
|
||||
self.params = dict(params) # 复制一份,避免修改原始字典
|
||||
self.params.setdefault("verbose", -1) # 默认抑制训练输出
|
||||
# n_estimators 可能存在于 params 中
|
||||
self.n_estimators = self.params.pop("n_estimators", n_estimators)
|
||||
else:
|
||||
# 方式2:通过独立参数构建 params
|
||||
self.params = {
|
||||
"objective": objective,
|
||||
"metric": metric,
|
||||
"num_leaves": num_leaves,
|
||||
"learning_rate": learning_rate,
|
||||
"verbose": -1, # 抑制训练输出
|
||||
**kwargs,
|
||||
}
|
||||
self.n_estimators = n_estimators
|
||||
|
||||
self.model = None
|
||||
self.feature_names_: Optional[list] = None
|
||||
|
||||
|
||||
@@ -15,7 +15,7 @@ class StockFilterConfig:
|
||||
基于股票代码进行过滤,不依赖外部数据。
|
||||
|
||||
Attributes:
|
||||
exclude_cyb: 是否排除创业板(300xxx)
|
||||
exclude_cyb: 是否排除创业板(300xxx, 301xxx)
|
||||
exclude_kcb: 是否排除科创板(688xxx)
|
||||
exclude_bj: 是否排除北交所(.BJ 后缀)
|
||||
exclude_st: 是否排除ST股票(需要外部数据支持)
|
||||
@@ -41,8 +41,8 @@ class StockFilterConfig:
|
||||
"""
|
||||
result = []
|
||||
for code in codes:
|
||||
# 排除创业板(300xxx)
|
||||
if self.exclude_cyb and code.startswith("300"):
|
||||
# 排除创业板(300xxx, 301xxx)
|
||||
if self.exclude_cyb and code.startswith(("300", "301")):
|
||||
continue
|
||||
# 排除科创板(688xxx)
|
||||
if self.exclude_kcb and code.startswith("688"):
|
||||
|
||||
Reference in New Issue
Block a user