Files
ProStock/src/experiment/common.py

671 lines
18 KiB
Python
Raw Normal View History

"""实验脚本的共用配置和辅助函数。
此模块包含 regression.py learn_to_rank.py 共用的代码
避免重复维护两份相同的配置和函数
"""
from datetime import datetime
from typing import Dict, List, Optional
import polars as pl
from src.factors import FactorEngine
# =============================================================================
# 日期范围配置(正确的 train/val/test 三分法)
# =============================================================================
TRAIN_START = "20200101"
TRAIN_END = "20231231"
VAL_START = "20240101"
VAL_END = "20241231"
TEST_START = "20250101"
TEST_END = "20261231"
# =============================================================================
# 因子配置
# =============================================================================
# 当前选择的因子列表(从 FACTOR_DEFINITIONS 中选择要使用的因子)
SELECTED_FACTORS = [
"ma_5",
"ma_20",
"ma_ratio_5_20",
"bias_10",
"high_low_ratio",
"bbi_ratio",
"return_5",
"return_20",
"kaufman_ER_20",
"mom_acceleration_10_20",
"drawdown_from_high_60",
"up_days_ratio_20",
"volatility_5",
"volatility_20",
"volatility_ratio",
"std_return_20",
"sharpe_ratio_20",
"min_ret_20",
"volatility_squeeze_5_60",
"overnight_intraday_diff",
"upper_shadow_ratio",
"capital_retention_20",
"max_ret_20",
"volume_ratio_5_20",
"turnover_rate_mean_5",
"turnover_deviation",
"amihud_illiq_20",
"turnover_cv_20",
"pv_corr_20",
"close_vwap_deviation",
"roe",
"roa",
"profit_margin",
"debt_to_equity",
"current_ratio",
"net_profit_yoy",
"revenue_yoy",
"healthy_expansion_velocity",
"EP",
"BP",
"CP",
"market_cap_rank",
"turnover_rank",
"return_5_rank",
"EP_rank",
"pe_expansion_trend",
"value_price_divergence",
"active_market_cap",
"ebit_rank",
"GTJA_alpha001",
"GTJA_alpha002",
"GTJA_alpha003",
"GTJA_alpha004",
"GTJA_alpha005",
"GTJA_alpha006",
"GTJA_alpha007",
"GTJA_alpha008",
"GTJA_alpha009",
"GTJA_alpha010",
"GTJA_alpha011",
"GTJA_alpha012",
"GTJA_alpha013",
"GTJA_alpha014",
"GTJA_alpha015",
"GTJA_alpha016",
"GTJA_alpha017",
"GTJA_alpha018",
"GTJA_alpha019",
"GTJA_alpha020",
"GTJA_alpha022",
"GTJA_alpha023",
"GTJA_alpha024",
"GTJA_alpha025",
"GTJA_alpha026",
"GTJA_alpha027",
"GTJA_alpha028",
"GTJA_alpha029",
"GTJA_alpha031",
"GTJA_alpha032",
"GTJA_alpha033",
"GTJA_alpha034",
"GTJA_alpha035",
"GTJA_alpha036",
"GTJA_alpha037",
# "GTJA_alpha038",
"GTJA_alpha039",
"GTJA_alpha040",
"GTJA_alpha041",
"GTJA_alpha042",
"GTJA_alpha043",
"GTJA_alpha044",
"GTJA_alpha045",
"GTJA_alpha046",
"GTJA_alpha047",
"GTJA_alpha048",
"GTJA_alpha049",
"GTJA_alpha050",
"GTJA_alpha051",
"GTJA_alpha052",
"GTJA_alpha053",
"GTJA_alpha054",
"GTJA_alpha056",
"GTJA_alpha057",
"GTJA_alpha058",
"GTJA_alpha059",
"GTJA_alpha060",
"GTJA_alpha061",
"GTJA_alpha062",
"GTJA_alpha063",
"GTJA_alpha064",
"GTJA_alpha065",
"GTJA_alpha066",
"GTJA_alpha067",
"GTJA_alpha068",
"GTJA_alpha070",
"GTJA_alpha071",
"GTJA_alpha072",
"GTJA_alpha073",
"GTJA_alpha074",
"GTJA_alpha076",
"GTJA_alpha077",
"GTJA_alpha078",
"GTJA_alpha079",
"GTJA_alpha080",
"GTJA_alpha081",
"GTJA_alpha082",
"GTJA_alpha083",
"GTJA_alpha084",
"GTJA_alpha085",
"GTJA_alpha086",
"GTJA_alpha087",
"GTJA_alpha088",
"GTJA_alpha089",
"GTJA_alpha090",
"GTJA_alpha091",
"GTJA_alpha092",
"GTJA_alpha093",
"GTJA_alpha094",
"GTJA_alpha095",
"GTJA_alpha096",
"GTJA_alpha097",
"GTJA_alpha098",
"GTJA_alpha099",
"GTJA_alpha100",
"GTJA_alpha101",
"GTJA_alpha102",
"GTJA_alpha103",
"GTJA_alpha104",
"GTJA_alpha105",
"GTJA_alpha106",
"GTJA_alpha107",
"GTJA_alpha108",
"GTJA_alpha109",
"GTJA_alpha110",
"GTJA_alpha111",
"GTJA_alpha112",
# "GTJA_alpha113",
"GTJA_alpha114",
"GTJA_alpha115",
"GTJA_alpha117",
"GTJA_alpha118",
"GTJA_alpha119",
"GTJA_alpha120",
# "GTJA_alpha121",
"GTJA_alpha122",
"GTJA_alpha123",
"GTJA_alpha124",
"GTJA_alpha125",
"GTJA_alpha126",
"GTJA_alpha127",
"GTJA_alpha128",
"GTJA_alpha129",
"GTJA_alpha130",
"GTJA_alpha131",
"GTJA_alpha132",
"GTJA_alpha133",
"GTJA_alpha134",
"GTJA_alpha135",
"GTJA_alpha136",
# "GTJA_alpha138",
"GTJA_alpha139",
# "GTJA_alpha140",
"GTJA_alpha141",
"GTJA_alpha142",
"GTJA_alpha145",
# "GTJA_alpha146",
"GTJA_alpha148",
"GTJA_alpha150",
"GTJA_alpha151",
"GTJA_alpha152",
"GTJA_alpha153",
"GTJA_alpha154",
"GTJA_alpha155",
"GTJA_alpha156",
"GTJA_alpha157",
"GTJA_alpha158",
"GTJA_alpha159",
"GTJA_alpha160",
"GTJA_alpha161",
"GTJA_alpha162",
"GTJA_alpha163",
"GTJA_alpha164",
# "GTJA_alpha165",
"GTJA_alpha166",
"GTJA_alpha167",
"GTJA_alpha168",
"GTJA_alpha169",
"GTJA_alpha170",
"GTJA_alpha171",
"GTJA_alpha173",
"GTJA_alpha174",
"GTJA_alpha175",
"GTJA_alpha176",
"GTJA_alpha177",
"GTJA_alpha178",
"GTJA_alpha179",
"GTJA_alpha180",
# "GTJA_alpha183",
"GTJA_alpha184",
"GTJA_alpha185",
"GTJA_alpha187",
"GTJA_alpha188",
"GTJA_alpha189",
"GTJA_alpha191",
]
# 因子定义字典完整因子库用于存放尚未注册到metadata的因子
FACTOR_DEFINITIONS = {"cs_rank_circ_mv": "cs_rank(circ_mv)"}
# 需要排除的因子列表(这些因子不会被计算和使用)
# 用于临时屏蔽效果不好的因子,无需从 SELECTED_FACTORS 中删除
# EXCLUDED_FACTORS: List[str] = [
# # "GTJA_alpha005",
# # "GTJA_alpha028",
# # "GTJA_alpha023",
# # "GTJA_alpha002",
# # "GTJA_alpha010",
# # "GTJA_alpha011",
# # "GTJA_alpha044",
# # "GTJA_alpha036",
# # "GTJA_alpha027",
# # "GTJA_alpha109",
# # "GTJA_alpha104",
# # "GTJA_alpha103",
# # "GTJA_alpha085",
# # "GTJA_alpha111",
# # "GTJA_alpha092",
# # "GTJA_alpha067",
# # "GTJA_alpha060",
# # "GTJA_alpha062",
# # "GTJA_alpha063",
# # "GTJA_alpha079",
# # "GTJA_alpha073",
# # "GTJA_alpha087",
# # "GTJA_alpha117",
# # "GTJA_alpha113",
# # "GTJA_alpha138",
# # "GTJA_alpha121",
# # "GTJA_alpha124",
# # "GTJA_alpha133",
# # "GTJA_alpha131",
# # "GTJA_alpha118",
# # "GTJA_alpha164",
# # "GTJA_alpha162",
# # "GTJA_alpha157",
# # "GTJA_alpha171",
# # "GTJA_alpha177",
# # "GTJA_alpha180",
# # "GTJA_alpha188",
# # "GTJA_alpha191",
# ]
def get_label_factor(label_name: str) -> dict:
"""获取Label因子定义字典。
Args:
label_name: label因子名称
Returns:
Label因子定义字典
"""
return {
label_name: "(ts_delay(close, -5) / ts_delay(open, -1)) - 1",
}
# =============================================================================
# 辅助函数
# =============================================================================
def register_factors(
engine: FactorEngine,
selected_factors: List[str],
factor_definitions: dict,
label_factor: dict,
excluded_factors: Optional[List[str]] = None,
) -> List[str]:
"""注册因子。
selected_factors metadata 查询factor_definitions DSL 表达式注册
excluded_factors 中的因子会被排除不参与计算
Args:
engine: FactorEngine实例
selected_factors: 从metadata中选择的因子名称列表
factor_definitions: 通过表达式定义的因子字典
label_factor: label因子定义字典
excluded_factors: 需要排除的因子名称列表默认为None
Returns:
特征列名称列表已排除excluded_factors中的因子
"""
print("=" * 80)
print("注册因子")
print("=" * 80)
# 处理排除列表
excluded = set(excluded_factors) if excluded_factors else set()
if excluded:
print(f"\n[排除因子] 以下 {len(excluded)} 个因子将被排除:")
for name in sorted(excluded):
print(f" - {name}")
# 过滤 SELECTED_FACTORS 中的因子排除excluded_factors
filtered_selected = [name for name in selected_factors if name not in excluded]
excluded_from_selected = set(selected_factors) - set(filtered_selected)
if excluded_from_selected:
print(
f"\n[排除详情] 从 SELECTED_FACTORS 排除 {len(excluded_from_selected)} 个因子"
)
# 注册 SELECTED_FACTORS 中的因子(已在 metadata 中)
print("\n注册特征因子(从 metadata:")
for name in filtered_selected:
engine.add_factor(name)
print(f" - {name}")
# 过滤 FACTOR_DEFINITIONS 中的因子排除excluded_factors
filtered_definitions = {
name: expr for name, expr in factor_definitions.items() if name not in excluded
}
excluded_from_definitions = set(factor_definitions.keys()) - set(
filtered_definitions.keys()
)
if excluded_from_definitions:
print(
f"\n[排除详情] 从 FACTOR_DEFINITIONS 排除 {len(excluded_from_definitions)} 个因子"
)
# 注册 FACTOR_DEFINITIONS 中的因子(通过表达式,尚未在 metadata 中)
print("\n注册特征因子(表达式):")
for name, expr in filtered_definitions.items():
engine.add_factor(name, expr)
print(f" - {name}: {expr}")
# 注册 label 因子通过表达式label因子不受excluded_factors影响
print("\n注册 Label 因子(表达式):")
for name, expr in label_factor.items():
engine.add_factor(name, expr)
print(f" - {name}: {expr}")
# 特征列 = 过滤后的 SELECTED_FACTORS + 过滤后的 FACTOR_DEFINITIONS 的 keys
feature_cols = filtered_selected + list(filtered_definitions.keys())
print(f"\n特征因子数: {len(feature_cols)}")
print(f" - 来自 metadata: {len(filtered_selected)}")
print(f" - 来自表达式: {len(filtered_definitions)}")
if excluded:
print(f" - 已排除: {len(excluded)}")
print(f"Label: {list(label_factor.keys())[0]}")
print(f"已注册因子总数: {len(engine.list_registered())}")
return feature_cols
def prepare_data(
engine: FactorEngine,
feature_cols: List[str],
start_date: str,
end_date: str,
label_name: str,
) -> pl.DataFrame:
"""准备数据。
计算因子并返回包含特征和label的数据框
Args:
engine: FactorEngine实例
feature_cols: 特征列名称列表
start_date: 开始日期 (YYYYMMDD)
end_date: 结束日期 (YYYYMMDD)
label_name: label列名称
Returns:
包含因子计算结果的数据框
"""
print("\n" + "=" * 80)
print("准备数据")
print("=" * 80)
# 计算因子(全市场数据)
print(f"\n计算因子: {start_date} - {end_date}")
factor_names = feature_cols + [label_name] # 包含 label
data = engine.compute(
factor_names=factor_names,
start_date=start_date,
end_date=end_date,
)
print(f"数据形状: {data.shape}")
print(f"数据列: {data.columns}")
print(f"\n前5行预览:")
print(data.head())
return data
# =============================================================================
# 股票池筛选配置
# =============================================================================
def stock_pool_filter(df: pl.DataFrame) -> pl.Series:
"""股票池筛选函数(单日数据)。
筛选条件
1. 排除创业板代码以 300 开头
2. 排除科创板代码以 688 开头
3. 排除北交所代码以 89 4 开头
4. 选取当日市值最小的500只股票
Args:
df: 单日数据框
Returns:
布尔Series表示哪些股票被选中
"""
# 代码筛选(排除创业板、科创板、北交所)
code_filter = (
~df["ts_code"].str.starts_with("30") # 排除创业板
& ~df["ts_code"].str.starts_with("68") # 排除科创板
& ~df["ts_code"].str.starts_with("8") # 排除北交所
& ~df["ts_code"].str.starts_with("9") # 排除北交所
& ~df["ts_code"].str.starts_with("4") # 排除北交所
)
# 在已筛选的股票中选取流通市值最小的500只
valid_df = df.filter(code_filter)
n = min(1000, len(valid_df))
small_cap_codes = valid_df.sort("circ_mv").head(n)["ts_code"]
# 返回布尔 Series是否在被选中的股票中
return df["ts_code"].is_in(small_cap_codes)
# 定义筛选所需的基础列
STOCK_FILTER_REQUIRED_COLUMNS = ["circ_mv"]
# =============================================================================
# 输出配置
# =============================================================================
OUTPUT_DIR = "output"
SAVE_PREDICTIONS = True
# 模型保存配置
SAVE_MODEL = False # 是否保存模型
MODEL_SAVE_DIR = "models" # 模型保存目录
# Top N 配置:每日推荐股票数量
TOP_N = 5 # 可调整为 10, 20 等
def get_output_path(model_type: str, test_start: str, test_end: str) -> str:
"""生成输出文件路径。
Args:
model_type: 模型类型"regression" "rank"
test_start: 测试开始日期
test_end: 测试结束日期
Returns:
输出文件路径
"""
import os
# 确保输出目录存在
os.makedirs(OUTPUT_DIR, exist_ok=True)
# 生成文件名
start_dt = datetime.strptime(test_start, "%Y%m%d")
end_dt = datetime.strptime(test_end, "%Y%m%d")
date_str = f"{start_dt.strftime('%Y%m%d')}_{end_dt.strftime('%Y%m%d')}"
filename = f"{model_type}_output.csv"
return os.path.join(OUTPUT_DIR, filename)
def get_model_save_path(
model_type: str,
) -> Optional[str]:
"""生成模型保存路径。
模型将保存在 models/{model_type}/ 目录下包含 model.pkl factors.json
Args:
model_type: 模型类型"regression" "rank"
Returns:
模型保存路径models/{model_type}/model.pkl如果 SAVE_MODEL False 则返回 None
"""
if not SAVE_MODEL:
return None
import os
# 模型保存目录models/{model_type}/
model_dir = os.path.join(MODEL_SAVE_DIR, model_type)
os.makedirs(model_dir, exist_ok=True)
# 模型文件路径
return os.path.join(model_dir, "model.pkl")
def save_model_with_factors(
model,
model_path: str,
selected_factors: list[str],
factor_definitions: dict,
fitted_processors: list | None = None,
) -> str:
"""保存模型及关联的因子信息和处理器。
将模型因子信息和处理器保存到同一文件夹models/{model_type}/
- model.pkl: 模型文件
- factors.json: 因子信息文件
- processors.pkl: 处理器状态文件如果提供
Args:
model: 训练好的模型实例需有 save 方法
model_path: 模型保存路径 get_model_save_path 生成
selected_factors: metadata 中选择的因子名称列表
factor_definitions: 通过表达式定义的因子字典
fitted_processors: 已拟合的处理器列表可选
Returns:
模型文件夹路径
"""
import json
import os
import pickle
# 获取模型文件夹路径
model_dir = os.path.dirname(model_path)
# 1. 保存模型本身
model.save(model_path)
print(f"[模型保存] 模型已保存至: {model_path}")
# 2. 保存因子信息到 factors.json 文件
factors_path = os.path.join(model_dir, "factors.json")
factors_info = {
"selected_factors": selected_factors,
"factor_definitions": factor_definitions,
"total_feature_count": len(selected_factors) + len(factor_definitions),
"selected_factors_count": len(selected_factors),
"factor_definitions_count": len(factor_definitions),
}
with open(factors_path, "w", encoding="utf-8") as f:
json.dump(factors_info, f, ensure_ascii=False, indent=2)
print(f"[模型保存] 因子信息已保存至: {factors_path}")
print(f"[模型保存] 总计 {factors_info['total_feature_count']} 个因子")
print(f" - 来自 metadata: {factors_info['selected_factors_count']}")
print(f" - 来自表达式定义: {factors_info['factor_definitions_count']}")
# 3. 保存处理器(如果提供)
if fitted_processors is not None:
processors_path = os.path.join(model_dir, "processors.pkl")
with open(processors_path, "wb") as f:
pickle.dump(fitted_processors, f)
print(f"[模型保存] 处理器已保存至: {processors_path}")
print(f"[模型保存] 共 {len(fitted_processors)} 个处理器")
return model_dir
def load_model_factors(model_path: str) -> Optional[dict]:
"""加载模型关联的因子信息。
Args:
model_path: 模型保存路径models/{model_type}/model.pkl
Returns:
包含因子信息的字典如果文件不存在则返回 None
"""
import json
import os
# 获取模型文件夹路径
model_dir = os.path.dirname(model_path)
factors_path = os.path.join(model_dir, "factors.json")
if not os.path.exists(factors_path):
print(f"[警告] 未找到因子信息文件: {factors_path}")
return None
with open(factors_path, "r", encoding="utf-8") as f:
factors_info = json.load(f)
print(
f"[模型加载] 已加载因子信息,总计 {factors_info.get('total_feature_count', 'N/A')} 个因子"
)
return factors_info
def load_processors(model_path: str) -> list | None:
"""加载模型关联的处理器。
Args:
model_path: 模型保存路径models/{model_type}/model.pkl
Returns:
处理器列表如果文件不存在则返回 None
"""
import pickle
import os
# 获取模型文件夹路径
model_dir = os.path.dirname(model_path)
processors_path = os.path.join(model_dir, "processors.pkl")
if not os.path.exists(processors_path):
print(f"[警告] 未找到处理器文件: {processors_path}")
return None
with open(processors_path, "rb") as f:
fitted_processors = pickle.load(f)
print(f"[模型加载] 已加载 {len(fitted_processors)} 个处理器")
return fitted_processors