- 新增 ExpressionFlattener 类自动拆解嵌套窗口函数(如 cs_rank(ts_delay(close, 1)))
- 支持因子引用其他因子:engine.register("fac2", cs_rank("fac1"))
- 给 DependencyExtractor 增加 ignore_symbols 免疫名单,防止已注册因子被当作数据库字段
- 添加完整测试覆盖嵌套场景和数值一致性验证
116 lines
3.5 KiB
Python
116 lines
3.5 KiB
Python
"""执行计划生成器。
|
||
|
||
整合编译器和翻译器,生成完整的执行计划。
|
||
"""
|
||
|
||
from typing import Any, Dict, List, Optional, Set, Union
|
||
|
||
from src.factors.dsl import (
|
||
Node,
|
||
Symbol,
|
||
FunctionNode,
|
||
BinaryOpNode,
|
||
UnaryOpNode,
|
||
Constant,
|
||
)
|
||
from src.factors.compiler import DependencyExtractor
|
||
from src.factors.translator import PolarsTranslator
|
||
from src.factors.engine.data_spec import DataSpec, ExecutionPlan
|
||
from src.factors.engine.schema_cache import get_schema_cache
|
||
|
||
|
||
class ExecutionPlanner:
|
||
"""执行计划生成器。
|
||
|
||
整合编译器和翻译器,生成完整的执行计划。
|
||
|
||
Attributes:
|
||
compiler: 依赖提取器
|
||
translator: Polars 翻译器
|
||
"""
|
||
|
||
def __init__(self) -> None:
|
||
"""初始化执行计划生成器。"""
|
||
self.compiler = DependencyExtractor()
|
||
self.translator = PolarsTranslator()
|
||
|
||
def create_plan(
|
||
self,
|
||
expression: Node,
|
||
output_name: str = "factor",
|
||
data_specs: Optional[List[DataSpec]] = None,
|
||
ignore_dependencies: Optional[Set[str]] = None,
|
||
) -> ExecutionPlan:
|
||
"""从表达式创建执行计划。
|
||
|
||
Args:
|
||
expression: DSL 表达式节点
|
||
output_name: 输出因子名称
|
||
data_specs: 预定义的数据规格,None 时自动推导
|
||
ignore_dependencies: 需要忽略的依赖符号集合(如已注册因子名)
|
||
|
||
Returns:
|
||
执行计划对象
|
||
"""
|
||
# 1. 提取依赖时传入要忽略的符号
|
||
dependencies = self.compiler.extract_dependencies(
|
||
expression, ignore_symbols=ignore_dependencies
|
||
)
|
||
|
||
# 2. 翻译为 Polars 表达式
|
||
polars_expr = self.translator.translate(expression)
|
||
|
||
# 3. 推导或验证数据规格
|
||
if data_specs is None:
|
||
data_specs = self._infer_data_specs(dependencies, expression)
|
||
|
||
return ExecutionPlan(
|
||
data_specs=data_specs,
|
||
polars_expr=polars_expr,
|
||
dependencies=dependencies,
|
||
output_name=output_name,
|
||
)
|
||
|
||
def _infer_data_specs(
|
||
self,
|
||
dependencies: Set[str],
|
||
expression: Node,
|
||
) -> List[DataSpec]:
|
||
"""从依赖推导数据规格(支持财务数据自动识别)。
|
||
|
||
使用 SchemaCache 动态扫描数据库表结构,自动匹配字段到对应的表。
|
||
自动识别财务数据表并配置 asof_backward 模式。
|
||
表结构只扫描一次并缓存在内存中。
|
||
|
||
Args:
|
||
dependencies: 依赖的字段集合
|
||
expression: 表达式节点
|
||
|
||
Returns:
|
||
数据规格列表
|
||
"""
|
||
# 使用 SchemaCache 自动匹配字段到表
|
||
schema_cache = get_schema_cache()
|
||
table_to_fields = schema_cache.match_fields_to_tables(dependencies)
|
||
|
||
data_specs = []
|
||
for table_name, columns in table_to_fields.items():
|
||
if schema_cache.is_financial_table(table_name):
|
||
# 财务表使用 asof_backward 模式
|
||
spec = DataSpec(
|
||
table=table_name,
|
||
columns=columns,
|
||
join_type="asof_backward",
|
||
left_on="trade_date",
|
||
right_on="f_ann_date",
|
||
)
|
||
else:
|
||
# 标准表使用默认模式
|
||
spec = DataSpec(
|
||
table=table_name,
|
||
columns=columns,
|
||
)
|
||
data_specs.append(spec)
|
||
|
||
return data_specs
|