Files
NewStock/main/factor/select_factor.py

156 lines
5.1 KiB
Python
Raw Normal View History

2025-11-29 00:23:12 +08:00
import pandas as pd
import numpy as np
from scipy.stats import spearmanr
from sklearn.ensemble import RandomForestRegressor
from sklearn.inspection import permutation_importance
def select_factors(
df,
all_features,
label_col='label',
ic_threshold=0.01,
corr_threshold=0.5,
ir_threshold=0.3,
sign_consistency_threshold=0.3,
perm_imp_threshold=0.0,
n_perm=5,
random_state=42,
verbose=True # 新增:是否打印每步日志
):
"""
因子筛选主函数带详细过滤日志
"""
log = {} # 记录每步数量
if verbose:
print(f"🔍 开始因子筛选 | 初始因子数: {len(all_features)}")
# --- Step 0: 展平 ---
needed_cols = all_features + [label_col]
df_flat = df[needed_cols].reset_index()
X = df_flat[all_features]
y = df_flat[label_col]
# --- Step 1: 单因子 IC 筛选 ---
ic_series = X.apply(lambda col: spearmanr(col, y, nan_policy='omit')[0])
valid_features = ic_series[ic_series.abs() >= ic_threshold].index.tolist()
log['after_univariate'] = len(valid_features)
if verbose:
dropped = len(all_features) - len(valid_features)
print(f" ✅ 单变量筛选 (|IC| ≥ {ic_threshold}) → 保留 {len(valid_features)} 个 (+{dropped} 被过滤)")
if not valid_features:
return [], log
del X
X_valid = df_flat[valid_features]
# --- Step 2: 去冗余 ---
corr_mat = X_valid.corr(method='spearman').abs()
selected = []
for f in valid_features:
if not selected:
selected.append(f)
else:
max_corr = corr_mat.loc[f, selected].max()
if max_corr < corr_threshold:
selected.append(f)
else:
existing = corr_mat.loc[f, selected].idxmax()
if abs(ic_series[f]) > abs(ic_series[existing]):
selected.remove(existing)
selected.append(f)
del corr_mat, X_valid
log['after_redundancy'] = len(selected)
if verbose:
dropped = len(valid_features) - len(selected)
print(f" 🔗 去冗余 (corr < {corr_threshold}) → 保留 {len(selected)} 个 (+{dropped} 被过滤)")
if not selected:
return [], log
# --- Step 3: Permutation Importance ---
X_sel = df_flat[selected]
model = RandomForestRegressor(
n_estimators=50,
max_depth=10,
random_state=random_state,
n_jobs=-1
)
model.fit(X_sel, y)
perm_result = permutation_importance(
model, X_sel, y,
n_repeats=n_perm,
random_state=random_state,
n_jobs=-1
)
perm_imp = pd.Series(perm_result.importances_mean, index=selected)
candidates = perm_imp[perm_imp > perm_imp_threshold].index.tolist()
del model, perm_result, X_sel
# 如果全被过滤,回退到 selected
if not candidates:
candidates = selected
if verbose:
print(" ⚠️ Permutation 全过滤,回退到去冗余结果")
log['after_permutation'] = len(candidates)
if verbose and len(candidates) != len(selected):
dropped = len(selected) - len(candidates)
print(f" 📊 Permutation Importance (> {perm_imp_threshold}) → 保留 {len(candidates)} 个 (+{dropped} 被过滤)")
# --- Step 4: 时序稳定性验证 ---
grouped = df_flat.groupby('datetime')
ic_records = []
for date, group in grouped:
if len(group) < 10:
continue
row = {'datetime': date}
for f in candidates:
try:
ic, _ = spearmanr(group[f], group[label_col], nan_policy='omit')
row[f] = ic if np.isfinite(ic) else 0.0
except:
row[f] = 0.0
ic_records.append(row)
if not ic_records:
log['final'] = len(candidates)
if verbose:
print(" ⏳ 无足够时间窗口,跳过稳定性验证")
return candidates, log
ic_df = pd.DataFrame(ic_records).set_index('datetime')
del ic_records
mean_ic = ic_df.mean()
std_ic = ic_df.std().replace(0, np.nan)
ir = mean_ic / std_ic
sign_consistency = (ic_df > 0).mean()
stable_mask = (
ir.abs() >= ir_threshold
) & (
(sign_consistency >= sign_consistency_threshold) |
(sign_consistency <= 1 - sign_consistency_threshold)
)
final_features = stable_mask[stable_mask].index.tolist()
if not final_features:
final_features = candidates
if verbose:
print(" ⚠️ 稳定性全过滤,回退到 Permutation 结果")
log['final'] = len(final_features)
if verbose and len(final_features) != len(candidates):
dropped = len(candidates) - len(final_features)
print(f" 🕰️ 稳定性验证 (IR ≥ {ir_threshold}, 符号一致性 ≥ {sign_consistency_threshold}) → 保留 {len(final_features)} 个 (+{dropped} 被过滤)")
del df_flat, ic_df, mean_ic, std_ic, ir, sign_consistency
if verbose:
print(f"🎯 最终因子数: {len(final_features)}")
if len(final_features) <= 5:
print("💡 提示: 因子过少,建议降低 ic_threshold 或 corr_threshold")
return final_features, log