"""因子计算引擎 - 系统统一入口。 提供从表达式到结果的完整执行链路,是研究员使用系统的唯一接口。 执行流程: 1. 注册表达式 -> 调用编译器解析依赖 2. 调用路由器连接数据库拉取并组装核心宽表 3. 调用翻译器生成物理执行计划 4. 将计划提交给计算引擎执行并行运算 5. 返回包含因子结果的数据表 """ from typing import Any, Dict, List, Optional, Set, Union, TYPE_CHECKING import polars as pl if TYPE_CHECKING: from src.factors.registry import FunctionRegistry from src.factors.metadata import FactorManager from src.factors.dsl import ( Node, Symbol, BinaryOpNode, UnaryOpNode, FunctionNode, ) from src.factors.translator import PolarsTranslator from src.factors.engine.data_spec import DataSpec, ExecutionPlan from src.factors.engine.data_router import DataRouter from src.factors.engine.planner import ExecutionPlanner from src.factors.engine.compute_engine import ComputeEngine from src.factors.engine.ast_optimizer import ExpressionFlattener class FactorEngine: """因子计算引擎 - 系统统一入口。 提供从表达式到结果的完整执行链路,是研究员使用系统的唯一接口。 执行流程: 1. 注册表达式 -> 调用编译器解析依赖 2. 调用路由器连接数据库拉取并组装核心宽表 3. 调用翻译器生成物理执行计划 4. 将计划提交给计算引擎执行并行运算 5. 返回包含因子结果的数据表 Attributes: router: 数据路由器 planner: 执行计划生成器 compute_engine: 计算引擎 registered_expressions: 注册的表达式字典 _registry: 函数注册表 _parser: 公式解析器 """ def __init__( self, data_source: Optional[Dict[str, pl.DataFrame]] = None, registry: Optional["FunctionRegistry"] = None, ) -> None: """初始化因子引擎。 Args: data_source: 内存数据源,为 None 时使用数据库连接 registry: 函数注册表,None 时创建独立实例 """ from src.factors.registry import FunctionRegistry from src.factors.parser import FormulaParser self.router = DataRouter(data_source) self.planner = ExecutionPlanner() self.compute_engine = ComputeEngine() self.registered_expressions: Dict[str, Node] = {} self._plans: Dict[str, ExecutionPlan] = {} # 初始化注册表和解析器(支持注入外部注册表实现共享) self._registry = registry if registry is not None else FunctionRegistry() self._parser = FormulaParser(self._registry) # 初始化 metadata 管理器(使用默认路径) from src.factors.metadata import FactorManager self._metadata = FactorManager() def _register_internal( self, name: str, expression: Node, data_specs: Optional[List[DataSpec]] = None, ) -> "FactorEngine": """内部注册方法,直接注册因子表达式。 Args: name: 因子名称 expression: DSL 表达式 data_specs: 数据规格,None 时自动推导 Returns: self,支持链式调用 """ # 检测因子依赖(在注册当前因子之前检查其他已注册因子) factor_deps = self._find_factor_dependencies(expression) # 获取当前所有已注册的因子名称(作为免疫名单,防止被当作数据库字段) known_factors = set(self.registered_expressions.keys()) self.registered_expressions[name] = expression # 预创建执行计划,过滤掉已注册的因子,防止被当作数据库字段 plan = self.planner.create_plan( expression=expression, output_name=name, data_specs=data_specs, ignore_dependencies=known_factors, ) # 添加因子依赖信息 plan.factor_dependencies = factor_deps # 如果数据规格为空,继承依赖因子(包括临时因子)的数据规格 if not plan.data_specs and factor_deps: merged_specs: List[DataSpec] = [] for dep_name in factor_deps: if dep_name in self._plans: merged_specs.extend(self._plans[dep_name].data_specs) # 去重(基于表名) seen_tables: set = set() unique_specs: List[DataSpec] = [] for spec in merged_specs: if spec.table not in seen_tables: seen_tables.add(spec.table) unique_specs.append(spec) plan.data_specs = unique_specs self._plans[name] = plan return self def register( self, name: str, expression: Node, data_specs: Optional[List[DataSpec]] = None, ) -> "FactorEngine": """注册因子表达式(自动处理嵌套窗口函数)。 Args: name: 因子名称 expression: DSL 表达式 data_specs: 数据规格,None 时自动推导 Returns: self,支持链式调用 Example: >>> from src.factors.api import close, ts_mean >>> engine = FactorEngine() >>> engine.register("ma20", ts_mean(close, 20)) """ # 使用 AST 优化器拍平嵌套窗口函数 flattener = ExpressionFlattener() flat_expression, tmp_factors = flattener.flatten(expression) # 先注册所有临时因子(自动推导数据规格) for tmp_name, tmp_node in tmp_factors.items(): self._register_internal(tmp_name, tmp_node, data_specs=None) # 最后注册主因子 self._register_internal(name, flat_expression, data_specs) return self def _add_factor_from_metadata( self, name: str, factor_name_in_metadata: str, data_specs: Optional[List[DataSpec]] = None, ) -> "FactorEngine": """从 metadata 中查询并注册因子(内部方法)。 Args: name: 要注册的因子名称(引擎中使用的名称) factor_name_in_metadata: metadata 中的因子名称 data_specs: 可选的数据规格 Returns: self,支持链式调用 Raises: RuntimeError: 当引擎未配置 metadata 路径时 ValueError: 当在 metadata 中未找到因子时 FormulaParseError: 当 DSL 表达式解析失败时 """ if self._metadata is None: raise RuntimeError( "引擎未配置 metadata 路径。请在初始化时传入 metadata_path 参数," + "例如:FactorEngine(metadata_path='data/factors.jsonl')" ) # 从 metadata 查询因子 df = self._metadata.get_factors_by_name(factor_name_in_metadata) if len(df) == 0: raise ValueError( f"在 metadata 中未找到因子 '{factor_name_in_metadata}'。" + "请确认因子名称正确,或先使用 FactorManager 添加该因子。" ) # 获取 DSL 表达式 dsl_expr = df["dsl"][0] # 解析表达式为 Node node = self._parser.parse(dsl_expr) # 委托给 register 方法(register 会处理嵌套窗口函数拍平) return self.register(name, node, data_specs) def add_factor( self, name: str, expression: Optional[Union[str, Node]] = None, data_specs: Optional[List[DataSpec]] = None, ) -> "FactorEngine": """注册因子(支持多种调用方式)。 这是 register 方法的增强版,支持以下调用方式: 1. 传入 name 和 expression:直接注册表达式(字符串或 Node) 2. 只传入 name:从 metadata 中查询表达式并注册 遵循 Fail-Fast 原则:字符串表达式会立即解析,失败时立即抛出异常。 Args: name: 因子名称(引擎中使用的名称) expression: 字符串表达式或 Node 对象,为 None 时从 metadata 查询 data_specs: 可选的数据规格 Returns: self,支持链式调用 Raises: TypeError: 当 expression 类型不支持时 FormulaParseError: 当字符串解析失败时(立即报错) RuntimeError: 当 expression 为 None 但未配置 metadata 时 ValueError: 当在 metadata 中未找到因子时 Example: >>> engine = FactorEngine() >>> >>> # 方式1:字符串表达式 >>> engine.add_factor("ma20", "ts_mean(close, 20)") >>> >>> # 方式2:Node 表达式 >>> from src.factors.api import close, ts_mean >>> engine.add_factor("ma20", ts_mean(close, 20)) >>> >>> # 方式3:从 metadata 查询(需要初始化时配置 metadata_path) >>> engine.add_factor("return_5") # 从 metadata 查询名为 return_5 的因子 >>> >>> # 链式调用 >>> (engine ... .add_factor("ma5", "ts_mean(close, 5)") ... .add_factor("ma10", "ts_mean(close, 10)") ... .add_factor("golden_cross", "ma5 > ma10")) """ if expression is None: # 从 metadata 查询表达式 return self._add_factor_from_metadata(name, name, data_specs) if isinstance(expression, str): # Fail-Fast:立即解析,失败立即报错 node = self._parser.parse(expression) elif isinstance(expression, Node): node = expression else: raise TypeError( f"表达式必须是 str 或 Node 类型,收到 {type(expression).__name__}" ) # 委托给现有的 register 方法 return self.register(name, node, data_specs) def compute( self, factor_names: Union[str, List[str]], start_date: str, end_date: str, stock_codes: Optional[List[str]] = None, ) -> pl.DataFrame: """计算指定因子的值。 完整的执行流程:取数 -> 组装 -> 翻译 -> 计算。 Args: factor_names: 因子名称或名称列表 start_date: 开始日期 (YYYYMMDD) end_date: 结束日期 (YYYYMMDD) stock_codes: 股票代码列表,None 表示全市场 Returns: 包含因子结果的数据表 Raises: ValueError: 当因子未注册或数据不足时 Example: >>> result = engine.compute("ma20", "20240101", "20240131") >>> result = engine.compute(["ma20", "rsi"], "20240101", "20240131") """ # 标准化因子名称 if isinstance(factor_names, str): factor_names = [factor_names] # 1. 收集所有需要的因子(包括临时因子依赖) all_factor_names = self._collect_all_dependencies(factor_names) # 2. 获取执行计划 plans = [] for name in all_factor_names: if name not in self._plans: raise ValueError(f"因子未注册: {name}") plans.append(self._plans[name]) # 3. 合并数据规格并获取数据 all_specs = [] for plan in plans: all_specs.extend(plan.data_specs) # 合并相同表的字段(而不是简单地去重) table_to_columns: Dict[str, Set[str]] = {} table_to_spec: Dict[str, DataSpec] = {} for spec in all_specs: if spec.table not in table_to_columns: table_to_columns[spec.table] = set() table_to_spec[spec.table] = spec table_to_columns[spec.table].update(spec.columns) # 创建合并后的数据规格 unique_specs: List[DataSpec] = [] for table_name, columns in table_to_columns.items(): original_spec = table_to_spec[table_name] unique_specs.append( DataSpec( table=table_name, columns=list(columns), join_type=original_spec.join_type, left_on=original_spec.left_on, right_on=original_spec.right_on, ) ) # 4. 从路由器获取核心宽表 core_data = self.router.fetch_data( data_specs=unique_specs, start_date=start_date, end_date=end_date, stock_codes=stock_codes, ) if len(core_data) == 0: raise ValueError("未获取到任何数据,请检查日期范围和股票代码") # 5. 按依赖顺序执行计算(包含临时因子) result = self._execute_with_dependencies(all_factor_names, core_data) # 6. 清理内存宽表,过滤掉临时因子列(__tmp_X) # 保留所有非临时因子列(包括原始数据列和用户请求的因子列) cols_to_keep = [col for col in result.columns if not col.startswith("__tmp_")] return result.select(cols_to_keep) def list_registered(self) -> List[str]: """获取已注册的因子列表。 Returns: 因子名称列表 """ return list(self.registered_expressions.keys()) def get_expression(self, name: str) -> Optional[Node]: """获取已注册的表达式。 Args: name: 因子名称 Returns: 表达式节点,未注册时返回 None """ return self.registered_expressions.get(name) def clear(self) -> None: """清除所有注册的表达式和缓存。""" self.registered_expressions.clear() self._plans.clear() self.router.clear_cache() def preview_plan(self, factor_name: str) -> Optional[ExecutionPlan]: """预览因子的执行计划。 Args: factor_name: 因子名称 Returns: 执行计划,未注册时返回 None """ return self._plans.get(factor_name) def _execute_with_dependencies( self, factor_names: List[str], core_data: pl.DataFrame, ) -> pl.DataFrame: """按依赖顺序执行因子计算。 支持 cs_rank 等需要依赖列已存在的场景。 Args: factor_names: 因子名称列表 core_data: 核心宽表数据 Returns: 包含所有因子结果的数据表 """ # 1. 拓扑排序 sorted_names = self._topological_sort(factor_names) # 2. 按顺序执行 result = core_data for name in sorted_names: plan = self._plans[name] # 创建新的执行计划,引用已计算的依赖列 new_plan = self._create_optimized_plan(plan, result) # 执行计算 result = self.compute_engine.execute(new_plan, result) return result def _create_optimized_plan( self, plan: ExecutionPlan, current_data: pl.DataFrame, ) -> ExecutionPlan: """创建优化的执行计划。 将表达式中已计算的依赖因子替换为列引用。 Args: plan: 原始执行计划 current_data: 当前数据(包含已计算的依赖列) Returns: 新的执行计划 """ from src.factors.dsl import Symbol # 获取当前数据中已存在的列 existing_cols = set(current_data.columns) # 检查依赖列是否已存在 deps_available = plan.factor_dependencies & existing_cols if not deps_available: # 没有可用的依赖列,直接返回原计划 return plan # 获取原始表达式 original_expr = self.registered_expressions[plan.output_name] # 创建新的表达式,用 Symbol 引用替换依赖因子 def replace_with_symbol(node: Node) -> Node: """递归替换表达式中的依赖因子为 Symbol 引用。""" from typing import Any n: Any = node # 检查当前节点是否等于某个已计算依赖因子 for dep_name in deps_available: dep_expr = self.registered_expressions[dep_name] if self._expressions_equal(node, dep_expr): return Symbol(dep_name) # 递归处理子节点 if isinstance(n, BinaryOpNode): new_left = replace_with_symbol(n.left) new_right = replace_with_symbol(n.right) if new_left is not n.left or new_right is not n.right: return BinaryOpNode(n.op, new_left, new_right) elif isinstance(n, UnaryOpNode): new_operand = replace_with_symbol(n.operand) if new_operand is not n.operand: return UnaryOpNode(n.op, new_operand) elif isinstance(n, FunctionNode): new_args = [replace_with_symbol(arg) for arg in n.args] if any( new_arg is not old_arg for new_arg, old_arg in zip(new_args, n.args) ): return FunctionNode(n.func_name, *new_args) return node # 替换表达式 new_expr = replace_with_symbol(original_expr) # 重新翻译表达式 translator = PolarsTranslator() new_polars_expr = translator.translate(new_expr) # 更新依赖集合 new_factor_deps = plan.factor_dependencies - deps_available new_deps = plan.dependencies | deps_available return ExecutionPlan( data_specs=plan.data_specs, polars_expr=new_polars_expr, dependencies=new_deps, output_name=plan.output_name, factor_dependencies=new_factor_deps, ) def _expressions_equal(self, expr1: Node, expr2: Node) -> bool: """比较两个表达式是否相等。 用于检测因子间的依赖关系。 Args: expr1: 第一个表达式 expr2: 第二个表达式 Returns: 是否相等 """ from typing import Any e1: Any = expr1 e2: Any = expr2 if type(e1) != type(e2): return False if isinstance(e1, Symbol): return e1.name == e2.name from src.factors.dsl import Constant if isinstance(e1, Constant): return e1.value == e2.value if isinstance(e1, BinaryOpNode): return ( e1.op == e2.op and self._expressions_equal(e1.left, e2.left) and self._expressions_equal(e1.right, e2.right) ) if isinstance(e1, UnaryOpNode): return e1.op == e2.op and self._expressions_equal(e1.operand, e2.operand) if isinstance(e1, FunctionNode): if e1.func_name != e2.func_name or len(e1.args) != len(e2.args): return False return all( self._expressions_equal(a1, a2) for a1, a2 in zip(e1.args, e2.args) ) return False def _collect_all_dependencies(self, factor_names: List[str]) -> List[str]: """收集所有因子及其依赖(包括用户定义的因子和临时因子)。""" collected: Set[str] = set() result: List[str] = [] def collect_recursive(name: str): if name in collected: return collected.add(name) # 获取执行计划并递归收集强依赖 plan = self._plans.get(name) if plan: for dep_name in plan.factor_dependencies: collect_recursive(dep_name) # 依赖收集完毕,再将自己加入列表(天然形成安全的计算顺序) result.append(name) for name in factor_names: collect_recursive(name) return result def _find_factor_dependencies(self, expression: Node) -> Set[str]: """查找表达式依赖的其他因子(包括临时因子和用户因子引用)。 Args: expression: 待检查的表达式 Returns: 依赖的因子名称集合 """ deps: Set[str] = set() # 1. 【新增】如果直接引用了已注册的因子名称(包含 __tmp_X 或用户因子) if ( isinstance(expression, Symbol) and expression.name in self.registered_expressions ): deps.add(expression.name) # 2. 检查表达式本身是否等于某个已注册因子的完整 AST for name, registered_expr in self.registered_expressions.items(): if self._expressions_equal(expression, registered_expr): deps.add(name) break # 3. 递归检查子节点 if isinstance(expression, BinaryOpNode): deps.update(self._find_factor_dependencies(expression.left)) deps.update(self._find_factor_dependencies(expression.right)) elif isinstance(expression, UnaryOpNode): deps.update(self._find_factor_dependencies(expression.operand)) elif isinstance(expression, FunctionNode): for arg in expression.args: deps.update(self._find_factor_dependencies(arg)) return deps def _topological_sort(self, factor_names: List[str]) -> List[str]: """按依赖关系对因子进行拓扑排序。 确保依赖的因子先被计算。 Args: factor_names: 因子名称列表 Returns: 排序后的因子名称列表 Raises: ValueError: 当检测到循环依赖时 """ # 构建依赖图 graph: Dict[str, Set[str]] = {} in_degree: Dict[str, int] = {} for name in factor_names: plan = self._plans[name] # 只考虑在本次计算范围内的依赖 deps = plan.factor_dependencies & set(factor_names) graph[name] = deps in_degree[name] = len(deps) # Kahn 算法 result = [] queue = [name for name, degree in in_degree.items() if degree == 0] while queue: # 按原始顺序处理同级别的因子 queue.sort(key=lambda x: factor_names.index(x)) name = queue.pop(0) result.append(name) for other in factor_names: if name in graph[other]: in_degree[other] -= 1 if in_degree[other] == 0: queue.append(other) if len(result) != len(factor_names): raise ValueError("检测到因子循环依赖") return result