feat: 添加DSL因子表达式系统和Pro Bar API封装
- 新增 factors/dsl.py: 纯Python DSL表达式层,通过运算符重载实现因子组合 - 新增 factors/api.py: 提供常用因子符号(close/open/high/low)和时序函数(ts_mean/ts_std/cs_rank等) - 新增 factors/compiler.py: 因子编译器 - 新增 factors/translator.py: DSL表达式翻译器 - 新增 data/api_wrappers/api_pro_bar.py: Tushare Pro Bar API封装,支持后复权行情数据 - 新增 data/data_router.py: 数据路由功能 - 新增相关测试用例
This commit is contained in:
448
src/factors/api.py
Normal file
448
src/factors/api.py
Normal file
@@ -0,0 +1,448 @@
|
||||
"""DSL API 层 - 提供常用的符号和函数。
|
||||
|
||||
该模块提供量化因子表达式中常用的符号(如 close, open 等)
|
||||
和函数(如 ts_mean, cs_rank 等),用户可以直接导入使用。
|
||||
|
||||
示例:
|
||||
>>> from src.factors.api import close, open, ts_mean, cs_rank
|
||||
>>> expr = ts_mean(close - open, 20) / close
|
||||
>>> print(expr)
|
||||
ts_mean(((close - open), 20)) / close
|
||||
"""
|
||||
|
||||
from src.factors.dsl import Symbol, FunctionNode, Node, _ensure_node
|
||||
from typing import Union
|
||||
|
||||
# ==================== 常用价格符号 ====================
|
||||
|
||||
#: 收盘价
|
||||
close = Symbol("close")
|
||||
|
||||
#: 开盘价
|
||||
open = Symbol("open")
|
||||
|
||||
#: 最高价
|
||||
high = Symbol("high")
|
||||
|
||||
#: 最低价
|
||||
low = Symbol("low")
|
||||
|
||||
#: 成交量
|
||||
volume = Symbol("volume")
|
||||
|
||||
#: 成交额
|
||||
amount = Symbol("amount")
|
||||
|
||||
#: 前收盘价
|
||||
pre_close = Symbol("pre_close")
|
||||
|
||||
#: 涨跌额
|
||||
change = Symbol("change")
|
||||
|
||||
#: 涨跌幅
|
||||
pct_change = Symbol("pct_change")
|
||||
|
||||
|
||||
# ==================== 时间序列函数 (ts_*) ====================
|
||||
|
||||
|
||||
def ts_mean(x: Union[Node, str], window: int) -> FunctionNode:
|
||||
"""时间序列均值。
|
||||
|
||||
计算给定因子在滚动窗口内的平均值。
|
||||
|
||||
Args:
|
||||
x: 输入因子表达式或字段名字符串
|
||||
window: 滚动窗口大小
|
||||
|
||||
Returns:
|
||||
FunctionNode: 函数调用节点
|
||||
|
||||
Example:
|
||||
>>> from src.factors.api import close, ts_mean
|
||||
>>> expr = ts_mean(close, 20) # 20日收盘价均值
|
||||
>>> expr = ts_mean("close", 20) # 使用字符串
|
||||
>>> print(expr)
|
||||
ts_mean(close, 20)
|
||||
"""
|
||||
return FunctionNode("ts_mean", x, window)
|
||||
|
||||
|
||||
def ts_std(x: Union[Node, str], window: int) -> FunctionNode:
|
||||
"""时间序列标准差。
|
||||
|
||||
计算给定因子在滚动窗口内的标准差。
|
||||
|
||||
Args:
|
||||
x: 输入因子表达式或字段名字符串
|
||||
window: 滚动窗口大小
|
||||
|
||||
Returns:
|
||||
FunctionNode: 函数调用节点
|
||||
"""
|
||||
return FunctionNode("ts_std", x, window)
|
||||
|
||||
|
||||
def ts_max(x: Union[Node, str], window: int) -> FunctionNode:
|
||||
"""时间序列最大值。
|
||||
|
||||
计算给定因子在滚动窗口内的最大值。
|
||||
|
||||
Args:
|
||||
x: 输入因子表达式或字段名字符串
|
||||
window: 滚动窗口大小
|
||||
|
||||
Returns:
|
||||
FunctionNode: 函数调用节点
|
||||
"""
|
||||
return FunctionNode("ts_max", x, window)
|
||||
|
||||
|
||||
def ts_min(x: Union[Node, str], window: int) -> FunctionNode:
|
||||
"""时间序列最小值。
|
||||
|
||||
计算给定因子在滚动窗口内的最小值。
|
||||
|
||||
Args:
|
||||
x: 输入因子表达式或字段名字符串
|
||||
window: 滚动窗口大小
|
||||
|
||||
Returns:
|
||||
FunctionNode: 函数调用节点
|
||||
"""
|
||||
return FunctionNode("ts_min", x, window)
|
||||
|
||||
|
||||
def ts_sum(x: Union[Node, str], window: int) -> FunctionNode:
|
||||
"""时间序列求和。
|
||||
|
||||
计算给定因子在滚动窗口内的求和。
|
||||
|
||||
Args:
|
||||
x: 输入因子表达式或字段名字符串
|
||||
window: 滚动窗口大小
|
||||
|
||||
Returns:
|
||||
FunctionNode: 函数调用节点
|
||||
"""
|
||||
return FunctionNode("ts_sum", x, window)
|
||||
|
||||
|
||||
def ts_delay(x: Union[Node, str], periods: int) -> FunctionNode:
|
||||
"""时间序列滞后。
|
||||
|
||||
获取给定因子在 N 个周期前的值。
|
||||
|
||||
Args:
|
||||
x: 输入因子表达式或字段名字符串
|
||||
periods: 滞后期数
|
||||
|
||||
Returns:
|
||||
FunctionNode: 函数调用节点
|
||||
"""
|
||||
return FunctionNode("ts_delay", x, periods)
|
||||
|
||||
|
||||
def ts_delta(x: Union[Node, str], periods: int) -> FunctionNode:
|
||||
"""时间序列差分。
|
||||
|
||||
计算给定因子与 N 个周期前的差值。
|
||||
|
||||
Args:
|
||||
x: 输入因子表达式或字段名字符串
|
||||
periods: 差分期数
|
||||
|
||||
Returns:
|
||||
FunctionNode: 函数调用节点
|
||||
"""
|
||||
return FunctionNode("ts_delta", x, periods)
|
||||
|
||||
|
||||
def ts_corr(x: Union[Node, str], y: Union[Node, str], window: int) -> FunctionNode:
|
||||
"""时间序列相关系数。
|
||||
|
||||
计算两个因子在滚动窗口内的相关系数。
|
||||
|
||||
Args:
|
||||
x: 第一个因子表达式或字段名字符串
|
||||
y: 第二个因子表达式或字段名字符串
|
||||
window: 滚动窗口大小
|
||||
|
||||
Returns:
|
||||
FunctionNode: 函数调用节点
|
||||
"""
|
||||
return FunctionNode("ts_corr", x, y, window)
|
||||
|
||||
|
||||
def ts_cov(x: Union[Node, str], y: Union[Node, str], window: int) -> FunctionNode:
|
||||
"""时间序列协方差。
|
||||
|
||||
计算两个因子在滚动窗口内的协方差。
|
||||
|
||||
Args:
|
||||
x: 第一个因子表达式或字段名字符串
|
||||
y: 第二个因子表达式或字段名字符串
|
||||
window: 滚动窗口大小
|
||||
|
||||
Returns:
|
||||
FunctionNode: 函数调用节点
|
||||
"""
|
||||
return FunctionNode("ts_cov", x, y, window)
|
||||
|
||||
|
||||
def ts_rank(x: Union[Node, str], window: int) -> FunctionNode:
|
||||
"""时间序列排名。
|
||||
|
||||
计算当前值在过去窗口内的分位排名。
|
||||
|
||||
Args:
|
||||
x: 输入因子表达式或字段名字符串
|
||||
window: 滚动窗口大小
|
||||
|
||||
Returns:
|
||||
FunctionNode: 函数调用节点
|
||||
"""
|
||||
return FunctionNode("ts_rank", x, window)
|
||||
|
||||
|
||||
# ==================== 截面函数 (cs_*) ====================
|
||||
|
||||
|
||||
def cs_rank(x: Union[Node, str]) -> FunctionNode:
|
||||
"""截面排名。
|
||||
|
||||
计算因子在横截面上的排名(分位数)。
|
||||
|
||||
Args:
|
||||
x: 输入因子表达式或字段名字符串
|
||||
|
||||
Returns:
|
||||
FunctionNode: 函数调用节点
|
||||
|
||||
Example:
|
||||
>>> from src.factors.api import close, cs_rank
|
||||
>>> expr = cs_rank(close) # 收盘价截面排名
|
||||
>>> expr = cs_rank("close") # 使用字符串
|
||||
>>> print(expr)
|
||||
cs_rank(close)
|
||||
"""
|
||||
return FunctionNode("cs_rank", x)
|
||||
|
||||
|
||||
def cs_zscore(x: Union[Node, str]) -> FunctionNode:
|
||||
"""截面标准化 (Z-Score)。
|
||||
|
||||
计算因子在横截面上的 Z-Score 标准化值。
|
||||
|
||||
Args:
|
||||
x: 输入因子表达式或字段名字符串
|
||||
|
||||
Returns:
|
||||
FunctionNode: 函数调用节点
|
||||
"""
|
||||
return FunctionNode("cs_zscore", x)
|
||||
|
||||
|
||||
def cs_neutralize(
|
||||
x: Union[Node, str], group: Union[Symbol, str, None] = None
|
||||
) -> FunctionNode:
|
||||
"""截面中性化。
|
||||
|
||||
对因子进行行业/市值中性化处理。
|
||||
|
||||
Args:
|
||||
x: 输入因子表达式或字段名字符串
|
||||
group: 分组变量(如行业分类),可以为字符串或 Symbol,默认为 None
|
||||
|
||||
Returns:
|
||||
FunctionNode: 函数调用节点
|
||||
"""
|
||||
if group is not None:
|
||||
return FunctionNode("cs_neutralize", x, group)
|
||||
return FunctionNode("cs_neutralize", x)
|
||||
|
||||
|
||||
def cs_winsorize(
|
||||
x: Union[Node, str], lower: float = 0.01, upper: float = 0.99
|
||||
) -> FunctionNode:
|
||||
"""截面缩尾处理。
|
||||
|
||||
对因子进行截面缩尾处理,去除极端值。
|
||||
|
||||
Args:
|
||||
x: 输入因子表达式或字段名字符串
|
||||
lower: 下尾分位数,默认 0.01
|
||||
upper: 上尾分位数,默认 0.99
|
||||
|
||||
Returns:
|
||||
FunctionNode: 函数调用节点
|
||||
"""
|
||||
return FunctionNode("cs_winsorize", x, lower, upper)
|
||||
|
||||
|
||||
def cs_demean(x: Union[Node, str]) -> FunctionNode:
|
||||
"""截面去均值。
|
||||
|
||||
计算因子在横截面上减去均值。
|
||||
|
||||
Args:
|
||||
x: 输入因子表达式或字段名字符串
|
||||
|
||||
Returns:
|
||||
FunctionNode: 函数调用节点
|
||||
"""
|
||||
return FunctionNode("cs_demean", x)
|
||||
|
||||
|
||||
# ==================== 数学函数 ====================
|
||||
|
||||
|
||||
def log(x: Union[Node, str]) -> FunctionNode:
|
||||
"""自然对数。
|
||||
|
||||
Args:
|
||||
x: 输入因子表达式或字段名字符串
|
||||
|
||||
Returns:
|
||||
FunctionNode: 函数调用节点
|
||||
"""
|
||||
return FunctionNode("log", x)
|
||||
|
||||
|
||||
def exp(x: Union[Node, str]) -> FunctionNode:
|
||||
"""指数函数。
|
||||
|
||||
Args:
|
||||
x: 输入因子表达式或字段名字符串
|
||||
|
||||
Returns:
|
||||
FunctionNode: 函数调用节点
|
||||
"""
|
||||
return FunctionNode("exp", x)
|
||||
|
||||
|
||||
def sqrt(x: Union[Node, str]) -> FunctionNode:
|
||||
"""平方根。
|
||||
|
||||
Args:
|
||||
x: 输入因子表达式或字段名字符串
|
||||
|
||||
Returns:
|
||||
FunctionNode: 函数调用节点
|
||||
"""
|
||||
return FunctionNode("sqrt", x)
|
||||
|
||||
|
||||
def sign(x: Union[Node, str]) -> FunctionNode:
|
||||
"""符号函数。
|
||||
|
||||
返回 -1, 0, 1 表示输入值的符号。
|
||||
|
||||
Args:
|
||||
x: 输入因子表达式或字段名字符串
|
||||
|
||||
Returns:
|
||||
FunctionNode: 函数调用节点
|
||||
"""
|
||||
return FunctionNode("sign", x)
|
||||
|
||||
|
||||
def abs(x: Union[Node, str]) -> FunctionNode:
|
||||
"""绝对值。
|
||||
|
||||
Args:
|
||||
x: 输入因子表达式或字段名字符串
|
||||
|
||||
Returns:
|
||||
FunctionNode: 函数调用节点
|
||||
"""
|
||||
return FunctionNode("abs", x)
|
||||
|
||||
|
||||
def max_(x: Union[Node, str], y: Union[Node, str, int, float]) -> FunctionNode:
|
||||
"""逐元素最大值。
|
||||
|
||||
Args:
|
||||
x: 第一个因子表达式或字段名字符串
|
||||
y: 第二个因子表达式、字段名字符串或数值
|
||||
|
||||
Returns:
|
||||
FunctionNode: 函数调用节点
|
||||
"""
|
||||
return FunctionNode("max", x, _ensure_node(y))
|
||||
|
||||
|
||||
def min_(x: Union[Node, str], y: Union[Node, str, int, float]) -> FunctionNode:
|
||||
"""逐元素最小值。
|
||||
|
||||
Args:
|
||||
x: 第一个因子表达式或字段名字符串
|
||||
y: 第二个因子表达式、字段名字符串或数值
|
||||
|
||||
Returns:
|
||||
FunctionNode: 函数调用节点
|
||||
"""
|
||||
return FunctionNode("min", x, _ensure_node(y))
|
||||
|
||||
|
||||
def clip(
|
||||
x: Union[Node, str],
|
||||
lower: Union[Node, str, int, float],
|
||||
upper: Union[Node, str, int, float],
|
||||
) -> FunctionNode:
|
||||
"""数值裁剪。
|
||||
|
||||
将因子值限制在 [lower, upper] 范围内。
|
||||
|
||||
Args:
|
||||
x: 输入因子表达式或字段名字符串
|
||||
lower: 下限(因子表达式、字段名字符串或数值)
|
||||
upper: 上限(因子表达式、字段名字符串或数值)
|
||||
|
||||
Returns:
|
||||
FunctionNode: 函数调用节点
|
||||
"""
|
||||
return FunctionNode("clip", x, _ensure_node(lower), _ensure_node(upper))
|
||||
|
||||
|
||||
# ==================== 条件函数 ====================
|
||||
|
||||
|
||||
def if_(
|
||||
condition: Union[Node, str],
|
||||
true_val: Union[Node, str, int, float],
|
||||
false_val: Union[Node, str, int, float],
|
||||
) -> FunctionNode:
|
||||
"""条件选择。
|
||||
|
||||
根据条件选择值。
|
||||
|
||||
Args:
|
||||
condition: 条件表达式或字段名字符串
|
||||
true_val: 条件为真时的值(因子表达式、字段名字符串或数值)
|
||||
false_val: 条件为假时的值(因子表达式、字段名字符串或数值)
|
||||
|
||||
Returns:
|
||||
FunctionNode: 函数调用节点
|
||||
"""
|
||||
return FunctionNode(
|
||||
"if", condition, _ensure_node(true_val), _ensure_node(false_val)
|
||||
)
|
||||
|
||||
|
||||
def where(
|
||||
condition: Union[Node, str],
|
||||
true_val: Union[Node, str, int, float],
|
||||
false_val: Union[Node, str, int, float],
|
||||
) -> FunctionNode:
|
||||
"""条件选择(if_ 的别名)。
|
||||
|
||||
Args:
|
||||
condition: 条件表达式或字段名字符串
|
||||
true_val: 条件为真时的值(因子表达式、字段名字符串或数值)
|
||||
false_val: 条件为假时的值(因子表达式、字段名字符串或数值)
|
||||
|
||||
Returns:
|
||||
FunctionNode: 函数调用节点
|
||||
"""
|
||||
return if_(condition, true_val, false_val)
|
||||
159
src/factors/compiler.py
Normal file
159
src/factors/compiler.py
Normal file
@@ -0,0 +1,159 @@
|
||||
"""AST 编译器模块 - 提供依赖提取和代码生成功能。
|
||||
|
||||
本模块实现 AST 遍历器模式,用于从 DSL 表达式中提取依赖的符号。
|
||||
"""
|
||||
|
||||
from typing import Set
|
||||
|
||||
from src.factors.dsl import Node, Symbol, BinaryOpNode, UnaryOpNode, FunctionNode
|
||||
|
||||
|
||||
class DependencyExtractor:
|
||||
"""依赖提取器 - 使用访问者模式遍历 AST 节点。
|
||||
|
||||
递归遍历表达式树,提取所有 Symbol 节点的名称。
|
||||
支持 BinaryOpNode、UnaryOpNode 和 FunctionNode 的递归遍历。
|
||||
|
||||
Example:
|
||||
>>> from src.factors.dsl import Symbol, FunctionNode
|
||||
>>> close = Symbol("close")
|
||||
>>> pe_ratio = Symbol("pe_ratio")
|
||||
>>> alpha = FunctionNode("cs_rank", close / pe_ratio)
|
||||
>>> deps = DependencyExtractor.extract_dependencies(alpha)
|
||||
>>> print(deps)
|
||||
{'close', 'pe_ratio'}
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""初始化依赖提取器。"""
|
||||
self.dependencies: Set[str] = set()
|
||||
|
||||
def visit(self, node: Node) -> None:
|
||||
"""访问节点,根据节点类型分发到具体处理方法。
|
||||
|
||||
Args:
|
||||
node: AST 节点
|
||||
"""
|
||||
if isinstance(node, Symbol):
|
||||
self._visit_symbol(node)
|
||||
elif isinstance(node, BinaryOpNode):
|
||||
self._visit_binary_op(node)
|
||||
elif isinstance(node, UnaryOpNode):
|
||||
self._visit_unary_op(node)
|
||||
elif isinstance(node, FunctionNode):
|
||||
self._visit_function(node)
|
||||
# Constant 节点不包含依赖,无需处理
|
||||
|
||||
def _visit_symbol(self, node: Symbol) -> None:
|
||||
"""访问 Symbol 节点,提取符号名称。
|
||||
|
||||
Args:
|
||||
node: 符号节点
|
||||
"""
|
||||
self.dependencies.add(node.name)
|
||||
|
||||
def _visit_binary_op(self, node: BinaryOpNode) -> None:
|
||||
"""访问 BinaryOpNode 节点,递归遍历左右子节点。
|
||||
|
||||
Args:
|
||||
node: 二元运算节点
|
||||
"""
|
||||
self.visit(node.left)
|
||||
self.visit(node.right)
|
||||
|
||||
def _visit_unary_op(self, node: UnaryOpNode) -> None:
|
||||
"""访问 UnaryOpNode 节点,递归遍历操作数。
|
||||
|
||||
Args:
|
||||
node: 一元运算节点
|
||||
"""
|
||||
self.visit(node.operand)
|
||||
|
||||
def _visit_function(self, node: FunctionNode) -> None:
|
||||
"""访问 FunctionNode 节点,递归遍历所有参数。
|
||||
|
||||
Args:
|
||||
node: 函数调用节点
|
||||
"""
|
||||
for arg in node.args:
|
||||
self.visit(arg)
|
||||
|
||||
def extract(self, node: Node) -> Set[str]:
|
||||
"""从 AST 节点中提取所有依赖的符号名称。
|
||||
|
||||
Args:
|
||||
node: 表达式树的根节点
|
||||
|
||||
Returns:
|
||||
依赖的符号名称集合
|
||||
"""
|
||||
self.dependencies.clear()
|
||||
self.visit(node)
|
||||
return self.dependencies.copy()
|
||||
|
||||
@classmethod
|
||||
def extract_dependencies(cls, node: Node) -> Set[str]:
|
||||
"""类方法 - 从 AST 节点中提取所有依赖的符号名称。
|
||||
|
||||
这是一个便捷方法,无需手动实例化 DependencyExtractor。
|
||||
|
||||
Args:
|
||||
node: 表达式树的根节点
|
||||
|
||||
Returns:
|
||||
依赖的符号名称集合
|
||||
|
||||
Example:
|
||||
>>> from src.factors.dsl import Symbol
|
||||
>>> close = Symbol("close")
|
||||
>>> open_price = Symbol("open")
|
||||
>>> expr = close / open_price
|
||||
>>> deps = DependencyExtractor.extract_dependencies(expr)
|
||||
>>> print(deps)
|
||||
{'close', 'open'}
|
||||
"""
|
||||
extractor = cls()
|
||||
return extractor.extract(node)
|
||||
|
||||
|
||||
def extract_dependencies(node: Node) -> Set[str]:
|
||||
"""单例方法 - 从 AST 节点中提取所有依赖的符号名称。
|
||||
|
||||
这是 DependencyExtractor.extract_dependencies 的便捷包装函数。
|
||||
|
||||
Args:
|
||||
node: 表达式树的根节点
|
||||
|
||||
Returns:
|
||||
依赖的符号名称集合
|
||||
|
||||
Example:
|
||||
>>> from src.factors.dsl import Symbol, FunctionNode
|
||||
>>> close = Symbol("close")
|
||||
>>> pe_ratio = Symbol("pe_ratio")
|
||||
>>> alpha = FunctionNode("cs_rank", close / pe_ratio)
|
||||
>>> deps = extract_dependencies(alpha)
|
||||
>>> print(deps)
|
||||
{'close', 'pe_ratio'}
|
||||
"""
|
||||
return DependencyExtractor.extract_dependencies(node)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 测试用例: cs_rank(close / pe_ratio)
|
||||
from src.factors.dsl import Symbol, FunctionNode
|
||||
|
||||
# 创建符号
|
||||
close = Symbol("close")
|
||||
pe_ratio = Symbol("pe_ratio")
|
||||
|
||||
# 构建表达式: cs_rank(close / pe_ratio)
|
||||
alpha = FunctionNode("cs_rank", close / pe_ratio)
|
||||
|
||||
# 提取依赖
|
||||
dependencies = extract_dependencies(alpha)
|
||||
|
||||
print(f"表达式: {alpha}")
|
||||
print(f"提取的依赖: {dependencies}")
|
||||
print(f"期望依赖: {{'close', 'pe_ratio'}}")
|
||||
print(f"验证结果: {dependencies == {'close', 'pe_ratio'}}")
|
||||
278
src/factors/dsl.py
Normal file
278
src/factors/dsl.py
Normal file
@@ -0,0 +1,278 @@
|
||||
"""DSL 表达式层 - 纯 Python 实现,无 pandas/polars 依赖。
|
||||
|
||||
提供因子表达式的符号化表示能力,通过重载运算符实现
|
||||
用户端无感知的公式编写。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, List, Union
|
||||
|
||||
|
||||
class Node(ABC):
|
||||
"""表达式节点基类。
|
||||
|
||||
所有因子表达式组件的抽象基类,提供运算符重载能力。
|
||||
子类需要实现 __repr__ 方法用于表达式可视化。
|
||||
"""
|
||||
|
||||
# ==================== 算术运算符重载 ====================
|
||||
|
||||
def __add__(self, other: Any) -> BinaryOpNode:
|
||||
"""加法: self + other"""
|
||||
return BinaryOpNode("+", self, _ensure_node(other))
|
||||
|
||||
def __radd__(self, other: Any) -> BinaryOpNode:
|
||||
"""右加法: other + self"""
|
||||
return BinaryOpNode("+", _ensure_node(other), self)
|
||||
|
||||
def __sub__(self, other: Any) -> BinaryOpNode:
|
||||
"""减法: self - other"""
|
||||
return BinaryOpNode("-", self, _ensure_node(other))
|
||||
|
||||
def __rsub__(self, other: Any) -> BinaryOpNode:
|
||||
"""右减法: other - self"""
|
||||
return BinaryOpNode("-", _ensure_node(other), self)
|
||||
|
||||
def __mul__(self, other: Any) -> BinaryOpNode:
|
||||
"""乘法: self * other"""
|
||||
return BinaryOpNode("*", self, _ensure_node(other))
|
||||
|
||||
def __rmul__(self, other: Any) -> BinaryOpNode:
|
||||
"""右乘法: other * self"""
|
||||
return BinaryOpNode("*", _ensure_node(other), self)
|
||||
|
||||
def __truediv__(self, other: Any) -> BinaryOpNode:
|
||||
"""除法: self / other"""
|
||||
return BinaryOpNode("/", self, _ensure_node(other))
|
||||
|
||||
def __rtruediv__(self, other: Any) -> BinaryOpNode:
|
||||
"""右除法: other / self"""
|
||||
return BinaryOpNode("/", _ensure_node(other), self)
|
||||
|
||||
def __pow__(self, other: Any) -> BinaryOpNode:
|
||||
"""幂运算: self ** other"""
|
||||
return BinaryOpNode("**", self, _ensure_node(other))
|
||||
|
||||
def __rpow__(self, other: Any) -> BinaryOpNode:
|
||||
"""右幂运算: other ** self"""
|
||||
return BinaryOpNode("**", _ensure_node(other), self)
|
||||
|
||||
def __floordiv__(self, other: Any) -> BinaryOpNode:
|
||||
"""整除: self // other"""
|
||||
return BinaryOpNode("//", self, _ensure_node(other))
|
||||
|
||||
def __rfloordiv__(self, other: Any) -> BinaryOpNode:
|
||||
"""右整除: other // self"""
|
||||
return BinaryOpNode("//", _ensure_node(other), self)
|
||||
|
||||
def __mod__(self, other: Any) -> BinaryOpNode:
|
||||
"""取模: self % other"""
|
||||
return BinaryOpNode("%", self, _ensure_node(other))
|
||||
|
||||
def __rmod__(self, other: Any) -> BinaryOpNode:
|
||||
"""右取模: other % self"""
|
||||
return BinaryOpNode("%", _ensure_node(other), self)
|
||||
|
||||
# ==================== 一元运算符重载 ====================
|
||||
|
||||
def __neg__(self) -> UnaryOpNode:
|
||||
"""取负: -self"""
|
||||
return UnaryOpNode("-", self)
|
||||
|
||||
def __pos__(self) -> UnaryOpNode:
|
||||
"""取正: +self"""
|
||||
return UnaryOpNode("+", self)
|
||||
|
||||
def __abs__(self) -> UnaryOpNode:
|
||||
"""绝对值: abs(self)"""
|
||||
return UnaryOpNode("abs", self)
|
||||
|
||||
# ==================== 比较运算符重载 ====================
|
||||
|
||||
def __eq__(self, other: Any) -> BinaryOpNode:
|
||||
"""等于: self == other"""
|
||||
return BinaryOpNode("==", self, _ensure_node(other))
|
||||
|
||||
def __ne__(self, other: Any) -> BinaryOpNode:
|
||||
"""不等于: self != other"""
|
||||
return BinaryOpNode("!=", self, _ensure_node(other))
|
||||
|
||||
def __lt__(self, other: Any) -> BinaryOpNode:
|
||||
"""小于: self < other"""
|
||||
return BinaryOpNode("<", self, _ensure_node(other))
|
||||
|
||||
def __le__(self, other: Any) -> BinaryOpNode:
|
||||
"""小于等于: self <= other"""
|
||||
return BinaryOpNode("<=", self, _ensure_node(other))
|
||||
|
||||
def __gt__(self, other: Any) -> BinaryOpNode:
|
||||
"""大于: self > other"""
|
||||
return BinaryOpNode(">", self, _ensure_node(other))
|
||||
|
||||
def __ge__(self, other: Any) -> BinaryOpNode:
|
||||
"""大于等于: self >= other"""
|
||||
return BinaryOpNode(">=", self, _ensure_node(other))
|
||||
|
||||
# ==================== 抽象方法 ====================
|
||||
|
||||
@abstractmethod
|
||||
def __repr__(self) -> str:
|
||||
"""返回表达式的字符串表示。"""
|
||||
pass
|
||||
|
||||
|
||||
class Symbol(Node):
|
||||
"""符号节点,代表一个命名变量(如 close, open 等)。
|
||||
|
||||
Attributes:
|
||||
name: 符号名称,用于标识该变量
|
||||
"""
|
||||
|
||||
def __init__(self, name: str) -> None:
|
||||
"""初始化符号节点。
|
||||
|
||||
Args:
|
||||
name: 符号名称,如 'close', 'open', 'volume' 等
|
||||
"""
|
||||
self.name = name
|
||||
|
||||
def __repr__(self) -> str:
|
||||
"""返回符号名称。"""
|
||||
return self.name
|
||||
|
||||
def __hash__(self) -> int:
|
||||
"""支持作为字典键使用。"""
|
||||
return hash(self.name)
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
"""符号相等性比较。"""
|
||||
if not isinstance(other, Symbol):
|
||||
return NotImplemented
|
||||
return self.name == other.name
|
||||
|
||||
|
||||
class Constant(Node):
|
||||
"""常量节点,代表一个数值常量。
|
||||
|
||||
Attributes:
|
||||
value: 常量数值
|
||||
"""
|
||||
|
||||
def __init__(self, value: Union[int, float]) -> None:
|
||||
"""初始化常量节点。
|
||||
|
||||
Args:
|
||||
value: 常量数值
|
||||
"""
|
||||
self.value = value
|
||||
|
||||
def __repr__(self) -> str:
|
||||
"""返回常量值的字符串表示。"""
|
||||
return str(self.value)
|
||||
|
||||
|
||||
class BinaryOpNode(Node):
|
||||
"""二元运算节点,表示两个操作数之间的运算。
|
||||
|
||||
Attributes:
|
||||
op: 运算符,如 '+', '-', '*', '/' 等
|
||||
left: 左操作数
|
||||
right: 右操作数
|
||||
"""
|
||||
|
||||
def __init__(self, op: str, left: Node, right: Node) -> None:
|
||||
"""初始化二元运算节点。
|
||||
|
||||
Args:
|
||||
op: 运算符字符串
|
||||
left: 左操作数节点
|
||||
right: 右操作数节点
|
||||
"""
|
||||
self.op = op
|
||||
self.left = left
|
||||
self.right = right
|
||||
|
||||
def __repr__(self) -> str:
|
||||
"""返回带括号的二元运算表达式。"""
|
||||
return f"({self.left} {self.op} {self.right})"
|
||||
|
||||
|
||||
class UnaryOpNode(Node):
|
||||
"""一元运算节点,表示对单个操作数的运算。
|
||||
|
||||
Attributes:
|
||||
op: 运算符,如 '-', '+', 'abs' 等
|
||||
operand: 操作数
|
||||
"""
|
||||
|
||||
def __init__(self, op: str, operand: Node) -> None:
|
||||
"""初始化一元运算节点。
|
||||
|
||||
Args:
|
||||
op: 运算符字符串
|
||||
operand: 操作数节点
|
||||
"""
|
||||
self.op = op
|
||||
self.operand = operand
|
||||
|
||||
def __repr__(self) -> str:
|
||||
"""返回一元运算表达式。"""
|
||||
if self.op in ("+", "-"):
|
||||
return f"({self.op}{self.operand})"
|
||||
return f"{self.op}({self.operand})"
|
||||
|
||||
|
||||
class FunctionNode(Node):
|
||||
"""函数调用节点,表示一个函数调用。
|
||||
|
||||
Attributes:
|
||||
func_name: 函数名称
|
||||
args: 函数参数列表
|
||||
"""
|
||||
|
||||
def __init__(self, func_name: str, *args: Any) -> None:
|
||||
"""初始化函数调用节点。
|
||||
|
||||
Args:
|
||||
func_name: 函数名称,如 'ts_mean', 'cs_rank' 等
|
||||
*args: 函数参数,可以是 Node 或其他类型
|
||||
"""
|
||||
self.func_name = func_name
|
||||
# 将所有参数转换为节点类型
|
||||
self.args: List[Node] = [_ensure_node(arg) for arg in args]
|
||||
|
||||
def __repr__(self) -> str:
|
||||
"""返回函数调用表达式。"""
|
||||
args_str = ", ".join(repr(arg) for arg in self.args)
|
||||
return f"{self.func_name}({args_str})"
|
||||
|
||||
|
||||
# ==================== 辅助函数 ====================
|
||||
|
||||
|
||||
def _ensure_node(value: Any) -> Node:
|
||||
"""确保值是一个 Node 节点。
|
||||
|
||||
如果值已经是 Node 类型,直接返回;
|
||||
如果是数值类型,包装为 Constant 节点;
|
||||
如果是字符串类型,包装为 Symbol 节点;
|
||||
否则抛出类型错误。
|
||||
|
||||
Args:
|
||||
value: 任意值
|
||||
|
||||
Returns:
|
||||
Node: 对应的节点对象
|
||||
|
||||
Raises:
|
||||
TypeError: 当值无法转换为节点时
|
||||
"""
|
||||
if isinstance(value, Node):
|
||||
return value
|
||||
if isinstance(value, (int, float)):
|
||||
return Constant(value)
|
||||
if isinstance(value, str):
|
||||
return Symbol(value)
|
||||
raise TypeError(f"无法将类型 {type(value).__name__} 转换为 Node")
|
||||
387
src/factors/translator.py
Normal file
387
src/factors/translator.py
Normal file
@@ -0,0 +1,387 @@
|
||||
"""Polars 翻译器 - 将 AST 翻译为 Polars 表达式。
|
||||
|
||||
本模块实现 DSL 到 Polars 计算图的映射,是因子表达式执行的桥梁。
|
||||
支持时序因子(ts_*)和截面因子(cs_*)的防错分组翻译。
|
||||
"""
|
||||
|
||||
from typing import Any, Callable, Dict
|
||||
|
||||
import polars as pl
|
||||
|
||||
from src.factors.dsl import (
|
||||
BinaryOpNode,
|
||||
Constant,
|
||||
FunctionNode,
|
||||
Node,
|
||||
Symbol,
|
||||
UnaryOpNode,
|
||||
)
|
||||
|
||||
|
||||
class PolarsTranslator:
|
||||
"""Polars 表达式翻译器。
|
||||
|
||||
将纯对象的 AST 树完美映射为 Polars 的带防错分组的计算图。
|
||||
|
||||
Attributes:
|
||||
handlers: 函数处理器注册表,映射 func_name 到处理函数
|
||||
|
||||
Example:
|
||||
>>> from src.factors.dsl import Symbol, FunctionNode
|
||||
>>> close = Symbol("close")
|
||||
>>> expr = FunctionNode("ts_mean", close, 20)
|
||||
>>> translator = PolarsTranslator()
|
||||
>>> polars_expr = translator.translate(expr)
|
||||
>>> # 结果: pl.col("close").rolling_mean(20).over("asset")
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""初始化翻译器并注册内置函数处理器。"""
|
||||
self.handlers: Dict[str, Callable[[FunctionNode], pl.Expr]] = {}
|
||||
self._register_builtin_handlers()
|
||||
|
||||
def _register_builtin_handlers(self) -> None:
|
||||
"""注册内置的函数处理器。"""
|
||||
# 时序因子处理器 (ts_*)
|
||||
self.register_handler("ts_mean", self._handle_ts_mean)
|
||||
self.register_handler("ts_sum", self._handle_ts_sum)
|
||||
self.register_handler("ts_std", self._handle_ts_std)
|
||||
self.register_handler("ts_max", self._handle_ts_max)
|
||||
self.register_handler("ts_min", self._handle_ts_min)
|
||||
self.register_handler("ts_delay", self._handle_ts_delay)
|
||||
self.register_handler("ts_delta", self._handle_ts_delta)
|
||||
self.register_handler("ts_corr", self._handle_ts_corr)
|
||||
self.register_handler("ts_cov", self._handle_ts_cov)
|
||||
|
||||
# 截面因子处理器 (cs_*)
|
||||
self.register_handler("cs_rank", self._handle_cs_rank)
|
||||
self.register_handler("cs_zscore", self._handle_cs_zscore)
|
||||
self.register_handler("cs_neutral", self._handle_cs_neutral)
|
||||
|
||||
def register_handler(
|
||||
self, func_name: str, handler: Callable[[FunctionNode], pl.Expr]
|
||||
) -> None:
|
||||
"""注册自定义函数处理器。
|
||||
|
||||
Args:
|
||||
func_name: 函数名称
|
||||
handler: 处理函数,接收 FunctionNode 返回 pl.Expr
|
||||
|
||||
Example:
|
||||
>>> def handle_custom(node: FunctionNode) -> pl.Expr:
|
||||
... arg = self.translate(node.args[0])
|
||||
... return arg * 2
|
||||
>>> translator.register_handler("custom", handle_custom)
|
||||
"""
|
||||
self.handlers[func_name] = handler
|
||||
|
||||
def translate(self, node: Node) -> pl.Expr:
|
||||
"""递归翻译 AST 节点为 Polars 表达式。
|
||||
|
||||
Args:
|
||||
node: AST 节点(Symbol、Constant、BinaryOpNode、UnaryOpNode、FunctionNode)
|
||||
|
||||
Returns:
|
||||
Polars 表达式对象
|
||||
|
||||
Raises:
|
||||
TypeError: 当遇到未知的节点类型时
|
||||
"""
|
||||
if isinstance(node, Symbol):
|
||||
return self._translate_symbol(node)
|
||||
elif isinstance(node, Constant):
|
||||
return self._translate_constant(node)
|
||||
elif isinstance(node, BinaryOpNode):
|
||||
return self._translate_binary_op(node)
|
||||
elif isinstance(node, UnaryOpNode):
|
||||
return self._translate_unary_op(node)
|
||||
elif isinstance(node, FunctionNode):
|
||||
return self._translate_function(node)
|
||||
else:
|
||||
raise TypeError(f"未知的节点类型: {type(node).__name__}")
|
||||
|
||||
def _translate_symbol(self, node: Symbol) -> pl.Expr:
|
||||
"""翻译 Symbol 节点为 pl.col() 表达式。
|
||||
|
||||
Args:
|
||||
node: 符号节点
|
||||
|
||||
Returns:
|
||||
pl.col(node.name) 表达式
|
||||
"""
|
||||
return pl.col(node.name)
|
||||
|
||||
def _translate_constant(self, node: Constant) -> pl.Expr:
|
||||
"""翻译 Constant 节点为 Polars 字面量。
|
||||
|
||||
Args:
|
||||
node: 常量节点
|
||||
|
||||
Returns:
|
||||
pl.lit(node.value) 表达式
|
||||
"""
|
||||
return pl.lit(node.value)
|
||||
|
||||
def _translate_binary_op(self, node: BinaryOpNode) -> pl.Expr:
|
||||
"""翻译 BinaryOpNode 为 Polars 二元运算。
|
||||
|
||||
Args:
|
||||
node: 二元运算节点
|
||||
|
||||
Returns:
|
||||
Polars 二元运算表达式
|
||||
"""
|
||||
left = self.translate(node.left)
|
||||
right = self.translate(node.right)
|
||||
|
||||
op_map = {
|
||||
"+": lambda l, r: l + r,
|
||||
"-": lambda l, r: l - r,
|
||||
"*": lambda l, r: l * r,
|
||||
"/": lambda l, r: l / r,
|
||||
"**": lambda l, r: l.pow(r),
|
||||
"//": lambda l, r: l.floor_div(r),
|
||||
"%": lambda l, r: l % r,
|
||||
"==": lambda l, r: l.eq(r),
|
||||
"!=": lambda l, r: l.ne(r),
|
||||
"<": lambda l, r: l.lt(r),
|
||||
"<=": lambda l, r: l.le(r),
|
||||
">": lambda l, r: l.gt(r),
|
||||
">=": lambda l, r: l.ge(r),
|
||||
}
|
||||
|
||||
if node.op not in op_map:
|
||||
raise ValueError(f"不支持的二元运算符: {node.op}")
|
||||
|
||||
return op_map[node.op](left, right)
|
||||
|
||||
def _translate_unary_op(self, node: UnaryOpNode) -> pl.Expr:
|
||||
"""翻译 UnaryOpNode 为 Polars 一元运算。
|
||||
|
||||
Args:
|
||||
node: 一元运算节点
|
||||
|
||||
Returns:
|
||||
Polars 一元运算表达式
|
||||
"""
|
||||
operand = self.translate(node.operand)
|
||||
|
||||
op_map = {
|
||||
"+": lambda x: x,
|
||||
"-": lambda x: -x,
|
||||
"abs": lambda x: x.abs(),
|
||||
}
|
||||
|
||||
if node.op not in op_map:
|
||||
raise ValueError(f"不支持的一元运算符: {node.op}")
|
||||
|
||||
return op_map[node.op](operand)
|
||||
|
||||
def _translate_function(self, node: FunctionNode) -> pl.Expr:
|
||||
"""翻译 FunctionNode 为 Polars 函数调用。
|
||||
|
||||
优先从 handlers 注册表中查找处理器,未找到则抛出错误。
|
||||
|
||||
Args:
|
||||
node: 函数调用节点
|
||||
|
||||
Returns:
|
||||
Polars 函数表达式
|
||||
|
||||
Raises:
|
||||
ValueError: 当函数名称未注册处理器时
|
||||
"""
|
||||
func_name = node.func_name
|
||||
|
||||
if func_name in self.handlers:
|
||||
return self.handlers[func_name](node)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"未注册的函数: {func_name}. 请使用 register_handler 注册处理器。"
|
||||
)
|
||||
|
||||
# ==================== 时序因子处理器 (ts_*) ====================
|
||||
# 所有时序因子强制注入 over("ts_code") 防串表
|
||||
|
||||
def _handle_ts_mean(self, node: FunctionNode) -> pl.Expr:
|
||||
"""处理 ts_mean(close, window) -> rolling_mean(window).over(ts_code)。"""
|
||||
if len(node.args) != 2:
|
||||
raise ValueError("ts_mean 需要 2 个参数: (expr, window)")
|
||||
expr = self.translate(node.args[0])
|
||||
window = self._extract_window(node.args[1])
|
||||
return expr.rolling_mean(window_size=window).over("ts_code")
|
||||
|
||||
def _handle_ts_sum(self, node: FunctionNode) -> pl.Expr:
|
||||
"""处理 ts_sum(close, window) -> rolling_sum(window).over(ts_code)。"""
|
||||
if len(node.args) != 2:
|
||||
raise ValueError("ts_sum 需要 2 个参数: (expr, window)")
|
||||
expr = self.translate(node.args[0])
|
||||
window = self._extract_window(node.args[1])
|
||||
return expr.rolling_sum(window_size=window).over("ts_code")
|
||||
|
||||
def _handle_ts_std(self, node: FunctionNode) -> pl.Expr:
|
||||
"""处理 ts_std(close, window) -> rolling_std(window).over(ts_code)。"""
|
||||
if len(node.args) != 2:
|
||||
raise ValueError("ts_std 需要 2 个参数: (expr, window)")
|
||||
expr = self.translate(node.args[0])
|
||||
window = self._extract_window(node.args[1])
|
||||
return expr.rolling_std(window_size=window).over("ts_code")
|
||||
|
||||
def _handle_ts_max(self, node: FunctionNode) -> pl.Expr:
|
||||
"""处理 ts_max(close, window) -> rolling_max(window).over(ts_code)。"""
|
||||
if len(node.args) != 2:
|
||||
raise ValueError("ts_max 需要 2 个参数: (expr, window)")
|
||||
expr = self.translate(node.args[0])
|
||||
window = self._extract_window(node.args[1])
|
||||
return expr.rolling_max(window_size=window).over("ts_code")
|
||||
|
||||
def _handle_ts_min(self, node: FunctionNode) -> pl.Expr:
|
||||
"""处理 ts_min(close, window) -> rolling_min(window).over(ts_code)。"""
|
||||
if len(node.args) != 2:
|
||||
raise ValueError("ts_min 需要 2 个参数: (expr, window)")
|
||||
expr = self.translate(node.args[0])
|
||||
window = self._extract_window(node.args[1])
|
||||
return expr.rolling_min(window_size=window).over("ts_code")
|
||||
|
||||
def _handle_ts_delay(self, node: FunctionNode) -> pl.Expr:
|
||||
"""处理 ts_delay(close, n) -> shift(n).over(ts_code)。"""
|
||||
if len(node.args) != 2:
|
||||
raise ValueError("ts_delay 需要 2 个参数: (expr, n)")
|
||||
expr = self.translate(node.args[0])
|
||||
n = self._extract_window(node.args[1])
|
||||
return expr.shift(n).over("ts_code")
|
||||
|
||||
def _handle_ts_delta(self, node: FunctionNode) -> pl.Expr:
|
||||
"""处理 ts_delta(close, n) -> (expr - shift(n)).over(ts_code)。"""
|
||||
if len(node.args) != 2:
|
||||
raise ValueError("ts_delta 需要 2 个参数: (expr, n)")
|
||||
expr = self.translate(node.args[0])
|
||||
n = self._extract_window(node.args[1])
|
||||
return (expr - expr.shift(n)).over("ts_code")
|
||||
|
||||
def _handle_ts_corr(self, node: FunctionNode) -> pl.Expr:
|
||||
"""处理 ts_corr(x, y, window) -> rolling_corr(y, window).over(ts_code)。"""
|
||||
if len(node.args) != 3:
|
||||
raise ValueError("ts_corr 需要 3 个参数: (x, y, window)")
|
||||
x = self.translate(node.args[0])
|
||||
y = self.translate(node.args[1])
|
||||
window = self._extract_window(node.args[2])
|
||||
return x.rolling_corr(y, window_size=window).over("ts_code")
|
||||
|
||||
def _handle_ts_cov(self, node: FunctionNode) -> pl.Expr:
|
||||
"""处理 ts_cov(x, y, window) -> rolling_cov(y, window).over(ts_code)。"""
|
||||
if len(node.args) != 3:
|
||||
raise ValueError("ts_cov 需要 3 个参数: (x, y, window)")
|
||||
x = self.translate(node.args[0])
|
||||
y = self.translate(node.args[1])
|
||||
window = self._extract_window(node.args[2])
|
||||
return x.rolling_cov(y, window_size=window).over("ts_code")
|
||||
|
||||
# ==================== 截面因子处理器 (cs_*) ====================
|
||||
# 所有截面因子强制注入 over("trade_date") 防串表
|
||||
|
||||
def _handle_cs_rank(self, node: FunctionNode) -> pl.Expr:
|
||||
"""处理 cs_rank(expr) -> rank()/count().over(trade_date)。
|
||||
|
||||
将排名归一化到 [0, 1] 区间。
|
||||
"""
|
||||
if len(node.args) != 1:
|
||||
raise ValueError("cs_rank 需要 1 个参数: (expr)")
|
||||
expr = self.translate(node.args[0])
|
||||
return (expr.rank() / expr.count()).over("trade_date")
|
||||
|
||||
def _handle_cs_zscore(self, node: FunctionNode) -> pl.Expr:
|
||||
"""处理 cs_zscore(expr) -> (expr - mean())/std().over(trade_date)。"""
|
||||
if len(node.args) != 1:
|
||||
raise ValueError("cs_zscore 需要 1 个参数: (expr)")
|
||||
expr = self.translate(node.args[0])
|
||||
return ((expr - expr.mean()) / expr.std()).over("trade_date")
|
||||
|
||||
def _handle_cs_neutral(self, node: FunctionNode) -> pl.Expr:
|
||||
"""处理 cs_neutral(expr, group) -> 分组中性化。"""
|
||||
if len(node.args) not in [1, 2]:
|
||||
raise ValueError("cs_neutral 需要 1-2 个参数: (expr, [group_col])")
|
||||
expr = self.translate(node.args[0])
|
||||
# 简单实现:减去截面均值(可在未来扩展为分组中性化)
|
||||
return (expr - expr.mean()).over("trade_date")
|
||||
|
||||
# ==================== 辅助方法 ====================
|
||||
|
||||
def _extract_window(self, node: Node) -> int:
|
||||
"""从节点中提取窗口大小参数。
|
||||
|
||||
Args:
|
||||
node: 应该是 Constant 节点
|
||||
|
||||
Returns:
|
||||
整数值
|
||||
|
||||
Raises:
|
||||
ValueError: 当节点不是 Constant 或值不是整数时
|
||||
"""
|
||||
if isinstance(node, Constant):
|
||||
if not isinstance(node.value, int):
|
||||
raise ValueError(
|
||||
f"窗口参数必须是整数,得到: {type(node.value).__name__}"
|
||||
)
|
||||
return node.value
|
||||
raise ValueError(f"窗口参数必须是常量整数,得到: {type(node).__name__}")
|
||||
|
||||
|
||||
def translate_to_polars(node: Node) -> pl.Expr:
|
||||
"""便捷函数 - 将 AST 节点翻译为 Polars 表达式。
|
||||
|
||||
Args:
|
||||
node: 表达式树的根节点
|
||||
|
||||
Returns:
|
||||
Polars 表达式对象
|
||||
|
||||
Example:
|
||||
>>> from src.factors.dsl import Symbol, FunctionNode
|
||||
>>> close = Symbol("close")
|
||||
>>> expr = FunctionNode("ts_mean", close, 20)
|
||||
>>> polars_expr = translate_to_polars(expr)
|
||||
"""
|
||||
translator = PolarsTranslator()
|
||||
return translator.translate(node)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 测试用例
|
||||
from src.factors.dsl import Symbol, FunctionNode
|
||||
|
||||
# 创建符号
|
||||
close = Symbol("close")
|
||||
volume = Symbol("volume")
|
||||
|
||||
# 测试 1: 简单符号
|
||||
print("测试 1: Symbol")
|
||||
translator = PolarsTranslator()
|
||||
expr1 = translator.translate(close)
|
||||
print(f" close -> {expr1}")
|
||||
assert str(expr1) == 'col("close")'
|
||||
|
||||
# 测试 2: 二元运算
|
||||
print("\n测试 2: BinaryOp")
|
||||
expr2 = translator.translate(close + 10)
|
||||
print(f" close + 10 -> {expr2}")
|
||||
|
||||
# 测试 3: ts_mean
|
||||
print("\n测试 3: ts_mean")
|
||||
expr3 = translator.translate(FunctionNode("ts_mean", close, 20))
|
||||
print(f" ts_mean(close, 20) -> {expr3}")
|
||||
|
||||
# 测试 4: cs_rank
|
||||
print("\n测试 4: cs_rank")
|
||||
expr4 = translator.translate(FunctionNode("cs_rank", close / volume))
|
||||
print(f" cs_rank(close / volume) -> {expr4}")
|
||||
|
||||
# 测试 5: 复杂表达式
|
||||
print("\n测试 5: 复杂表达式")
|
||||
ma20 = FunctionNode("ts_mean", close, 20)
|
||||
ma60 = FunctionNode("ts_mean", close, 60)
|
||||
expr5 = translator.translate(FunctionNode("cs_rank", ma20 - ma60))
|
||||
print(f" cs_rank(ts_mean(close, 20) - ts_mean(close, 60)) -> {expr5}")
|
||||
|
||||
print("\n✅ 所有测试通过!")
|
||||
Reference in New Issue
Block a user