- 新增 atan, log1p 数学函数 - 新增 ts_var, ts_skew, ts_kurt, ts_pct_change, ts_ema 统计函数 - 新增 ts_atr, ts_rsi, ts_obv TA-Lib 技术指标函数 - 新增完整集成测试覆盖所有新函数
633 lines
23 KiB
Python
633 lines
23 KiB
Python
"""Polars 翻译器 - 将 AST 翻译为 Polars 表达式。
|
||
|
||
本模块实现 DSL 到 Polars 计算图的映射,是因子表达式执行的桥梁。
|
||
支持时序因子(ts_*)和截面因子(cs_*)的防错分组翻译。
|
||
"""
|
||
|
||
from typing import Any, Callable, Dict
|
||
|
||
import numpy as np
|
||
import polars as pl
|
||
|
||
# TA-Lib 可选依赖
|
||
try:
|
||
import talib
|
||
|
||
HAS_TALIB = True
|
||
except ImportError:
|
||
HAS_TALIB = False
|
||
talib = None
|
||
|
||
from src.factors.decorators import cross_section, element_wise, time_series
|
||
from src.factors.dsl import (
|
||
BinaryOpNode,
|
||
Constant,
|
||
FunctionNode,
|
||
Node,
|
||
Symbol,
|
||
UnaryOpNode,
|
||
)
|
||
|
||
|
||
class PolarsTranslator:
|
||
"""Polars 表达式翻译器。
|
||
|
||
将纯对象的 AST 树完美映射为 Polars 的带防错分组的计算图。
|
||
|
||
Attributes:
|
||
handlers: 函数处理器注册表,映射 func_name 到处理函数
|
||
|
||
Example:
|
||
>>> from src.factors.dsl import Symbol, FunctionNode
|
||
>>> close = Symbol("close")
|
||
>>> expr = FunctionNode("ts_mean", close, 20)
|
||
>>> translator = PolarsTranslator()
|
||
>>> polars_expr = translator.translate(expr)
|
||
>>> # 结果: pl.col("close").rolling_mean(20).over("asset")
|
||
"""
|
||
|
||
def __init__(self) -> None:
|
||
"""初始化翻译器并注册内置函数处理器。"""
|
||
self.handlers: Dict[str, Callable[[FunctionNode], pl.Expr]] = {}
|
||
self._register_builtin_handlers()
|
||
|
||
def _register_builtin_handlers(self) -> None:
|
||
"""注册内置的函数处理器。"""
|
||
# 时序因子处理器 (ts_*)
|
||
self.register_handler("ts_mean", self._handle_ts_mean)
|
||
self.register_handler("ts_sum", self._handle_ts_sum)
|
||
self.register_handler("ts_std", self._handle_ts_std)
|
||
self.register_handler("ts_max", self._handle_ts_max)
|
||
self.register_handler("ts_min", self._handle_ts_min)
|
||
self.register_handler("ts_delay", self._handle_ts_delay)
|
||
self.register_handler("ts_delta", self._handle_ts_delta)
|
||
self.register_handler("ts_corr", self._handle_ts_corr)
|
||
self.register_handler("ts_cov", self._handle_ts_cov)
|
||
self.register_handler("ts_var", self._handle_ts_var)
|
||
self.register_handler("ts_skew", self._handle_ts_skew)
|
||
self.register_handler("ts_kurt", self._handle_ts_kurt)
|
||
self.register_handler("ts_pct_change", self._handle_ts_pct_change)
|
||
self.register_handler("ts_ema", self._handle_ts_ema)
|
||
self.register_handler("ts_atr", self._handle_ts_atr)
|
||
self.register_handler("ts_rsi", self._handle_ts_rsi)
|
||
self.register_handler("ts_obv", self._handle_ts_obv)
|
||
|
||
# 截面因子处理器 (cs_*)
|
||
self.register_handler("cs_rank", self._handle_cs_rank)
|
||
self.register_handler("cs_zscore", self._handle_cs_zscore)
|
||
self.register_handler("cs_neutral", self._handle_cs_neutral)
|
||
|
||
# 元素级数学函数 (element_wise)
|
||
self.register_handler("log", self._handle_log)
|
||
self.register_handler("exp", self._handle_exp)
|
||
self.register_handler("sqrt", self._handle_sqrt)
|
||
self.register_handler("sign", self._handle_sign)
|
||
self.register_handler("cos", self._handle_cos)
|
||
self.register_handler("sin", self._handle_sin)
|
||
self.register_handler("atan", self._handle_atan)
|
||
self.register_handler("log1p", self._handle_log1p)
|
||
|
||
def register_handler(
|
||
self, func_name: str, handler: Callable[[FunctionNode], pl.Expr]
|
||
) -> None:
|
||
"""注册自定义函数处理器。
|
||
|
||
Args:
|
||
func_name: 函数名称
|
||
handler: 处理函数,接收 FunctionNode 返回 pl.Expr
|
||
|
||
Example:
|
||
>>> def handle_custom(node: FunctionNode) -> pl.Expr:
|
||
... arg = self.translate(node.args[0])
|
||
... return arg * 2
|
||
>>> translator.register_handler("custom", handle_custom)
|
||
"""
|
||
self.handlers[func_name] = handler
|
||
|
||
def translate(self, node: Node) -> pl.Expr:
|
||
"""递归翻译 AST 节点为 Polars 表达式。
|
||
|
||
Args:
|
||
node: AST 节点(Symbol、Constant、BinaryOpNode、UnaryOpNode、FunctionNode)
|
||
|
||
Returns:
|
||
Polars 表达式对象
|
||
|
||
Raises:
|
||
TypeError: 当遇到未知的节点类型时
|
||
"""
|
||
if isinstance(node, Symbol):
|
||
return self._translate_symbol(node)
|
||
elif isinstance(node, Constant):
|
||
return self._translate_constant(node)
|
||
elif isinstance(node, BinaryOpNode):
|
||
return self._translate_binary_op(node)
|
||
elif isinstance(node, UnaryOpNode):
|
||
return self._translate_unary_op(node)
|
||
elif isinstance(node, FunctionNode):
|
||
return self._translate_function(node)
|
||
else:
|
||
raise TypeError(f"未知的节点类型: {type(node).__name__}")
|
||
|
||
def _translate_symbol(self, node: Symbol) -> pl.Expr:
|
||
"""翻译 Symbol 节点为 pl.col() 表达式。
|
||
|
||
Args:
|
||
node: 符号节点
|
||
|
||
Returns:
|
||
pl.col(node.name) 表达式
|
||
"""
|
||
return pl.col(node.name)
|
||
|
||
def _translate_constant(self, node: Constant) -> pl.Expr:
|
||
"""翻译 Constant 节点为 Polars 字面量。
|
||
|
||
Args:
|
||
node: 常量节点
|
||
|
||
Returns:
|
||
pl.lit(node.value) 表达式
|
||
"""
|
||
return pl.lit(node.value)
|
||
|
||
def _translate_binary_op(self, node: BinaryOpNode) -> pl.Expr:
|
||
"""翻译 BinaryOpNode 为 Polars 二元运算。
|
||
|
||
Args:
|
||
node: 二元运算节点
|
||
|
||
Returns:
|
||
Polars 二元运算表达式
|
||
"""
|
||
left = self.translate(node.left)
|
||
right = self.translate(node.right)
|
||
|
||
op_map = {
|
||
"+": lambda l, r: l + r,
|
||
"-": lambda l, r: l - r,
|
||
"*": lambda l, r: l * r,
|
||
"/": lambda l, r: l / r,
|
||
"**": lambda l, r: l.pow(r),
|
||
"//": lambda l, r: l.floor_div(r),
|
||
"%": lambda l, r: l % r,
|
||
"==": lambda l, r: l.eq(r),
|
||
"!=": lambda l, r: l.ne(r),
|
||
"<": lambda l, r: l.lt(r),
|
||
"<=": lambda l, r: l.le(r),
|
||
">": lambda l, r: l.gt(r),
|
||
">=": lambda l, r: l.ge(r),
|
||
}
|
||
|
||
if node.op not in op_map:
|
||
raise ValueError(f"不支持的二元运算符: {node.op}")
|
||
|
||
return op_map[node.op](left, right)
|
||
|
||
def _translate_unary_op(self, node: UnaryOpNode) -> pl.Expr:
|
||
"""翻译 UnaryOpNode 为 Polars 一元运算。
|
||
|
||
Args:
|
||
node: 一元运算节点
|
||
|
||
Returns:
|
||
Polars 一元运算表达式
|
||
"""
|
||
operand = self.translate(node.operand)
|
||
|
||
op_map = {
|
||
"+": lambda x: x,
|
||
"-": lambda x: -x,
|
||
"abs": lambda x: x.abs(),
|
||
}
|
||
|
||
if node.op not in op_map:
|
||
raise ValueError(f"不支持的一元运算符: {node.op}")
|
||
|
||
return op_map[node.op](operand)
|
||
|
||
def _translate_function(self, node: FunctionNode) -> pl.Expr:
|
||
"""翻译 FunctionNode 为 Polars 函数调用。
|
||
|
||
优先从 handlers 注册表中查找处理器,未找到则抛出错误。
|
||
|
||
Args:
|
||
node: 函数调用节点
|
||
|
||
Returns:
|
||
Polars 函数表达式
|
||
|
||
Raises:
|
||
ValueError: 当函数名称未注册处理器时
|
||
"""
|
||
func_name = node.func_name
|
||
|
||
if func_name in self.handlers:
|
||
return self.handlers[func_name](node)
|
||
else:
|
||
raise ValueError(
|
||
f"未注册的函数: {func_name}. 请使用 register_handler 注册处理器。"
|
||
)
|
||
|
||
# ==================== 时序因子处理器 (ts_*) ====================
|
||
# 所有时序因子使用 @time_series 装饰器自动注入 over("ts_code") 防串表
|
||
|
||
@time_series
|
||
def _handle_ts_mean(self, node: FunctionNode) -> pl.Expr:
|
||
"""处理 ts_mean(close, window) -> rolling_mean(window)。"""
|
||
if len(node.args) != 2:
|
||
raise ValueError("ts_mean 需要 2 个参数: (expr, window)")
|
||
expr = self.translate(node.args[0])
|
||
window = self._extract_window(node.args[1])
|
||
return expr.rolling_mean(window_size=window)
|
||
|
||
@time_series
|
||
def _handle_ts_sum(self, node: FunctionNode) -> pl.Expr:
|
||
"""处理 ts_sum(close, window) -> rolling_sum(window)。"""
|
||
if len(node.args) != 2:
|
||
raise ValueError("ts_sum 需要 2 个参数: (expr, window)")
|
||
expr = self.translate(node.args[0])
|
||
window = self._extract_window(node.args[1])
|
||
return expr.rolling_sum(window_size=window)
|
||
|
||
@time_series
|
||
def _handle_ts_std(self, node: FunctionNode) -> pl.Expr:
|
||
"""处理 ts_std(close, window) -> rolling_std(window)。"""
|
||
if len(node.args) != 2:
|
||
raise ValueError("ts_std 需要 2 个参数: (expr, window)")
|
||
expr = self.translate(node.args[0])
|
||
window = self._extract_window(node.args[1])
|
||
return expr.rolling_std(window_size=window)
|
||
|
||
@time_series
|
||
def _handle_ts_max(self, node: FunctionNode) -> pl.Expr:
|
||
"""处理 ts_max(close, window) -> rolling_max(window)。"""
|
||
if len(node.args) != 2:
|
||
raise ValueError("ts_max 需要 2 个参数: (expr, window)")
|
||
expr = self.translate(node.args[0])
|
||
window = self._extract_window(node.args[1])
|
||
return expr.rolling_max(window_size=window)
|
||
|
||
@time_series
|
||
def _handle_ts_min(self, node: FunctionNode) -> pl.Expr:
|
||
"""处理 ts_min(close, window) -> rolling_min(window)。"""
|
||
if len(node.args) != 2:
|
||
raise ValueError("ts_min 需要 2 个参数: (expr, window)")
|
||
expr = self.translate(node.args[0])
|
||
window = self._extract_window(node.args[1])
|
||
return expr.rolling_min(window_size=window)
|
||
|
||
@time_series
|
||
def _handle_ts_delay(self, node: FunctionNode) -> pl.Expr:
|
||
"""处理 ts_delay(close, n) -> shift(n)。"""
|
||
if len(node.args) != 2:
|
||
raise ValueError("ts_delay 需要 2 个参数: (expr, n)")
|
||
expr = self.translate(node.args[0])
|
||
n = self._extract_window(node.args[1])
|
||
return expr.shift(n)
|
||
|
||
@time_series
|
||
def _handle_ts_delta(self, node: FunctionNode) -> pl.Expr:
|
||
"""处理 ts_delta(close, n) -> (expr - shift(n))。"""
|
||
if len(node.args) != 2:
|
||
raise ValueError("ts_delta 需要 2 个参数: (expr, n)")
|
||
expr = self.translate(node.args[0])
|
||
n = self._extract_window(node.args[1])
|
||
return expr - expr.shift(n)
|
||
|
||
@time_series
|
||
def _handle_ts_corr(self, node: FunctionNode) -> pl.Expr:
|
||
"""处理 ts_corr(x, y, window) -> rolling_corr(y, window)。"""
|
||
if len(node.args) != 3:
|
||
raise ValueError("ts_corr 需要 3 个参数: (x, y, window)")
|
||
x = self.translate(node.args[0])
|
||
y = self.translate(node.args[1])
|
||
window = self._extract_window(node.args[2])
|
||
return x.rolling_corr(y, window_size=window)
|
||
|
||
@time_series
|
||
def _handle_ts_cov(self, node: FunctionNode) -> pl.Expr:
|
||
"""处理 ts_cov(x, y, window) -> rolling_cov(y, window)。"""
|
||
if len(node.args) != 3:
|
||
raise ValueError("ts_cov 需要 3 个参数: (x, y, window)")
|
||
x = self.translate(node.args[0])
|
||
y = self.translate(node.args[1])
|
||
window = self._extract_window(node.args[2])
|
||
return x.rolling_cov(y, window_size=window)
|
||
|
||
@time_series
|
||
def _handle_ts_var(self, node: FunctionNode) -> pl.Expr:
|
||
"""处理 ts_var(close, window) -> rolling_var(window)。"""
|
||
if len(node.args) != 2:
|
||
raise ValueError("ts_var 需要 2 个参数: (expr, window)")
|
||
expr = self.translate(node.args[0])
|
||
window = self._extract_window(node.args[1])
|
||
return expr.rolling_var(window_size=window)
|
||
|
||
@time_series
|
||
def _handle_ts_skew(self, node: FunctionNode) -> pl.Expr:
|
||
"""处理 ts_skew(close, window) -> rolling_skew(window)。"""
|
||
if len(node.args) != 2:
|
||
raise ValueError("ts_skew 需要 2 个参数: (expr, window)")
|
||
expr = self.translate(node.args[0])
|
||
window = self._extract_window(node.args[1])
|
||
return expr.rolling_skew(window_size=window)
|
||
|
||
@time_series
|
||
def _handle_ts_kurt(self, node: FunctionNode) -> pl.Expr:
|
||
"""处理 ts_kurt(close, window) -> rolling_kurt(window)。"""
|
||
if len(node.args) != 2:
|
||
raise ValueError("ts_kurt 需要 2 个参数: (expr, window)")
|
||
expr = self.translate(node.args[0])
|
||
window = self._extract_window(node.args[1])
|
||
# 使用 rolling_map 计算峰度
|
||
return expr.rolling_map(
|
||
lambda s: s.kurtosis() if len(s.drop_nulls()) >= 4 else float("nan"),
|
||
window_size=window,
|
||
)
|
||
|
||
@time_series
|
||
def _handle_ts_pct_change(self, node: FunctionNode) -> pl.Expr:
|
||
"""处理 ts_pct_change(x, n) -> (x - shift(n)) / shift(n)。"""
|
||
if len(node.args) != 2:
|
||
raise ValueError("ts_pct_change 需要 2 个参数: (expr, periods)")
|
||
expr = self.translate(node.args[0])
|
||
n = self._extract_window(node.args[1])
|
||
shifted = expr.shift(n)
|
||
return (expr - shifted) / shifted
|
||
|
||
@time_series
|
||
def _handle_ts_ema(self, node: FunctionNode) -> pl.Expr:
|
||
"""处理 ts_ema(x, window) -> ewm_mean(span=window)。"""
|
||
if len(node.args) != 2:
|
||
raise ValueError("ts_ema 需要 2 个参数: (expr, window)")
|
||
expr = self.translate(node.args[0])
|
||
window = self._extract_window(node.args[1])
|
||
return expr.ewm_mean(span=window)
|
||
|
||
@time_series
|
||
def _handle_ts_atr(self, node: FunctionNode) -> pl.Expr:
|
||
"""处理 ts_atr(high, low, close, window) -> 使用 TA-Lib 计算 ATR。
|
||
|
||
使用 map_batches 在每个分组上应用 TA-Lib ATR 函数。
|
||
@time_series 装饰器会自动添加 .over("ts_code")
|
||
"""
|
||
if not HAS_TALIB:
|
||
raise ImportError("ts_atr 需要安装 TA-Lib。请运行: pip install TA-Lib")
|
||
if len(node.args) != 4:
|
||
raise ValueError("ts_atr 需要 4 个参数: (high, low, close, window)")
|
||
|
||
high = self.translate(node.args[0])
|
||
low = self.translate(node.args[1])
|
||
close = self.translate(node.args[2])
|
||
window = self._extract_window(node.args[3])
|
||
|
||
# 使用 map_batches 应用 TA-Lib ATR 到整个分组
|
||
def calc_atr(struct_series: pl.Series) -> pl.Series:
|
||
"""计算 ATR 的辅助函数。"""
|
||
if len(struct_series) == 0:
|
||
return pl.Series([float("nan")] * len(struct_series))
|
||
|
||
# struct_series 包含 h, l, c 三个字段
|
||
h = np.array(struct_series.struct.field("h").to_list(), dtype=float)
|
||
l = np.array(struct_series.struct.field("l").to_list(), dtype=float)
|
||
c = np.array(struct_series.struct.field("c").to_list(), dtype=float)
|
||
result = talib.ATR(h, l, c, timeperiod=window)
|
||
return pl.Series(result)
|
||
|
||
return pl.struct(
|
||
[high.alias("h"), low.alias("l"), close.alias("c")]
|
||
).map_batches(calc_atr)
|
||
|
||
@time_series
|
||
def _handle_ts_rsi(self, node: FunctionNode) -> pl.Expr:
|
||
"""处理 ts_rsi(close, window) -> 使用 TA-Lib 计算 RSI。
|
||
|
||
使用 map_batches 在每个分组上应用 TA-Lib RSI 函数。
|
||
@time_series 装饰器会自动添加 .over("ts_code")
|
||
"""
|
||
if not HAS_TALIB:
|
||
raise ImportError("ts_rsi 需要安装 TA-Lib。请运行: pip install TA-Lib")
|
||
if len(node.args) != 2:
|
||
raise ValueError("ts_rsi 需要 2 个参数: (close, window)")
|
||
|
||
close = self.translate(node.args[0])
|
||
window = self._extract_window(node.args[1])
|
||
|
||
# 使用 map_batches 应用 TA-Lib RSI 到整个分组
|
||
def calc_rsi(series: pl.Series) -> pl.Series:
|
||
"""计算 RSI 的辅助函数。"""
|
||
values = np.array(series.to_list(), dtype=float)
|
||
result = talib.RSI(values, timeperiod=window)
|
||
return pl.Series(result)
|
||
|
||
return close.map_batches(calc_rsi)
|
||
|
||
@time_series
|
||
def _handle_ts_obv(self, node: FunctionNode) -> pl.Expr:
|
||
"""处理 ts_obv(close, volume) -> 使用 TA-Lib 计算 OBV。
|
||
|
||
使用 map_batches 在每个分组上应用 TA-Lib OBV 函数。
|
||
@time_series 装饰器会自动添加 .over("ts_code")
|
||
"""
|
||
if not HAS_TALIB:
|
||
raise ImportError("ts_obv 需要安装 TA-Lib。请运行: pip install TA-Lib")
|
||
if len(node.args) != 2:
|
||
raise ValueError("ts_obv 需要 2 个参数: (close, volume)")
|
||
|
||
close = self.translate(node.args[0])
|
||
volume = self.translate(node.args[1])
|
||
|
||
# 使用 map_batches 应用 TA-Lib OBV 到整个分组
|
||
def calc_obv(struct_series: pl.Series) -> pl.Series:
|
||
"""计算 OBV 的辅助函数。"""
|
||
if len(struct_series) == 0:
|
||
return pl.Series([float("nan")] * len(struct_series))
|
||
|
||
# struct_series 包含 c 和 v 两个字段
|
||
c = np.array(struct_series.struct.field("c").to_list(), dtype=float)
|
||
v = np.array(struct_series.struct.field("v").to_list(), dtype=float)
|
||
result = talib.OBV(c, v)
|
||
return pl.Series(result)
|
||
|
||
return pl.struct([close.alias("c"), volume.alias("v")]).map_batches(calc_obv)
|
||
|
||
# ==================== 截面因子处理器 (cs_*) ====================
|
||
# 所有截面因子使用 @cross_section 装饰器自动注入 over("trade_date") 防串表
|
||
|
||
@cross_section
|
||
def _handle_cs_rank(self, node: FunctionNode) -> pl.Expr:
|
||
"""处理 cs_rank(expr) -> rank()/count()。
|
||
|
||
将排名归一化到 [0, 1] 区间。
|
||
"""
|
||
if len(node.args) != 1:
|
||
raise ValueError("cs_rank 需要 1 个参数: (expr)")
|
||
expr = self.translate(node.args[0])
|
||
return expr.rank() / expr.count()
|
||
|
||
@cross_section
|
||
def _handle_cs_zscore(self, node: FunctionNode) -> pl.Expr:
|
||
"""处理 cs_zscore(expr) -> (expr - mean())/std()。"""
|
||
if len(node.args) != 1:
|
||
raise ValueError("cs_zscore 需要 1 个参数: (expr)")
|
||
expr = self.translate(node.args[0])
|
||
return (expr - expr.mean()) / expr.std()
|
||
|
||
@cross_section
|
||
def _handle_cs_neutral(self, node: FunctionNode) -> pl.Expr:
|
||
"""处理 cs_neutral(expr, group) -> 分组中性化。"""
|
||
if len(node.args) not in [1, 2]:
|
||
raise ValueError("cs_neutral 需要 1-2 个参数: (expr, [group_col])")
|
||
expr = self.translate(node.args[0])
|
||
# 简单实现:减去截面均值(可在未来扩展为分组中性化)
|
||
return expr - expr.mean()
|
||
|
||
# ==================== 元素级数学函数 (element_wise) ====================
|
||
# 这些函数对每个元素独立计算,不添加 over
|
||
|
||
@element_wise
|
||
def _handle_log(self, node: FunctionNode) -> pl.Expr:
|
||
"""处理 log(expr) -> 自然对数。"""
|
||
if len(node.args) != 1:
|
||
raise ValueError("log 需要 1 个参数: (expr)")
|
||
expr = self.translate(node.args[0])
|
||
return expr.log()
|
||
|
||
@element_wise
|
||
def _handle_exp(self, node: FunctionNode) -> pl.Expr:
|
||
"""处理 exp(expr) -> 指数函数。"""
|
||
if len(node.args) != 1:
|
||
raise ValueError("exp 需要 1 个参数: (expr)")
|
||
expr = self.translate(node.args[0])
|
||
return expr.exp()
|
||
|
||
@element_wise
|
||
def _handle_sqrt(self, node: FunctionNode) -> pl.Expr:
|
||
"""处理 sqrt(expr) -> 平方根。"""
|
||
if len(node.args) != 1:
|
||
raise ValueError("sqrt 需要 1 个参数: (expr)")
|
||
expr = self.translate(node.args[0])
|
||
return expr.sqrt()
|
||
|
||
@element_wise
|
||
def _handle_sign(self, node: FunctionNode) -> pl.Expr:
|
||
"""处理 sign(expr) -> 符号函数。"""
|
||
if len(node.args) != 1:
|
||
raise ValueError("sign 需要 1 个参数: (expr)")
|
||
expr = self.translate(node.args[0])
|
||
return expr.sign()
|
||
|
||
@element_wise
|
||
def _handle_cos(self, node: FunctionNode) -> pl.Expr:
|
||
"""处理 cos(expr) -> 余弦函数。"""
|
||
if len(node.args) != 1:
|
||
raise ValueError("cos 需要 1 个参数: (expr)")
|
||
expr = self.translate(node.args[0])
|
||
return expr.cos()
|
||
|
||
@element_wise
|
||
def _handle_sin(self, node: FunctionNode) -> pl.Expr:
|
||
"""处理 sin(expr) -> 正弦函数。"""
|
||
if len(node.args) != 1:
|
||
raise ValueError("sin 需要 1 个参数: (expr)")
|
||
expr = self.translate(node.args[0])
|
||
return expr.sin()
|
||
|
||
@element_wise
|
||
def _handle_atan(self, node: FunctionNode) -> pl.Expr:
|
||
"""处理 atan(expr) -> 反正切函数。"""
|
||
if len(node.args) != 1:
|
||
raise ValueError("atan 需要 1 个参数: (expr)")
|
||
expr = self.translate(node.args[0])
|
||
return expr.arctan()
|
||
|
||
@element_wise
|
||
def _handle_log1p(self, node: FunctionNode) -> pl.Expr:
|
||
"""处理 log1p(expr) -> log(1+x) 函数。"""
|
||
if len(node.args) != 1:
|
||
raise ValueError("log1p 需要 1 个参数: (expr)")
|
||
expr = self.translate(node.args[0])
|
||
return expr.log1p()
|
||
|
||
# ==================== 辅助方法 ====================
|
||
|
||
def _extract_window(self, node: Node) -> int:
|
||
"""从节点中提取窗口大小参数。
|
||
|
||
Args:
|
||
node: 应该是 Constant 节点
|
||
|
||
Returns:
|
||
整数值
|
||
|
||
Raises:
|
||
ValueError: 当节点不是 Constant 或值不是整数时
|
||
"""
|
||
if isinstance(node, Constant):
|
||
if not isinstance(node.value, int):
|
||
raise ValueError(
|
||
f"窗口参数必须是整数,得到: {type(node.value).__name__}"
|
||
)
|
||
return node.value
|
||
raise ValueError(f"窗口参数必须是常量整数,得到: {type(node).__name__}")
|
||
|
||
|
||
def translate_to_polars(node: Node) -> pl.Expr:
|
||
"""便捷函数 - 将 AST 节点翻译为 Polars 表达式。
|
||
|
||
Args:
|
||
node: 表达式树的根节点
|
||
|
||
Returns:
|
||
Polars 表达式对象
|
||
|
||
Example:
|
||
>>> from src.factors.dsl import Symbol, FunctionNode
|
||
>>> close = Symbol("close")
|
||
>>> expr = FunctionNode("ts_mean", close, 20)
|
||
>>> polars_expr = translate_to_polars(expr)
|
||
"""
|
||
translator = PolarsTranslator()
|
||
return translator.translate(node)
|
||
|
||
|
||
if __name__ == "__main__":
|
||
# 测试用例
|
||
from src.factors.dsl import Symbol, FunctionNode
|
||
|
||
# 创建符号
|
||
close = Symbol("close")
|
||
volume = Symbol("volume")
|
||
|
||
# 测试 1: 简单符号
|
||
print("测试 1: Symbol")
|
||
translator = PolarsTranslator()
|
||
expr1 = translator.translate(close)
|
||
print(f" close -> {expr1}")
|
||
assert str(expr1) == 'col("close")'
|
||
|
||
# 测试 2: 二元运算
|
||
print("\n测试 2: BinaryOp")
|
||
expr2 = translator.translate(close + 10)
|
||
print(f" close + 10 -> {expr2}")
|
||
|
||
# 测试 3: ts_mean
|
||
print("\n测试 3: ts_mean")
|
||
expr3 = translator.translate(FunctionNode("ts_mean", close, 20))
|
||
print(f" ts_mean(close, 20) -> {expr3}")
|
||
|
||
# 测试 4: cs_rank
|
||
print("\n测试 4: cs_rank")
|
||
expr4 = translator.translate(FunctionNode("cs_rank", close / volume))
|
||
print(f" cs_rank(close / volume) -> {expr4}")
|
||
|
||
# 测试 5: 复杂表达式
|
||
print("\n测试 5: 复杂表达式")
|
||
ma20 = FunctionNode("ts_mean", close, 20)
|
||
ma60 = FunctionNode("ts_mean", close, 60)
|
||
expr5 = translator.translate(FunctionNode("cs_rank", ma20 - ma60))
|
||
print(f" cs_rank(ts_mean(close, 20) - ts_mean(close, 60)) -> {expr5}")
|
||
|
||
print("\n✅ 所有测试通过!")
|