Files
ProStock/src/experiment/probe_selection/probe_selector.py

285 lines
9.8 KiB
Python
Raw Normal View History

"""探针选择器 - 主类
协调整个探针筛选流程执行迭代特征选择
"""
from typing import List, Optional
import polars as pl
from src.experiment.probe_selection.importance_evaluator import ImportanceEvaluator
from src.experiment.probe_selection.noise_generator import NoiseGenerator
from src.experiment.probe_selection.probe_trainer import ProbeTrainer
class ProbeSelector:
"""探针选择器
实现增强探针法因子筛选算法
1. 注入噪音探针
2. 多任务训练回归+分类
3. 基于噪音及格线交叉淘汰
4. 迭代直到收敛
关键约束
- 分类目标使用截面中位数
- 强制使用 Gain 重要性
- 训练时使用验证集早停
- Polars 零拷贝操作
"""
def __init__(
self,
n_iterations: int = 5,
n_noise_features: int = 10,
validation_ratio: float = 0.15,
random_state: int = 42,
regression_params: Optional[dict] = None,
classification_params: Optional[dict] = None,
verbose: bool = True,
):
"""初始化探针选择器
Args:
n_iterations: 最大迭代轮数 K
n_noise_features: 每轮注入的噪音数 M
validation_ratio: 验证集比例用于早停
random_state: 随机种子
regression_params: 回归模型参数
classification_params: 分类模型参数
verbose: 是否输出详细日志
"""
self.n_iterations = n_iterations
self.n_noise_features = n_noise_features
self.validation_ratio = validation_ratio
self.random_state = random_state
self.verbose = verbose
# 初始化子组件
self.noise_generator = NoiseGenerator(random_state=random_state)
self.trainer = ProbeTrainer(
regression_params=regression_params,
classification_params=classification_params,
validation_ratio=validation_ratio,
random_state=random_state,
)
self.evaluator = ImportanceEvaluator(noise_prefix=NoiseGenerator.NOISE_PREFIX)
# 存储历史记录
self.selection_history: List[dict] = []
self.final_features: Optional[List[str]] = None
def select(
self,
data: pl.DataFrame,
feature_cols: List[str],
target_col_regression: str,
date_col: str = "trade_date",
) -> List[str]:
"""执行特征选择
Args:
data: 训练数据
feature_cols: 候选特征列表
target_col_regression: 回归目标列名
date_col: 日期列名
Returns:
筛选后的特征列表
"""
remaining_features = feature_cols.copy()
original_count = len(remaining_features)
if self.verbose:
print("=" * 80)
print("增强探针法因子筛选")
print("=" * 80)
print(f"\n初始特征数: {original_count}")
print(f"迭代轮数: {self.n_iterations}")
print(f"每轮探针数: {self.n_noise_features}")
print(f"验证集比例: {self.validation_ratio:.0%}")
for iteration in range(1, self.n_iterations + 1):
if self.verbose:
print(f"\n{'=' * 80}")
print(f"探针筛选第 {iteration}/{self.n_iterations}")
print(f"当前候选特征: {len(remaining_features)}")
print("=" * 80)
# 注入探针
current_features = remaining_features.copy()
feature_matrix = data.select(
current_features + [target_col_regression, date_col]
)
# 注入噪音特征
seed = self.random_state + iteration # 每轮使用不同种子
data_with_noise = self.noise_generator.generate_noise(
feature_matrix, self.n_noise_features, seed
)
all_feature_cols = (
current_features
+ self.noise_generator.get_noise_columns(data_with_noise)
)
if self.verbose:
print(f"\n[1/4] 注入探针: {self.n_noise_features} 列噪音特征")
# 多任务训练
if self.verbose:
print("\n[2/4] 多任务训练(回归 + 分类)...")
self.trainer.fit(
df=data_with_noise,
feature_cols=all_feature_cols,
target_col_regression=target_col_regression,
date_col=date_col,
)
# 获取训练信息
train_info = self.trainer.get_training_info()
if self.verbose:
print(
f" 数据切分: 训练集 {train_info.get('train_size')} 条, 验证集 {train_info.get('val_size')}"
)
if "regression_best_iter" in train_info:
print(f" 回归模型早停: {train_info['regression_best_iter']}")
if "classification_best_iter" in train_info:
print(
f" 分类模型早停: {train_info['classification_best_iter']}"
)
# 获取特征重要性
reg_imp, cls_imp = self.trainer.get_feature_importance(
importance_type="gain"
)
if self.verbose:
print("\n[3/4] 计算及格线...")
# 评估并淘汰
remaining_features = self.evaluator.evaluate(
regression_importance=reg_imp,
classification_importance=cls_imp,
candidate_features=current_features,
)
thresholds = self.evaluator.get_thresholds()
if self.verbose:
print(f" 回归及格线: {thresholds[0]:.6f}")
print(f" 分类及格线: {thresholds[1]:.6f}")
# 记录本轮结果
stats = self.evaluator.get_elimination_stats()
eliminated = stats["eliminated_count"]
if self.verbose:
print(f"\n[4/4] 交叉淘汰...")
print(f" 淘汰特征: {eliminated}")
print(f" 剩余特征: {stats['survived_count']}")
if eliminated > 0:
print("\n 淘汰的特征:")
for feat_info in stats["eliminated_features"][:10]: # 只显示前10个
print(
f" - {feat_info['feature']}: 回归={feat_info['regression_importance']:.6f}, 分类={feat_info['classification_importance']:.6f}"
)
if eliminated > 10:
print(f" ... 还有 {eliminated - 10}")
# 保存历史
self.selection_history.append(
{
"iteration": iteration,
"initial_features": len(current_features),
"eliminated": eliminated,
"survived": len(remaining_features),
"regression_threshold": thresholds[0],
"classification_threshold": thresholds[1],
"eliminated_features": [
f["feature"] for f in stats["eliminated_features"]
],
}
)
# 检查终止条件
if eliminated == 0:
if self.verbose:
print(f"\n[提前终止] 第 {iteration} 轮没有因子被淘汰")
break
self.final_features = remaining_features
if self.verbose:
print("\n" + "=" * 80)
print("探针筛选完成")
print("=" * 80)
print(f"\n原始特征数: {original_count}")
print(f"最终特征数: {len(remaining_features)}")
print(f"淘汰特征数: {original_count - len(remaining_features)}")
print(
f"淘汰比例: {(original_count - len(remaining_features)) / original_count:.1%}"
)
print(f"\n最终特征列表:")
for i, feat in enumerate(remaining_features, 1):
print(f" {i:2d}. {feat}")
return remaining_features
def get_selection_history(self) -> List[dict]:
"""获取筛选历史
Returns:
每轮筛选的历史记录列表
"""
return self.selection_history
def get_importance_report(
self,
data: pl.DataFrame,
feature_cols: List[str],
target_col_regression: str,
date_col: str = "trade_date",
) -> List[dict]:
"""获取最后一轮的重要性详细报告
Args:
data: 数据
feature_cols: 特征列表
target_col_regression: 回归目标
date_col: 日期列名
Returns:
特征对比列表
"""
# 注入探针
feature_matrix = data.select(feature_cols + [target_col_regression, date_col])
data_with_noise = self.noise_generator.generate_noise(
feature_matrix, self.n_noise_features, self.random_state
)
all_feature_cols = feature_cols + self.noise_generator.get_noise_columns(
data_with_noise
)
# 训练
self.trainer.fit(
df=data_with_noise,
feature_cols=all_feature_cols,
target_col_regression=target_col_regression,
date_col=date_col,
)
# 获取重要性
reg_imp, cls_imp = self.trainer.get_feature_importance(importance_type="gain")
# 执行评估以获取及格线
self.evaluator.evaluate(reg_imp, cls_imp, feature_cols)
# 获取详细对比
comparison = self.evaluator.get_feature_comparison(
reg_imp, cls_imp, feature_cols
)
return comparison