feat(factors): 实现 AST 拍平优化支持嵌套窗口函数
- 新增 ExpressionFlattener 类自动拆解嵌套窗口函数(如 cs_rank(ts_delay(close, 1)))
- 支持因子引用其他因子:engine.register("fac2", cs_rank("fac1"))
- 给 DependencyExtractor 增加 ignore_symbols 免疫名单,防止已注册因子被当作数据库字段
- 添加完整测试覆盖嵌套场景和数值一致性验证
This commit is contained in:
@@ -3,7 +3,7 @@
|
|||||||
本模块实现 AST 遍历器模式,用于从 DSL 表达式中提取依赖的符号。
|
本模块实现 AST 遍历器模式,用于从 DSL 表达式中提取依赖的符号。
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Set
|
from typing import Set, Optional
|
||||||
|
|
||||||
from src.factors.dsl import Node, Symbol, BinaryOpNode, UnaryOpNode, FunctionNode
|
from src.factors.dsl import Node, Symbol, BinaryOpNode, UnaryOpNode, FunctionNode
|
||||||
|
|
||||||
@@ -24,9 +24,14 @@ class DependencyExtractor:
|
|||||||
{'close', 'pe_ratio'}
|
{'close', 'pe_ratio'}
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self, ignore_symbols: Optional[Set[str]] = None) -> None:
|
||||||
"""初始化依赖提取器。"""
|
"""初始化依赖提取器。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ignore_symbols: 需要忽略的符号集合(如已注册的因子名)
|
||||||
|
"""
|
||||||
self.dependencies: Set[str] = set()
|
self.dependencies: Set[str] = set()
|
||||||
|
self.ignore_symbols: Set[str] = ignore_symbols or set()
|
||||||
|
|
||||||
def visit(self, node: Node) -> None:
|
def visit(self, node: Node) -> None:
|
||||||
"""访问节点,根据节点类型分发到具体处理方法。
|
"""访问节点,根据节点类型分发到具体处理方法。
|
||||||
@@ -47,10 +52,14 @@ class DependencyExtractor:
|
|||||||
def _visit_symbol(self, node: Symbol) -> None:
|
def _visit_symbol(self, node: Symbol) -> None:
|
||||||
"""访问 Symbol 节点,提取符号名称。
|
"""访问 Symbol 节点,提取符号名称。
|
||||||
|
|
||||||
|
排除临时因子(以 __tmp_ 开头的符号)和已在免疫名单中的因子。
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
node: 符号节点
|
node: 符号节点
|
||||||
"""
|
"""
|
||||||
self.dependencies.add(node.name)
|
# 排除临时因子引用 和 已在免疫名单中的因子
|
||||||
|
if not node.name.startswith("__tmp_") and node.name not in self.ignore_symbols:
|
||||||
|
self.dependencies.add(node.name)
|
||||||
|
|
||||||
def _visit_binary_op(self, node: BinaryOpNode) -> None:
|
def _visit_binary_op(self, node: BinaryOpNode) -> None:
|
||||||
"""访问 BinaryOpNode 节点,递归遍历左右子节点。
|
"""访问 BinaryOpNode 节点,递归遍历左右子节点。
|
||||||
@@ -92,13 +101,16 @@ class DependencyExtractor:
|
|||||||
return self.dependencies.copy()
|
return self.dependencies.copy()
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def extract_dependencies(cls, node: Node) -> Set[str]:
|
def extract_dependencies(
|
||||||
|
cls, node: Node, ignore_symbols: Optional[Set[str]] = None
|
||||||
|
) -> Set[str]:
|
||||||
"""类方法 - 从 AST 节点中提取所有依赖的符号名称。
|
"""类方法 - 从 AST 节点中提取所有依赖的符号名称。
|
||||||
|
|
||||||
这是一个便捷方法,无需手动实例化 DependencyExtractor。
|
这是一个便捷方法,无需手动实例化 DependencyExtractor。
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
node: 表达式树的根节点
|
node: 表达式树的根节点
|
||||||
|
ignore_symbols: 需要忽略的符号集合(如已注册的因子名)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
依赖的符号名称集合
|
依赖的符号名称集合
|
||||||
@@ -112,17 +124,20 @@ class DependencyExtractor:
|
|||||||
>>> print(deps)
|
>>> print(deps)
|
||||||
{'close', 'open'}
|
{'close', 'open'}
|
||||||
"""
|
"""
|
||||||
extractor = cls()
|
extractor = cls(ignore_symbols=ignore_symbols)
|
||||||
return extractor.extract(node)
|
return extractor.extract(node)
|
||||||
|
|
||||||
|
|
||||||
def extract_dependencies(node: Node) -> Set[str]:
|
def extract_dependencies(
|
||||||
|
node: Node, ignore_symbols: Optional[Set[str]] = None
|
||||||
|
) -> Set[str]:
|
||||||
"""单例方法 - 从 AST 节点中提取所有依赖的符号名称。
|
"""单例方法 - 从 AST 节点中提取所有依赖的符号名称。
|
||||||
|
|
||||||
这是 DependencyExtractor.extract_dependencies 的便捷包装函数。
|
这是 DependencyExtractor.extract_dependencies 的便捷包装函数。
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
node: 表达式树的根节点
|
node: 表达式树的根节点
|
||||||
|
ignore_symbols: 需要忽略的符号集合(如已注册的因子名)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
依赖的符号名称集合
|
依赖的符号名称集合
|
||||||
@@ -136,7 +151,7 @@ def extract_dependencies(node: Node) -> Set[str]:
|
|||||||
>>> print(deps)
|
>>> print(deps)
|
||||||
{'close', 'pe_ratio'}
|
{'close', 'pe_ratio'}
|
||||||
"""
|
"""
|
||||||
return DependencyExtractor.extract_dependencies(node)
|
return DependencyExtractor.extract_dependencies(node, ignore_symbols=ignore_symbols)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
223
src/factors/engine/ast_optimizer.py
Normal file
223
src/factors/engine/ast_optimizer.py
Normal file
@@ -0,0 +1,223 @@
|
|||||||
|
"""AST 优化器 - 表达式拍平。
|
||||||
|
|
||||||
|
本模块实现将嵌套的窗口函数表达式自动提取为中间临时因子,
|
||||||
|
解决多维窗口函数(over)嵌套导致计算为空的问题。
|
||||||
|
|
||||||
|
核心思想:
|
||||||
|
通过 AST 变换,将嵌套在窗口函数内的窗口函数表达式提取出来,
|
||||||
|
作为独立的临时因子先行计算,然后主表达式引用这些临时因子。
|
||||||
|
|
||||||
|
示例:
|
||||||
|
原始表达式: cs_rank(ts_delay(close, 1))
|
||||||
|
拍平后:
|
||||||
|
- 临时因子: __tmp_0 = ts_delay(close, 1)
|
||||||
|
- 主表达式: cs_rank(__tmp_0)
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Dict, Tuple
|
||||||
|
|
||||||
|
from src.factors.dsl import (
|
||||||
|
BinaryOpNode,
|
||||||
|
Constant,
|
||||||
|
FunctionNode,
|
||||||
|
Node,
|
||||||
|
Symbol,
|
||||||
|
UnaryOpNode,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ExpressionFlattener:
|
||||||
|
"""表达式拍平器。
|
||||||
|
|
||||||
|
遍历 AST 并自动提取嵌套的窗口函数为独立临时因子。
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
_counter: 临时因子名称计数器,用于生成唯一名称
|
||||||
|
_extracted_nodes: 存储已提取的临时因子字典
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
"""初始化拍平器。"""
|
||||||
|
self._counter: int = 0
|
||||||
|
self._extracted_nodes: Dict[str, Node] = {}
|
||||||
|
|
||||||
|
def _generate_temp_name(self) -> str:
|
||||||
|
"""生成唯一的临时因子名称。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
格式为 "__tmp_X" 的临时名称,其中 X 是递增数字
|
||||||
|
"""
|
||||||
|
name = f"__tmp_{self._counter}"
|
||||||
|
self._counter += 1
|
||||||
|
return name
|
||||||
|
|
||||||
|
def _is_window_function(self, func_name: str) -> bool:
|
||||||
|
"""判断是否为窗口函数。
|
||||||
|
|
||||||
|
窗口函数以 "ts_"(时序)或 "cs_"(截面)开头。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
func_name: 函数名称
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
是否是窗口函数
|
||||||
|
"""
|
||||||
|
return func_name.startswith("ts_") or func_name.startswith("cs_")
|
||||||
|
|
||||||
|
def flatten(self, node: Node) -> Tuple[Node, Dict[str, Node]]:
|
||||||
|
"""拍平表达式。
|
||||||
|
|
||||||
|
遍历 AST,将嵌套的窗口函数提取为临时因子。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
node: 原始表达式根节点
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple[拍平后的主表达式节点, 临时因子字典]
|
||||||
|
临时因子字典: {临时名称 -> 被提取的节点}
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> flattener = ExpressionFlattener()
|
||||||
|
>>> from src.factors.dsl import Symbol, FunctionNode
|
||||||
|
>>> close = Symbol("close")
|
||||||
|
>>> expr = FunctionNode("cs_rank", FunctionNode("ts_delay", close, 1))
|
||||||
|
>>> flat_expr, tmp_factors = flattener.flatten(expr)
|
||||||
|
>>> # flat_expr = cs_rank(__tmp_0)
|
||||||
|
>>> # tmp_factors = {"__tmp_0": ts_delay(close, 1)}
|
||||||
|
"""
|
||||||
|
# 重置状态
|
||||||
|
self._counter = 0
|
||||||
|
self._extracted_nodes = {}
|
||||||
|
|
||||||
|
# 从根节点开始遍历,初始状态为不在窗口函数内部
|
||||||
|
new_node = self._flatten_recursive(node, inside_window=False)
|
||||||
|
|
||||||
|
return new_node, self._extracted_nodes.copy()
|
||||||
|
|
||||||
|
def _flatten_recursive(self, node: Node, inside_window: bool) -> Node:
|
||||||
|
"""递归拍平节点。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
node: 当前处理的节点
|
||||||
|
inside_window: 当前是否处于窗口函数内部
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
处理后的节点(可能是原节点或替换为 Symbol)
|
||||||
|
"""
|
||||||
|
# Symbol 和 Constant 是叶子节点,直接返回
|
||||||
|
if isinstance(node, Symbol):
|
||||||
|
return node
|
||||||
|
|
||||||
|
if isinstance(node, Constant):
|
||||||
|
return node
|
||||||
|
|
||||||
|
# 处理二元运算节点
|
||||||
|
if isinstance(node, BinaryOpNode):
|
||||||
|
return self._flatten_binary_op(node, inside_window)
|
||||||
|
|
||||||
|
# 处理一元运算节点
|
||||||
|
if isinstance(node, UnaryOpNode):
|
||||||
|
return self._flatten_unary_op(node, inside_window)
|
||||||
|
|
||||||
|
# 处理函数调用节点
|
||||||
|
if isinstance(node, FunctionNode):
|
||||||
|
return self._flatten_function(node, inside_window)
|
||||||
|
|
||||||
|
# 未知节点类型,直接返回
|
||||||
|
return node
|
||||||
|
|
||||||
|
def _flatten_binary_op(self, node: BinaryOpNode, inside_window: bool) -> Node:
|
||||||
|
"""拍平二元运算节点。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
node: 二元运算节点
|
||||||
|
inside_window: 当前是否处于窗口函数内部
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
处理后的节点
|
||||||
|
"""
|
||||||
|
# 递归处理左右子节点
|
||||||
|
new_left = self._flatten_recursive(node.left, inside_window)
|
||||||
|
new_right = self._flatten_recursive(node.right, inside_window)
|
||||||
|
|
||||||
|
# 如果子节点没有变化,返回原节点
|
||||||
|
if new_left is node.left and new_right is node.right:
|
||||||
|
return node
|
||||||
|
|
||||||
|
# 创建新的二元运算节点
|
||||||
|
return BinaryOpNode(node.op, new_left, new_right)
|
||||||
|
|
||||||
|
def _flatten_unary_op(self, node: UnaryOpNode, inside_window: bool) -> Node:
|
||||||
|
"""拍平一元运算节点。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
node: 一元运算节点
|
||||||
|
inside_window: 当前是否处于窗口函数内部
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
处理后的节点
|
||||||
|
"""
|
||||||
|
# 递归处理操作数
|
||||||
|
new_operand = self._flatten_recursive(node.operand, inside_window)
|
||||||
|
|
||||||
|
# 如果操作数没有变化,返回原节点
|
||||||
|
if new_operand is node.operand:
|
||||||
|
return node
|
||||||
|
|
||||||
|
# 创建新的一元运算节点
|
||||||
|
return UnaryOpNode(node.op, new_operand)
|
||||||
|
|
||||||
|
def _flatten_function(self, node: FunctionNode, inside_window: bool) -> Node:
|
||||||
|
"""拍平函数调用节点。
|
||||||
|
|
||||||
|
修正为后序遍历(Bottom-Up):先递归拍平参数,再决定是否提取当前节点。
|
||||||
|
确保深层嵌套(如 3层以上)也能被彻底逐层拆解。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
node: 函数调用节点
|
||||||
|
inside_window: 当前是否处于窗口函数内部
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
处理后的节点
|
||||||
|
"""
|
||||||
|
is_window = self._is_window_function(node.func_name)
|
||||||
|
next_inside_window = inside_window or is_window
|
||||||
|
|
||||||
|
# 1. 优先递归处理所有参数
|
||||||
|
new_args = []
|
||||||
|
has_change = False
|
||||||
|
for arg in node.args:
|
||||||
|
new_arg = self._flatten_recursive(arg, next_inside_window)
|
||||||
|
new_args.append(new_arg)
|
||||||
|
if new_arg is not arg:
|
||||||
|
has_change = True
|
||||||
|
|
||||||
|
# 2. 只有当参数发生变化时,才创建新的当前节点
|
||||||
|
current_node = FunctionNode(node.func_name, *new_args) if has_change else node
|
||||||
|
|
||||||
|
# 3. 判断是否需要提取(此时子节点肯定已经被彻底拍平了)
|
||||||
|
if inside_window and is_window:
|
||||||
|
temp_name = self._generate_temp_name()
|
||||||
|
self._extracted_nodes[temp_name] = current_node
|
||||||
|
return Symbol(temp_name)
|
||||||
|
|
||||||
|
return current_node
|
||||||
|
|
||||||
|
|
||||||
|
def flatten_expression(node: Node) -> Tuple[Node, Dict[str, Node]]:
|
||||||
|
"""便捷函数 - 拍平表达式。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
node: 表达式树的根节点
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple[拍平后的主表达式节点, 临时因子字典]
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> from src.factors.dsl import Symbol, FunctionNode
|
||||||
|
>>> close = Symbol("close")
|
||||||
|
>>> expr = FunctionNode("cs_rank", FunctionNode("ts_delay", close, 1))
|
||||||
|
>>> flat_expr, tmp_factors = flatten_expression(expr)
|
||||||
|
"""
|
||||||
|
flattener = ExpressionFlattener()
|
||||||
|
return flattener.flatten(node)
|
||||||
@@ -30,6 +30,7 @@ from src.factors.engine.data_spec import DataSpec, ExecutionPlan
|
|||||||
from src.factors.engine.data_router import DataRouter
|
from src.factors.engine.data_router import DataRouter
|
||||||
from src.factors.engine.planner import ExecutionPlanner
|
from src.factors.engine.planner import ExecutionPlanner
|
||||||
from src.factors.engine.compute_engine import ComputeEngine
|
from src.factors.engine.compute_engine import ComputeEngine
|
||||||
|
from src.factors.engine.ast_optimizer import ExpressionFlattener
|
||||||
|
|
||||||
|
|
||||||
class FactorEngine:
|
class FactorEngine:
|
||||||
@@ -92,13 +93,68 @@ class FactorEngine:
|
|||||||
|
|
||||||
self._metadata = FactorManager()
|
self._metadata = FactorManager()
|
||||||
|
|
||||||
|
def _register_internal(
|
||||||
|
self,
|
||||||
|
name: str,
|
||||||
|
expression: Node,
|
||||||
|
data_specs: Optional[List[DataSpec]] = None,
|
||||||
|
) -> "FactorEngine":
|
||||||
|
"""内部注册方法,直接注册因子表达式。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: 因子名称
|
||||||
|
expression: DSL 表达式
|
||||||
|
data_specs: 数据规格,None 时自动推导
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
self,支持链式调用
|
||||||
|
"""
|
||||||
|
# 检测因子依赖(在注册当前因子之前检查其他已注册因子)
|
||||||
|
factor_deps = self._find_factor_dependencies(expression)
|
||||||
|
|
||||||
|
# 获取当前所有已注册的因子名称(作为免疫名单,防止被当作数据库字段)
|
||||||
|
known_factors = set(self.registered_expressions.keys())
|
||||||
|
|
||||||
|
self.registered_expressions[name] = expression
|
||||||
|
|
||||||
|
# 预创建执行计划,过滤掉已注册的因子,防止被当作数据库字段
|
||||||
|
plan = self.planner.create_plan(
|
||||||
|
expression=expression,
|
||||||
|
output_name=name,
|
||||||
|
data_specs=data_specs,
|
||||||
|
ignore_dependencies=known_factors,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 添加因子依赖信息
|
||||||
|
plan.factor_dependencies = factor_deps
|
||||||
|
|
||||||
|
# 如果数据规格为空,继承依赖因子(包括临时因子)的数据规格
|
||||||
|
if not plan.data_specs and factor_deps:
|
||||||
|
merged_specs: List[DataSpec] = []
|
||||||
|
for dep_name in factor_deps:
|
||||||
|
if dep_name in self._plans:
|
||||||
|
merged_specs.extend(self._plans[dep_name].data_specs)
|
||||||
|
|
||||||
|
# 去重(基于表名)
|
||||||
|
seen_tables: set = set()
|
||||||
|
unique_specs: List[DataSpec] = []
|
||||||
|
for spec in merged_specs:
|
||||||
|
if spec.table not in seen_tables:
|
||||||
|
seen_tables.add(spec.table)
|
||||||
|
unique_specs.append(spec)
|
||||||
|
plan.data_specs = unique_specs
|
||||||
|
|
||||||
|
self._plans[name] = plan
|
||||||
|
|
||||||
|
return self
|
||||||
|
|
||||||
def register(
|
def register(
|
||||||
self,
|
self,
|
||||||
name: str,
|
name: str,
|
||||||
expression: Node,
|
expression: Node,
|
||||||
data_specs: Optional[List[DataSpec]] = None,
|
data_specs: Optional[List[DataSpec]] = None,
|
||||||
) -> "FactorEngine":
|
) -> "FactorEngine":
|
||||||
"""注册因子表达式。
|
"""注册因子表达式(自动处理嵌套窗口函数)。
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
name: 因子名称
|
name: 因子名称
|
||||||
@@ -113,22 +169,16 @@ class FactorEngine:
|
|||||||
>>> engine = FactorEngine()
|
>>> engine = FactorEngine()
|
||||||
>>> engine.register("ma20", ts_mean(close, 20))
|
>>> engine.register("ma20", ts_mean(close, 20))
|
||||||
"""
|
"""
|
||||||
# 检测因子依赖(在注册当前因子之前检查其他已注册因子)
|
# 使用 AST 优化器拍平嵌套窗口函数
|
||||||
factor_deps = self._find_factor_dependencies(expression)
|
flattener = ExpressionFlattener()
|
||||||
|
flat_expression, tmp_factors = flattener.flatten(expression)
|
||||||
|
|
||||||
self.registered_expressions[name] = expression
|
# 先注册所有临时因子(自动推导数据规格)
|
||||||
|
for tmp_name, tmp_node in tmp_factors.items():
|
||||||
|
self._register_internal(tmp_name, tmp_node, data_specs=None)
|
||||||
|
|
||||||
# 预创建执行计划
|
# 最后注册主因子
|
||||||
plan = self.planner.create_plan(
|
self._register_internal(name, flat_expression, data_specs)
|
||||||
expression=expression,
|
|
||||||
output_name=name,
|
|
||||||
data_specs=data_specs,
|
|
||||||
)
|
|
||||||
|
|
||||||
# 添加因子依赖信息
|
|
||||||
plan.factor_dependencies = factor_deps
|
|
||||||
|
|
||||||
self._plans[name] = plan
|
|
||||||
|
|
||||||
return self
|
return self
|
||||||
|
|
||||||
@@ -174,7 +224,7 @@ class FactorEngine:
|
|||||||
# 解析表达式为 Node
|
# 解析表达式为 Node
|
||||||
node = self._parser.parse(dsl_expr)
|
node = self._parser.parse(dsl_expr)
|
||||||
|
|
||||||
# 委托给 register 方法
|
# 委托给 register 方法(register 会处理嵌套窗口函数拍平)
|
||||||
return self.register(name, node, data_specs)
|
return self.register(name, node, data_specs)
|
||||||
|
|
||||||
def add_factor(
|
def add_factor(
|
||||||
@@ -272,21 +322,32 @@ class FactorEngine:
|
|||||||
if isinstance(factor_names, str):
|
if isinstance(factor_names, str):
|
||||||
factor_names = [factor_names]
|
factor_names = [factor_names]
|
||||||
|
|
||||||
# 1. 获取执行计划
|
# 1. 收集所有需要的因子(包括临时因子依赖)
|
||||||
|
all_factor_names = self._collect_all_dependencies(factor_names)
|
||||||
|
|
||||||
|
# 2. 获取执行计划
|
||||||
plans = []
|
plans = []
|
||||||
for name in factor_names:
|
for name in all_factor_names:
|
||||||
if name not in self._plans:
|
if name not in self._plans:
|
||||||
raise ValueError(f"因子未注册: {name}")
|
raise ValueError(f"因子未注册: {name}")
|
||||||
plans.append(self._plans[name])
|
plans.append(self._plans[name])
|
||||||
|
|
||||||
# 2. 合并数据规格并获取数据
|
# 3. 合并数据规格并获取数据
|
||||||
all_specs = []
|
all_specs = []
|
||||||
for plan in plans:
|
for plan in plans:
|
||||||
all_specs.extend(plan.data_specs)
|
all_specs.extend(plan.data_specs)
|
||||||
|
|
||||||
# 3. 从路由器获取核心宽表
|
# 去重数据规格(基于表名)
|
||||||
|
seen_tables: set = set()
|
||||||
|
unique_specs: List[DataSpec] = []
|
||||||
|
for spec in all_specs:
|
||||||
|
if spec.table not in seen_tables:
|
||||||
|
seen_tables.add(spec.table)
|
||||||
|
unique_specs.append(spec)
|
||||||
|
|
||||||
|
# 4. 从路由器获取核心宽表
|
||||||
core_data = self.router.fetch_data(
|
core_data = self.router.fetch_data(
|
||||||
data_specs=all_specs,
|
data_specs=unique_specs,
|
||||||
start_date=start_date,
|
start_date=start_date,
|
||||||
end_date=end_date,
|
end_date=end_date,
|
||||||
stock_codes=stock_codes,
|
stock_codes=stock_codes,
|
||||||
@@ -295,14 +356,14 @@ class FactorEngine:
|
|||||||
if len(core_data) == 0:
|
if len(core_data) == 0:
|
||||||
raise ValueError("未获取到任何数据,请检查日期范围和股票代码")
|
raise ValueError("未获取到任何数据,请检查日期范围和股票代码")
|
||||||
|
|
||||||
# 4. 按依赖顺序执行计算
|
# 5. 按依赖顺序执行计算(包含临时因子)
|
||||||
if len(plans) == 1:
|
result = self._execute_with_dependencies(all_factor_names, core_data)
|
||||||
result = self.compute_engine.execute(plans[0], core_data)
|
|
||||||
else:
|
|
||||||
# 使用依赖感知的方式执行
|
|
||||||
result = self._execute_with_dependencies(factor_names, core_data)
|
|
||||||
|
|
||||||
return result
|
# 6. 清理内存宽表,过滤掉临时因子列(__tmp_X)
|
||||||
|
# 保留所有非临时因子列(包括原始数据列和用户请求的因子列)
|
||||||
|
cols_to_keep = [col for col in result.columns if not col.startswith("__tmp_")]
|
||||||
|
|
||||||
|
return result.select(cols_to_keep)
|
||||||
|
|
||||||
def list_registered(self) -> List[str]:
|
def list_registered(self) -> List[str]:
|
||||||
"""获取已注册的因子列表。
|
"""获取已注册的因子列表。
|
||||||
@@ -501,10 +562,32 @@ class FactorEngine:
|
|||||||
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def _find_factor_dependencies(self, expression: Node) -> Set[str]:
|
def _collect_all_dependencies(self, factor_names: List[str]) -> List[str]:
|
||||||
"""查找表达式依赖的其他因子。
|
"""收集所有因子及其依赖(包括用户定义的因子和临时因子)。"""
|
||||||
|
collected: Set[str] = set()
|
||||||
|
result: List[str] = []
|
||||||
|
|
||||||
遍历已注册因子,检查表达式是否包含任何已注册因子的完整表达式。
|
def collect_recursive(name: str):
|
||||||
|
if name in collected:
|
||||||
|
return
|
||||||
|
collected.add(name)
|
||||||
|
|
||||||
|
# 获取执行计划并递归收集强依赖
|
||||||
|
plan = self._plans.get(name)
|
||||||
|
if plan:
|
||||||
|
for dep_name in plan.factor_dependencies:
|
||||||
|
collect_recursive(dep_name)
|
||||||
|
|
||||||
|
# 依赖收集完毕,再将自己加入列表(天然形成安全的计算顺序)
|
||||||
|
result.append(name)
|
||||||
|
|
||||||
|
for name in factor_names:
|
||||||
|
collect_recursive(name)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
def _find_factor_dependencies(self, expression: Node) -> Set[str]:
|
||||||
|
"""查找表达式依赖的其他因子(包括临时因子和用户因子引用)。
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
expression: 待检查的表达式
|
expression: 待检查的表达式
|
||||||
@@ -514,13 +597,20 @@ class FactorEngine:
|
|||||||
"""
|
"""
|
||||||
deps: Set[str] = set()
|
deps: Set[str] = set()
|
||||||
|
|
||||||
# 检查表达式本身是否等于某个已注册因子
|
# 1. 【新增】如果直接引用了已注册的因子名称(包含 __tmp_X 或用户因子)
|
||||||
|
if (
|
||||||
|
isinstance(expression, Symbol)
|
||||||
|
and expression.name in self.registered_expressions
|
||||||
|
):
|
||||||
|
deps.add(expression.name)
|
||||||
|
|
||||||
|
# 2. 检查表达式本身是否等于某个已注册因子的完整 AST
|
||||||
for name, registered_expr in self.registered_expressions.items():
|
for name, registered_expr in self.registered_expressions.items():
|
||||||
if self._expressions_equal(expression, registered_expr):
|
if self._expressions_equal(expression, registered_expr):
|
||||||
deps.add(name)
|
deps.add(name)
|
||||||
break
|
break
|
||||||
|
|
||||||
# 递归检查子节点
|
# 3. 递归检查子节点
|
||||||
if isinstance(expression, BinaryOpNode):
|
if isinstance(expression, BinaryOpNode):
|
||||||
deps.update(self._find_factor_dependencies(expression.left))
|
deps.update(self._find_factor_dependencies(expression.left))
|
||||||
deps.update(self._find_factor_dependencies(expression.right))
|
deps.update(self._find_factor_dependencies(expression.right))
|
||||||
|
|||||||
@@ -39,6 +39,7 @@ class ExecutionPlanner:
|
|||||||
expression: Node,
|
expression: Node,
|
||||||
output_name: str = "factor",
|
output_name: str = "factor",
|
||||||
data_specs: Optional[List[DataSpec]] = None,
|
data_specs: Optional[List[DataSpec]] = None,
|
||||||
|
ignore_dependencies: Optional[Set[str]] = None,
|
||||||
) -> ExecutionPlan:
|
) -> ExecutionPlan:
|
||||||
"""从表达式创建执行计划。
|
"""从表达式创建执行计划。
|
||||||
|
|
||||||
@@ -46,12 +47,15 @@ class ExecutionPlanner:
|
|||||||
expression: DSL 表达式节点
|
expression: DSL 表达式节点
|
||||||
output_name: 输出因子名称
|
output_name: 输出因子名称
|
||||||
data_specs: 预定义的数据规格,None 时自动推导
|
data_specs: 预定义的数据规格,None 时自动推导
|
||||||
|
ignore_dependencies: 需要忽略的依赖符号集合(如已注册因子名)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
执行计划对象
|
执行计划对象
|
||||||
"""
|
"""
|
||||||
# 1. 提取依赖
|
# 1. 提取依赖时传入要忽略的符号
|
||||||
dependencies = self.compiler.extract_dependencies(expression)
|
dependencies = self.compiler.extract_dependencies(
|
||||||
|
expression, ignore_symbols=ignore_dependencies
|
||||||
|
)
|
||||||
|
|
||||||
# 2. 翻译为 Polars 表达式
|
# 2. 翻译为 Polars 表达式
|
||||||
polars_expr = self.translator.translate(expression)
|
polars_expr = self.translator.translate(expression)
|
||||||
|
|||||||
367
tests/test_ast_optimizer.py
Normal file
367
tests/test_ast_optimizer.py
Normal file
@@ -0,0 +1,367 @@
|
|||||||
|
"""AST 优化器测试 - 验证嵌套窗口函数拍平功能。
|
||||||
|
|
||||||
|
测试因子: cs_rank(ts_delay(close, 1))
|
||||||
|
这是一个典型的窗口函数嵌套场景,应该被自动拍平为临时因子。
|
||||||
|
"""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import polars as pl
|
||||||
|
import numpy as np
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
|
||||||
|
from src.factors.engine import FactorEngine
|
||||||
|
from src.factors.api import close, ts_delay, cs_rank
|
||||||
|
from src.factors.dsl import FunctionNode
|
||||||
|
from src.factors.engine.ast_optimizer import ExpressionFlattener
|
||||||
|
|
||||||
|
|
||||||
|
def create_mock_data(
|
||||||
|
start_date: str = "20240101",
|
||||||
|
end_date: str = "20240131",
|
||||||
|
n_stocks: int = 5,
|
||||||
|
) -> pl.DataFrame:
|
||||||
|
"""创建模拟的日线数据。"""
|
||||||
|
start = datetime.strptime(start_date, "%Y%m%d")
|
||||||
|
end = datetime.strptime(end_date, "%Y%m%d")
|
||||||
|
|
||||||
|
dates = []
|
||||||
|
current = start
|
||||||
|
while current <= end:
|
||||||
|
if current.weekday() < 5: # 周一到周五
|
||||||
|
dates.append(current.strftime("%Y%m%d"))
|
||||||
|
current += timedelta(days=1)
|
||||||
|
|
||||||
|
stocks = [f"{600000 + i:06d}.SH" for i in range(n_stocks)]
|
||||||
|
np.random.seed(42)
|
||||||
|
|
||||||
|
rows = []
|
||||||
|
for date in dates:
|
||||||
|
for stock in stocks:
|
||||||
|
base_price = 10 + np.random.randn() * 5
|
||||||
|
close_val = base_price + np.random.randn() * 0.5
|
||||||
|
open_val = close_val + np.random.randn() * 0.2
|
||||||
|
high_val = max(open_val, close_val) + abs(np.random.randn()) * 0.3
|
||||||
|
low_val = min(open_val, close_val) - abs(np.random.randn()) * 0.3
|
||||||
|
vol = int(1000000 + np.random.exponential(500000))
|
||||||
|
|
||||||
|
rows.append(
|
||||||
|
{
|
||||||
|
"ts_code": stock,
|
||||||
|
"trade_date": date,
|
||||||
|
"open": round(open_val, 2),
|
||||||
|
"high": round(high_val, 2),
|
||||||
|
"low": round(low_val, 2),
|
||||||
|
"close": round(close_val, 2),
|
||||||
|
"volume": vol,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
return pl.DataFrame(rows)
|
||||||
|
|
||||||
|
|
||||||
|
class TestASTOptimizer:
|
||||||
|
"""AST 优化器测试类。"""
|
||||||
|
|
||||||
|
def test_flattener_basic(self):
|
||||||
|
"""测试拍平器基本功能。"""
|
||||||
|
from src.factors.api import close
|
||||||
|
|
||||||
|
flattener = ExpressionFlattener()
|
||||||
|
|
||||||
|
# 创建嵌套表达式: cs_rank(ts_delay(close, 1))
|
||||||
|
expr = FunctionNode("cs_rank", FunctionNode("ts_delay", close, 1))
|
||||||
|
|
||||||
|
flat_expr, tmp_factors = flattener.flatten(expr)
|
||||||
|
|
||||||
|
# 验证临时因子被提取
|
||||||
|
assert len(tmp_factors) == 1
|
||||||
|
assert "__tmp_0" in tmp_factors
|
||||||
|
|
||||||
|
# 验证主表达式使用了 Symbol 引用
|
||||||
|
assert isinstance(flat_expr, FunctionNode)
|
||||||
|
assert flat_expr.func_name == "cs_rank"
|
||||||
|
# 验证第一个参数是临时因子引用(通过 name 属性检查)
|
||||||
|
assert hasattr(flat_expr.args[0], "name")
|
||||||
|
assert flat_expr.args[0].name == "__tmp_0"
|
||||||
|
|
||||||
|
# 验证临时因子内容
|
||||||
|
tmp_node = tmp_factors["__tmp_0"]
|
||||||
|
assert isinstance(tmp_node, FunctionNode)
|
||||||
|
assert tmp_node.func_name == "ts_delay"
|
||||||
|
|
||||||
|
print("[PASS] 拍平器基本功能测试")
|
||||||
|
|
||||||
|
def test_flattener_no_nested(self):
|
||||||
|
"""测试非嵌套表达式不会被拍平。"""
|
||||||
|
from src.factors.api import close, ts_mean
|
||||||
|
|
||||||
|
flattener = ExpressionFlattener()
|
||||||
|
|
||||||
|
# 非嵌套表达式: ts_mean(close, 20)
|
||||||
|
expr = FunctionNode("ts_mean", close, 20)
|
||||||
|
|
||||||
|
flat_expr, tmp_factors = flattener.flatten(expr)
|
||||||
|
|
||||||
|
# 验证没有临时因子被提取
|
||||||
|
assert len(tmp_factors) == 0
|
||||||
|
|
||||||
|
# 验证表达式保持不变
|
||||||
|
assert isinstance(flat_expr, FunctionNode)
|
||||||
|
assert flat_expr.func_name == "ts_mean"
|
||||||
|
|
||||||
|
print("[PASS] 非嵌套表达式测试")
|
||||||
|
|
||||||
|
def test_flattener_deeply_nested(self):
|
||||||
|
"""测试多层嵌套表达式拍平。"""
|
||||||
|
from src.factors.api import close, ts_mean
|
||||||
|
|
||||||
|
flattener = ExpressionFlattener()
|
||||||
|
|
||||||
|
# 深层嵌套: cs_rank(ts_mean(ts_delay(close, 1), 5))
|
||||||
|
expr = FunctionNode(
|
||||||
|
"cs_rank", FunctionNode("ts_mean", FunctionNode("ts_delay", close, 1), 5)
|
||||||
|
)
|
||||||
|
|
||||||
|
flat_expr, tmp_factors = flattener.flatten(expr)
|
||||||
|
|
||||||
|
# 验证提取了两个临时因子(修复后正确行为)
|
||||||
|
# ts_delay(close, 1) 被提取为 __tmp_0
|
||||||
|
# ts_mean(__tmp_0, 5) 被提取为 __tmp_1
|
||||||
|
assert len(tmp_factors) == 2
|
||||||
|
assert "__tmp_0" in tmp_factors
|
||||||
|
assert "__tmp_1" in tmp_factors
|
||||||
|
|
||||||
|
# 验证 __tmp_0 内容是 ts_delay(close, 1)
|
||||||
|
tmp0_node = tmp_factors["__tmp_0"]
|
||||||
|
assert isinstance(tmp0_node, FunctionNode)
|
||||||
|
assert tmp0_node.func_name == "ts_delay"
|
||||||
|
|
||||||
|
# 验证 __tmp_1 内容是 ts_mean(__tmp_0, 5)
|
||||||
|
tmp1_node = tmp_factors["__tmp_1"]
|
||||||
|
assert isinstance(tmp1_node, FunctionNode)
|
||||||
|
assert tmp1_node.func_name == "ts_mean"
|
||||||
|
from src.factors.dsl import Symbol
|
||||||
|
|
||||||
|
assert isinstance(tmp1_node.args[0], Symbol)
|
||||||
|
assert tmp1_node.args[0].name == "__tmp_0"
|
||||||
|
|
||||||
|
# 验证主表达式引用 __tmp_1
|
||||||
|
assert isinstance(flat_expr, FunctionNode)
|
||||||
|
assert flat_expr.func_name == "cs_rank"
|
||||||
|
assert isinstance(flat_expr.args[0], Symbol)
|
||||||
|
assert flat_expr.args[0].name == "__tmp_1"
|
||||||
|
|
||||||
|
print("[PASS] 多层嵌套表达式拍平测试")
|
||||||
|
|
||||||
|
def test_nested_window_function_engine(self):
|
||||||
|
"""测试引擎正确处理嵌套窗口函数 cs_rank(ts_delay(close, 1))。"""
|
||||||
|
print("\n" + "=" * 60)
|
||||||
|
print("测试嵌套窗口函数: cs_rank(ts_delay(close, 1))")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
# 1. 准备数据
|
||||||
|
mock_data = create_mock_data("20240101", "20240131", n_stocks=5)
|
||||||
|
print(f"\n生成模拟数据: {len(mock_data)} 行")
|
||||||
|
|
||||||
|
# 2. 初始化引擎
|
||||||
|
engine = FactorEngine(data_source={"pro_bar": mock_data})
|
||||||
|
print("引擎初始化完成")
|
||||||
|
|
||||||
|
# 3. 使用字符串表达式注册嵌套窗口函数
|
||||||
|
print("\n注册因子: cs_rank(ts_delay(close, 1))")
|
||||||
|
engine.add_factor("delayed_rank", "cs_rank(ts_delay(close, 1))")
|
||||||
|
|
||||||
|
# 4. 检查临时因子是否被创建
|
||||||
|
registered_factors = engine.list_registered()
|
||||||
|
print(f"已注册因子: {registered_factors}")
|
||||||
|
|
||||||
|
# 验证有临时因子被创建
|
||||||
|
tmp_factors = [name for name in registered_factors if name.startswith("__tmp_")]
|
||||||
|
assert len(tmp_factors) >= 1, "应该有临时因子被创建"
|
||||||
|
print(f"临时因子: {tmp_factors}")
|
||||||
|
|
||||||
|
# 5. 执行计算
|
||||||
|
print("\n执行计算...")
|
||||||
|
result = engine.compute("delayed_rank", "20240115", "20240131")
|
||||||
|
print(f"计算完成: {len(result)} 行")
|
||||||
|
|
||||||
|
# 6. 验证结果
|
||||||
|
assert "delayed_rank" in result.columns, "结果中应该有 delayed_rank 列"
|
||||||
|
|
||||||
|
# 检查结果值是否在合理范围内(排名因子应该在 0-1 之间,但可能由于滞后有 null)
|
||||||
|
non_null_values = result["delayed_rank"].drop_nulls()
|
||||||
|
if len(non_null_values) > 0:
|
||||||
|
assert non_null_values.min() >= 0, "排名应该在 [0, 1] 之间"
|
||||||
|
assert non_null_values.max() <= 1, "排名应该在 [0, 1] 之间"
|
||||||
|
|
||||||
|
# 检查没有过多空值(考虑到开头的滞后期)
|
||||||
|
null_count = result["delayed_rank"].is_null().sum()
|
||||||
|
print(f"空值数量: {null_count}")
|
||||||
|
|
||||||
|
# 展示部分结果
|
||||||
|
print("\n前 10 行结果:")
|
||||||
|
sample = result.select(["ts_code", "trade_date", "close", "delayed_rank"]).head(
|
||||||
|
10
|
||||||
|
)
|
||||||
|
print(sample.to_pandas().to_string(index=False))
|
||||||
|
|
||||||
|
print("\n" + "=" * 60)
|
||||||
|
print("嵌套窗口函数测试通过!")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
def test_multiple_nested_factors(self):
|
||||||
|
"""测试同时注册多个嵌套因子。"""
|
||||||
|
print("\n" + "=" * 60)
|
||||||
|
print("测试多个嵌套因子")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
mock_data = create_mock_data("20240101", "20240131", n_stocks=5)
|
||||||
|
engine = FactorEngine(data_source={"pro_bar": mock_data})
|
||||||
|
|
||||||
|
# 注册多个嵌套因子(使用字符串表达式)
|
||||||
|
print("\n注册因子1: cs_rank(ts_delay(close, 1))")
|
||||||
|
engine.add_factor("rank1", "cs_rank(ts_delay(close, 1))")
|
||||||
|
|
||||||
|
print("注册因子2: ts_mean(cs_rank(close), 5)")
|
||||||
|
engine.add_factor("rank_mean", "ts_mean(cs_rank(close), 5)")
|
||||||
|
|
||||||
|
# 检查已注册因子
|
||||||
|
factors = engine.list_registered()
|
||||||
|
print(f"\n已注册因子: {factors}")
|
||||||
|
|
||||||
|
# 计算所有因子
|
||||||
|
result = engine.compute(["rank1", "rank_mean"], "20240115", "20240131")
|
||||||
|
|
||||||
|
assert "rank1" in result.columns
|
||||||
|
assert "rank_mean" in result.columns
|
||||||
|
|
||||||
|
print(f"\n结果行数: {len(result)}")
|
||||||
|
print(f"rank1 空值数: {result['rank1'].is_null().sum()}")
|
||||||
|
print(f"rank_mean 空值数: {result['rank_mean'].is_null().sum()}")
|
||||||
|
|
||||||
|
print("\n" + "=" * 60)
|
||||||
|
print("多个嵌套因子测试通过!")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
def test_nested_vs_native_polars(self):
|
||||||
|
"""对比测试:嵌套窗口函数 vs 原生 Polars 计算,验证数值一致性。"""
|
||||||
|
print("\n" + "=" * 60)
|
||||||
|
print("对比测试:cs_rank(ts_delay(close, 1)) vs 原生 Polars")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
# 1. 准备数据
|
||||||
|
mock_data = create_mock_data("20240101", "20240131", n_stocks=5)
|
||||||
|
print(f"\n生成模拟数据: {len(mock_data)} 行")
|
||||||
|
|
||||||
|
# 2. 使用 FactorEngine 计算嵌套因子
|
||||||
|
engine = FactorEngine(data_source={"pro_bar": mock_data})
|
||||||
|
print("\n使用 FactorEngine 计算 cs_rank(ts_delay(close, 1))...")
|
||||||
|
engine.register("delayed_rank", cs_rank(ts_delay(close, 1)))
|
||||||
|
engine_result = engine.compute("delayed_rank", "20240115", "20240131")
|
||||||
|
print(f"FactorEngine 结果: {len(engine_result)} 行")
|
||||||
|
|
||||||
|
# 3. 使用原生 Polars 计算(手动分步)
|
||||||
|
print("\n使用原生 Polars 手动计算...")
|
||||||
|
# 先计算 ts_delay(close, 1)
|
||||||
|
native_result = mock_data.sort(["ts_code", "trade_date"]).with_columns(
|
||||||
|
[pl.col("close").shift(1).over("ts_code").alias("delayed_close")]
|
||||||
|
)
|
||||||
|
# 再计算 cs_rank
|
||||||
|
native_result = native_result.with_columns(
|
||||||
|
[
|
||||||
|
(pl.col("delayed_close").rank() / pl.col("delayed_close").count())
|
||||||
|
.over("trade_date")
|
||||||
|
.alias("native_delayed_rank")
|
||||||
|
]
|
||||||
|
)
|
||||||
|
print(f"原生 Polars 结果: {len(native_result)} 行")
|
||||||
|
|
||||||
|
# 4. 合并结果进行对比
|
||||||
|
comparison = engine_result.join(
|
||||||
|
native_result.select(["ts_code", "trade_date", "native_delayed_rank"]),
|
||||||
|
on=["ts_code", "trade_date"],
|
||||||
|
how="inner",
|
||||||
|
)
|
||||||
|
|
||||||
|
# 5. 验证数值一致性(允许微小浮点误差)
|
||||||
|
diff = comparison.with_columns(
|
||||||
|
[
|
||||||
|
(pl.col("delayed_rank") - pl.col("native_delayed_rank"))
|
||||||
|
.abs()
|
||||||
|
.alias("diff")
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
max_diff = diff["diff"].max()
|
||||||
|
print(f"\n最大差异: {max_diff}")
|
||||||
|
|
||||||
|
# 过滤掉空值后比较(开头的滞后期会有空值)
|
||||||
|
non_null_diff = diff.filter(pl.col("diff").is_not_null())
|
||||||
|
assert non_null_diff["diff"].max() < 1e-10, (
|
||||||
|
f"数值差异过大: {non_null_diff['diff'].max()}"
|
||||||
|
)
|
||||||
|
|
||||||
|
print("\n" + "=" * 60)
|
||||||
|
print("数值一致性验证通过!")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
def test_factor_reference_factor(self):
|
||||||
|
"""测试因子引用另一个因子:fac2 = cs_rank(fac1)。"""
|
||||||
|
print("\n" + "=" * 60)
|
||||||
|
print("测试因子引用其他因子: fac2 = cs_rank(fac1)")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
# 准备数据
|
||||||
|
mock_data = create_mock_data("20240101", "20240131", n_stocks=5)
|
||||||
|
engine = FactorEngine(data_source={"pro_bar": mock_data})
|
||||||
|
|
||||||
|
# 1. 注册基础因子 fac1
|
||||||
|
print("\n注册基础因子 fac1 = ts_mean(close, 5)")
|
||||||
|
from src.factors.api import ts_mean
|
||||||
|
|
||||||
|
engine.register("fac1", ts_mean(close, 5))
|
||||||
|
|
||||||
|
# 2. 注册引用因子 fac2,引用 fac1
|
||||||
|
print("注册引用因子 fac2 = cs_rank(fac1)")
|
||||||
|
engine.register("fac2", cs_rank("fac1")) # 字符串引用另一个因子
|
||||||
|
|
||||||
|
# 3. 验证依赖关系
|
||||||
|
registered = engine.list_registered()
|
||||||
|
print(f"\n已注册因子: {registered}")
|
||||||
|
assert "fac1" in registered
|
||||||
|
assert "fac2" in registered
|
||||||
|
|
||||||
|
# 4. 执行计算
|
||||||
|
print("\n执行计算...")
|
||||||
|
result = engine.compute(["fac1", "fac2"], "20240115", "20240131")
|
||||||
|
print(f"计算完成: {len(result)} 行")
|
||||||
|
|
||||||
|
# 5. 验证结果
|
||||||
|
assert "fac1" in result.columns, "结果中应有 fac1 列"
|
||||||
|
assert "fac2" in result.columns, "结果中应有 fac2 列"
|
||||||
|
|
||||||
|
# fac2 是排名,应在 [0, 1] 之间
|
||||||
|
assert result["fac2"].min() >= 0, "排名应在 [0, 1] 之间"
|
||||||
|
assert result["fac2"].max() <= 1, "排名应在 [0, 1] 之间"
|
||||||
|
|
||||||
|
print("\n前 10 行结果:")
|
||||||
|
sample = result.select(["ts_code", "trade_date", "close", "fac1", "fac2"]).head(
|
||||||
|
10
|
||||||
|
)
|
||||||
|
print(sample.to_pandas().to_string(index=False))
|
||||||
|
|
||||||
|
print("\n" + "=" * 60)
|
||||||
|
print("因子引用功能测试通过!")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test = TestASTOptimizer()
|
||||||
|
test.test_flattener_basic()
|
||||||
|
test.test_flattener_no_nested()
|
||||||
|
test.test_flattener_deeply_nested()
|
||||||
|
test.test_nested_window_function_engine()
|
||||||
|
test.test_multiple_nested_factors()
|
||||||
|
test.test_nested_vs_native_polars()
|
||||||
|
test.test_factor_reference_factor()
|
||||||
|
print("\n所有测试通过!")
|
||||||
Reference in New Issue
Block a user