feat(probe-selection): 添加探针法因子筛选模块
This commit is contained in:
284
src/experiment/probe_selection/probe_selector.py
Normal file
284
src/experiment/probe_selection/probe_selector.py
Normal file
@@ -0,0 +1,284 @@
|
||||
"""探针选择器 - 主类
|
||||
|
||||
协调整个探针筛选流程,执行迭代特征选择。
|
||||
"""
|
||||
|
||||
from typing import List, Optional
|
||||
|
||||
import polars as pl
|
||||
|
||||
from src.experiment.probe_selection.importance_evaluator import ImportanceEvaluator
|
||||
from src.experiment.probe_selection.noise_generator import NoiseGenerator
|
||||
from src.experiment.probe_selection.probe_trainer import ProbeTrainer
|
||||
|
||||
|
||||
class ProbeSelector:
|
||||
"""探针选择器
|
||||
|
||||
实现增强探针法因子筛选算法:
|
||||
1. 注入噪音探针
|
||||
2. 多任务训练(回归+分类)
|
||||
3. 基于噪音及格线交叉淘汰
|
||||
4. 迭代直到收敛
|
||||
|
||||
关键约束:
|
||||
- 分类目标使用截面中位数
|
||||
- 强制使用 Gain 重要性
|
||||
- 训练时使用验证集早停
|
||||
- Polars 零拷贝操作
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
n_iterations: int = 5,
|
||||
n_noise_features: int = 10,
|
||||
validation_ratio: float = 0.15,
|
||||
random_state: int = 42,
|
||||
regression_params: Optional[dict] = None,
|
||||
classification_params: Optional[dict] = None,
|
||||
verbose: bool = True,
|
||||
):
|
||||
"""初始化探针选择器
|
||||
|
||||
Args:
|
||||
n_iterations: 最大迭代轮数 K
|
||||
n_noise_features: 每轮注入的噪音数 M
|
||||
validation_ratio: 验证集比例(用于早停)
|
||||
random_state: 随机种子
|
||||
regression_params: 回归模型参数
|
||||
classification_params: 分类模型参数
|
||||
verbose: 是否输出详细日志
|
||||
"""
|
||||
self.n_iterations = n_iterations
|
||||
self.n_noise_features = n_noise_features
|
||||
self.validation_ratio = validation_ratio
|
||||
self.random_state = random_state
|
||||
self.verbose = verbose
|
||||
|
||||
# 初始化子组件
|
||||
self.noise_generator = NoiseGenerator(random_state=random_state)
|
||||
self.trainer = ProbeTrainer(
|
||||
regression_params=regression_params,
|
||||
classification_params=classification_params,
|
||||
validation_ratio=validation_ratio,
|
||||
random_state=random_state,
|
||||
)
|
||||
self.evaluator = ImportanceEvaluator(noise_prefix=NoiseGenerator.NOISE_PREFIX)
|
||||
|
||||
# 存储历史记录
|
||||
self.selection_history: List[dict] = []
|
||||
self.final_features: Optional[List[str]] = None
|
||||
|
||||
def select(
|
||||
self,
|
||||
data: pl.DataFrame,
|
||||
feature_cols: List[str],
|
||||
target_col_regression: str,
|
||||
date_col: str = "trade_date",
|
||||
) -> List[str]:
|
||||
"""执行特征选择
|
||||
|
||||
Args:
|
||||
data: 训练数据
|
||||
feature_cols: 候选特征列表
|
||||
target_col_regression: 回归目标列名
|
||||
date_col: 日期列名
|
||||
|
||||
Returns:
|
||||
筛选后的特征列表
|
||||
"""
|
||||
remaining_features = feature_cols.copy()
|
||||
original_count = len(remaining_features)
|
||||
|
||||
if self.verbose:
|
||||
print("=" * 80)
|
||||
print("增强探针法因子筛选")
|
||||
print("=" * 80)
|
||||
print(f"\n初始特征数: {original_count}")
|
||||
print(f"迭代轮数: {self.n_iterations}")
|
||||
print(f"每轮探针数: {self.n_noise_features}")
|
||||
print(f"验证集比例: {self.validation_ratio:.0%}")
|
||||
|
||||
for iteration in range(1, self.n_iterations + 1):
|
||||
if self.verbose:
|
||||
print(f"\n{'=' * 80}")
|
||||
print(f"探针筛选第 {iteration}/{self.n_iterations} 轮")
|
||||
print(f"当前候选特征: {len(remaining_features)} 个")
|
||||
print("=" * 80)
|
||||
|
||||
# 注入探针
|
||||
current_features = remaining_features.copy()
|
||||
feature_matrix = data.select(
|
||||
current_features + [target_col_regression, date_col]
|
||||
)
|
||||
|
||||
# 注入噪音特征
|
||||
seed = self.random_state + iteration # 每轮使用不同种子
|
||||
data_with_noise = self.noise_generator.generate_noise(
|
||||
feature_matrix, self.n_noise_features, seed
|
||||
)
|
||||
|
||||
all_feature_cols = (
|
||||
current_features
|
||||
+ self.noise_generator.get_noise_columns(data_with_noise)
|
||||
)
|
||||
|
||||
if self.verbose:
|
||||
print(f"\n[1/4] 注入探针: {self.n_noise_features} 列噪音特征")
|
||||
|
||||
# 多任务训练
|
||||
if self.verbose:
|
||||
print("\n[2/4] 多任务训练(回归 + 分类)...")
|
||||
|
||||
self.trainer.fit(
|
||||
df=data_with_noise,
|
||||
feature_cols=all_feature_cols,
|
||||
target_col_regression=target_col_regression,
|
||||
date_col=date_col,
|
||||
)
|
||||
|
||||
# 获取训练信息
|
||||
train_info = self.trainer.get_training_info()
|
||||
if self.verbose:
|
||||
print(
|
||||
f" 数据切分: 训练集 {train_info.get('train_size')} 条, 验证集 {train_info.get('val_size')} 条"
|
||||
)
|
||||
if "regression_best_iter" in train_info:
|
||||
print(f" 回归模型早停: {train_info['regression_best_iter']} 轮")
|
||||
if "classification_best_iter" in train_info:
|
||||
print(
|
||||
f" 分类模型早停: {train_info['classification_best_iter']} 轮"
|
||||
)
|
||||
|
||||
# 获取特征重要性
|
||||
reg_imp, cls_imp = self.trainer.get_feature_importance(
|
||||
importance_type="gain"
|
||||
)
|
||||
|
||||
if self.verbose:
|
||||
print("\n[3/4] 计算及格线...")
|
||||
|
||||
# 评估并淘汰
|
||||
remaining_features = self.evaluator.evaluate(
|
||||
regression_importance=reg_imp,
|
||||
classification_importance=cls_imp,
|
||||
candidate_features=current_features,
|
||||
)
|
||||
|
||||
thresholds = self.evaluator.get_thresholds()
|
||||
if self.verbose:
|
||||
print(f" 回归及格线: {thresholds[0]:.6f}")
|
||||
print(f" 分类及格线: {thresholds[1]:.6f}")
|
||||
|
||||
# 记录本轮结果
|
||||
stats = self.evaluator.get_elimination_stats()
|
||||
eliminated = stats["eliminated_count"]
|
||||
|
||||
if self.verbose:
|
||||
print(f"\n[4/4] 交叉淘汰...")
|
||||
print(f" 淘汰特征: {eliminated} 个")
|
||||
print(f" 剩余特征: {stats['survived_count']} 个")
|
||||
|
||||
if eliminated > 0:
|
||||
print("\n 淘汰的特征:")
|
||||
for feat_info in stats["eliminated_features"][:10]: # 只显示前10个
|
||||
print(
|
||||
f" - {feat_info['feature']}: 回归={feat_info['regression_importance']:.6f}, 分类={feat_info['classification_importance']:.6f}"
|
||||
)
|
||||
if eliminated > 10:
|
||||
print(f" ... 还有 {eliminated - 10} 个")
|
||||
|
||||
# 保存历史
|
||||
self.selection_history.append(
|
||||
{
|
||||
"iteration": iteration,
|
||||
"initial_features": len(current_features),
|
||||
"eliminated": eliminated,
|
||||
"survived": len(remaining_features),
|
||||
"regression_threshold": thresholds[0],
|
||||
"classification_threshold": thresholds[1],
|
||||
"eliminated_features": [
|
||||
f["feature"] for f in stats["eliminated_features"]
|
||||
],
|
||||
}
|
||||
)
|
||||
|
||||
# 检查终止条件
|
||||
if eliminated == 0:
|
||||
if self.verbose:
|
||||
print(f"\n[提前终止] 第 {iteration} 轮没有因子被淘汰")
|
||||
break
|
||||
|
||||
self.final_features = remaining_features
|
||||
|
||||
if self.verbose:
|
||||
print("\n" + "=" * 80)
|
||||
print("探针筛选完成")
|
||||
print("=" * 80)
|
||||
print(f"\n原始特征数: {original_count}")
|
||||
print(f"最终特征数: {len(remaining_features)}")
|
||||
print(f"淘汰特征数: {original_count - len(remaining_features)}")
|
||||
print(
|
||||
f"淘汰比例: {(original_count - len(remaining_features)) / original_count:.1%}"
|
||||
)
|
||||
print(f"\n最终特征列表:")
|
||||
for i, feat in enumerate(remaining_features, 1):
|
||||
print(f" {i:2d}. {feat}")
|
||||
|
||||
return remaining_features
|
||||
|
||||
def get_selection_history(self) -> List[dict]:
|
||||
"""获取筛选历史
|
||||
|
||||
Returns:
|
||||
每轮筛选的历史记录列表
|
||||
"""
|
||||
return self.selection_history
|
||||
|
||||
def get_importance_report(
|
||||
self,
|
||||
data: pl.DataFrame,
|
||||
feature_cols: List[str],
|
||||
target_col_regression: str,
|
||||
date_col: str = "trade_date",
|
||||
) -> List[dict]:
|
||||
"""获取最后一轮的重要性详细报告
|
||||
|
||||
Args:
|
||||
data: 数据
|
||||
feature_cols: 特征列表
|
||||
target_col_regression: 回归目标
|
||||
date_col: 日期列名
|
||||
|
||||
Returns:
|
||||
特征对比列表
|
||||
"""
|
||||
# 注入探针
|
||||
feature_matrix = data.select(feature_cols + [target_col_regression, date_col])
|
||||
data_with_noise = self.noise_generator.generate_noise(
|
||||
feature_matrix, self.n_noise_features, self.random_state
|
||||
)
|
||||
all_feature_cols = feature_cols + self.noise_generator.get_noise_columns(
|
||||
data_with_noise
|
||||
)
|
||||
|
||||
# 训练
|
||||
self.trainer.fit(
|
||||
df=data_with_noise,
|
||||
feature_cols=all_feature_cols,
|
||||
target_col_regression=target_col_regression,
|
||||
date_col=date_col,
|
||||
)
|
||||
|
||||
# 获取重要性
|
||||
reg_imp, cls_imp = self.trainer.get_feature_importance(importance_type="gain")
|
||||
|
||||
# 执行评估以获取及格线
|
||||
self.evaluator.evaluate(reg_imp, cls_imp, feature_cols)
|
||||
|
||||
# 获取详细对比
|
||||
comparison = self.evaluator.get_feature_comparison(
|
||||
reg_imp, cls_imp, feature_cols
|
||||
)
|
||||
|
||||
return comparison
|
||||
Reference in New Issue
Block a user