Files
ProStock/src/factors/translator.py
liaozhaorun 62a4635a71 feat: 新增因子装饰器系统和完整因子文档
- 添加因子表达式文档,收录180+个因子及数学表达式
- 添加因子实现分析报告,明确ts_*与cs_*算子分类
- 实现装饰器系统:@time_series/@cross_section/@element_wise
- 优化API和翻译器以支持新架构
2026-03-06 23:59:39 +08:00

460 lines
16 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""Polars 翻译器 - 将 AST 翻译为 Polars 表达式。
本模块实现 DSL 到 Polars 计算图的映射,是因子表达式执行的桥梁。
支持时序因子ts_*和截面因子cs_*)的防错分组翻译。
"""
from typing import Any, Callable, Dict
import polars as pl
from src.factors.decorators import cross_section, element_wise, time_series
from src.factors.dsl import (
BinaryOpNode,
Constant,
FunctionNode,
Node,
Symbol,
UnaryOpNode,
)
class PolarsTranslator:
"""Polars 表达式翻译器。
将纯对象的 AST 树完美映射为 Polars 的带防错分组的计算图。
Attributes:
handlers: 函数处理器注册表,映射 func_name 到处理函数
Example:
>>> from src.factors.dsl import Symbol, FunctionNode
>>> close = Symbol("close")
>>> expr = FunctionNode("ts_mean", close, 20)
>>> translator = PolarsTranslator()
>>> polars_expr = translator.translate(expr)
>>> # 结果: pl.col("close").rolling_mean(20).over("asset")
"""
def __init__(self) -> None:
"""初始化翻译器并注册内置函数处理器。"""
self.handlers: Dict[str, Callable[[FunctionNode], pl.Expr]] = {}
self._register_builtin_handlers()
def _register_builtin_handlers(self) -> None:
"""注册内置的函数处理器。"""
# 时序因子处理器 (ts_*)
self.register_handler("ts_mean", self._handle_ts_mean)
self.register_handler("ts_sum", self._handle_ts_sum)
self.register_handler("ts_std", self._handle_ts_std)
self.register_handler("ts_max", self._handle_ts_max)
self.register_handler("ts_min", self._handle_ts_min)
self.register_handler("ts_delay", self._handle_ts_delay)
self.register_handler("ts_delta", self._handle_ts_delta)
self.register_handler("ts_corr", self._handle_ts_corr)
self.register_handler("ts_cov", self._handle_ts_cov)
# 截面因子处理器 (cs_*)
self.register_handler("cs_rank", self._handle_cs_rank)
self.register_handler("cs_zscore", self._handle_cs_zscore)
self.register_handler("cs_neutral", self._handle_cs_neutral)
# 元素级数学函数 (element_wise)
self.register_handler("log", self._handle_log)
self.register_handler("exp", self._handle_exp)
self.register_handler("sqrt", self._handle_sqrt)
self.register_handler("sign", self._handle_sign)
self.register_handler("cos", self._handle_cos)
self.register_handler("sin", self._handle_sin)
def register_handler(
self, func_name: str, handler: Callable[[FunctionNode], pl.Expr]
) -> None:
"""注册自定义函数处理器。
Args:
func_name: 函数名称
handler: 处理函数,接收 FunctionNode 返回 pl.Expr
Example:
>>> def handle_custom(node: FunctionNode) -> pl.Expr:
... arg = self.translate(node.args[0])
... return arg * 2
>>> translator.register_handler("custom", handle_custom)
"""
self.handlers[func_name] = handler
def translate(self, node: Node) -> pl.Expr:
"""递归翻译 AST 节点为 Polars 表达式。
Args:
node: AST 节点Symbol、Constant、BinaryOpNode、UnaryOpNode、FunctionNode
Returns:
Polars 表达式对象
Raises:
TypeError: 当遇到未知的节点类型时
"""
if isinstance(node, Symbol):
return self._translate_symbol(node)
elif isinstance(node, Constant):
return self._translate_constant(node)
elif isinstance(node, BinaryOpNode):
return self._translate_binary_op(node)
elif isinstance(node, UnaryOpNode):
return self._translate_unary_op(node)
elif isinstance(node, FunctionNode):
return self._translate_function(node)
else:
raise TypeError(f"未知的节点类型: {type(node).__name__}")
def _translate_symbol(self, node: Symbol) -> pl.Expr:
"""翻译 Symbol 节点为 pl.col() 表达式。
Args:
node: 符号节点
Returns:
pl.col(node.name) 表达式
"""
return pl.col(node.name)
def _translate_constant(self, node: Constant) -> pl.Expr:
"""翻译 Constant 节点为 Polars 字面量。
Args:
node: 常量节点
Returns:
pl.lit(node.value) 表达式
"""
return pl.lit(node.value)
def _translate_binary_op(self, node: BinaryOpNode) -> pl.Expr:
"""翻译 BinaryOpNode 为 Polars 二元运算。
Args:
node: 二元运算节点
Returns:
Polars 二元运算表达式
"""
left = self.translate(node.left)
right = self.translate(node.right)
op_map = {
"+": lambda l, r: l + r,
"-": lambda l, r: l - r,
"*": lambda l, r: l * r,
"/": lambda l, r: l / r,
"**": lambda l, r: l.pow(r),
"//": lambda l, r: l.floor_div(r),
"%": lambda l, r: l % r,
"==": lambda l, r: l.eq(r),
"!=": lambda l, r: l.ne(r),
"<": lambda l, r: l.lt(r),
"<=": lambda l, r: l.le(r),
">": lambda l, r: l.gt(r),
">=": lambda l, r: l.ge(r),
}
if node.op not in op_map:
raise ValueError(f"不支持的二元运算符: {node.op}")
return op_map[node.op](left, right)
def _translate_unary_op(self, node: UnaryOpNode) -> pl.Expr:
"""翻译 UnaryOpNode 为 Polars 一元运算。
Args:
node: 一元运算节点
Returns:
Polars 一元运算表达式
"""
operand = self.translate(node.operand)
op_map = {
"+": lambda x: x,
"-": lambda x: -x,
"abs": lambda x: x.abs(),
}
if node.op not in op_map:
raise ValueError(f"不支持的一元运算符: {node.op}")
return op_map[node.op](operand)
def _translate_function(self, node: FunctionNode) -> pl.Expr:
"""翻译 FunctionNode 为 Polars 函数调用。
优先从 handlers 注册表中查找处理器,未找到则抛出错误。
Args:
node: 函数调用节点
Returns:
Polars 函数表达式
Raises:
ValueError: 当函数名称未注册处理器时
"""
func_name = node.func_name
if func_name in self.handlers:
return self.handlers[func_name](node)
else:
raise ValueError(
f"未注册的函数: {func_name}. 请使用 register_handler 注册处理器。"
)
# ==================== 时序因子处理器 (ts_*) ====================
# 所有时序因子使用 @time_series 装饰器自动注入 over("ts_code") 防串表
@time_series
def _handle_ts_mean(self, node: FunctionNode) -> pl.Expr:
"""处理 ts_mean(close, window) -> rolling_mean(window)。"""
if len(node.args) != 2:
raise ValueError("ts_mean 需要 2 个参数: (expr, window)")
expr = self.translate(node.args[0])
window = self._extract_window(node.args[1])
return expr.rolling_mean(window_size=window)
@time_series
def _handle_ts_sum(self, node: FunctionNode) -> pl.Expr:
"""处理 ts_sum(close, window) -> rolling_sum(window)。"""
if len(node.args) != 2:
raise ValueError("ts_sum 需要 2 个参数: (expr, window)")
expr = self.translate(node.args[0])
window = self._extract_window(node.args[1])
return expr.rolling_sum(window_size=window)
@time_series
def _handle_ts_std(self, node: FunctionNode) -> pl.Expr:
"""处理 ts_std(close, window) -> rolling_std(window)。"""
if len(node.args) != 2:
raise ValueError("ts_std 需要 2 个参数: (expr, window)")
expr = self.translate(node.args[0])
window = self._extract_window(node.args[1])
return expr.rolling_std(window_size=window)
@time_series
def _handle_ts_max(self, node: FunctionNode) -> pl.Expr:
"""处理 ts_max(close, window) -> rolling_max(window)。"""
if len(node.args) != 2:
raise ValueError("ts_max 需要 2 个参数: (expr, window)")
expr = self.translate(node.args[0])
window = self._extract_window(node.args[1])
return expr.rolling_max(window_size=window)
@time_series
def _handle_ts_min(self, node: FunctionNode) -> pl.Expr:
"""处理 ts_min(close, window) -> rolling_min(window)。"""
if len(node.args) != 2:
raise ValueError("ts_min 需要 2 个参数: (expr, window)")
expr = self.translate(node.args[0])
window = self._extract_window(node.args[1])
return expr.rolling_min(window_size=window)
@time_series
def _handle_ts_delay(self, node: FunctionNode) -> pl.Expr:
"""处理 ts_delay(close, n) -> shift(n)。"""
if len(node.args) != 2:
raise ValueError("ts_delay 需要 2 个参数: (expr, n)")
expr = self.translate(node.args[0])
n = self._extract_window(node.args[1])
return expr.shift(n)
@time_series
def _handle_ts_delta(self, node: FunctionNode) -> pl.Expr:
"""处理 ts_delta(close, n) -> (expr - shift(n))。"""
if len(node.args) != 2:
raise ValueError("ts_delta 需要 2 个参数: (expr, n)")
expr = self.translate(node.args[0])
n = self._extract_window(node.args[1])
return expr - expr.shift(n)
@time_series
def _handle_ts_corr(self, node: FunctionNode) -> pl.Expr:
"""处理 ts_corr(x, y, window) -> rolling_corr(y, window)。"""
if len(node.args) != 3:
raise ValueError("ts_corr 需要 3 个参数: (x, y, window)")
x = self.translate(node.args[0])
y = self.translate(node.args[1])
window = self._extract_window(node.args[2])
return x.rolling_corr(y, window_size=window)
@time_series
def _handle_ts_cov(self, node: FunctionNode) -> pl.Expr:
"""处理 ts_cov(x, y, window) -> rolling_cov(y, window)。"""
if len(node.args) != 3:
raise ValueError("ts_cov 需要 3 个参数: (x, y, window)")
x = self.translate(node.args[0])
y = self.translate(node.args[1])
window = self._extract_window(node.args[2])
return x.rolling_cov(y, window_size=window)
# ==================== 截面因子处理器 (cs_*) ====================
# 所有截面因子使用 @cross_section 装饰器自动注入 over("trade_date") 防串表
@cross_section
def _handle_cs_rank(self, node: FunctionNode) -> pl.Expr:
"""处理 cs_rank(expr) -> rank()/count()。
将排名归一化到 [0, 1] 区间。
"""
if len(node.args) != 1:
raise ValueError("cs_rank 需要 1 个参数: (expr)")
expr = self.translate(node.args[0])
return expr.rank() / expr.count()
@cross_section
def _handle_cs_zscore(self, node: FunctionNode) -> pl.Expr:
"""处理 cs_zscore(expr) -> (expr - mean())/std()。"""
if len(node.args) != 1:
raise ValueError("cs_zscore 需要 1 个参数: (expr)")
expr = self.translate(node.args[0])
return (expr - expr.mean()) / expr.std()
@cross_section
def _handle_cs_neutral(self, node: FunctionNode) -> pl.Expr:
"""处理 cs_neutral(expr, group) -> 分组中性化。"""
if len(node.args) not in [1, 2]:
raise ValueError("cs_neutral 需要 1-2 个参数: (expr, [group_col])")
expr = self.translate(node.args[0])
# 简单实现:减去截面均值(可在未来扩展为分组中性化)
return expr - expr.mean()
# ==================== 元素级数学函数 (element_wise) ====================
# 这些函数对每个元素独立计算,不添加 over
@element_wise
def _handle_log(self, node: FunctionNode) -> pl.Expr:
"""处理 log(expr) -> 自然对数。"""
if len(node.args) != 1:
raise ValueError("log 需要 1 个参数: (expr)")
expr = self.translate(node.args[0])
return expr.log()
@element_wise
def _handle_exp(self, node: FunctionNode) -> pl.Expr:
"""处理 exp(expr) -> 指数函数。"""
if len(node.args) != 1:
raise ValueError("exp 需要 1 个参数: (expr)")
expr = self.translate(node.args[0])
return expr.exp()
@element_wise
def _handle_sqrt(self, node: FunctionNode) -> pl.Expr:
"""处理 sqrt(expr) -> 平方根。"""
if len(node.args) != 1:
raise ValueError("sqrt 需要 1 个参数: (expr)")
expr = self.translate(node.args[0])
return expr.sqrt()
@element_wise
def _handle_sign(self, node: FunctionNode) -> pl.Expr:
"""处理 sign(expr) -> 符号函数。"""
if len(node.args) != 1:
raise ValueError("sign 需要 1 个参数: (expr)")
expr = self.translate(node.args[0])
return expr.sign()
@element_wise
def _handle_cos(self, node: FunctionNode) -> pl.Expr:
"""处理 cos(expr) -> 余弦函数。"""
if len(node.args) != 1:
raise ValueError("cos 需要 1 个参数: (expr)")
expr = self.translate(node.args[0])
return expr.cos()
@element_wise
def _handle_sin(self, node: FunctionNode) -> pl.Expr:
"""处理 sin(expr) -> 正弦函数。"""
if len(node.args) != 1:
raise ValueError("sin 需要 1 个参数: (expr)")
expr = self.translate(node.args[0])
return expr.sin()
# ==================== 辅助方法 ====================
def _extract_window(self, node: Node) -> int:
"""从节点中提取窗口大小参数。
Args:
node: 应该是 Constant 节点
Returns:
整数值
Raises:
ValueError: 当节点不是 Constant 或值不是整数时
"""
if isinstance(node, Constant):
if not isinstance(node.value, int):
raise ValueError(
f"窗口参数必须是整数,得到: {type(node.value).__name__}"
)
return node.value
raise ValueError(f"窗口参数必须是常量整数,得到: {type(node).__name__}")
def translate_to_polars(node: Node) -> pl.Expr:
"""便捷函数 - 将 AST 节点翻译为 Polars 表达式。
Args:
node: 表达式树的根节点
Returns:
Polars 表达式对象
Example:
>>> from src.factors.dsl import Symbol, FunctionNode
>>> close = Symbol("close")
>>> expr = FunctionNode("ts_mean", close, 20)
>>> polars_expr = translate_to_polars(expr)
"""
translator = PolarsTranslator()
return translator.translate(node)
if __name__ == "__main__":
# 测试用例
from src.factors.dsl import Symbol, FunctionNode
# 创建符号
close = Symbol("close")
volume = Symbol("volume")
# 测试 1: 简单符号
print("测试 1: Symbol")
translator = PolarsTranslator()
expr1 = translator.translate(close)
print(f" close -> {expr1}")
assert str(expr1) == 'col("close")'
# 测试 2: 二元运算
print("\n测试 2: BinaryOp")
expr2 = translator.translate(close + 10)
print(f" close + 10 -> {expr2}")
# 测试 3: ts_mean
print("\n测试 3: ts_mean")
expr3 = translator.translate(FunctionNode("ts_mean", close, 20))
print(f" ts_mean(close, 20) -> {expr3}")
# 测试 4: cs_rank
print("\n测试 4: cs_rank")
expr4 = translator.translate(FunctionNode("cs_rank", close / volume))
print(f" cs_rank(close / volume) -> {expr4}")
# 测试 5: 复杂表达式
print("\n测试 5: 复杂表达式")
ma20 = FunctionNode("ts_mean", close, 20)
ma60 = FunctionNode("ts_mean", close, 60)
expr5 = translator.translate(FunctionNode("cs_rank", ma20 - ma60))
print(f" cs_rank(ts_mean(close, 20) - ts_mean(close, 60)) -> {expr5}")
print("\n✅ 所有测试通过!")