feat: 添加DSL因子表达式系统和Pro Bar API封装
- 新增 factors/dsl.py: 纯Python DSL表达式层,通过运算符重载实现因子组合 - 新增 factors/api.py: 提供常用因子符号(close/open/high/low)和时序函数(ts_mean/ts_std/cs_rank等) - 新增 factors/compiler.py: 因子编译器 - 新增 factors/translator.py: DSL表达式翻译器 - 新增 data/api_wrappers/api_pro_bar.py: Tushare Pro Bar API封装,支持后复权行情数据 - 新增 data/data_router.py: 数据路由功能 - 新增相关测试用例
This commit is contained in:
387
src/factors/translator.py
Normal file
387
src/factors/translator.py
Normal file
@@ -0,0 +1,387 @@
|
||||
"""Polars 翻译器 - 将 AST 翻译为 Polars 表达式。
|
||||
|
||||
本模块实现 DSL 到 Polars 计算图的映射,是因子表达式执行的桥梁。
|
||||
支持时序因子(ts_*)和截面因子(cs_*)的防错分组翻译。
|
||||
"""
|
||||
|
||||
from typing import Any, Callable, Dict
|
||||
|
||||
import polars as pl
|
||||
|
||||
from src.factors.dsl import (
|
||||
BinaryOpNode,
|
||||
Constant,
|
||||
FunctionNode,
|
||||
Node,
|
||||
Symbol,
|
||||
UnaryOpNode,
|
||||
)
|
||||
|
||||
|
||||
class PolarsTranslator:
|
||||
"""Polars 表达式翻译器。
|
||||
|
||||
将纯对象的 AST 树完美映射为 Polars 的带防错分组的计算图。
|
||||
|
||||
Attributes:
|
||||
handlers: 函数处理器注册表,映射 func_name 到处理函数
|
||||
|
||||
Example:
|
||||
>>> from src.factors.dsl import Symbol, FunctionNode
|
||||
>>> close = Symbol("close")
|
||||
>>> expr = FunctionNode("ts_mean", close, 20)
|
||||
>>> translator = PolarsTranslator()
|
||||
>>> polars_expr = translator.translate(expr)
|
||||
>>> # 结果: pl.col("close").rolling_mean(20).over("asset")
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""初始化翻译器并注册内置函数处理器。"""
|
||||
self.handlers: Dict[str, Callable[[FunctionNode], pl.Expr]] = {}
|
||||
self._register_builtin_handlers()
|
||||
|
||||
def _register_builtin_handlers(self) -> None:
|
||||
"""注册内置的函数处理器。"""
|
||||
# 时序因子处理器 (ts_*)
|
||||
self.register_handler("ts_mean", self._handle_ts_mean)
|
||||
self.register_handler("ts_sum", self._handle_ts_sum)
|
||||
self.register_handler("ts_std", self._handle_ts_std)
|
||||
self.register_handler("ts_max", self._handle_ts_max)
|
||||
self.register_handler("ts_min", self._handle_ts_min)
|
||||
self.register_handler("ts_delay", self._handle_ts_delay)
|
||||
self.register_handler("ts_delta", self._handle_ts_delta)
|
||||
self.register_handler("ts_corr", self._handle_ts_corr)
|
||||
self.register_handler("ts_cov", self._handle_ts_cov)
|
||||
|
||||
# 截面因子处理器 (cs_*)
|
||||
self.register_handler("cs_rank", self._handle_cs_rank)
|
||||
self.register_handler("cs_zscore", self._handle_cs_zscore)
|
||||
self.register_handler("cs_neutral", self._handle_cs_neutral)
|
||||
|
||||
def register_handler(
|
||||
self, func_name: str, handler: Callable[[FunctionNode], pl.Expr]
|
||||
) -> None:
|
||||
"""注册自定义函数处理器。
|
||||
|
||||
Args:
|
||||
func_name: 函数名称
|
||||
handler: 处理函数,接收 FunctionNode 返回 pl.Expr
|
||||
|
||||
Example:
|
||||
>>> def handle_custom(node: FunctionNode) -> pl.Expr:
|
||||
... arg = self.translate(node.args[0])
|
||||
... return arg * 2
|
||||
>>> translator.register_handler("custom", handle_custom)
|
||||
"""
|
||||
self.handlers[func_name] = handler
|
||||
|
||||
def translate(self, node: Node) -> pl.Expr:
|
||||
"""递归翻译 AST 节点为 Polars 表达式。
|
||||
|
||||
Args:
|
||||
node: AST 节点(Symbol、Constant、BinaryOpNode、UnaryOpNode、FunctionNode)
|
||||
|
||||
Returns:
|
||||
Polars 表达式对象
|
||||
|
||||
Raises:
|
||||
TypeError: 当遇到未知的节点类型时
|
||||
"""
|
||||
if isinstance(node, Symbol):
|
||||
return self._translate_symbol(node)
|
||||
elif isinstance(node, Constant):
|
||||
return self._translate_constant(node)
|
||||
elif isinstance(node, BinaryOpNode):
|
||||
return self._translate_binary_op(node)
|
||||
elif isinstance(node, UnaryOpNode):
|
||||
return self._translate_unary_op(node)
|
||||
elif isinstance(node, FunctionNode):
|
||||
return self._translate_function(node)
|
||||
else:
|
||||
raise TypeError(f"未知的节点类型: {type(node).__name__}")
|
||||
|
||||
def _translate_symbol(self, node: Symbol) -> pl.Expr:
|
||||
"""翻译 Symbol 节点为 pl.col() 表达式。
|
||||
|
||||
Args:
|
||||
node: 符号节点
|
||||
|
||||
Returns:
|
||||
pl.col(node.name) 表达式
|
||||
"""
|
||||
return pl.col(node.name)
|
||||
|
||||
def _translate_constant(self, node: Constant) -> pl.Expr:
|
||||
"""翻译 Constant 节点为 Polars 字面量。
|
||||
|
||||
Args:
|
||||
node: 常量节点
|
||||
|
||||
Returns:
|
||||
pl.lit(node.value) 表达式
|
||||
"""
|
||||
return pl.lit(node.value)
|
||||
|
||||
def _translate_binary_op(self, node: BinaryOpNode) -> pl.Expr:
|
||||
"""翻译 BinaryOpNode 为 Polars 二元运算。
|
||||
|
||||
Args:
|
||||
node: 二元运算节点
|
||||
|
||||
Returns:
|
||||
Polars 二元运算表达式
|
||||
"""
|
||||
left = self.translate(node.left)
|
||||
right = self.translate(node.right)
|
||||
|
||||
op_map = {
|
||||
"+": lambda l, r: l + r,
|
||||
"-": lambda l, r: l - r,
|
||||
"*": lambda l, r: l * r,
|
||||
"/": lambda l, r: l / r,
|
||||
"**": lambda l, r: l.pow(r),
|
||||
"//": lambda l, r: l.floor_div(r),
|
||||
"%": lambda l, r: l % r,
|
||||
"==": lambda l, r: l.eq(r),
|
||||
"!=": lambda l, r: l.ne(r),
|
||||
"<": lambda l, r: l.lt(r),
|
||||
"<=": lambda l, r: l.le(r),
|
||||
">": lambda l, r: l.gt(r),
|
||||
">=": lambda l, r: l.ge(r),
|
||||
}
|
||||
|
||||
if node.op not in op_map:
|
||||
raise ValueError(f"不支持的二元运算符: {node.op}")
|
||||
|
||||
return op_map[node.op](left, right)
|
||||
|
||||
def _translate_unary_op(self, node: UnaryOpNode) -> pl.Expr:
|
||||
"""翻译 UnaryOpNode 为 Polars 一元运算。
|
||||
|
||||
Args:
|
||||
node: 一元运算节点
|
||||
|
||||
Returns:
|
||||
Polars 一元运算表达式
|
||||
"""
|
||||
operand = self.translate(node.operand)
|
||||
|
||||
op_map = {
|
||||
"+": lambda x: x,
|
||||
"-": lambda x: -x,
|
||||
"abs": lambda x: x.abs(),
|
||||
}
|
||||
|
||||
if node.op not in op_map:
|
||||
raise ValueError(f"不支持的一元运算符: {node.op}")
|
||||
|
||||
return op_map[node.op](operand)
|
||||
|
||||
def _translate_function(self, node: FunctionNode) -> pl.Expr:
|
||||
"""翻译 FunctionNode 为 Polars 函数调用。
|
||||
|
||||
优先从 handlers 注册表中查找处理器,未找到则抛出错误。
|
||||
|
||||
Args:
|
||||
node: 函数调用节点
|
||||
|
||||
Returns:
|
||||
Polars 函数表达式
|
||||
|
||||
Raises:
|
||||
ValueError: 当函数名称未注册处理器时
|
||||
"""
|
||||
func_name = node.func_name
|
||||
|
||||
if func_name in self.handlers:
|
||||
return self.handlers[func_name](node)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"未注册的函数: {func_name}. 请使用 register_handler 注册处理器。"
|
||||
)
|
||||
|
||||
# ==================== 时序因子处理器 (ts_*) ====================
|
||||
# 所有时序因子强制注入 over("ts_code") 防串表
|
||||
|
||||
def _handle_ts_mean(self, node: FunctionNode) -> pl.Expr:
|
||||
"""处理 ts_mean(close, window) -> rolling_mean(window).over(ts_code)。"""
|
||||
if len(node.args) != 2:
|
||||
raise ValueError("ts_mean 需要 2 个参数: (expr, window)")
|
||||
expr = self.translate(node.args[0])
|
||||
window = self._extract_window(node.args[1])
|
||||
return expr.rolling_mean(window_size=window).over("ts_code")
|
||||
|
||||
def _handle_ts_sum(self, node: FunctionNode) -> pl.Expr:
|
||||
"""处理 ts_sum(close, window) -> rolling_sum(window).over(ts_code)。"""
|
||||
if len(node.args) != 2:
|
||||
raise ValueError("ts_sum 需要 2 个参数: (expr, window)")
|
||||
expr = self.translate(node.args[0])
|
||||
window = self._extract_window(node.args[1])
|
||||
return expr.rolling_sum(window_size=window).over("ts_code")
|
||||
|
||||
def _handle_ts_std(self, node: FunctionNode) -> pl.Expr:
|
||||
"""处理 ts_std(close, window) -> rolling_std(window).over(ts_code)。"""
|
||||
if len(node.args) != 2:
|
||||
raise ValueError("ts_std 需要 2 个参数: (expr, window)")
|
||||
expr = self.translate(node.args[0])
|
||||
window = self._extract_window(node.args[1])
|
||||
return expr.rolling_std(window_size=window).over("ts_code")
|
||||
|
||||
def _handle_ts_max(self, node: FunctionNode) -> pl.Expr:
|
||||
"""处理 ts_max(close, window) -> rolling_max(window).over(ts_code)。"""
|
||||
if len(node.args) != 2:
|
||||
raise ValueError("ts_max 需要 2 个参数: (expr, window)")
|
||||
expr = self.translate(node.args[0])
|
||||
window = self._extract_window(node.args[1])
|
||||
return expr.rolling_max(window_size=window).over("ts_code")
|
||||
|
||||
def _handle_ts_min(self, node: FunctionNode) -> pl.Expr:
|
||||
"""处理 ts_min(close, window) -> rolling_min(window).over(ts_code)。"""
|
||||
if len(node.args) != 2:
|
||||
raise ValueError("ts_min 需要 2 个参数: (expr, window)")
|
||||
expr = self.translate(node.args[0])
|
||||
window = self._extract_window(node.args[1])
|
||||
return expr.rolling_min(window_size=window).over("ts_code")
|
||||
|
||||
def _handle_ts_delay(self, node: FunctionNode) -> pl.Expr:
|
||||
"""处理 ts_delay(close, n) -> shift(n).over(ts_code)。"""
|
||||
if len(node.args) != 2:
|
||||
raise ValueError("ts_delay 需要 2 个参数: (expr, n)")
|
||||
expr = self.translate(node.args[0])
|
||||
n = self._extract_window(node.args[1])
|
||||
return expr.shift(n).over("ts_code")
|
||||
|
||||
def _handle_ts_delta(self, node: FunctionNode) -> pl.Expr:
|
||||
"""处理 ts_delta(close, n) -> (expr - shift(n)).over(ts_code)。"""
|
||||
if len(node.args) != 2:
|
||||
raise ValueError("ts_delta 需要 2 个参数: (expr, n)")
|
||||
expr = self.translate(node.args[0])
|
||||
n = self._extract_window(node.args[1])
|
||||
return (expr - expr.shift(n)).over("ts_code")
|
||||
|
||||
def _handle_ts_corr(self, node: FunctionNode) -> pl.Expr:
|
||||
"""处理 ts_corr(x, y, window) -> rolling_corr(y, window).over(ts_code)。"""
|
||||
if len(node.args) != 3:
|
||||
raise ValueError("ts_corr 需要 3 个参数: (x, y, window)")
|
||||
x = self.translate(node.args[0])
|
||||
y = self.translate(node.args[1])
|
||||
window = self._extract_window(node.args[2])
|
||||
return x.rolling_corr(y, window_size=window).over("ts_code")
|
||||
|
||||
def _handle_ts_cov(self, node: FunctionNode) -> pl.Expr:
|
||||
"""处理 ts_cov(x, y, window) -> rolling_cov(y, window).over(ts_code)。"""
|
||||
if len(node.args) != 3:
|
||||
raise ValueError("ts_cov 需要 3 个参数: (x, y, window)")
|
||||
x = self.translate(node.args[0])
|
||||
y = self.translate(node.args[1])
|
||||
window = self._extract_window(node.args[2])
|
||||
return x.rolling_cov(y, window_size=window).over("ts_code")
|
||||
|
||||
# ==================== 截面因子处理器 (cs_*) ====================
|
||||
# 所有截面因子强制注入 over("trade_date") 防串表
|
||||
|
||||
def _handle_cs_rank(self, node: FunctionNode) -> pl.Expr:
|
||||
"""处理 cs_rank(expr) -> rank()/count().over(trade_date)。
|
||||
|
||||
将排名归一化到 [0, 1] 区间。
|
||||
"""
|
||||
if len(node.args) != 1:
|
||||
raise ValueError("cs_rank 需要 1 个参数: (expr)")
|
||||
expr = self.translate(node.args[0])
|
||||
return (expr.rank() / expr.count()).over("trade_date")
|
||||
|
||||
def _handle_cs_zscore(self, node: FunctionNode) -> pl.Expr:
|
||||
"""处理 cs_zscore(expr) -> (expr - mean())/std().over(trade_date)。"""
|
||||
if len(node.args) != 1:
|
||||
raise ValueError("cs_zscore 需要 1 个参数: (expr)")
|
||||
expr = self.translate(node.args[0])
|
||||
return ((expr - expr.mean()) / expr.std()).over("trade_date")
|
||||
|
||||
def _handle_cs_neutral(self, node: FunctionNode) -> pl.Expr:
|
||||
"""处理 cs_neutral(expr, group) -> 分组中性化。"""
|
||||
if len(node.args) not in [1, 2]:
|
||||
raise ValueError("cs_neutral 需要 1-2 个参数: (expr, [group_col])")
|
||||
expr = self.translate(node.args[0])
|
||||
# 简单实现:减去截面均值(可在未来扩展为分组中性化)
|
||||
return (expr - expr.mean()).over("trade_date")
|
||||
|
||||
# ==================== 辅助方法 ====================
|
||||
|
||||
def _extract_window(self, node: Node) -> int:
|
||||
"""从节点中提取窗口大小参数。
|
||||
|
||||
Args:
|
||||
node: 应该是 Constant 节点
|
||||
|
||||
Returns:
|
||||
整数值
|
||||
|
||||
Raises:
|
||||
ValueError: 当节点不是 Constant 或值不是整数时
|
||||
"""
|
||||
if isinstance(node, Constant):
|
||||
if not isinstance(node.value, int):
|
||||
raise ValueError(
|
||||
f"窗口参数必须是整数,得到: {type(node.value).__name__}"
|
||||
)
|
||||
return node.value
|
||||
raise ValueError(f"窗口参数必须是常量整数,得到: {type(node).__name__}")
|
||||
|
||||
|
||||
def translate_to_polars(node: Node) -> pl.Expr:
|
||||
"""便捷函数 - 将 AST 节点翻译为 Polars 表达式。
|
||||
|
||||
Args:
|
||||
node: 表达式树的根节点
|
||||
|
||||
Returns:
|
||||
Polars 表达式对象
|
||||
|
||||
Example:
|
||||
>>> from src.factors.dsl import Symbol, FunctionNode
|
||||
>>> close = Symbol("close")
|
||||
>>> expr = FunctionNode("ts_mean", close, 20)
|
||||
>>> polars_expr = translate_to_polars(expr)
|
||||
"""
|
||||
translator = PolarsTranslator()
|
||||
return translator.translate(node)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 测试用例
|
||||
from src.factors.dsl import Symbol, FunctionNode
|
||||
|
||||
# 创建符号
|
||||
close = Symbol("close")
|
||||
volume = Symbol("volume")
|
||||
|
||||
# 测试 1: 简单符号
|
||||
print("测试 1: Symbol")
|
||||
translator = PolarsTranslator()
|
||||
expr1 = translator.translate(close)
|
||||
print(f" close -> {expr1}")
|
||||
assert str(expr1) == 'col("close")'
|
||||
|
||||
# 测试 2: 二元运算
|
||||
print("\n测试 2: BinaryOp")
|
||||
expr2 = translator.translate(close + 10)
|
||||
print(f" close + 10 -> {expr2}")
|
||||
|
||||
# 测试 3: ts_mean
|
||||
print("\n测试 3: ts_mean")
|
||||
expr3 = translator.translate(FunctionNode("ts_mean", close, 20))
|
||||
print(f" ts_mean(close, 20) -> {expr3}")
|
||||
|
||||
# 测试 4: cs_rank
|
||||
print("\n测试 4: cs_rank")
|
||||
expr4 = translator.translate(FunctionNode("cs_rank", close / volume))
|
||||
print(f" cs_rank(close / volume) -> {expr4}")
|
||||
|
||||
# 测试 5: 复杂表达式
|
||||
print("\n测试 5: 复杂表达式")
|
||||
ma20 = FunctionNode("ts_mean", close, 20)
|
||||
ma60 = FunctionNode("ts_mean", close, 60)
|
||||
expr5 = translator.translate(FunctionNode("cs_rank", ma20 - ma60))
|
||||
print(f" cs_rank(ts_mean(close, 20) - ts_mean(close, 60)) -> {expr5}")
|
||||
|
||||
print("\n✅ 所有测试通过!")
|
||||
Reference in New Issue
Block a user