diff --git a/src/factors/__init__.py b/src/factors/__init__.py index 305f2b2..c3746be 100644 --- a/src/factors/__init__.py +++ b/src/factors/__init__.py @@ -52,6 +52,19 @@ from src.factors.engine import ( ComputeEngine, ) +from src.factors.parser import FormulaParser + +from src.factors.registry import FunctionRegistry + +from src.factors.exceptions import ( + FormulaParseError, + UnknownFunctionError, + InvalidSyntaxError, + EmptyExpressionError, + RegistryError, + DuplicateFunctionError, +) + # 保持向后兼容:factor_engine.py 中的类也可以通过 src.factors.engine 访问 # 例如:from src.factors.engine import FactorEngine @@ -76,4 +89,15 @@ __all__ = [ "DataRouter", "ExecutionPlanner", "ComputeEngine", + # 解析器 (Phase 1 新增) + "FormulaParser", + # 注册表 (Phase 1 新增) + "FunctionRegistry", + # 异常类 (Phase 1 新增) + "FormulaParseError", + "UnknownFunctionError", + "InvalidSyntaxError", + "EmptyExpressionError", + "RegistryError", + "DuplicateFunctionError", ] diff --git a/src/factors/engine.py b/src/factors/engine.py deleted file mode 100644 index 827d380..0000000 --- a/src/factors/engine.py +++ /dev/null @@ -1,1063 +0,0 @@ -"""FactorEngine - 因子计算引擎统一入口。 - -提供从表达式注册到结果输出的完整执行链路: -接收研究员的表达式 -> 调用编译器解析依赖 -> 调用路由器连接数据库拉取并组装核心宽表 --> 调用翻译器生成物理执行计划 -> 将计划提交给计算引擎执行并行运算。 -""" - -from __future__ import annotations - -from dataclasses import dataclass, field -from typing import Any, Dict, List, Optional, Set, Union -from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor -import threading - -import polars as pl - -from src.factors.dsl import ( - Node, - Symbol, - FunctionNode, - BinaryOpNode, - UnaryOpNode, - Constant, -) -from src.factors.compiler import DependencyExtractor -from src.factors.translator import PolarsTranslator -from src.data.storage import Storage - - -@dataclass -class DataSpec: - """数据规格定义。 - - 描述因子计算所需的数据表和字段。 - - Attributes: - table: 数据表名称 - columns: 需要的字段列表 - lookback_days: 回看天数(用于时序计算) - """ - - table: str - columns: List[str] - lookback_days: int = 1 - - -@dataclass -class ExecutionPlan: - """执行计划。 - - 包含完整的执行所需信息:数据源、转换逻辑、输出格式。 - - Attributes: - data_specs: 数据规格列表 - polars_expr: Polars 表达式 - dependencies: 依赖的原始字段 - output_name: 输出因子名称 - factor_dependencies: 依赖的其他因子名称(用于分步执行) - """ - - data_specs: List[DataSpec] - polars_expr: pl.Expr - dependencies: Set[str] - output_name: str - factor_dependencies: Set[str] = field(default_factory=set) - - -class DataRouter: - """数据路由器 - 按需取数、组装核心宽表。 - - 负责根据数据规格从数据源拉取数据,并组装成统一的宽表格式。 - 支持内存数据源(用于测试)和真实数据库连接。 - - Attributes: - data_source: 数据源,可以是内存 DataFrame 字典或数据库连接 - is_memory_mode: 是否为内存模式 - """ - - def __init__(self, data_source: Optional[Dict[str, pl.DataFrame]] = None) -> None: - """初始化数据路由器。 - - Args: - data_source: 内存数据源,字典格式 {表名: DataFrame} - 为 None 时自动连接 DuckDB 数据库 - """ - self.data_source = data_source or {} - self.is_memory_mode = data_source is not None - self._cache: Dict[str, pl.DataFrame] = {} - self._lock = threading.Lock() - - # 数据库模式下初始化 Storage - if not self.is_memory_mode: - self._storage = Storage() - else: - self._storage = None - - def fetch_data( - self, - data_specs: List[DataSpec], - start_date: str, - end_date: str, - stock_codes: Optional[List[str]] = None, - ) -> pl.DataFrame: - """根据数据规格获取并组装核心宽表。 - - Args: - data_specs: 数据规格列表 - start_date: 开始日期 (YYYYMMDD) - end_date: 结束日期 (YYYYMMDD) - stock_codes: 股票代码列表,None 表示全市场 - - Returns: - 组装好的核心宽表 DataFrame - - Raises: - ValueError: 当数据源中缺少必要的表或字段时 - """ - if not data_specs: - raise ValueError("数据规格不能为空") - - # 收集所有需要的表和字段 - required_tables: Dict[str, Set[str]] = {} - max_lookback = 0 - - for spec in data_specs: - if spec.table not in required_tables: - required_tables[spec.table] = set() - required_tables[spec.table].update(spec.columns) - max_lookback = max(max_lookback, spec.lookback_days) - - # 调整日期范围以包含回看期 - adjusted_start = self._adjust_start_date(start_date, max_lookback) - - # 从数据源获取各表数据 - table_data = {} - for table_name, columns in required_tables.items(): - df = self._load_table( - table_name=table_name, - columns=list(columns), - start_date=adjusted_start, - end_date=end_date, - stock_codes=stock_codes, - ) - table_data[table_name] = df - - # 组装核心宽表 - core_table = self._assemble_wide_table(table_data, required_tables) - - # 过滤到实际请求日期范围 - core_table = core_table.filter( - (pl.col("trade_date") >= start_date) & (pl.col("trade_date") <= end_date) - ) - - return core_table - - def _load_table( - self, - table_name: str, - columns: List[str], - start_date: str, - end_date: str, - stock_codes: Optional[List[str]] = None, - ) -> pl.DataFrame: - """加载单个表的数据。 - - Args: - table_name: 表名 - columns: 需要的字段 - start_date: 开始日期 - end_date: 结束日期 - stock_codes: 股票代码过滤 - - Returns: - 过滤后的 DataFrame - """ - cache_key = f"{table_name}_{start_date}_{end_date}_{stock_codes}" - - with self._lock: - if cache_key in self._cache: - return self._cache[cache_key] - - if self.is_memory_mode: - df = self._load_from_memory( - table_name, columns, start_date, end_date, stock_codes - ) - else: - df = self._load_from_database( - table_name, columns, start_date, end_date, stock_codes - ) - - with self._lock: - self._cache[cache_key] = df - - return df - - def _load_from_memory( - self, - table_name: str, - columns: List[str], - start_date: str, - end_date: str, - stock_codes: Optional[List[str]] = None, - ) -> pl.DataFrame: - """从内存数据源加载数据。""" - if table_name not in self.data_source: - raise ValueError(f"内存数据源中缺少表: {table_name}") - - df = self.data_source[table_name] - - # 确保必需字段存在 - for col in columns: - if col not in df.columns and col not in ["ts_code", "trade_date"]: - raise ValueError(f"表 {table_name} 缺少字段: {col}") - - # 过滤日期和股票 - df = df.filter( - (pl.col("trade_date") >= start_date) & (pl.col("trade_date") <= end_date) - ) - - if stock_codes is not None: - df = df.filter(pl.col("ts_code").is_in(stock_codes)) - - # 选择需要的列 - select_cols = ["ts_code", "trade_date"] + [ - c for c in columns if c in df.columns - ] - return df.select(select_cols) - - def _load_from_database( - self, - table_name: str, - columns: List[str], - start_date: str, - end_date: str, - stock_codes: Optional[List[str]] = None, - ) -> pl.DataFrame: - """从 DuckDB 数据库加载数据。 - - 利用 Storage.load_polars() 方法,支持 SQL 查询下推。 - """ - if self._storage is None: - raise RuntimeError("Storage 未初始化") - - # 检查表是否存在 - if not self._storage.exists(table_name): - raise ValueError(f"数据库中不存在表: {table_name}") - - # 构建查询参数 - # Storage.load_polars 目前只支持单个 ts_code,需要处理列表情况 - if stock_codes is not None and len(stock_codes) == 1: - ts_code_filter = stock_codes[0] - else: - ts_code_filter = None - - try: - # 从数据库加载原始数据 - df = self._storage.load_polars( - name=table_name, - start_date=start_date, - end_date=end_date, - ts_code=ts_code_filter, - ) - except Exception as e: - raise RuntimeError(f"从数据库加载表 {table_name} 失败: {e}") - - # 如果 stock_codes 是列表且长度 > 1,在内存中过滤 - if stock_codes is not None and len(stock_codes) > 1: - df = df.filter(pl.col("ts_code").is_in(stock_codes)) - - # 检查必需字段 - for col in columns: - if col not in df.columns and col not in ["ts_code", "trade_date"]: - raise ValueError(f"表 {table_name} 缺少字段: {col}") - - # 选择需要的列 - select_cols = ["ts_code", "trade_date"] + [ - c for c in columns if c in df.columns - ] - - return df.select(select_cols) - - def _assemble_wide_table( - self, - table_data: Dict[str, pl.DataFrame], - required_tables: Dict[str, Set[str]], - ) -> pl.DataFrame: - """组装多表数据为核心宽表。 - - 使用 left join 合并各表数据,以第一个表为基准。 - - Args: - table_data: 表名到 DataFrame 的映射 - required_tables: 表名到字段集合的映射 - - Returns: - 组装后的宽表 - """ - if not table_data: - raise ValueError("没有数据可组装") - - # 以第一个表为基准 - base_table_name = list(table_data.keys())[0] - result = table_data[base_table_name] - - # 与其他表 join - for table_name, df in table_data.items(): - if table_name == base_table_name: - continue - - # 使用 ts_code 和 trade_date 作为 join 键 - result = result.join( - df, - on=["ts_code", "trade_date"], - how="left", - ) - - return result - - def _adjust_start_date(self, start_date: str, lookback_days: int) -> str: - """根据回看天数调整开始日期。 - - Args: - start_date: 原始开始日期 (YYYYMMDD) - lookback_days: 需要回看的交易日数 - - Returns: - 调整后的开始日期 - """ - # 简化的日期调整:假设每月30天,向前推移 - # 实际应用中应该使用交易日历 - year = int(start_date[:4]) - month = int(start_date[4:6]) - day = int(start_date[6:8]) - - total_days = lookback_days + 30 # 额外缓冲 - - day -= total_days - while day <= 0: - month -= 1 - if month <= 0: - month = 12 - year -= 1 - day += 30 - - return f"{year:04d}{month:02d}{day:02d}" - - def clear_cache(self) -> None: - """清除数据缓存。""" - with self._lock: - self._cache.clear() - - # 数据库模式下清理 Storage 连接(可选) - if not self.is_memory_mode and self._storage is not None: - # Storage 使用单例模式,不需要关闭连接 - pass - - -class ExecutionPlanner: - """执行计划生成器。 - - 整合编译器和翻译器,生成完整的执行计划。 - - Attributes: - compiler: 依赖提取器 - translator: Polars 翻译器 - """ - - def __init__(self) -> None: - """初始化执行计划生成器。""" - self.compiler = DependencyExtractor() - self.translator = PolarsTranslator() - - def create_plan( - self, - expression: Node, - output_name: str = "factor", - data_specs: Optional[List[DataSpec]] = None, - ) -> ExecutionPlan: - """从表达式创建执行计划。 - - Args: - expression: DSL 表达式节点 - output_name: 输出因子名称 - data_specs: 预定义的数据规格,None 时自动推导 - - Returns: - 执行计划对象 - """ - # 1. 提取依赖 - dependencies = self.compiler.extract_dependencies(expression) - - # 2. 翻译为 Polars 表达式 - polars_expr = self.translator.translate(expression) - - # 3. 推导或验证数据规格 - if data_specs is None: - data_specs = self._infer_data_specs(dependencies, expression) - - return ExecutionPlan( - data_specs=data_specs, - polars_expr=polars_expr, - dependencies=dependencies, - output_name=output_name, - ) - - def _infer_data_specs( - self, - dependencies: Set[str], - expression: Node, - ) -> List[DataSpec]: - """从依赖推导数据规格。 - - 根据表达式中的函数类型推断回看天数需求。 - 基础行情字段(open, high, low, close, vol, amount, pre_close, change, pct_chg) - 默认从 pro_bar 表获取。 - - Args: - dependencies: 依赖的字段集合 - expression: 表达式节点 - - Returns: - 数据规格列表 - """ - # 计算最大回看窗口 - max_window = self._extract_max_window(expression) - lookback_days = max(1, max_window) - - # 基础行情字段集合(这些字段从 pro_bar 表获取) - pro_bar_fields = { - "open", - "high", - "low", - "close", - "vol", - "amount", - "pre_close", - "change", - "pct_chg", - "turnover_rate", - "volume_ratio", - } - - # 将依赖分为 pro_bar 字段和其他字段 - pro_bar_deps = dependencies & pro_bar_fields - other_deps = dependencies - pro_bar_fields - - data_specs = [] - - # pro_bar 表的数据规格 - if pro_bar_deps: - data_specs.append( - DataSpec( - table="pro_bar", - columns=sorted(pro_bar_deps), - lookback_days=lookback_days, - ) - ) - - # 其他字段从 daily 表获取 - if other_deps: - data_specs.append( - DataSpec( - table="daily", - columns=sorted(other_deps), - lookback_days=lookback_days, - ) - ) - - return data_specs - - def _extract_max_window(self, node: Node) -> int: - """从表达式中提取最大窗口大小。 - - Args: - node: AST 节点 - - Returns: - 最大窗口大小,无时序函数返回 1 - """ - if isinstance(node, FunctionNode): - window = 1 - # 检查函数参数中的窗口大小 - for arg in node.args: - if ( - isinstance(arg, Constant) - and isinstance(arg.value, int) - and arg.value > window - ): - window = arg.value - - # 递归检查子表达式 - for arg in node.args: - if isinstance(arg, Node) and not isinstance(arg, Constant): - window = max(window, self._extract_max_window(arg)) - - return window - - elif isinstance(node, BinaryOpNode): - return max( - self._extract_max_window(node.left), - self._extract_max_window(node.right), - ) - - elif isinstance(node, UnaryOpNode): - return self._extract_max_window(node.operand) - - return 1 - - -class ComputeEngine: - """计算引擎 - 执行并行运算。 - - 负责将执行计划应用到数据上,支持并行计算。 - - Attributes: - max_workers: 最大并行工作线程数 - use_processes: 是否使用进程池(CPU 密集型任务) - """ - - def __init__( - self, - max_workers: int = 4, - use_processes: bool = False, - ) -> None: - """初始化计算引擎。 - - Args: - max_workers: 最大并行工作线程数 - use_processes: 是否使用进程池代替线程池 - """ - self.max_workers = max_workers - self.use_processes = use_processes - - def execute( - self, - plan: ExecutionPlan, - data: pl.DataFrame, - ) -> pl.DataFrame: - """执行计算计划。 - - Args: - plan: 执行计划 - data: 输入数据(核心宽表) - - Returns: - 包含因子结果的 DataFrame - """ - # 检查依赖字段是否存在 - missing_cols = plan.dependencies - set(data.columns) - if missing_cols: - raise ValueError(f"数据缺少必要的字段: {missing_cols}") - - # 执行计算 - result = data.with_columns([plan.polars_expr.alias(plan.output_name)]) - - return result - - def execute_batch( - self, - plans: List[ExecutionPlan], - data: pl.DataFrame, - ) -> pl.DataFrame: - """批量执行多个计算计划。 - - Args: - plans: 执行计划列表 - data: 输入数据 - - Returns: - 包含所有因子结果的 DataFrame - """ - result = data - - for plan in plans: - result = self.execute(plan, result) - - return result - - def execute_parallel( - self, - plans: List[ExecutionPlan], - data: pl.DataFrame, - ) -> pl.DataFrame: - """并行执行多个计算计划。 - - Args: - plans: 执行计划列表 - data: 输入数据 - - Returns: - 包含所有因子结果的 DataFrame - """ - # 检查计划间依赖 - independent_plans = [] - dependent_plans = [] - available_cols = set(data.columns) - - for plan in plans: - if plan.dependencies <= available_cols: - independent_plans.append(plan) - available_cols.add(plan.output_name) - else: - dependent_plans.append(plan) - - # 并行执行独立计划 - if independent_plans: - ExecutorClass = ( - ProcessPoolExecutor if self.use_processes else ThreadPoolExecutor - ) - - with ExecutorClass(max_workers=self.max_workers) as executor: - futures = { - executor.submit(self._execute_single, plan, data): plan - for plan in independent_plans - } - - results = [] - for future in futures: - plan = futures[future] - try: - result_col = future.result() - results.append((plan.output_name, result_col)) - except Exception as e: - raise RuntimeError(f"计算因子 {plan.output_name} 失败: {e}") - - # 合并结果 - for name, series in results: - data = data.with_columns([series.alias(name)]) - - # 顺序执行依赖计划 - for plan in dependent_plans: - data = self.execute(plan, data) - - return data - - def _execute_single( - self, - plan: ExecutionPlan, - data: pl.DataFrame, - ) -> pl.Series: - """执行单个计划并返回结果列。 - - Args: - plan: 执行计划 - data: 输入数据 - - Returns: - 计算结果序列 - """ - result = self.execute(plan, data) - return result[plan.output_name] - - -class FactorEngine: - """因子计算引擎 - 系统统一入口。 - - 提供从表达式到结果的完整执行链路,是研究员使用系统的唯一接口。 - - 执行流程: - 1. 注册表达式 -> 调用编译器解析依赖 - 2. 调用路由器连接数据库拉取并组装核心宽表 - 3. 调用翻译器生成物理执行计划 - 4. 将计划提交给计算引擎执行并行运算 - 5. 返回包含因子结果的数据表 - - Attributes: - router: 数据路由器 - planner: 执行计划生成器 - compute_engine: 计算引擎 - registered_expressions: 注册的表达式字典 - """ - - def __init__( - self, - data_source: Optional[Dict[str, pl.DataFrame]] = None, - max_workers: int = 4, - ) -> None: - """初始化因子引擎。 - - Args: - data_source: 内存数据源,为 None 时使用数据库连接 - max_workers: 并行计算的最大工作线程数 - """ - self.router = DataRouter(data_source) - self.planner = ExecutionPlanner() - self.compute_engine = ComputeEngine(max_workers=max_workers) - self.registered_expressions: Dict[str, Node] = {} - self._plans: Dict[str, ExecutionPlan] = {} - - def register( - self, - name: str, - expression: Node, - data_specs: Optional[List[DataSpec]] = None, - ) -> FactorEngine: - """注册因子表达式。 - - Args: - name: 因子名称 - expression: DSL 表达式 - data_specs: 数据规格,None 时自动推导 - - Returns: - self,支持链式调用 - - Example: - >>> from src.factors.api import close, ts_mean - >>> engine = FactorEngine() - >>> engine.register("ma20", ts_mean(close, 20)) - """ - # 检测因子依赖(在注册当前因子之前检查其他已注册因子) - factor_deps = self._find_factor_dependencies(expression) - - self.registered_expressions[name] = expression - - # 预创建执行计划 - plan = self.planner.create_plan( - expression=expression, - output_name=name, - data_specs=data_specs, - ) - - # 添加因子依赖信息 - plan.factor_dependencies = factor_deps - - self._plans[name] = plan - - return self - - def compute( - self, - factor_names: Union[str, List[str]], - start_date: str, - end_date: str, - stock_codes: Optional[List[str]] = None, - ) -> pl.DataFrame: - """计算指定因子的值。 - - 完整的执行流程:取数 -> 组装 -> 翻译 -> 计算。 - - Args: - factor_names: 因子名称或名称列表 - start_date: 开始日期 (YYYYMMDD) - end_date: 结束日期 (YYYYMMDD) - stock_codes: 股票代码列表,None 表示全市场 - - Returns: - 包含因子结果的数据表 - - Raises: - ValueError: 当因子未注册或数据不足时 - - Example: - >>> result = engine.compute("ma20", "20240101", "20240131") - >>> result = engine.compute(["ma20", "rsi"], "20240101", "20240131") - """ - # 标准化因子名称 - if isinstance(factor_names, str): - factor_names = [factor_names] - - # 1. 获取执行计划 - plans = [] - for name in factor_names: - if name not in self._plans: - raise ValueError(f"因子未注册: {name}") - plans.append(self._plans[name]) - - # 2. 合并数据规格并获取数据 - all_specs = [] - for plan in plans: - all_specs.extend(plan.data_specs) - - # 3. 从路由器获取核心宽表 - core_data = self.router.fetch_data( - data_specs=all_specs, - start_date=start_date, - end_date=end_date, - stock_codes=stock_codes, - ) - - if len(core_data) == 0: - raise ValueError("未获取到任何数据,请检查日期范围和股票代码") - - # 4. 按依赖顺序执行计算 - if len(plans) == 1: - result = self.compute_engine.execute(plans[0], core_data) - else: - # 使用依赖感知的方式执行 - result = self._execute_with_dependencies(factor_names, core_data) - - return result - - def list_registered(self) -> List[str]: - """获取已注册的因子列表。 - - Returns: - 因子名称列表 - """ - return list(self.registered_expressions.keys()) - - def get_expression(self, name: str) -> Optional[Node]: - """获取已注册的表达式。 - - Args: - name: 因子名称 - - Returns: - 表达式节点,未注册时返回 None - """ - return self.registered_expressions.get(name) - - def clear(self) -> None: - """清除所有注册的表达式和缓存。""" - self.registered_expressions.clear() - self._plans.clear() - self.router.clear_cache() - - def preview_plan(self, factor_name: str) -> Optional[ExecutionPlan]: - """预览因子的执行计划。 - - Args: - factor_name: 因子名称 - - Returns: - 执行计划,未注册时返回 None - """ - return self._plans.get(factor_name) - - def _execute_with_dependencies( - self, - factor_names: List[str], - core_data: pl.DataFrame, - ) -> pl.DataFrame: - """按依赖顺序执行因子计算。 - - 支持 cs_rank 等需要依赖列已存在的场景。 - - Args: - factor_names: 因子名称列表 - core_data: 核心宽表数据 - - Returns: - 包含所有因子结果的数据表 - """ - # 1. 拓扑排序 - sorted_names = self._topological_sort(factor_names) - - # 2. 按顺序执行 - result = core_data - for name in sorted_names: - plan = self._plans[name] - - # 创建新的执行计划,引用已计算的依赖列 - new_plan = self._create_optimized_plan(plan, result) - - # 执行计算 - result = self.compute_engine.execute(new_plan, result) - - return result - - def _create_optimized_plan( - self, - plan: ExecutionPlan, - current_data: pl.DataFrame, - ) -> ExecutionPlan: - """创建优化的执行计划。 - - 将表达式中已计算的依赖因子替换为列引用。 - - Args: - plan: 原始执行计划 - current_data: 当前数据(包含已计算的依赖列) - - Returns: - 新的执行计划 - """ - from src.factors.dsl import Symbol - - # 获取当前数据中已存在的列 - existing_cols = set(current_data.columns) - - # 检查依赖列是否已存在 - deps_available = plan.factor_dependencies & existing_cols - - if not deps_available: - # 没有可用的依赖列,直接返回原计划 - return plan - - # 获取原始表达式 - original_expr = self.registered_expressions[plan.output_name] - - # 创建新的表达式,用 Symbol 引用替换依赖因子 - def replace_with_symbol(node: Node) -> Node: - """递归替换表达式中的依赖因子为 Symbol 引用。""" - from typing import Any - - n: Any = node - - # 检查当前节点是否等于某个已计算依赖因子 - for dep_name in deps_available: - dep_expr = self.registered_expressions[dep_name] - if self._expressions_equal(node, dep_expr): - return Symbol(dep_name) - - # 递归处理子节点 - if isinstance(n, BinaryOpNode): - new_left = replace_with_symbol(n.left) - new_right = replace_with_symbol(n.right) - if new_left is not n.left or new_right is not n.right: - return BinaryOpNode(n.op, new_left, new_right) - elif isinstance(n, UnaryOpNode): - new_operand = replace_with_symbol(n.operand) - if new_operand is not n.operand: - return UnaryOpNode(n.op, new_operand) - elif isinstance(n, FunctionNode): - new_args = [replace_with_symbol(arg) for arg in n.args] - if any( - new_arg is not old_arg for new_arg, old_arg in zip(new_args, n.args) - ): - return FunctionNode(n.func_name, *new_args) - - return node - - # 替换表达式 - new_expr = replace_with_symbol(original_expr) - - # 重新翻译表达式 - translator = PolarsTranslator() - new_polars_expr = translator.translate(new_expr) - - # 更新依赖集合 - new_factor_deps = plan.factor_dependencies - deps_available - new_deps = plan.dependencies | deps_available - - return ExecutionPlan( - data_specs=plan.data_specs, - polars_expr=new_polars_expr, - dependencies=new_deps, - output_name=plan.output_name, - factor_dependencies=new_factor_deps, - ) - - def _expressions_equal(self, expr1: Node, expr2: Node) -> bool: - """比较两个表达式是否相等。 - - 用于检测因子间的依赖关系。 - - Args: - expr1: 第一个表达式 - expr2: 第二个表达式 - - Returns: - 是否相等 - """ - from typing import Any - - e1: Any = expr1 - e2: Any = expr2 - - if type(e1) != type(e2): - return False - - if isinstance(e1, Symbol): - return e1.name == e2.name - - if isinstance(e1, Constant): - return e1.value == e2.value - - if isinstance(e1, BinaryOpNode): - return ( - e1.op == e2.op - and self._expressions_equal(e1.left, e2.left) - and self._expressions_equal(e1.right, e2.right) - ) - - if isinstance(e1, UnaryOpNode): - return e1.op == e2.op and self._expressions_equal(e1.operand, e2.operand) - - if isinstance(e1, FunctionNode): - if e1.func_name != e2.func_name or len(e1.args) != len(e2.args): - return False - return all( - self._expressions_equal(a1, a2) for a1, a2 in zip(e1.args, e2.args) - ) - - return False - - def _find_factor_dependencies(self, expression: Node) -> Set[str]: - """查找表达式依赖的其他因子。 - - 遍历已注册因子,检查表达式是否包含任何已注册因子的完整表达式。 - - Args: - expression: 待检查的表达式 - - Returns: - 依赖的因子名称集合 - """ - deps: Set[str] = set() - - # 检查表达式本身是否等于某个已注册因子 - for name, registered_expr in self.registered_expressions.items(): - if self._expressions_equal(expression, registered_expr): - deps.add(name) - break - - # 递归检查子节点 - if isinstance(expression, BinaryOpNode): - deps.update(self._find_factor_dependencies(expression.left)) - deps.update(self._find_factor_dependencies(expression.right)) - elif isinstance(expression, UnaryOpNode): - deps.update(self._find_factor_dependencies(expression.operand)) - elif isinstance(expression, FunctionNode): - for arg in expression.args: - deps.update(self._find_factor_dependencies(arg)) - - return deps - - def _topological_sort(self, factor_names: List[str]) -> List[str]: - """按依赖关系对因子进行拓扑排序。 - - 确保依赖的因子先被计算。 - - Args: - factor_names: 因子名称列表 - - Returns: - 排序后的因子名称列表 - - Raises: - ValueError: 当检测到循环依赖时 - """ - # 构建依赖图 - graph: Dict[str, Set[str]] = {} - in_degree: Dict[str, int] = {} - - for name in factor_names: - plan = self._plans[name] - # 只考虑在本次计算范围内的依赖 - deps = plan.factor_dependencies & set(factor_names) - graph[name] = deps - in_degree[name] = len(deps) - - # Kahn 算法 - result = [] - queue = [name for name, degree in in_degree.items() if degree == 0] - - while queue: - # 按原始顺序处理同级别的因子 - queue.sort(key=lambda x: factor_names.index(x)) - name = queue.pop(0) - result.append(name) - - for other in factor_names: - if name in graph[other]: - in_degree[other] -= 1 - if in_degree[other] == 0: - queue.append(other) - - if len(result) != len(factor_names): - raise ValueError("检测到因子循环依赖") - - return result diff --git a/src/factors/engine/__init__.py b/src/factors/engine/__init__.py index 63cdace..81dca8a 100644 --- a/src/factors/engine/__init__.py +++ b/src/factors/engine/__init__.py @@ -23,3 +23,6 @@ __all__ = [ "ComputeEngine", "FactorEngine", ] + +# 类型导出(用于类型注解) +# FunctionRegistry 从 src.factors.registry 导入 diff --git a/src/factors/engine/factor_engine.py b/src/factors/engine/factor_engine.py index 0c1cd3f..160abcd 100644 --- a/src/factors/engine/factor_engine.py +++ b/src/factors/engine/factor_engine.py @@ -10,10 +10,13 @@ 5. 返回包含因子结果的数据表 """ -from typing import Any, Dict, List, Optional, Set, Union +from typing import Any, Dict, List, Optional, Set, Union, TYPE_CHECKING import polars as pl +if TYPE_CHECKING: + from src.factors.registry import FunctionRegistry + from src.factors.dsl import ( Node, Symbol, @@ -45,25 +48,36 @@ class FactorEngine: planner: 执行计划生成器 compute_engine: 计算引擎 registered_expressions: 注册的表达式字典 + _registry: 函数注册表 + _parser: 公式解析器 """ def __init__( self, data_source: Optional[Dict[str, pl.DataFrame]] = None, max_workers: int = 4, + registry: Optional["FunctionRegistry"] = None, ) -> None: """初始化因子引擎。 Args: data_source: 内存数据源,为 None 时使用数据库连接 max_workers: 并行计算的最大工作线程数 + registry: 函数注册表,None 时创建独立实例 """ + from src.factors.registry import FunctionRegistry + from src.factors.parser import FormulaParser + self.router = DataRouter(data_source) self.planner = ExecutionPlanner() self.compute_engine = ComputeEngine(max_workers=max_workers) self.registered_expressions: Dict[str, Node] = {} self._plans: Dict[str, ExecutionPlan] = {} + # 初始化注册表和解析器(支持注入外部注册表实现共享) + self._registry = registry if registry is not None else FunctionRegistry() + self._parser = FormulaParser(self._registry) + def register( self, name: str, @@ -104,6 +118,63 @@ class FactorEngine: return self + def add_factor( + self, + name: str, + expression: Union[str, Node], + data_specs: Optional[List[DataSpec]] = None, + ) -> "FactorEngine": + """注册因子(支持字符串或 Node 表达式)。 + + 这是 register 方法的增强版,支持字符串表达式解析。 + 向后兼容:register 方法保持不变,继续只接受 Node 类型。 + + 遵循 Fail-Fast 原则:字符串表达式会立即解析,失败时立即抛出异常。 + + Args: + name: 因子名称 + expression: 字符串表达式或 Node 对象 + data_specs: 可选的数据规格 + + Returns: + self,支持链式调用 + + Raises: + TypeError: 当 expression 类型不支持时 + FormulaParseError: 当字符串解析失败时(立即报错) + + Example: + >>> engine = FactorEngine() + >>> + >>> # 字符串方式(新功能) + >>> engine.add_factor("ma20", "ts_mean(close, 20)") + >>> + >>> # Node 方式(与 register 相同) + >>> from src.factors.api import close, ts_mean + >>> engine.add_factor("ma20", ts_mean(close, 20)) + >>> + >>> # 复杂表达式 + >>> engine.add_factor("alpha1", "cs_rank(close / open)") + >>> + >>> # 链式调用 + >>> (engine + ... .add_factor("ma5", "ts_mean(close, 5)") + ... .add_factor("ma10", "ts_mean(close, 10)") + ... .add_factor("golden_cross", "ma5 > ma10")) + """ + if isinstance(expression, str): + # Fail-Fast:立即解析,失败立即报错 + node = self._parser.parse(expression) + elif isinstance(expression, Node): + node = expression + else: + raise TypeError( + f"表达式必须是 str 或 Node 类型,收到 {type(expression).__name__}" + ) + + # 委托给现有的 register 方法 + return self.register(name, node, data_specs) + def compute( self, factor_names: Union[str, List[str]], diff --git a/src/factors/exceptions.py b/src/factors/exceptions.py new file mode 100644 index 0000000..1686862 --- /dev/null +++ b/src/factors/exceptions.py @@ -0,0 +1,144 @@ +"""公式解析异常定义。 + +提供清晰的错误信息,帮助用户快速定位公式解析问题。 +""" + +import difflib +from typing import List, Optional + + +class FormulaParseError(Exception): + """公式解析错误基类。 + + Attributes: + expr: 原始表达式字符串 + lineno: 错误所在行号(从1开始) + col_offset: 错误所在列号(从0开始) + """ + + def __init__( + self, + message: str, + expr: Optional[str] = None, + lineno: Optional[int] = None, + col_offset: Optional[int] = None, + ): + self.expr = expr + self.lineno = lineno + self.col_offset = col_offset + + # 构建详细错误信息 + full_message = self._format_message(message) + super().__init__(full_message) + + def _format_message(self, message: str) -> str: + """格式化错误信息,包含位置指示器。""" + lines = [f"FormulaParseError: {message}"] + + if self.expr: + lines.append(f" 公式: {self.expr}") + + # 添加错误位置指示器 + if self.col_offset is not None and self.lineno is not None: + # 计算错误行在表达式中的起始位置 + expr_lines = self.expr.split("\n") + if 1 <= self.lineno <= len(expr_lines): + error_line = expr_lines[self.lineno - 1] + lines.append(f" {error_line}") + # 添加指向错误位置的箭头 + pointer = " " * (self.col_offset + 7) + "^--- 此处出错" + lines.append(pointer) + + return "\n".join(lines) + + +class UnknownFunctionError(FormulaParseError): + """未知函数错误。 + + 当表达式中使用了未注册的函数时抛出。 + + Attributes: + func_name: 未知的函数名 + available: 可用函数列表 + suggestions: 模糊匹配建议列表 + """ + + def __init__( + self, + func_name: str, + available: List[str], + expr: Optional[str] = None, + lineno: Optional[int] = None, + col_offset: Optional[int] = None, + ): + self.func_name = func_name + self.available = available + + # 使用 difflib 获取模糊匹配建议 + self.suggestions = difflib.get_close_matches( + func_name, available, n=3, cutoff=0.5 + ) + + # 构建错误信息 + if self.suggestions: + suggestion_str = ", ".join(f"'{s}'" for s in self.suggestions) + hint_msg = f"你是不是想找: {suggestion_str}?" + else: + # 只显示前10个可用函数 + available_preview = ", ".join(available[:10]) + if len(available) > 10: + available_preview += f", ... 等共 {len(available)} 个函数" + hint_msg = f"可用函数预览: {available_preview}" + + msg = f"未知函数 '{func_name}'。{hint_msg}" + + super().__init__( + message=msg, + expr=expr, + lineno=lineno, + col_offset=col_offset, + ) + + +class InvalidSyntaxError(FormulaParseError): + """语法错误。 + + 当表达式语法不正确或不支持时抛出。 + """ + + pass + + +class UnsupportedOperatorError(InvalidSyntaxError): + """不支持的运算符错误。 + + 当使用了不支持的运算符时抛出(如位运算、矩阵运算等)。 + """ + + pass + + +class EmptyExpressionError(FormulaParseError): + """空表达式错误。""" + + def __init__(self): + super().__init__("表达式不能为空或只包含空白字符") + + +class RegistryError(Exception): + """注册表错误基类。""" + + pass + + +class DuplicateFunctionError(RegistryError): + """函数重复注册错误。 + + 当尝试注册已存在的函数且未设置 force=True 时抛出。 + """ + + def __init__(self, func_name: str): + self.func_name = func_name + super().__init__( + f"函数 '{func_name}' 已存在。使用 force=True 覆盖,或选择其他名称。" + ) diff --git a/src/factors/parser.py b/src/factors/parser.py new file mode 100644 index 0000000..7d8f844 --- /dev/null +++ b/src/factors/parser.py @@ -0,0 +1,411 @@ +"""公式解析器 - 将字符串表达式转换为 DSL 节点树。 + +基于 Python ast 模块实现,支持算术运算、比较运算、函数调用等。 + +示例: + >>> from src.factors.parser import FormulaParser + >>> from src.factors.registry import FunctionRegistry + >>> parser = FormulaParser(FunctionRegistry()) + >>> node = parser.parse("ts_mean(close, 20)") + >>> print(node) + ts_mean(close, 20) +""" + +import ast +from typing import Any, Dict, Optional, TYPE_CHECKING + +from src.factors.dsl import Node, Symbol, Constant, BinaryOpNode, UnaryOpNode +from src.factors.exceptions import ( + FormulaParseError, + UnknownFunctionError, + InvalidSyntaxError, + EmptyExpressionError, +) + +if TYPE_CHECKING: + from src.factors.registry import FunctionRegistry + + +# 运算符映射表 +BIN_OP_MAP: Dict[type, str] = { + ast.Add: "+", + ast.Sub: "-", + ast.Mult: "*", + ast.Div: "/", + ast.Pow: "**", + ast.FloorDiv: "//", + ast.Mod: "%", +} + +UNARY_OP_MAP: Dict[type, str] = { + ast.UAdd: "+", + ast.USub: "-", + ast.Invert: "~", # 不支持,应报错 +} + +COMPARE_OP_MAP: Dict[type, str] = { + ast.Eq: "==", + ast.NotEq: "!=", + ast.Lt: "<", + ast.LtE: "<=", + ast.Gt: ">", + ast.GtE: ">=", +} + + +class FormulaParser: + """基于 AST 的公式解析器。 + + 将字符串表达式解析为 DSL 节点树,支持: + - 符号引用(如 close, open) + - 数值常量(如 20, 3.14) + - 二元运算(如 +, -, *, /) + - 一元运算(如 -x) + - 函数调用(如 ts_mean(close, 20)) + - 比较运算(如 close > open) + + Attributes: + registry: 函数注册表,用于解析函数调用 + """ + + def __init__(self, registry: "FunctionRegistry") -> None: + """初始化解析器。 + + Args: + registry: 函数注册表,提供函数名到可调用对象的映射 + """ + self.registry = registry + + def parse(self, expr: str) -> Node: + """解析字符串表达式为 Node 树。 + + Args: + expr: 公式字符串,如 "ts_mean(close, 20)" + + Returns: + 解析后的 Node 节点 + + Raises: + EmptyExpressionError: 表达式为空时抛出 + SyntaxError: Python 语法错误时抛出 + FormulaParseError: 解析失败时抛出 + + Example: + >>> parser.parse("close / open") + BinaryOpNode("/", Symbol("close"), Symbol("open")) + """ + # 检查空表达式 + if not expr or not expr.strip(): + raise EmptyExpressionError() + + # 解析为 Python AST + try: + tree = ast.parse(expr, mode="eval") + except SyntaxError as e: + # 将 SyntaxError 包装为 InvalidSyntaxError,统一异常类型 + raise InvalidSyntaxError( + message=f"表达式语法错误: {e.msg}", + expr=expr, + lineno=e.lineno, + col_offset=e.offset, + ) from e + + # 递归访问 AST 节点 + try: + return self._visit(tree.body, expr) + except FormulaParseError: + # 重新抛出 FormulaParseError(保留已有的位置信息) + raise + except Exception as e: + # 将其他异常包装为 FormulaParseError + if not isinstance(e, FormulaParseError): + raise FormulaParseError( + message=f"解析失败: {str(e)}", + expr=expr, + ) from e + raise + + def _visit(self, node: ast.AST, expr: str) -> Node: + """递归访问 AST 节点并转换为 DSL 节点。 + + Args: + node: Python AST 节点 + expr: 原始表达式字符串(用于错误报告) + + Returns: + 对应的 DSL 节点 + + Raises: + InvalidSyntaxError: 遇到不支持的语法时抛出 + """ + # 提取位置信息(如果节点有) + lineno = getattr(node, "lineno", None) + col_offset = getattr(node, "col_offset", None) + + try: + if isinstance(node, ast.Name): + return self._visit_Name(node) + elif isinstance(node, ast.Constant): + return self._visit_Constant(node, expr) + elif isinstance(node, ast.BinOp): + return self._visit_BinOp(node, expr) + elif isinstance(node, ast.UnaryOp): + return self._visit_UnaryOp(node, expr) + elif isinstance(node, ast.Call): + return self._visit_Call(node, expr) + elif isinstance(node, ast.Compare): + return self._visit_Compare(node, expr) + else: + raise InvalidSyntaxError( + message=f"不支持的语法: {type(node).__name__}", + expr=expr, + lineno=lineno, + col_offset=col_offset, + ) + except FormulaParseError: + # 重新抛出(保留已有的位置信息) + raise + except Exception as e: + # 包装为 FormulaParseError,添加位置信息 + raise FormulaParseError( + message=f"解析节点失败: {str(e)}", + expr=expr, + lineno=lineno, + col_offset=col_offset, + ) from e + + def _visit_Name(self, node: ast.Name) -> Symbol: + """访问名称节点 - 永远转为 Symbol。 + + 注意:利用 AST 语法自然区分变量和函数调用: + - log → Symbol("log")(数据列引用) + - log(close) → 在 _visit_Call 中处理(函数调用) + + Args: + node: AST 名称节点 + + Returns: + Symbol 节点 + """ + return Symbol(node.id) + + def _visit_Constant(self, node: ast.Constant, expr: str) -> Node: + """访问常量节点。 + + 支持的类型: + - int/float → Constant 节点 + - str → Symbol 节点(支持 ts_mean("close", 20) 语法) + + Args: + node: AST 常量节点 + expr: 原始表达式字符串 + + Returns: + Constant 或 Symbol 节点 + + Raises: + InvalidSyntaxError: 不支持的常量类型 + """ + if isinstance(node.value, (int, float)): + return Constant(node.value) + elif isinstance(node.value, str): + # 字符串常量转为 Symbol,支持 "close" 写法 + return Symbol(node.value) + else: + lineno = getattr(node, "lineno", None) + col_offset = getattr(node, "col_offset", None) + raise InvalidSyntaxError( + message=f"不支持的常量类型: {type(node.value).__name__}", + expr=expr, + lineno=lineno, + col_offset=col_offset, + ) + + def _visit_BinOp(self, node: ast.BinOp, expr: str) -> BinaryOpNode: + """访问二元运算节点。 + + Args: + node: AST 二元运算节点 + expr: 原始表达式字符串 + + Returns: + BinaryOpNode 节点 + + Raises: + InvalidSyntaxError: 不支持的运算符 + """ + left = self._visit(node.left, expr) + right = self._visit(node.right, expr) + + op = BIN_OP_MAP.get(type(node.op)) + if op is None: + lineno = getattr(node, "lineno", None) + col_offset = getattr(node, "col_offset", None) + raise InvalidSyntaxError( + message=f"不支持的运算符: {type(node.op).__name__}", + expr=expr, + lineno=lineno, + col_offset=col_offset, + ) + + return BinaryOpNode(op, left, right) + + def _visit_UnaryOp(self, node: ast.UnaryOp, expr: str) -> Node: + """访问一元运算节点。 + + 支持常量折叠优化:纯数值的一元运算直接计算结果。 + + Args: + node: AST 一元运算节点 + expr: 原始表达式字符串 + + Returns: + Constant(常量折叠)或 UnaryOpNode 节点 + + Raises: + InvalidSyntaxError: 不支持的运算符 + """ + operand = self._visit(node.operand, expr) + op = UNARY_OP_MAP.get(type(node.op)) + + lineno = getattr(node, "lineno", None) + col_offset = getattr(node, "col_offset", None) + + if op is None: + raise InvalidSyntaxError( + message=f"不支持的一元运算符: {type(node.op).__name__}", + expr=expr, + lineno=lineno, + col_offset=col_offset, + ) + + if op == "~": + raise InvalidSyntaxError( + message="位运算 '~' 不被支持", + expr=expr, + lineno=lineno, + col_offset=col_offset, + ) + + # 常量折叠优化:纯数值直接计算 + if isinstance(operand, Constant) and isinstance(operand.value, (int, float)): + if op == "-": + return Constant(-operand.value) + elif op == "+": + return operand # +5 就是 5 + + # 非常量使用运算符重载 + if op == "-": + return -operand + elif op == "+": + return +operand + + # 不应该到达这里 + raise InvalidSyntaxError( + message=f"无法处理的一元运算符: {op}", + expr=expr, + lineno=lineno, + col_offset=col_offset, + ) + + def _visit_Call(self, node: ast.Call, expr: str) -> Node: + """访问函数调用节点。 + + 注意:只有在这里查注册表,处理函数调用。 + + Args: + node: AST 函数调用节点 + expr: 原始表达式字符串 + + Returns: + 函数返回的 Node 节点 + + Raises: + InvalidSyntaxError: 不支持的函数调用语法 + UnknownFunctionError: 函数未注册 + """ + lineno = getattr(node, "lineno", None) + col_offset = getattr(node, "col_offset", None) + + # 只支持简单函数调用(如 func(a, b)) + if not isinstance(node.func, ast.Name): + raise InvalidSyntaxError( + message="只支持简单函数调用(如 func(a, b))", + expr=expr, + lineno=lineno, + col_offset=col_offset, + ) + + func_name = node.func.id + func = self.registry.get(func_name) + + if func is None: + raise UnknownFunctionError( + func_name=func_name, + available=self.registry.available_functions(), + expr=expr, + lineno=lineno, + col_offset=col_offset, + ) + + # 解析位置参数 + args = [self._visit(arg, expr) for arg in node.args] + + # 解析关键字参数(如果有) + kwargs = {} + for keyword in node.keywords: + kwargs[keyword.arg] = self._visit(keyword.value, expr) + + # 应用函数 + try: + if kwargs: + return func(*args, **kwargs) + return func(*args) + except TypeError as e: + raise InvalidSyntaxError( + message=f"函数 '{func_name}' 调用失败: {e}", + expr=expr, + lineno=lineno, + col_offset=col_offset, + ) from e + + def _visit_Compare(self, node: ast.Compare, expr: str) -> BinaryOpNode: + """访问比较运算节点。 + + 注意:只支持简单二元比较,不支持链式比较(如 a < b < c)。 + + Args: + node: AST 比较节点 + expr: 原始表达式字符串 + + Returns: + BinaryOpNode 节点(使用比较运算符) + + Raises: + InvalidSyntaxError: 链式比较或不支持的运算符 + """ + lineno = getattr(node, "lineno", None) + col_offset = getattr(node, "col_offset", None) + + # Python 支持链式比较 (a < b < c),这里简化为二元比较 + if len(node.ops) != 1 or len(node.comparators) != 1: + raise InvalidSyntaxError( + message="只支持简单二元比较(如 a > b),不支持链式比较", + expr=expr, + lineno=lineno, + col_offset=col_offset, + ) + + left = self._visit(node.left, expr) + op = COMPARE_OP_MAP.get(type(node.ops[0])) + + if op is None: + raise InvalidSyntaxError( + message=f"不支持的比较运算符: {type(node.ops[0]).__name__}", + expr=expr, + lineno=lineno, + col_offset=col_offset, + ) + + right = self._visit(node.comparators[0], expr) + return BinaryOpNode(op, left, right) diff --git a/src/factors/registry.py b/src/factors/registry.py new file mode 100644 index 0000000..5618652 --- /dev/null +++ b/src/factors/registry.py @@ -0,0 +1,227 @@ +"""函数注册表 - 管理字符串函数名到 Python 函数的映射。 + +支持自动发现和手动注册,与 FormulaParser 配合使用。 + +示例: + >>> from src.factors.registry import FunctionRegistry + >>> registry = FunctionRegistry(auto_scan=True) # 自动加载 api.py 函数 + >>> registry.available_functions()[:5] + ['abs', 'clip', 'cs_demean', 'cs_neutralize', 'cs_rank'] +""" + +import inspect +import typing +from typing import Any, Callable, Dict, List, Optional, Set + +from src.factors.dsl import Node, FunctionNode +from src.factors.exceptions import DuplicateFunctionError + + +class FunctionRegistry: + """函数注册表。 + + 管理字符串函数名到可调用对象的映射。 + 自动从 api.py 加载标准函数,支持用户自定义函数注册。 + + Attributes: + _functions: 函数字典,name -> callable + """ + + def __init__(self, auto_scan: bool = True) -> None: + """初始化注册表。 + + Args: + auto_scan: 是否自动扫描 api.py 模块,默认 True + """ + self._functions: Dict[str, Callable] = {} + + if auto_scan: + self._scan_api_module() + + def register( + self, name: str, func: Callable, force: bool = False + ) -> "FunctionRegistry": + """注册自定义函数。 + + Args: + name: 函数名称(字符串形式) + func: 可调用对象 + force: 是否强制覆盖已存在的函数,默认 False + + Returns: + self(支持链式调用) + + Raises: + DuplicateFunctionError: 当函数名已存在且 force=False 时 + + Example: + >>> registry = FunctionRegistry(auto_scan=False) + >>> registry.register("my_func", lambda x: x * 2) + >>> registry.get("my_func")(5) + 10 + """ + if name in self._functions and not force: + raise DuplicateFunctionError(name) + + self._functions[name] = func + return self + + def unregister(self, name: str) -> "FunctionRegistry": + """注销函数。 + + Args: + name: 要注销的函数名 + + Returns: + self(支持链式调用) + + Raises: + KeyError: 函数不存在时 + """ + if name not in self._functions: + raise KeyError(f"函数 '{name}' 不存在") + del self._functions[name] + return self + + def get(self, name: str) -> Optional[Callable]: + """获取函数。 + + Args: + name: 函数名称 + + Returns: + 函数对象,不存在返回 None + """ + return self._functions.get(name) + + def has(self, name: str) -> bool: + """检查函数是否存在。 + + Args: + name: 函数名称 + + Returns: + 是否存在 + """ + return name in self._functions + + def available_functions(self) -> List[str]: + """返回所有可用函数名列表(按字母序)。 + + Returns: + 排序后的函数名列表 + """ + return sorted(self._functions.keys()) + + def clear(self) -> "FunctionRegistry": + """清空所有注册的函数。 + + Returns: + self(支持链式调用) + """ + self._functions.clear() + return self + + def scan_module( + self, module: Any, prefix: str = "", force: bool = False + ) -> "FunctionRegistry": + """扫描指定模块,自动注册符合条件的函数。 + + 扫描规则: + 1. 模块级别的函数(排除私有函数 _*) + 2. 返回类型注解为 Node 或 FunctionNode + + Args: + module: 要扫描的模块对象 + prefix: 函数名前缀,用于避免命名冲突 + force: 是否强制覆盖已存在的函数 + + Returns: + self(支持链式调用) + + Example: + >>> import my_custom_module + >>> registry.scan_module(my_custom_module, prefix="custom_") + """ + for name, obj in inspect.getmembers(module): + # 只处理非私有函数 + if not inspect.isfunction(obj) or name.startswith("_"): + continue + + # 检查是否应该注册 + if self._should_register(obj): + full_name = prefix + name + self.register(full_name, obj, force=force) + + return self + + def _scan_api_module(self) -> None: + """自动扫描 api.py 模块,注册所有符合条件的函数。 + + 这是默认的自动扫描行为,在 __init__ 中调用。 + """ + try: + from src.factors import api + + self.scan_module(api) + except ImportError: + # api 模块可能不存在,静默跳过 + pass + + def _should_register(self, func: Callable) -> bool: + """检查函数是否应该被注册。 + + 基于类型提示检查函数返回类型,只注册返回 Node 或 FunctionNode 的函数。 + + Args: + func: 要检查的函数 + + Returns: + 是否应该注册该函数 + """ + try: + hints = typing.get_type_hints(func) + return_type = hints.get("return") + + if return_type is None: + return False + + # 处理 Union 类型(如 Union[Node, FunctionNode]) + origin = typing.get_origin(return_type) + args = typing.get_args(return_type) + + if origin is typing.Union: + # Union 类型,检查任一参数 + return any(self._is_node_type(arg) for arg in args) + else: + # 单一类型 + return self._is_node_type(return_type) + + except Exception: + return False + + def _is_node_type(self, typ: Any) -> bool: + """检查类型是否是 Node 或 FunctionNode 的子类。 + + Args: + typ: 要检查的类型 + + Returns: + 是否是 Node 相关类型 + """ + if not isinstance(typ, type): + return False + + return issubclass(typ, (Node, FunctionNode)) + + def __len__(self) -> int: + """返回已注册函数数量。""" + return len(self._functions) + + def __contains__(self, name: str) -> bool: + """检查是否包含某个函数名。""" + return name in self._functions + + def __repr__(self) -> str: + """返回注册表字符串表示。""" + return f"FunctionRegistry({len(self._functions)} functions: {self.available_functions()[:5]}...)"