"""因子计算引擎 - 系统统一入口。 提供从表达式到结果的完整执行链路,是研究员使用系统的唯一接口。 执行流程: 1. 注册表达式 -> 调用编译器解析依赖 2. 调用路由器连接数据库拉取并组装核心宽表 3. 调用翻译器生成物理执行计划 4. 将计划提交给计算引擎执行并行运算 5. 返回包含因子结果的数据表 """ from typing import Any, Dict, List, Optional, Set, Union import polars as pl from src.factors.dsl import ( Node, Symbol, BinaryOpNode, UnaryOpNode, FunctionNode, ) from src.factors.translator import PolarsTranslator from src.factors.engine.data_spec import DataSpec, ExecutionPlan from src.factors.engine.data_router import DataRouter from src.factors.engine.planner import ExecutionPlanner from src.factors.engine.compute_engine import ComputeEngine class FactorEngine: """因子计算引擎 - 系统统一入口。 提供从表达式到结果的完整执行链路,是研究员使用系统的唯一接口。 执行流程: 1. 注册表达式 -> 调用编译器解析依赖 2. 调用路由器连接数据库拉取并组装核心宽表 3. 调用翻译器生成物理执行计划 4. 将计划提交给计算引擎执行并行运算 5. 返回包含因子结果的数据表 Attributes: router: 数据路由器 planner: 执行计划生成器 compute_engine: 计算引擎 registered_expressions: 注册的表达式字典 """ def __init__( self, data_source: Optional[Dict[str, pl.DataFrame]] = None, max_workers: int = 4, ) -> None: """初始化因子引擎。 Args: data_source: 内存数据源,为 None 时使用数据库连接 max_workers: 并行计算的最大工作线程数 """ self.router = DataRouter(data_source) self.planner = ExecutionPlanner() self.compute_engine = ComputeEngine(max_workers=max_workers) self.registered_expressions: Dict[str, Node] = {} self._plans: Dict[str, ExecutionPlan] = {} def register( self, name: str, expression: Node, data_specs: Optional[List[DataSpec]] = None, ) -> "FactorEngine": """注册因子表达式。 Args: name: 因子名称 expression: DSL 表达式 data_specs: 数据规格,None 时自动推导 Returns: self,支持链式调用 Example: >>> from src.factors.api import close, ts_mean >>> engine = FactorEngine() >>> engine.register("ma20", ts_mean(close, 20)) """ # 检测因子依赖(在注册当前因子之前检查其他已注册因子) factor_deps = self._find_factor_dependencies(expression) self.registered_expressions[name] = expression # 预创建执行计划 plan = self.planner.create_plan( expression=expression, output_name=name, data_specs=data_specs, ) # 添加因子依赖信息 plan.factor_dependencies = factor_deps self._plans[name] = plan return self def compute( self, factor_names: Union[str, List[str]], start_date: str, end_date: str, stock_codes: Optional[List[str]] = None, ) -> pl.DataFrame: """计算指定因子的值。 完整的执行流程:取数 -> 组装 -> 翻译 -> 计算。 Args: factor_names: 因子名称或名称列表 start_date: 开始日期 (YYYYMMDD) end_date: 结束日期 (YYYYMMDD) stock_codes: 股票代码列表,None 表示全市场 Returns: 包含因子结果的数据表 Raises: ValueError: 当因子未注册或数据不足时 Example: >>> result = engine.compute("ma20", "20240101", "20240131") >>> result = engine.compute(["ma20", "rsi"], "20240101", "20240131") """ # 标准化因子名称 if isinstance(factor_names, str): factor_names = [factor_names] # 1. 获取执行计划 plans = [] for name in factor_names: if name not in self._plans: raise ValueError(f"因子未注册: {name}") plans.append(self._plans[name]) # 2. 合并数据规格并获取数据 all_specs = [] for plan in plans: all_specs.extend(plan.data_specs) # 3. 从路由器获取核心宽表 core_data = self.router.fetch_data( data_specs=all_specs, start_date=start_date, end_date=end_date, stock_codes=stock_codes, ) if len(core_data) == 0: raise ValueError("未获取到任何数据,请检查日期范围和股票代码") # 4. 按依赖顺序执行计算 if len(plans) == 1: result = self.compute_engine.execute(plans[0], core_data) else: # 使用依赖感知的方式执行 result = self._execute_with_dependencies(factor_names, core_data) return result def list_registered(self) -> List[str]: """获取已注册的因子列表。 Returns: 因子名称列表 """ return list(self.registered_expressions.keys()) def get_expression(self, name: str) -> Optional[Node]: """获取已注册的表达式。 Args: name: 因子名称 Returns: 表达式节点,未注册时返回 None """ return self.registered_expressions.get(name) def clear(self) -> None: """清除所有注册的表达式和缓存。""" self.registered_expressions.clear() self._plans.clear() self.router.clear_cache() def preview_plan(self, factor_name: str) -> Optional[ExecutionPlan]: """预览因子的执行计划。 Args: factor_name: 因子名称 Returns: 执行计划,未注册时返回 None """ return self._plans.get(factor_name) def _execute_with_dependencies( self, factor_names: List[str], core_data: pl.DataFrame, ) -> pl.DataFrame: """按依赖顺序执行因子计算。 支持 cs_rank 等需要依赖列已存在的场景。 Args: factor_names: 因子名称列表 core_data: 核心宽表数据 Returns: 包含所有因子结果的数据表 """ # 1. 拓扑排序 sorted_names = self._topological_sort(factor_names) # 2. 按顺序执行 result = core_data for name in sorted_names: plan = self._plans[name] # 创建新的执行计划,引用已计算的依赖列 new_plan = self._create_optimized_plan(plan, result) # 执行计算 result = self.compute_engine.execute(new_plan, result) return result def _create_optimized_plan( self, plan: ExecutionPlan, current_data: pl.DataFrame, ) -> ExecutionPlan: """创建优化的执行计划。 将表达式中已计算的依赖因子替换为列引用。 Args: plan: 原始执行计划 current_data: 当前数据(包含已计算的依赖列) Returns: 新的执行计划 """ from src.factors.dsl import Symbol # 获取当前数据中已存在的列 existing_cols = set(current_data.columns) # 检查依赖列是否已存在 deps_available = plan.factor_dependencies & existing_cols if not deps_available: # 没有可用的依赖列,直接返回原计划 return plan # 获取原始表达式 original_expr = self.registered_expressions[plan.output_name] # 创建新的表达式,用 Symbol 引用替换依赖因子 def replace_with_symbol(node: Node) -> Node: """递归替换表达式中的依赖因子为 Symbol 引用。""" from typing import Any n: Any = node # 检查当前节点是否等于某个已计算依赖因子 for dep_name in deps_available: dep_expr = self.registered_expressions[dep_name] if self._expressions_equal(node, dep_expr): return Symbol(dep_name) # 递归处理子节点 if isinstance(n, BinaryOpNode): new_left = replace_with_symbol(n.left) new_right = replace_with_symbol(n.right) if new_left is not n.left or new_right is not n.right: return BinaryOpNode(n.op, new_left, new_right) elif isinstance(n, UnaryOpNode): new_operand = replace_with_symbol(n.operand) if new_operand is not n.operand: return UnaryOpNode(n.op, new_operand) elif isinstance(n, FunctionNode): new_args = [replace_with_symbol(arg) for arg in n.args] if any( new_arg is not old_arg for new_arg, old_arg in zip(new_args, n.args) ): return FunctionNode(n.func_name, *new_args) return node # 替换表达式 new_expr = replace_with_symbol(original_expr) # 重新翻译表达式 translator = PolarsTranslator() new_polars_expr = translator.translate(new_expr) # 更新依赖集合 new_factor_deps = plan.factor_dependencies - deps_available new_deps = plan.dependencies | deps_available return ExecutionPlan( data_specs=plan.data_specs, polars_expr=new_polars_expr, dependencies=new_deps, output_name=plan.output_name, factor_dependencies=new_factor_deps, ) def _expressions_equal(self, expr1: Node, expr2: Node) -> bool: """比较两个表达式是否相等。 用于检测因子间的依赖关系。 Args: expr1: 第一个表达式 expr2: 第二个表达式 Returns: 是否相等 """ from typing import Any e1: Any = expr1 e2: Any = expr2 if type(e1) != type(e2): return False if isinstance(e1, Symbol): return e1.name == e2.name from src.factors.dsl import Constant if isinstance(e1, Constant): return e1.value == e2.value if isinstance(e1, BinaryOpNode): return ( e1.op == e2.op and self._expressions_equal(e1.left, e2.left) and self._expressions_equal(e1.right, e2.right) ) if isinstance(e1, UnaryOpNode): return e1.op == e2.op and self._expressions_equal(e1.operand, e2.operand) if isinstance(e1, FunctionNode): if e1.func_name != e2.func_name or len(e1.args) != len(e2.args): return False return all( self._expressions_equal(a1, a2) for a1, a2 in zip(e1.args, e2.args) ) return False def _find_factor_dependencies(self, expression: Node) -> Set[str]: """查找表达式依赖的其他因子。 遍历已注册因子,检查表达式是否包含任何已注册因子的完整表达式。 Args: expression: 待检查的表达式 Returns: 依赖的因子名称集合 """ deps: Set[str] = set() # 检查表达式本身是否等于某个已注册因子 for name, registered_expr in self.registered_expressions.items(): if self._expressions_equal(expression, registered_expr): deps.add(name) break # 递归检查子节点 if isinstance(expression, BinaryOpNode): deps.update(self._find_factor_dependencies(expression.left)) deps.update(self._find_factor_dependencies(expression.right)) elif isinstance(expression, UnaryOpNode): deps.update(self._find_factor_dependencies(expression.operand)) elif isinstance(expression, FunctionNode): for arg in expression.args: deps.update(self._find_factor_dependencies(arg)) return deps def _topological_sort(self, factor_names: List[str]) -> List[str]: """按依赖关系对因子进行拓扑排序。 确保依赖的因子先被计算。 Args: factor_names: 因子名称列表 Returns: 排序后的因子名称列表 Raises: ValueError: 当检测到循环依赖时 """ # 构建依赖图 graph: Dict[str, Set[str]] = {} in_degree: Dict[str, int] = {} for name in factor_names: plan = self._plans[name] # 只考虑在本次计算范围内的依赖 deps = plan.factor_dependencies & set(factor_names) graph[name] = deps in_degree[name] = len(deps) # Kahn 算法 result = [] queue = [name for name, degree in in_degree.items() if degree == 0] while queue: # 按原始顺序处理同级别的因子 queue.sort(key=lambda x: factor_names.index(x)) name = queue.pop(0) result.append(name) for other in factor_names: if name in graph[other]: in_degree[other] -= 1 if in_degree[other] == 0: queue.append(other) if len(result) != len(factor_names): raise ValueError("检测到因子循环依赖") return result