"""因子计算引擎 - 系统统一入口。 提供从表达式到结果的完整执行链路,是研究员使用系统的唯一接口。 执行流程: 1. 注册表达式 -> 调用编译器解析依赖 2. 调用路由器连接数据库拉取并组装核心宽表 3. 调用翻译器生成物理执行计划 4. 将计划提交给计算引擎执行并行运算 5. 返回包含因子结果的数据表 """ from typing import Any, Dict, List, Optional, Set, Union, TYPE_CHECKING import polars as pl if TYPE_CHECKING: from src.factors.registry import FunctionRegistry from src.factors.metadata import FactorManager from src.factors.dsl import ( Node, Symbol, BinaryOpNode, UnaryOpNode, FunctionNode, ) from src.factors.translator import PolarsTranslator from src.factors.engine.data_spec import DataSpec, ExecutionPlan from src.factors.engine.data_router import DataRouter from src.factors.engine.planner import ExecutionPlanner from src.factors.engine.compute_engine import ComputeEngine class FactorEngine: """因子计算引擎 - 系统统一入口。 提供从表达式到结果的完整执行链路,是研究员使用系统的唯一接口。 执行流程: 1. 注册表达式 -> 调用编译器解析依赖 2. 调用路由器连接数据库拉取并组装核心宽表 3. 调用翻译器生成物理执行计划 4. 将计划提交给计算引擎执行并行运算 5. 返回包含因子结果的数据表 Attributes: router: 数据路由器 planner: 执行计划生成器 compute_engine: 计算引擎 registered_expressions: 注册的表达式字典 _registry: 函数注册表 _parser: 公式解析器 """ def __init__( self, data_source: Optional[Dict[str, pl.DataFrame]] = None, max_workers: int = 4, registry: Optional["FunctionRegistry"] = None, metadata_path: Optional[str] = None, ) -> None: """初始化因子引擎。 Args: data_source: 内存数据源,为 None 时使用数据库连接 max_workers: 并行计算的最大工作线程数 registry: 函数注册表,None 时创建独立实例 metadata_path: 因子元数据文件路径,为 None 时不启用 metadata 功能 """ from src.factors.registry import FunctionRegistry from src.factors.parser import FormulaParser self.router = DataRouter(data_source) self.planner = ExecutionPlanner() self.compute_engine = ComputeEngine(max_workers=max_workers) self.registered_expressions: Dict[str, Node] = {} self._plans: Dict[str, ExecutionPlan] = {} # 初始化注册表和解析器(支持注入外部注册表实现共享) self._registry = registry if registry is not None else FunctionRegistry() self._parser = FormulaParser(self._registry) # 初始化 metadata 管理器(可选) self._metadata: Optional["FactorManager"] = None if metadata_path is not None: from src.factors.metadata import FactorManager self._metadata = FactorManager(metadata_path) def register( self, name: str, expression: Node, data_specs: Optional[List[DataSpec]] = None, ) -> "FactorEngine": """注册因子表达式。 Args: name: 因子名称 expression: DSL 表达式 data_specs: 数据规格,None 时自动推导 Returns: self,支持链式调用 Example: >>> from src.factors.api import close, ts_mean >>> engine = FactorEngine() >>> engine.register("ma20", ts_mean(close, 20)) """ # 检测因子依赖(在注册当前因子之前检查其他已注册因子) factor_deps = self._find_factor_dependencies(expression) self.registered_expressions[name] = expression # 预创建执行计划 plan = self.planner.create_plan( expression=expression, output_name=name, data_specs=data_specs, ) # 添加因子依赖信息 plan.factor_dependencies = factor_deps self._plans[name] = plan return self def add_factor( self, name: str, expression: Union[str, Node], data_specs: Optional[List[DataSpec]] = None, ) -> "FactorEngine": """注册因子(支持字符串或 Node 表达式)。 这是 register 方法的增强版,支持字符串表达式解析。 向后兼容:register 方法保持不变,继续只接受 Node 类型。 遵循 Fail-Fast 原则:字符串表达式会立即解析,失败时立即抛出异常。 Args: name: 因子名称 expression: 字符串表达式或 Node 对象 data_specs: 可选的数据规格 Returns: self,支持链式调用 Raises: TypeError: 当 expression 类型不支持时 FormulaParseError: 当字符串解析失败时(立即报错) Example: >>> engine = FactorEngine() >>> >>> # 字符串方式(新功能) >>> engine.add_factor("ma20", "ts_mean(close, 20)") >>> >>> # Node 方式(与 register 相同) >>> from src.factors.api import close, ts_mean >>> engine.add_factor("ma20", ts_mean(close, 20)) >>> >>> # 复杂表达式 >>> engine.add_factor("alpha1", "cs_rank(close / open)") >>> >>> # 链式调用 >>> (engine ... .add_factor("ma5", "ts_mean(close, 5)") ... .add_factor("ma10", "ts_mean(close, 10)") ... .add_factor("golden_cross", "ma5 > ma10")) """ if isinstance(expression, str): # Fail-Fast:立即解析,失败立即报错 node = self._parser.parse(expression) elif isinstance(expression, Node): node = expression else: raise TypeError( f"表达式必须是 str 或 Node 类型,收到 {type(expression).__name__}" ) # 委托给现有的 register 方法 return self.register(name, node, data_specs) def add_factor_by_name( self, name: str, factor_name_in_metadata: Optional[str] = None, data_specs: Optional[List[DataSpec]] = None, ) -> "FactorEngine": """根据 metadata 中的因子名称注册因子。 从 metadata 管理器中根据因子名称查询 DSL 表达式, 然后解析并注册到引擎中。 Args: name: 要注册的因子名称(引擎中使用的名称) factor_name_in_metadata: metadata 中的因子名称, 为 None 时默认使用 name 参数 data_specs: 可选的数据规格 Returns: self,支持链式调用 Raises: RuntimeError: 当引擎未配置 metadata 路径时 ValueError: 当在 metadata 中未找到因子时 FormulaParseError: 当 DSL 表达式解析失败时 Example: >>> # 初始化时启用 metadata >>> engine = FactorEngine(metadata_path="data/factors.jsonl") >>> >>> # 注册 metadata 中的因子(使用相同名称) >>> engine.add_factor_by_name("return_5") >>> >>> # 使用不同名称注册 >>> engine.add_factor_by_name("my_mom", "momentum_5d") >>> >>> # 链式调用 >>> (engine ... .add_factor_by_name("ma20") ... .add_factor_by_name("rsi14") ... .compute(["ma20", "rsi14"], "20240101", "20240131")) """ if self._metadata is None: raise RuntimeError( "引擎未配置 metadata 路径。请在初始化时传入 metadata_path 参数," + "例如:FactorEngine(metadata_path='data/factors.jsonl')" ) # 使用传入的名称或默认使用 name query_name = ( factor_name_in_metadata if factor_name_in_metadata is not None else name ) # 从 metadata 查询因子 df = self._metadata.get_factors_by_name(query_name) if len(df) == 0: raise ValueError( f"在 metadata 中未找到因子 '{query_name}'。" + "请确认因子名称正确,或先使用 FactorManager 添加该因子。" ) # 获取 DSL 表达式 dsl_expr = df["dsl"][0] # 解析表达式为 Node node = self._parser.parse(dsl_expr) # 委托给 register 方法 return self.register(name, node, data_specs) def compute( self, factor_names: Union[str, List[str]], start_date: str, end_date: str, stock_codes: Optional[List[str]] = None, ) -> pl.DataFrame: """计算指定因子的值。 完整的执行流程:取数 -> 组装 -> 翻译 -> 计算。 Args: factor_names: 因子名称或名称列表 start_date: 开始日期 (YYYYMMDD) end_date: 结束日期 (YYYYMMDD) stock_codes: 股票代码列表,None 表示全市场 Returns: 包含因子结果的数据表 Raises: ValueError: 当因子未注册或数据不足时 Example: >>> result = engine.compute("ma20", "20240101", "20240131") >>> result = engine.compute(["ma20", "rsi"], "20240101", "20240131") """ # 标准化因子名称 if isinstance(factor_names, str): factor_names = [factor_names] # 1. 获取执行计划 plans = [] for name in factor_names: if name not in self._plans: raise ValueError(f"因子未注册: {name}") plans.append(self._plans[name]) # 2. 合并数据规格并获取数据 all_specs = [] for plan in plans: all_specs.extend(plan.data_specs) # 3. 从路由器获取核心宽表 core_data = self.router.fetch_data( data_specs=all_specs, start_date=start_date, end_date=end_date, stock_codes=stock_codes, ) if len(core_data) == 0: raise ValueError("未获取到任何数据,请检查日期范围和股票代码") # 4. 按依赖顺序执行计算 if len(plans) == 1: result = self.compute_engine.execute(plans[0], core_data) else: # 使用依赖感知的方式执行 result = self._execute_with_dependencies(factor_names, core_data) return result def list_registered(self) -> List[str]: """获取已注册的因子列表。 Returns: 因子名称列表 """ return list(self.registered_expressions.keys()) def get_expression(self, name: str) -> Optional[Node]: """获取已注册的表达式。 Args: name: 因子名称 Returns: 表达式节点,未注册时返回 None """ return self.registered_expressions.get(name) def clear(self) -> None: """清除所有注册的表达式和缓存。""" self.registered_expressions.clear() self._plans.clear() self.router.clear_cache() def preview_plan(self, factor_name: str) -> Optional[ExecutionPlan]: """预览因子的执行计划。 Args: factor_name: 因子名称 Returns: 执行计划,未注册时返回 None """ return self._plans.get(factor_name) def _execute_with_dependencies( self, factor_names: List[str], core_data: pl.DataFrame, ) -> pl.DataFrame: """按依赖顺序执行因子计算。 支持 cs_rank 等需要依赖列已存在的场景。 Args: factor_names: 因子名称列表 core_data: 核心宽表数据 Returns: 包含所有因子结果的数据表 """ # 1. 拓扑排序 sorted_names = self._topological_sort(factor_names) # 2. 按顺序执行 result = core_data for name in sorted_names: plan = self._plans[name] # 创建新的执行计划,引用已计算的依赖列 new_plan = self._create_optimized_plan(plan, result) # 执行计算 result = self.compute_engine.execute(new_plan, result) return result def _create_optimized_plan( self, plan: ExecutionPlan, current_data: pl.DataFrame, ) -> ExecutionPlan: """创建优化的执行计划。 将表达式中已计算的依赖因子替换为列引用。 Args: plan: 原始执行计划 current_data: 当前数据(包含已计算的依赖列) Returns: 新的执行计划 """ from src.factors.dsl import Symbol # 获取当前数据中已存在的列 existing_cols = set(current_data.columns) # 检查依赖列是否已存在 deps_available = plan.factor_dependencies & existing_cols if not deps_available: # 没有可用的依赖列,直接返回原计划 return plan # 获取原始表达式 original_expr = self.registered_expressions[plan.output_name] # 创建新的表达式,用 Symbol 引用替换依赖因子 def replace_with_symbol(node: Node) -> Node: """递归替换表达式中的依赖因子为 Symbol 引用。""" from typing import Any n: Any = node # 检查当前节点是否等于某个已计算依赖因子 for dep_name in deps_available: dep_expr = self.registered_expressions[dep_name] if self._expressions_equal(node, dep_expr): return Symbol(dep_name) # 递归处理子节点 if isinstance(n, BinaryOpNode): new_left = replace_with_symbol(n.left) new_right = replace_with_symbol(n.right) if new_left is not n.left or new_right is not n.right: return BinaryOpNode(n.op, new_left, new_right) elif isinstance(n, UnaryOpNode): new_operand = replace_with_symbol(n.operand) if new_operand is not n.operand: return UnaryOpNode(n.op, new_operand) elif isinstance(n, FunctionNode): new_args = [replace_with_symbol(arg) for arg in n.args] if any( new_arg is not old_arg for new_arg, old_arg in zip(new_args, n.args) ): return FunctionNode(n.func_name, *new_args) return node # 替换表达式 new_expr = replace_with_symbol(original_expr) # 重新翻译表达式 translator = PolarsTranslator() new_polars_expr = translator.translate(new_expr) # 更新依赖集合 new_factor_deps = plan.factor_dependencies - deps_available new_deps = plan.dependencies | deps_available return ExecutionPlan( data_specs=plan.data_specs, polars_expr=new_polars_expr, dependencies=new_deps, output_name=plan.output_name, factor_dependencies=new_factor_deps, ) def _expressions_equal(self, expr1: Node, expr2: Node) -> bool: """比较两个表达式是否相等。 用于检测因子间的依赖关系。 Args: expr1: 第一个表达式 expr2: 第二个表达式 Returns: 是否相等 """ from typing import Any e1: Any = expr1 e2: Any = expr2 if type(e1) != type(e2): return False if isinstance(e1, Symbol): return e1.name == e2.name from src.factors.dsl import Constant if isinstance(e1, Constant): return e1.value == e2.value if isinstance(e1, BinaryOpNode): return ( e1.op == e2.op and self._expressions_equal(e1.left, e2.left) and self._expressions_equal(e1.right, e2.right) ) if isinstance(e1, UnaryOpNode): return e1.op == e2.op and self._expressions_equal(e1.operand, e2.operand) if isinstance(e1, FunctionNode): if e1.func_name != e2.func_name or len(e1.args) != len(e2.args): return False return all( self._expressions_equal(a1, a2) for a1, a2 in zip(e1.args, e2.args) ) return False def _find_factor_dependencies(self, expression: Node) -> Set[str]: """查找表达式依赖的其他因子。 遍历已注册因子,检查表达式是否包含任何已注册因子的完整表达式。 Args: expression: 待检查的表达式 Returns: 依赖的因子名称集合 """ deps: Set[str] = set() # 检查表达式本身是否等于某个已注册因子 for name, registered_expr in self.registered_expressions.items(): if self._expressions_equal(expression, registered_expr): deps.add(name) break # 递归检查子节点 if isinstance(expression, BinaryOpNode): deps.update(self._find_factor_dependencies(expression.left)) deps.update(self._find_factor_dependencies(expression.right)) elif isinstance(expression, UnaryOpNode): deps.update(self._find_factor_dependencies(expression.operand)) elif isinstance(expression, FunctionNode): for arg in expression.args: deps.update(self._find_factor_dependencies(arg)) return deps def _topological_sort(self, factor_names: List[str]) -> List[str]: """按依赖关系对因子进行拓扑排序。 确保依赖的因子先被计算。 Args: factor_names: 因子名称列表 Returns: 排序后的因子名称列表 Raises: ValueError: 当检测到循环依赖时 """ # 构建依赖图 graph: Dict[str, Set[str]] = {} in_degree: Dict[str, int] = {} for name in factor_names: plan = self._plans[name] # 只考虑在本次计算范围内的依赖 deps = plan.factor_dependencies & set(factor_names) graph[name] = deps in_degree[name] = len(deps) # Kahn 算法 result = [] queue = [name for name, degree in in_degree.items() if degree == 0] while queue: # 按原始顺序处理同级别的因子 queue.sort(key=lambda x: factor_names.index(x)) name = queue.pop(0) result.append(name) for other in factor_names: if name in graph[other]: in_degree[other] -= 1 if in_degree[other] == 0: queue.append(other) if len(result) != len(factor_names): raise ValueError("检测到因子循环依赖") return result