"""探针选择器 - 主类 协调整个探针筛选流程,执行迭代特征选择。 """ from typing import List, Optional import polars as pl from src.experiment.probe_selection.importance_evaluator import ImportanceEvaluator from src.experiment.probe_selection.noise_generator import NoiseGenerator from src.experiment.probe_selection.probe_trainer import ProbeTrainer class ProbeSelector: """探针选择器 实现增强探针法因子筛选算法: 1. 注入噪音探针 2. 多任务训练(回归+分类) 3. 基于噪音及格线交叉淘汰 4. 迭代直到收敛 关键约束: - 分类目标使用截面中位数 - 强制使用 Gain 重要性 - 训练时使用验证集早停 - Polars 零拷贝操作 """ def __init__( self, n_iterations: int = 5, n_noise_features: int = 10, validation_ratio: float = 0.15, random_state: int = 42, regression_params: Optional[dict] = None, classification_params: Optional[dict] = None, verbose: bool = True, ): """初始化探针选择器 Args: n_iterations: 最大迭代轮数 K n_noise_features: 每轮注入的噪音数 M validation_ratio: 验证集比例(用于早停) random_state: 随机种子 regression_params: 回归模型参数 classification_params: 分类模型参数 verbose: 是否输出详细日志 """ self.n_iterations = n_iterations self.n_noise_features = n_noise_features self.validation_ratio = validation_ratio self.random_state = random_state self.verbose = verbose # 初始化子组件 self.noise_generator = NoiseGenerator(random_state=random_state) self.trainer = ProbeTrainer( regression_params=regression_params, classification_params=classification_params, validation_ratio=validation_ratio, random_state=random_state, ) self.evaluator = ImportanceEvaluator(noise_prefix=NoiseGenerator.NOISE_PREFIX) # 存储历史记录 self.selection_history: List[dict] = [] self.final_features: Optional[List[str]] = None def select( self, data: pl.DataFrame, feature_cols: List[str], target_col_regression: str, date_col: str = "trade_date", ) -> List[str]: """执行特征选择 Args: data: 训练数据 feature_cols: 候选特征列表 target_col_regression: 回归目标列名 date_col: 日期列名 Returns: 筛选后的特征列表 """ remaining_features = feature_cols.copy() original_count = len(remaining_features) if self.verbose: print("=" * 80) print("增强探针法因子筛选") print("=" * 80) print(f"\n初始特征数: {original_count}") print(f"迭代轮数: {self.n_iterations}") print(f"每轮探针数: {self.n_noise_features}") print(f"验证集比例: {self.validation_ratio:.0%}") for iteration in range(1, self.n_iterations + 1): if self.verbose: print(f"\n{'=' * 80}") print(f"探针筛选第 {iteration}/{self.n_iterations} 轮") print(f"当前候选特征: {len(remaining_features)} 个") print("=" * 80) # 注入探针 current_features = remaining_features.copy() feature_matrix = data.select( current_features + [target_col_regression, date_col] ) # 注入噪音特征 seed = self.random_state + iteration # 每轮使用不同种子 data_with_noise = self.noise_generator.generate_noise( feature_matrix, self.n_noise_features, seed ) all_feature_cols = ( current_features + self.noise_generator.get_noise_columns(data_with_noise) ) if self.verbose: print(f"\n[1/4] 注入探针: {self.n_noise_features} 列噪音特征") # 多任务训练 if self.verbose: print("\n[2/4] 多任务训练(回归 + 分类)...") self.trainer.fit( df=data_with_noise, feature_cols=all_feature_cols, target_col_regression=target_col_regression, date_col=date_col, ) # 获取训练信息 train_info = self.trainer.get_training_info() if self.verbose: print( f" 数据切分: 训练集 {train_info.get('train_size')} 条, 验证集 {train_info.get('val_size')} 条" ) if "regression_best_iter" in train_info: print(f" 回归模型早停: {train_info['regression_best_iter']} 轮") if "classification_best_iter" in train_info: print( f" 分类模型早停: {train_info['classification_best_iter']} 轮" ) # 获取特征重要性 reg_imp, cls_imp = self.trainer.get_feature_importance( importance_type="gain" ) if self.verbose: print("\n[3/4] 计算及格线...") # 评估并淘汰 remaining_features = self.evaluator.evaluate( regression_importance=reg_imp, classification_importance=cls_imp, candidate_features=current_features, ) thresholds = self.evaluator.get_thresholds() if self.verbose: print(f" 回归及格线: {thresholds[0]:.6f}") print(f" 分类及格线: {thresholds[1]:.6f}") # 记录本轮结果 stats = self.evaluator.get_elimination_stats() eliminated = stats["eliminated_count"] if self.verbose: print(f"\n[4/4] 交叉淘汰...") print(f" 淘汰特征: {eliminated} 个") print(f" 剩余特征: {stats['survived_count']} 个") if eliminated > 0: print("\n 淘汰的特征:") for feat_info in stats["eliminated_features"][:10]: # 只显示前10个 print( f" - {feat_info['feature']}: 回归={feat_info['regression_importance']:.6f}, 分类={feat_info['classification_importance']:.6f}" ) if eliminated > 10: print(f" ... 还有 {eliminated - 10} 个") # 保存历史 self.selection_history.append( { "iteration": iteration, "initial_features": len(current_features), "eliminated": eliminated, "survived": len(remaining_features), "regression_threshold": thresholds[0], "classification_threshold": thresholds[1], "eliminated_features": [ f["feature"] for f in stats["eliminated_features"] ], } ) # 检查终止条件 if eliminated == 0: if self.verbose: print(f"\n[提前终止] 第 {iteration} 轮没有因子被淘汰") break self.final_features = remaining_features if self.verbose: print("\n" + "=" * 80) print("探针筛选完成") print("=" * 80) print(f"\n原始特征数: {original_count}") print(f"最终特征数: {len(remaining_features)}") print(f"淘汰特征数: {original_count - len(remaining_features)}") print( f"淘汰比例: {(original_count - len(remaining_features)) / original_count:.1%}" ) print(f"\n最终特征列表:") for i, feat in enumerate(remaining_features, 1): print(f" {i:2d}. {feat}") return remaining_features def get_selection_history(self) -> List[dict]: """获取筛选历史 Returns: 每轮筛选的历史记录列表 """ return self.selection_history def get_importance_report( self, data: pl.DataFrame, feature_cols: List[str], target_col_regression: str, date_col: str = "trade_date", ) -> List[dict]: """获取最后一轮的重要性详细报告 Args: data: 数据 feature_cols: 特征列表 target_col_regression: 回归目标 date_col: 日期列名 Returns: 特征对比列表 """ # 注入探针 feature_matrix = data.select(feature_cols + [target_col_regression, date_col]) data_with_noise = self.noise_generator.generate_noise( feature_matrix, self.n_noise_features, self.random_state ) all_feature_cols = feature_cols + self.noise_generator.get_noise_columns( data_with_noise ) # 训练 self.trainer.fit( df=data_with_noise, feature_cols=all_feature_cols, target_col_regression=target_col_regression, date_col=date_col, ) # 获取重要性 reg_imp, cls_imp = self.trainer.get_feature_importance(importance_type="gain") # 执行评估以获取及格线 self.evaluator.evaluate(reg_imp, cls_imp, feature_cols) # 获取详细对比 comparison = self.evaluator.get_feature_comparison( reg_imp, cls_imp, feature_cols ) return comparison