feat(factors): 实现 AST 拍平优化支持嵌套窗口函数
- 新增 ExpressionFlattener 类自动拆解嵌套窗口函数(如 cs_rank(ts_delay(close, 1)))
- 支持因子引用其他因子:engine.register("fac2", cs_rank("fac1"))
- 给 DependencyExtractor 增加 ignore_symbols 免疫名单,防止已注册因子被当作数据库字段
- 添加完整测试覆盖嵌套场景和数值一致性验证
This commit is contained in:
@@ -30,6 +30,7 @@ from src.factors.engine.data_spec import DataSpec, ExecutionPlan
|
||||
from src.factors.engine.data_router import DataRouter
|
||||
from src.factors.engine.planner import ExecutionPlanner
|
||||
from src.factors.engine.compute_engine import ComputeEngine
|
||||
from src.factors.engine.ast_optimizer import ExpressionFlattener
|
||||
|
||||
|
||||
class FactorEngine:
|
||||
@@ -92,13 +93,68 @@ class FactorEngine:
|
||||
|
||||
self._metadata = FactorManager()
|
||||
|
||||
def _register_internal(
|
||||
self,
|
||||
name: str,
|
||||
expression: Node,
|
||||
data_specs: Optional[List[DataSpec]] = None,
|
||||
) -> "FactorEngine":
|
||||
"""内部注册方法,直接注册因子表达式。
|
||||
|
||||
Args:
|
||||
name: 因子名称
|
||||
expression: DSL 表达式
|
||||
data_specs: 数据规格,None 时自动推导
|
||||
|
||||
Returns:
|
||||
self,支持链式调用
|
||||
"""
|
||||
# 检测因子依赖(在注册当前因子之前检查其他已注册因子)
|
||||
factor_deps = self._find_factor_dependencies(expression)
|
||||
|
||||
# 获取当前所有已注册的因子名称(作为免疫名单,防止被当作数据库字段)
|
||||
known_factors = set(self.registered_expressions.keys())
|
||||
|
||||
self.registered_expressions[name] = expression
|
||||
|
||||
# 预创建执行计划,过滤掉已注册的因子,防止被当作数据库字段
|
||||
plan = self.planner.create_plan(
|
||||
expression=expression,
|
||||
output_name=name,
|
||||
data_specs=data_specs,
|
||||
ignore_dependencies=known_factors,
|
||||
)
|
||||
|
||||
# 添加因子依赖信息
|
||||
plan.factor_dependencies = factor_deps
|
||||
|
||||
# 如果数据规格为空,继承依赖因子(包括临时因子)的数据规格
|
||||
if not plan.data_specs and factor_deps:
|
||||
merged_specs: List[DataSpec] = []
|
||||
for dep_name in factor_deps:
|
||||
if dep_name in self._plans:
|
||||
merged_specs.extend(self._plans[dep_name].data_specs)
|
||||
|
||||
# 去重(基于表名)
|
||||
seen_tables: set = set()
|
||||
unique_specs: List[DataSpec] = []
|
||||
for spec in merged_specs:
|
||||
if spec.table not in seen_tables:
|
||||
seen_tables.add(spec.table)
|
||||
unique_specs.append(spec)
|
||||
plan.data_specs = unique_specs
|
||||
|
||||
self._plans[name] = plan
|
||||
|
||||
return self
|
||||
|
||||
def register(
|
||||
self,
|
||||
name: str,
|
||||
expression: Node,
|
||||
data_specs: Optional[List[DataSpec]] = None,
|
||||
) -> "FactorEngine":
|
||||
"""注册因子表达式。
|
||||
"""注册因子表达式(自动处理嵌套窗口函数)。
|
||||
|
||||
Args:
|
||||
name: 因子名称
|
||||
@@ -113,22 +169,16 @@ class FactorEngine:
|
||||
>>> engine = FactorEngine()
|
||||
>>> engine.register("ma20", ts_mean(close, 20))
|
||||
"""
|
||||
# 检测因子依赖(在注册当前因子之前检查其他已注册因子)
|
||||
factor_deps = self._find_factor_dependencies(expression)
|
||||
# 使用 AST 优化器拍平嵌套窗口函数
|
||||
flattener = ExpressionFlattener()
|
||||
flat_expression, tmp_factors = flattener.flatten(expression)
|
||||
|
||||
self.registered_expressions[name] = expression
|
||||
# 先注册所有临时因子(自动推导数据规格)
|
||||
for tmp_name, tmp_node in tmp_factors.items():
|
||||
self._register_internal(tmp_name, tmp_node, data_specs=None)
|
||||
|
||||
# 预创建执行计划
|
||||
plan = self.planner.create_plan(
|
||||
expression=expression,
|
||||
output_name=name,
|
||||
data_specs=data_specs,
|
||||
)
|
||||
|
||||
# 添加因子依赖信息
|
||||
plan.factor_dependencies = factor_deps
|
||||
|
||||
self._plans[name] = plan
|
||||
# 最后注册主因子
|
||||
self._register_internal(name, flat_expression, data_specs)
|
||||
|
||||
return self
|
||||
|
||||
@@ -174,7 +224,7 @@ class FactorEngine:
|
||||
# 解析表达式为 Node
|
||||
node = self._parser.parse(dsl_expr)
|
||||
|
||||
# 委托给 register 方法
|
||||
# 委托给 register 方法(register 会处理嵌套窗口函数拍平)
|
||||
return self.register(name, node, data_specs)
|
||||
|
||||
def add_factor(
|
||||
@@ -272,21 +322,32 @@ class FactorEngine:
|
||||
if isinstance(factor_names, str):
|
||||
factor_names = [factor_names]
|
||||
|
||||
# 1. 获取执行计划
|
||||
# 1. 收集所有需要的因子(包括临时因子依赖)
|
||||
all_factor_names = self._collect_all_dependencies(factor_names)
|
||||
|
||||
# 2. 获取执行计划
|
||||
plans = []
|
||||
for name in factor_names:
|
||||
for name in all_factor_names:
|
||||
if name not in self._plans:
|
||||
raise ValueError(f"因子未注册: {name}")
|
||||
plans.append(self._plans[name])
|
||||
|
||||
# 2. 合并数据规格并获取数据
|
||||
# 3. 合并数据规格并获取数据
|
||||
all_specs = []
|
||||
for plan in plans:
|
||||
all_specs.extend(plan.data_specs)
|
||||
|
||||
# 3. 从路由器获取核心宽表
|
||||
# 去重数据规格(基于表名)
|
||||
seen_tables: set = set()
|
||||
unique_specs: List[DataSpec] = []
|
||||
for spec in all_specs:
|
||||
if spec.table not in seen_tables:
|
||||
seen_tables.add(spec.table)
|
||||
unique_specs.append(spec)
|
||||
|
||||
# 4. 从路由器获取核心宽表
|
||||
core_data = self.router.fetch_data(
|
||||
data_specs=all_specs,
|
||||
data_specs=unique_specs,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
stock_codes=stock_codes,
|
||||
@@ -295,14 +356,14 @@ class FactorEngine:
|
||||
if len(core_data) == 0:
|
||||
raise ValueError("未获取到任何数据,请检查日期范围和股票代码")
|
||||
|
||||
# 4. 按依赖顺序执行计算
|
||||
if len(plans) == 1:
|
||||
result = self.compute_engine.execute(plans[0], core_data)
|
||||
else:
|
||||
# 使用依赖感知的方式执行
|
||||
result = self._execute_with_dependencies(factor_names, core_data)
|
||||
# 5. 按依赖顺序执行计算(包含临时因子)
|
||||
result = self._execute_with_dependencies(all_factor_names, core_data)
|
||||
|
||||
return result
|
||||
# 6. 清理内存宽表,过滤掉临时因子列(__tmp_X)
|
||||
# 保留所有非临时因子列(包括原始数据列和用户请求的因子列)
|
||||
cols_to_keep = [col for col in result.columns if not col.startswith("__tmp_")]
|
||||
|
||||
return result.select(cols_to_keep)
|
||||
|
||||
def list_registered(self) -> List[str]:
|
||||
"""获取已注册的因子列表。
|
||||
@@ -501,10 +562,32 @@ class FactorEngine:
|
||||
|
||||
return False
|
||||
|
||||
def _find_factor_dependencies(self, expression: Node) -> Set[str]:
|
||||
"""查找表达式依赖的其他因子。
|
||||
def _collect_all_dependencies(self, factor_names: List[str]) -> List[str]:
|
||||
"""收集所有因子及其依赖(包括用户定义的因子和临时因子)。"""
|
||||
collected: Set[str] = set()
|
||||
result: List[str] = []
|
||||
|
||||
遍历已注册因子,检查表达式是否包含任何已注册因子的完整表达式。
|
||||
def collect_recursive(name: str):
|
||||
if name in collected:
|
||||
return
|
||||
collected.add(name)
|
||||
|
||||
# 获取执行计划并递归收集强依赖
|
||||
plan = self._plans.get(name)
|
||||
if plan:
|
||||
for dep_name in plan.factor_dependencies:
|
||||
collect_recursive(dep_name)
|
||||
|
||||
# 依赖收集完毕,再将自己加入列表(天然形成安全的计算顺序)
|
||||
result.append(name)
|
||||
|
||||
for name in factor_names:
|
||||
collect_recursive(name)
|
||||
|
||||
return result
|
||||
|
||||
def _find_factor_dependencies(self, expression: Node) -> Set[str]:
|
||||
"""查找表达式依赖的其他因子(包括临时因子和用户因子引用)。
|
||||
|
||||
Args:
|
||||
expression: 待检查的表达式
|
||||
@@ -514,13 +597,20 @@ class FactorEngine:
|
||||
"""
|
||||
deps: Set[str] = set()
|
||||
|
||||
# 检查表达式本身是否等于某个已注册因子
|
||||
# 1. 【新增】如果直接引用了已注册的因子名称(包含 __tmp_X 或用户因子)
|
||||
if (
|
||||
isinstance(expression, Symbol)
|
||||
and expression.name in self.registered_expressions
|
||||
):
|
||||
deps.add(expression.name)
|
||||
|
||||
# 2. 检查表达式本身是否等于某个已注册因子的完整 AST
|
||||
for name, registered_expr in self.registered_expressions.items():
|
||||
if self._expressions_equal(expression, registered_expr):
|
||||
deps.add(name)
|
||||
break
|
||||
|
||||
# 递归检查子节点
|
||||
# 3. 递归检查子节点
|
||||
if isinstance(expression, BinaryOpNode):
|
||||
deps.update(self._find_factor_dependencies(expression.left))
|
||||
deps.update(self._find_factor_dependencies(expression.right))
|
||||
|
||||
Reference in New Issue
Block a user