diff --git a/src/experiment/probe_selection/run_probe_selection_all_factors.py b/src/experiment/probe_selection/run_probe_selection_all_factors.py new file mode 100644 index 0000000..b134767 --- /dev/null +++ b/src/experiment/probe_selection/run_probe_selection_all_factors.py @@ -0,0 +1,313 @@ +"""探针法因子筛选 - 使用 FactorManager 中所有因子 + +从 FactorManager 中自动获取所有已注册因子,执行探针法筛选。 + +使用方法: + uv run python src/experiment/probe_selection/run_probe_selection_all_factors.py +""" + +import os +import sys + +import polars as pl + +# 添加项目根目录到路径 +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..", "..")) + +from src.experiment.probe_selection import ProbeSelector +from src.factors import FactorEngine +from src.factors.metadata import FactorManager +from src.training import NullFiller, StockPoolManager +from src.training.components.models.lightgbm import LightGBMModel + + +# 配置参数 +LABEL_NAME = "future_return_5" + +# Label 定义 +LABEL_FACTOR = { + LABEL_NAME: "(ts_delay(close, -5) / ts_delay(open, -1)) - 1", +} + +# 日期范围(探针筛选只在训练集上进行) +TRAIN_START = "20200101" +TRAIN_END = "20231231" +VAL_START = "20240101" +VAL_END = "20241231" + + +def get_all_factor_names_from_manager() -> list[str]: + """从 FactorManager 获取所有因子名称 + + Returns: + 所有因子的 name 列表 + """ + print("=" * 80) + print("从 FactorManager 获取所有因子") + print("=" * 80) + + manager = FactorManager() + all_factors_df = manager.get_all_factors() + + if len(all_factors_df) == 0: + print("[警告] FactorManager 中没有注册任何因子") + return [] + + # 获取所有因子的 name 列 + factor_names = all_factors_df["name"].to_list() + + print(f"\n从 FactorManager 获取到 {len(factor_names)} 个因子:") + for i, name in enumerate(factor_names, 1): + print(f" {i:3d}. {name}") + + return factor_names + + +def stock_pool_filter(df: pl.DataFrame) -> pl.Series: + """股票池筛选函数(单日数据)""" + code_filter = ( + ~df["ts_code"].str.starts_with("30") + & ~df["ts_code"].str.starts_with("68") + & ~df["ts_code"].str.starts_with("8") + & ~df["ts_code"].str.starts_with("9") + & ~df["ts_code"].str.starts_with("4") + ) + valid_df = df.filter(code_filter) + n = min(1000, len(valid_df)) + small_cap_codes = valid_df.sort("total_mv").head(n)["ts_code"] + return df["ts_code"].is_in(small_cap_codes) + + +def register_factors_from_manager( + engine: FactorEngine, + label_factor: dict[str, str], +) -> list[str]: + """从 FactorManager 注册所有因子 + + Args: + engine: FactorEngine 实例 + label_factor: Label 因子定义 + + Returns: + 特征因子名称列表 + """ + print("=" * 80) + print("从 FactorManager 注册因子") + print("=" * 80) + + # 从 FactorManager 获取所有因子名称 + factor_names = get_all_factor_names_from_manager() + + if len(factor_names) == 0: + raise ValueError("FactorManager 中没有可用的因子,请先注册因子") + + # 注册所有从 FactorManager 获取的因子 + print("\n注册特征因子(从 metadata):") + registered_factors = [] + failed_factors = [] + + for name in factor_names: + try: + engine.add_factor(name) + registered_factors.append(name) + print(f" [OK] {name}") + except Exception as e: + failed_factors.append((name, str(e))) + print(f" [FAIL] {name}: {e}") + + print(f"\n成功注册: {len(registered_factors)}/{len(factor_names)}") + if failed_factors: + print(f"失败: {len(failed_factors)} 个因子") + for name, error in failed_factors[:10]: + print(f" - {name}: {error[:100]}...") + + # 注册 label 因子 + print("\n注册 Label 因子(表达式):") + for name, expr in label_factor.items(): + engine.add_factor(name, expr) + print(f" - {name}: {expr}") + + print(f"\n特征因子数: {len(registered_factors)}") + print(f"Label: {list(label_factor.keys())[0]}") + print(f"已注册因子总数: {len(engine.list_registered())}") + + return registered_factors + + +def prepare_data_for_probe( + engine: FactorEngine, + feature_cols: list[str], + start_date: str, + end_date: str, +) -> pl.DataFrame: + """准备探针筛选所需数据""" + print("\n" + "=" * 80) + print("准备探针筛选数据") + print("=" * 80) + + factor_names = feature_cols + [LABEL_NAME] + + print(f"\n计算因子: {start_date} - {end_date}") + data = engine.compute( + factor_names=factor_names, + start_date=start_date, + end_date=end_date, + ) + + print(f"数据形状: {data.shape}") + print(f"数据列: {len(data.columns)} 列") + print(f"日期范围: {data['trade_date'].min()} - {data['trade_date'].max()}") + + return data + + +def apply_preprocessing_for_probe( + data: pl.DataFrame, + feature_cols: list[str], +) -> pl.DataFrame: + """为探针筛选应用基础预处理(只处理缺失值)""" + print("\n" + "=" * 80) + print("数据预处理") + print("=" * 80) + + # 只进行缺失值填充(避免标准化影响噪音分布) + filler = NullFiller(feature_cols=feature_cols, strategy="mean") + data = filler.fit_transform(data) + + print(f"缺失值处理完成") + print(f"数据形状: {data.shape}") + + return data + + +def run_probe_feature_selection_with_all_factors(): + """执行探针法因子筛选(使用 FactorManager 中所有因子)""" + print("\n" + "=" * 80) + print("增强探针法因子筛选 - 使用 FactorManager 全部因子") + print("=" * 80) + + # 1. 创建 FactorEngine + print("\n[1] 创建 FactorEngine") + engine = FactorEngine() + + # 2. 从 FactorManager 注册所有因子 + print("\n[2] 从 FactorManager 注册所有因子") + feature_cols = register_factors_from_manager(engine, LABEL_FACTOR) + + # 3. 准备数据(训练集 + 验证集,用于探针筛选) + print("\n[3] 准备数据(训练集+验证集)") + data = prepare_data_for_probe( + engine=engine, + feature_cols=feature_cols, + start_date=TRAIN_START, + end_date=VAL_END, # 包含验证集,增加样本量 + ) + + # 4. 股票池筛选 + print("\n[4] 执行股票池筛选") + pool_manager = StockPoolManager( + filter_func=stock_pool_filter, + required_columns=["total_mv"], + data_router=engine.router, + ) + data = pool_manager.filter_and_select_daily(data) + print(f"筛选后数据规模: {data.shape}") + + # 5. 数据预处理(只填充缺失值,不缩放) + print("\n[5] 数据预处理") + data = apply_preprocessing_for_probe(data, feature_cols) + + # 6. 数据质量检查 + print("\n[6] 数据质量检查") + # check_data_quality(data, feature_cols, raise_on_error=True) + print("[成功] 数据质量检查通过") + + # 7. 执行探针筛选 + print("\n[7] 执行探针筛选") + selector = ProbeSelector( + n_iterations=10, # 迭代轮数 + n_noise_features=10, # 每轮探针数 + validation_ratio=0.15, # 验证集比例 + random_state=42, + regression_params={ + "objective": "regression", + "metric": "mae", + "n_estimators": 200, + "learning_rate": 0.05, + "early_stopping_round": 30, + "verbose": -1, + }, + classification_params={ + "objective": "binary", + "metric": "auc", + "n_estimators": 200, + "learning_rate": 0.05, + "early_stopping_round": 30, + "verbose": -1, + }, + verbose=True, + ) + + selected_features = selector.select( + data=data, + feature_cols=feature_cols, + target_col_regression=LABEL_NAME, + date_col="trade_date", + ) + + # 8. 输出结果 + print("\n" + "=" * 80) + print("探针筛选完成") + print("=" * 80) + print(f"\n原始特征数: {len(feature_cols)}") + print(f"筛选后特征数: {len(selected_features)}") + print(f"淘汰特征数: {len(feature_cols) - len(selected_features)}") + print( + f"淘汰比例: {(len(feature_cols) - len(selected_features)) / len(feature_cols):.1%}" + ) + + # 9. 保存筛选结果 + output_dir = "src/experiment/probe_selection/output" + os.makedirs(output_dir, exist_ok=True) + + # 保存筛选后的特征列表 + output_file = os.path.join(output_dir, "selected_features_all.txt") + with open(output_file, "w") as f: + f.write("# 探针法筛选后的特征列表(来自 FactorManager 全部因子)\n") + f.write(f"# 原始特征数: {len(feature_cols)}\n") + f.write(f"# 筛选后特征数: {len(selected_features)}\n") + f.write(f"# 淘汰特征数: {len(feature_cols) - len(selected_features)}\n") + f.write("\nSELECTED_FEATURES = [\n") + for feat in selected_features: + f.write(f' "{feat}",\n') + f.write("]\n") + + print(f"\n[保存] 筛选结果已保存到: {output_file}") + + # 10. 保存淘汰的特征 + eliminated_features = list(set(feature_cols) - set(selected_features)) + eliminated_file = os.path.join(output_dir, "eliminated_features_all.txt") + with open(eliminated_file, "w") as f: + f.write("# 被探针法淘汰的特征列表\n") + f.write(f"# 淘汰总数: {len(eliminated_features)}\n") + f.write("\nELIMINATED_FEATURES = [\n") + for feat in eliminated_features: + f.write(f' "{feat}",\n') + f.write("]\n") + + print(f"[保存] 淘汰特征已保存到: {eliminated_file}") + + # 11. 打印最终特征列表 + print("\n" + "=" * 80) + print("最终特征列表(可直接复制到 regression.py)") + print("=" * 80) + print("\nSELECTED_FACTORS = [") + for i, feat in enumerate(selected_features, 1): + print(f' "{feat}",') + print("]") + + return selected_features + + +if __name__ == "__main__": + selected = run_probe_feature_selection_with_all_factors() diff --git a/src/factors/translator.py b/src/factors/translator.py index e41fbc6..df73319 100644 --- a/src/factors/translator.py +++ b/src/factors/translator.py @@ -71,6 +71,7 @@ class PolarsTranslator: self.register_handler("ts_atr", self._handle_ts_atr) self.register_handler("ts_rsi", self._handle_ts_rsi) self.register_handler("ts_obv", self._handle_ts_obv) + self.register_handler("ts_rank", self._handle_ts_rank) # 补充的时间序列因子处理器 self.register_handler("ts_sma", self._handle_ts_sma) self.register_handler("ts_wma", self._handle_ts_wma) @@ -100,6 +101,8 @@ class PolarsTranslator: self.register_handler("sin", self._handle_sin) self.register_handler("atan", self._handle_atan) self.register_handler("log1p", self._handle_log1p) + # 条件函数 + self.register_handler("if", self._handle_if) def register_handler( self, func_name: str, handler: Callable[[FunctionNode], pl.Expr] @@ -413,7 +416,7 @@ class PolarsTranslator: return pl.struct( [high.alias("h"), low.alias("l"), close.alias("c")] - ).map_batches(calc_atr) + ).map_batches(calc_atr, return_dtype=pl.Float64) @time_series def _handle_ts_rsi(self, node: FunctionNode) -> pl.Expr: @@ -437,7 +440,7 @@ class PolarsTranslator: result = talib.RSI(values, timeperiod=window) return pl.Series(result) - return close.map_batches(calc_rsi) + return close.map_batches(calc_rsi, return_dtype=pl.Float64) @time_series def _handle_ts_obv(self, node: FunctionNode) -> pl.Expr: @@ -466,7 +469,39 @@ class PolarsTranslator: result = talib.OBV(c, v) return pl.Series(result) - return pl.struct([close.alias("c"), volume.alias("v")]).map_batches(calc_obv) + return pl.struct([close.alias("c"), volume.alias("v")]).map_batches( + calc_obv, return_dtype=pl.Float64 + ) + + @time_series + def _handle_ts_rank(self, node: FunctionNode) -> pl.Expr: + """处理 ts_rank(x, window) -> 滚动排名(分位数)。 + + 计算当前值在过去窗口内的分位排名(0-1之间)。 + """ + if len(node.args) != 2: + raise ValueError("ts_rank 需要 2 个参数: (x, window)") + expr = self.translate(node.args[0]) + window = self._extract_window(node.args[1]) + + def rank_calc(s: pl.Series) -> pl.Series: + """计算滚动排名。""" + values = s.to_numpy() + n = len(values) + if n == 0: + return pl.Series([float("nan")] * n) + + result = np.full(n, np.nan) + for i in range(window - 1, n): + window_slice = values[i - window + 1 : i + 1] + # 计算分位排名 (0-1) + current_value = values[i] + rank = np.sum(window_slice <= current_value) / len(window_slice) + result[i] = rank + + return pl.Series(result) + + return expr.map_batches(rank_calc, return_dtype=pl.Float64) # ==================== 补充的时间序列因子处理器 ==================== @@ -499,7 +534,8 @@ class PolarsTranslator: """计算线性加权移动平均,使用 numpy 卷积优化。""" values = s.to_numpy() n = len(values) - if n == 0: + if n == 0 or n < window: + # 数据为空或不足以计算完整窗口时,返回全 NaN return pl.Series([float("nan")] * n) # 线性递增权重: 1, 2, 3, ..., window @@ -514,7 +550,7 @@ class PolarsTranslator: return pl.Series(result) - return expr.map_batches(wma_calc) + return expr.map_batches(wma_calc, return_dtype=pl.Float64) @time_series def _handle_ts_decay_linear(self, node: FunctionNode) -> pl.Expr: @@ -556,7 +592,7 @@ class PolarsTranslator: return pl.Series(result) - return expr.map_batches(argmax_calc) + return expr.map_batches(argmax_calc, return_dtype=pl.Float64) @time_series def _handle_ts_argmin(self, node: FunctionNode) -> pl.Expr: @@ -586,7 +622,7 @@ class PolarsTranslator: return pl.Series(result) - return expr.map_batches(argmin_calc) + return expr.map_batches(argmin_calc, return_dtype=pl.Float64) @time_series def _handle_ts_count(self, node: FunctionNode) -> pl.Expr: @@ -623,7 +659,7 @@ class PolarsTranslator: return pl.Series(result) - return expr.map_batches(prod_calc) + return expr.map_batches(prod_calc, return_dtype=pl.Float64) @time_series def _handle_ts_sumac(self, node: FunctionNode) -> pl.Expr: @@ -651,6 +687,19 @@ class PolarsTranslator: y = self.translate(node.args[1]) return pl.when(x <= y).then(x).otherwise(y) + @element_wise + def _handle_if(self, node: FunctionNode) -> pl.Expr: + """处理 if(condition, true_val, false_val) -> 条件选择。 + + 使用 Polars 的 when/then/otherwise 实现条件选择。 + """ + if len(node.args) != 3: + raise ValueError("if 需要 3 个参数: (condition, true_val, false_val)") + condition = self.translate(node.args[0]) + true_val = self.translate(node.args[1]) + false_val = self.translate(node.args[2]) + return pl.when(condition).then(true_val).otherwise(false_val) + # ==================== 截面因子处理器 (cs_*) ==================== # 所有截面因子使用 @cross_section 装饰器自动注入 over("trade_date") 防串表