feat(experiment): 新增因子排除机制并优化模型训练参数
- 添加 EXCLUDED_FACTORS 列表支持批量排除效果不佳的因子 - 修复 LightGBM 树结构冲突,调整正则化和采样策略防过拟合 - 调整数据处理器配置,关闭模型自动保存
This commit is contained in:
@@ -11,7 +11,6 @@ import polars as pl
|
||||
|
||||
from src.factors import FactorEngine
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# 日期范围配置(正确的 train/val/test 三分法)
|
||||
# =============================================================================
|
||||
@@ -22,7 +21,6 @@ VAL_END = "20241231"
|
||||
TEST_START = "20250101"
|
||||
TEST_END = "20261231"
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# 因子配置
|
||||
# =============================================================================
|
||||
@@ -257,6 +255,49 @@ SELECTED_FACTORS = [
|
||||
# 因子定义字典(完整因子库,用于存放尚未注册到metadata的因子)
|
||||
FACTOR_DEFINITIONS = {}
|
||||
|
||||
# 需要排除的因子列表(这些因子不会被计算和使用)
|
||||
# 用于临时屏蔽效果不好的因子,无需从 SELECTED_FACTORS 中删除
|
||||
EXCLUDED_FACTORS: List[str] = [
|
||||
'GTJA_alpha005',
|
||||
'GTJA_alpha028',
|
||||
'GTJA_alpha023',
|
||||
'GTJA_alpha002',
|
||||
'GTJA_alpha010',
|
||||
'GTJA_alpha011',
|
||||
'GTJA_alpha044',
|
||||
'GTJA_alpha036',
|
||||
'GTJA_alpha027',
|
||||
'GTJA_alpha109',
|
||||
'GTJA_alpha104',
|
||||
'GTJA_alpha103',
|
||||
'GTJA_alpha085',
|
||||
'GTJA_alpha111',
|
||||
'GTJA_alpha092',
|
||||
'GTJA_alpha067',
|
||||
'GTJA_alpha060',
|
||||
'GTJA_alpha062',
|
||||
'GTJA_alpha063',
|
||||
'GTJA_alpha079',
|
||||
'GTJA_alpha073',
|
||||
'GTJA_alpha087',
|
||||
'GTJA_alpha117',
|
||||
'GTJA_alpha113',
|
||||
'GTJA_alpha138',
|
||||
'GTJA_alpha121',
|
||||
'GTJA_alpha124',
|
||||
'GTJA_alpha133',
|
||||
'GTJA_alpha131',
|
||||
'GTJA_alpha118',
|
||||
'GTJA_alpha164',
|
||||
'GTJA_alpha162',
|
||||
'GTJA_alpha157',
|
||||
'GTJA_alpha171',
|
||||
'GTJA_alpha177',
|
||||
'GTJA_alpha180',
|
||||
'GTJA_alpha188',
|
||||
'GTJA_alpha191',
|
||||
]
|
||||
|
||||
|
||||
def get_label_factor(label_name: str) -> dict:
|
||||
"""获取Label因子定义字典。
|
||||
@@ -276,52 +317,84 @@ def get_label_factor(label_name: str) -> dict:
|
||||
# 辅助函数
|
||||
# =============================================================================
|
||||
def register_factors(
|
||||
engine: FactorEngine,
|
||||
selected_factors: List[str],
|
||||
factor_definitions: dict,
|
||||
label_factor: dict,
|
||||
engine: FactorEngine,
|
||||
selected_factors: List[str],
|
||||
factor_definitions: dict,
|
||||
label_factor: dict,
|
||||
excluded_factors: Optional[List[str]] = None,
|
||||
) -> List[str]:
|
||||
"""注册因子。
|
||||
|
||||
selected_factors 从 metadata 查询,factor_definitions 用 DSL 表达式注册。
|
||||
excluded_factors 中的因子会被排除,不参与计算。
|
||||
|
||||
Args:
|
||||
engine: FactorEngine实例
|
||||
selected_factors: 从metadata中选择的因子名称列表
|
||||
factor_definitions: 通过表达式定义的因子字典
|
||||
label_factor: label因子定义字典
|
||||
excluded_factors: 需要排除的因子名称列表,默认为None
|
||||
|
||||
Returns:
|
||||
特征列名称列表
|
||||
特征列名称列表(已排除excluded_factors中的因子)
|
||||
"""
|
||||
print("=" * 80)
|
||||
print("注册因子")
|
||||
print("=" * 80)
|
||||
|
||||
# 处理排除列表
|
||||
excluded = set(excluded_factors) if excluded_factors else set()
|
||||
if excluded:
|
||||
print(f"\n[排除因子] 以下 {len(excluded)} 个因子将被排除:")
|
||||
for name in sorted(excluded):
|
||||
print(f" - {name}")
|
||||
|
||||
# 过滤 SELECTED_FACTORS 中的因子(排除excluded_factors)
|
||||
filtered_selected = [name for name in selected_factors if name not in excluded]
|
||||
excluded_from_selected = set(selected_factors) - set(filtered_selected)
|
||||
if excluded_from_selected:
|
||||
print(
|
||||
f"\n[排除详情] 从 SELECTED_FACTORS 排除 {len(excluded_from_selected)} 个因子"
|
||||
)
|
||||
|
||||
# 注册 SELECTED_FACTORS 中的因子(已在 metadata 中)
|
||||
print("\n注册特征因子(从 metadata):")
|
||||
for name in selected_factors:
|
||||
for name in filtered_selected:
|
||||
engine.add_factor(name)
|
||||
print(f" - {name}")
|
||||
|
||||
# 过滤 FACTOR_DEFINITIONS 中的因子(排除excluded_factors)
|
||||
filtered_definitions = {
|
||||
name: expr for name, expr in factor_definitions.items() if name not in excluded
|
||||
}
|
||||
excluded_from_definitions = set(factor_definitions.keys()) - set(
|
||||
filtered_definitions.keys()
|
||||
)
|
||||
if excluded_from_definitions:
|
||||
print(
|
||||
f"\n[排除详情] 从 FACTOR_DEFINITIONS 排除 {len(excluded_from_definitions)} 个因子"
|
||||
)
|
||||
|
||||
# 注册 FACTOR_DEFINITIONS 中的因子(通过表达式,尚未在 metadata 中)
|
||||
print("\n注册特征因子(表达式):")
|
||||
for name, expr in factor_definitions.items():
|
||||
for name, expr in filtered_definitions.items():
|
||||
engine.add_factor(name, expr)
|
||||
print(f" - {name}: {expr}")
|
||||
|
||||
# 注册 label 因子(通过表达式)
|
||||
# 注册 label 因子(通过表达式,label因子不受excluded_factors影响)
|
||||
print("\n注册 Label 因子(表达式):")
|
||||
for name, expr in label_factor.items():
|
||||
engine.add_factor(name, expr)
|
||||
print(f" - {name}: {expr}")
|
||||
|
||||
# 特征列 = SELECTED_FACTORS + FACTOR_DEFINITIONS 的 keys
|
||||
feature_cols = selected_factors + list(factor_definitions.keys())
|
||||
# 特征列 = 过滤后的 SELECTED_FACTORS + 过滤后的 FACTOR_DEFINITIONS 的 keys
|
||||
feature_cols = filtered_selected + list(filtered_definitions.keys())
|
||||
|
||||
print(f"\n特征因子数: {len(feature_cols)}")
|
||||
print(f" - 来自 metadata: {len(selected_factors)}")
|
||||
print(f" - 来自表达式: {len(factor_definitions)}")
|
||||
print(f" - 来自 metadata: {len(filtered_selected)}")
|
||||
print(f" - 来自表达式: {len(filtered_definitions)}")
|
||||
if excluded:
|
||||
print(f" - 已排除: {len(excluded)}")
|
||||
print(f"Label: {list(label_factor.keys())[0]}")
|
||||
print(f"已注册因子总数: {len(engine.list_registered())}")
|
||||
|
||||
@@ -329,11 +402,11 @@ def register_factors(
|
||||
|
||||
|
||||
def prepare_data(
|
||||
engine: FactorEngine,
|
||||
feature_cols: List[str],
|
||||
start_date: str,
|
||||
end_date: str,
|
||||
label_name: str,
|
||||
engine: FactorEngine,
|
||||
feature_cols: List[str],
|
||||
start_date: str,
|
||||
end_date: str,
|
||||
label_name: str,
|
||||
) -> pl.DataFrame:
|
||||
"""准备数据。
|
||||
|
||||
@@ -391,11 +464,11 @@ def stock_pool_filter(df: pl.DataFrame) -> pl.Series:
|
||||
"""
|
||||
# 代码筛选(排除创业板、科创板、北交所)
|
||||
code_filter = (
|
||||
~df["ts_code"].str.starts_with("30") # 排除创业板
|
||||
& ~df["ts_code"].str.starts_with("68") # 排除科创板
|
||||
& ~df["ts_code"].str.starts_with("8") # 排除北交所
|
||||
& ~df["ts_code"].str.starts_with("9") # 排除北交所
|
||||
& ~df["ts_code"].str.starts_with("4") # 排除北交所
|
||||
~df["ts_code"].str.starts_with("30") # 排除创业板
|
||||
& ~df["ts_code"].str.starts_with("68") # 排除科创板
|
||||
& ~df["ts_code"].str.starts_with("8") # 排除北交所
|
||||
& ~df["ts_code"].str.starts_with("9") # 排除北交所
|
||||
& ~df["ts_code"].str.starts_with("4") # 排除北交所
|
||||
)
|
||||
|
||||
# 在已筛选的股票中,选取市值最小的500只
|
||||
@@ -410,7 +483,6 @@ def stock_pool_filter(df: pl.DataFrame) -> pl.Series:
|
||||
# 定义筛选所需的基础列
|
||||
STOCK_FILTER_REQUIRED_COLUMNS = ["total_mv"]
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# 输出配置
|
||||
# =============================================================================
|
||||
@@ -418,7 +490,7 @@ OUTPUT_DIR = "output"
|
||||
SAVE_PREDICTIONS = True
|
||||
|
||||
# 模型保存配置
|
||||
SAVE_MODEL = True # 是否保存模型
|
||||
SAVE_MODEL = False # 是否保存模型
|
||||
MODEL_SAVE_DIR = "models" # 模型保存目录
|
||||
|
||||
# Top N 配置:每日推荐股票数量
|
||||
@@ -451,7 +523,7 @@ def get_output_path(model_type: str, test_start: str, test_end: str) -> str:
|
||||
|
||||
|
||||
def get_model_save_path(
|
||||
model_type: str, model_name: Optional[str] = None
|
||||
model_type: str, model_name: Optional[str] = None
|
||||
) -> Optional[str]:
|
||||
"""生成模型保存路径。
|
||||
|
||||
@@ -477,10 +549,10 @@ def get_model_save_path(
|
||||
|
||||
|
||||
def save_model_with_factors(
|
||||
model,
|
||||
model_path: str,
|
||||
selected_factors: List[str],
|
||||
factor_definitions: dict,
|
||||
model,
|
||||
model_path: str,
|
||||
selected_factors: List[str],
|
||||
factor_definitions: dict,
|
||||
) -> None:
|
||||
"""保存模型及关联的因子信息。
|
||||
|
||||
|
||||
Reference in New Issue
Block a user