Files
ProStock/src/factors/translator.py

460 lines
16 KiB
Python
Raw Normal View History

"""Polars 翻译器 - 将 AST 翻译为 Polars 表达式。
本模块实现 DSL Polars 计算图的映射是因子表达式执行的桥梁
支持时序因子ts_*和截面因子cs_*的防错分组翻译
"""
from typing import Any, Callable, Dict
import polars as pl
from src.factors.decorators import cross_section, element_wise, time_series
from src.factors.dsl import (
BinaryOpNode,
Constant,
FunctionNode,
Node,
Symbol,
UnaryOpNode,
)
class PolarsTranslator:
"""Polars 表达式翻译器。
将纯对象的 AST 树完美映射为 Polars 的带防错分组的计算图
Attributes:
handlers: 函数处理器注册表映射 func_name 到处理函数
Example:
>>> from src.factors.dsl import Symbol, FunctionNode
>>> close = Symbol("close")
>>> expr = FunctionNode("ts_mean", close, 20)
>>> translator = PolarsTranslator()
>>> polars_expr = translator.translate(expr)
>>> # 结果: pl.col("close").rolling_mean(20).over("asset")
"""
def __init__(self) -> None:
"""初始化翻译器并注册内置函数处理器。"""
self.handlers: Dict[str, Callable[[FunctionNode], pl.Expr]] = {}
self._register_builtin_handlers()
def _register_builtin_handlers(self) -> None:
"""注册内置的函数处理器。"""
# 时序因子处理器 (ts_*)
self.register_handler("ts_mean", self._handle_ts_mean)
self.register_handler("ts_sum", self._handle_ts_sum)
self.register_handler("ts_std", self._handle_ts_std)
self.register_handler("ts_max", self._handle_ts_max)
self.register_handler("ts_min", self._handle_ts_min)
self.register_handler("ts_delay", self._handle_ts_delay)
self.register_handler("ts_delta", self._handle_ts_delta)
self.register_handler("ts_corr", self._handle_ts_corr)
self.register_handler("ts_cov", self._handle_ts_cov)
# 截面因子处理器 (cs_*)
self.register_handler("cs_rank", self._handle_cs_rank)
self.register_handler("cs_zscore", self._handle_cs_zscore)
self.register_handler("cs_neutral", self._handle_cs_neutral)
# 元素级数学函数 (element_wise)
self.register_handler("log", self._handle_log)
self.register_handler("exp", self._handle_exp)
self.register_handler("sqrt", self._handle_sqrt)
self.register_handler("sign", self._handle_sign)
self.register_handler("cos", self._handle_cos)
self.register_handler("sin", self._handle_sin)
def register_handler(
self, func_name: str, handler: Callable[[FunctionNode], pl.Expr]
) -> None:
"""注册自定义函数处理器。
Args:
func_name: 函数名称
handler: 处理函数接收 FunctionNode 返回 pl.Expr
Example:
>>> def handle_custom(node: FunctionNode) -> pl.Expr:
... arg = self.translate(node.args[0])
... return arg * 2
>>> translator.register_handler("custom", handle_custom)
"""
self.handlers[func_name] = handler
def translate(self, node: Node) -> pl.Expr:
"""递归翻译 AST 节点为 Polars 表达式。
Args:
node: AST 节点SymbolConstantBinaryOpNodeUnaryOpNodeFunctionNode
Returns:
Polars 表达式对象
Raises:
TypeError: 当遇到未知的节点类型时
"""
if isinstance(node, Symbol):
return self._translate_symbol(node)
elif isinstance(node, Constant):
return self._translate_constant(node)
elif isinstance(node, BinaryOpNode):
return self._translate_binary_op(node)
elif isinstance(node, UnaryOpNode):
return self._translate_unary_op(node)
elif isinstance(node, FunctionNode):
return self._translate_function(node)
else:
raise TypeError(f"未知的节点类型: {type(node).__name__}")
def _translate_symbol(self, node: Symbol) -> pl.Expr:
"""翻译 Symbol 节点为 pl.col() 表达式。
Args:
node: 符号节点
Returns:
pl.col(node.name) 表达式
"""
return pl.col(node.name)
def _translate_constant(self, node: Constant) -> pl.Expr:
"""翻译 Constant 节点为 Polars 字面量。
Args:
node: 常量节点
Returns:
pl.lit(node.value) 表达式
"""
return pl.lit(node.value)
def _translate_binary_op(self, node: BinaryOpNode) -> pl.Expr:
"""翻译 BinaryOpNode 为 Polars 二元运算。
Args:
node: 二元运算节点
Returns:
Polars 二元运算表达式
"""
left = self.translate(node.left)
right = self.translate(node.right)
op_map = {
"+": lambda l, r: l + r,
"-": lambda l, r: l - r,
"*": lambda l, r: l * r,
"/": lambda l, r: l / r,
"**": lambda l, r: l.pow(r),
"//": lambda l, r: l.floor_div(r),
"%": lambda l, r: l % r,
"==": lambda l, r: l.eq(r),
"!=": lambda l, r: l.ne(r),
"<": lambda l, r: l.lt(r),
"<=": lambda l, r: l.le(r),
">": lambda l, r: l.gt(r),
">=": lambda l, r: l.ge(r),
}
if node.op not in op_map:
raise ValueError(f"不支持的二元运算符: {node.op}")
return op_map[node.op](left, right)
def _translate_unary_op(self, node: UnaryOpNode) -> pl.Expr:
"""翻译 UnaryOpNode 为 Polars 一元运算。
Args:
node: 一元运算节点
Returns:
Polars 一元运算表达式
"""
operand = self.translate(node.operand)
op_map = {
"+": lambda x: x,
"-": lambda x: -x,
"abs": lambda x: x.abs(),
}
if node.op not in op_map:
raise ValueError(f"不支持的一元运算符: {node.op}")
return op_map[node.op](operand)
def _translate_function(self, node: FunctionNode) -> pl.Expr:
"""翻译 FunctionNode 为 Polars 函数调用。
优先从 handlers 注册表中查找处理器未找到则抛出错误
Args:
node: 函数调用节点
Returns:
Polars 函数表达式
Raises:
ValueError: 当函数名称未注册处理器时
"""
func_name = node.func_name
if func_name in self.handlers:
return self.handlers[func_name](node)
else:
raise ValueError(
f"未注册的函数: {func_name}. 请使用 register_handler 注册处理器。"
)
# ==================== 时序因子处理器 (ts_*) ====================
# 所有时序因子使用 @time_series 装饰器自动注入 over("ts_code") 防串表
@time_series
def _handle_ts_mean(self, node: FunctionNode) -> pl.Expr:
"""处理 ts_mean(close, window) -> rolling_mean(window)。"""
if len(node.args) != 2:
raise ValueError("ts_mean 需要 2 个参数: (expr, window)")
expr = self.translate(node.args[0])
window = self._extract_window(node.args[1])
return expr.rolling_mean(window_size=window)
@time_series
def _handle_ts_sum(self, node: FunctionNode) -> pl.Expr:
"""处理 ts_sum(close, window) -> rolling_sum(window)。"""
if len(node.args) != 2:
raise ValueError("ts_sum 需要 2 个参数: (expr, window)")
expr = self.translate(node.args[0])
window = self._extract_window(node.args[1])
return expr.rolling_sum(window_size=window)
@time_series
def _handle_ts_std(self, node: FunctionNode) -> pl.Expr:
"""处理 ts_std(close, window) -> rolling_std(window)。"""
if len(node.args) != 2:
raise ValueError("ts_std 需要 2 个参数: (expr, window)")
expr = self.translate(node.args[0])
window = self._extract_window(node.args[1])
return expr.rolling_std(window_size=window)
@time_series
def _handle_ts_max(self, node: FunctionNode) -> pl.Expr:
"""处理 ts_max(close, window) -> rolling_max(window)。"""
if len(node.args) != 2:
raise ValueError("ts_max 需要 2 个参数: (expr, window)")
expr = self.translate(node.args[0])
window = self._extract_window(node.args[1])
return expr.rolling_max(window_size=window)
@time_series
def _handle_ts_min(self, node: FunctionNode) -> pl.Expr:
"""处理 ts_min(close, window) -> rolling_min(window)。"""
if len(node.args) != 2:
raise ValueError("ts_min 需要 2 个参数: (expr, window)")
expr = self.translate(node.args[0])
window = self._extract_window(node.args[1])
return expr.rolling_min(window_size=window)
@time_series
def _handle_ts_delay(self, node: FunctionNode) -> pl.Expr:
"""处理 ts_delay(close, n) -> shift(n)。"""
if len(node.args) != 2:
raise ValueError("ts_delay 需要 2 个参数: (expr, n)")
expr = self.translate(node.args[0])
n = self._extract_window(node.args[1])
return expr.shift(n)
@time_series
def _handle_ts_delta(self, node: FunctionNode) -> pl.Expr:
"""处理 ts_delta(close, n) -> (expr - shift(n))。"""
if len(node.args) != 2:
raise ValueError("ts_delta 需要 2 个参数: (expr, n)")
expr = self.translate(node.args[0])
n = self._extract_window(node.args[1])
return expr - expr.shift(n)
@time_series
def _handle_ts_corr(self, node: FunctionNode) -> pl.Expr:
"""处理 ts_corr(x, y, window) -> rolling_corr(y, window)。"""
if len(node.args) != 3:
raise ValueError("ts_corr 需要 3 个参数: (x, y, window)")
x = self.translate(node.args[0])
y = self.translate(node.args[1])
window = self._extract_window(node.args[2])
return x.rolling_corr(y, window_size=window)
@time_series
def _handle_ts_cov(self, node: FunctionNode) -> pl.Expr:
"""处理 ts_cov(x, y, window) -> rolling_cov(y, window)。"""
if len(node.args) != 3:
raise ValueError("ts_cov 需要 3 个参数: (x, y, window)")
x = self.translate(node.args[0])
y = self.translate(node.args[1])
window = self._extract_window(node.args[2])
return x.rolling_cov(y, window_size=window)
# ==================== 截面因子处理器 (cs_*) ====================
# 所有截面因子使用 @cross_section 装饰器自动注入 over("trade_date") 防串表
@cross_section
def _handle_cs_rank(self, node: FunctionNode) -> pl.Expr:
"""处理 cs_rank(expr) -> rank()/count()。
将排名归一化到 [0, 1] 区间
"""
if len(node.args) != 1:
raise ValueError("cs_rank 需要 1 个参数: (expr)")
expr = self.translate(node.args[0])
return expr.rank() / expr.count()
@cross_section
def _handle_cs_zscore(self, node: FunctionNode) -> pl.Expr:
"""处理 cs_zscore(expr) -> (expr - mean())/std()。"""
if len(node.args) != 1:
raise ValueError("cs_zscore 需要 1 个参数: (expr)")
expr = self.translate(node.args[0])
return (expr - expr.mean()) / expr.std()
@cross_section
def _handle_cs_neutral(self, node: FunctionNode) -> pl.Expr:
"""处理 cs_neutral(expr, group) -> 分组中性化。"""
if len(node.args) not in [1, 2]:
raise ValueError("cs_neutral 需要 1-2 个参数: (expr, [group_col])")
expr = self.translate(node.args[0])
# 简单实现:减去截面均值(可在未来扩展为分组中性化)
return expr - expr.mean()
# ==================== 元素级数学函数 (element_wise) ====================
# 这些函数对每个元素独立计算,不添加 over
@element_wise
def _handle_log(self, node: FunctionNode) -> pl.Expr:
"""处理 log(expr) -> 自然对数。"""
if len(node.args) != 1:
raise ValueError("log 需要 1 个参数: (expr)")
expr = self.translate(node.args[0])
return expr.log()
@element_wise
def _handle_exp(self, node: FunctionNode) -> pl.Expr:
"""处理 exp(expr) -> 指数函数。"""
if len(node.args) != 1:
raise ValueError("exp 需要 1 个参数: (expr)")
expr = self.translate(node.args[0])
return expr.exp()
@element_wise
def _handle_sqrt(self, node: FunctionNode) -> pl.Expr:
"""处理 sqrt(expr) -> 平方根。"""
if len(node.args) != 1:
raise ValueError("sqrt 需要 1 个参数: (expr)")
expr = self.translate(node.args[0])
return expr.sqrt()
@element_wise
def _handle_sign(self, node: FunctionNode) -> pl.Expr:
"""处理 sign(expr) -> 符号函数。"""
if len(node.args) != 1:
raise ValueError("sign 需要 1 个参数: (expr)")
expr = self.translate(node.args[0])
return expr.sign()
@element_wise
def _handle_cos(self, node: FunctionNode) -> pl.Expr:
"""处理 cos(expr) -> 余弦函数。"""
if len(node.args) != 1:
raise ValueError("cos 需要 1 个参数: (expr)")
expr = self.translate(node.args[0])
return expr.cos()
@element_wise
def _handle_sin(self, node: FunctionNode) -> pl.Expr:
"""处理 sin(expr) -> 正弦函数。"""
if len(node.args) != 1:
raise ValueError("sin 需要 1 个参数: (expr)")
expr = self.translate(node.args[0])
return expr.sin()
# ==================== 辅助方法 ====================
def _extract_window(self, node: Node) -> int:
"""从节点中提取窗口大小参数。
Args:
node: 应该是 Constant 节点
Returns:
整数值
Raises:
ValueError: 当节点不是 Constant 或值不是整数时
"""
if isinstance(node, Constant):
if not isinstance(node.value, int):
raise ValueError(
f"窗口参数必须是整数,得到: {type(node.value).__name__}"
)
return node.value
raise ValueError(f"窗口参数必须是常量整数,得到: {type(node).__name__}")
def translate_to_polars(node: Node) -> pl.Expr:
"""便捷函数 - 将 AST 节点翻译为 Polars 表达式。
Args:
node: 表达式树的根节点
Returns:
Polars 表达式对象
Example:
>>> from src.factors.dsl import Symbol, FunctionNode
>>> close = Symbol("close")
>>> expr = FunctionNode("ts_mean", close, 20)
>>> polars_expr = translate_to_polars(expr)
"""
translator = PolarsTranslator()
return translator.translate(node)
if __name__ == "__main__":
# 测试用例
from src.factors.dsl import Symbol, FunctionNode
# 创建符号
close = Symbol("close")
volume = Symbol("volume")
# 测试 1: 简单符号
print("测试 1: Symbol")
translator = PolarsTranslator()
expr1 = translator.translate(close)
print(f" close -> {expr1}")
assert str(expr1) == 'col("close")'
# 测试 2: 二元运算
print("\n测试 2: BinaryOp")
expr2 = translator.translate(close + 10)
print(f" close + 10 -> {expr2}")
# 测试 3: ts_mean
print("\n测试 3: ts_mean")
expr3 = translator.translate(FunctionNode("ts_mean", close, 20))
print(f" ts_mean(close, 20) -> {expr3}")
# 测试 4: cs_rank
print("\n测试 4: cs_rank")
expr4 = translator.translate(FunctionNode("cs_rank", close / volume))
print(f" cs_rank(close / volume) -> {expr4}")
# 测试 5: 复杂表达式
print("\n测试 5: 复杂表达式")
ma20 = FunctionNode("ts_mean", close, 20)
ma60 = FunctionNode("ts_mean", close, 60)
expr5 = translator.translate(FunctionNode("cs_rank", ma20 - ma60))
print(f" cs_rank(ts_mean(close, 20) - ts_mean(close, 60)) -> {expr5}")
print("\n✅ 所有测试通过!")