test(debug): 添加因子回测一致性问题的调试测试套件
- 分析GTJA_alpha032等因子在不同LOOKBACK_DAYS下的差异来源 - 验证cs_rank嵌套和截面股票数量对结果的影响 - 测试ts_rank NaN处理和除法除零修复
This commit is contained in:
@@ -5,7 +5,7 @@
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import List, Optional
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import polars as pl
|
||||
|
||||
@@ -182,14 +182,14 @@ SELECTED_FACTORS = [
|
||||
"GTJA_alpha110",
|
||||
"GTJA_alpha111",
|
||||
"GTJA_alpha112",
|
||||
"GTJA_alpha113",
|
||||
# "GTJA_alpha113",
|
||||
"GTJA_alpha114",
|
||||
"GTJA_alpha115",
|
||||
"GTJA_alpha117",
|
||||
"GTJA_alpha118",
|
||||
"GTJA_alpha119",
|
||||
"GTJA_alpha120",
|
||||
"GTJA_alpha121",
|
||||
# "GTJA_alpha121",
|
||||
"GTJA_alpha122",
|
||||
"GTJA_alpha123",
|
||||
"GTJA_alpha124",
|
||||
@@ -205,13 +205,13 @@ SELECTED_FACTORS = [
|
||||
"GTJA_alpha134",
|
||||
"GTJA_alpha135",
|
||||
"GTJA_alpha136",
|
||||
"GTJA_alpha138",
|
||||
# "GTJA_alpha138",
|
||||
"GTJA_alpha139",
|
||||
"GTJA_alpha140",
|
||||
# "GTJA_alpha140",
|
||||
"GTJA_alpha141",
|
||||
"GTJA_alpha142",
|
||||
"GTJA_alpha145",
|
||||
"GTJA_alpha146",
|
||||
# "GTJA_alpha146",
|
||||
"GTJA_alpha148",
|
||||
"GTJA_alpha150",
|
||||
"GTJA_alpha151",
|
||||
@@ -253,50 +253,50 @@ SELECTED_FACTORS = [
|
||||
]
|
||||
|
||||
# 因子定义字典(完整因子库,用于存放尚未注册到metadata的因子)
|
||||
FACTOR_DEFINITIONS = {}
|
||||
FACTOR_DEFINITIONS = {"cs_rank_circ_mv": "cs_rank(circ_mv)"}
|
||||
|
||||
# 需要排除的因子列表(这些因子不会被计算和使用)
|
||||
# 用于临时屏蔽效果不好的因子,无需从 SELECTED_FACTORS 中删除
|
||||
EXCLUDED_FACTORS: List[str] = [
|
||||
# "GTJA_alpha005",
|
||||
# "GTJA_alpha028",
|
||||
# "GTJA_alpha023",
|
||||
# "GTJA_alpha002",
|
||||
# "GTJA_alpha010",
|
||||
# "GTJA_alpha011",
|
||||
# "GTJA_alpha044",
|
||||
# "GTJA_alpha036",
|
||||
# "GTJA_alpha027",
|
||||
# "GTJA_alpha109",
|
||||
# "GTJA_alpha104",
|
||||
# "GTJA_alpha103",
|
||||
# "GTJA_alpha085",
|
||||
# "GTJA_alpha111",
|
||||
# "GTJA_alpha092",
|
||||
# "GTJA_alpha067",
|
||||
# "GTJA_alpha060",
|
||||
# "GTJA_alpha062",
|
||||
# "GTJA_alpha063",
|
||||
# "GTJA_alpha079",
|
||||
# "GTJA_alpha073",
|
||||
# "GTJA_alpha087",
|
||||
# "GTJA_alpha117",
|
||||
# "GTJA_alpha113",
|
||||
# "GTJA_alpha138",
|
||||
# "GTJA_alpha121",
|
||||
# "GTJA_alpha124",
|
||||
# "GTJA_alpha133",
|
||||
# "GTJA_alpha131",
|
||||
# "GTJA_alpha118",
|
||||
# "GTJA_alpha164",
|
||||
# "GTJA_alpha162",
|
||||
# "GTJA_alpha157",
|
||||
# "GTJA_alpha171",
|
||||
# "GTJA_alpha177",
|
||||
# "GTJA_alpha180",
|
||||
# "GTJA_alpha188",
|
||||
# "GTJA_alpha191",
|
||||
]
|
||||
# EXCLUDED_FACTORS: List[str] = [
|
||||
# # "GTJA_alpha005",
|
||||
# # "GTJA_alpha028",
|
||||
# # "GTJA_alpha023",
|
||||
# # "GTJA_alpha002",
|
||||
# # "GTJA_alpha010",
|
||||
# # "GTJA_alpha011",
|
||||
# # "GTJA_alpha044",
|
||||
# # "GTJA_alpha036",
|
||||
# # "GTJA_alpha027",
|
||||
# # "GTJA_alpha109",
|
||||
# # "GTJA_alpha104",
|
||||
# # "GTJA_alpha103",
|
||||
# # "GTJA_alpha085",
|
||||
# # "GTJA_alpha111",
|
||||
# # "GTJA_alpha092",
|
||||
# # "GTJA_alpha067",
|
||||
# # "GTJA_alpha060",
|
||||
# # "GTJA_alpha062",
|
||||
# # "GTJA_alpha063",
|
||||
# # "GTJA_alpha079",
|
||||
# # "GTJA_alpha073",
|
||||
# # "GTJA_alpha087",
|
||||
# # "GTJA_alpha117",
|
||||
# # "GTJA_alpha113",
|
||||
# # "GTJA_alpha138",
|
||||
# # "GTJA_alpha121",
|
||||
# # "GTJA_alpha124",
|
||||
# # "GTJA_alpha133",
|
||||
# # "GTJA_alpha131",
|
||||
# # "GTJA_alpha118",
|
||||
# # "GTJA_alpha164",
|
||||
# # "GTJA_alpha162",
|
||||
# # "GTJA_alpha157",
|
||||
# # "GTJA_alpha171",
|
||||
# # "GTJA_alpha177",
|
||||
# # "GTJA_alpha180",
|
||||
# # "GTJA_alpha188",
|
||||
# # "GTJA_alpha191",
|
||||
# ]
|
||||
|
||||
|
||||
def get_label_factor(label_name: str) -> dict:
|
||||
@@ -471,17 +471,18 @@ def stock_pool_filter(df: pl.DataFrame) -> pl.Series:
|
||||
& ~df["ts_code"].str.starts_with("4") # 排除北交所
|
||||
)
|
||||
|
||||
# 在已筛选的股票中,选取市值最小的500只
|
||||
# 在已筛选的股票中,选取流通市值最小的500只
|
||||
valid_df = df.filter(code_filter)
|
||||
n = min(500, len(valid_df))
|
||||
small_cap_codes = valid_df.sort("total_mv").head(n)["ts_code"]
|
||||
n = min(1000, len(valid_df))
|
||||
small_cap_codes = valid_df.sort("circ_mv").head(n)["ts_code"]
|
||||
|
||||
# 返回布尔 Series:是否在被选中的股票中
|
||||
return df["ts_code"].is_in(small_cap_codes)
|
||||
|
||||
|
||||
# 定义筛选所需的基础列
|
||||
STOCK_FILTER_REQUIRED_COLUMNS = ["total_mv"]
|
||||
STOCK_FILTER_REQUIRED_COLUMNS = ["circ_mv"]
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# 输出配置
|
||||
@@ -490,7 +491,7 @@ OUTPUT_DIR = "output"
|
||||
SAVE_PREDICTIONS = True
|
||||
|
||||
# 模型保存配置
|
||||
SAVE_MODEL = True # 是否保存模型
|
||||
SAVE_MODEL = False # 是否保存模型
|
||||
MODEL_SAVE_DIR = "models" # 模型保存目录
|
||||
|
||||
# Top N 配置:每日推荐股票数量
|
||||
|
||||
@@ -41,7 +41,6 @@ from src.training.config import TrainingConfig
|
||||
from src.experiment.common import (
|
||||
SELECTED_FACTORS,
|
||||
FACTOR_DEFINITIONS,
|
||||
EXCLUDED_FACTORS,
|
||||
get_label_factor,
|
||||
register_factors,
|
||||
prepare_data,
|
||||
@@ -223,26 +222,53 @@ N_QUANTILES = 20 # 将 label 分为 20 组
|
||||
N_QUANTILES = 20 # 将 label 分为 20 组
|
||||
|
||||
# LambdaRank 模型参数配置
|
||||
# MODEL_PARAMS = {
|
||||
# "objective": "lambdarank",
|
||||
# "metric": "ndcg",
|
||||
# "ndcg_at": 15,
|
||||
# "learning_rate": 0.001,
|
||||
# "num_leaves": 32,
|
||||
# "max_depth": 5,
|
||||
# "min_data_in_leaf": 32,
|
||||
# "n_estimators": 1000,
|
||||
# "early_stopping_round": 150,
|
||||
# "subsample": 0.6,
|
||||
# "colsample_bytree": 0.6,
|
||||
# "reg_alpha": 1,
|
||||
# "reg_lambda": 3.0,
|
||||
# "verbose": -1,
|
||||
# "random_state": 42,
|
||||
# "lambdarank_truncation_level": 30,
|
||||
# "label_gain": [
|
||||
# i for i in range(1, N_QUANTILES + 1)
|
||||
# ], # 如果收益率被分为了比如 5 档,建议用[0, 1, 3, 7, 15] 这种指数型 gain
|
||||
# }
|
||||
|
||||
MODEL_PARAMS = {
|
||||
"objective": "lambdarank",
|
||||
"metric": "ndcg",
|
||||
"ndcg_at": 15,
|
||||
"learning_rate": 0.002,
|
||||
"num_leaves": 63,
|
||||
"max_depth": 5,
|
||||
"min_data_in_leaf": 63,
|
||||
"ndcg_at": 25, # 根据你实际持仓数量调整,如果是前50只股票,改成50
|
||||
"learning_rate": 0.1, # 【修改】提高学习率,配合合理的早停
|
||||
"n_estimators": 1000,
|
||||
"early_stopping_round": 150,
|
||||
"subsample": 0.8,
|
||||
"colsample_bytree": 0.8,
|
||||
"reg_alpha": 0.5,
|
||||
"reg_lambda": 1.0,
|
||||
"early_stopping_round": 50, # 【修改】验证集一旦不降,50轮内尽早停下
|
||||
# --- 1. 防止过拟合的核心约束 ---
|
||||
"max_depth": 4, # 【修改】金融数据不需要太深,3~4 足够了
|
||||
"num_leaves": 32, # 【修改】大幅减少叶子数,避免过拟合 (2^4 = 16,取12限制生长)
|
||||
"min_data_in_leaf": 256, # 【修改】极度重要!强制每个叶子必须代表一个大的股票群体(如500只)
|
||||
# --- 2. 随机采样(增加鲁棒性) ---
|
||||
"subsample": 0.4, # 每棵树使用 70% 的样本
|
||||
"subsample_freq": 1, # 每 1 轮进行一次 subsample
|
||||
"colsample_bytree": 0.4, # 【修改】降低特征采样率,迫使模型不要只依赖那几个头部 Alpha 因子,增加树的多样性
|
||||
# --- 3. 正则化惩罚 ---
|
||||
"reg_alpha": 10.0, # 【修改】增加 L1 正则化,帮助剔除无效的 GTJA 噪音因子
|
||||
"reg_lambda": 50.0, # 【修改】增加 L2 正则化
|
||||
# --- 4. Lambdarank 专属配置 ---
|
||||
"lambdarank_truncation_level": 50,
|
||||
"label_gain": [
|
||||
i * i for i in range(1, N_QUANTILES + 1)
|
||||
], # 如果收益率被分为了比如 5 档,建议用[0, 1, 3, 7, 15] 这种指数型 gain
|
||||
"verbose": -1,
|
||||
"random_state": 42,
|
||||
"lambdarank_truncation_level": 30,
|
||||
"label_gain": [
|
||||
i for i in range(1, N_QUANTILES + 1)
|
||||
], # 如果收益率被分为了比如 5 档,建议用[0, 1, 3, 7, 15] 这种指数型 gain
|
||||
}
|
||||
|
||||
# 注意:stock_pool_filter, STOCK_FILTER_REQUIRED_COLUMNS, OUTPUT_DIR 等配置
|
||||
@@ -258,6 +284,163 @@ print("=" * 80)
|
||||
print("\n[1] 创建 FactorEngine")
|
||||
engine = FactorEngine()
|
||||
|
||||
EXCLUDED_FACTORS = ['volatility_5',
|
||||
'volume_ratio_5_20',
|
||||
'capital_retention_20',
|
||||
'volatility_squeeze_5_60',
|
||||
'drawdown_from_high_60',
|
||||
'ma_ratio_5_20',
|
||||
'bias_10',
|
||||
'high_low_ratio',
|
||||
'bbi_ratio',
|
||||
'volatility_20',
|
||||
'std_return_20',
|
||||
'sharpe_ratio_20',
|
||||
'ma_5',
|
||||
'max_ret_20',
|
||||
'CP',
|
||||
'net_profit_yoy',
|
||||
'debt_to_equity',
|
||||
'EP_rank',
|
||||
'turnover_rank',
|
||||
'return_5_rank',
|
||||
'ebit_rank',
|
||||
'BP',
|
||||
'EP',
|
||||
'amihud_illiq_20',
|
||||
'profit_margin',
|
||||
'return_5',
|
||||
'return_20',
|
||||
'kaufman_ER_20',
|
||||
'GTJA_alpha043',
|
||||
'GTJA_alpha042',
|
||||
'GTJA_alpha041',
|
||||
'GTJA_alpha040',
|
||||
'GTJA_alpha039',
|
||||
'GTJA_alpha037',
|
||||
'GTJA_alpha036',
|
||||
'GTJA_alpha035',
|
||||
'GTJA_alpha033',
|
||||
'GTJA_alpha032',
|
||||
'GTJA_alpha031',
|
||||
'GTJA_alpha028',
|
||||
'GTJA_alpha026',
|
||||
'GTJA_alpha027',
|
||||
'GTJA_alpha023',
|
||||
'GTJA_alpha024',
|
||||
'GTJA_alpha009',
|
||||
'GTJA_alpha011',
|
||||
'GTJA_alpha022',
|
||||
'GTJA_alpha020',
|
||||
'GTJA_alpha018',
|
||||
'GTJA_alpha019',
|
||||
'GTJA_alpha014',
|
||||
'GTJA_alpha013',
|
||||
'GTJA_alpha010',
|
||||
'GTJA_alpha001',
|
||||
'GTJA_alpha003',
|
||||
'GTJA_alpha002',
|
||||
'GTJA_alpha004',
|
||||
'GTJA_alpha005',
|
||||
'GTJA_alpha006',
|
||||
'GTJA_alpha008',
|
||||
'turnover_deviation',
|
||||
'turnover_cv_20',
|
||||
'roa',
|
||||
'GTJA_alpha073',
|
||||
'GTJA_alpha078',
|
||||
'GTJA_alpha077',
|
||||
'GTJA_alpha076',
|
||||
'GTJA_alpha067',
|
||||
'GTJA_alpha085',
|
||||
'GTJA_alpha084',
|
||||
'GTJA_alpha087',
|
||||
'GTJA_alpha088',
|
||||
'GTJA_alpha090',
|
||||
'GTJA_alpha083',
|
||||
'GTJA_alpha079',
|
||||
'GTJA_alpha080',
|
||||
'GTJA_alpha094',
|
||||
'GTJA_alpha092',
|
||||
'GTJA_alpha089',
|
||||
'GTJA_alpha095',
|
||||
'GTJA_alpha064',
|
||||
'GTJA_alpha065',
|
||||
'GTJA_alpha066',
|
||||
'GTJA_alpha063',
|
||||
'GTJA_alpha060',
|
||||
'GTJA_alpha058',
|
||||
'GTJA_alpha057',
|
||||
'GTJA_alpha056',
|
||||
'GTJA_alpha046',
|
||||
'GTJA_alpha044',
|
||||
'GTJA_alpha049',
|
||||
'GTJA_alpha050',
|
||||
'GTJA_alpha110',
|
||||
'GTJA_alpha107',
|
||||
'GTJA_alpha104',
|
||||
'GTJA_alpha106',
|
||||
'GTJA_alpha103',
|
||||
'GTJA_alpha100',
|
||||
'GTJA_alpha101',
|
||||
'GTJA_alpha102',
|
||||
'GTJA_alpha098',
|
||||
'GTJA_alpha097',
|
||||
'GTJA_alpha096',
|
||||
'GTJA_alpha099',
|
||||
'GTJA_alpha117',
|
||||
'GTJA_alpha118',
|
||||
'GTJA_alpha114',
|
||||
'GTJA_alpha111',
|
||||
'GTJA_alpha129',
|
||||
'GTJA_alpha130',
|
||||
'GTJA_alpha132',
|
||||
'GTJA_alpha131',
|
||||
'GTJA_alpha134',
|
||||
'GTJA_alpha135',
|
||||
'GTJA_alpha136',
|
||||
'GTJA_alpha112',
|
||||
'GTJA_alpha120',
|
||||
'GTJA_alpha119',
|
||||
'GTJA_alpha122',
|
||||
'GTJA_alpha124',
|
||||
'GTJA_alpha126',
|
||||
'GTJA_alpha127',
|
||||
'GTJA_alpha128',
|
||||
'GTJA_alpha115',
|
||||
'GTJA_alpha153',
|
||||
'GTJA_alpha152',
|
||||
'GTJA_alpha151',
|
||||
'GTJA_alpha150',
|
||||
'GTJA_alpha148',
|
||||
'GTJA_alpha142',
|
||||
'GTJA_alpha141',
|
||||
'GTJA_alpha139',
|
||||
'GTJA_alpha133',
|
||||
'GTJA_alpha161',
|
||||
'GTJA_alpha164',
|
||||
'GTJA_alpha162',
|
||||
'GTJA_alpha157',
|
||||
'GTJA_alpha156',
|
||||
'GTJA_alpha160',
|
||||
'GTJA_alpha155',
|
||||
'GTJA_alpha170',
|
||||
'GTJA_alpha169',
|
||||
'GTJA_alpha168',
|
||||
'GTJA_alpha166',
|
||||
'GTJA_alpha163',
|
||||
'GTJA_alpha176',
|
||||
'GTJA_alpha175',
|
||||
'GTJA_alpha174',
|
||||
'GTJA_alpha178',
|
||||
'GTJA_alpha177',
|
||||
'GTJA_alpha185',
|
||||
'GTJA_alpha180',
|
||||
'GTJA_alpha187',
|
||||
'GTJA_alpha188',
|
||||
'GTJA_alpha189',
|
||||
'GTJA_alpha191',]
|
||||
|
||||
# 2. 使用 metadata 定义因子
|
||||
print("\n[2] 定义因子(从 metadata 注册)")
|
||||
feature_cols = register_factors(
|
||||
@@ -538,14 +721,64 @@ print("-" * 40)
|
||||
for metric, value in ndcg_results.items():
|
||||
print(f" {metric}: {value:.4f}")
|
||||
|
||||
# 特征重要性
|
||||
print("\n特征重要性(Top 20):")
|
||||
print("-" * 40)
|
||||
importance = model.feature_importance()
|
||||
if importance is not None:
|
||||
top_features = importance.sort_values(ascending=False).head(20)
|
||||
for i, (feature, score) in enumerate(top_features.items(), 1):
|
||||
print(f" {i:2d}. {feature:30s} {score:10.2f}")
|
||||
# 特征重要性
|
||||
print("\n特征重要性分析:")
|
||||
print("=" * 80)
|
||||
importance = model.feature_importance()
|
||||
if importance is not None:
|
||||
# 按重要性降序排列
|
||||
importance_sorted = importance.sort_values(ascending=False)
|
||||
|
||||
# 计算总重要性和百分比
|
||||
total_importance = importance_sorted.sum()
|
||||
importance_pct = (importance_sorted / total_importance * 100).round(2)
|
||||
|
||||
# 找出贡献为0的特征
|
||||
zero_importance_features = importance_sorted[
|
||||
importance_sorted == 0
|
||||
].index.tolist()
|
||||
|
||||
# 打印所有特征重要性(带百分比)
|
||||
print(f"\n所有特征重要性(共 {len(importance_sorted)} 个):")
|
||||
print("-" * 80)
|
||||
print(f"{'排名':<6}{'特征名':<35}{'重要性':<15}{'占比':<10}")
|
||||
print("-" * 80)
|
||||
for i, (feature, score) in enumerate(importance_sorted.items(), 1):
|
||||
pct = importance_pct[feature]
|
||||
if score == 0:
|
||||
marker = " [零贡献]"
|
||||
elif pct >= 10:
|
||||
marker = " [高贡献]"
|
||||
elif pct >= 1:
|
||||
marker = " [中贡献]"
|
||||
else:
|
||||
marker = " [低贡献]"
|
||||
print(f"{i:<6}{feature:<35}{score:<15.2f}{pct:<8.2f}%{marker}")
|
||||
|
||||
# 打印贡献为0的特征
|
||||
print("\n" + "=" * 80)
|
||||
if zero_importance_features:
|
||||
print(f"[警告] 贡献为0的特征(共 {len(zero_importance_features)} 个):")
|
||||
print("-" * 80)
|
||||
for i, feature in enumerate(zero_importance_features, 1):
|
||||
print(f"'{feature}',")
|
||||
else:
|
||||
print("[信息] 所有特征都有贡献(无零贡献特征)")
|
||||
|
||||
# 打印统计摘要
|
||||
print("\n" + "=" * 80)
|
||||
print("特征重要性统计摘要:")
|
||||
print("-" * 80)
|
||||
print(f" 特征总数: {len(importance_sorted)}")
|
||||
print(
|
||||
f" 有贡献特征数: {len(importance_sorted) - len(zero_importance_features)}"
|
||||
)
|
||||
print(f" 零贡献特征数: {len(zero_importance_features)}")
|
||||
print(
|
||||
f" 零贡献占比: {len(zero_importance_features) / len(importance_sorted) * 100:.1f}%"
|
||||
)
|
||||
print(f" Top 10特征累计占比: {importance_pct.head(10).sum():.1f}%")
|
||||
print(f" Top 20特征累计占比: {importance_pct.head(20).sum():.1f}%")
|
||||
# %%
|
||||
# 确保输出目录存在
|
||||
os.makedirs(OUTPUT_DIR, exist_ok=True)
|
||||
|
||||
@@ -25,7 +25,6 @@ from src.training.config import TrainingConfig
|
||||
from src.experiment.common import (
|
||||
SELECTED_FACTORS,
|
||||
FACTOR_DEFINITIONS,
|
||||
EXCLUDED_FACTORS,
|
||||
get_label_factor,
|
||||
register_factors,
|
||||
prepare_data,
|
||||
@@ -97,6 +96,31 @@ print("=" * 80)
|
||||
print("\n[1] 创建 FactorEngine")
|
||||
engine = FactorEngine()
|
||||
|
||||
EXCLUDED_FACTORS = [
|
||||
'GTJA_alpha010',
|
||||
'GTJA_alpha005',
|
||||
'GTJA_alpha036',
|
||||
'GTJA_alpha027',
|
||||
'GTJA_alpha053',
|
||||
'GTJA_alpha073',
|
||||
'GTJA_alpha104',
|
||||
'GTJA_alpha103',
|
||||
'GTJA_alpha087',
|
||||
'GTJA_alpha092',
|
||||
'GTJA_alpha085',
|
||||
'GTJA_alpha044',
|
||||
'GTJA_alpha062',
|
||||
'GTJA_alpha124',
|
||||
'GTJA_alpha133',
|
||||
'GTJA_alpha131',
|
||||
'GTJA_alpha117',
|
||||
'GTJA_alpha157',
|
||||
'GTJA_alpha162',
|
||||
'GTJA_alpha177',
|
||||
'GTJA_alpha180',
|
||||
'GTJA_alpha191',
|
||||
]
|
||||
|
||||
# 2. 使用 metadata 定义因子
|
||||
print("\n[2] 定义因子(从 metadata 注册)")
|
||||
feature_cols = register_factors(
|
||||
|
||||
@@ -184,7 +184,8 @@ class PolarsTranslator:
|
||||
"+": lambda l, r: l + r,
|
||||
"-": lambda l, r: l - r,
|
||||
"*": lambda l, r: l * r,
|
||||
"/": lambda l, r: l / r,
|
||||
# 【修复】除法处理除零,避免产生 NaN/inf 导致 EMA 永久感染
|
||||
"/": lambda l, r: pl.when(r == 0).then(None).otherwise(l / r),
|
||||
"**": lambda l, r: l.pow(r),
|
||||
"//": lambda l, r: l.floor_div(r),
|
||||
"%": lambda l, r: l % r,
|
||||
@@ -363,6 +364,7 @@ class PolarsTranslator:
|
||||
# 抛弃极慢的 rolling_map,借用 pandas 的 Cython 引擎
|
||||
def kurt_calc(s: pl.Series) -> pl.Series:
|
||||
import pandas as pd
|
||||
|
||||
# pandas.rolling.kurt() 是用 Cython 编写的,速度比 pure python 快很多
|
||||
pd_series = pd.Series(s.to_numpy())
|
||||
result = pd_series.rolling(window).kurt().to_numpy()
|
||||
@@ -499,11 +501,21 @@ class PolarsTranslator:
|
||||
# 当前值即为每个窗口的最后一个元素 (N - window + 1, )
|
||||
current_vals = windows[:, -1]
|
||||
|
||||
# 向量化广播比较,然后沿窗口轴(axis=1)求和,直接得出排名比例
|
||||
ranks = np.sum(windows <= current_vals[:, None], axis=1) / window
|
||||
# 【终极修复】使用窗口内实际有效数据个数作为分母
|
||||
# 1. 统计小于等于当前值的个数
|
||||
less_equal = np.sum(windows <= current_vals[:, None], axis=1)
|
||||
# 2. 统计当前窗口内有效的非 NaN 数据个数
|
||||
valid_counts = np.sum(~np.isnan(windows), axis=1)
|
||||
|
||||
# 3. 使用真实有效个数作为分母,避免分母陷阱
|
||||
with np.errstate(divide="ignore", invalid="ignore"):
|
||||
ranks = np.where(valid_counts > 0, less_equal / valid_counts, np.nan)
|
||||
|
||||
# 【修复】如果当前值是 NaN,则排名也必须是 NaN
|
||||
ranks[np.isnan(current_vals)] = np.nan
|
||||
|
||||
result = np.full(n, np.nan)
|
||||
result[window - 1:] = ranks
|
||||
result[window - 1 :] = ranks
|
||||
return pl.Series(result)
|
||||
|
||||
return expr.map_batches(rank_calc, return_dtype=pl.Float64)
|
||||
@@ -592,7 +604,7 @@ class PolarsTranslator:
|
||||
distances = window - 1 - argmax_indices
|
||||
|
||||
result = np.full(n, np.nan)
|
||||
result[window - 1:] = distances
|
||||
result[window - 1 :] = distances
|
||||
return pl.Series(result)
|
||||
|
||||
return expr.map_batches(argmax_calc, return_dtype=pl.Float64)
|
||||
@@ -616,7 +628,7 @@ class PolarsTranslator:
|
||||
distances = window - 1 - argmin_indices
|
||||
|
||||
result = np.full(n, np.nan)
|
||||
result[window - 1:] = distances
|
||||
result[window - 1 :] = distances
|
||||
return pl.Series(result)
|
||||
|
||||
return expr.map_batches(argmin_calc, return_dtype=pl.Float64)
|
||||
@@ -650,7 +662,7 @@ class PolarsTranslator:
|
||||
prods = np.prod(windows, axis=1)
|
||||
|
||||
result = np.full(n, np.nan)
|
||||
result[window - 1:] = prods
|
||||
result[window - 1 :] = prods
|
||||
return pl.Series(result)
|
||||
|
||||
return expr.map_batches(prod_calc, return_dtype=pl.Float64)
|
||||
|
||||
Reference in New Issue
Block a user