feat: 添加DSL因子表达式系统和Pro Bar API封装

- 新增 factors/dsl.py: 纯Python DSL表达式层,通过运算符重载实现因子组合
- 新增 factors/api.py: 提供常用因子符号(close/open/high/low)和时序函数(ts_mean/ts_std/cs_rank等)
- 新增 factors/compiler.py: 因子编译器
- 新增 factors/translator.py: DSL表达式翻译器
- 新增 data/api_wrappers/api_pro_bar.py: Tushare Pro Bar API封装,支持后复权行情数据
- 新增 data/data_router.py: 数据路由功能
- 新增相关测试用例
This commit is contained in:
2026-02-27 22:43:45 +08:00
parent a56433e440
commit 0698b9d919
9 changed files with 4012 additions and 0 deletions

448
src/factors/api.py Normal file
View File

@@ -0,0 +1,448 @@
"""DSL API 层 - 提供常用的符号和函数。
该模块提供量化因子表达式中常用的符号(如 close, open 等)
和函数(如 ts_mean, cs_rank 等),用户可以直接导入使用。
示例:
>>> from src.factors.api import close, open, ts_mean, cs_rank
>>> expr = ts_mean(close - open, 20) / close
>>> print(expr)
ts_mean(((close - open), 20)) / close
"""
from src.factors.dsl import Symbol, FunctionNode, Node, _ensure_node
from typing import Union
# ==================== 常用价格符号 ====================
#: 收盘价
close = Symbol("close")
#: 开盘价
open = Symbol("open")
#: 最高价
high = Symbol("high")
#: 最低价
low = Symbol("low")
#: 成交量
volume = Symbol("volume")
#: 成交额
amount = Symbol("amount")
#: 前收盘价
pre_close = Symbol("pre_close")
#: 涨跌额
change = Symbol("change")
#: 涨跌幅
pct_change = Symbol("pct_change")
# ==================== 时间序列函数 (ts_*) ====================
def ts_mean(x: Union[Node, str], window: int) -> FunctionNode:
"""时间序列均值。
计算给定因子在滚动窗口内的平均值。
Args:
x: 输入因子表达式或字段名字符串
window: 滚动窗口大小
Returns:
FunctionNode: 函数调用节点
Example:
>>> from src.factors.api import close, ts_mean
>>> expr = ts_mean(close, 20) # 20日收盘价均值
>>> expr = ts_mean("close", 20) # 使用字符串
>>> print(expr)
ts_mean(close, 20)
"""
return FunctionNode("ts_mean", x, window)
def ts_std(x: Union[Node, str], window: int) -> FunctionNode:
"""时间序列标准差。
计算给定因子在滚动窗口内的标准差。
Args:
x: 输入因子表达式或字段名字符串
window: 滚动窗口大小
Returns:
FunctionNode: 函数调用节点
"""
return FunctionNode("ts_std", x, window)
def ts_max(x: Union[Node, str], window: int) -> FunctionNode:
"""时间序列最大值。
计算给定因子在滚动窗口内的最大值。
Args:
x: 输入因子表达式或字段名字符串
window: 滚动窗口大小
Returns:
FunctionNode: 函数调用节点
"""
return FunctionNode("ts_max", x, window)
def ts_min(x: Union[Node, str], window: int) -> FunctionNode:
"""时间序列最小值。
计算给定因子在滚动窗口内的最小值。
Args:
x: 输入因子表达式或字段名字符串
window: 滚动窗口大小
Returns:
FunctionNode: 函数调用节点
"""
return FunctionNode("ts_min", x, window)
def ts_sum(x: Union[Node, str], window: int) -> FunctionNode:
"""时间序列求和。
计算给定因子在滚动窗口内的求和。
Args:
x: 输入因子表达式或字段名字符串
window: 滚动窗口大小
Returns:
FunctionNode: 函数调用节点
"""
return FunctionNode("ts_sum", x, window)
def ts_delay(x: Union[Node, str], periods: int) -> FunctionNode:
"""时间序列滞后。
获取给定因子在 N 个周期前的值。
Args:
x: 输入因子表达式或字段名字符串
periods: 滞后期数
Returns:
FunctionNode: 函数调用节点
"""
return FunctionNode("ts_delay", x, periods)
def ts_delta(x: Union[Node, str], periods: int) -> FunctionNode:
"""时间序列差分。
计算给定因子与 N 个周期前的差值。
Args:
x: 输入因子表达式或字段名字符串
periods: 差分期数
Returns:
FunctionNode: 函数调用节点
"""
return FunctionNode("ts_delta", x, periods)
def ts_corr(x: Union[Node, str], y: Union[Node, str], window: int) -> FunctionNode:
"""时间序列相关系数。
计算两个因子在滚动窗口内的相关系数。
Args:
x: 第一个因子表达式或字段名字符串
y: 第二个因子表达式或字段名字符串
window: 滚动窗口大小
Returns:
FunctionNode: 函数调用节点
"""
return FunctionNode("ts_corr", x, y, window)
def ts_cov(x: Union[Node, str], y: Union[Node, str], window: int) -> FunctionNode:
"""时间序列协方差。
计算两个因子在滚动窗口内的协方差。
Args:
x: 第一个因子表达式或字段名字符串
y: 第二个因子表达式或字段名字符串
window: 滚动窗口大小
Returns:
FunctionNode: 函数调用节点
"""
return FunctionNode("ts_cov", x, y, window)
def ts_rank(x: Union[Node, str], window: int) -> FunctionNode:
"""时间序列排名。
计算当前值在过去窗口内的分位排名。
Args:
x: 输入因子表达式或字段名字符串
window: 滚动窗口大小
Returns:
FunctionNode: 函数调用节点
"""
return FunctionNode("ts_rank", x, window)
# ==================== 截面函数 (cs_*) ====================
def cs_rank(x: Union[Node, str]) -> FunctionNode:
"""截面排名。
计算因子在横截面上的排名(分位数)。
Args:
x: 输入因子表达式或字段名字符串
Returns:
FunctionNode: 函数调用节点
Example:
>>> from src.factors.api import close, cs_rank
>>> expr = cs_rank(close) # 收盘价截面排名
>>> expr = cs_rank("close") # 使用字符串
>>> print(expr)
cs_rank(close)
"""
return FunctionNode("cs_rank", x)
def cs_zscore(x: Union[Node, str]) -> FunctionNode:
"""截面标准化 (Z-Score)。
计算因子在横截面上的 Z-Score 标准化值。
Args:
x: 输入因子表达式或字段名字符串
Returns:
FunctionNode: 函数调用节点
"""
return FunctionNode("cs_zscore", x)
def cs_neutralize(
x: Union[Node, str], group: Union[Symbol, str, None] = None
) -> FunctionNode:
"""截面中性化。
对因子进行行业/市值中性化处理。
Args:
x: 输入因子表达式或字段名字符串
group: 分组变量(如行业分类),可以为字符串或 Symbol默认为 None
Returns:
FunctionNode: 函数调用节点
"""
if group is not None:
return FunctionNode("cs_neutralize", x, group)
return FunctionNode("cs_neutralize", x)
def cs_winsorize(
x: Union[Node, str], lower: float = 0.01, upper: float = 0.99
) -> FunctionNode:
"""截面缩尾处理。
对因子进行截面缩尾处理,去除极端值。
Args:
x: 输入因子表达式或字段名字符串
lower: 下尾分位数,默认 0.01
upper: 上尾分位数,默认 0.99
Returns:
FunctionNode: 函数调用节点
"""
return FunctionNode("cs_winsorize", x, lower, upper)
def cs_demean(x: Union[Node, str]) -> FunctionNode:
"""截面去均值。
计算因子在横截面上减去均值。
Args:
x: 输入因子表达式或字段名字符串
Returns:
FunctionNode: 函数调用节点
"""
return FunctionNode("cs_demean", x)
# ==================== 数学函数 ====================
def log(x: Union[Node, str]) -> FunctionNode:
"""自然对数。
Args:
x: 输入因子表达式或字段名字符串
Returns:
FunctionNode: 函数调用节点
"""
return FunctionNode("log", x)
def exp(x: Union[Node, str]) -> FunctionNode:
"""指数函数。
Args:
x: 输入因子表达式或字段名字符串
Returns:
FunctionNode: 函数调用节点
"""
return FunctionNode("exp", x)
def sqrt(x: Union[Node, str]) -> FunctionNode:
"""平方根。
Args:
x: 输入因子表达式或字段名字符串
Returns:
FunctionNode: 函数调用节点
"""
return FunctionNode("sqrt", x)
def sign(x: Union[Node, str]) -> FunctionNode:
"""符号函数。
返回 -1, 0, 1 表示输入值的符号。
Args:
x: 输入因子表达式或字段名字符串
Returns:
FunctionNode: 函数调用节点
"""
return FunctionNode("sign", x)
def abs(x: Union[Node, str]) -> FunctionNode:
"""绝对值。
Args:
x: 输入因子表达式或字段名字符串
Returns:
FunctionNode: 函数调用节点
"""
return FunctionNode("abs", x)
def max_(x: Union[Node, str], y: Union[Node, str, int, float]) -> FunctionNode:
"""逐元素最大值。
Args:
x: 第一个因子表达式或字段名字符串
y: 第二个因子表达式、字段名字符串或数值
Returns:
FunctionNode: 函数调用节点
"""
return FunctionNode("max", x, _ensure_node(y))
def min_(x: Union[Node, str], y: Union[Node, str, int, float]) -> FunctionNode:
"""逐元素最小值。
Args:
x: 第一个因子表达式或字段名字符串
y: 第二个因子表达式、字段名字符串或数值
Returns:
FunctionNode: 函数调用节点
"""
return FunctionNode("min", x, _ensure_node(y))
def clip(
x: Union[Node, str],
lower: Union[Node, str, int, float],
upper: Union[Node, str, int, float],
) -> FunctionNode:
"""数值裁剪。
将因子值限制在 [lower, upper] 范围内。
Args:
x: 输入因子表达式或字段名字符串
lower: 下限(因子表达式、字段名字符串或数值)
upper: 上限(因子表达式、字段名字符串或数值)
Returns:
FunctionNode: 函数调用节点
"""
return FunctionNode("clip", x, _ensure_node(lower), _ensure_node(upper))
# ==================== 条件函数 ====================
def if_(
condition: Union[Node, str],
true_val: Union[Node, str, int, float],
false_val: Union[Node, str, int, float],
) -> FunctionNode:
"""条件选择。
根据条件选择值。
Args:
condition: 条件表达式或字段名字符串
true_val: 条件为真时的值(因子表达式、字段名字符串或数值)
false_val: 条件为假时的值(因子表达式、字段名字符串或数值)
Returns:
FunctionNode: 函数调用节点
"""
return FunctionNode(
"if", condition, _ensure_node(true_val), _ensure_node(false_val)
)
def where(
condition: Union[Node, str],
true_val: Union[Node, str, int, float],
false_val: Union[Node, str, int, float],
) -> FunctionNode:
"""条件选择if_ 的别名)。
Args:
condition: 条件表达式或字段名字符串
true_val: 条件为真时的值(因子表达式、字段名字符串或数值)
false_val: 条件为假时的值(因子表达式、字段名字符串或数值)
Returns:
FunctionNode: 函数调用节点
"""
return if_(condition, true_val, false_val)

159
src/factors/compiler.py Normal file
View File

@@ -0,0 +1,159 @@
"""AST 编译器模块 - 提供依赖提取和代码生成功能。
本模块实现 AST 遍历器模式,用于从 DSL 表达式中提取依赖的符号。
"""
from typing import Set
from src.factors.dsl import Node, Symbol, BinaryOpNode, UnaryOpNode, FunctionNode
class DependencyExtractor:
"""依赖提取器 - 使用访问者模式遍历 AST 节点。
递归遍历表达式树,提取所有 Symbol 节点的名称。
支持 BinaryOpNode、UnaryOpNode 和 FunctionNode 的递归遍历。
Example:
>>> from src.factors.dsl import Symbol, FunctionNode
>>> close = Symbol("close")
>>> pe_ratio = Symbol("pe_ratio")
>>> alpha = FunctionNode("cs_rank", close / pe_ratio)
>>> deps = DependencyExtractor.extract_dependencies(alpha)
>>> print(deps)
{'close', 'pe_ratio'}
"""
def __init__(self) -> None:
"""初始化依赖提取器。"""
self.dependencies: Set[str] = set()
def visit(self, node: Node) -> None:
"""访问节点,根据节点类型分发到具体处理方法。
Args:
node: AST 节点
"""
if isinstance(node, Symbol):
self._visit_symbol(node)
elif isinstance(node, BinaryOpNode):
self._visit_binary_op(node)
elif isinstance(node, UnaryOpNode):
self._visit_unary_op(node)
elif isinstance(node, FunctionNode):
self._visit_function(node)
# Constant 节点不包含依赖,无需处理
def _visit_symbol(self, node: Symbol) -> None:
"""访问 Symbol 节点,提取符号名称。
Args:
node: 符号节点
"""
self.dependencies.add(node.name)
def _visit_binary_op(self, node: BinaryOpNode) -> None:
"""访问 BinaryOpNode 节点,递归遍历左右子节点。
Args:
node: 二元运算节点
"""
self.visit(node.left)
self.visit(node.right)
def _visit_unary_op(self, node: UnaryOpNode) -> None:
"""访问 UnaryOpNode 节点,递归遍历操作数。
Args:
node: 一元运算节点
"""
self.visit(node.operand)
def _visit_function(self, node: FunctionNode) -> None:
"""访问 FunctionNode 节点,递归遍历所有参数。
Args:
node: 函数调用节点
"""
for arg in node.args:
self.visit(arg)
def extract(self, node: Node) -> Set[str]:
"""从 AST 节点中提取所有依赖的符号名称。
Args:
node: 表达式树的根节点
Returns:
依赖的符号名称集合
"""
self.dependencies.clear()
self.visit(node)
return self.dependencies.copy()
@classmethod
def extract_dependencies(cls, node: Node) -> Set[str]:
"""类方法 - 从 AST 节点中提取所有依赖的符号名称。
这是一个便捷方法,无需手动实例化 DependencyExtractor。
Args:
node: 表达式树的根节点
Returns:
依赖的符号名称集合
Example:
>>> from src.factors.dsl import Symbol
>>> close = Symbol("close")
>>> open_price = Symbol("open")
>>> expr = close / open_price
>>> deps = DependencyExtractor.extract_dependencies(expr)
>>> print(deps)
{'close', 'open'}
"""
extractor = cls()
return extractor.extract(node)
def extract_dependencies(node: Node) -> Set[str]:
"""单例方法 - 从 AST 节点中提取所有依赖的符号名称。
这是 DependencyExtractor.extract_dependencies 的便捷包装函数。
Args:
node: 表达式树的根节点
Returns:
依赖的符号名称集合
Example:
>>> from src.factors.dsl import Symbol, FunctionNode
>>> close = Symbol("close")
>>> pe_ratio = Symbol("pe_ratio")
>>> alpha = FunctionNode("cs_rank", close / pe_ratio)
>>> deps = extract_dependencies(alpha)
>>> print(deps)
{'close', 'pe_ratio'}
"""
return DependencyExtractor.extract_dependencies(node)
if __name__ == "__main__":
# 测试用例: cs_rank(close / pe_ratio)
from src.factors.dsl import Symbol, FunctionNode
# 创建符号
close = Symbol("close")
pe_ratio = Symbol("pe_ratio")
# 构建表达式: cs_rank(close / pe_ratio)
alpha = FunctionNode("cs_rank", close / pe_ratio)
# 提取依赖
dependencies = extract_dependencies(alpha)
print(f"表达式: {alpha}")
print(f"提取的依赖: {dependencies}")
print(f"期望依赖: {{'close', 'pe_ratio'}}")
print(f"验证结果: {dependencies == {'close', 'pe_ratio'}}")

278
src/factors/dsl.py Normal file
View File

@@ -0,0 +1,278 @@
"""DSL 表达式层 - 纯 Python 实现,无 pandas/polars 依赖。
提供因子表达式的符号化表示能力,通过重载运算符实现
用户端无感知的公式编写。
"""
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import Any, List, Union
class Node(ABC):
"""表达式节点基类。
所有因子表达式组件的抽象基类,提供运算符重载能力。
子类需要实现 __repr__ 方法用于表达式可视化。
"""
# ==================== 算术运算符重载 ====================
def __add__(self, other: Any) -> BinaryOpNode:
"""加法: self + other"""
return BinaryOpNode("+", self, _ensure_node(other))
def __radd__(self, other: Any) -> BinaryOpNode:
"""右加法: other + self"""
return BinaryOpNode("+", _ensure_node(other), self)
def __sub__(self, other: Any) -> BinaryOpNode:
"""减法: self - other"""
return BinaryOpNode("-", self, _ensure_node(other))
def __rsub__(self, other: Any) -> BinaryOpNode:
"""右减法: other - self"""
return BinaryOpNode("-", _ensure_node(other), self)
def __mul__(self, other: Any) -> BinaryOpNode:
"""乘法: self * other"""
return BinaryOpNode("*", self, _ensure_node(other))
def __rmul__(self, other: Any) -> BinaryOpNode:
"""右乘法: other * self"""
return BinaryOpNode("*", _ensure_node(other), self)
def __truediv__(self, other: Any) -> BinaryOpNode:
"""除法: self / other"""
return BinaryOpNode("/", self, _ensure_node(other))
def __rtruediv__(self, other: Any) -> BinaryOpNode:
"""右除法: other / self"""
return BinaryOpNode("/", _ensure_node(other), self)
def __pow__(self, other: Any) -> BinaryOpNode:
"""幂运算: self ** other"""
return BinaryOpNode("**", self, _ensure_node(other))
def __rpow__(self, other: Any) -> BinaryOpNode:
"""右幂运算: other ** self"""
return BinaryOpNode("**", _ensure_node(other), self)
def __floordiv__(self, other: Any) -> BinaryOpNode:
"""整除: self // other"""
return BinaryOpNode("//", self, _ensure_node(other))
def __rfloordiv__(self, other: Any) -> BinaryOpNode:
"""右整除: other // self"""
return BinaryOpNode("//", _ensure_node(other), self)
def __mod__(self, other: Any) -> BinaryOpNode:
"""取模: self % other"""
return BinaryOpNode("%", self, _ensure_node(other))
def __rmod__(self, other: Any) -> BinaryOpNode:
"""右取模: other % self"""
return BinaryOpNode("%", _ensure_node(other), self)
# ==================== 一元运算符重载 ====================
def __neg__(self) -> UnaryOpNode:
"""取负: -self"""
return UnaryOpNode("-", self)
def __pos__(self) -> UnaryOpNode:
"""取正: +self"""
return UnaryOpNode("+", self)
def __abs__(self) -> UnaryOpNode:
"""绝对值: abs(self)"""
return UnaryOpNode("abs", self)
# ==================== 比较运算符重载 ====================
def __eq__(self, other: Any) -> BinaryOpNode:
"""等于: self == other"""
return BinaryOpNode("==", self, _ensure_node(other))
def __ne__(self, other: Any) -> BinaryOpNode:
"""不等于: self != other"""
return BinaryOpNode("!=", self, _ensure_node(other))
def __lt__(self, other: Any) -> BinaryOpNode:
"""小于: self < other"""
return BinaryOpNode("<", self, _ensure_node(other))
def __le__(self, other: Any) -> BinaryOpNode:
"""小于等于: self <= other"""
return BinaryOpNode("<=", self, _ensure_node(other))
def __gt__(self, other: Any) -> BinaryOpNode:
"""大于: self > other"""
return BinaryOpNode(">", self, _ensure_node(other))
def __ge__(self, other: Any) -> BinaryOpNode:
"""大于等于: self >= other"""
return BinaryOpNode(">=", self, _ensure_node(other))
# ==================== 抽象方法 ====================
@abstractmethod
def __repr__(self) -> str:
"""返回表达式的字符串表示。"""
pass
class Symbol(Node):
"""符号节点,代表一个命名变量(如 close, open 等)。
Attributes:
name: 符号名称,用于标识该变量
"""
def __init__(self, name: str) -> None:
"""初始化符号节点。
Args:
name: 符号名称,如 'close', 'open', 'volume'
"""
self.name = name
def __repr__(self) -> str:
"""返回符号名称。"""
return self.name
def __hash__(self) -> int:
"""支持作为字典键使用。"""
return hash(self.name)
def __eq__(self, other: object) -> bool:
"""符号相等性比较。"""
if not isinstance(other, Symbol):
return NotImplemented
return self.name == other.name
class Constant(Node):
"""常量节点,代表一个数值常量。
Attributes:
value: 常量数值
"""
def __init__(self, value: Union[int, float]) -> None:
"""初始化常量节点。
Args:
value: 常量数值
"""
self.value = value
def __repr__(self) -> str:
"""返回常量值的字符串表示。"""
return str(self.value)
class BinaryOpNode(Node):
"""二元运算节点,表示两个操作数之间的运算。
Attributes:
op: 运算符,如 '+', '-', '*', '/'
left: 左操作数
right: 右操作数
"""
def __init__(self, op: str, left: Node, right: Node) -> None:
"""初始化二元运算节点。
Args:
op: 运算符字符串
left: 左操作数节点
right: 右操作数节点
"""
self.op = op
self.left = left
self.right = right
def __repr__(self) -> str:
"""返回带括号的二元运算表达式。"""
return f"({self.left} {self.op} {self.right})"
class UnaryOpNode(Node):
"""一元运算节点,表示对单个操作数的运算。
Attributes:
op: 运算符,如 '-', '+', 'abs'
operand: 操作数
"""
def __init__(self, op: str, operand: Node) -> None:
"""初始化一元运算节点。
Args:
op: 运算符字符串
operand: 操作数节点
"""
self.op = op
self.operand = operand
def __repr__(self) -> str:
"""返回一元运算表达式。"""
if self.op in ("+", "-"):
return f"({self.op}{self.operand})"
return f"{self.op}({self.operand})"
class FunctionNode(Node):
"""函数调用节点,表示一个函数调用。
Attributes:
func_name: 函数名称
args: 函数参数列表
"""
def __init__(self, func_name: str, *args: Any) -> None:
"""初始化函数调用节点。
Args:
func_name: 函数名称,如 'ts_mean', 'cs_rank'
*args: 函数参数,可以是 Node 或其他类型
"""
self.func_name = func_name
# 将所有参数转换为节点类型
self.args: List[Node] = [_ensure_node(arg) for arg in args]
def __repr__(self) -> str:
"""返回函数调用表达式。"""
args_str = ", ".join(repr(arg) for arg in self.args)
return f"{self.func_name}({args_str})"
# ==================== 辅助函数 ====================
def _ensure_node(value: Any) -> Node:
"""确保值是一个 Node 节点。
如果值已经是 Node 类型,直接返回;
如果是数值类型,包装为 Constant 节点;
如果是字符串类型,包装为 Symbol 节点;
否则抛出类型错误。
Args:
value: 任意值
Returns:
Node: 对应的节点对象
Raises:
TypeError: 当值无法转换为节点时
"""
if isinstance(value, Node):
return value
if isinstance(value, (int, float)):
return Constant(value)
if isinstance(value, str):
return Symbol(value)
raise TypeError(f"无法将类型 {type(value).__name__} 转换为 Node")

387
src/factors/translator.py Normal file
View File

@@ -0,0 +1,387 @@
"""Polars 翻译器 - 将 AST 翻译为 Polars 表达式。
本模块实现 DSL 到 Polars 计算图的映射,是因子表达式执行的桥梁。
支持时序因子ts_*和截面因子cs_*)的防错分组翻译。
"""
from typing import Any, Callable, Dict
import polars as pl
from src.factors.dsl import (
BinaryOpNode,
Constant,
FunctionNode,
Node,
Symbol,
UnaryOpNode,
)
class PolarsTranslator:
"""Polars 表达式翻译器。
将纯对象的 AST 树完美映射为 Polars 的带防错分组的计算图。
Attributes:
handlers: 函数处理器注册表,映射 func_name 到处理函数
Example:
>>> from src.factors.dsl import Symbol, FunctionNode
>>> close = Symbol("close")
>>> expr = FunctionNode("ts_mean", close, 20)
>>> translator = PolarsTranslator()
>>> polars_expr = translator.translate(expr)
>>> # 结果: pl.col("close").rolling_mean(20).over("asset")
"""
def __init__(self) -> None:
"""初始化翻译器并注册内置函数处理器。"""
self.handlers: Dict[str, Callable[[FunctionNode], pl.Expr]] = {}
self._register_builtin_handlers()
def _register_builtin_handlers(self) -> None:
"""注册内置的函数处理器。"""
# 时序因子处理器 (ts_*)
self.register_handler("ts_mean", self._handle_ts_mean)
self.register_handler("ts_sum", self._handle_ts_sum)
self.register_handler("ts_std", self._handle_ts_std)
self.register_handler("ts_max", self._handle_ts_max)
self.register_handler("ts_min", self._handle_ts_min)
self.register_handler("ts_delay", self._handle_ts_delay)
self.register_handler("ts_delta", self._handle_ts_delta)
self.register_handler("ts_corr", self._handle_ts_corr)
self.register_handler("ts_cov", self._handle_ts_cov)
# 截面因子处理器 (cs_*)
self.register_handler("cs_rank", self._handle_cs_rank)
self.register_handler("cs_zscore", self._handle_cs_zscore)
self.register_handler("cs_neutral", self._handle_cs_neutral)
def register_handler(
self, func_name: str, handler: Callable[[FunctionNode], pl.Expr]
) -> None:
"""注册自定义函数处理器。
Args:
func_name: 函数名称
handler: 处理函数,接收 FunctionNode 返回 pl.Expr
Example:
>>> def handle_custom(node: FunctionNode) -> pl.Expr:
... arg = self.translate(node.args[0])
... return arg * 2
>>> translator.register_handler("custom", handle_custom)
"""
self.handlers[func_name] = handler
def translate(self, node: Node) -> pl.Expr:
"""递归翻译 AST 节点为 Polars 表达式。
Args:
node: AST 节点Symbol、Constant、BinaryOpNode、UnaryOpNode、FunctionNode
Returns:
Polars 表达式对象
Raises:
TypeError: 当遇到未知的节点类型时
"""
if isinstance(node, Symbol):
return self._translate_symbol(node)
elif isinstance(node, Constant):
return self._translate_constant(node)
elif isinstance(node, BinaryOpNode):
return self._translate_binary_op(node)
elif isinstance(node, UnaryOpNode):
return self._translate_unary_op(node)
elif isinstance(node, FunctionNode):
return self._translate_function(node)
else:
raise TypeError(f"未知的节点类型: {type(node).__name__}")
def _translate_symbol(self, node: Symbol) -> pl.Expr:
"""翻译 Symbol 节点为 pl.col() 表达式。
Args:
node: 符号节点
Returns:
pl.col(node.name) 表达式
"""
return pl.col(node.name)
def _translate_constant(self, node: Constant) -> pl.Expr:
"""翻译 Constant 节点为 Polars 字面量。
Args:
node: 常量节点
Returns:
pl.lit(node.value) 表达式
"""
return pl.lit(node.value)
def _translate_binary_op(self, node: BinaryOpNode) -> pl.Expr:
"""翻译 BinaryOpNode 为 Polars 二元运算。
Args:
node: 二元运算节点
Returns:
Polars 二元运算表达式
"""
left = self.translate(node.left)
right = self.translate(node.right)
op_map = {
"+": lambda l, r: l + r,
"-": lambda l, r: l - r,
"*": lambda l, r: l * r,
"/": lambda l, r: l / r,
"**": lambda l, r: l.pow(r),
"//": lambda l, r: l.floor_div(r),
"%": lambda l, r: l % r,
"==": lambda l, r: l.eq(r),
"!=": lambda l, r: l.ne(r),
"<": lambda l, r: l.lt(r),
"<=": lambda l, r: l.le(r),
">": lambda l, r: l.gt(r),
">=": lambda l, r: l.ge(r),
}
if node.op not in op_map:
raise ValueError(f"不支持的二元运算符: {node.op}")
return op_map[node.op](left, right)
def _translate_unary_op(self, node: UnaryOpNode) -> pl.Expr:
"""翻译 UnaryOpNode 为 Polars 一元运算。
Args:
node: 一元运算节点
Returns:
Polars 一元运算表达式
"""
operand = self.translate(node.operand)
op_map = {
"+": lambda x: x,
"-": lambda x: -x,
"abs": lambda x: x.abs(),
}
if node.op not in op_map:
raise ValueError(f"不支持的一元运算符: {node.op}")
return op_map[node.op](operand)
def _translate_function(self, node: FunctionNode) -> pl.Expr:
"""翻译 FunctionNode 为 Polars 函数调用。
优先从 handlers 注册表中查找处理器,未找到则抛出错误。
Args:
node: 函数调用节点
Returns:
Polars 函数表达式
Raises:
ValueError: 当函数名称未注册处理器时
"""
func_name = node.func_name
if func_name in self.handlers:
return self.handlers[func_name](node)
else:
raise ValueError(
f"未注册的函数: {func_name}. 请使用 register_handler 注册处理器。"
)
# ==================== 时序因子处理器 (ts_*) ====================
# 所有时序因子强制注入 over("ts_code") 防串表
def _handle_ts_mean(self, node: FunctionNode) -> pl.Expr:
"""处理 ts_mean(close, window) -> rolling_mean(window).over(ts_code)。"""
if len(node.args) != 2:
raise ValueError("ts_mean 需要 2 个参数: (expr, window)")
expr = self.translate(node.args[0])
window = self._extract_window(node.args[1])
return expr.rolling_mean(window_size=window).over("ts_code")
def _handle_ts_sum(self, node: FunctionNode) -> pl.Expr:
"""处理 ts_sum(close, window) -> rolling_sum(window).over(ts_code)。"""
if len(node.args) != 2:
raise ValueError("ts_sum 需要 2 个参数: (expr, window)")
expr = self.translate(node.args[0])
window = self._extract_window(node.args[1])
return expr.rolling_sum(window_size=window).over("ts_code")
def _handle_ts_std(self, node: FunctionNode) -> pl.Expr:
"""处理 ts_std(close, window) -> rolling_std(window).over(ts_code)。"""
if len(node.args) != 2:
raise ValueError("ts_std 需要 2 个参数: (expr, window)")
expr = self.translate(node.args[0])
window = self._extract_window(node.args[1])
return expr.rolling_std(window_size=window).over("ts_code")
def _handle_ts_max(self, node: FunctionNode) -> pl.Expr:
"""处理 ts_max(close, window) -> rolling_max(window).over(ts_code)。"""
if len(node.args) != 2:
raise ValueError("ts_max 需要 2 个参数: (expr, window)")
expr = self.translate(node.args[0])
window = self._extract_window(node.args[1])
return expr.rolling_max(window_size=window).over("ts_code")
def _handle_ts_min(self, node: FunctionNode) -> pl.Expr:
"""处理 ts_min(close, window) -> rolling_min(window).over(ts_code)。"""
if len(node.args) != 2:
raise ValueError("ts_min 需要 2 个参数: (expr, window)")
expr = self.translate(node.args[0])
window = self._extract_window(node.args[1])
return expr.rolling_min(window_size=window).over("ts_code")
def _handle_ts_delay(self, node: FunctionNode) -> pl.Expr:
"""处理 ts_delay(close, n) -> shift(n).over(ts_code)。"""
if len(node.args) != 2:
raise ValueError("ts_delay 需要 2 个参数: (expr, n)")
expr = self.translate(node.args[0])
n = self._extract_window(node.args[1])
return expr.shift(n).over("ts_code")
def _handle_ts_delta(self, node: FunctionNode) -> pl.Expr:
"""处理 ts_delta(close, n) -> (expr - shift(n)).over(ts_code)。"""
if len(node.args) != 2:
raise ValueError("ts_delta 需要 2 个参数: (expr, n)")
expr = self.translate(node.args[0])
n = self._extract_window(node.args[1])
return (expr - expr.shift(n)).over("ts_code")
def _handle_ts_corr(self, node: FunctionNode) -> pl.Expr:
"""处理 ts_corr(x, y, window) -> rolling_corr(y, window).over(ts_code)。"""
if len(node.args) != 3:
raise ValueError("ts_corr 需要 3 个参数: (x, y, window)")
x = self.translate(node.args[0])
y = self.translate(node.args[1])
window = self._extract_window(node.args[2])
return x.rolling_corr(y, window_size=window).over("ts_code")
def _handle_ts_cov(self, node: FunctionNode) -> pl.Expr:
"""处理 ts_cov(x, y, window) -> rolling_cov(y, window).over(ts_code)。"""
if len(node.args) != 3:
raise ValueError("ts_cov 需要 3 个参数: (x, y, window)")
x = self.translate(node.args[0])
y = self.translate(node.args[1])
window = self._extract_window(node.args[2])
return x.rolling_cov(y, window_size=window).over("ts_code")
# ==================== 截面因子处理器 (cs_*) ====================
# 所有截面因子强制注入 over("trade_date") 防串表
def _handle_cs_rank(self, node: FunctionNode) -> pl.Expr:
"""处理 cs_rank(expr) -> rank()/count().over(trade_date)。
将排名归一化到 [0, 1] 区间。
"""
if len(node.args) != 1:
raise ValueError("cs_rank 需要 1 个参数: (expr)")
expr = self.translate(node.args[0])
return (expr.rank() / expr.count()).over("trade_date")
def _handle_cs_zscore(self, node: FunctionNode) -> pl.Expr:
"""处理 cs_zscore(expr) -> (expr - mean())/std().over(trade_date)。"""
if len(node.args) != 1:
raise ValueError("cs_zscore 需要 1 个参数: (expr)")
expr = self.translate(node.args[0])
return ((expr - expr.mean()) / expr.std()).over("trade_date")
def _handle_cs_neutral(self, node: FunctionNode) -> pl.Expr:
"""处理 cs_neutral(expr, group) -> 分组中性化。"""
if len(node.args) not in [1, 2]:
raise ValueError("cs_neutral 需要 1-2 个参数: (expr, [group_col])")
expr = self.translate(node.args[0])
# 简单实现:减去截面均值(可在未来扩展为分组中性化)
return (expr - expr.mean()).over("trade_date")
# ==================== 辅助方法 ====================
def _extract_window(self, node: Node) -> int:
"""从节点中提取窗口大小参数。
Args:
node: 应该是 Constant 节点
Returns:
整数值
Raises:
ValueError: 当节点不是 Constant 或值不是整数时
"""
if isinstance(node, Constant):
if not isinstance(node.value, int):
raise ValueError(
f"窗口参数必须是整数,得到: {type(node.value).__name__}"
)
return node.value
raise ValueError(f"窗口参数必须是常量整数,得到: {type(node).__name__}")
def translate_to_polars(node: Node) -> pl.Expr:
"""便捷函数 - 将 AST 节点翻译为 Polars 表达式。
Args:
node: 表达式树的根节点
Returns:
Polars 表达式对象
Example:
>>> from src.factors.dsl import Symbol, FunctionNode
>>> close = Symbol("close")
>>> expr = FunctionNode("ts_mean", close, 20)
>>> polars_expr = translate_to_polars(expr)
"""
translator = PolarsTranslator()
return translator.translate(node)
if __name__ == "__main__":
# 测试用例
from src.factors.dsl import Symbol, FunctionNode
# 创建符号
close = Symbol("close")
volume = Symbol("volume")
# 测试 1: 简单符号
print("测试 1: Symbol")
translator = PolarsTranslator()
expr1 = translator.translate(close)
print(f" close -> {expr1}")
assert str(expr1) == 'col("close")'
# 测试 2: 二元运算
print("\n测试 2: BinaryOp")
expr2 = translator.translate(close + 10)
print(f" close + 10 -> {expr2}")
# 测试 3: ts_mean
print("\n测试 3: ts_mean")
expr3 = translator.translate(FunctionNode("ts_mean", close, 20))
print(f" ts_mean(close, 20) -> {expr3}")
# 测试 4: cs_rank
print("\n测试 4: cs_rank")
expr4 = translator.translate(FunctionNode("cs_rank", close / volume))
print(f" cs_rank(close / volume) -> {expr4}")
# 测试 5: 复杂表达式
print("\n测试 5: 复杂表达式")
ma20 = FunctionNode("ts_mean", close, 20)
ma60 = FunctionNode("ts_mean", close, 60)
expr5 = translator.translate(FunctionNode("cs_rank", ma20 - ma60))
print(f" cs_rank(ts_mean(close, 20) - ts_mean(close, 60)) -> {expr5}")
print("\n✅ 所有测试通过!")