feat(factors): 添加 GTJA alpha 因子并优化计算性能

- 新增 190+ 个 GTJA alpha 因子到因子列表
- 优化 ts_kurt、ts_rank、ts_argmax/min、ts_prod 计算性能
- 性能分析器新增超时检测(180秒)和实时打印功能
- 简化探针因子选择脚本的 main 函数入口
This commit is contained in:
2026-03-15 22:21:21 +08:00
parent 81e89f3796
commit 5ed06d20d2
5 changed files with 295 additions and 81 deletions

View File

@@ -91,8 +91,11 @@ class FactorEngine:
# 调试模式配置
self._debug = debug
# 初始化性能分析器
self._profiler = PerformanceProfiler(enabled=debug)
# 初始化性能分析器debug 模式下启用实时打印和超时检测)
self._profiler = PerformanceProfiler(
enabled=debug,
immediate_print=debug,
)
@property
def profiler(self) -> PerformanceProfiler:

View File

@@ -5,6 +5,7 @@
df = df.with_columns(expr) # 触发真实计算
"""
import os
import time
from contextlib import contextmanager
from dataclasses import dataclass
@@ -21,6 +22,17 @@ class ProfileRecord:
# call_count: int = 1
class FactorTimeoutError(TimeoutError):
"""因子计算超时异常"""
def __init__(self, factor_name: str, timeout_seconds: float):
self.factor_name = factor_name
self.timeout_seconds = timeout_seconds
super().__init__(
f"因子 '{factor_name}' 计算超时(超过 {timeout_seconds:.0f} 秒)"
)
class PerformanceProfiler:
"""性能分析器 - 独立组件,与 FactorEngine 解耦
@@ -29,20 +41,33 @@ class PerformanceProfiler:
df = df.with_columns(expr) # 触发真实计算
"""
def __init__(self, enabled: bool = False):
# 默认超时时间3分钟180秒
DEFAULT_TIMEOUT_SECONDS = 180
def __init__(
self,
enabled: bool = False,
timeout_seconds: Optional[float] = None,
immediate_print: bool = False,
):
self.enabled = enabled
self.timeout_seconds = timeout_seconds or self.DEFAULT_TIMEOUT_SECONDS
self.immediate_print = immediate_print
self.records: Dict[str, ProfileRecord] = {}
self._context_stack: List[str] = [] # 支持嵌套计时
@contextmanager
def measure(self, name: str) -> Iterator[None]:
"""上下文管理器:安全计时,自动处理异常
"""上下文管理器:安全计时,自动处理异常,支持超时检测
Args:
name: 计时任务名称(通常是因子名称)
Yields:
None
Raises:
FactorTimeoutError: 当计算时间超过超时阈值时
"""
if not self.enabled:
yield
@@ -63,6 +88,43 @@ class PerformanceProfiler:
self.records[name].time_ms += elapsed
# 实时打印当前因子性能信息
if self.immediate_print:
self._print_single_factor(name, elapsed)
# 检查是否超时
if elapsed > self.timeout_seconds:
self._handle_timeout(name, elapsed)
def _print_single_factor(self, name: str, elapsed_seconds: float) -> None:
"""打印单个因子的性能信息
Args:
name: 因子名称
elapsed_seconds: 耗时(秒)
"""
elapsed_ms = elapsed_seconds * 1000
print(f"[Performance] {name}: {elapsed_ms:.2f}ms")
def _handle_timeout(self, name: str, elapsed_seconds: float) -> None:
"""处理超时情况
在 debug 模式下,如果因子计算超过阈值,强制中止进程。
Args:
name: 因子名称
elapsed_seconds: 实际耗时(秒)
"""
error_msg = (
f"[ERROR] 因子 '{name}' 计算超时!"
f" 耗时: {elapsed_seconds:.1f}s阈值: {self.timeout_seconds:.0f}s"
)
print(error_msg, flush=True)
print(f"[ERROR] 强制中止进程", flush=True)
# 强制退出进程
os._exit(1)
def get_report(self) -> Dict[str, Any]:
"""生成性能报告

View File

@@ -359,11 +359,16 @@ class PolarsTranslator:
raise ValueError("ts_kurt 需要 2 个参数: (expr, window)")
expr = self.translate(node.args[0])
window = self._extract_window(node.args[1])
# 使用 rolling_map 计算峰度
return expr.rolling_map(
lambda s: s.kurtosis() if len(s.drop_nulls()) >= 4 else float("nan"),
window_size=window,
)
# 抛弃极慢的 rolling_map借用 pandas 的 Cython 引擎
def kurt_calc(s: pl.Series) -> pl.Series:
import pandas as pd
# pandas.rolling.kurt() 是用 Cython 编写的,速度比 pure python 快很多
pd_series = pd.Series(s.to_numpy())
result = pd_series.rolling(window).kurt().to_numpy()
return pl.Series(result)
return expr.map_batches(kurt_calc, return_dtype=pl.Float64)
@time_series
def _handle_ts_pct_change(self, node: FunctionNode) -> pl.Expr:
@@ -475,30 +480,30 @@ class PolarsTranslator:
@time_series
def _handle_ts_rank(self, node: FunctionNode) -> pl.Expr:
"""处理 ts_rank(x, window) -> 滚动排名(分位数)。
计算当前值在过去窗口内的分位排名0-1之间
"""
"""处理 ts_rank(x, window) -> 滚动排名(分位数)。"""
if len(node.args) != 2:
raise ValueError("ts_rank 需要 2 个参数: (x, window)")
expr = self.translate(node.args[0])
window = self._extract_window(node.args[1])
def rank_calc(s: pl.Series) -> pl.Series:
"""计算滚动排名。"""
values = s.to_numpy()
n = len(values)
if n == 0:
if n < window:
return pl.Series([float("nan")] * n)
result = np.full(n, np.nan)
for i in range(window - 1, n):
window_slice = values[i - window + 1 : i + 1]
# 计算分位排名 (0-1)
current_value = values[i]
rank = np.sum(window_slice <= current_value) / len(window_slice)
result[i] = rank
# 核心魔法:创建零拷贝的 2D 滑动窗口视图
# 形状为 (N - window + 1, window)
windows = np.lib.stride_tricks.sliding_window_view(values, window)
# 当前值即为每个窗口的最后一个元素 (N - window + 1, )
current_vals = windows[:, -1]
# 向量化广播比较,然后沿窗口轴(axis=1)求和,直接得出排名比例
ranks = np.sum(windows <= current_vals[:, None], axis=1) / window
result = np.full(n, np.nan)
result[window - 1:] = ranks
return pl.Series(result)
return expr.map_batches(rank_calc, return_dtype=pl.Float64)
@@ -564,62 +569,54 @@ class PolarsTranslator:
@time_series
def _handle_ts_argmax(self, node: FunctionNode) -> pl.Expr:
"""处理 ts_argmax(x, window) -> 距离最大值的交易日数。
HIGHDAY 语义:返回距离过去 window 期内最大值的交易日数。
例如:今天是最高点返回 0昨天是最高点返回 1。
"""
"""处理 ts_argmax(x, window) -> 距离最大值的交易日数。"""
if len(node.args) != 2:
raise ValueError("ts_argmax 需要 2 个参数: (x, window)")
expr = self.translate(node.args[0])
window = self._extract_window(node.args[1])
def argmax_calc(s: pl.Series) -> pl.Series:
"""计算距离最大值的交易日数。"""
values = s.to_numpy()
n = len(values)
if n == 0:
if n < window:
return pl.Series([float("nan")] * n)
result = np.full(n, np.nan)
for i in range(window - 1, n):
window_slice = values[i - window + 1 : i + 1]
# 距离 = 当前索引 - (i - window + 1) - argmax_idx
# 其中 argmax_idx = np.argmax(window_slice)
argmax_idx = np.argmax(window_slice)
distance = window - 1 - argmax_idx
result[i] = distance
# 创建 2D 视图
windows = np.lib.stride_tricks.sliding_window_view(values, window)
# 沿窗口轴(axis=1)进行 C 级极速 argmax 查找
# 注意:若有缺失值,可用 np.nanargmax 防止报错
argmax_indices = np.nanargmax(windows, axis=1)
# 计算距离
distances = window - 1 - argmax_indices
result = np.full(n, np.nan)
result[window - 1:] = distances
return pl.Series(result)
return expr.map_batches(argmax_calc, return_dtype=pl.Float64)
@time_series
def _handle_ts_argmin(self, node: FunctionNode) -> pl.Expr:
"""处理 ts_argmin(x, window) -> 距离最小值的交易日数。
LOWDAY 语义:返回距离过去 window 期内最小值的交易日数。
例如:今天是最低点返回 0昨天是最低点返回 1。
"""
"""处理 ts_argmin(x, window) -> 距离最小值的交易日数。"""
if len(node.args) != 2:
raise ValueError("ts_argmin 需要 2 个参数: (x, window)")
expr = self.translate(node.args[0])
window = self._extract_window(node.args[1])
def argmin_calc(s: pl.Series) -> pl.Series:
"""计算距离最小值的交易日数。"""
values = s.to_numpy()
n = len(values)
if n == 0:
if n < window:
return pl.Series([float("nan")] * n)
result = np.full(n, np.nan)
for i in range(window - 1, n):
window_slice = values[i - window + 1 : i + 1]
argmin_idx = np.argmin(window_slice)
distance = window - 1 - argmin_idx
result[i] = distance
windows = np.lib.stride_tricks.sliding_window_view(values, window)
argmin_indices = np.nanargmin(windows, axis=1)
distances = window - 1 - argmin_indices
result = np.full(n, np.nan)
result[window - 1:] = distances
return pl.Series(result)
return expr.map_batches(argmin_calc, return_dtype=pl.Float64)
@@ -636,27 +633,24 @@ class PolarsTranslator:
@time_series
def _handle_ts_prod(self, node: FunctionNode) -> pl.Expr:
"""处理 ts_prod(x, window) -> 滚动连乘积。
使用 map_batches 实现窗口内的累积乘积。
"""
"""处理 ts_prod(x, window) -> 滚动连乘积。"""
if len(node.args) != 2:
raise ValueError("ts_prod 需要 2 个参数: (x, window)")
expr = self.translate(node.args[0])
window = self._extract_window(node.args[1])
def prod_calc(s: pl.Series) -> pl.Series:
"""计算窗口内的连乘积。"""
values = s.to_numpy()
n = len(values)
if n == 0:
if n < window:
return pl.Series([float("nan")] * n)
result = np.full(n, np.nan)
for i in range(window - 1, n):
window_slice = values[i - window + 1 : i + 1]
result[i] = np.prod(window_slice)
windows = np.lib.stride_tricks.sliding_window_view(values, window)
# 沿轴求积,极速实现
prods = np.prod(windows, axis=1)
result = np.full(n, np.nan)
result[window - 1:] = prods
return pl.Series(result)
return expr.map_batches(prod_calc, return_dtype=pl.Float64)