Compare commits

...

3 Commits

Author SHA1 Message Date
181994f063 perf(factors/engine): 重构计算引擎使用 Polars 原生并行
- 移除 Python 多进程/多线程池,消除 DataFrame 序列化开销
- 采用 BFS 分层执行策略,每层表达式通过单次 with_columns 提交
- 利用 Polars Rust 引擎实现零拷贝并行计算
- 添加死锁检测机制处理依赖环
2026-03-14 01:24:52 +08:00
2034d60fbb fix(factors): 修复 AST 优化器并发命名冲突及逻辑运算支持
- 修复 ExpressionFlattener 跨实例临时名称冲突
- 添加 & 和 | 逻辑运算符的 DSL/Parser/Translator 支持
- 增加回归测试验证修复
2026-03-14 01:17:14 +08:00
c8808d07eb feat(factors): 实现 AST 拍平优化支持嵌套窗口函数
- 新增 ExpressionFlattener 类自动拆解嵌套窗口函数(如 cs_rank(ts_delay(close, 1)))
- 支持因子引用其他因子:engine.register("fac2", cs_rank("fac1"))
- 给 DependencyExtractor 增加 ignore_symbols 免疫名单,防止已注册因子被当作数据库字段
- 添加完整测试覆盖嵌套场景和数值一致性验证
2026-03-14 01:06:17 +08:00
12 changed files with 1031 additions and 121 deletions

View File

@@ -937,6 +937,37 @@ LSP 报错Syntax error on line 45
python tests/test_sync.py
```
### 因子编写规范
**⚠️ 强制要求:编写因子时,优先使用字符串表达式而非 DSL 表达式。**
1. **推荐方式(字符串表达式)**
```python
from src.factors import FactorEngine
engine = FactorEngine()
engine.add_factor("ma20", "ts_mean(close, 20)")
engine.add_factor("alpha", "cs_rank(ts_mean(close, 5) - ts_mean(close, 20))")
```
2. **不推荐方式DSL 表达式)**
```python
from src.factors.api import close, ts_mean, cs_rank
engine.register("ma20", ts_mean(close, 20)) # 不推荐
```
3. **原因说明**
- 字符串表达式更易于序列化存储到因子元数据(`factors.jsonl`
- 字符串表达式支持从元数据动态加载和复用
- 字符串表达式便于在配置文件中定义和维护
- 与 `src/scripts/register_factors.py` 批量注册脚本兼容
4. **使用场景**
- ✅ 在 `register_factors.py` 的 `FACTORS` 列表中定义因子
- ✅ 动态添加因子到 FactorEngine
- ✅ 从因子元数据查询并注册因子
### Emoji 表情禁用规则
**⚠️ 强制要求:代码和测试文件中禁止出现 emoji 表情。**

View File

@@ -3,7 +3,7 @@
本模块实现 AST 遍历器模式,用于从 DSL 表达式中提取依赖的符号。
"""
from typing import Set
from typing import Set, Optional
from src.factors.dsl import Node, Symbol, BinaryOpNode, UnaryOpNode, FunctionNode
@@ -24,9 +24,14 @@ class DependencyExtractor:
{'close', 'pe_ratio'}
"""
def __init__(self) -> None:
"""初始化依赖提取器。"""
def __init__(self, ignore_symbols: Optional[Set[str]] = None) -> None:
"""初始化依赖提取器。
Args:
ignore_symbols: 需要忽略的符号集合(如已注册的因子名)
"""
self.dependencies: Set[str] = set()
self.ignore_symbols: Set[str] = ignore_symbols or set()
def visit(self, node: Node) -> None:
"""访问节点,根据节点类型分发到具体处理方法。
@@ -47,10 +52,14 @@ class DependencyExtractor:
def _visit_symbol(self, node: Symbol) -> None:
"""访问 Symbol 节点,提取符号名称。
排除临时因子(以 __tmp_ 开头的符号)和已在免疫名单中的因子。
Args:
node: 符号节点
"""
self.dependencies.add(node.name)
# 排除临时因子引用 和 已在免疫名单中的因子
if not node.name.startswith("__tmp_") and node.name not in self.ignore_symbols:
self.dependencies.add(node.name)
def _visit_binary_op(self, node: BinaryOpNode) -> None:
"""访问 BinaryOpNode 节点,递归遍历左右子节点。
@@ -92,13 +101,16 @@ class DependencyExtractor:
return self.dependencies.copy()
@classmethod
def extract_dependencies(cls, node: Node) -> Set[str]:
def extract_dependencies(
cls, node: Node, ignore_symbols: Optional[Set[str]] = None
) -> Set[str]:
"""类方法 - 从 AST 节点中提取所有依赖的符号名称。
这是一个便捷方法,无需手动实例化 DependencyExtractor。
Args:
node: 表达式树的根节点
ignore_symbols: 需要忽略的符号集合(如已注册的因子名)
Returns:
依赖的符号名称集合
@@ -112,17 +124,20 @@ class DependencyExtractor:
>>> print(deps)
{'close', 'open'}
"""
extractor = cls()
extractor = cls(ignore_symbols=ignore_symbols)
return extractor.extract(node)
def extract_dependencies(node: Node) -> Set[str]:
def extract_dependencies(
node: Node, ignore_symbols: Optional[Set[str]] = None
) -> Set[str]:
"""单例方法 - 从 AST 节点中提取所有依赖的符号名称。
这是 DependencyExtractor.extract_dependencies 的便捷包装函数。
Args:
node: 表达式树的根节点
ignore_symbols: 需要忽略的符号集合(如已注册的因子名)
Returns:
依赖的符号名称集合
@@ -136,7 +151,7 @@ def extract_dependencies(node: Node) -> Set[str]:
>>> print(deps)
{'close', 'pe_ratio'}
"""
return DependencyExtractor.extract_dependencies(node)
return DependencyExtractor.extract_dependencies(node, ignore_symbols=ignore_symbols)
if __name__ == "__main__":

View File

@@ -115,6 +115,24 @@ class Node(ABC):
"""大于等于: self >= other"""
return BinaryOpNode(">=", self, _ensure_node(other))
# ==================== 位运算符重载(用于逻辑运算) ====================
def __and__(self, other: Any) -> BinaryOpNode:
"""逻辑与: self & other"""
return BinaryOpNode("&", self, _ensure_node(other))
def __rand__(self, other: Any) -> BinaryOpNode:
"""右逻辑与: other & self"""
return BinaryOpNode("&", _ensure_node(other), self)
def __or__(self, other: Any) -> BinaryOpNode:
"""逻辑或: self | other"""
return BinaryOpNode("|", self, _ensure_node(other))
def __ror__(self, other: Any) -> BinaryOpNode:
"""右逻辑或: other | self"""
return BinaryOpNode("|", _ensure_node(other), self)
# ==================== 抽象方法 ====================
@abstractmethod

View File

@@ -0,0 +1,238 @@
"""AST 优化器 - 表达式拍平。
本模块实现将嵌套的窗口函数表达式自动提取为中间临时因子,
解决多维窗口函数over嵌套导致计算为空的问题。
核心思想:
通过 AST 变换,将嵌套在窗口函数内的窗口函数表达式提取出来,
作为独立的临时因子先行计算,然后主表达式引用这些临时因子。
示例:
原始表达式: cs_rank(ts_delay(close, 1))
拍平后:
- 临时因子: __tmp_0 = ts_delay(close, 1)
- 主表达式: cs_rank(__tmp_0)
"""
import threading
from typing import Dict, Tuple
from src.factors.dsl import (
BinaryOpNode,
Constant,
FunctionNode,
Node,
Symbol,
UnaryOpNode,
)
# 模块级全局计数器,用于生成唯一的临时因子名称
_global_counter: int = 0
_counter_lock = threading.Lock()
def _get_next_counter() -> int:
"""获取下一个全局计数器值。
Returns:
递增后的全局计数器值
"""
global _global_counter
with _counter_lock:
_global_counter += 1
return _global_counter
class ExpressionFlattener:
"""表达式拍平器。
遍历 AST 并自动提取嵌套的窗口函数为独立临时因子。
Attributes:
_extracted_nodes: 存储已提取的临时因子字典
"""
def __init__(self) -> None:
"""初始化拍平器。"""
self._extracted_nodes: Dict[str, Node] = {}
def _generate_temp_name(self) -> str:
"""生成唯一的临时因子名称。
使用模块级全局计数器确保跨因子注册时的唯一性。
Returns:
格式为 "__tmp_X" 的临时名称,其中 X 是全局递增数字
"""
return f"__tmp_{_get_next_counter()}"
def _is_window_function(self, func_name: str) -> bool:
"""判断是否为窗口函数。
窗口函数以 "ts_"(时序)或 "cs_"(截面)开头。
Args:
func_name: 函数名称
Returns:
是否是窗口函数
"""
return func_name.startswith("ts_") or func_name.startswith("cs_")
def flatten(self, node: Node) -> Tuple[Node, Dict[str, Node]]:
"""拍平表达式。
遍历 AST将嵌套的窗口函数提取为临时因子。
Args:
node: 原始表达式根节点
Returns:
Tuple[拍平后的主表达式节点, 临时因子字典]
临时因子字典: {临时名称 -> 被提取的节点}
Example:
>>> flattener = ExpressionFlattener()
>>> from src.factors.dsl import Symbol, FunctionNode
>>> close = Symbol("close")
>>> expr = FunctionNode("cs_rank", FunctionNode("ts_delay", close, 1))
>>> flat_expr, tmp_factors = flattener.flatten(expr)
>>> # flat_expr = cs_rank(__tmp_0)
>>> # tmp_factors = {"__tmp_0": ts_delay(close, 1)}
"""
# 重置状态(只重置提取的节点字典,计数器保持全局递增)
self._extracted_nodes = {}
# 从根节点开始遍历,初始状态为不在窗口函数内部
new_node = self._flatten_recursive(node, inside_window=False)
return new_node, self._extracted_nodes.copy()
def _flatten_recursive(self, node: Node, inside_window: bool) -> Node:
"""递归拍平节点。
Args:
node: 当前处理的节点
inside_window: 当前是否处于窗口函数内部
Returns:
处理后的节点(可能是原节点或替换为 Symbol
"""
# Symbol 和 Constant 是叶子节点,直接返回
if isinstance(node, Symbol):
return node
if isinstance(node, Constant):
return node
# 处理二元运算节点
if isinstance(node, BinaryOpNode):
return self._flatten_binary_op(node, inside_window)
# 处理一元运算节点
if isinstance(node, UnaryOpNode):
return self._flatten_unary_op(node, inside_window)
# 处理函数调用节点
if isinstance(node, FunctionNode):
return self._flatten_function(node, inside_window)
# 未知节点类型,直接返回
return node
def _flatten_binary_op(self, node: BinaryOpNode, inside_window: bool) -> Node:
"""拍平二元运算节点。
Args:
node: 二元运算节点
inside_window: 当前是否处于窗口函数内部
Returns:
处理后的节点
"""
# 递归处理左右子节点
new_left = self._flatten_recursive(node.left, inside_window)
new_right = self._flatten_recursive(node.right, inside_window)
# 如果子节点没有变化,返回原节点
if new_left is node.left and new_right is node.right:
return node
# 创建新的二元运算节点
return BinaryOpNode(node.op, new_left, new_right)
def _flatten_unary_op(self, node: UnaryOpNode, inside_window: bool) -> Node:
"""拍平一元运算节点。
Args:
node: 一元运算节点
inside_window: 当前是否处于窗口函数内部
Returns:
处理后的节点
"""
# 递归处理操作数
new_operand = self._flatten_recursive(node.operand, inside_window)
# 如果操作数没有变化,返回原节点
if new_operand is node.operand:
return node
# 创建新的一元运算节点
return UnaryOpNode(node.op, new_operand)
def _flatten_function(self, node: FunctionNode, inside_window: bool) -> Node:
"""拍平函数调用节点。
修正为后序遍历Bottom-Up先递归拍平参数再决定是否提取当前节点。
确保深层嵌套(如 3层以上也能被彻底逐层拆解。
Args:
node: 函数调用节点
inside_window: 当前是否处于窗口函数内部
Returns:
处理后的节点
"""
is_window = self._is_window_function(node.func_name)
next_inside_window = inside_window or is_window
# 1. 优先递归处理所有参数
new_args = []
has_change = False
for arg in node.args:
new_arg = self._flatten_recursive(arg, next_inside_window)
new_args.append(new_arg)
if new_arg is not arg:
has_change = True
# 2. 只有当参数发生变化时,才创建新的当前节点
current_node = FunctionNode(node.func_name, *new_args) if has_change else node
# 3. 判断是否需要提取(此时子节点肯定已经被彻底拍平了)
if inside_window and is_window:
temp_name = self._generate_temp_name()
self._extracted_nodes[temp_name] = current_node
return Symbol(temp_name)
return current_node
def flatten_expression(node: Node) -> Tuple[Node, Dict[str, Node]]:
"""便捷函数 - 拍平表达式。
Args:
node: 表达式树的根节点
Returns:
Tuple[拍平后的主表达式节点, 临时因子字典]
Example:
>>> from src.factors.dsl import Symbol, FunctionNode
>>> close = Symbol("close")
>>> expr = FunctionNode("cs_rank", FunctionNode("ts_delay", close, 1))
>>> flat_expr, tmp_factors = flatten_expression(expr)
"""
flattener = ExpressionFlattener()
return flattener.flatten(node)

View File

@@ -1,10 +1,12 @@
"""计算引擎。
执行并行运算,负责将执行计划应用到数据上。
利用 Polars 底层 Rust 引擎的原生并行能力,通过 BFS 分层执行策略
避免 Python 层面的多进程/多线程开销。
"""
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor
from typing import Any, Dict, List, Optional, Set, Union
from typing import Dict, List, Set
import polars as pl
@@ -14,33 +16,25 @@ from src.factors.engine.data_spec import ExecutionPlan
class ComputeEngine:
"""计算引擎 - 执行并行运算。
负责将执行计划应用到数据上,支持并行计算
负责将执行计划应用到数据上,利用 Polars 底层 Rust 引擎的原生并行能力
Attributes:
max_workers: 最大并行工作线程数
use_processes: 是否使用进程池CPU 密集型任务)
采用 BFS 分层执行策略:
1. 构建依赖图,识别各计划间的依赖关系
2. 按拓扑排序分层,每层包含互不依赖的计划
3. 将每层计划打包为表达式列表,通过单次 with_columns 提交
4. Polars 自动在所有 CPU 核心上并行计算,零拷贝内存
"""
def __init__(
self,
max_workers: int = 4,
use_processes: bool = False,
) -> None:
"""初始化计算引擎。
Args:
max_workers: 最大并行工作线程数
use_processes: 是否使用进程池代替线程池
"""
self.max_workers = max_workers
self.use_processes = use_processes
def __init__(self) -> None:
"""初始化计算引擎。"""
pass
def execute(
self,
plan: ExecutionPlan,
data: pl.DataFrame,
) -> pl.DataFrame:
"""执行计算计划。
"""执行单个计算计划。
Args:
plan: 执行计划
@@ -55,16 +49,14 @@ class ComputeEngine:
raise ValueError(f"数据缺少必要的字段: {missing_cols}")
# 执行计算
result = data.with_columns([plan.polars_expr.alias(plan.output_name)])
return result
return data.with_columns([plan.polars_expr.alias(plan.output_name)])
def execute_batch(
self,
plans: List[ExecutionPlan],
data: pl.DataFrame,
) -> pl.DataFrame:
"""批量执行多个计算计划。
"""顺序批量执行多个计算计划。
Args:
plans: 执行计划列表
@@ -74,10 +66,8 @@ class ComputeEngine:
包含所有因子结果的 DataFrame
"""
result = data
for plan in plans:
result = self.execute(plan, result)
return result
def execute_parallel(
@@ -85,7 +75,11 @@ class ComputeEngine:
plans: List[ExecutionPlan],
data: pl.DataFrame,
) -> pl.DataFrame:
"""并行执行多个计算计划。
"""分层并行执行计算计划(利用 Polars 原生并发优化)
抛弃 Python 的多进程/多线程池采用计算图拓扑分层BFS DAG
将每一层互不依赖的表达式列表打包,通过单次 with_columns 交给 Polars
由底层 Rust 引擎自动调度并行计算,实现零拷贝性能最大化。
Args:
plans: 执行计划列表
@@ -93,63 +87,70 @@ class ComputeEngine:
Returns:
包含所有因子结果的 DataFrame
Raises:
RuntimeError: 当存在依赖环或缺少基础依赖字段时
"""
# 检查计划间依赖
independent_plans = []
dependent_plans = []
available_cols = set(data.columns)
if not plans:
return data
for plan in plans:
if plan.dependencies <= available_cols:
independent_plans.append(plan)
result = data
available_cols: Set[str] = set(result.columns)
# 复制一份计划列表用于迭代
remaining_plans = plans.copy()
while remaining_plans:
# 找出当前可以执行的所有独立计划(即依赖的所有列都已就绪)
current_layer: List[ExecutionPlan] = []
next_remaining: List[ExecutionPlan] = []
for plan in remaining_plans:
if plan.dependencies <= available_cols:
current_layer.append(plan)
else:
next_remaining.append(plan)
# 安全兜底:如果一轮遍历后没找到任何可执行计划,说明存在依赖环或数据缺失
if not current_layer:
missing = remaining_plans[0].dependencies - available_cols
raise RuntimeError(
f"计算发生死锁或缺少基础依赖字段!\n"
f"因子 '{remaining_plans[0].output_name}' 缺少: {missing}"
)
# 核心优化:利用 Polars 内部 Rust 级多线程引擎执行当前层
exprs = [plan.polars_expr.alias(plan.output_name) for plan in current_layer]
result = result.with_columns(exprs)
# 更新已就绪字段集合,为计算下一层做准备
for plan in current_layer:
available_cols.add(plan.output_name)
else:
dependent_plans.append(plan)
# 并行执行独立计划
if independent_plans:
ExecutorClass = (
ProcessPoolExecutor if self.use_processes else ThreadPoolExecutor
)
remaining_plans = next_remaining
with ExecutorClass(max_workers=self.max_workers) as executor:
futures = {
executor.submit(self._execute_single, plan, data): plan
for plan in independent_plans
}
return result
results = []
for future in futures:
plan = futures[future]
try:
result_col = future.result()
results.append((plan.output_name, result_col))
except Exception as e:
raise RuntimeError(f"计算因子 {plan.output_name} 失败: {e}")
# 合并结果
for name, series in results:
data = data.with_columns([series.alias(name)])
# 顺序执行依赖计划
for plan in dependent_plans:
data = self.execute(plan, data)
return data
def _execute_single(
def compute(
self,
plan: ExecutionPlan,
plans: List[ExecutionPlan],
data: pl.DataFrame,
) -> pl.Series:
"""执行单个计划并返回结果列。
parallel: bool = True,
) -> pl.DataFrame:
"""智能计算入口。
根据 parallel 参数自动选择执行模式:
- True: 使用分层并行执行(推荐)
- False: 使用顺序执行
Args:
plan: 执行计划
plans: 执行计划列表
data: 输入数据
parallel: 是否使用并行执行
Returns:
计算结果序列
包含所有因子结果的 DataFrame
"""
result = self.execute(plan, data)
return result[plan.output_name]
if parallel:
return self.execute_parallel(plans, data)
return self.execute_batch(plans, data)

View File

@@ -30,6 +30,7 @@ from src.factors.engine.data_spec import DataSpec, ExecutionPlan
from src.factors.engine.data_router import DataRouter
from src.factors.engine.planner import ExecutionPlanner
from src.factors.engine.compute_engine import ComputeEngine
from src.factors.engine.ast_optimizer import ExpressionFlattener
class FactorEngine:
@@ -56,7 +57,6 @@ class FactorEngine:
def __init__(
self,
data_source: Optional[Dict[str, pl.DataFrame]] = None,
max_workers: int = 4,
registry: Optional["FunctionRegistry"] = None,
metadata_path: Optional[str] = None,
) -> None:
@@ -64,16 +64,15 @@ class FactorEngine:
Args:
data_source: 内存数据源,为 None 时使用数据库连接
max_workers: 并行计算的最大工作线程数
registry: 函数注册表None 时创建独立实例
metadata_path: 因子元数据文件路径,为 None 时启用 metadata 功能
metadata_path: 因子元数据文件路径,为 None 时启用默认 metadata 功能
"""
from src.factors.registry import FunctionRegistry
from src.factors.parser import FormulaParser
self.router = DataRouter(data_source)
self.planner = ExecutionPlanner()
self.compute_engine = ComputeEngine(max_workers=max_workers)
self.compute_engine = ComputeEngine()
self.registered_expressions: Dict[str, Node] = {}
self._plans: Dict[str, ExecutionPlan] = {}
@@ -92,13 +91,68 @@ class FactorEngine:
self._metadata = FactorManager()
def _register_internal(
self,
name: str,
expression: Node,
data_specs: Optional[List[DataSpec]] = None,
) -> "FactorEngine":
"""内部注册方法,直接注册因子表达式。
Args:
name: 因子名称
expression: DSL 表达式
data_specs: 数据规格None 时自动推导
Returns:
self支持链式调用
"""
# 检测因子依赖(在注册当前因子之前检查其他已注册因子)
factor_deps = self._find_factor_dependencies(expression)
# 获取当前所有已注册的因子名称(作为免疫名单,防止被当作数据库字段)
known_factors = set(self.registered_expressions.keys())
self.registered_expressions[name] = expression
# 预创建执行计划,过滤掉已注册的因子,防止被当作数据库字段
plan = self.planner.create_plan(
expression=expression,
output_name=name,
data_specs=data_specs,
ignore_dependencies=known_factors,
)
# 添加因子依赖信息
plan.factor_dependencies = factor_deps
# 如果数据规格为空,继承依赖因子(包括临时因子)的数据规格
if not plan.data_specs and factor_deps:
merged_specs: List[DataSpec] = []
for dep_name in factor_deps:
if dep_name in self._plans:
merged_specs.extend(self._plans[dep_name].data_specs)
# 去重(基于表名)
seen_tables: set = set()
unique_specs: List[DataSpec] = []
for spec in merged_specs:
if spec.table not in seen_tables:
seen_tables.add(spec.table)
unique_specs.append(spec)
plan.data_specs = unique_specs
self._plans[name] = plan
return self
def register(
self,
name: str,
expression: Node,
data_specs: Optional[List[DataSpec]] = None,
) -> "FactorEngine":
"""注册因子表达式。
"""注册因子表达式(自动处理嵌套窗口函数)
Args:
name: 因子名称
@@ -113,22 +167,16 @@ class FactorEngine:
>>> engine = FactorEngine()
>>> engine.register("ma20", ts_mean(close, 20))
"""
# 检测因子依赖(在注册当前因子之前检查其他已注册因子)
factor_deps = self._find_factor_dependencies(expression)
# 使用 AST 优化器拍平嵌套窗口函数
flattener = ExpressionFlattener()
flat_expression, tmp_factors = flattener.flatten(expression)
self.registered_expressions[name] = expression
# 先注册所有临时因子(自动推导数据规格)
for tmp_name, tmp_node in tmp_factors.items():
self._register_internal(tmp_name, tmp_node, data_specs=None)
# 预创建执行计划
plan = self.planner.create_plan(
expression=expression,
output_name=name,
data_specs=data_specs,
)
# 添加因子依赖信息
plan.factor_dependencies = factor_deps
self._plans[name] = plan
# 最后注册主因子
self._register_internal(name, flat_expression, data_specs)
return self
@@ -174,7 +222,7 @@ class FactorEngine:
# 解析表达式为 Node
node = self._parser.parse(dsl_expr)
# 委托给 register 方法
# 委托给 register 方法register 会处理嵌套窗口函数拍平)
return self.register(name, node, data_specs)
def add_factor(
@@ -272,21 +320,32 @@ class FactorEngine:
if isinstance(factor_names, str):
factor_names = [factor_names]
# 1. 获取执行计划
# 1. 收集所有需要的因子(包括临时因子依赖)
all_factor_names = self._collect_all_dependencies(factor_names)
# 2. 获取执行计划
plans = []
for name in factor_names:
for name in all_factor_names:
if name not in self._plans:
raise ValueError(f"因子未注册: {name}")
plans.append(self._plans[name])
# 2. 合并数据规格并获取数据
# 3. 合并数据规格并获取数据
all_specs = []
for plan in plans:
all_specs.extend(plan.data_specs)
# 3. 从路由器获取核心宽表
# 去重数据规格(基于表名)
seen_tables: set = set()
unique_specs: List[DataSpec] = []
for spec in all_specs:
if spec.table not in seen_tables:
seen_tables.add(spec.table)
unique_specs.append(spec)
# 4. 从路由器获取核心宽表
core_data = self.router.fetch_data(
data_specs=all_specs,
data_specs=unique_specs,
start_date=start_date,
end_date=end_date,
stock_codes=stock_codes,
@@ -295,14 +354,14 @@ class FactorEngine:
if len(core_data) == 0:
raise ValueError("未获取到任何数据,请检查日期范围和股票代码")
# 4. 按依赖顺序执行计算
if len(plans) == 1:
result = self.compute_engine.execute(plans[0], core_data)
else:
# 使用依赖感知的方式执行
result = self._execute_with_dependencies(factor_names, core_data)
# 5. 按依赖顺序执行计算(包含临时因子)
result = self._execute_with_dependencies(all_factor_names, core_data)
return result
# 6. 清理内存宽表过滤掉临时因子列__tmp_X
# 保留所有非临时因子列(包括原始数据列和用户请求的因子列)
cols_to_keep = [col for col in result.columns if not col.startswith("__tmp_")]
return result.select(cols_to_keep)
def list_registered(self) -> List[str]:
"""获取已注册的因子列表。
@@ -501,10 +560,32 @@ class FactorEngine:
return False
def _find_factor_dependencies(self, expression: Node) -> Set[str]:
"""查找表达式依赖的其他因子。
def _collect_all_dependencies(self, factor_names: List[str]) -> List[str]:
"""收集所有因子及其依赖(包括用户定义的因子和临时因子)。"""
collected: Set[str] = set()
result: List[str] = []
遍历已注册因子,检查表达式是否包含任何已注册因子的完整表达式。
def collect_recursive(name: str):
if name in collected:
return
collected.add(name)
# 获取执行计划并递归收集强依赖
plan = self._plans.get(name)
if plan:
for dep_name in plan.factor_dependencies:
collect_recursive(dep_name)
# 依赖收集完毕,再将自己加入列表(天然形成安全的计算顺序)
result.append(name)
for name in factor_names:
collect_recursive(name)
return result
def _find_factor_dependencies(self, expression: Node) -> Set[str]:
"""查找表达式依赖的其他因子(包括临时因子和用户因子引用)。
Args:
expression: 待检查的表达式
@@ -514,13 +595,20 @@ class FactorEngine:
"""
deps: Set[str] = set()
# 检查表达式本身是否等于某个已注册因子
# 1. 【新增】如果直接引用了已注册因子名称(包含 __tmp_X 或用户因子)
if (
isinstance(expression, Symbol)
and expression.name in self.registered_expressions
):
deps.add(expression.name)
# 2. 检查表达式本身是否等于某个已注册因子的完整 AST
for name, registered_expr in self.registered_expressions.items():
if self._expressions_equal(expression, registered_expr):
deps.add(name)
break
# 递归检查子节点
# 3. 递归检查子节点
if isinstance(expression, BinaryOpNode):
deps.update(self._find_factor_dependencies(expression.left))
deps.update(self._find_factor_dependencies(expression.right))

View File

@@ -39,6 +39,7 @@ class ExecutionPlanner:
expression: Node,
output_name: str = "factor",
data_specs: Optional[List[DataSpec]] = None,
ignore_dependencies: Optional[Set[str]] = None,
) -> ExecutionPlan:
"""从表达式创建执行计划。
@@ -46,12 +47,15 @@ class ExecutionPlanner:
expression: DSL 表达式节点
output_name: 输出因子名称
data_specs: 预定义的数据规格None 时自动推导
ignore_dependencies: 需要忽略的依赖符号集合(如已注册因子名)
Returns:
执行计划对象
"""
# 1. 提取依赖
dependencies = self.compiler.extract_dependencies(expression)
# 1. 提取依赖时传入要忽略的符号
dependencies = self.compiler.extract_dependencies(
expression, ignore_symbols=ignore_dependencies
)
# 2. 翻译为 Polars 表达式
polars_expr = self.translator.translate(expression)

View File

@@ -35,6 +35,8 @@ BIN_OP_MAP: Dict[type, str] = {
ast.Pow: "**",
ast.FloorDiv: "//",
ast.Mod: "%",
ast.BitAnd: "&",
ast.BitOr: "|",
}
UNARY_OP_MAP: Dict[type, str] = {

View File

@@ -178,6 +178,8 @@ class PolarsTranslator:
"<=": lambda l, r: l.le(r),
">": lambda l, r: l.gt(r),
">=": lambda l, r: l.ge(r),
"&": lambda l, r: l & r,
"|": lambda l, r: l | r,
}
if node.op not in op_map:

367
tests/test_ast_optimizer.py Normal file
View File

@@ -0,0 +1,367 @@
"""AST 优化器测试 - 验证嵌套窗口函数拍平功能。
测试因子: cs_rank(ts_delay(close, 1))
这是一个典型的窗口函数嵌套场景,应该被自动拍平为临时因子。
"""
import pytest
import polars as pl
import numpy as np
from datetime import datetime, timedelta
from src.factors.engine import FactorEngine
from src.factors.api import close, ts_delay, cs_rank
from src.factors.dsl import FunctionNode
from src.factors.engine.ast_optimizer import ExpressionFlattener
def create_mock_data(
start_date: str = "20240101",
end_date: str = "20240131",
n_stocks: int = 5,
) -> pl.DataFrame:
"""创建模拟的日线数据。"""
start = datetime.strptime(start_date, "%Y%m%d")
end = datetime.strptime(end_date, "%Y%m%d")
dates = []
current = start
while current <= end:
if current.weekday() < 5: # 周一到周五
dates.append(current.strftime("%Y%m%d"))
current += timedelta(days=1)
stocks = [f"{600000 + i:06d}.SH" for i in range(n_stocks)]
np.random.seed(42)
rows = []
for date in dates:
for stock in stocks:
base_price = 10 + np.random.randn() * 5
close_val = base_price + np.random.randn() * 0.5
open_val = close_val + np.random.randn() * 0.2
high_val = max(open_val, close_val) + abs(np.random.randn()) * 0.3
low_val = min(open_val, close_val) - abs(np.random.randn()) * 0.3
vol = int(1000000 + np.random.exponential(500000))
rows.append(
{
"ts_code": stock,
"trade_date": date,
"open": round(open_val, 2),
"high": round(high_val, 2),
"low": round(low_val, 2),
"close": round(close_val, 2),
"volume": vol,
}
)
return pl.DataFrame(rows)
class TestASTOptimizer:
"""AST 优化器测试类。"""
def test_flattener_basic(self):
"""测试拍平器基本功能。"""
from src.factors.api import close
flattener = ExpressionFlattener()
# 创建嵌套表达式: cs_rank(ts_delay(close, 1))
expr = FunctionNode("cs_rank", FunctionNode("ts_delay", close, 1))
flat_expr, tmp_factors = flattener.flatten(expr)
# 验证临时因子被提取
assert len(tmp_factors) == 1
assert "__tmp_0" in tmp_factors
# 验证主表达式使用了 Symbol 引用
assert isinstance(flat_expr, FunctionNode)
assert flat_expr.func_name == "cs_rank"
# 验证第一个参数是临时因子引用(通过 name 属性检查)
assert hasattr(flat_expr.args[0], "name")
assert flat_expr.args[0].name == "__tmp_0"
# 验证临时因子内容
tmp_node = tmp_factors["__tmp_0"]
assert isinstance(tmp_node, FunctionNode)
assert tmp_node.func_name == "ts_delay"
print("[PASS] 拍平器基本功能测试")
def test_flattener_no_nested(self):
"""测试非嵌套表达式不会被拍平。"""
from src.factors.api import close, ts_mean
flattener = ExpressionFlattener()
# 非嵌套表达式: ts_mean(close, 20)
expr = FunctionNode("ts_mean", close, 20)
flat_expr, tmp_factors = flattener.flatten(expr)
# 验证没有临时因子被提取
assert len(tmp_factors) == 0
# 验证表达式保持不变
assert isinstance(flat_expr, FunctionNode)
assert flat_expr.func_name == "ts_mean"
print("[PASS] 非嵌套表达式测试")
def test_flattener_deeply_nested(self):
"""测试多层嵌套表达式拍平。"""
from src.factors.api import close, ts_mean
flattener = ExpressionFlattener()
# 深层嵌套: cs_rank(ts_mean(ts_delay(close, 1), 5))
expr = FunctionNode(
"cs_rank", FunctionNode("ts_mean", FunctionNode("ts_delay", close, 1), 5)
)
flat_expr, tmp_factors = flattener.flatten(expr)
# 验证提取了两个临时因子(修复后正确行为)
# ts_delay(close, 1) 被提取为 __tmp_0
# ts_mean(__tmp_0, 5) 被提取为 __tmp_1
assert len(tmp_factors) == 2
assert "__tmp_0" in tmp_factors
assert "__tmp_1" in tmp_factors
# 验证 __tmp_0 内容是 ts_delay(close, 1)
tmp0_node = tmp_factors["__tmp_0"]
assert isinstance(tmp0_node, FunctionNode)
assert tmp0_node.func_name == "ts_delay"
# 验证 __tmp_1 内容是 ts_mean(__tmp_0, 5)
tmp1_node = tmp_factors["__tmp_1"]
assert isinstance(tmp1_node, FunctionNode)
assert tmp1_node.func_name == "ts_mean"
from src.factors.dsl import Symbol
assert isinstance(tmp1_node.args[0], Symbol)
assert tmp1_node.args[0].name == "__tmp_0"
# 验证主表达式引用 __tmp_1
assert isinstance(flat_expr, FunctionNode)
assert flat_expr.func_name == "cs_rank"
assert isinstance(flat_expr.args[0], Symbol)
assert flat_expr.args[0].name == "__tmp_1"
print("[PASS] 多层嵌套表达式拍平测试")
def test_nested_window_function_engine(self):
"""测试引擎正确处理嵌套窗口函数 cs_rank(ts_delay(close, 1))。"""
print("\n" + "=" * 60)
print("测试嵌套窗口函数: cs_rank(ts_delay(close, 1))")
print("=" * 60)
# 1. 准备数据
mock_data = create_mock_data("20240101", "20240131", n_stocks=5)
print(f"\n生成模拟数据: {len(mock_data)}")
# 2. 初始化引擎
engine = FactorEngine(data_source={"pro_bar": mock_data})
print("引擎初始化完成")
# 3. 使用字符串表达式注册嵌套窗口函数
print("\n注册因子: cs_rank(ts_delay(close, 1))")
engine.add_factor("delayed_rank", "cs_rank(ts_delay(close, 1))")
# 4. 检查临时因子是否被创建
registered_factors = engine.list_registered()
print(f"已注册因子: {registered_factors}")
# 验证有临时因子被创建
tmp_factors = [name for name in registered_factors if name.startswith("__tmp_")]
assert len(tmp_factors) >= 1, "应该有临时因子被创建"
print(f"临时因子: {tmp_factors}")
# 5. 执行计算
print("\n执行计算...")
result = engine.compute("delayed_rank", "20240115", "20240131")
print(f"计算完成: {len(result)}")
# 6. 验证结果
assert "delayed_rank" in result.columns, "结果中应该有 delayed_rank 列"
# 检查结果值是否在合理范围内(排名因子应该在 0-1 之间,但可能由于滞后有 null
non_null_values = result["delayed_rank"].drop_nulls()
if len(non_null_values) > 0:
assert non_null_values.min() >= 0, "排名应该在 [0, 1] 之间"
assert non_null_values.max() <= 1, "排名应该在 [0, 1] 之间"
# 检查没有过多空值(考虑到开头的滞后期)
null_count = result["delayed_rank"].is_null().sum()
print(f"空值数量: {null_count}")
# 展示部分结果
print("\n前 10 行结果:")
sample = result.select(["ts_code", "trade_date", "close", "delayed_rank"]).head(
10
)
print(sample.to_pandas().to_string(index=False))
print("\n" + "=" * 60)
print("嵌套窗口函数测试通过!")
print("=" * 60)
def test_multiple_nested_factors(self):
"""测试同时注册多个嵌套因子。"""
print("\n" + "=" * 60)
print("测试多个嵌套因子")
print("=" * 60)
mock_data = create_mock_data("20240101", "20240131", n_stocks=5)
engine = FactorEngine(data_source={"pro_bar": mock_data})
# 注册多个嵌套因子(使用字符串表达式)
print("\n注册因子1: cs_rank(ts_delay(close, 1))")
engine.add_factor("rank1", "cs_rank(ts_delay(close, 1))")
print("注册因子2: ts_mean(cs_rank(close), 5)")
engine.add_factor("rank_mean", "ts_mean(cs_rank(close), 5)")
# 检查已注册因子
factors = engine.list_registered()
print(f"\n已注册因子: {factors}")
# 计算所有因子
result = engine.compute(["rank1", "rank_mean"], "20240115", "20240131")
assert "rank1" in result.columns
assert "rank_mean" in result.columns
print(f"\n结果行数: {len(result)}")
print(f"rank1 空值数: {result['rank1'].is_null().sum()}")
print(f"rank_mean 空值数: {result['rank_mean'].is_null().sum()}")
print("\n" + "=" * 60)
print("多个嵌套因子测试通过!")
print("=" * 60)
def test_nested_vs_native_polars(self):
"""对比测试:嵌套窗口函数 vs 原生 Polars 计算,验证数值一致性。"""
print("\n" + "=" * 60)
print("对比测试cs_rank(ts_delay(close, 1)) vs 原生 Polars")
print("=" * 60)
# 1. 准备数据
mock_data = create_mock_data("20240101", "20240131", n_stocks=5)
print(f"\n生成模拟数据: {len(mock_data)}")
# 2. 使用 FactorEngine 计算嵌套因子
engine = FactorEngine(data_source={"pro_bar": mock_data})
print("\n使用 FactorEngine 计算 cs_rank(ts_delay(close, 1))...")
engine.register("delayed_rank", cs_rank(ts_delay(close, 1)))
engine_result = engine.compute("delayed_rank", "20240115", "20240131")
print(f"FactorEngine 结果: {len(engine_result)}")
# 3. 使用原生 Polars 计算(手动分步)
print("\n使用原生 Polars 手动计算...")
# 先计算 ts_delay(close, 1)
native_result = mock_data.sort(["ts_code", "trade_date"]).with_columns(
[pl.col("close").shift(1).over("ts_code").alias("delayed_close")]
)
# 再计算 cs_rank
native_result = native_result.with_columns(
[
(pl.col("delayed_close").rank() / pl.col("delayed_close").count())
.over("trade_date")
.alias("native_delayed_rank")
]
)
print(f"原生 Polars 结果: {len(native_result)}")
# 4. 合并结果进行对比
comparison = engine_result.join(
native_result.select(["ts_code", "trade_date", "native_delayed_rank"]),
on=["ts_code", "trade_date"],
how="inner",
)
# 5. 验证数值一致性(允许微小浮点误差)
diff = comparison.with_columns(
[
(pl.col("delayed_rank") - pl.col("native_delayed_rank"))
.abs()
.alias("diff")
]
)
max_diff = diff["diff"].max()
print(f"\n最大差异: {max_diff}")
# 过滤掉空值后比较(开头的滞后期会有空值)
non_null_diff = diff.filter(pl.col("diff").is_not_null())
assert non_null_diff["diff"].max() < 1e-10, (
f"数值差异过大: {non_null_diff['diff'].max()}"
)
print("\n" + "=" * 60)
print("数值一致性验证通过!")
print("=" * 60)
def test_factor_reference_factor(self):
"""测试因子引用另一个因子fac2 = cs_rank(fac1)。"""
print("\n" + "=" * 60)
print("测试因子引用其他因子: fac2 = cs_rank(fac1)")
print("=" * 60)
# 准备数据
mock_data = create_mock_data("20240101", "20240131", n_stocks=5)
engine = FactorEngine(data_source={"pro_bar": mock_data})
# 1. 注册基础因子 fac1
print("\n注册基础因子 fac1 = ts_mean(close, 5)")
from src.factors.api import ts_mean
engine.register("fac1", ts_mean(close, 5))
# 2. 注册引用因子 fac2引用 fac1
print("注册引用因子 fac2 = cs_rank(fac1)")
engine.register("fac2", cs_rank("fac1")) # 字符串引用另一个因子
# 3. 验证依赖关系
registered = engine.list_registered()
print(f"\n已注册因子: {registered}")
assert "fac1" in registered
assert "fac2" in registered
# 4. 执行计算
print("\n执行计算...")
result = engine.compute(["fac1", "fac2"], "20240115", "20240131")
print(f"计算完成: {len(result)}")
# 5. 验证结果
assert "fac1" in result.columns, "结果中应有 fac1 列"
assert "fac2" in result.columns, "结果中应有 fac2 列"
# fac2 是排名,应在 [0, 1] 之间
assert result["fac2"].min() >= 0, "排名应在 [0, 1] 之间"
assert result["fac2"].max() <= 1, "排名应在 [0, 1] 之间"
print("\n前 10 行结果:")
sample = result.select(["ts_code", "trade_date", "close", "fac1", "fac2"]).head(
10
)
print(sample.to_pandas().to_string(index=False))
print("\n" + "=" * 60)
print("因子引用功能测试通过!")
print("=" * 60)
if __name__ == "__main__":
test = TestASTOptimizer()
test.test_flattener_basic()
test.test_flattener_no_nested()
test.test_flattener_deeply_nested()
test.test_nested_window_function_engine()
test.test_multiple_nested_factors()
test.test_nested_vs_native_polars()
test.test_factor_reference_factor()
print("\n所有测试通过!")

144
tests/test_bugfixes.py Normal file
View File

@@ -0,0 +1,144 @@
"""测试 Bug 修复:
1. 临时因子命名冲突修复验证
2. 逻辑运算符支持验证
"""
import sys
sys.path.insert(0, "D:/PyProject/ProStock")
from src.factors.dsl import Symbol, BinaryOpNode
from src.factors.engine.ast_optimizer import ExpressionFlattener, flatten_expression
def test_temp_name_uniqueness():
"""测试:临时因子名称全局唯一性。"""
print("测试 1: 临时因子命名冲突修复")
print("-" * 50)
close = Symbol("close")
open_price = Symbol("open")
# 创建两个表达式拍平器实例
flattener1 = ExpressionFlattener()
flattener2 = ExpressionFlattener()
# 模拟因子 A: cs_rank(ts_delay(close, 1))
from src.factors.dsl import FunctionNode
expr_a = FunctionNode("cs_rank", FunctionNode("ts_delay", close, 1))
flat_a, temps_a = flattener1.flatten(expr_a)
# 模拟因子 B: cs_mean(ts_delay(open, 2))
expr_b = FunctionNode("cs_mean", FunctionNode("ts_delay", open_price, 2))
flat_b, temps_b = flattener2.flatten(expr_b)
# 验证临时名称不冲突
temp_names_a = set(temps_a.keys())
temp_names_b = set(temps_b.keys())
print(f"因子 A 临时名称: {temp_names_a}")
print(f"因子 B 临时名称: {temp_names_b}")
# 检查是否有名称冲突
common_names = temp_names_a & temp_names_b
if common_names:
print(f"[失败] 发现命名冲突: {common_names}")
return False
print("[通过] 临时因子名称全局唯一,无冲突")
return True
def test_logical_operators():
"""测试:逻辑运算符支持。"""
print("\n测试 2: 逻辑运算符支持")
print("-" * 50)
# 测试 DSL 层
close = Symbol("close")
open_price = Symbol("open")
# 测试 & 运算符(注意 Python 运算符优先级,需要用括号)
and_expr = (close > open_price) & (close > 0)
print(f"DSL 表达式 ((close > open) & (close > 0)): {and_expr}")
assert isinstance(and_expr, BinaryOpNode), "& 应生成 BinaryOpNode"
assert and_expr.op == "&", "运算符应为 &"
print("[通过] DSL 层支持 & 运算符")
# 测试 | 运算符(注意 Python 运算符优先级,需要用括号)
or_expr = (close < open_price) | (close < 0)
print(f"DSL 表达式 ((close < open) | (close < 0)): {or_expr}")
assert isinstance(or_expr, BinaryOpNode), "| 应生成 BinaryOpNode"
assert or_expr.op == "|", "运算符应为 |"
print("[通过] DSL 层支持 | 运算符")
# 测试字符串解析
from src.factors.parser import FormulaParser
from src.factors.registry import FunctionRegistry
parser = FormulaParser(FunctionRegistry())
# 解析包含 & 的表达式
try:
parsed_and = parser.parse("(close > open) & (volume > 0)")
print(f"解析器支持 & 运算符: {parsed_and}")
print("[通过] Parser 支持 & 运算符")
except Exception as e:
print(f"[失败] Parser 解析 & 失败: {e}")
return False
# 解析包含 | 的表达式
try:
parsed_or = parser.parse("(close < open) | (volume < 0)")
print(f"解析器支持 | 运算符: {parsed_or}")
print("[通过] Parser 支持 | 运算符")
except Exception as e:
print(f"[失败] Parser 解析 | 失败: {e}")
return False
# 测试翻译到 Polars
from src.factors.translator import PolarsTranslator
import polars as pl
translator = PolarsTranslator()
try:
polars_and = translator.translate(parsed_and)
print(f"Polars 表达式 (&): {polars_and}")
print("[通过] Translator 支持 & 运算符")
except Exception as e:
print(f"[失败] Translator 翻译 & 失败: {e}")
return False
try:
polars_or = translator.translate(parsed_or)
print(f"Polars 表达式 (|): {polars_or}")
print("[通过] Translator 支持 | 运算符")
except Exception as e:
print(f"[失败] Translator 翻译 | 失败: {e}")
return False
return True
if __name__ == "__main__":
print("=" * 60)
print("Bug 修复验证测试")
print("=" * 60)
test1_passed = test_temp_name_uniqueness()
test2_passed = test_logical_operators()
print("\n" + "=" * 60)
print("测试结果汇总")
print("=" * 60)
print(f"临时因子命名冲突修复: {'[通过]' if test1_passed else '[失败]'}")
print(f"逻辑运算符支持: {'[通过]' if test2_passed else '[失败]'}")
if test1_passed and test2_passed:
print("\n所有测试通过!")
sys.exit(0)
else:
print("\n存在失败的测试!")
sys.exit(1)

View File

@@ -72,7 +72,7 @@ class TestFactorEngineEndToEnd:
def engine(self, mock_data):
"""提供配置好的 FactorEngine fixture。"""
data_source = {"pro_bar": mock_data}
return FactorEngine(data_source=data_source, max_workers=2)
return FactorEngine(data_source=data_source)
def test_simple_symbol_expression(self, engine):
"""测试简单的符号表达式。"""