- 添加因子表达式文档,收录180+个因子及数学表达式 - 添加因子实现分析报告,明确ts_*与cs_*算子分类 - 实现装饰器系统:@time_series/@cross_section/@element_wise - 优化API和翻译器以支持新架构
460 lines
16 KiB
Python
460 lines
16 KiB
Python
"""Polars 翻译器 - 将 AST 翻译为 Polars 表达式。
|
||
|
||
本模块实现 DSL 到 Polars 计算图的映射,是因子表达式执行的桥梁。
|
||
支持时序因子(ts_*)和截面因子(cs_*)的防错分组翻译。
|
||
"""
|
||
|
||
from typing import Any, Callable, Dict
|
||
|
||
import polars as pl
|
||
|
||
from src.factors.decorators import cross_section, element_wise, time_series
|
||
from src.factors.dsl import (
|
||
BinaryOpNode,
|
||
Constant,
|
||
FunctionNode,
|
||
Node,
|
||
Symbol,
|
||
UnaryOpNode,
|
||
)
|
||
|
||
|
||
class PolarsTranslator:
|
||
"""Polars 表达式翻译器。
|
||
|
||
将纯对象的 AST 树完美映射为 Polars 的带防错分组的计算图。
|
||
|
||
Attributes:
|
||
handlers: 函数处理器注册表,映射 func_name 到处理函数
|
||
|
||
Example:
|
||
>>> from src.factors.dsl import Symbol, FunctionNode
|
||
>>> close = Symbol("close")
|
||
>>> expr = FunctionNode("ts_mean", close, 20)
|
||
>>> translator = PolarsTranslator()
|
||
>>> polars_expr = translator.translate(expr)
|
||
>>> # 结果: pl.col("close").rolling_mean(20).over("asset")
|
||
"""
|
||
|
||
def __init__(self) -> None:
|
||
"""初始化翻译器并注册内置函数处理器。"""
|
||
self.handlers: Dict[str, Callable[[FunctionNode], pl.Expr]] = {}
|
||
self._register_builtin_handlers()
|
||
|
||
def _register_builtin_handlers(self) -> None:
|
||
"""注册内置的函数处理器。"""
|
||
# 时序因子处理器 (ts_*)
|
||
self.register_handler("ts_mean", self._handle_ts_mean)
|
||
self.register_handler("ts_sum", self._handle_ts_sum)
|
||
self.register_handler("ts_std", self._handle_ts_std)
|
||
self.register_handler("ts_max", self._handle_ts_max)
|
||
self.register_handler("ts_min", self._handle_ts_min)
|
||
self.register_handler("ts_delay", self._handle_ts_delay)
|
||
self.register_handler("ts_delta", self._handle_ts_delta)
|
||
self.register_handler("ts_corr", self._handle_ts_corr)
|
||
self.register_handler("ts_cov", self._handle_ts_cov)
|
||
|
||
# 截面因子处理器 (cs_*)
|
||
self.register_handler("cs_rank", self._handle_cs_rank)
|
||
self.register_handler("cs_zscore", self._handle_cs_zscore)
|
||
self.register_handler("cs_neutral", self._handle_cs_neutral)
|
||
|
||
# 元素级数学函数 (element_wise)
|
||
self.register_handler("log", self._handle_log)
|
||
self.register_handler("exp", self._handle_exp)
|
||
self.register_handler("sqrt", self._handle_sqrt)
|
||
self.register_handler("sign", self._handle_sign)
|
||
self.register_handler("cos", self._handle_cos)
|
||
self.register_handler("sin", self._handle_sin)
|
||
|
||
def register_handler(
|
||
self, func_name: str, handler: Callable[[FunctionNode], pl.Expr]
|
||
) -> None:
|
||
"""注册自定义函数处理器。
|
||
|
||
Args:
|
||
func_name: 函数名称
|
||
handler: 处理函数,接收 FunctionNode 返回 pl.Expr
|
||
|
||
Example:
|
||
>>> def handle_custom(node: FunctionNode) -> pl.Expr:
|
||
... arg = self.translate(node.args[0])
|
||
... return arg * 2
|
||
>>> translator.register_handler("custom", handle_custom)
|
||
"""
|
||
self.handlers[func_name] = handler
|
||
|
||
def translate(self, node: Node) -> pl.Expr:
|
||
"""递归翻译 AST 节点为 Polars 表达式。
|
||
|
||
Args:
|
||
node: AST 节点(Symbol、Constant、BinaryOpNode、UnaryOpNode、FunctionNode)
|
||
|
||
Returns:
|
||
Polars 表达式对象
|
||
|
||
Raises:
|
||
TypeError: 当遇到未知的节点类型时
|
||
"""
|
||
if isinstance(node, Symbol):
|
||
return self._translate_symbol(node)
|
||
elif isinstance(node, Constant):
|
||
return self._translate_constant(node)
|
||
elif isinstance(node, BinaryOpNode):
|
||
return self._translate_binary_op(node)
|
||
elif isinstance(node, UnaryOpNode):
|
||
return self._translate_unary_op(node)
|
||
elif isinstance(node, FunctionNode):
|
||
return self._translate_function(node)
|
||
else:
|
||
raise TypeError(f"未知的节点类型: {type(node).__name__}")
|
||
|
||
def _translate_symbol(self, node: Symbol) -> pl.Expr:
|
||
"""翻译 Symbol 节点为 pl.col() 表达式。
|
||
|
||
Args:
|
||
node: 符号节点
|
||
|
||
Returns:
|
||
pl.col(node.name) 表达式
|
||
"""
|
||
return pl.col(node.name)
|
||
|
||
def _translate_constant(self, node: Constant) -> pl.Expr:
|
||
"""翻译 Constant 节点为 Polars 字面量。
|
||
|
||
Args:
|
||
node: 常量节点
|
||
|
||
Returns:
|
||
pl.lit(node.value) 表达式
|
||
"""
|
||
return pl.lit(node.value)
|
||
|
||
def _translate_binary_op(self, node: BinaryOpNode) -> pl.Expr:
|
||
"""翻译 BinaryOpNode 为 Polars 二元运算。
|
||
|
||
Args:
|
||
node: 二元运算节点
|
||
|
||
Returns:
|
||
Polars 二元运算表达式
|
||
"""
|
||
left = self.translate(node.left)
|
||
right = self.translate(node.right)
|
||
|
||
op_map = {
|
||
"+": lambda l, r: l + r,
|
||
"-": lambda l, r: l - r,
|
||
"*": lambda l, r: l * r,
|
||
"/": lambda l, r: l / r,
|
||
"**": lambda l, r: l.pow(r),
|
||
"//": lambda l, r: l.floor_div(r),
|
||
"%": lambda l, r: l % r,
|
||
"==": lambda l, r: l.eq(r),
|
||
"!=": lambda l, r: l.ne(r),
|
||
"<": lambda l, r: l.lt(r),
|
||
"<=": lambda l, r: l.le(r),
|
||
">": lambda l, r: l.gt(r),
|
||
">=": lambda l, r: l.ge(r),
|
||
}
|
||
|
||
if node.op not in op_map:
|
||
raise ValueError(f"不支持的二元运算符: {node.op}")
|
||
|
||
return op_map[node.op](left, right)
|
||
|
||
def _translate_unary_op(self, node: UnaryOpNode) -> pl.Expr:
|
||
"""翻译 UnaryOpNode 为 Polars 一元运算。
|
||
|
||
Args:
|
||
node: 一元运算节点
|
||
|
||
Returns:
|
||
Polars 一元运算表达式
|
||
"""
|
||
operand = self.translate(node.operand)
|
||
|
||
op_map = {
|
||
"+": lambda x: x,
|
||
"-": lambda x: -x,
|
||
"abs": lambda x: x.abs(),
|
||
}
|
||
|
||
if node.op not in op_map:
|
||
raise ValueError(f"不支持的一元运算符: {node.op}")
|
||
|
||
return op_map[node.op](operand)
|
||
|
||
def _translate_function(self, node: FunctionNode) -> pl.Expr:
|
||
"""翻译 FunctionNode 为 Polars 函数调用。
|
||
|
||
优先从 handlers 注册表中查找处理器,未找到则抛出错误。
|
||
|
||
Args:
|
||
node: 函数调用节点
|
||
|
||
Returns:
|
||
Polars 函数表达式
|
||
|
||
Raises:
|
||
ValueError: 当函数名称未注册处理器时
|
||
"""
|
||
func_name = node.func_name
|
||
|
||
if func_name in self.handlers:
|
||
return self.handlers[func_name](node)
|
||
else:
|
||
raise ValueError(
|
||
f"未注册的函数: {func_name}. 请使用 register_handler 注册处理器。"
|
||
)
|
||
|
||
# ==================== 时序因子处理器 (ts_*) ====================
|
||
# 所有时序因子使用 @time_series 装饰器自动注入 over("ts_code") 防串表
|
||
|
||
@time_series
|
||
def _handle_ts_mean(self, node: FunctionNode) -> pl.Expr:
|
||
"""处理 ts_mean(close, window) -> rolling_mean(window)。"""
|
||
if len(node.args) != 2:
|
||
raise ValueError("ts_mean 需要 2 个参数: (expr, window)")
|
||
expr = self.translate(node.args[0])
|
||
window = self._extract_window(node.args[1])
|
||
return expr.rolling_mean(window_size=window)
|
||
|
||
@time_series
|
||
def _handle_ts_sum(self, node: FunctionNode) -> pl.Expr:
|
||
"""处理 ts_sum(close, window) -> rolling_sum(window)。"""
|
||
if len(node.args) != 2:
|
||
raise ValueError("ts_sum 需要 2 个参数: (expr, window)")
|
||
expr = self.translate(node.args[0])
|
||
window = self._extract_window(node.args[1])
|
||
return expr.rolling_sum(window_size=window)
|
||
|
||
@time_series
|
||
def _handle_ts_std(self, node: FunctionNode) -> pl.Expr:
|
||
"""处理 ts_std(close, window) -> rolling_std(window)。"""
|
||
if len(node.args) != 2:
|
||
raise ValueError("ts_std 需要 2 个参数: (expr, window)")
|
||
expr = self.translate(node.args[0])
|
||
window = self._extract_window(node.args[1])
|
||
return expr.rolling_std(window_size=window)
|
||
|
||
@time_series
|
||
def _handle_ts_max(self, node: FunctionNode) -> pl.Expr:
|
||
"""处理 ts_max(close, window) -> rolling_max(window)。"""
|
||
if len(node.args) != 2:
|
||
raise ValueError("ts_max 需要 2 个参数: (expr, window)")
|
||
expr = self.translate(node.args[0])
|
||
window = self._extract_window(node.args[1])
|
||
return expr.rolling_max(window_size=window)
|
||
|
||
@time_series
|
||
def _handle_ts_min(self, node: FunctionNode) -> pl.Expr:
|
||
"""处理 ts_min(close, window) -> rolling_min(window)。"""
|
||
if len(node.args) != 2:
|
||
raise ValueError("ts_min 需要 2 个参数: (expr, window)")
|
||
expr = self.translate(node.args[0])
|
||
window = self._extract_window(node.args[1])
|
||
return expr.rolling_min(window_size=window)
|
||
|
||
@time_series
|
||
def _handle_ts_delay(self, node: FunctionNode) -> pl.Expr:
|
||
"""处理 ts_delay(close, n) -> shift(n)。"""
|
||
if len(node.args) != 2:
|
||
raise ValueError("ts_delay 需要 2 个参数: (expr, n)")
|
||
expr = self.translate(node.args[0])
|
||
n = self._extract_window(node.args[1])
|
||
return expr.shift(n)
|
||
|
||
@time_series
|
||
def _handle_ts_delta(self, node: FunctionNode) -> pl.Expr:
|
||
"""处理 ts_delta(close, n) -> (expr - shift(n))。"""
|
||
if len(node.args) != 2:
|
||
raise ValueError("ts_delta 需要 2 个参数: (expr, n)")
|
||
expr = self.translate(node.args[0])
|
||
n = self._extract_window(node.args[1])
|
||
return expr - expr.shift(n)
|
||
|
||
@time_series
|
||
def _handle_ts_corr(self, node: FunctionNode) -> pl.Expr:
|
||
"""处理 ts_corr(x, y, window) -> rolling_corr(y, window)。"""
|
||
if len(node.args) != 3:
|
||
raise ValueError("ts_corr 需要 3 个参数: (x, y, window)")
|
||
x = self.translate(node.args[0])
|
||
y = self.translate(node.args[1])
|
||
window = self._extract_window(node.args[2])
|
||
return x.rolling_corr(y, window_size=window)
|
||
|
||
@time_series
|
||
def _handle_ts_cov(self, node: FunctionNode) -> pl.Expr:
|
||
"""处理 ts_cov(x, y, window) -> rolling_cov(y, window)。"""
|
||
if len(node.args) != 3:
|
||
raise ValueError("ts_cov 需要 3 个参数: (x, y, window)")
|
||
x = self.translate(node.args[0])
|
||
y = self.translate(node.args[1])
|
||
window = self._extract_window(node.args[2])
|
||
return x.rolling_cov(y, window_size=window)
|
||
|
||
# ==================== 截面因子处理器 (cs_*) ====================
|
||
# 所有截面因子使用 @cross_section 装饰器自动注入 over("trade_date") 防串表
|
||
|
||
@cross_section
|
||
def _handle_cs_rank(self, node: FunctionNode) -> pl.Expr:
|
||
"""处理 cs_rank(expr) -> rank()/count()。
|
||
|
||
将排名归一化到 [0, 1] 区间。
|
||
"""
|
||
if len(node.args) != 1:
|
||
raise ValueError("cs_rank 需要 1 个参数: (expr)")
|
||
expr = self.translate(node.args[0])
|
||
return expr.rank() / expr.count()
|
||
|
||
@cross_section
|
||
def _handle_cs_zscore(self, node: FunctionNode) -> pl.Expr:
|
||
"""处理 cs_zscore(expr) -> (expr - mean())/std()。"""
|
||
if len(node.args) != 1:
|
||
raise ValueError("cs_zscore 需要 1 个参数: (expr)")
|
||
expr = self.translate(node.args[0])
|
||
return (expr - expr.mean()) / expr.std()
|
||
|
||
@cross_section
|
||
def _handle_cs_neutral(self, node: FunctionNode) -> pl.Expr:
|
||
"""处理 cs_neutral(expr, group) -> 分组中性化。"""
|
||
if len(node.args) not in [1, 2]:
|
||
raise ValueError("cs_neutral 需要 1-2 个参数: (expr, [group_col])")
|
||
expr = self.translate(node.args[0])
|
||
# 简单实现:减去截面均值(可在未来扩展为分组中性化)
|
||
return expr - expr.mean()
|
||
|
||
# ==================== 元素级数学函数 (element_wise) ====================
|
||
# 这些函数对每个元素独立计算,不添加 over
|
||
|
||
@element_wise
|
||
def _handle_log(self, node: FunctionNode) -> pl.Expr:
|
||
"""处理 log(expr) -> 自然对数。"""
|
||
if len(node.args) != 1:
|
||
raise ValueError("log 需要 1 个参数: (expr)")
|
||
expr = self.translate(node.args[0])
|
||
return expr.log()
|
||
|
||
@element_wise
|
||
def _handle_exp(self, node: FunctionNode) -> pl.Expr:
|
||
"""处理 exp(expr) -> 指数函数。"""
|
||
if len(node.args) != 1:
|
||
raise ValueError("exp 需要 1 个参数: (expr)")
|
||
expr = self.translate(node.args[0])
|
||
return expr.exp()
|
||
|
||
@element_wise
|
||
def _handle_sqrt(self, node: FunctionNode) -> pl.Expr:
|
||
"""处理 sqrt(expr) -> 平方根。"""
|
||
if len(node.args) != 1:
|
||
raise ValueError("sqrt 需要 1 个参数: (expr)")
|
||
expr = self.translate(node.args[0])
|
||
return expr.sqrt()
|
||
|
||
@element_wise
|
||
def _handle_sign(self, node: FunctionNode) -> pl.Expr:
|
||
"""处理 sign(expr) -> 符号函数。"""
|
||
if len(node.args) != 1:
|
||
raise ValueError("sign 需要 1 个参数: (expr)")
|
||
expr = self.translate(node.args[0])
|
||
return expr.sign()
|
||
|
||
@element_wise
|
||
def _handle_cos(self, node: FunctionNode) -> pl.Expr:
|
||
"""处理 cos(expr) -> 余弦函数。"""
|
||
if len(node.args) != 1:
|
||
raise ValueError("cos 需要 1 个参数: (expr)")
|
||
expr = self.translate(node.args[0])
|
||
return expr.cos()
|
||
|
||
@element_wise
|
||
def _handle_sin(self, node: FunctionNode) -> pl.Expr:
|
||
"""处理 sin(expr) -> 正弦函数。"""
|
||
if len(node.args) != 1:
|
||
raise ValueError("sin 需要 1 个参数: (expr)")
|
||
expr = self.translate(node.args[0])
|
||
return expr.sin()
|
||
|
||
# ==================== 辅助方法 ====================
|
||
|
||
def _extract_window(self, node: Node) -> int:
|
||
"""从节点中提取窗口大小参数。
|
||
|
||
Args:
|
||
node: 应该是 Constant 节点
|
||
|
||
Returns:
|
||
整数值
|
||
|
||
Raises:
|
||
ValueError: 当节点不是 Constant 或值不是整数时
|
||
"""
|
||
if isinstance(node, Constant):
|
||
if not isinstance(node.value, int):
|
||
raise ValueError(
|
||
f"窗口参数必须是整数,得到: {type(node.value).__name__}"
|
||
)
|
||
return node.value
|
||
raise ValueError(f"窗口参数必须是常量整数,得到: {type(node).__name__}")
|
||
|
||
|
||
def translate_to_polars(node: Node) -> pl.Expr:
|
||
"""便捷函数 - 将 AST 节点翻译为 Polars 表达式。
|
||
|
||
Args:
|
||
node: 表达式树的根节点
|
||
|
||
Returns:
|
||
Polars 表达式对象
|
||
|
||
Example:
|
||
>>> from src.factors.dsl import Symbol, FunctionNode
|
||
>>> close = Symbol("close")
|
||
>>> expr = FunctionNode("ts_mean", close, 20)
|
||
>>> polars_expr = translate_to_polars(expr)
|
||
"""
|
||
translator = PolarsTranslator()
|
||
return translator.translate(node)
|
||
|
||
|
||
if __name__ == "__main__":
|
||
# 测试用例
|
||
from src.factors.dsl import Symbol, FunctionNode
|
||
|
||
# 创建符号
|
||
close = Symbol("close")
|
||
volume = Symbol("volume")
|
||
|
||
# 测试 1: 简单符号
|
||
print("测试 1: Symbol")
|
||
translator = PolarsTranslator()
|
||
expr1 = translator.translate(close)
|
||
print(f" close -> {expr1}")
|
||
assert str(expr1) == 'col("close")'
|
||
|
||
# 测试 2: 二元运算
|
||
print("\n测试 2: BinaryOp")
|
||
expr2 = translator.translate(close + 10)
|
||
print(f" close + 10 -> {expr2}")
|
||
|
||
# 测试 3: ts_mean
|
||
print("\n测试 3: ts_mean")
|
||
expr3 = translator.translate(FunctionNode("ts_mean", close, 20))
|
||
print(f" ts_mean(close, 20) -> {expr3}")
|
||
|
||
# 测试 4: cs_rank
|
||
print("\n测试 4: cs_rank")
|
||
expr4 = translator.translate(FunctionNode("cs_rank", close / volume))
|
||
print(f" cs_rank(close / volume) -> {expr4}")
|
||
|
||
# 测试 5: 复杂表达式
|
||
print("\n测试 5: 复杂表达式")
|
||
ma20 = FunctionNode("ts_mean", close, 20)
|
||
ma60 = FunctionNode("ts_mean", close, 60)
|
||
expr5 = translator.translate(FunctionNode("cs_rank", ma20 - ma60))
|
||
print(f" cs_rank(ts_mean(close, 20) - ts_mean(close, 60)) -> {expr5}")
|
||
|
||
print("\n✅ 所有测试通过!")
|