"""实验脚本的共用配置和辅助函数。 此模块包含 regression.py 和 learn_to_rank.py 共用的代码, 避免重复维护两份相同的配置和函数。 """ from datetime import datetime from typing import Dict, List, Optional import polars as pl from src.factors import FactorEngine # ============================================================================= # 日期范围配置(正确的 train/val/test 三分法) # ============================================================================= TRAIN_START = "20200101" TRAIN_END = "20231231" VAL_START = "20240101" VAL_END = "20241231" TEST_START = "20250101" TEST_END = "20261231" # ============================================================================= # 因子配置 # ============================================================================= # 当前选择的因子列表(从 FACTOR_DEFINITIONS 中选择要使用的因子) SELECTED_FACTORS = [ "ma_5", "ma_20", "ma_ratio_5_20", "bias_10", "high_low_ratio", "bbi_ratio", "return_5", "return_20", "kaufman_ER_20", "mom_acceleration_10_20", "drawdown_from_high_60", "up_days_ratio_20", "volatility_5", "volatility_20", "volatility_ratio", "std_return_20", "sharpe_ratio_20", "min_ret_20", "volatility_squeeze_5_60", "overnight_intraday_diff", "upper_shadow_ratio", "capital_retention_20", "max_ret_20", "volume_ratio_5_20", "turnover_rate_mean_5", "turnover_deviation", "amihud_illiq_20", "turnover_cv_20", "pv_corr_20", "close_vwap_deviation", "roe", "roa", "profit_margin", "debt_to_equity", "current_ratio", "net_profit_yoy", "revenue_yoy", "healthy_expansion_velocity", "EP", "BP", "CP", "market_cap_rank", "turnover_rank", "return_5_rank", "EP_rank", "pe_expansion_trend", "value_price_divergence", "active_market_cap", "ebit_rank", "GTJA_alpha001", "GTJA_alpha002", "GTJA_alpha003", "GTJA_alpha004", "GTJA_alpha005", "GTJA_alpha006", "GTJA_alpha007", "GTJA_alpha008", "GTJA_alpha009", "GTJA_alpha010", "GTJA_alpha011", "GTJA_alpha012", "GTJA_alpha013", "GTJA_alpha014", "GTJA_alpha015", "GTJA_alpha016", "GTJA_alpha017", "GTJA_alpha018", "GTJA_alpha019", "GTJA_alpha020", "GTJA_alpha022", "GTJA_alpha023", "GTJA_alpha024", "GTJA_alpha025", "GTJA_alpha026", "GTJA_alpha027", "GTJA_alpha028", "GTJA_alpha029", "GTJA_alpha031", "GTJA_alpha032", "GTJA_alpha033", "GTJA_alpha034", "GTJA_alpha035", "GTJA_alpha036", "GTJA_alpha037", # "GTJA_alpha038", "GTJA_alpha039", "GTJA_alpha040", "GTJA_alpha041", "GTJA_alpha042", "GTJA_alpha043", "GTJA_alpha044", "GTJA_alpha045", "GTJA_alpha046", "GTJA_alpha047", "GTJA_alpha048", "GTJA_alpha049", "GTJA_alpha050", "GTJA_alpha051", "GTJA_alpha052", "GTJA_alpha053", "GTJA_alpha054", "GTJA_alpha056", "GTJA_alpha057", "GTJA_alpha058", "GTJA_alpha059", "GTJA_alpha060", "GTJA_alpha061", "GTJA_alpha062", "GTJA_alpha063", "GTJA_alpha064", "GTJA_alpha065", "GTJA_alpha066", "GTJA_alpha067", "GTJA_alpha068", "GTJA_alpha070", "GTJA_alpha071", "GTJA_alpha072", "GTJA_alpha073", "GTJA_alpha074", "GTJA_alpha076", "GTJA_alpha077", "GTJA_alpha078", "GTJA_alpha079", "GTJA_alpha080", "GTJA_alpha081", "GTJA_alpha082", "GTJA_alpha083", "GTJA_alpha084", "GTJA_alpha085", "GTJA_alpha086", "GTJA_alpha087", "GTJA_alpha088", "GTJA_alpha089", "GTJA_alpha090", "GTJA_alpha091", "GTJA_alpha092", "GTJA_alpha093", "GTJA_alpha094", "GTJA_alpha095", "GTJA_alpha096", "GTJA_alpha097", "GTJA_alpha098", "GTJA_alpha099", "GTJA_alpha100", "GTJA_alpha101", "GTJA_alpha102", "GTJA_alpha103", "GTJA_alpha104", "GTJA_alpha105", "GTJA_alpha106", "GTJA_alpha107", "GTJA_alpha108", "GTJA_alpha109", "GTJA_alpha110", "GTJA_alpha111", "GTJA_alpha112", # "GTJA_alpha113", "GTJA_alpha114", "GTJA_alpha115", "GTJA_alpha117", "GTJA_alpha118", "GTJA_alpha119", "GTJA_alpha120", # "GTJA_alpha121", "GTJA_alpha122", "GTJA_alpha123", "GTJA_alpha124", "GTJA_alpha125", "GTJA_alpha126", "GTJA_alpha127", "GTJA_alpha128", "GTJA_alpha129", "GTJA_alpha130", "GTJA_alpha131", "GTJA_alpha132", "GTJA_alpha133", "GTJA_alpha134", "GTJA_alpha135", "GTJA_alpha136", # "GTJA_alpha138", "GTJA_alpha139", # "GTJA_alpha140", "GTJA_alpha141", "GTJA_alpha142", "GTJA_alpha145", # "GTJA_alpha146", "GTJA_alpha148", "GTJA_alpha150", "GTJA_alpha151", "GTJA_alpha152", "GTJA_alpha153", "GTJA_alpha154", "GTJA_alpha155", "GTJA_alpha156", "GTJA_alpha157", "GTJA_alpha158", "GTJA_alpha159", "GTJA_alpha160", "GTJA_alpha161", "GTJA_alpha162", "GTJA_alpha163", "GTJA_alpha164", # "GTJA_alpha165", "GTJA_alpha166", "GTJA_alpha167", "GTJA_alpha168", "GTJA_alpha169", "GTJA_alpha170", "GTJA_alpha171", "GTJA_alpha173", "GTJA_alpha174", "GTJA_alpha175", "GTJA_alpha176", "GTJA_alpha177", "GTJA_alpha178", "GTJA_alpha179", "GTJA_alpha180", # "GTJA_alpha183", "GTJA_alpha184", "GTJA_alpha185", "GTJA_alpha187", "GTJA_alpha188", "GTJA_alpha189", "GTJA_alpha191", "chip_dispersion_90", "chip_dispersion_70", "cost_skewness", "dispersion_change_20", "price_to_avg_cost", "price_to_median_cost", "mean_median_dev", "trap_pressure", "bottom_profit", "history_position", "winner_rate_surge_5", "winner_rate_cs_rank", "winner_rate_dev_20", "winner_rate_volatility", "smart_money_accumulation", "winner_vol_corr_20", "cost_base_momentum", "bottom_cost_stability", "pivot_reversion", "chip_transition", ] # 因子定义字典(完整因子库,用于存放尚未注册到metadata的因子) FACTOR_DEFINITIONS = {"cs_rank_circ_mv": "cs_rank(circ_mv)"} # ============================================================================= # Label 配置(统一绑定 label_name 和 label_dsl) # ============================================================================= # Label 名称 LABEL_NAME = "future_return_5" # Label DSL 公式 LABEL_DSL = "(ts_delay(close, -5) / ts_delay(open, -1)) - 1" # Label 配置字典(绑定 name 和 dsl) LABEL_FACTOR = {LABEL_NAME: LABEL_DSL} def get_label_factor(label_name: str) -> dict: """获取Label因子定义字典。 警告: 此函数已废弃,请直接使用 LABEL_FACTOR 常量。 label_name 参数将被忽略,始终返回预定义的 LABEL_FACTOR。 Args: label_name: label因子名称(已废弃,仅保留参数保持向后兼容) Returns: Label因子定义字典 """ return LABEL_FACTOR # ============================================================================= # 辅助函数 # ============================================================================= def register_factors( engine: FactorEngine, selected_factors: List[str], factor_definitions: dict, label_factor: dict, excluded_factors: Optional[List[str]] = None, ) -> List[str]: """注册因子。 selected_factors 从 metadata 查询,factor_definitions 用 DSL 表达式注册。 excluded_factors 中的因子会被排除,不参与计算。 Args: engine: FactorEngine实例 selected_factors: 从metadata中选择的因子名称列表 factor_definitions: 通过表达式定义的因子字典 label_factor: label因子定义字典 excluded_factors: 需要排除的因子名称列表,默认为None Returns: 特征列名称列表(已排除excluded_factors中的因子) """ print("=" * 80) print("注册因子") print("=" * 80) # 处理排除列表 excluded = set(excluded_factors) if excluded_factors else set() if excluded: print(f"\n[排除因子] 以下 {len(excluded)} 个因子将被排除:") for name in sorted(excluded): print(f" - {name}") # 过滤 SELECTED_FACTORS 中的因子(排除excluded_factors) filtered_selected = [name for name in selected_factors if name not in excluded] excluded_from_selected = set(selected_factors) - set(filtered_selected) if excluded_from_selected: print( f"\n[排除详情] 从 SELECTED_FACTORS 排除 {len(excluded_from_selected)} 个因子" ) # 注册 SELECTED_FACTORS 中的因子(已在 metadata 中) print("\n注册特征因子(从 metadata):") for name in filtered_selected: engine.add_factor(name) print(f" - {name}") # 过滤 FACTOR_DEFINITIONS 中的因子(排除excluded_factors) filtered_definitions = { name: expr for name, expr in factor_definitions.items() if name not in excluded } excluded_from_definitions = set(factor_definitions.keys()) - set( filtered_definitions.keys() ) if excluded_from_definitions: print( f"\n[排除详情] 从 FACTOR_DEFINITIONS 排除 {len(excluded_from_definitions)} 个因子" ) # 注册 FACTOR_DEFINITIONS 中的因子(通过表达式,尚未在 metadata 中) print("\n注册特征因子(表达式):") for name, expr in filtered_definitions.items(): engine.add_factor(name, expr) print(f" - {name}: {expr}") # 注册 label 因子(通过表达式,label因子不受excluded_factors影响) print("\n注册 Label 因子(表达式):") for name, expr in label_factor.items(): engine.add_factor(name, expr) print(f" - {name}: {expr}") # 特征列 = 过滤后的 SELECTED_FACTORS + 过滤后的 FACTOR_DEFINITIONS 的 keys feature_cols = filtered_selected + list(filtered_definitions.keys()) print(f"\n特征因子数: {len(feature_cols)}") print(f" - 来自 metadata: {len(filtered_selected)}") print(f" - 来自表达式: {len(filtered_definitions)}") if excluded: print(f" - 已排除: {len(excluded)}") print(f"Label: {list(label_factor.keys())[0]}") print(f"已注册因子总数: {len(engine.list_registered())}") return feature_cols def prepare_data( engine: FactorEngine, feature_cols: List[str], start_date: str, end_date: str, label_name: str, ) -> pl.DataFrame: """准备数据。 计算因子并返回包含特征和label的数据框。 Args: engine: FactorEngine实例 feature_cols: 特征列名称列表 start_date: 开始日期 (YYYYMMDD) end_date: 结束日期 (YYYYMMDD) label_name: label列名称 Returns: 包含因子计算结果的数据框 """ print("\n" + "=" * 80) print("准备数据") print("=" * 80) # 计算因子(全市场数据) print(f"\n计算因子: {start_date} - {end_date}") factor_names = feature_cols + [label_name] # 包含 label data = engine.compute( factor_names=factor_names, start_date=start_date, end_date=end_date, ) print(f"数据形状: {data.shape}") print(f"数据列: {data.columns}") print(f"\n前5行预览:") print(data.head()) return data # ============================================================================= # 股票池筛选配置 # ============================================================================= def stock_pool_filter(df: pl.DataFrame) -> pl.Series: """股票池筛选函数(单日数据)。 筛选条件: 1. 排除创业板(代码以 300 开头) 2. 排除科创板(代码以 688 开头) 3. 排除北交所(代码以 8、9 或 4 开头) 4. 选取当日市值最小的500只股票 Args: df: 单日数据框 Returns: 布尔Series,表示哪些股票被选中 """ # 代码筛选(排除创业板、科创板、北交所) code_filter = ( ~df["ts_code"].str.starts_with("30") # 排除创业板 & ~df["ts_code"].str.starts_with("68") # 排除科创板 & ~df["ts_code"].str.starts_with("8") # 排除北交所 & ~df["ts_code"].str.starts_with("9") # 排除北交所 & ~df["ts_code"].str.starts_with("4") # 排除北交所 ) # 在已筛选的股票中,选取流通市值最小的500只 valid_df = df.filter(code_filter) n = min(1000, len(valid_df)) small_cap_codes = valid_df.sort("circ_mv").head(n)["ts_code"] # 返回布尔 Series:是否在被选中的股票中 return df["ts_code"].is_in(small_cap_codes) # 定义筛选所需的基础列 STOCK_FILTER_REQUIRED_COLUMNS = ["circ_mv"] # ============================================================================= # 输出配置 # ============================================================================= OUTPUT_DIR = "output" SAVE_PREDICTIONS = True # 模型保存配置 SAVE_MODEL = False # 是否保存模型 MODEL_SAVE_DIR = "models" # 模型保存目录 # Top N 配置:每日推荐股票数量 TOP_N = 5 # 可调整为 10, 20 等 def get_output_path(model_type: str, test_start: str, test_end: str) -> str: """生成输出文件路径。 Args: model_type: 模型类型("regression" 或 "rank") test_start: 测试开始日期 test_end: 测试结束日期 Returns: 输出文件路径 """ import os # 确保输出目录存在 os.makedirs(OUTPUT_DIR, exist_ok=True) # 生成文件名 start_dt = datetime.strptime(test_start, "%Y%m%d") end_dt = datetime.strptime(test_end, "%Y%m%d") date_str = f"{start_dt.strftime('%Y%m%d')}_{end_dt.strftime('%Y%m%d')}" filename = f"{model_type}_output.csv" return os.path.join(OUTPUT_DIR, filename) def get_model_save_path( model_type: str, ) -> Optional[str]: """生成模型保存路径。 模型将保存在 models/{model_type}/ 目录下,包含 model.pkl 和 factors.json Args: model_type: 模型类型("regression" 或 "rank") Returns: 模型保存路径(models/{model_type}/model.pkl),如果 SAVE_MODEL 为 False 则返回 None """ if not SAVE_MODEL: return None import os # 模型保存目录:models/{model_type}/ model_dir = os.path.join(MODEL_SAVE_DIR, model_type) os.makedirs(model_dir, exist_ok=True) # 模型文件路径 return os.path.join(model_dir, "model.pkl") def save_model_with_factors( model, model_path: str, selected_factors: list[str], factor_definitions: dict, fitted_processors: list | None = None, ) -> str: """保存模型及关联的因子信息和处理器。 将模型、因子信息和处理器保存到同一文件夹(models/{model_type}/)下: - model.pkl: 模型文件 - factors.json: 因子信息文件 - processors.pkl: 处理器状态文件(如果提供) Args: model: 训练好的模型实例(需有 save 方法) model_path: 模型保存路径(由 get_model_save_path 生成) selected_factors: 从 metadata 中选择的因子名称列表 factor_definitions: 通过表达式定义的因子字典 fitted_processors: 已拟合的处理器列表(可选) Returns: 模型文件夹路径 """ import json import os import pickle # 获取模型文件夹路径 model_dir = os.path.dirname(model_path) # 1. 保存模型本身 model.save(model_path) print(f"[模型保存] 模型已保存至: {model_path}") # 2. 保存因子信息到 factors.json 文件 factors_path = os.path.join(model_dir, "factors.json") factors_info = { "selected_factors": selected_factors, "factor_definitions": factor_definitions, "total_feature_count": len(selected_factors) + len(factor_definitions), "selected_factors_count": len(selected_factors), "factor_definitions_count": len(factor_definitions), } with open(factors_path, "w", encoding="utf-8") as f: json.dump(factors_info, f, ensure_ascii=False, indent=2) print(f"[模型保存] 因子信息已保存至: {factors_path}") print(f"[模型保存] 总计 {factors_info['total_feature_count']} 个因子") print(f" - 来自 metadata: {factors_info['selected_factors_count']} 个") print(f" - 来自表达式定义: {factors_info['factor_definitions_count']} 个") # 3. 保存处理器(如果提供) if fitted_processors is not None: processors_path = os.path.join(model_dir, "processors.pkl") with open(processors_path, "wb") as f: pickle.dump(fitted_processors, f) print(f"[模型保存] 处理器已保存至: {processors_path}") print(f"[模型保存] 共 {len(fitted_processors)} 个处理器") return model_dir def load_model_factors(model_path: str) -> Optional[dict]: """加载模型关联的因子信息。 Args: model_path: 模型保存路径(models/{model_type}/model.pkl) Returns: 包含因子信息的字典,如果文件不存在则返回 None """ import json import os # 获取模型文件夹路径 model_dir = os.path.dirname(model_path) factors_path = os.path.join(model_dir, "factors.json") if not os.path.exists(factors_path): print(f"[警告] 未找到因子信息文件: {factors_path}") return None with open(factors_path, "r", encoding="utf-8") as f: factors_info = json.load(f) print( f"[模型加载] 已加载因子信息,总计 {factors_info.get('total_feature_count', 'N/A')} 个因子" ) return factors_info def load_processors(model_path: str) -> list | None: """加载模型关联的处理器。 Args: model_path: 模型保存路径(models/{model_type}/model.pkl) Returns: 处理器列表,如果文件不存在则返回 None """ import pickle import os # 获取模型文件夹路径 model_dir = os.path.dirname(model_path) processors_path = os.path.join(model_dir, "processors.pkl") if not os.path.exists(processors_path): print(f"[警告] 未找到处理器文件: {processors_path}") return None with open(processors_path, "rb") as f: fitted_processors = pickle.load(f) print(f"[模型加载] 已加载 {len(fitted_processors)} 个处理器") return fitted_processors