feat(factors): 添加 GTJA alpha 因子并优化计算性能
- 新增 190+ 个 GTJA alpha 因子到因子列表 - 优化 ts_kurt、ts_rank、ts_argmax/min、ts_prod 计算性能 - 性能分析器新增超时检测(180秒)和实时打印功能 - 简化探针因子选择脚本的 main 函数入口
This commit is contained in:
@@ -91,8 +91,11 @@ class FactorEngine:
|
||||
# 调试模式配置
|
||||
self._debug = debug
|
||||
|
||||
# 初始化性能分析器
|
||||
self._profiler = PerformanceProfiler(enabled=debug)
|
||||
# 初始化性能分析器(debug 模式下启用实时打印和超时检测)
|
||||
self._profiler = PerformanceProfiler(
|
||||
enabled=debug,
|
||||
immediate_print=debug,
|
||||
)
|
||||
|
||||
@property
|
||||
def profiler(self) -> PerformanceProfiler:
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
df = df.with_columns(expr) # 触发真实计算
|
||||
"""
|
||||
|
||||
import os
|
||||
import time
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
@@ -21,6 +22,17 @@ class ProfileRecord:
|
||||
# call_count: int = 1
|
||||
|
||||
|
||||
class FactorTimeoutError(TimeoutError):
|
||||
"""因子计算超时异常"""
|
||||
|
||||
def __init__(self, factor_name: str, timeout_seconds: float):
|
||||
self.factor_name = factor_name
|
||||
self.timeout_seconds = timeout_seconds
|
||||
super().__init__(
|
||||
f"因子 '{factor_name}' 计算超时(超过 {timeout_seconds:.0f} 秒)"
|
||||
)
|
||||
|
||||
|
||||
class PerformanceProfiler:
|
||||
"""性能分析器 - 独立组件,与 FactorEngine 解耦
|
||||
|
||||
@@ -29,20 +41,33 @@ class PerformanceProfiler:
|
||||
df = df.with_columns(expr) # 触发真实计算
|
||||
"""
|
||||
|
||||
def __init__(self, enabled: bool = False):
|
||||
# 默认超时时间:3分钟(180秒)
|
||||
DEFAULT_TIMEOUT_SECONDS = 180
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
enabled: bool = False,
|
||||
timeout_seconds: Optional[float] = None,
|
||||
immediate_print: bool = False,
|
||||
):
|
||||
self.enabled = enabled
|
||||
self.timeout_seconds = timeout_seconds or self.DEFAULT_TIMEOUT_SECONDS
|
||||
self.immediate_print = immediate_print
|
||||
self.records: Dict[str, ProfileRecord] = {}
|
||||
self._context_stack: List[str] = [] # 支持嵌套计时
|
||||
|
||||
@contextmanager
|
||||
def measure(self, name: str) -> Iterator[None]:
|
||||
"""上下文管理器:安全计时,自动处理异常
|
||||
"""上下文管理器:安全计时,自动处理异常,支持超时检测
|
||||
|
||||
Args:
|
||||
name: 计时任务名称(通常是因子名称)
|
||||
|
||||
Yields:
|
||||
None
|
||||
|
||||
Raises:
|
||||
FactorTimeoutError: 当计算时间超过超时阈值时
|
||||
"""
|
||||
if not self.enabled:
|
||||
yield
|
||||
@@ -63,6 +88,43 @@ class PerformanceProfiler:
|
||||
|
||||
self.records[name].time_ms += elapsed
|
||||
|
||||
# 实时打印当前因子性能信息
|
||||
if self.immediate_print:
|
||||
self._print_single_factor(name, elapsed)
|
||||
|
||||
# 检查是否超时
|
||||
if elapsed > self.timeout_seconds:
|
||||
self._handle_timeout(name, elapsed)
|
||||
|
||||
def _print_single_factor(self, name: str, elapsed_seconds: float) -> None:
|
||||
"""打印单个因子的性能信息
|
||||
|
||||
Args:
|
||||
name: 因子名称
|
||||
elapsed_seconds: 耗时(秒)
|
||||
"""
|
||||
elapsed_ms = elapsed_seconds * 1000
|
||||
print(f"[Performance] {name}: {elapsed_ms:.2f}ms")
|
||||
|
||||
def _handle_timeout(self, name: str, elapsed_seconds: float) -> None:
|
||||
"""处理超时情况
|
||||
|
||||
在 debug 模式下,如果因子计算超过阈值,强制中止进程。
|
||||
|
||||
Args:
|
||||
name: 因子名称
|
||||
elapsed_seconds: 实际耗时(秒)
|
||||
"""
|
||||
error_msg = (
|
||||
f"[ERROR] 因子 '{name}' 计算超时!"
|
||||
f" 耗时: {elapsed_seconds:.1f}s,阈值: {self.timeout_seconds:.0f}s"
|
||||
)
|
||||
print(error_msg, flush=True)
|
||||
print(f"[ERROR] 强制中止进程", flush=True)
|
||||
|
||||
# 强制退出进程
|
||||
os._exit(1)
|
||||
|
||||
def get_report(self) -> Dict[str, Any]:
|
||||
"""生成性能报告
|
||||
|
||||
|
||||
@@ -359,11 +359,16 @@ class PolarsTranslator:
|
||||
raise ValueError("ts_kurt 需要 2 个参数: (expr, window)")
|
||||
expr = self.translate(node.args[0])
|
||||
window = self._extract_window(node.args[1])
|
||||
# 使用 rolling_map 计算峰度
|
||||
return expr.rolling_map(
|
||||
lambda s: s.kurtosis() if len(s.drop_nulls()) >= 4 else float("nan"),
|
||||
window_size=window,
|
||||
)
|
||||
|
||||
# 抛弃极慢的 rolling_map,借用 pandas 的 Cython 引擎
|
||||
def kurt_calc(s: pl.Series) -> pl.Series:
|
||||
import pandas as pd
|
||||
# pandas.rolling.kurt() 是用 Cython 编写的,速度比 pure python 快很多
|
||||
pd_series = pd.Series(s.to_numpy())
|
||||
result = pd_series.rolling(window).kurt().to_numpy()
|
||||
return pl.Series(result)
|
||||
|
||||
return expr.map_batches(kurt_calc, return_dtype=pl.Float64)
|
||||
|
||||
@time_series
|
||||
def _handle_ts_pct_change(self, node: FunctionNode) -> pl.Expr:
|
||||
@@ -475,30 +480,30 @@ class PolarsTranslator:
|
||||
|
||||
@time_series
|
||||
def _handle_ts_rank(self, node: FunctionNode) -> pl.Expr:
|
||||
"""处理 ts_rank(x, window) -> 滚动排名(分位数)。
|
||||
|
||||
计算当前值在过去窗口内的分位排名(0-1之间)。
|
||||
"""
|
||||
"""处理 ts_rank(x, window) -> 滚动排名(分位数)。"""
|
||||
if len(node.args) != 2:
|
||||
raise ValueError("ts_rank 需要 2 个参数: (x, window)")
|
||||
expr = self.translate(node.args[0])
|
||||
window = self._extract_window(node.args[1])
|
||||
|
||||
def rank_calc(s: pl.Series) -> pl.Series:
|
||||
"""计算滚动排名。"""
|
||||
values = s.to_numpy()
|
||||
n = len(values)
|
||||
if n == 0:
|
||||
if n < window:
|
||||
return pl.Series([float("nan")] * n)
|
||||
|
||||
result = np.full(n, np.nan)
|
||||
for i in range(window - 1, n):
|
||||
window_slice = values[i - window + 1 : i + 1]
|
||||
# 计算分位排名 (0-1)
|
||||
current_value = values[i]
|
||||
rank = np.sum(window_slice <= current_value) / len(window_slice)
|
||||
result[i] = rank
|
||||
# 核心魔法:创建零拷贝的 2D 滑动窗口视图
|
||||
# 形状为 (N - window + 1, window)
|
||||
windows = np.lib.stride_tricks.sliding_window_view(values, window)
|
||||
|
||||
# 当前值即为每个窗口的最后一个元素 (N - window + 1, )
|
||||
current_vals = windows[:, -1]
|
||||
|
||||
# 向量化广播比较,然后沿窗口轴(axis=1)求和,直接得出排名比例
|
||||
ranks = np.sum(windows <= current_vals[:, None], axis=1) / window
|
||||
|
||||
result = np.full(n, np.nan)
|
||||
result[window - 1:] = ranks
|
||||
return pl.Series(result)
|
||||
|
||||
return expr.map_batches(rank_calc, return_dtype=pl.Float64)
|
||||
@@ -564,62 +569,54 @@ class PolarsTranslator:
|
||||
|
||||
@time_series
|
||||
def _handle_ts_argmax(self, node: FunctionNode) -> pl.Expr:
|
||||
"""处理 ts_argmax(x, window) -> 距离最大值的交易日数。
|
||||
|
||||
HIGHDAY 语义:返回距离过去 window 期内最大值的交易日数。
|
||||
例如:今天是最高点返回 0,昨天是最高点返回 1。
|
||||
"""
|
||||
"""处理 ts_argmax(x, window) -> 距离最大值的交易日数。"""
|
||||
if len(node.args) != 2:
|
||||
raise ValueError("ts_argmax 需要 2 个参数: (x, window)")
|
||||
expr = self.translate(node.args[0])
|
||||
window = self._extract_window(node.args[1])
|
||||
|
||||
def argmax_calc(s: pl.Series) -> pl.Series:
|
||||
"""计算距离最大值的交易日数。"""
|
||||
values = s.to_numpy()
|
||||
n = len(values)
|
||||
if n == 0:
|
||||
if n < window:
|
||||
return pl.Series([float("nan")] * n)
|
||||
|
||||
result = np.full(n, np.nan)
|
||||
for i in range(window - 1, n):
|
||||
window_slice = values[i - window + 1 : i + 1]
|
||||
# 距离 = 当前索引 - (i - window + 1) - argmax_idx
|
||||
# 其中 argmax_idx = np.argmax(window_slice)
|
||||
argmax_idx = np.argmax(window_slice)
|
||||
distance = window - 1 - argmax_idx
|
||||
result[i] = distance
|
||||
# 创建 2D 视图
|
||||
windows = np.lib.stride_tricks.sliding_window_view(values, window)
|
||||
|
||||
# 沿窗口轴(axis=1)进行 C 级极速 argmax 查找
|
||||
# 注意:若有缺失值,可用 np.nanargmax 防止报错
|
||||
argmax_indices = np.nanargmax(windows, axis=1)
|
||||
|
||||
# 计算距离
|
||||
distances = window - 1 - argmax_indices
|
||||
|
||||
result = np.full(n, np.nan)
|
||||
result[window - 1:] = distances
|
||||
return pl.Series(result)
|
||||
|
||||
return expr.map_batches(argmax_calc, return_dtype=pl.Float64)
|
||||
|
||||
@time_series
|
||||
def _handle_ts_argmin(self, node: FunctionNode) -> pl.Expr:
|
||||
"""处理 ts_argmin(x, window) -> 距离最小值的交易日数。
|
||||
|
||||
LOWDAY 语义:返回距离过去 window 期内最小值的交易日数。
|
||||
例如:今天是最低点返回 0,昨天是最低点返回 1。
|
||||
"""
|
||||
"""处理 ts_argmin(x, window) -> 距离最小值的交易日数。"""
|
||||
if len(node.args) != 2:
|
||||
raise ValueError("ts_argmin 需要 2 个参数: (x, window)")
|
||||
expr = self.translate(node.args[0])
|
||||
window = self._extract_window(node.args[1])
|
||||
|
||||
def argmin_calc(s: pl.Series) -> pl.Series:
|
||||
"""计算距离最小值的交易日数。"""
|
||||
values = s.to_numpy()
|
||||
n = len(values)
|
||||
if n == 0:
|
||||
if n < window:
|
||||
return pl.Series([float("nan")] * n)
|
||||
|
||||
result = np.full(n, np.nan)
|
||||
for i in range(window - 1, n):
|
||||
window_slice = values[i - window + 1 : i + 1]
|
||||
argmin_idx = np.argmin(window_slice)
|
||||
distance = window - 1 - argmin_idx
|
||||
result[i] = distance
|
||||
windows = np.lib.stride_tricks.sliding_window_view(values, window)
|
||||
argmin_indices = np.nanargmin(windows, axis=1)
|
||||
distances = window - 1 - argmin_indices
|
||||
|
||||
result = np.full(n, np.nan)
|
||||
result[window - 1:] = distances
|
||||
return pl.Series(result)
|
||||
|
||||
return expr.map_batches(argmin_calc, return_dtype=pl.Float64)
|
||||
@@ -636,27 +633,24 @@ class PolarsTranslator:
|
||||
|
||||
@time_series
|
||||
def _handle_ts_prod(self, node: FunctionNode) -> pl.Expr:
|
||||
"""处理 ts_prod(x, window) -> 滚动连乘积。
|
||||
|
||||
使用 map_batches 实现窗口内的累积乘积。
|
||||
"""
|
||||
"""处理 ts_prod(x, window) -> 滚动连乘积。"""
|
||||
if len(node.args) != 2:
|
||||
raise ValueError("ts_prod 需要 2 个参数: (x, window)")
|
||||
expr = self.translate(node.args[0])
|
||||
window = self._extract_window(node.args[1])
|
||||
|
||||
def prod_calc(s: pl.Series) -> pl.Series:
|
||||
"""计算窗口内的连乘积。"""
|
||||
values = s.to_numpy()
|
||||
n = len(values)
|
||||
if n == 0:
|
||||
if n < window:
|
||||
return pl.Series([float("nan")] * n)
|
||||
|
||||
result = np.full(n, np.nan)
|
||||
for i in range(window - 1, n):
|
||||
window_slice = values[i - window + 1 : i + 1]
|
||||
result[i] = np.prod(window_slice)
|
||||
windows = np.lib.stride_tricks.sliding_window_view(values, window)
|
||||
# 沿轴求积,极速实现
|
||||
prods = np.prod(windows, axis=1)
|
||||
|
||||
result = np.full(n, np.nan)
|
||||
result[window - 1:] = prods
|
||||
return pl.Series(result)
|
||||
|
||||
return expr.map_batches(prod_calc, return_dtype=pl.Float64)
|
||||
|
||||
Reference in New Issue
Block a user