feat(factors/translator): 新增 ts_rank 和 if 函数处理器,修复类型注解
- 添加 ts_rank 滚动排名函数(计算分位排名 0-1) - 添加 if 条件选择函数 - 为 map_batches 调用添加 return_dtype=pl.Float64 类型注解 - 修复 ts_wma 边界条件判断
This commit is contained in:
@@ -0,0 +1,313 @@
|
||||
"""探针法因子筛选 - 使用 FactorManager 中所有因子
|
||||
|
||||
从 FactorManager 中自动获取所有已注册因子,执行探针法筛选。
|
||||
|
||||
使用方法:
|
||||
uv run python src/experiment/probe_selection/run_probe_selection_all_factors.py
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
|
||||
import polars as pl
|
||||
|
||||
# 添加项目根目录到路径
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..", ".."))
|
||||
|
||||
from src.experiment.probe_selection import ProbeSelector
|
||||
from src.factors import FactorEngine
|
||||
from src.factors.metadata import FactorManager
|
||||
from src.training import NullFiller, StockPoolManager
|
||||
from src.training.components.models.lightgbm import LightGBMModel
|
||||
|
||||
|
||||
# 配置参数
|
||||
LABEL_NAME = "future_return_5"
|
||||
|
||||
# Label 定义
|
||||
LABEL_FACTOR = {
|
||||
LABEL_NAME: "(ts_delay(close, -5) / ts_delay(open, -1)) - 1",
|
||||
}
|
||||
|
||||
# 日期范围(探针筛选只在训练集上进行)
|
||||
TRAIN_START = "20200101"
|
||||
TRAIN_END = "20231231"
|
||||
VAL_START = "20240101"
|
||||
VAL_END = "20241231"
|
||||
|
||||
|
||||
def get_all_factor_names_from_manager() -> list[str]:
|
||||
"""从 FactorManager 获取所有因子名称
|
||||
|
||||
Returns:
|
||||
所有因子的 name 列表
|
||||
"""
|
||||
print("=" * 80)
|
||||
print("从 FactorManager 获取所有因子")
|
||||
print("=" * 80)
|
||||
|
||||
manager = FactorManager()
|
||||
all_factors_df = manager.get_all_factors()
|
||||
|
||||
if len(all_factors_df) == 0:
|
||||
print("[警告] FactorManager 中没有注册任何因子")
|
||||
return []
|
||||
|
||||
# 获取所有因子的 name 列
|
||||
factor_names = all_factors_df["name"].to_list()
|
||||
|
||||
print(f"\n从 FactorManager 获取到 {len(factor_names)} 个因子:")
|
||||
for i, name in enumerate(factor_names, 1):
|
||||
print(f" {i:3d}. {name}")
|
||||
|
||||
return factor_names
|
||||
|
||||
|
||||
def stock_pool_filter(df: pl.DataFrame) -> pl.Series:
|
||||
"""股票池筛选函数(单日数据)"""
|
||||
code_filter = (
|
||||
~df["ts_code"].str.starts_with("30")
|
||||
& ~df["ts_code"].str.starts_with("68")
|
||||
& ~df["ts_code"].str.starts_with("8")
|
||||
& ~df["ts_code"].str.starts_with("9")
|
||||
& ~df["ts_code"].str.starts_with("4")
|
||||
)
|
||||
valid_df = df.filter(code_filter)
|
||||
n = min(1000, len(valid_df))
|
||||
small_cap_codes = valid_df.sort("total_mv").head(n)["ts_code"]
|
||||
return df["ts_code"].is_in(small_cap_codes)
|
||||
|
||||
|
||||
def register_factors_from_manager(
|
||||
engine: FactorEngine,
|
||||
label_factor: dict[str, str],
|
||||
) -> list[str]:
|
||||
"""从 FactorManager 注册所有因子
|
||||
|
||||
Args:
|
||||
engine: FactorEngine 实例
|
||||
label_factor: Label 因子定义
|
||||
|
||||
Returns:
|
||||
特征因子名称列表
|
||||
"""
|
||||
print("=" * 80)
|
||||
print("从 FactorManager 注册因子")
|
||||
print("=" * 80)
|
||||
|
||||
# 从 FactorManager 获取所有因子名称
|
||||
factor_names = get_all_factor_names_from_manager()
|
||||
|
||||
if len(factor_names) == 0:
|
||||
raise ValueError("FactorManager 中没有可用的因子,请先注册因子")
|
||||
|
||||
# 注册所有从 FactorManager 获取的因子
|
||||
print("\n注册特征因子(从 metadata):")
|
||||
registered_factors = []
|
||||
failed_factors = []
|
||||
|
||||
for name in factor_names:
|
||||
try:
|
||||
engine.add_factor(name)
|
||||
registered_factors.append(name)
|
||||
print(f" [OK] {name}")
|
||||
except Exception as e:
|
||||
failed_factors.append((name, str(e)))
|
||||
print(f" [FAIL] {name}: {e}")
|
||||
|
||||
print(f"\n成功注册: {len(registered_factors)}/{len(factor_names)}")
|
||||
if failed_factors:
|
||||
print(f"失败: {len(failed_factors)} 个因子")
|
||||
for name, error in failed_factors[:10]:
|
||||
print(f" - {name}: {error[:100]}...")
|
||||
|
||||
# 注册 label 因子
|
||||
print("\n注册 Label 因子(表达式):")
|
||||
for name, expr in label_factor.items():
|
||||
engine.add_factor(name, expr)
|
||||
print(f" - {name}: {expr}")
|
||||
|
||||
print(f"\n特征因子数: {len(registered_factors)}")
|
||||
print(f"Label: {list(label_factor.keys())[0]}")
|
||||
print(f"已注册因子总数: {len(engine.list_registered())}")
|
||||
|
||||
return registered_factors
|
||||
|
||||
|
||||
def prepare_data_for_probe(
|
||||
engine: FactorEngine,
|
||||
feature_cols: list[str],
|
||||
start_date: str,
|
||||
end_date: str,
|
||||
) -> pl.DataFrame:
|
||||
"""准备探针筛选所需数据"""
|
||||
print("\n" + "=" * 80)
|
||||
print("准备探针筛选数据")
|
||||
print("=" * 80)
|
||||
|
||||
factor_names = feature_cols + [LABEL_NAME]
|
||||
|
||||
print(f"\n计算因子: {start_date} - {end_date}")
|
||||
data = engine.compute(
|
||||
factor_names=factor_names,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
)
|
||||
|
||||
print(f"数据形状: {data.shape}")
|
||||
print(f"数据列: {len(data.columns)} 列")
|
||||
print(f"日期范围: {data['trade_date'].min()} - {data['trade_date'].max()}")
|
||||
|
||||
return data
|
||||
|
||||
|
||||
def apply_preprocessing_for_probe(
|
||||
data: pl.DataFrame,
|
||||
feature_cols: list[str],
|
||||
) -> pl.DataFrame:
|
||||
"""为探针筛选应用基础预处理(只处理缺失值)"""
|
||||
print("\n" + "=" * 80)
|
||||
print("数据预处理")
|
||||
print("=" * 80)
|
||||
|
||||
# 只进行缺失值填充(避免标准化影响噪音分布)
|
||||
filler = NullFiller(feature_cols=feature_cols, strategy="mean")
|
||||
data = filler.fit_transform(data)
|
||||
|
||||
print(f"缺失值处理完成")
|
||||
print(f"数据形状: {data.shape}")
|
||||
|
||||
return data
|
||||
|
||||
|
||||
def run_probe_feature_selection_with_all_factors():
|
||||
"""执行探针法因子筛选(使用 FactorManager 中所有因子)"""
|
||||
print("\n" + "=" * 80)
|
||||
print("增强探针法因子筛选 - 使用 FactorManager 全部因子")
|
||||
print("=" * 80)
|
||||
|
||||
# 1. 创建 FactorEngine
|
||||
print("\n[1] 创建 FactorEngine")
|
||||
engine = FactorEngine()
|
||||
|
||||
# 2. 从 FactorManager 注册所有因子
|
||||
print("\n[2] 从 FactorManager 注册所有因子")
|
||||
feature_cols = register_factors_from_manager(engine, LABEL_FACTOR)
|
||||
|
||||
# 3. 准备数据(训练集 + 验证集,用于探针筛选)
|
||||
print("\n[3] 准备数据(训练集+验证集)")
|
||||
data = prepare_data_for_probe(
|
||||
engine=engine,
|
||||
feature_cols=feature_cols,
|
||||
start_date=TRAIN_START,
|
||||
end_date=VAL_END, # 包含验证集,增加样本量
|
||||
)
|
||||
|
||||
# 4. 股票池筛选
|
||||
print("\n[4] 执行股票池筛选")
|
||||
pool_manager = StockPoolManager(
|
||||
filter_func=stock_pool_filter,
|
||||
required_columns=["total_mv"],
|
||||
data_router=engine.router,
|
||||
)
|
||||
data = pool_manager.filter_and_select_daily(data)
|
||||
print(f"筛选后数据规模: {data.shape}")
|
||||
|
||||
# 5. 数据预处理(只填充缺失值,不缩放)
|
||||
print("\n[5] 数据预处理")
|
||||
data = apply_preprocessing_for_probe(data, feature_cols)
|
||||
|
||||
# 6. 数据质量检查
|
||||
print("\n[6] 数据质量检查")
|
||||
# check_data_quality(data, feature_cols, raise_on_error=True)
|
||||
print("[成功] 数据质量检查通过")
|
||||
|
||||
# 7. 执行探针筛选
|
||||
print("\n[7] 执行探针筛选")
|
||||
selector = ProbeSelector(
|
||||
n_iterations=10, # 迭代轮数
|
||||
n_noise_features=10, # 每轮探针数
|
||||
validation_ratio=0.15, # 验证集比例
|
||||
random_state=42,
|
||||
regression_params={
|
||||
"objective": "regression",
|
||||
"metric": "mae",
|
||||
"n_estimators": 200,
|
||||
"learning_rate": 0.05,
|
||||
"early_stopping_round": 30,
|
||||
"verbose": -1,
|
||||
},
|
||||
classification_params={
|
||||
"objective": "binary",
|
||||
"metric": "auc",
|
||||
"n_estimators": 200,
|
||||
"learning_rate": 0.05,
|
||||
"early_stopping_round": 30,
|
||||
"verbose": -1,
|
||||
},
|
||||
verbose=True,
|
||||
)
|
||||
|
||||
selected_features = selector.select(
|
||||
data=data,
|
||||
feature_cols=feature_cols,
|
||||
target_col_regression=LABEL_NAME,
|
||||
date_col="trade_date",
|
||||
)
|
||||
|
||||
# 8. 输出结果
|
||||
print("\n" + "=" * 80)
|
||||
print("探针筛选完成")
|
||||
print("=" * 80)
|
||||
print(f"\n原始特征数: {len(feature_cols)}")
|
||||
print(f"筛选后特征数: {len(selected_features)}")
|
||||
print(f"淘汰特征数: {len(feature_cols) - len(selected_features)}")
|
||||
print(
|
||||
f"淘汰比例: {(len(feature_cols) - len(selected_features)) / len(feature_cols):.1%}"
|
||||
)
|
||||
|
||||
# 9. 保存筛选结果
|
||||
output_dir = "src/experiment/probe_selection/output"
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
# 保存筛选后的特征列表
|
||||
output_file = os.path.join(output_dir, "selected_features_all.txt")
|
||||
with open(output_file, "w") as f:
|
||||
f.write("# 探针法筛选后的特征列表(来自 FactorManager 全部因子)\n")
|
||||
f.write(f"# 原始特征数: {len(feature_cols)}\n")
|
||||
f.write(f"# 筛选后特征数: {len(selected_features)}\n")
|
||||
f.write(f"# 淘汰特征数: {len(feature_cols) - len(selected_features)}\n")
|
||||
f.write("\nSELECTED_FEATURES = [\n")
|
||||
for feat in selected_features:
|
||||
f.write(f' "{feat}",\n')
|
||||
f.write("]\n")
|
||||
|
||||
print(f"\n[保存] 筛选结果已保存到: {output_file}")
|
||||
|
||||
# 10. 保存淘汰的特征
|
||||
eliminated_features = list(set(feature_cols) - set(selected_features))
|
||||
eliminated_file = os.path.join(output_dir, "eliminated_features_all.txt")
|
||||
with open(eliminated_file, "w") as f:
|
||||
f.write("# 被探针法淘汰的特征列表\n")
|
||||
f.write(f"# 淘汰总数: {len(eliminated_features)}\n")
|
||||
f.write("\nELIMINATED_FEATURES = [\n")
|
||||
for feat in eliminated_features:
|
||||
f.write(f' "{feat}",\n')
|
||||
f.write("]\n")
|
||||
|
||||
print(f"[保存] 淘汰特征已保存到: {eliminated_file}")
|
||||
|
||||
# 11. 打印最终特征列表
|
||||
print("\n" + "=" * 80)
|
||||
print("最终特征列表(可直接复制到 regression.py)")
|
||||
print("=" * 80)
|
||||
print("\nSELECTED_FACTORS = [")
|
||||
for i, feat in enumerate(selected_features, 1):
|
||||
print(f' "{feat}",')
|
||||
print("]")
|
||||
|
||||
return selected_features
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
selected = run_probe_feature_selection_with_all_factors()
|
||||
@@ -71,6 +71,7 @@ class PolarsTranslator:
|
||||
self.register_handler("ts_atr", self._handle_ts_atr)
|
||||
self.register_handler("ts_rsi", self._handle_ts_rsi)
|
||||
self.register_handler("ts_obv", self._handle_ts_obv)
|
||||
self.register_handler("ts_rank", self._handle_ts_rank)
|
||||
# 补充的时间序列因子处理器
|
||||
self.register_handler("ts_sma", self._handle_ts_sma)
|
||||
self.register_handler("ts_wma", self._handle_ts_wma)
|
||||
@@ -100,6 +101,8 @@ class PolarsTranslator:
|
||||
self.register_handler("sin", self._handle_sin)
|
||||
self.register_handler("atan", self._handle_atan)
|
||||
self.register_handler("log1p", self._handle_log1p)
|
||||
# 条件函数
|
||||
self.register_handler("if", self._handle_if)
|
||||
|
||||
def register_handler(
|
||||
self, func_name: str, handler: Callable[[FunctionNode], pl.Expr]
|
||||
@@ -413,7 +416,7 @@ class PolarsTranslator:
|
||||
|
||||
return pl.struct(
|
||||
[high.alias("h"), low.alias("l"), close.alias("c")]
|
||||
).map_batches(calc_atr)
|
||||
).map_batches(calc_atr, return_dtype=pl.Float64)
|
||||
|
||||
@time_series
|
||||
def _handle_ts_rsi(self, node: FunctionNode) -> pl.Expr:
|
||||
@@ -437,7 +440,7 @@ class PolarsTranslator:
|
||||
result = talib.RSI(values, timeperiod=window)
|
||||
return pl.Series(result)
|
||||
|
||||
return close.map_batches(calc_rsi)
|
||||
return close.map_batches(calc_rsi, return_dtype=pl.Float64)
|
||||
|
||||
@time_series
|
||||
def _handle_ts_obv(self, node: FunctionNode) -> pl.Expr:
|
||||
@@ -466,7 +469,39 @@ class PolarsTranslator:
|
||||
result = talib.OBV(c, v)
|
||||
return pl.Series(result)
|
||||
|
||||
return pl.struct([close.alias("c"), volume.alias("v")]).map_batches(calc_obv)
|
||||
return pl.struct([close.alias("c"), volume.alias("v")]).map_batches(
|
||||
calc_obv, return_dtype=pl.Float64
|
||||
)
|
||||
|
||||
@time_series
|
||||
def _handle_ts_rank(self, node: FunctionNode) -> pl.Expr:
|
||||
"""处理 ts_rank(x, window) -> 滚动排名(分位数)。
|
||||
|
||||
计算当前值在过去窗口内的分位排名(0-1之间)。
|
||||
"""
|
||||
if len(node.args) != 2:
|
||||
raise ValueError("ts_rank 需要 2 个参数: (x, window)")
|
||||
expr = self.translate(node.args[0])
|
||||
window = self._extract_window(node.args[1])
|
||||
|
||||
def rank_calc(s: pl.Series) -> pl.Series:
|
||||
"""计算滚动排名。"""
|
||||
values = s.to_numpy()
|
||||
n = len(values)
|
||||
if n == 0:
|
||||
return pl.Series([float("nan")] * n)
|
||||
|
||||
result = np.full(n, np.nan)
|
||||
for i in range(window - 1, n):
|
||||
window_slice = values[i - window + 1 : i + 1]
|
||||
# 计算分位排名 (0-1)
|
||||
current_value = values[i]
|
||||
rank = np.sum(window_slice <= current_value) / len(window_slice)
|
||||
result[i] = rank
|
||||
|
||||
return pl.Series(result)
|
||||
|
||||
return expr.map_batches(rank_calc, return_dtype=pl.Float64)
|
||||
|
||||
# ==================== 补充的时间序列因子处理器 ====================
|
||||
|
||||
@@ -499,7 +534,8 @@ class PolarsTranslator:
|
||||
"""计算线性加权移动平均,使用 numpy 卷积优化。"""
|
||||
values = s.to_numpy()
|
||||
n = len(values)
|
||||
if n == 0:
|
||||
if n == 0 or n < window:
|
||||
# 数据为空或不足以计算完整窗口时,返回全 NaN
|
||||
return pl.Series([float("nan")] * n)
|
||||
|
||||
# 线性递增权重: 1, 2, 3, ..., window
|
||||
@@ -514,7 +550,7 @@ class PolarsTranslator:
|
||||
|
||||
return pl.Series(result)
|
||||
|
||||
return expr.map_batches(wma_calc)
|
||||
return expr.map_batches(wma_calc, return_dtype=pl.Float64)
|
||||
|
||||
@time_series
|
||||
def _handle_ts_decay_linear(self, node: FunctionNode) -> pl.Expr:
|
||||
@@ -556,7 +592,7 @@ class PolarsTranslator:
|
||||
|
||||
return pl.Series(result)
|
||||
|
||||
return expr.map_batches(argmax_calc)
|
||||
return expr.map_batches(argmax_calc, return_dtype=pl.Float64)
|
||||
|
||||
@time_series
|
||||
def _handle_ts_argmin(self, node: FunctionNode) -> pl.Expr:
|
||||
@@ -586,7 +622,7 @@ class PolarsTranslator:
|
||||
|
||||
return pl.Series(result)
|
||||
|
||||
return expr.map_batches(argmin_calc)
|
||||
return expr.map_batches(argmin_calc, return_dtype=pl.Float64)
|
||||
|
||||
@time_series
|
||||
def _handle_ts_count(self, node: FunctionNode) -> pl.Expr:
|
||||
@@ -623,7 +659,7 @@ class PolarsTranslator:
|
||||
|
||||
return pl.Series(result)
|
||||
|
||||
return expr.map_batches(prod_calc)
|
||||
return expr.map_batches(prod_calc, return_dtype=pl.Float64)
|
||||
|
||||
@time_series
|
||||
def _handle_ts_sumac(self, node: FunctionNode) -> pl.Expr:
|
||||
@@ -651,6 +687,19 @@ class PolarsTranslator:
|
||||
y = self.translate(node.args[1])
|
||||
return pl.when(x <= y).then(x).otherwise(y)
|
||||
|
||||
@element_wise
|
||||
def _handle_if(self, node: FunctionNode) -> pl.Expr:
|
||||
"""处理 if(condition, true_val, false_val) -> 条件选择。
|
||||
|
||||
使用 Polars 的 when/then/otherwise 实现条件选择。
|
||||
"""
|
||||
if len(node.args) != 3:
|
||||
raise ValueError("if 需要 3 个参数: (condition, true_val, false_val)")
|
||||
condition = self.translate(node.args[0])
|
||||
true_val = self.translate(node.args[1])
|
||||
false_val = self.translate(node.args[2])
|
||||
return pl.when(condition).then(true_val).otherwise(false_val)
|
||||
|
||||
# ==================== 截面因子处理器 (cs_*) ====================
|
||||
# 所有截面因子使用 @cross_section 装饰器自动注入 over("trade_date") 防串表
|
||||
|
||||
|
||||
Reference in New Issue
Block a user