feat(factors): 添加 GTJA alpha 因子并优化计算性能
- 新增 190+ 个 GTJA alpha 因子到因子列表 - 优化 ts_kurt、ts_rank、ts_argmax/min、ts_prod 计算性能 - 性能分析器新增超时检测(180秒)和实时打印功能 - 简化探针因子选择脚本的 main 函数入口
This commit is contained in:
@@ -28,7 +28,6 @@ TEST_END = "20261231"
|
|||||||
# =============================================================================
|
# =============================================================================
|
||||||
# 当前选择的因子列表(从 FACTOR_DEFINITIONS 中选择要使用的因子)
|
# 当前选择的因子列表(从 FACTOR_DEFINITIONS 中选择要使用的因子)
|
||||||
SELECTED_FACTORS = [
|
SELECTED_FACTORS = [
|
||||||
# ================= 1. 价格、趋势与路径依赖 =================
|
|
||||||
"ma_5",
|
"ma_5",
|
||||||
"ma_20",
|
"ma_20",
|
||||||
"ma_ratio_5_20",
|
"ma_ratio_5_20",
|
||||||
@@ -41,7 +40,6 @@ SELECTED_FACTORS = [
|
|||||||
"mom_acceleration_10_20",
|
"mom_acceleration_10_20",
|
||||||
"drawdown_from_high_60",
|
"drawdown_from_high_60",
|
||||||
"up_days_ratio_20",
|
"up_days_ratio_20",
|
||||||
# ================= 2. 波动率、风险调整与高阶矩 =================
|
|
||||||
"volatility_5",
|
"volatility_5",
|
||||||
"volatility_20",
|
"volatility_20",
|
||||||
"volatility_ratio",
|
"volatility_ratio",
|
||||||
@@ -49,12 +47,10 @@ SELECTED_FACTORS = [
|
|||||||
"sharpe_ratio_20",
|
"sharpe_ratio_20",
|
||||||
"min_ret_20",
|
"min_ret_20",
|
||||||
"volatility_squeeze_5_60",
|
"volatility_squeeze_5_60",
|
||||||
# ================= 3. 日内微观结构与异象 =================
|
|
||||||
"overnight_intraday_diff",
|
"overnight_intraday_diff",
|
||||||
"upper_shadow_ratio",
|
"upper_shadow_ratio",
|
||||||
"capital_retention_20",
|
"capital_retention_20",
|
||||||
"max_ret_20",
|
"max_ret_20",
|
||||||
# ================= 4. 量能、流动性与量价背离 =================
|
|
||||||
"volume_ratio_5_20",
|
"volume_ratio_5_20",
|
||||||
"turnover_rate_mean_5",
|
"turnover_rate_mean_5",
|
||||||
"turnover_deviation",
|
"turnover_deviation",
|
||||||
@@ -62,7 +58,6 @@ SELECTED_FACTORS = [
|
|||||||
"turnover_cv_20",
|
"turnover_cv_20",
|
||||||
"pv_corr_20",
|
"pv_corr_20",
|
||||||
"close_vwap_deviation",
|
"close_vwap_deviation",
|
||||||
# ================= 5. 基本面财务特征 =================
|
|
||||||
"roe",
|
"roe",
|
||||||
"roa",
|
"roa",
|
||||||
"profit_margin",
|
"profit_margin",
|
||||||
@@ -71,7 +66,6 @@ SELECTED_FACTORS = [
|
|||||||
"net_profit_yoy",
|
"net_profit_yoy",
|
||||||
"revenue_yoy",
|
"revenue_yoy",
|
||||||
"healthy_expansion_velocity",
|
"healthy_expansion_velocity",
|
||||||
# ================= 6. 基本面估值与截面动量共振 =================
|
|
||||||
"EP",
|
"EP",
|
||||||
"BP",
|
"BP",
|
||||||
"CP",
|
"CP",
|
||||||
@@ -83,11 +77,185 @@ SELECTED_FACTORS = [
|
|||||||
"value_price_divergence",
|
"value_price_divergence",
|
||||||
"active_market_cap",
|
"active_market_cap",
|
||||||
"ebit_rank",
|
"ebit_rank",
|
||||||
|
"GTJA_alpha001",
|
||||||
|
"GTJA_alpha002",
|
||||||
|
"GTJA_alpha003",
|
||||||
|
"GTJA_alpha004",
|
||||||
|
"GTJA_alpha005",
|
||||||
|
"GTJA_alpha006",
|
||||||
|
"GTJA_alpha007",
|
||||||
|
"GTJA_alpha008",
|
||||||
|
"GTJA_alpha009",
|
||||||
|
"GTJA_alpha010",
|
||||||
|
"GTJA_alpha011",
|
||||||
|
"GTJA_alpha012",
|
||||||
|
"GTJA_alpha013",
|
||||||
|
"GTJA_alpha014",
|
||||||
|
"GTJA_alpha015",
|
||||||
|
"GTJA_alpha016",
|
||||||
|
"GTJA_alpha017",
|
||||||
|
"GTJA_alpha018",
|
||||||
|
"GTJA_alpha019",
|
||||||
|
"GTJA_alpha020",
|
||||||
|
"GTJA_alpha022",
|
||||||
|
"GTJA_alpha023",
|
||||||
|
"GTJA_alpha024",
|
||||||
|
"GTJA_alpha025",
|
||||||
|
"GTJA_alpha026",
|
||||||
|
"GTJA_alpha027",
|
||||||
|
"GTJA_alpha028",
|
||||||
|
"GTJA_alpha029",
|
||||||
|
"GTJA_alpha031",
|
||||||
|
"GTJA_alpha032",
|
||||||
|
"GTJA_alpha033",
|
||||||
|
"GTJA_alpha034",
|
||||||
|
"GTJA_alpha035",
|
||||||
|
"GTJA_alpha036",
|
||||||
|
"GTJA_alpha037",
|
||||||
|
# "GTJA_alpha038",
|
||||||
|
"GTJA_alpha039",
|
||||||
|
"GTJA_alpha040",
|
||||||
|
"GTJA_alpha041",
|
||||||
|
"GTJA_alpha042",
|
||||||
|
"GTJA_alpha043",
|
||||||
|
"GTJA_alpha044",
|
||||||
|
"GTJA_alpha045",
|
||||||
|
"GTJA_alpha046",
|
||||||
|
"GTJA_alpha047",
|
||||||
|
"GTJA_alpha048",
|
||||||
|
"GTJA_alpha049",
|
||||||
|
"GTJA_alpha050",
|
||||||
|
"GTJA_alpha051",
|
||||||
|
"GTJA_alpha052",
|
||||||
|
"GTJA_alpha053",
|
||||||
|
"GTJA_alpha054",
|
||||||
|
"GTJA_alpha056",
|
||||||
|
"GTJA_alpha057",
|
||||||
|
"GTJA_alpha058",
|
||||||
|
"GTJA_alpha059",
|
||||||
|
"GTJA_alpha060",
|
||||||
|
"GTJA_alpha061",
|
||||||
|
"GTJA_alpha062",
|
||||||
|
"GTJA_alpha063",
|
||||||
|
"GTJA_alpha064",
|
||||||
|
"GTJA_alpha065",
|
||||||
|
"GTJA_alpha066",
|
||||||
|
"GTJA_alpha067",
|
||||||
|
"GTJA_alpha068",
|
||||||
|
"GTJA_alpha070",
|
||||||
|
"GTJA_alpha071",
|
||||||
|
"GTJA_alpha072",
|
||||||
|
"GTJA_alpha073",
|
||||||
|
"GTJA_alpha074",
|
||||||
|
"GTJA_alpha076",
|
||||||
|
"GTJA_alpha077",
|
||||||
|
"GTJA_alpha078",
|
||||||
|
"GTJA_alpha079",
|
||||||
|
"GTJA_alpha080",
|
||||||
|
"GTJA_alpha081",
|
||||||
|
"GTJA_alpha082",
|
||||||
|
"GTJA_alpha083",
|
||||||
|
"GTJA_alpha084",
|
||||||
|
"GTJA_alpha085",
|
||||||
|
"GTJA_alpha086",
|
||||||
|
"GTJA_alpha087",
|
||||||
|
"GTJA_alpha088",
|
||||||
|
"GTJA_alpha089",
|
||||||
|
"GTJA_alpha090",
|
||||||
|
"GTJA_alpha091",
|
||||||
|
"GTJA_alpha092",
|
||||||
|
"GTJA_alpha093",
|
||||||
|
"GTJA_alpha094",
|
||||||
|
"GTJA_alpha095",
|
||||||
|
"GTJA_alpha096",
|
||||||
|
"GTJA_alpha097",
|
||||||
|
"GTJA_alpha098",
|
||||||
|
"GTJA_alpha099",
|
||||||
|
"GTJA_alpha100",
|
||||||
|
"GTJA_alpha101",
|
||||||
|
"GTJA_alpha102",
|
||||||
|
"GTJA_alpha103",
|
||||||
|
"GTJA_alpha104",
|
||||||
|
"GTJA_alpha105",
|
||||||
|
"GTJA_alpha106",
|
||||||
|
"GTJA_alpha107",
|
||||||
|
"GTJA_alpha108",
|
||||||
|
"GTJA_alpha109",
|
||||||
|
"GTJA_alpha110",
|
||||||
|
"GTJA_alpha111",
|
||||||
|
"GTJA_alpha112",
|
||||||
|
"GTJA_alpha113",
|
||||||
|
"GTJA_alpha114",
|
||||||
|
"GTJA_alpha115",
|
||||||
|
"GTJA_alpha117",
|
||||||
|
"GTJA_alpha118",
|
||||||
|
"GTJA_alpha119",
|
||||||
|
"GTJA_alpha120",
|
||||||
|
"GTJA_alpha121",
|
||||||
|
"GTJA_alpha122",
|
||||||
|
"GTJA_alpha123",
|
||||||
|
"GTJA_alpha124",
|
||||||
|
"GTJA_alpha125",
|
||||||
|
"GTJA_alpha126",
|
||||||
|
"GTJA_alpha127",
|
||||||
|
"GTJA_alpha128",
|
||||||
|
"GTJA_alpha129",
|
||||||
|
"GTJA_alpha130",
|
||||||
|
"GTJA_alpha131",
|
||||||
|
"GTJA_alpha132",
|
||||||
|
"GTJA_alpha133",
|
||||||
|
"GTJA_alpha134",
|
||||||
|
"GTJA_alpha135",
|
||||||
|
"GTJA_alpha136",
|
||||||
|
"GTJA_alpha138",
|
||||||
|
"GTJA_alpha139",
|
||||||
|
"GTJA_alpha140",
|
||||||
|
"GTJA_alpha141",
|
||||||
|
"GTJA_alpha142",
|
||||||
|
"GTJA_alpha145",
|
||||||
|
"GTJA_alpha146",
|
||||||
|
"GTJA_alpha148",
|
||||||
|
"GTJA_alpha150",
|
||||||
|
"GTJA_alpha151",
|
||||||
|
"GTJA_alpha152",
|
||||||
|
"GTJA_alpha153",
|
||||||
|
"GTJA_alpha154",
|
||||||
|
"GTJA_alpha155",
|
||||||
|
"GTJA_alpha156",
|
||||||
|
"GTJA_alpha157",
|
||||||
|
"GTJA_alpha158",
|
||||||
|
"GTJA_alpha159",
|
||||||
|
"GTJA_alpha160",
|
||||||
|
"GTJA_alpha161",
|
||||||
|
"GTJA_alpha162",
|
||||||
|
"GTJA_alpha163",
|
||||||
|
"GTJA_alpha164",
|
||||||
|
"GTJA_alpha165",
|
||||||
|
"GTJA_alpha166",
|
||||||
|
"GTJA_alpha167",
|
||||||
|
"GTJA_alpha168",
|
||||||
|
"GTJA_alpha169",
|
||||||
|
"GTJA_alpha170",
|
||||||
|
"GTJA_alpha171",
|
||||||
|
"GTJA_alpha173",
|
||||||
|
"GTJA_alpha174",
|
||||||
|
"GTJA_alpha175",
|
||||||
|
"GTJA_alpha176",
|
||||||
|
"GTJA_alpha177",
|
||||||
|
"GTJA_alpha178",
|
||||||
|
"GTJA_alpha179",
|
||||||
|
"GTJA_alpha180",
|
||||||
|
"GTJA_alpha183",
|
||||||
|
"GTJA_alpha184",
|
||||||
|
"GTJA_alpha185",
|
||||||
|
"GTJA_alpha187",
|
||||||
|
"GTJA_alpha188",
|
||||||
|
"GTJA_alpha189",
|
||||||
|
"GTJA_alpha191",
|
||||||
]
|
]
|
||||||
|
|
||||||
# 因子定义字典(完整因子库,用于存放尚未注册到metadata的因子)
|
# 因子定义字典(完整因子库,用于存放尚未注册到metadata的因子)
|
||||||
FACTOR_DEFINITIONS = {
|
FACTOR_DEFINITIONS = {
|
||||||
'test': '[([(col("close")) - (col("close").shift([dyn int: 5]).over([col("ts_code")]))]) / (col("close").shift([dyn int: 5]).over([col("ts_code")]))]'
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -328,20 +328,7 @@ def run_probe_feature_selection_with_all_factors(debug: bool = True):
|
|||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
"""主入口函数,支持命令行参数"""
|
selected = run_probe_feature_selection_with_all_factors(debug=True)
|
||||||
parser = argparse.ArgumentParser(
|
|
||||||
description="探针法因子筛选 - 使用 FactorManager 中所有因子"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--debug",
|
|
||||||
"-d",
|
|
||||||
action="store_true",
|
|
||||||
help="启用 debug 模式,显示详细的性能统计信息",
|
|
||||||
)
|
|
||||||
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
selected = run_probe_feature_selection_with_all_factors(debug=args.debug)
|
|
||||||
return selected
|
return selected
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -91,8 +91,11 @@ class FactorEngine:
|
|||||||
# 调试模式配置
|
# 调试模式配置
|
||||||
self._debug = debug
|
self._debug = debug
|
||||||
|
|
||||||
# 初始化性能分析器
|
# 初始化性能分析器(debug 模式下启用实时打印和超时检测)
|
||||||
self._profiler = PerformanceProfiler(enabled=debug)
|
self._profiler = PerformanceProfiler(
|
||||||
|
enabled=debug,
|
||||||
|
immediate_print=debug,
|
||||||
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def profiler(self) -> PerformanceProfiler:
|
def profiler(self) -> PerformanceProfiler:
|
||||||
|
|||||||
@@ -5,6 +5,7 @@
|
|||||||
df = df.with_columns(expr) # 触发真实计算
|
df = df.with_columns(expr) # 触发真实计算
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
import time
|
import time
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
@@ -21,6 +22,17 @@ class ProfileRecord:
|
|||||||
# call_count: int = 1
|
# call_count: int = 1
|
||||||
|
|
||||||
|
|
||||||
|
class FactorTimeoutError(TimeoutError):
|
||||||
|
"""因子计算超时异常"""
|
||||||
|
|
||||||
|
def __init__(self, factor_name: str, timeout_seconds: float):
|
||||||
|
self.factor_name = factor_name
|
||||||
|
self.timeout_seconds = timeout_seconds
|
||||||
|
super().__init__(
|
||||||
|
f"因子 '{factor_name}' 计算超时(超过 {timeout_seconds:.0f} 秒)"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class PerformanceProfiler:
|
class PerformanceProfiler:
|
||||||
"""性能分析器 - 独立组件,与 FactorEngine 解耦
|
"""性能分析器 - 独立组件,与 FactorEngine 解耦
|
||||||
|
|
||||||
@@ -29,20 +41,33 @@ class PerformanceProfiler:
|
|||||||
df = df.with_columns(expr) # 触发真实计算
|
df = df.with_columns(expr) # 触发真实计算
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, enabled: bool = False):
|
# 默认超时时间:3分钟(180秒)
|
||||||
|
DEFAULT_TIMEOUT_SECONDS = 180
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
enabled: bool = False,
|
||||||
|
timeout_seconds: Optional[float] = None,
|
||||||
|
immediate_print: bool = False,
|
||||||
|
):
|
||||||
self.enabled = enabled
|
self.enabled = enabled
|
||||||
|
self.timeout_seconds = timeout_seconds or self.DEFAULT_TIMEOUT_SECONDS
|
||||||
|
self.immediate_print = immediate_print
|
||||||
self.records: Dict[str, ProfileRecord] = {}
|
self.records: Dict[str, ProfileRecord] = {}
|
||||||
self._context_stack: List[str] = [] # 支持嵌套计时
|
self._context_stack: List[str] = [] # 支持嵌套计时
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def measure(self, name: str) -> Iterator[None]:
|
def measure(self, name: str) -> Iterator[None]:
|
||||||
"""上下文管理器:安全计时,自动处理异常
|
"""上下文管理器:安全计时,自动处理异常,支持超时检测
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
name: 计时任务名称(通常是因子名称)
|
name: 计时任务名称(通常是因子名称)
|
||||||
|
|
||||||
Yields:
|
Yields:
|
||||||
None
|
None
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
FactorTimeoutError: 当计算时间超过超时阈值时
|
||||||
"""
|
"""
|
||||||
if not self.enabled:
|
if not self.enabled:
|
||||||
yield
|
yield
|
||||||
@@ -63,6 +88,43 @@ class PerformanceProfiler:
|
|||||||
|
|
||||||
self.records[name].time_ms += elapsed
|
self.records[name].time_ms += elapsed
|
||||||
|
|
||||||
|
# 实时打印当前因子性能信息
|
||||||
|
if self.immediate_print:
|
||||||
|
self._print_single_factor(name, elapsed)
|
||||||
|
|
||||||
|
# 检查是否超时
|
||||||
|
if elapsed > self.timeout_seconds:
|
||||||
|
self._handle_timeout(name, elapsed)
|
||||||
|
|
||||||
|
def _print_single_factor(self, name: str, elapsed_seconds: float) -> None:
|
||||||
|
"""打印单个因子的性能信息
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: 因子名称
|
||||||
|
elapsed_seconds: 耗时(秒)
|
||||||
|
"""
|
||||||
|
elapsed_ms = elapsed_seconds * 1000
|
||||||
|
print(f"[Performance] {name}: {elapsed_ms:.2f}ms")
|
||||||
|
|
||||||
|
def _handle_timeout(self, name: str, elapsed_seconds: float) -> None:
|
||||||
|
"""处理超时情况
|
||||||
|
|
||||||
|
在 debug 模式下,如果因子计算超过阈值,强制中止进程。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: 因子名称
|
||||||
|
elapsed_seconds: 实际耗时(秒)
|
||||||
|
"""
|
||||||
|
error_msg = (
|
||||||
|
f"[ERROR] 因子 '{name}' 计算超时!"
|
||||||
|
f" 耗时: {elapsed_seconds:.1f}s,阈值: {self.timeout_seconds:.0f}s"
|
||||||
|
)
|
||||||
|
print(error_msg, flush=True)
|
||||||
|
print(f"[ERROR] 强制中止进程", flush=True)
|
||||||
|
|
||||||
|
# 强制退出进程
|
||||||
|
os._exit(1)
|
||||||
|
|
||||||
def get_report(self) -> Dict[str, Any]:
|
def get_report(self) -> Dict[str, Any]:
|
||||||
"""生成性能报告
|
"""生成性能报告
|
||||||
|
|
||||||
|
|||||||
@@ -359,11 +359,16 @@ class PolarsTranslator:
|
|||||||
raise ValueError("ts_kurt 需要 2 个参数: (expr, window)")
|
raise ValueError("ts_kurt 需要 2 个参数: (expr, window)")
|
||||||
expr = self.translate(node.args[0])
|
expr = self.translate(node.args[0])
|
||||||
window = self._extract_window(node.args[1])
|
window = self._extract_window(node.args[1])
|
||||||
# 使用 rolling_map 计算峰度
|
|
||||||
return expr.rolling_map(
|
# 抛弃极慢的 rolling_map,借用 pandas 的 Cython 引擎
|
||||||
lambda s: s.kurtosis() if len(s.drop_nulls()) >= 4 else float("nan"),
|
def kurt_calc(s: pl.Series) -> pl.Series:
|
||||||
window_size=window,
|
import pandas as pd
|
||||||
)
|
# pandas.rolling.kurt() 是用 Cython 编写的,速度比 pure python 快很多
|
||||||
|
pd_series = pd.Series(s.to_numpy())
|
||||||
|
result = pd_series.rolling(window).kurt().to_numpy()
|
||||||
|
return pl.Series(result)
|
||||||
|
|
||||||
|
return expr.map_batches(kurt_calc, return_dtype=pl.Float64)
|
||||||
|
|
||||||
@time_series
|
@time_series
|
||||||
def _handle_ts_pct_change(self, node: FunctionNode) -> pl.Expr:
|
def _handle_ts_pct_change(self, node: FunctionNode) -> pl.Expr:
|
||||||
@@ -475,30 +480,30 @@ class PolarsTranslator:
|
|||||||
|
|
||||||
@time_series
|
@time_series
|
||||||
def _handle_ts_rank(self, node: FunctionNode) -> pl.Expr:
|
def _handle_ts_rank(self, node: FunctionNode) -> pl.Expr:
|
||||||
"""处理 ts_rank(x, window) -> 滚动排名(分位数)。
|
"""处理 ts_rank(x, window) -> 滚动排名(分位数)。"""
|
||||||
|
|
||||||
计算当前值在过去窗口内的分位排名(0-1之间)。
|
|
||||||
"""
|
|
||||||
if len(node.args) != 2:
|
if len(node.args) != 2:
|
||||||
raise ValueError("ts_rank 需要 2 个参数: (x, window)")
|
raise ValueError("ts_rank 需要 2 个参数: (x, window)")
|
||||||
expr = self.translate(node.args[0])
|
expr = self.translate(node.args[0])
|
||||||
window = self._extract_window(node.args[1])
|
window = self._extract_window(node.args[1])
|
||||||
|
|
||||||
def rank_calc(s: pl.Series) -> pl.Series:
|
def rank_calc(s: pl.Series) -> pl.Series:
|
||||||
"""计算滚动排名。"""
|
|
||||||
values = s.to_numpy()
|
values = s.to_numpy()
|
||||||
n = len(values)
|
n = len(values)
|
||||||
if n == 0:
|
if n < window:
|
||||||
return pl.Series([float("nan")] * n)
|
return pl.Series([float("nan")] * n)
|
||||||
|
|
||||||
result = np.full(n, np.nan)
|
# 核心魔法:创建零拷贝的 2D 滑动窗口视图
|
||||||
for i in range(window - 1, n):
|
# 形状为 (N - window + 1, window)
|
||||||
window_slice = values[i - window + 1 : i + 1]
|
windows = np.lib.stride_tricks.sliding_window_view(values, window)
|
||||||
# 计算分位排名 (0-1)
|
|
||||||
current_value = values[i]
|
|
||||||
rank = np.sum(window_slice <= current_value) / len(window_slice)
|
|
||||||
result[i] = rank
|
|
||||||
|
|
||||||
|
# 当前值即为每个窗口的最后一个元素 (N - window + 1, )
|
||||||
|
current_vals = windows[:, -1]
|
||||||
|
|
||||||
|
# 向量化广播比较,然后沿窗口轴(axis=1)求和,直接得出排名比例
|
||||||
|
ranks = np.sum(windows <= current_vals[:, None], axis=1) / window
|
||||||
|
|
||||||
|
result = np.full(n, np.nan)
|
||||||
|
result[window - 1:] = ranks
|
||||||
return pl.Series(result)
|
return pl.Series(result)
|
||||||
|
|
||||||
return expr.map_batches(rank_calc, return_dtype=pl.Float64)
|
return expr.map_batches(rank_calc, return_dtype=pl.Float64)
|
||||||
@@ -564,62 +569,54 @@ class PolarsTranslator:
|
|||||||
|
|
||||||
@time_series
|
@time_series
|
||||||
def _handle_ts_argmax(self, node: FunctionNode) -> pl.Expr:
|
def _handle_ts_argmax(self, node: FunctionNode) -> pl.Expr:
|
||||||
"""处理 ts_argmax(x, window) -> 距离最大值的交易日数。
|
"""处理 ts_argmax(x, window) -> 距离最大值的交易日数。"""
|
||||||
|
|
||||||
HIGHDAY 语义:返回距离过去 window 期内最大值的交易日数。
|
|
||||||
例如:今天是最高点返回 0,昨天是最高点返回 1。
|
|
||||||
"""
|
|
||||||
if len(node.args) != 2:
|
if len(node.args) != 2:
|
||||||
raise ValueError("ts_argmax 需要 2 个参数: (x, window)")
|
raise ValueError("ts_argmax 需要 2 个参数: (x, window)")
|
||||||
expr = self.translate(node.args[0])
|
expr = self.translate(node.args[0])
|
||||||
window = self._extract_window(node.args[1])
|
window = self._extract_window(node.args[1])
|
||||||
|
|
||||||
def argmax_calc(s: pl.Series) -> pl.Series:
|
def argmax_calc(s: pl.Series) -> pl.Series:
|
||||||
"""计算距离最大值的交易日数。"""
|
|
||||||
values = s.to_numpy()
|
values = s.to_numpy()
|
||||||
n = len(values)
|
n = len(values)
|
||||||
if n == 0:
|
if n < window:
|
||||||
return pl.Series([float("nan")] * n)
|
return pl.Series([float("nan")] * n)
|
||||||
|
|
||||||
result = np.full(n, np.nan)
|
# 创建 2D 视图
|
||||||
for i in range(window - 1, n):
|
windows = np.lib.stride_tricks.sliding_window_view(values, window)
|
||||||
window_slice = values[i - window + 1 : i + 1]
|
|
||||||
# 距离 = 当前索引 - (i - window + 1) - argmax_idx
|
|
||||||
# 其中 argmax_idx = np.argmax(window_slice)
|
|
||||||
argmax_idx = np.argmax(window_slice)
|
|
||||||
distance = window - 1 - argmax_idx
|
|
||||||
result[i] = distance
|
|
||||||
|
|
||||||
|
# 沿窗口轴(axis=1)进行 C 级极速 argmax 查找
|
||||||
|
# 注意:若有缺失值,可用 np.nanargmax 防止报错
|
||||||
|
argmax_indices = np.nanargmax(windows, axis=1)
|
||||||
|
|
||||||
|
# 计算距离
|
||||||
|
distances = window - 1 - argmax_indices
|
||||||
|
|
||||||
|
result = np.full(n, np.nan)
|
||||||
|
result[window - 1:] = distances
|
||||||
return pl.Series(result)
|
return pl.Series(result)
|
||||||
|
|
||||||
return expr.map_batches(argmax_calc, return_dtype=pl.Float64)
|
return expr.map_batches(argmax_calc, return_dtype=pl.Float64)
|
||||||
|
|
||||||
@time_series
|
@time_series
|
||||||
def _handle_ts_argmin(self, node: FunctionNode) -> pl.Expr:
|
def _handle_ts_argmin(self, node: FunctionNode) -> pl.Expr:
|
||||||
"""处理 ts_argmin(x, window) -> 距离最小值的交易日数。
|
"""处理 ts_argmin(x, window) -> 距离最小值的交易日数。"""
|
||||||
|
|
||||||
LOWDAY 语义:返回距离过去 window 期内最小值的交易日数。
|
|
||||||
例如:今天是最低点返回 0,昨天是最低点返回 1。
|
|
||||||
"""
|
|
||||||
if len(node.args) != 2:
|
if len(node.args) != 2:
|
||||||
raise ValueError("ts_argmin 需要 2 个参数: (x, window)")
|
raise ValueError("ts_argmin 需要 2 个参数: (x, window)")
|
||||||
expr = self.translate(node.args[0])
|
expr = self.translate(node.args[0])
|
||||||
window = self._extract_window(node.args[1])
|
window = self._extract_window(node.args[1])
|
||||||
|
|
||||||
def argmin_calc(s: pl.Series) -> pl.Series:
|
def argmin_calc(s: pl.Series) -> pl.Series:
|
||||||
"""计算距离最小值的交易日数。"""
|
|
||||||
values = s.to_numpy()
|
values = s.to_numpy()
|
||||||
n = len(values)
|
n = len(values)
|
||||||
if n == 0:
|
if n < window:
|
||||||
return pl.Series([float("nan")] * n)
|
return pl.Series([float("nan")] * n)
|
||||||
|
|
||||||
result = np.full(n, np.nan)
|
windows = np.lib.stride_tricks.sliding_window_view(values, window)
|
||||||
for i in range(window - 1, n):
|
argmin_indices = np.nanargmin(windows, axis=1)
|
||||||
window_slice = values[i - window + 1 : i + 1]
|
distances = window - 1 - argmin_indices
|
||||||
argmin_idx = np.argmin(window_slice)
|
|
||||||
distance = window - 1 - argmin_idx
|
|
||||||
result[i] = distance
|
|
||||||
|
|
||||||
|
result = np.full(n, np.nan)
|
||||||
|
result[window - 1:] = distances
|
||||||
return pl.Series(result)
|
return pl.Series(result)
|
||||||
|
|
||||||
return expr.map_batches(argmin_calc, return_dtype=pl.Float64)
|
return expr.map_batches(argmin_calc, return_dtype=pl.Float64)
|
||||||
@@ -636,27 +633,24 @@ class PolarsTranslator:
|
|||||||
|
|
||||||
@time_series
|
@time_series
|
||||||
def _handle_ts_prod(self, node: FunctionNode) -> pl.Expr:
|
def _handle_ts_prod(self, node: FunctionNode) -> pl.Expr:
|
||||||
"""处理 ts_prod(x, window) -> 滚动连乘积。
|
"""处理 ts_prod(x, window) -> 滚动连乘积。"""
|
||||||
|
|
||||||
使用 map_batches 实现窗口内的累积乘积。
|
|
||||||
"""
|
|
||||||
if len(node.args) != 2:
|
if len(node.args) != 2:
|
||||||
raise ValueError("ts_prod 需要 2 个参数: (x, window)")
|
raise ValueError("ts_prod 需要 2 个参数: (x, window)")
|
||||||
expr = self.translate(node.args[0])
|
expr = self.translate(node.args[0])
|
||||||
window = self._extract_window(node.args[1])
|
window = self._extract_window(node.args[1])
|
||||||
|
|
||||||
def prod_calc(s: pl.Series) -> pl.Series:
|
def prod_calc(s: pl.Series) -> pl.Series:
|
||||||
"""计算窗口内的连乘积。"""
|
|
||||||
values = s.to_numpy()
|
values = s.to_numpy()
|
||||||
n = len(values)
|
n = len(values)
|
||||||
if n == 0:
|
if n < window:
|
||||||
return pl.Series([float("nan")] * n)
|
return pl.Series([float("nan")] * n)
|
||||||
|
|
||||||
result = np.full(n, np.nan)
|
windows = np.lib.stride_tricks.sliding_window_view(values, window)
|
||||||
for i in range(window - 1, n):
|
# 沿轴求积,极速实现
|
||||||
window_slice = values[i - window + 1 : i + 1]
|
prods = np.prod(windows, axis=1)
|
||||||
result[i] = np.prod(window_slice)
|
|
||||||
|
|
||||||
|
result = np.full(n, np.nan)
|
||||||
|
result[window - 1:] = prods
|
||||||
return pl.Series(result)
|
return pl.Series(result)
|
||||||
|
|
||||||
return expr.map_batches(prod_calc, return_dtype=pl.Float64)
|
return expr.map_batches(prod_calc, return_dtype=pl.Float64)
|
||||||
|
|||||||
Reference in New Issue
Block a user