Files
ProStock/src/factors/translator.py
liaozhaorun 1520c2a51e feat(factors): 新增 Phase 1-2 数学和统计因子函数
- 新增 atan, log1p 数学函数
- 新增 ts_var, ts_skew, ts_kurt, ts_pct_change, ts_ema 统计函数
- 新增 ts_atr, ts_rsi, ts_obv TA-Lib 技术指标函数
- 新增完整集成测试覆盖所有新函数
2026-03-07 01:03:49 +08:00

633 lines
23 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""Polars 翻译器 - 将 AST 翻译为 Polars 表达式。
本模块实现 DSL 到 Polars 计算图的映射,是因子表达式执行的桥梁。
支持时序因子ts_*和截面因子cs_*)的防错分组翻译。
"""
from typing import Any, Callable, Dict
import numpy as np
import polars as pl
# TA-Lib 可选依赖
try:
import talib
HAS_TALIB = True
except ImportError:
HAS_TALIB = False
talib = None
from src.factors.decorators import cross_section, element_wise, time_series
from src.factors.dsl import (
BinaryOpNode,
Constant,
FunctionNode,
Node,
Symbol,
UnaryOpNode,
)
class PolarsTranslator:
"""Polars 表达式翻译器。
将纯对象的 AST 树完美映射为 Polars 的带防错分组的计算图。
Attributes:
handlers: 函数处理器注册表,映射 func_name 到处理函数
Example:
>>> from src.factors.dsl import Symbol, FunctionNode
>>> close = Symbol("close")
>>> expr = FunctionNode("ts_mean", close, 20)
>>> translator = PolarsTranslator()
>>> polars_expr = translator.translate(expr)
>>> # 结果: pl.col("close").rolling_mean(20).over("asset")
"""
def __init__(self) -> None:
"""初始化翻译器并注册内置函数处理器。"""
self.handlers: Dict[str, Callable[[FunctionNode], pl.Expr]] = {}
self._register_builtin_handlers()
def _register_builtin_handlers(self) -> None:
"""注册内置的函数处理器。"""
# 时序因子处理器 (ts_*)
self.register_handler("ts_mean", self._handle_ts_mean)
self.register_handler("ts_sum", self._handle_ts_sum)
self.register_handler("ts_std", self._handle_ts_std)
self.register_handler("ts_max", self._handle_ts_max)
self.register_handler("ts_min", self._handle_ts_min)
self.register_handler("ts_delay", self._handle_ts_delay)
self.register_handler("ts_delta", self._handle_ts_delta)
self.register_handler("ts_corr", self._handle_ts_corr)
self.register_handler("ts_cov", self._handle_ts_cov)
self.register_handler("ts_var", self._handle_ts_var)
self.register_handler("ts_skew", self._handle_ts_skew)
self.register_handler("ts_kurt", self._handle_ts_kurt)
self.register_handler("ts_pct_change", self._handle_ts_pct_change)
self.register_handler("ts_ema", self._handle_ts_ema)
self.register_handler("ts_atr", self._handle_ts_atr)
self.register_handler("ts_rsi", self._handle_ts_rsi)
self.register_handler("ts_obv", self._handle_ts_obv)
# 截面因子处理器 (cs_*)
self.register_handler("cs_rank", self._handle_cs_rank)
self.register_handler("cs_zscore", self._handle_cs_zscore)
self.register_handler("cs_neutral", self._handle_cs_neutral)
# 元素级数学函数 (element_wise)
self.register_handler("log", self._handle_log)
self.register_handler("exp", self._handle_exp)
self.register_handler("sqrt", self._handle_sqrt)
self.register_handler("sign", self._handle_sign)
self.register_handler("cos", self._handle_cos)
self.register_handler("sin", self._handle_sin)
self.register_handler("atan", self._handle_atan)
self.register_handler("log1p", self._handle_log1p)
def register_handler(
self, func_name: str, handler: Callable[[FunctionNode], pl.Expr]
) -> None:
"""注册自定义函数处理器。
Args:
func_name: 函数名称
handler: 处理函数,接收 FunctionNode 返回 pl.Expr
Example:
>>> def handle_custom(node: FunctionNode) -> pl.Expr:
... arg = self.translate(node.args[0])
... return arg * 2
>>> translator.register_handler("custom", handle_custom)
"""
self.handlers[func_name] = handler
def translate(self, node: Node) -> pl.Expr:
"""递归翻译 AST 节点为 Polars 表达式。
Args:
node: AST 节点Symbol、Constant、BinaryOpNode、UnaryOpNode、FunctionNode
Returns:
Polars 表达式对象
Raises:
TypeError: 当遇到未知的节点类型时
"""
if isinstance(node, Symbol):
return self._translate_symbol(node)
elif isinstance(node, Constant):
return self._translate_constant(node)
elif isinstance(node, BinaryOpNode):
return self._translate_binary_op(node)
elif isinstance(node, UnaryOpNode):
return self._translate_unary_op(node)
elif isinstance(node, FunctionNode):
return self._translate_function(node)
else:
raise TypeError(f"未知的节点类型: {type(node).__name__}")
def _translate_symbol(self, node: Symbol) -> pl.Expr:
"""翻译 Symbol 节点为 pl.col() 表达式。
Args:
node: 符号节点
Returns:
pl.col(node.name) 表达式
"""
return pl.col(node.name)
def _translate_constant(self, node: Constant) -> pl.Expr:
"""翻译 Constant 节点为 Polars 字面量。
Args:
node: 常量节点
Returns:
pl.lit(node.value) 表达式
"""
return pl.lit(node.value)
def _translate_binary_op(self, node: BinaryOpNode) -> pl.Expr:
"""翻译 BinaryOpNode 为 Polars 二元运算。
Args:
node: 二元运算节点
Returns:
Polars 二元运算表达式
"""
left = self.translate(node.left)
right = self.translate(node.right)
op_map = {
"+": lambda l, r: l + r,
"-": lambda l, r: l - r,
"*": lambda l, r: l * r,
"/": lambda l, r: l / r,
"**": lambda l, r: l.pow(r),
"//": lambda l, r: l.floor_div(r),
"%": lambda l, r: l % r,
"==": lambda l, r: l.eq(r),
"!=": lambda l, r: l.ne(r),
"<": lambda l, r: l.lt(r),
"<=": lambda l, r: l.le(r),
">": lambda l, r: l.gt(r),
">=": lambda l, r: l.ge(r),
}
if node.op not in op_map:
raise ValueError(f"不支持的二元运算符: {node.op}")
return op_map[node.op](left, right)
def _translate_unary_op(self, node: UnaryOpNode) -> pl.Expr:
"""翻译 UnaryOpNode 为 Polars 一元运算。
Args:
node: 一元运算节点
Returns:
Polars 一元运算表达式
"""
operand = self.translate(node.operand)
op_map = {
"+": lambda x: x,
"-": lambda x: -x,
"abs": lambda x: x.abs(),
}
if node.op not in op_map:
raise ValueError(f"不支持的一元运算符: {node.op}")
return op_map[node.op](operand)
def _translate_function(self, node: FunctionNode) -> pl.Expr:
"""翻译 FunctionNode 为 Polars 函数调用。
优先从 handlers 注册表中查找处理器,未找到则抛出错误。
Args:
node: 函数调用节点
Returns:
Polars 函数表达式
Raises:
ValueError: 当函数名称未注册处理器时
"""
func_name = node.func_name
if func_name in self.handlers:
return self.handlers[func_name](node)
else:
raise ValueError(
f"未注册的函数: {func_name}. 请使用 register_handler 注册处理器。"
)
# ==================== 时序因子处理器 (ts_*) ====================
# 所有时序因子使用 @time_series 装饰器自动注入 over("ts_code") 防串表
@time_series
def _handle_ts_mean(self, node: FunctionNode) -> pl.Expr:
"""处理 ts_mean(close, window) -> rolling_mean(window)。"""
if len(node.args) != 2:
raise ValueError("ts_mean 需要 2 个参数: (expr, window)")
expr = self.translate(node.args[0])
window = self._extract_window(node.args[1])
return expr.rolling_mean(window_size=window)
@time_series
def _handle_ts_sum(self, node: FunctionNode) -> pl.Expr:
"""处理 ts_sum(close, window) -> rolling_sum(window)。"""
if len(node.args) != 2:
raise ValueError("ts_sum 需要 2 个参数: (expr, window)")
expr = self.translate(node.args[0])
window = self._extract_window(node.args[1])
return expr.rolling_sum(window_size=window)
@time_series
def _handle_ts_std(self, node: FunctionNode) -> pl.Expr:
"""处理 ts_std(close, window) -> rolling_std(window)。"""
if len(node.args) != 2:
raise ValueError("ts_std 需要 2 个参数: (expr, window)")
expr = self.translate(node.args[0])
window = self._extract_window(node.args[1])
return expr.rolling_std(window_size=window)
@time_series
def _handle_ts_max(self, node: FunctionNode) -> pl.Expr:
"""处理 ts_max(close, window) -> rolling_max(window)。"""
if len(node.args) != 2:
raise ValueError("ts_max 需要 2 个参数: (expr, window)")
expr = self.translate(node.args[0])
window = self._extract_window(node.args[1])
return expr.rolling_max(window_size=window)
@time_series
def _handle_ts_min(self, node: FunctionNode) -> pl.Expr:
"""处理 ts_min(close, window) -> rolling_min(window)。"""
if len(node.args) != 2:
raise ValueError("ts_min 需要 2 个参数: (expr, window)")
expr = self.translate(node.args[0])
window = self._extract_window(node.args[1])
return expr.rolling_min(window_size=window)
@time_series
def _handle_ts_delay(self, node: FunctionNode) -> pl.Expr:
"""处理 ts_delay(close, n) -> shift(n)。"""
if len(node.args) != 2:
raise ValueError("ts_delay 需要 2 个参数: (expr, n)")
expr = self.translate(node.args[0])
n = self._extract_window(node.args[1])
return expr.shift(n)
@time_series
def _handle_ts_delta(self, node: FunctionNode) -> pl.Expr:
"""处理 ts_delta(close, n) -> (expr - shift(n))。"""
if len(node.args) != 2:
raise ValueError("ts_delta 需要 2 个参数: (expr, n)")
expr = self.translate(node.args[0])
n = self._extract_window(node.args[1])
return expr - expr.shift(n)
@time_series
def _handle_ts_corr(self, node: FunctionNode) -> pl.Expr:
"""处理 ts_corr(x, y, window) -> rolling_corr(y, window)。"""
if len(node.args) != 3:
raise ValueError("ts_corr 需要 3 个参数: (x, y, window)")
x = self.translate(node.args[0])
y = self.translate(node.args[1])
window = self._extract_window(node.args[2])
return x.rolling_corr(y, window_size=window)
@time_series
def _handle_ts_cov(self, node: FunctionNode) -> pl.Expr:
"""处理 ts_cov(x, y, window) -> rolling_cov(y, window)。"""
if len(node.args) != 3:
raise ValueError("ts_cov 需要 3 个参数: (x, y, window)")
x = self.translate(node.args[0])
y = self.translate(node.args[1])
window = self._extract_window(node.args[2])
return x.rolling_cov(y, window_size=window)
@time_series
def _handle_ts_var(self, node: FunctionNode) -> pl.Expr:
"""处理 ts_var(close, window) -> rolling_var(window)。"""
if len(node.args) != 2:
raise ValueError("ts_var 需要 2 个参数: (expr, window)")
expr = self.translate(node.args[0])
window = self._extract_window(node.args[1])
return expr.rolling_var(window_size=window)
@time_series
def _handle_ts_skew(self, node: FunctionNode) -> pl.Expr:
"""处理 ts_skew(close, window) -> rolling_skew(window)。"""
if len(node.args) != 2:
raise ValueError("ts_skew 需要 2 个参数: (expr, window)")
expr = self.translate(node.args[0])
window = self._extract_window(node.args[1])
return expr.rolling_skew(window_size=window)
@time_series
def _handle_ts_kurt(self, node: FunctionNode) -> pl.Expr:
"""处理 ts_kurt(close, window) -> rolling_kurt(window)。"""
if len(node.args) != 2:
raise ValueError("ts_kurt 需要 2 个参数: (expr, window)")
expr = self.translate(node.args[0])
window = self._extract_window(node.args[1])
# 使用 rolling_map 计算峰度
return expr.rolling_map(
lambda s: s.kurtosis() if len(s.drop_nulls()) >= 4 else float("nan"),
window_size=window,
)
@time_series
def _handle_ts_pct_change(self, node: FunctionNode) -> pl.Expr:
"""处理 ts_pct_change(x, n) -> (x - shift(n)) / shift(n)。"""
if len(node.args) != 2:
raise ValueError("ts_pct_change 需要 2 个参数: (expr, periods)")
expr = self.translate(node.args[0])
n = self._extract_window(node.args[1])
shifted = expr.shift(n)
return (expr - shifted) / shifted
@time_series
def _handle_ts_ema(self, node: FunctionNode) -> pl.Expr:
"""处理 ts_ema(x, window) -> ewm_mean(span=window)。"""
if len(node.args) != 2:
raise ValueError("ts_ema 需要 2 个参数: (expr, window)")
expr = self.translate(node.args[0])
window = self._extract_window(node.args[1])
return expr.ewm_mean(span=window)
@time_series
def _handle_ts_atr(self, node: FunctionNode) -> pl.Expr:
"""处理 ts_atr(high, low, close, window) -> 使用 TA-Lib 计算 ATR。
使用 map_batches 在每个分组上应用 TA-Lib ATR 函数。
@time_series 装饰器会自动添加 .over("ts_code")
"""
if not HAS_TALIB:
raise ImportError("ts_atr 需要安装 TA-Lib。请运行: pip install TA-Lib")
if len(node.args) != 4:
raise ValueError("ts_atr 需要 4 个参数: (high, low, close, window)")
high = self.translate(node.args[0])
low = self.translate(node.args[1])
close = self.translate(node.args[2])
window = self._extract_window(node.args[3])
# 使用 map_batches 应用 TA-Lib ATR 到整个分组
def calc_atr(struct_series: pl.Series) -> pl.Series:
"""计算 ATR 的辅助函数。"""
if len(struct_series) == 0:
return pl.Series([float("nan")] * len(struct_series))
# struct_series 包含 h, l, c 三个字段
h = np.array(struct_series.struct.field("h").to_list(), dtype=float)
l = np.array(struct_series.struct.field("l").to_list(), dtype=float)
c = np.array(struct_series.struct.field("c").to_list(), dtype=float)
result = talib.ATR(h, l, c, timeperiod=window)
return pl.Series(result)
return pl.struct(
[high.alias("h"), low.alias("l"), close.alias("c")]
).map_batches(calc_atr)
@time_series
def _handle_ts_rsi(self, node: FunctionNode) -> pl.Expr:
"""处理 ts_rsi(close, window) -> 使用 TA-Lib 计算 RSI。
使用 map_batches 在每个分组上应用 TA-Lib RSI 函数。
@time_series 装饰器会自动添加 .over("ts_code")
"""
if not HAS_TALIB:
raise ImportError("ts_rsi 需要安装 TA-Lib。请运行: pip install TA-Lib")
if len(node.args) != 2:
raise ValueError("ts_rsi 需要 2 个参数: (close, window)")
close = self.translate(node.args[0])
window = self._extract_window(node.args[1])
# 使用 map_batches 应用 TA-Lib RSI 到整个分组
def calc_rsi(series: pl.Series) -> pl.Series:
"""计算 RSI 的辅助函数。"""
values = np.array(series.to_list(), dtype=float)
result = talib.RSI(values, timeperiod=window)
return pl.Series(result)
return close.map_batches(calc_rsi)
@time_series
def _handle_ts_obv(self, node: FunctionNode) -> pl.Expr:
"""处理 ts_obv(close, volume) -> 使用 TA-Lib 计算 OBV。
使用 map_batches 在每个分组上应用 TA-Lib OBV 函数。
@time_series 装饰器会自动添加 .over("ts_code")
"""
if not HAS_TALIB:
raise ImportError("ts_obv 需要安装 TA-Lib。请运行: pip install TA-Lib")
if len(node.args) != 2:
raise ValueError("ts_obv 需要 2 个参数: (close, volume)")
close = self.translate(node.args[0])
volume = self.translate(node.args[1])
# 使用 map_batches 应用 TA-Lib OBV 到整个分组
def calc_obv(struct_series: pl.Series) -> pl.Series:
"""计算 OBV 的辅助函数。"""
if len(struct_series) == 0:
return pl.Series([float("nan")] * len(struct_series))
# struct_series 包含 c 和 v 两个字段
c = np.array(struct_series.struct.field("c").to_list(), dtype=float)
v = np.array(struct_series.struct.field("v").to_list(), dtype=float)
result = talib.OBV(c, v)
return pl.Series(result)
return pl.struct([close.alias("c"), volume.alias("v")]).map_batches(calc_obv)
# ==================== 截面因子处理器 (cs_*) ====================
# 所有截面因子使用 @cross_section 装饰器自动注入 over("trade_date") 防串表
@cross_section
def _handle_cs_rank(self, node: FunctionNode) -> pl.Expr:
"""处理 cs_rank(expr) -> rank()/count()。
将排名归一化到 [0, 1] 区间。
"""
if len(node.args) != 1:
raise ValueError("cs_rank 需要 1 个参数: (expr)")
expr = self.translate(node.args[0])
return expr.rank() / expr.count()
@cross_section
def _handle_cs_zscore(self, node: FunctionNode) -> pl.Expr:
"""处理 cs_zscore(expr) -> (expr - mean())/std()。"""
if len(node.args) != 1:
raise ValueError("cs_zscore 需要 1 个参数: (expr)")
expr = self.translate(node.args[0])
return (expr - expr.mean()) / expr.std()
@cross_section
def _handle_cs_neutral(self, node: FunctionNode) -> pl.Expr:
"""处理 cs_neutral(expr, group) -> 分组中性化。"""
if len(node.args) not in [1, 2]:
raise ValueError("cs_neutral 需要 1-2 个参数: (expr, [group_col])")
expr = self.translate(node.args[0])
# 简单实现:减去截面均值(可在未来扩展为分组中性化)
return expr - expr.mean()
# ==================== 元素级数学函数 (element_wise) ====================
# 这些函数对每个元素独立计算,不添加 over
@element_wise
def _handle_log(self, node: FunctionNode) -> pl.Expr:
"""处理 log(expr) -> 自然对数。"""
if len(node.args) != 1:
raise ValueError("log 需要 1 个参数: (expr)")
expr = self.translate(node.args[0])
return expr.log()
@element_wise
def _handle_exp(self, node: FunctionNode) -> pl.Expr:
"""处理 exp(expr) -> 指数函数。"""
if len(node.args) != 1:
raise ValueError("exp 需要 1 个参数: (expr)")
expr = self.translate(node.args[0])
return expr.exp()
@element_wise
def _handle_sqrt(self, node: FunctionNode) -> pl.Expr:
"""处理 sqrt(expr) -> 平方根。"""
if len(node.args) != 1:
raise ValueError("sqrt 需要 1 个参数: (expr)")
expr = self.translate(node.args[0])
return expr.sqrt()
@element_wise
def _handle_sign(self, node: FunctionNode) -> pl.Expr:
"""处理 sign(expr) -> 符号函数。"""
if len(node.args) != 1:
raise ValueError("sign 需要 1 个参数: (expr)")
expr = self.translate(node.args[0])
return expr.sign()
@element_wise
def _handle_cos(self, node: FunctionNode) -> pl.Expr:
"""处理 cos(expr) -> 余弦函数。"""
if len(node.args) != 1:
raise ValueError("cos 需要 1 个参数: (expr)")
expr = self.translate(node.args[0])
return expr.cos()
@element_wise
def _handle_sin(self, node: FunctionNode) -> pl.Expr:
"""处理 sin(expr) -> 正弦函数。"""
if len(node.args) != 1:
raise ValueError("sin 需要 1 个参数: (expr)")
expr = self.translate(node.args[0])
return expr.sin()
@element_wise
def _handle_atan(self, node: FunctionNode) -> pl.Expr:
"""处理 atan(expr) -> 反正切函数。"""
if len(node.args) != 1:
raise ValueError("atan 需要 1 个参数: (expr)")
expr = self.translate(node.args[0])
return expr.arctan()
@element_wise
def _handle_log1p(self, node: FunctionNode) -> pl.Expr:
"""处理 log1p(expr) -> log(1+x) 函数。"""
if len(node.args) != 1:
raise ValueError("log1p 需要 1 个参数: (expr)")
expr = self.translate(node.args[0])
return expr.log1p()
# ==================== 辅助方法 ====================
def _extract_window(self, node: Node) -> int:
"""从节点中提取窗口大小参数。
Args:
node: 应该是 Constant 节点
Returns:
整数值
Raises:
ValueError: 当节点不是 Constant 或值不是整数时
"""
if isinstance(node, Constant):
if not isinstance(node.value, int):
raise ValueError(
f"窗口参数必须是整数,得到: {type(node.value).__name__}"
)
return node.value
raise ValueError(f"窗口参数必须是常量整数,得到: {type(node).__name__}")
def translate_to_polars(node: Node) -> pl.Expr:
"""便捷函数 - 将 AST 节点翻译为 Polars 表达式。
Args:
node: 表达式树的根节点
Returns:
Polars 表达式对象
Example:
>>> from src.factors.dsl import Symbol, FunctionNode
>>> close = Symbol("close")
>>> expr = FunctionNode("ts_mean", close, 20)
>>> polars_expr = translate_to_polars(expr)
"""
translator = PolarsTranslator()
return translator.translate(node)
if __name__ == "__main__":
# 测试用例
from src.factors.dsl import Symbol, FunctionNode
# 创建符号
close = Symbol("close")
volume = Symbol("volume")
# 测试 1: 简单符号
print("测试 1: Symbol")
translator = PolarsTranslator()
expr1 = translator.translate(close)
print(f" close -> {expr1}")
assert str(expr1) == 'col("close")'
# 测试 2: 二元运算
print("\n测试 2: BinaryOp")
expr2 = translator.translate(close + 10)
print(f" close + 10 -> {expr2}")
# 测试 3: ts_mean
print("\n测试 3: ts_mean")
expr3 = translator.translate(FunctionNode("ts_mean", close, 20))
print(f" ts_mean(close, 20) -> {expr3}")
# 测试 4: cs_rank
print("\n测试 4: cs_rank")
expr4 = translator.translate(FunctionNode("cs_rank", close / volume))
print(f" cs_rank(close / volume) -> {expr4}")
# 测试 5: 复杂表达式
print("\n测试 5: 复杂表达式")
ma20 = FunctionNode("ts_mean", close, 20)
ma60 = FunctionNode("ts_mean", close, 60)
expr5 = translator.translate(FunctionNode("cs_rank", ma20 - ma60))
print(f" cs_rank(ts_mean(close, 20) - ts_mean(close, 60)) -> {expr5}")
print("\n✅ 所有测试通过!")