From 77e4e94e05937356237afd93c59491b1199c2ee7 Mon Sep 17 00:00:00 2001 From: liaozhaorun <1300336796@qq.com> Date: Mon, 2 Mar 2026 22:29:18 +0800 Subject: [PATCH] =?UTF-8?q?refactor(factors):=20=E6=8B=86=E5=88=86=20engin?= =?UTF-8?q?e.py=20=E4=B8=BA=E6=A8=A1=E5=9D=97=E5=8C=96=E5=8C=85=20?= =?UTF-8?q?=E5=B0=86=E5=8D=95=E6=96=87=E4=BB=B6=20engine.py=20(1064?= =?UTF-8?q?=E8=A1=8C)=20=E6=8B=86=E5=88=86=E4=B8=BA=20engine/=20=E5=8C=85?= =?UTF-8?q?=EF=BC=9A=20-=20=E6=95=B0=E6=8D=AE=E8=A7=84=E6=A0=BC=E3=80=81?= =?UTF-8?q?=E8=B7=AF=E7=94=B1=E5=99=A8=E3=80=81=E8=AE=A1=E5=88=92=E5=99=A8?= =?UTF-8?q?=E3=80=81=E8=AE=A1=E7=AE=97=E5=BC=95=E6=93=8E=E3=80=81=E5=9B=A0?= =?UTF-8?q?=E5=AD=90=E5=BC=95=E6=93=8E=E5=88=86=E7=A6=BB=20-=20=E4=BF=9D?= =?UTF-8?q?=E6=8C=81=E5=90=91=E5=90=8E=E5=85=BC=E5=AE=B9=EF=BC=8CAPI=20?= =?UTF-8?q?=E6=97=A0=E5=8F=98=E5=8C=96?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/factors/__init__.py | 3 + src/factors/engine/__init__.py | 25 ++ src/factors/engine/compute_engine.py | 155 ++++++++++ src/factors/engine/data_router.py | 304 ++++++++++++++++++ src/factors/engine/data_spec.py | 47 +++ src/factors/engine/factor_engine.py | 442 +++++++++++++++++++++++++++ src/factors/engine/planner.py | 170 +++++++++++ 7 files changed, 1146 insertions(+) create mode 100644 src/factors/engine/__init__.py create mode 100644 src/factors/engine/compute_engine.py create mode 100644 src/factors/engine/data_router.py create mode 100644 src/factors/engine/data_spec.py create mode 100644 src/factors/engine/factor_engine.py create mode 100644 src/factors/engine/planner.py diff --git a/src/factors/__init__.py b/src/factors/__init__.py index 6d9f579..305f2b2 100644 --- a/src/factors/__init__.py +++ b/src/factors/__init__.py @@ -52,6 +52,9 @@ from src.factors.engine import ( ComputeEngine, ) +# 保持向后兼容:factor_engine.py 中的类也可以通过 src.factors.engine 访问 +# 例如:from src.factors.engine import FactorEngine + __all__ = [ # DSL 层 "Node", diff --git a/src/factors/engine/__init__.py b/src/factors/engine/__init__.py new file mode 100644 index 0000000..63cdace --- /dev/null +++ b/src/factors/engine/__init__.py @@ -0,0 +1,25 @@ +"""因子计算引擎模块。 + +提供完整的因子计算引擎组件: +- DataSpec: 数据规格定义 +- ExecutionPlan: 执行计划 +- DataRouter: 数据路由器 +- ExecutionPlanner: 执行计划生成器 +- ComputeEngine: 计算引擎 +- FactorEngine: 因子计算引擎(统一入口) +""" + +from src.factors.engine.data_spec import DataSpec, ExecutionPlan +from src.factors.engine.data_router import DataRouter +from src.factors.engine.planner import ExecutionPlanner +from src.factors.engine.compute_engine import ComputeEngine +from src.factors.engine.factor_engine import FactorEngine + +__all__ = [ + "DataSpec", + "ExecutionPlan", + "DataRouter", + "ExecutionPlanner", + "ComputeEngine", + "FactorEngine", +] diff --git a/src/factors/engine/compute_engine.py b/src/factors/engine/compute_engine.py new file mode 100644 index 0000000..8d1da46 --- /dev/null +++ b/src/factors/engine/compute_engine.py @@ -0,0 +1,155 @@ +"""计算引擎。 + +执行并行运算,负责将执行计划应用到数据上。 +""" + +from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor +from typing import Any, Dict, List, Optional, Set, Union + +import polars as pl + +from src.factors.engine.data_spec import ExecutionPlan + + +class ComputeEngine: + """计算引擎 - 执行并行运算。 + + 负责将执行计划应用到数据上,支持并行计算。 + + Attributes: + max_workers: 最大并行工作线程数 + use_processes: 是否使用进程池(CPU 密集型任务) + """ + + def __init__( + self, + max_workers: int = 4, + use_processes: bool = False, + ) -> None: + """初始化计算引擎。 + + Args: + max_workers: 最大并行工作线程数 + use_processes: 是否使用进程池代替线程池 + """ + self.max_workers = max_workers + self.use_processes = use_processes + + def execute( + self, + plan: ExecutionPlan, + data: pl.DataFrame, + ) -> pl.DataFrame: + """执行计算计划。 + + Args: + plan: 执行计划 + data: 输入数据(核心宽表) + + Returns: + 包含因子结果的 DataFrame + """ + # 检查依赖字段是否存在 + missing_cols = plan.dependencies - set(data.columns) + if missing_cols: + raise ValueError(f"数据缺少必要的字段: {missing_cols}") + + # 执行计算 + result = data.with_columns([plan.polars_expr.alias(plan.output_name)]) + + return result + + def execute_batch( + self, + plans: List[ExecutionPlan], + data: pl.DataFrame, + ) -> pl.DataFrame: + """批量执行多个计算计划。 + + Args: + plans: 执行计划列表 + data: 输入数据 + + Returns: + 包含所有因子结果的 DataFrame + """ + result = data + + for plan in plans: + result = self.execute(plan, result) + + return result + + def execute_parallel( + self, + plans: List[ExecutionPlan], + data: pl.DataFrame, + ) -> pl.DataFrame: + """并行执行多个计算计划。 + + Args: + plans: 执行计划列表 + data: 输入数据 + + Returns: + 包含所有因子结果的 DataFrame + """ + # 检查计划间依赖 + independent_plans = [] + dependent_plans = [] + available_cols = set(data.columns) + + for plan in plans: + if plan.dependencies <= available_cols: + independent_plans.append(plan) + available_cols.add(plan.output_name) + else: + dependent_plans.append(plan) + + # 并行执行独立计划 + if independent_plans: + ExecutorClass = ( + ProcessPoolExecutor if self.use_processes else ThreadPoolExecutor + ) + + with ExecutorClass(max_workers=self.max_workers) as executor: + futures = { + executor.submit(self._execute_single, plan, data): plan + for plan in independent_plans + } + + results = [] + for future in futures: + plan = futures[future] + try: + result_col = future.result() + results.append((plan.output_name, result_col)) + except Exception as e: + raise RuntimeError(f"计算因子 {plan.output_name} 失败: {e}") + + # 合并结果 + for name, series in results: + data = data.with_columns([series.alias(name)]) + + # 顺序执行依赖计划 + for plan in dependent_plans: + data = self.execute(plan, data) + + return data + + def _execute_single( + self, + plan: ExecutionPlan, + data: pl.DataFrame, + ) -> pl.Series: + """执行单个计划并返回结果列。 + + Args: + plan: 执行计划 + data: 输入数据 + + Returns: + 计算结果序列 + """ + result = self.execute(plan, data) + return result[plan.output_name] diff --git a/src/factors/engine/data_router.py b/src/factors/engine/data_router.py new file mode 100644 index 0000000..6cccc29 --- /dev/null +++ b/src/factors/engine/data_router.py @@ -0,0 +1,304 @@ +"""数据路由器。 + +按需取数、组装核心宽表。 +负责根据数据规格从数据源拉取数据,并组装成统一的宽表格式。 +支持内存数据源(用于测试)和真实数据库连接。 +""" + +from typing import Any, Dict, List, Optional, Set, Union +import threading + +import polars as pl + +from src.factors.engine.data_spec import DataSpec +from src.data.storage import Storage + + +class DataRouter: + """数据路由器 - 按需取数、组装核心宽表。 + + 负责根据数据规格从数据源拉取数据,并组装成统一的宽表格式。 + 支持内存数据源(用于测试)和真实数据库连接。 + + Attributes: + data_source: 数据源,可以是内存 DataFrame 字典或数据库连接 + is_memory_mode: 是否为内存模式 + """ + + def __init__(self, data_source: Optional[Dict[str, pl.DataFrame]] = None) -> None: + """初始化数据路由器。 + + Args: + data_source: 内存数据源,字典格式 {表名: DataFrame} + 为 None 时自动连接 DuckDB 数据库 + """ + self.data_source = data_source or {} + self.is_memory_mode = data_source is not None + self._cache: Dict[str, pl.DataFrame] = {} + self._lock = threading.Lock() + + # 数据库模式下初始化 Storage + if not self.is_memory_mode: + self._storage = Storage() + else: + self._storage = None + + def fetch_data( + self, + data_specs: List[DataSpec], + start_date: str, + end_date: str, + stock_codes: Optional[List[str]] = None, + ) -> pl.DataFrame: + """根据数据规格获取并组装核心宽表。 + + Args: + data_specs: 数据规格列表 + start_date: 开始日期 (YYYYMMDD) + end_date: 结束日期 (YYYYMMDD) + stock_codes: 股票代码列表,None 表示全市场 + + Returns: + 组装好的核心宽表 DataFrame + + Raises: + ValueError: 当数据源中缺少必要的表或字段时 + """ + if not data_specs: + raise ValueError("数据规格不能为空") + + # 收集所有需要的表和字段 + required_tables: Dict[str, Set[str]] = {} + max_lookback = 0 + + for spec in data_specs: + if spec.table not in required_tables: + required_tables[spec.table] = set() + required_tables[spec.table].update(spec.columns) + max_lookback = max(max_lookback, spec.lookback_days) + + # 调整日期范围以包含回看期 + adjusted_start = self._adjust_start_date(start_date, max_lookback) + + # 从数据源获取各表数据 + table_data = {} + for table_name, columns in required_tables.items(): + df = self._load_table( + table_name=table_name, + columns=list(columns), + start_date=adjusted_start, + end_date=end_date, + stock_codes=stock_codes, + ) + table_data[table_name] = df + + # 组装核心宽表 + core_table = self._assemble_wide_table(table_data, required_tables) + + # 过滤到实际请求日期范围 + core_table = core_table.filter( + (pl.col("trade_date") >= start_date) & (pl.col("trade_date") <= end_date) + ) + + return core_table + + def _load_table( + self, + table_name: str, + columns: List[str], + start_date: str, + end_date: str, + stock_codes: Optional[List[str]] = None, + ) -> pl.DataFrame: + """加载单个表的数据。 + + Args: + table_name: 表名 + columns: 需要的字段 + start_date: 开始日期 + end_date: 结束日期 + stock_codes: 股票代码过滤 + + Returns: + 过滤后的 DataFrame + """ + cache_key = f"{table_name}_{start_date}_{end_date}_{stock_codes}" + + with self._lock: + if cache_key in self._cache: + return self._cache[cache_key] + + if self.is_memory_mode: + df = self._load_from_memory( + table_name, columns, start_date, end_date, stock_codes + ) + else: + df = self._load_from_database( + table_name, columns, start_date, end_date, stock_codes + ) + + with self._lock: + self._cache[cache_key] = df + + return df + + def _load_from_memory( + self, + table_name: str, + columns: List[str], + start_date: str, + end_date: str, + stock_codes: Optional[List[str]] = None, + ) -> pl.DataFrame: + """从内存数据源加载数据。""" + if table_name not in self.data_source: + raise ValueError(f"内存数据源中缺少表: {table_name}") + + df = self.data_source[table_name] + + # 确保必需字段存在 + for col in columns: + if col not in df.columns and col not in ["ts_code", "trade_date"]: + raise ValueError(f"表 {table_name} 缺少字段: {col}") + + # 过滤日期和股票 + df = df.filter( + (pl.col("trade_date") >= start_date) & (pl.col("trade_date") <= end_date) + ) + + if stock_codes is not None: + df = df.filter(pl.col("ts_code").is_in(stock_codes)) + + # 选择需要的列 + select_cols = ["ts_code", "trade_date"] + [ + c for c in columns if c in df.columns + ] + return df.select(select_cols) + + def _load_from_database( + self, + table_name: str, + columns: List[str], + start_date: str, + end_date: str, + stock_codes: Optional[List[str]] = None, + ) -> pl.DataFrame: + """从 DuckDB 数据库加载数据。 + + 利用 Storage.load_polars() 方法,支持 SQL 查询下推。 + """ + if self._storage is None: + raise RuntimeError("Storage 未初始化") + + # 检查表是否存在 + if not self._storage.exists(table_name): + raise ValueError(f"数据库中不存在表: {table_name}") + + # 构建查询参数 + # Storage.load_polars 目前只支持单个 ts_code,需要处理列表情况 + if stock_codes is not None and len(stock_codes) == 1: + ts_code_filter = stock_codes[0] + else: + ts_code_filter = None + + try: + # 从数据库加载原始数据 + df = self._storage.load_polars( + name=table_name, + start_date=start_date, + end_date=end_date, + ts_code=ts_code_filter, + ) + except Exception as e: + raise RuntimeError(f"从数据库加载表 {table_name} 失败: {e}") + + # 如果 stock_codes 是列表且长度 > 1,在内存中过滤 + if stock_codes is not None and len(stock_codes) > 1: + df = df.filter(pl.col("ts_code").is_in(stock_codes)) + + # 检查必需字段 + for col in columns: + if col not in df.columns and col not in ["ts_code", "trade_date"]: + raise ValueError(f"表 {table_name} 缺少字段: {col}") + + # 选择需要的列 + select_cols = ["ts_code", "trade_date"] + [ + c for c in columns if c in df.columns + ] + + return df.select(select_cols) + + def _assemble_wide_table( + self, + table_data: Dict[str, pl.DataFrame], + required_tables: Dict[str, Set[str]], + ) -> pl.DataFrame: + """组装多表数据为核心宽表。 + + 使用 left join 合并各表数据,以第一个表为基准。 + + Args: + table_data: 表名到 DataFrame 的映射 + required_tables: 表名到字段集合的映射 + + Returns: + 组装后的宽表 + """ + if not table_data: + raise ValueError("没有数据可组装") + + # 以第一个表为基准 + base_table_name = list(table_data.keys())[0] + result = table_data[base_table_name] + + # 与其他表 join + for table_name, df in table_data.items(): + if table_name == base_table_name: + continue + + # 使用 ts_code 和 trade_date 作为 join 键 + result = result.join( + df, + on=["ts_code", "trade_date"], + how="left", + ) + + return result + + def _adjust_start_date(self, start_date: str, lookback_days: int) -> str: + """根据回看天数调整开始日期。 + + Args: + start_date: 原始开始日期 (YYYYMMDD) + lookback_days: 需要回看的交易日数 + + Returns: + 调整后的开始日期 + """ + # 简化的日期调整:假设每月30天,向前推移 + # 实际应用中应该使用交易日历 + year = int(start_date[:4]) + month = int(start_date[4:6]) + day = int(start_date[6:8]) + + total_days = lookback_days + 30 # 额外缓冲 + + day -= total_days + while day <= 0: + month -= 1 + if month <= 0: + month = 12 + year -= 1 + day += 30 + + return f"{year:04d}{month:02d}{day:02d}" + + def clear_cache(self) -> None: + """清除数据缓存。""" + with self._lock: + self._cache.clear() + + # 数据库模式下清理 Storage 连接(可选) + if not self.is_memory_mode and self._storage is not None: + # Storage 使用单例模式,不需要关闭连接 + pass diff --git a/src/factors/engine/data_spec.py b/src/factors/engine/data_spec.py new file mode 100644 index 0000000..8a7e81e --- /dev/null +++ b/src/factors/engine/data_spec.py @@ -0,0 +1,47 @@ +"""数据规格和执行计划定义。 + +定义因子计算所需的数据规格和执行计划结构。 +""" + +from dataclasses import dataclass, field +from typing import Any, Dict, List, Optional, Set, Union + +import polars as pl + + +@dataclass +class DataSpec: + """数据规格定义。 + + 描述因子计算所需的数据表和字段。 + + Attributes: + table: 数据表名称 + columns: 需要的字段列表 + lookback_days: 回看天数(用于时序计算) + """ + + table: str + columns: List[str] + lookback_days: int = 1 + + +@dataclass +class ExecutionPlan: + """执行计划。 + + 包含完整的执行所需信息:数据源、转换逻辑、输出格式。 + + Attributes: + data_specs: 数据规格列表 + polars_expr: Polars 表达式 + dependencies: 依赖的原始字段 + output_name: 输出因子名称 + factor_dependencies: 依赖的其他因子名称(用于分步执行) + """ + + data_specs: List[DataSpec] + polars_expr: pl.Expr + dependencies: Set[str] + output_name: str + factor_dependencies: Set[str] = field(default_factory=set) diff --git a/src/factors/engine/factor_engine.py b/src/factors/engine/factor_engine.py new file mode 100644 index 0000000..0c1cd3f --- /dev/null +++ b/src/factors/engine/factor_engine.py @@ -0,0 +1,442 @@ +"""因子计算引擎 - 系统统一入口。 + +提供从表达式到结果的完整执行链路,是研究员使用系统的唯一接口。 + +执行流程: + 1. 注册表达式 -> 调用编译器解析依赖 + 2. 调用路由器连接数据库拉取并组装核心宽表 + 3. 调用翻译器生成物理执行计划 + 4. 将计划提交给计算引擎执行并行运算 + 5. 返回包含因子结果的数据表 +""" + +from typing import Any, Dict, List, Optional, Set, Union + +import polars as pl + +from src.factors.dsl import ( + Node, + Symbol, + BinaryOpNode, + UnaryOpNode, + FunctionNode, +) +from src.factors.translator import PolarsTranslator +from src.factors.engine.data_spec import DataSpec, ExecutionPlan +from src.factors.engine.data_router import DataRouter +from src.factors.engine.planner import ExecutionPlanner +from src.factors.engine.compute_engine import ComputeEngine + + +class FactorEngine: + """因子计算引擎 - 系统统一入口。 + + 提供从表达式到结果的完整执行链路,是研究员使用系统的唯一接口。 + + 执行流程: + 1. 注册表达式 -> 调用编译器解析依赖 + 2. 调用路由器连接数据库拉取并组装核心宽表 + 3. 调用翻译器生成物理执行计划 + 4. 将计划提交给计算引擎执行并行运算 + 5. 返回包含因子结果的数据表 + + Attributes: + router: 数据路由器 + planner: 执行计划生成器 + compute_engine: 计算引擎 + registered_expressions: 注册的表达式字典 + """ + + def __init__( + self, + data_source: Optional[Dict[str, pl.DataFrame]] = None, + max_workers: int = 4, + ) -> None: + """初始化因子引擎。 + + Args: + data_source: 内存数据源,为 None 时使用数据库连接 + max_workers: 并行计算的最大工作线程数 + """ + self.router = DataRouter(data_source) + self.planner = ExecutionPlanner() + self.compute_engine = ComputeEngine(max_workers=max_workers) + self.registered_expressions: Dict[str, Node] = {} + self._plans: Dict[str, ExecutionPlan] = {} + + def register( + self, + name: str, + expression: Node, + data_specs: Optional[List[DataSpec]] = None, + ) -> "FactorEngine": + """注册因子表达式。 + + Args: + name: 因子名称 + expression: DSL 表达式 + data_specs: 数据规格,None 时自动推导 + + Returns: + self,支持链式调用 + + Example: + >>> from src.factors.api import close, ts_mean + >>> engine = FactorEngine() + >>> engine.register("ma20", ts_mean(close, 20)) + """ + # 检测因子依赖(在注册当前因子之前检查其他已注册因子) + factor_deps = self._find_factor_dependencies(expression) + + self.registered_expressions[name] = expression + + # 预创建执行计划 + plan = self.planner.create_plan( + expression=expression, + output_name=name, + data_specs=data_specs, + ) + + # 添加因子依赖信息 + plan.factor_dependencies = factor_deps + + self._plans[name] = plan + + return self + + def compute( + self, + factor_names: Union[str, List[str]], + start_date: str, + end_date: str, + stock_codes: Optional[List[str]] = None, + ) -> pl.DataFrame: + """计算指定因子的值。 + + 完整的执行流程:取数 -> 组装 -> 翻译 -> 计算。 + + Args: + factor_names: 因子名称或名称列表 + start_date: 开始日期 (YYYYMMDD) + end_date: 结束日期 (YYYYMMDD) + stock_codes: 股票代码列表,None 表示全市场 + + Returns: + 包含因子结果的数据表 + + Raises: + ValueError: 当因子未注册或数据不足时 + + Example: + >>> result = engine.compute("ma20", "20240101", "20240131") + >>> result = engine.compute(["ma20", "rsi"], "20240101", "20240131") + """ + # 标准化因子名称 + if isinstance(factor_names, str): + factor_names = [factor_names] + + # 1. 获取执行计划 + plans = [] + for name in factor_names: + if name not in self._plans: + raise ValueError(f"因子未注册: {name}") + plans.append(self._plans[name]) + + # 2. 合并数据规格并获取数据 + all_specs = [] + for plan in plans: + all_specs.extend(plan.data_specs) + + # 3. 从路由器获取核心宽表 + core_data = self.router.fetch_data( + data_specs=all_specs, + start_date=start_date, + end_date=end_date, + stock_codes=stock_codes, + ) + + if len(core_data) == 0: + raise ValueError("未获取到任何数据,请检查日期范围和股票代码") + + # 4. 按依赖顺序执行计算 + if len(plans) == 1: + result = self.compute_engine.execute(plans[0], core_data) + else: + # 使用依赖感知的方式执行 + result = self._execute_with_dependencies(factor_names, core_data) + + return result + + def list_registered(self) -> List[str]: + """获取已注册的因子列表。 + + Returns: + 因子名称列表 + """ + return list(self.registered_expressions.keys()) + + def get_expression(self, name: str) -> Optional[Node]: + """获取已注册的表达式。 + + Args: + name: 因子名称 + + Returns: + 表达式节点,未注册时返回 None + """ + return self.registered_expressions.get(name) + + def clear(self) -> None: + """清除所有注册的表达式和缓存。""" + self.registered_expressions.clear() + self._plans.clear() + self.router.clear_cache() + + def preview_plan(self, factor_name: str) -> Optional[ExecutionPlan]: + """预览因子的执行计划。 + + Args: + factor_name: 因子名称 + + Returns: + 执行计划,未注册时返回 None + """ + return self._plans.get(factor_name) + + def _execute_with_dependencies( + self, + factor_names: List[str], + core_data: pl.DataFrame, + ) -> pl.DataFrame: + """按依赖顺序执行因子计算。 + + 支持 cs_rank 等需要依赖列已存在的场景。 + + Args: + factor_names: 因子名称列表 + core_data: 核心宽表数据 + + Returns: + 包含所有因子结果的数据表 + """ + # 1. 拓扑排序 + sorted_names = self._topological_sort(factor_names) + + # 2. 按顺序执行 + result = core_data + for name in sorted_names: + plan = self._plans[name] + + # 创建新的执行计划,引用已计算的依赖列 + new_plan = self._create_optimized_plan(plan, result) + + # 执行计算 + result = self.compute_engine.execute(new_plan, result) + + return result + + def _create_optimized_plan( + self, + plan: ExecutionPlan, + current_data: pl.DataFrame, + ) -> ExecutionPlan: + """创建优化的执行计划。 + + 将表达式中已计算的依赖因子替换为列引用。 + + Args: + plan: 原始执行计划 + current_data: 当前数据(包含已计算的依赖列) + + Returns: + 新的执行计划 + """ + from src.factors.dsl import Symbol + + # 获取当前数据中已存在的列 + existing_cols = set(current_data.columns) + + # 检查依赖列是否已存在 + deps_available = plan.factor_dependencies & existing_cols + + if not deps_available: + # 没有可用的依赖列,直接返回原计划 + return plan + + # 获取原始表达式 + original_expr = self.registered_expressions[plan.output_name] + + # 创建新的表达式,用 Symbol 引用替换依赖因子 + def replace_with_symbol(node: Node) -> Node: + """递归替换表达式中的依赖因子为 Symbol 引用。""" + from typing import Any + + n: Any = node + + # 检查当前节点是否等于某个已计算依赖因子 + for dep_name in deps_available: + dep_expr = self.registered_expressions[dep_name] + if self._expressions_equal(node, dep_expr): + return Symbol(dep_name) + + # 递归处理子节点 + if isinstance(n, BinaryOpNode): + new_left = replace_with_symbol(n.left) + new_right = replace_with_symbol(n.right) + if new_left is not n.left or new_right is not n.right: + return BinaryOpNode(n.op, new_left, new_right) + elif isinstance(n, UnaryOpNode): + new_operand = replace_with_symbol(n.operand) + if new_operand is not n.operand: + return UnaryOpNode(n.op, new_operand) + elif isinstance(n, FunctionNode): + new_args = [replace_with_symbol(arg) for arg in n.args] + if any( + new_arg is not old_arg for new_arg, old_arg in zip(new_args, n.args) + ): + return FunctionNode(n.func_name, *new_args) + + return node + + # 替换表达式 + new_expr = replace_with_symbol(original_expr) + + # 重新翻译表达式 + translator = PolarsTranslator() + new_polars_expr = translator.translate(new_expr) + + # 更新依赖集合 + new_factor_deps = plan.factor_dependencies - deps_available + new_deps = plan.dependencies | deps_available + + return ExecutionPlan( + data_specs=plan.data_specs, + polars_expr=new_polars_expr, + dependencies=new_deps, + output_name=plan.output_name, + factor_dependencies=new_factor_deps, + ) + + def _expressions_equal(self, expr1: Node, expr2: Node) -> bool: + """比较两个表达式是否相等。 + + 用于检测因子间的依赖关系。 + + Args: + expr1: 第一个表达式 + expr2: 第二个表达式 + + Returns: + 是否相等 + """ + from typing import Any + + e1: Any = expr1 + e2: Any = expr2 + + if type(e1) != type(e2): + return False + + if isinstance(e1, Symbol): + return e1.name == e2.name + + from src.factors.dsl import Constant + + if isinstance(e1, Constant): + return e1.value == e2.value + + if isinstance(e1, BinaryOpNode): + return ( + e1.op == e2.op + and self._expressions_equal(e1.left, e2.left) + and self._expressions_equal(e1.right, e2.right) + ) + + if isinstance(e1, UnaryOpNode): + return e1.op == e2.op and self._expressions_equal(e1.operand, e2.operand) + + if isinstance(e1, FunctionNode): + if e1.func_name != e2.func_name or len(e1.args) != len(e2.args): + return False + return all( + self._expressions_equal(a1, a2) for a1, a2 in zip(e1.args, e2.args) + ) + + return False + + def _find_factor_dependencies(self, expression: Node) -> Set[str]: + """查找表达式依赖的其他因子。 + + 遍历已注册因子,检查表达式是否包含任何已注册因子的完整表达式。 + + Args: + expression: 待检查的表达式 + + Returns: + 依赖的因子名称集合 + """ + deps: Set[str] = set() + + # 检查表达式本身是否等于某个已注册因子 + for name, registered_expr in self.registered_expressions.items(): + if self._expressions_equal(expression, registered_expr): + deps.add(name) + break + + # 递归检查子节点 + if isinstance(expression, BinaryOpNode): + deps.update(self._find_factor_dependencies(expression.left)) + deps.update(self._find_factor_dependencies(expression.right)) + elif isinstance(expression, UnaryOpNode): + deps.update(self._find_factor_dependencies(expression.operand)) + elif isinstance(expression, FunctionNode): + for arg in expression.args: + deps.update(self._find_factor_dependencies(arg)) + + return deps + + def _topological_sort(self, factor_names: List[str]) -> List[str]: + """按依赖关系对因子进行拓扑排序。 + + 确保依赖的因子先被计算。 + + Args: + factor_names: 因子名称列表 + + Returns: + 排序后的因子名称列表 + + Raises: + ValueError: 当检测到循环依赖时 + """ + # 构建依赖图 + graph: Dict[str, Set[str]] = {} + in_degree: Dict[str, int] = {} + + for name in factor_names: + plan = self._plans[name] + # 只考虑在本次计算范围内的依赖 + deps = plan.factor_dependencies & set(factor_names) + graph[name] = deps + in_degree[name] = len(deps) + + # Kahn 算法 + result = [] + queue = [name for name, degree in in_degree.items() if degree == 0] + + while queue: + # 按原始顺序处理同级别的因子 + queue.sort(key=lambda x: factor_names.index(x)) + name = queue.pop(0) + result.append(name) + + for other in factor_names: + if name in graph[other]: + in_degree[other] -= 1 + if in_degree[other] == 0: + queue.append(other) + + if len(result) != len(factor_names): + raise ValueError("检测到因子循环依赖") + + return result diff --git a/src/factors/engine/planner.py b/src/factors/engine/planner.py new file mode 100644 index 0000000..aece8bc --- /dev/null +++ b/src/factors/engine/planner.py @@ -0,0 +1,170 @@ +"""执行计划生成器。 + +整合编译器和翻译器,生成完整的执行计划。 +""" + +from typing import Any, Dict, List, Optional, Set, Union + +from src.factors.dsl import ( + Node, + Symbol, + FunctionNode, + BinaryOpNode, + UnaryOpNode, + Constant, +) +from src.factors.compiler import DependencyExtractor +from src.factors.translator import PolarsTranslator +from src.factors.engine.data_spec import DataSpec, ExecutionPlan + + +class ExecutionPlanner: + """执行计划生成器。 + + 整合编译器和翻译器,生成完整的执行计划。 + + Attributes: + compiler: 依赖提取器 + translator: Polars 翻译器 + """ + + def __init__(self) -> None: + """初始化执行计划生成器。""" + self.compiler = DependencyExtractor() + self.translator = PolarsTranslator() + + def create_plan( + self, + expression: Node, + output_name: str = "factor", + data_specs: Optional[List[DataSpec]] = None, + ) -> ExecutionPlan: + """从表达式创建执行计划。 + + Args: + expression: DSL 表达式节点 + output_name: 输出因子名称 + data_specs: 预定义的数据规格,None 时自动推导 + + Returns: + 执行计划对象 + """ + # 1. 提取依赖 + dependencies = self.compiler.extract_dependencies(expression) + + # 2. 翻译为 Polars 表达式 + polars_expr = self.translator.translate(expression) + + # 3. 推导或验证数据规格 + if data_specs is None: + data_specs = self._infer_data_specs(dependencies, expression) + + return ExecutionPlan( + data_specs=data_specs, + polars_expr=polars_expr, + dependencies=dependencies, + output_name=output_name, + ) + + def _infer_data_specs( + self, + dependencies: Set[str], + expression: Node, + ) -> List[DataSpec]: + """从依赖推导数据规格。 + + 根据表达式中的函数类型推断回看天数需求。 + 基础行情字段(open, high, low, close, vol, amount, pre_close, change, pct_chg) + 默认从 pro_bar 表获取。 + + Args: + dependencies: 依赖的字段集合 + expression: 表达式节点 + + Returns: + 数据规格列表 + """ + # 计算最大回看窗口 + max_window = self._extract_max_window(expression) + lookback_days = max(1, max_window) + + # 基础行情字段集合(这些字段从 pro_bar 表获取) + pro_bar_fields = { + "open", + "high", + "low", + "close", + "vol", + "amount", + "pre_close", + "change", + "pct_chg", + "turnover_rate", + "volume_ratio", + } + + # 将依赖分为 pro_bar 字段和其他字段 + pro_bar_deps = dependencies & pro_bar_fields + other_deps = dependencies - pro_bar_fields + + data_specs = [] + + # pro_bar 表的数据规格 + if pro_bar_deps: + data_specs.append( + DataSpec( + table="pro_bar", + columns=sorted(pro_bar_deps), + lookback_days=lookback_days, + ) + ) + + # 其他字段从 daily 表获取 + if other_deps: + data_specs.append( + DataSpec( + table="daily", + columns=sorted(other_deps), + lookback_days=lookback_days, + ) + ) + + return data_specs + + def _extract_max_window(self, node: Node) -> int: + """从表达式中提取最大窗口大小。 + + Args: + node: AST 节点 + + Returns: + 最大窗口大小,无时序函数返回 1 + """ + if isinstance(node, FunctionNode): + window = 1 + # 检查函数参数中的窗口大小 + for arg in node.args: + if ( + isinstance(arg, Constant) + and isinstance(arg.value, int) + and arg.value > window + ): + window = arg.value + + # 递归检查子表达式 + for arg in node.args: + if isinstance(arg, Node) and not isinstance(arg, Constant): + window = max(window, self._extract_max_window(arg)) + + return window + + elif isinstance(node, BinaryOpNode): + return max( + self._extract_max_window(node.left), + self._extract_max_window(node.right), + ) + + elif isinstance(node, UnaryOpNode): + return self._extract_max_window(node.operand) + + return 1