refactor(experiment): 提取共用配置到 common 模块
- 将因子定义、日期配置、股票池筛选等提取到 common.py - 重构 learn_to_rank 和 regression 脚本,统一使用公共配置 - 简化代码结构,消除重复定义
This commit is contained in:
@@ -3,7 +3,6 @@
|
||||
# %%
|
||||
import os
|
||||
from datetime import datetime
|
||||
from typing import List
|
||||
|
||||
import polars as pl
|
||||
|
||||
@@ -13,7 +12,6 @@ from src.training import (
|
||||
LightGBMModel,
|
||||
STFilter,
|
||||
StandardScaler,
|
||||
# StockFilterConfig, # 已删除,使用 StockPoolManager + filter_func 替代
|
||||
StockPoolManager,
|
||||
Trainer,
|
||||
Winsorizer,
|
||||
@@ -22,167 +20,38 @@ from src.training import (
|
||||
)
|
||||
from src.training.config import TrainingConfig
|
||||
|
||||
|
||||
# %% md
|
||||
# ## 2. 定义辅助函数
|
||||
# %%
|
||||
def register_factors(
|
||||
engine: FactorEngine,
|
||||
selected_factors: List[str],
|
||||
factor_definitions: dict,
|
||||
label_factor: dict,
|
||||
) -> List[str]:
|
||||
"""注册因子(selected_factors 从 metadata 查询,factor_definitions 用 DSL 表达式注册)"""
|
||||
print("=" * 80)
|
||||
print("注册因子")
|
||||
print("=" * 80)
|
||||
|
||||
# 注册 SELECTED_FACTORS 中的因子(已在 metadata 中)
|
||||
print("\n注册特征因子(从 metadata):")
|
||||
for name in selected_factors:
|
||||
engine.add_factor(name)
|
||||
print(f" - {name}")
|
||||
|
||||
# 注册 FACTOR_DEFINITIONS 中的因子(通过表达式,尚未在 metadata 中)
|
||||
print("\n注册特征因子(表达式):")
|
||||
for name, expr in factor_definitions.items():
|
||||
engine.add_factor(name, expr)
|
||||
print(f" - {name}: {expr}")
|
||||
|
||||
# 注册 label 因子(通过表达式)
|
||||
print("\n注册 Label 因子(表达式):")
|
||||
for name, expr in label_factor.items():
|
||||
engine.add_factor(name, expr)
|
||||
print(f" - {name}: {expr}")
|
||||
|
||||
# 特征列 = SELECTED_FACTORS + FACTOR_DEFINITIONS 的 keys
|
||||
feature_cols = selected_factors + list(factor_definitions.keys())
|
||||
|
||||
print(f"\n特征因子数: {len(feature_cols)}")
|
||||
print(f" - 来自 metadata: {len(selected_factors)}")
|
||||
print(f" - 来自表达式: {len(factor_definitions)}")
|
||||
print(f"Label: {list(label_factor.keys())[0]}")
|
||||
print(f"已注册因子总数: {len(engine.list_registered())}")
|
||||
|
||||
return feature_cols
|
||||
|
||||
|
||||
def prepare_data(
|
||||
engine: FactorEngine,
|
||||
feature_cols: List[str],
|
||||
start_date: str,
|
||||
end_date: str,
|
||||
) -> pl.DataFrame:
|
||||
print("\n" + "=" * 80)
|
||||
print("准备数据")
|
||||
print("=" * 80)
|
||||
|
||||
# 计算因子(全市场数据)
|
||||
print(f"\n计算因子: {start_date} - {end_date}")
|
||||
factor_names = feature_cols + [LABEL_NAME] # 包含 label
|
||||
|
||||
data = engine.compute(
|
||||
factor_names=factor_names,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
)
|
||||
|
||||
print(f"数据形状: {data.shape}")
|
||||
print(f"数据列: {data.columns}")
|
||||
print(f"\n前5行预览:")
|
||||
print(data.head())
|
||||
|
||||
return data
|
||||
# 从 common 模块导入共用配置和函数
|
||||
from src.experiment.common import (
|
||||
SELECTED_FACTORS,
|
||||
FACTOR_DEFINITIONS,
|
||||
get_label_factor,
|
||||
register_factors,
|
||||
prepare_data,
|
||||
TRAIN_START,
|
||||
TRAIN_END,
|
||||
VAL_START,
|
||||
VAL_END,
|
||||
TEST_START,
|
||||
TEST_END,
|
||||
stock_pool_filter,
|
||||
STOCK_FILTER_REQUIRED_COLUMNS,
|
||||
OUTPUT_DIR,
|
||||
SAVE_PREDICTIONS,
|
||||
PERSIST_MODEL,
|
||||
TOP_N,
|
||||
)
|
||||
|
||||
|
||||
# %% md
|
||||
# ## 3. 配置参数
|
||||
# ## 2. 配置参数
|
||||
#
|
||||
# ### 3.1 因子定义
|
||||
# ### 2.1 标签定义
|
||||
# %%
|
||||
# 特征因子定义字典:新增因子只需在此处添加一行
|
||||
# Label 名称(回归任务使用连续收益率)
|
||||
LABEL_NAME = "future_return_5"
|
||||
|
||||
# 当前选择的因子列表(从 FACTOR_DEFINITIONS 中选择要使用的因子)
|
||||
SELECTED_FACTORS = [
|
||||
# ================= 1. 价格、趋势与路径依赖 =================
|
||||
"ma_5",
|
||||
"ma_20",
|
||||
"ma_ratio_5_20",
|
||||
"bias_10",
|
||||
"high_low_ratio",
|
||||
"bbi_ratio",
|
||||
"return_5",
|
||||
"return_20",
|
||||
"kaufman_ER_20",
|
||||
"mom_acceleration_10_20",
|
||||
"drawdown_from_high_60",
|
||||
"up_days_ratio_20",
|
||||
# ================= 2. 波动率、风险调整与高阶矩 =================
|
||||
"volatility_5",
|
||||
"volatility_20",
|
||||
"volatility_ratio",
|
||||
"std_return_20",
|
||||
"sharpe_ratio_20",
|
||||
"min_ret_20",
|
||||
"volatility_squeeze_5_60",
|
||||
# ================= 3. 日内微观结构与异象 =================
|
||||
"overnight_intraday_diff",
|
||||
"upper_shadow_ratio",
|
||||
"capital_retention_20",
|
||||
"max_ret_20",
|
||||
# ================= 4. 量能、流动性与量价背离 =================
|
||||
"volume_ratio_5_20",
|
||||
"turnover_rate_mean_5",
|
||||
"turnover_deviation",
|
||||
"amihud_illiq_20",
|
||||
"turnover_cv_20",
|
||||
"pv_corr_20",
|
||||
"close_vwap_deviation",
|
||||
# ================= 5. 基本面财务特征 =================
|
||||
"roe",
|
||||
"roa",
|
||||
"profit_margin",
|
||||
"debt_to_equity",
|
||||
"current_ratio",
|
||||
"net_profit_yoy",
|
||||
"revenue_yoy",
|
||||
"healthy_expansion_velocity",
|
||||
# ================= 6. 基本面估值与截面动量共振 =================
|
||||
"EP",
|
||||
"BP",
|
||||
"CP",
|
||||
"market_cap_rank",
|
||||
"turnover_rank",
|
||||
"return_5_rank",
|
||||
"EP_rank",
|
||||
"pe_expansion_trend",
|
||||
"value_price_divergence",
|
||||
"active_market_cap",
|
||||
"ebit_rank",
|
||||
]
|
||||
|
||||
# 因子定义字典(完整因子库)
|
||||
FACTOR_DEFINITIONS = {
|
||||
}
|
||||
|
||||
# Label 因子定义(不参与训练,用于计算目标)
|
||||
LABEL_FACTOR = {
|
||||
LABEL_NAME: "(ts_delay(close, -5) / ts_delay(open, -1)) - 1", # 未来5日收益率
|
||||
}
|
||||
# %% md
|
||||
# ### 3.2 训练参数配置
|
||||
# %%
|
||||
# 日期范围配置(正确的 train/val/test 三分法)
|
||||
# Train: 用于训练模型参数
|
||||
# Val: 用于验证/早停/调参(位于 train 之后,test 之前)
|
||||
# Test: 仅用于最终评估,完全独立于训练过程
|
||||
TRAIN_START = "20200101"
|
||||
TRAIN_END = "20231231"
|
||||
VAL_START = "20240101"
|
||||
VAL_END = "20241231"
|
||||
TEST_START = "20250101"
|
||||
TEST_END = "20261231"
|
||||
# 获取 Label 因子定义
|
||||
LABEL_FACTOR = get_label_factor(LABEL_NAME)
|
||||
|
||||
# 模型参数配置
|
||||
MODEL_PARAMS = {
|
||||
@@ -207,59 +76,6 @@ MODEL_PARAMS = {
|
||||
"verbose": -1,
|
||||
"random_state": 42,
|
||||
}
|
||||
|
||||
|
||||
# 股票池筛选函数
|
||||
# 使用新的 StockPoolManager API:传入自定义筛选函数和所需列/因子
|
||||
# 筛选函数接收单日 DataFrame,返回布尔 Series
|
||||
#
|
||||
# 筛选逻辑(针对单日数据):
|
||||
# 1. 先排除创业板、科创板、北交所(ST过滤由STFilter组件处理)
|
||||
# 2. 然后选取市值最小的500只股票
|
||||
def stock_pool_filter(df: pl.DataFrame) -> pl.Series:
|
||||
"""股票池筛选函数(单日数据)
|
||||
|
||||
筛选条件:
|
||||
1. 排除创业板(代码以 300 开头)
|
||||
2. 排除科创板(代码以 688 开头)
|
||||
3. 排除北交所(代码以 8、9 或 4 开头)
|
||||
4. 选取当日市值最小的500只股票
|
||||
"""
|
||||
# 代码筛选(排除创业板、科创板、北交所)
|
||||
code_filter = (
|
||||
~df["ts_code"].str.starts_with("30") # 排除创业板
|
||||
& ~df["ts_code"].str.starts_with("68") # 排除科创板
|
||||
& ~df["ts_code"].str.starts_with("8") # 排除北交所
|
||||
& ~df["ts_code"].str.starts_with("9") # 排除北交所
|
||||
& ~df["ts_code"].str.starts_with("4") # 排除北交所
|
||||
)
|
||||
|
||||
# 在已筛选的股票中,选取市值最小的500只
|
||||
# 按市值升序排序,取前500
|
||||
valid_df = df.filter(code_filter)
|
||||
n = min(1000, len(valid_df))
|
||||
small_cap_codes = valid_df.sort("total_mv").head(n)["ts_code"]
|
||||
|
||||
# 返回布尔 Series:是否在被选中的股票中
|
||||
return df["ts_code"].is_in(small_cap_codes)
|
||||
|
||||
|
||||
# 定义筛选所需的基础列
|
||||
STOCK_FILTER_REQUIRED_COLUMNS = ["total_mv"] # ST过滤由STFilter组件处理
|
||||
|
||||
# 可选:定义筛选所需的因子(如果需要用因子进行筛选)
|
||||
# STOCK_FILTER_REQUIRED_FACTORS = {
|
||||
# "market_cap_rank": "cs_rank(total_mv)",
|
||||
# }
|
||||
|
||||
|
||||
# 输出配置(相对于本文件所在目录)
|
||||
OUTPUT_DIR = "output"
|
||||
SAVE_PREDICTIONS = True
|
||||
PERSIST_MODEL = False
|
||||
|
||||
# Top N 配置:每日推荐股票数量
|
||||
TOP_N = 5 # 可调整为 10, 20 等
|
||||
# %% md
|
||||
# ## 4. 训练流程
|
||||
#
|
||||
@@ -288,6 +104,7 @@ data = prepare_data(
|
||||
feature_cols=feature_cols,
|
||||
start_date=TRAIN_START,
|
||||
end_date=TEST_END,
|
||||
label_name=LABEL_NAME,
|
||||
)
|
||||
|
||||
# 4. 打印配置信息
|
||||
|
||||
Reference in New Issue
Block a user