refactor(experiment): 提取共用配置到 common 模块
- 将因子定义、日期配置、股票池筛选等提取到 common.py - 重构 learn_to_rank 和 regression 脚本,统一使用公共配置 - 简化代码结构,消除重复定义
This commit is contained in:
278
src/experiment/common.py
Normal file
278
src/experiment/common.py
Normal file
@@ -0,0 +1,278 @@
|
||||
"""实验脚本的共用配置和辅助函数。
|
||||
|
||||
此模块包含 regression.py 和 learn_to_rank.py 共用的代码,
|
||||
避免重复维护两份相同的配置和函数。
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import List
|
||||
|
||||
import polars as pl
|
||||
|
||||
from src.factors import FactorEngine
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# 日期范围配置(正确的 train/val/test 三分法)
|
||||
# =============================================================================
|
||||
TRAIN_START = "20200101"
|
||||
TRAIN_END = "20231231"
|
||||
VAL_START = "20240101"
|
||||
VAL_END = "20241231"
|
||||
TEST_START = "20250101"
|
||||
TEST_END = "20261231"
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# 因子配置
|
||||
# =============================================================================
|
||||
# 当前选择的因子列表(从 FACTOR_DEFINITIONS 中选择要使用的因子)
|
||||
SELECTED_FACTORS = [
|
||||
# ================= 1. 价格、趋势与路径依赖 =================
|
||||
"ma_5",
|
||||
"ma_20",
|
||||
"ma_ratio_5_20",
|
||||
"bias_10",
|
||||
"high_low_ratio",
|
||||
"bbi_ratio",
|
||||
"return_5",
|
||||
"return_20",
|
||||
"kaufman_ER_20",
|
||||
"mom_acceleration_10_20",
|
||||
"drawdown_from_high_60",
|
||||
"up_days_ratio_20",
|
||||
# ================= 2. 波动率、风险调整与高阶矩 =================
|
||||
"volatility_5",
|
||||
"volatility_20",
|
||||
"volatility_ratio",
|
||||
"std_return_20",
|
||||
"sharpe_ratio_20",
|
||||
"min_ret_20",
|
||||
"volatility_squeeze_5_60",
|
||||
# ================= 3. 日内微观结构与异象 =================
|
||||
"overnight_intraday_diff",
|
||||
"upper_shadow_ratio",
|
||||
"capital_retention_20",
|
||||
"max_ret_20",
|
||||
# ================= 4. 量能、流动性与量价背离 =================
|
||||
"volume_ratio_5_20",
|
||||
"turnover_rate_mean_5",
|
||||
"turnover_deviation",
|
||||
"amihud_illiq_20",
|
||||
"turnover_cv_20",
|
||||
"pv_corr_20",
|
||||
"close_vwap_deviation",
|
||||
# ================= 5. 基本面财务特征 =================
|
||||
"roe",
|
||||
"roa",
|
||||
"profit_margin",
|
||||
"debt_to_equity",
|
||||
"current_ratio",
|
||||
"net_profit_yoy",
|
||||
"revenue_yoy",
|
||||
"healthy_expansion_velocity",
|
||||
# ================= 6. 基本面估值与截面动量共振 =================
|
||||
"EP",
|
||||
"BP",
|
||||
"CP",
|
||||
"market_cap_rank",
|
||||
"turnover_rank",
|
||||
"return_5_rank",
|
||||
"EP_rank",
|
||||
"pe_expansion_trend",
|
||||
"value_price_divergence",
|
||||
"active_market_cap",
|
||||
"ebit_rank",
|
||||
]
|
||||
|
||||
# 因子定义字典(完整因子库,用于存放尚未注册到metadata的因子)
|
||||
FACTOR_DEFINITIONS = {}
|
||||
|
||||
|
||||
def get_label_factor(label_name: str) -> dict:
|
||||
"""获取Label因子定义字典。
|
||||
|
||||
Args:
|
||||
label_name: label因子名称
|
||||
|
||||
Returns:
|
||||
Label因子定义字典
|
||||
"""
|
||||
return {
|
||||
label_name: "(ts_delay(close, -5) / ts_delay(open, -1)) - 1",
|
||||
}
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# 辅助函数
|
||||
# =============================================================================
|
||||
def register_factors(
|
||||
engine: FactorEngine,
|
||||
selected_factors: List[str],
|
||||
factor_definitions: dict,
|
||||
label_factor: dict,
|
||||
) -> List[str]:
|
||||
"""注册因子。
|
||||
|
||||
selected_factors 从 metadata 查询,factor_definitions 用 DSL 表达式注册。
|
||||
|
||||
Args:
|
||||
engine: FactorEngine实例
|
||||
selected_factors: 从metadata中选择的因子名称列表
|
||||
factor_definitions: 通过表达式定义的因子字典
|
||||
label_factor: label因子定义字典
|
||||
|
||||
Returns:
|
||||
特征列名称列表
|
||||
"""
|
||||
print("=" * 80)
|
||||
print("注册因子")
|
||||
print("=" * 80)
|
||||
|
||||
# 注册 SELECTED_FACTORS 中的因子(已在 metadata 中)
|
||||
print("\n注册特征因子(从 metadata):")
|
||||
for name in selected_factors:
|
||||
engine.add_factor(name)
|
||||
print(f" - {name}")
|
||||
|
||||
# 注册 FACTOR_DEFINITIONS 中的因子(通过表达式,尚未在 metadata 中)
|
||||
print("\n注册特征因子(表达式):")
|
||||
for name, expr in factor_definitions.items():
|
||||
engine.add_factor(name, expr)
|
||||
print(f" - {name}: {expr}")
|
||||
|
||||
# 注册 label 因子(通过表达式)
|
||||
print("\n注册 Label 因子(表达式):")
|
||||
for name, expr in label_factor.items():
|
||||
engine.add_factor(name, expr)
|
||||
print(f" - {name}: {expr}")
|
||||
|
||||
# 特征列 = SELECTED_FACTORS + FACTOR_DEFINITIONS 的 keys
|
||||
feature_cols = selected_factors + list(factor_definitions.keys())
|
||||
|
||||
print(f"\n特征因子数: {len(feature_cols)}")
|
||||
print(f" - 来自 metadata: {len(selected_factors)}")
|
||||
print(f" - 来自表达式: {len(factor_definitions)}")
|
||||
print(f"Label: {list(label_factor.keys())[0]}")
|
||||
print(f"已注册因子总数: {len(engine.list_registered())}")
|
||||
|
||||
return feature_cols
|
||||
|
||||
|
||||
def prepare_data(
|
||||
engine: FactorEngine,
|
||||
feature_cols: List[str],
|
||||
start_date: str,
|
||||
end_date: str,
|
||||
label_name: str,
|
||||
) -> pl.DataFrame:
|
||||
"""准备数据。
|
||||
|
||||
计算因子并返回包含特征和label的数据框。
|
||||
|
||||
Args:
|
||||
engine: FactorEngine实例
|
||||
feature_cols: 特征列名称列表
|
||||
start_date: 开始日期 (YYYYMMDD)
|
||||
end_date: 结束日期 (YYYYMMDD)
|
||||
label_name: label列名称
|
||||
|
||||
Returns:
|
||||
包含因子计算结果的数据框
|
||||
"""
|
||||
print("\n" + "=" * 80)
|
||||
print("准备数据")
|
||||
print("=" * 80)
|
||||
|
||||
# 计算因子(全市场数据)
|
||||
print(f"\n计算因子: {start_date} - {end_date}")
|
||||
factor_names = feature_cols + [label_name] # 包含 label
|
||||
|
||||
data = engine.compute(
|
||||
factor_names=factor_names,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
)
|
||||
|
||||
print(f"数据形状: {data.shape}")
|
||||
print(f"数据列: {data.columns}")
|
||||
print(f"\n前5行预览:")
|
||||
print(data.head())
|
||||
|
||||
return data
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# 股票池筛选配置
|
||||
# =============================================================================
|
||||
def stock_pool_filter(df: pl.DataFrame) -> pl.Series:
|
||||
"""股票池筛选函数(单日数据)。
|
||||
|
||||
筛选条件:
|
||||
1. 排除创业板(代码以 300 开头)
|
||||
2. 排除科创板(代码以 688 开头)
|
||||
3. 排除北交所(代码以 8、9 或 4 开头)
|
||||
4. 选取当日市值最小的500只股票
|
||||
|
||||
Args:
|
||||
df: 单日数据框
|
||||
|
||||
Returns:
|
||||
布尔Series,表示哪些股票被选中
|
||||
"""
|
||||
# 代码筛选(排除创业板、科创板、北交所)
|
||||
code_filter = (
|
||||
~df["ts_code"].str.starts_with("30") # 排除创业板
|
||||
& ~df["ts_code"].str.starts_with("68") # 排除科创板
|
||||
& ~df["ts_code"].str.starts_with("8") # 排除北交所
|
||||
& ~df["ts_code"].str.starts_with("9") # 排除北交所
|
||||
& ~df["ts_code"].str.starts_with("4") # 排除北交所
|
||||
)
|
||||
|
||||
# 在已筛选的股票中,选取市值最小的500只
|
||||
valid_df = df.filter(code_filter)
|
||||
n = min(500, len(valid_df))
|
||||
small_cap_codes = valid_df.sort("total_mv").head(n)["ts_code"]
|
||||
|
||||
# 返回布尔 Series:是否在被选中的股票中
|
||||
return df["ts_code"].is_in(small_cap_codes)
|
||||
|
||||
|
||||
# 定义筛选所需的基础列
|
||||
STOCK_FILTER_REQUIRED_COLUMNS = ["total_mv"]
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# 输出配置
|
||||
# =============================================================================
|
||||
OUTPUT_DIR = "output"
|
||||
SAVE_PREDICTIONS = True
|
||||
PERSIST_MODEL = False
|
||||
|
||||
# Top N 配置:每日推荐股票数量
|
||||
TOP_N = 5 # 可调整为 10, 20 等
|
||||
|
||||
|
||||
def get_output_path(model_type: str, test_start: str, test_end: str) -> str:
|
||||
"""生成输出文件路径。
|
||||
|
||||
Args:
|
||||
model_type: 模型类型("regression" 或 "rank")
|
||||
test_start: 测试开始日期
|
||||
test_end: 测试结束日期
|
||||
|
||||
Returns:
|
||||
输出文件路径
|
||||
"""
|
||||
import os
|
||||
|
||||
# 确保输出目录存在
|
||||
os.makedirs(OUTPUT_DIR, exist_ok=True)
|
||||
|
||||
# 生成文件名
|
||||
start_dt = datetime.strptime(test_start, "%Y%m%d")
|
||||
end_dt = datetime.strptime(test_end, "%Y%m%d")
|
||||
date_str = f"{start_dt.strftime('%Y%m%d')}_{end_dt.strftime('%Y%m%d')}"
|
||||
|
||||
filename = f"{model_type}_output.csv"
|
||||
return os.path.join(OUTPUT_DIR, filename)
|
||||
Reference in New Issue
Block a user