# 因子计算流程详解 本文档详细描述 ProStock 项目中因子计算引擎的完整数据流,以因子表达式 `(close / ts_delay(close, 5)) - 1` 为例,说明从字符串解析到最终计算结果的完整流程。 ## 目录 1. [整体架构概览](#1-整体架构概览) 2. [阶段一:字符串解析](#2-阶段一字符串解析) 3. [阶段二:AST 编译与依赖提取](#3-阶段二ast-编译与依赖提取) 4. [阶段三:执行计划生成](#4-阶段三执行计划生成) 5. [阶段四:数据获取](#5-阶段四数据获取) 6. [阶段五:Polars 翻译与计算](#6-阶段五polars-翻译与计算) 7. [完整调用链示例](#7-完整调用链示例) 8. [数据流时序图](#8-数据流时序图) --- ## 1. 整体架构概览 ### 1.1 架构层次 因子框架采用分层设计,从上到下依次为: ``` API 层 (api.py) | v DSL 层 (dsl.py) ← 因子表达式 (Node) | v Parser (parser.py) ← 字符串公式解析 | v Registry (registry.py) ← 函数注册表 | v Compiler (compiler.py) ← AST 依赖提取 | v Planner (planner.py) ← 执行计划生成 | v Translator (translator.py) ← 翻译为 Polars 表达式 | v Engine (engine/) ← 执行引擎 | - FactorEngine: 统一入口 | - DataRouter: 数据路由 | - ExecutionPlanner: 执行计划 | - ComputeEngine: 计算引擎 | v 数据层 (data/storage.py) ← DuckDB 数据获取和存储 ``` ### 1.2 核心组件职责 | 组件 | 文件路径 | 主要职责 | |------|----------|----------| | **FormulaParser** | `factors/parser.py` | 将字符串表达式解析为 DSL AST 节点树 | | **FunctionRegistry** | `factors/registry.py` | 管理函数名到 Python 实现的映射 | | **DSL Nodes** | `factors/dsl.py` | 定义表达式节点(Symbol, FunctionNode, BinaryOpNode 等)| | **DependencyExtractor** | `factors/compiler.py` | 从 AST 提取依赖的数据字段 | | **ExecutionPlanner** | `factors/engine/planner.py` | 整合编译器和翻译器生成执行计划 | | **DataRouter** | `factors/engine/data_router.py` | 按需取数、组装核心宽表 | | **PolarsTranslator** | `factors/translator.py` | 将 DSL AST 翻译为 Polars 表达式 | | **ComputeEngine** | `factors/engine/compute_engine.py` | 执行 Polars 表达式计算 | | **FactorEngine** | `factors/engine/factor_engine.py` | 系统统一入口,协调各组件 | | **Storage** | `data/storage.py` | DuckDB 数据存储和查询接口 | --- ## 2. 阶段一:字符串解析 ### 2.1 解析流程 当用户调用 `parser.parse("(close / ts_delay(close, 5)) - 1")` 时,解析流程如下: ```python # 1. FormulaParser.parse() 方法 formula = "(close / ts_delay(close, 5)) - 1" ast_tree = ast.parse(formula, mode='eval') # Python AST dsl_node = self._visit(ast_tree.body) # 递归转换为 DSL Node ``` **解析步骤:** 1. **Python AST 解析**:使用 Python 标准库 `ast.parse()` 将字符串解析为 Python AST 2. **AST 遍历转换**:通过 `_visit()` 方法递归遍历 Python AST,映射为 DSL 节点 3. **节点类型映射**: - `ast.Name` → `Symbol` - `ast.Constant` → `Constant` - `ast.BinOp` → `BinaryOpNode` - `ast.UnaryOp` → `UnaryOpNode` - `ast.Call` → `FunctionNode` ### 2.2 解析结果 AST 结构 ``` BinaryOpNode(op='-', left, right=Constant(1)) ├── left: BinaryOpNode(op='/', left, right) │ ├── left: Symbol('close') │ └── right: FunctionNode('ts_delay', [Symbol('close'), Constant(5)]) └── right: Constant(1) ``` ### 2.3 函数解析机制 对于函数调用 `ts_delay(close, 5)`: ```python # _visit_Call 方法处理逻辑 def _visit_Call(self, node: ast.Call) -> FunctionNode: func_name = node.func.id # "ts_delay" # 从注册表获取函数实现 if self.registry.has(func_name): func_impl = self.registry.get(func_name) # 递归解析参数 args = [self._visit(arg) for arg in node.args] # 调用函数实现,返回 FunctionNode return func_impl(*args) ``` **关键点**:`ts_delay` 函数在 `api.py` 中定义: ```python def ts_delay(x: NodeOrStr, periods: int) -> FunctionNode: return FunctionNode('ts_delay', to_node(x), Constant(periods)) ``` --- ## 3. 阶段二:AST 编译与依赖提取 ### 3.1 依赖提取流程 解析后的 AST 需要提取依赖的原始数据字段: ```python # DependencyExtractor.extract_dependencies(node) extractor = DependencyExtractor() dependencies = extractor.extract_dependencies(dsl_node) # 结果: {'close'} ``` **提取逻辑**: ```python def _visit(self, node): if isinstance(node, Symbol): return {node.name} # 收集字段名 elif isinstance(node, (BinaryOpNode, FunctionNode)): # 递归收集子节点依赖 deps = set() for child in node.args: deps.update(self._visit(child)) return deps # ... 其他节点类型 ``` ### 3.2 依赖的作用 提取的依赖 `{close}` 用于: 1. **数据规格推导**:确定需要从数据库读取哪些字段 2. **执行计划生成**:明确数据需求,避免读取不必要的字段 --- ## 4. 阶段三:执行计划生成 ### 4.1 ExecutionPlanner 的作用 `ExecutionPlanner.create_plan()` 将 AST 转换为可执行的 `ExecutionPlan`: ```python planner = ExecutionPlanner() plan = planner.create_plan( node=dsl_node, # 解析后的 DSL 节点 output_name="returns_5d" # 输出列名 ) ``` ### 4.2 计划生成流程 ```python def create_plan(self, node, output_name): # 1. 提取依赖 dependencies = self.dependency_extractor.extract_dependencies(node) # 2. 推导数据规格 data_specs = self._infer_data_specs(node, dependencies) # 结果: [DataSpec(table='pro_bar', columns=['close'])] # 3. 翻译为 Polars 表达式 polars_expr = self.polars_translator.translate(node) # 结果: (pl.col('close') / pl.col('close').shift(5)) - 1 # 4. 构建执行计划 return ExecutionPlan( data_specs=data_specs, polars_expr=polars_expr, dependencies=dependencies, output_name=output_name ) ``` ### 4.3 数据规格推导 根据依赖字段,`_infer_data_specs` 推导出需要的数据规格: ```python def _infer_data_specs(self, node, dependencies): return [ DataSpec( table='pro_bar', # 默认使用 pro_bar 表 columns=list(dependencies), # ['close'] ) ] ``` **DataSpec 说明**: - `table`: 数据表名(pro_bar 为主力行情表) - `columns`: 需要的字段列表 **注意**:数据获取使用用户传入的日期范围,不做自动扩展。时序因子(如 `ts_delay`、`ts_mean`)在数据不足时会返回 null,这是符合预期的行为。 --- ## 5. 阶段四:数据获取 ### 5.1 DataRouter 的核心职责 `DataRouter.fetch_data()` 按需取数、组装核心宽表: ```python data_router = DataRouter(storage, start_date, end_date) wide_data = data_router.fetch_data([plan.data_specs]) ``` ### 5.2 数据获取流程 ```python def fetch_data(self, data_specs, start_date, end_date): # 1. 合并数据规格,收集所需表和字段 required_tables = self._collect_required_tables(data_specs) # 2. 加载各表数据 table_data = {} for table, columns in required_tables.items(): table_data[table] = self._load_table(table, columns, start_date, end_date) # 3. 组装宽表(left join 合并) wide_table = self._assemble_wide_table(table_data, data_specs) return wide_table ``` ### 5.3 数据库查询 `_load_table` 方法从 DuckDB 读取数据: ```python def _load_table(self, table, columns, start_date, end_date): # 通过 Storage 查询数据库 df = self.storage.load_polars( table_name=table, columns=columns + ['ts_code', 'trade_date'], # 必须包含主键 start_date=start_date, end_date=end_date ) return df ``` **Storage.load_polars 内部实现**: ```python # data/storage.py SELECT {columns} FROM {table_name} WHERE trade_date BETWEEN '{start}' AND '{end}' ORDER BY ts_code, trade_date ``` ### 5.4 宽表组装 对于 `(close / ts_delay(close, 5)) - 1`,DataRouter 返回的宽表结构: ``` ┌──────────┬────────────┬───────┐ │ ts_code │ trade_date │ close │ ├──────────┼────────────┼───────┤ │ 000001.SZ│ 20240101 │ 10.5 │ │ 000001.SZ│ 20240102 │ 10.6 │ │ ... │ ... │ ... │ │ 000002.SZ│ 20240101 │ 20.1 │ └──────────┴────────────┴───────┘ ``` **注意**:数据获取使用用户传入的日期范围,不做自动扩展。对于时序因子(如 `ts_delay(close, 5)`),如果数据不足会返回 null,这是符合预期的行为。用户如需完整计算,应显式扩展日期范围。 --- ## 6. 阶段五:Polars 翻译与计算 ### 6.1 PolarsTranslator 的作用 将 DSL AST 翻译为 Polars 表达式(惰性计算图): ```python translator = PolarsTranslator() polars_expr = translator.translate(dsl_node) ``` ### 6.2 翻译规则 | DSL 节点类型 | Polars 表达式 | 说明 | |-------------|--------------|------| | `Symbol('close')` | `pl.col('close')` | 列引用 | | `Constant(5)` | `pl.lit(5)` | 字面量 | | `BinaryOpNode('/', a, b)` | `a / b` | 算术运算 | | `FunctionNode('ts_delay', x, n)` | `x.shift(n).over('ts_code')` | 时间序列滞后 | | `FunctionNode('ts_mean', x, n)` | `x.rolling_mean(n).over('ts_code')` | 时间序列均值 | | `FunctionNode('cs_rank', x)` | `x.rank().over('trade_date')` | 截面排名 | ### 6.3 时间序列函数翻译 `ts_delay(close, 5)` 翻译为: ```python pl.col('close').shift(5).over('ts_code') ``` **关键点**: - `.shift(5)`:向后偏移 5 个位置 - `.over('ts_code')`:按股票代码分组计算(每只股票独立计算) ### 6.4 计算执行 `ComputeEngine.execute()` 执行计算: ```python compute_engine = ComputeEngine() result = compute_engine.execute(plan, wide_data) ``` **执行逻辑**: ```python def execute(self, plan, data): # 使用 Polars with_columns 添加因子列 result = data.with_columns([ plan.polars_expr.alias(plan.output_name) ]) return result ``` ### 6.5 计算结果 最终结果包含原始列和计算出的因子列: ``` ┌──────────┬────────────┬───────┬─────────────┐ │ ts_code │ trade_date │ close │ returns_5d │ ├──────────┼────────────┼───────┼─────────────┤ │ 000001.SZ│ 20240101 │ 10.5 │ null │ # 前5天无数据 │ 000001.SZ│ 20240106 │ 10.8 │ 0.0286 │ # (10.8/10.5)-1 │ ... │ ... │ ... │ ... │ └──────────┴────────────┴───────┴─────────────┘ ``` --- ## 7. 完整调用链示例 ### 7.1 用户代码 ```python from src.factors import FactorEngine # 1. 创建引擎 engine = FactorEngine() # 2. 使用字符串表达式注册因子(推荐) engine.add_factor("returns_5d", "(close / ts_delay(close, 5)) - 1") # 或者使用 DSL 表达式 from src.factors.api import close, ts_delay engine.register("returns_5d", (close / ts_delay(close, 5)) - 1) # 3. 执行计算 result = engine.compute( factor_names=["returns_5d"], start_date="20240101", end_date="20240131" ) ``` ### 7.2 内部调用链 ``` FactorEngine.add_factor() / register() │ └── 创建并缓存 ExecutionPlan └── ExecutionPlanner.create_plan() ├── DependencyExtractor.extract_dependencies() → {'close'} ├── _infer_data_specs() → [DataSpec('pro_bar', ['close'], 5)] └── PolarsTranslator.translate() → pl.col('close').shift(5).over('ts_code')... FactorEngine.compute() │ ├── 1. 获取所有缓存的执行计划 ├── 2. 合并数据规格 │ └── _merge_data_specs() ├── 3. 获取数据 │ └── DataRouter.fetch_data(merged_specs) │ ├── _load_table('pro_bar', ['close'], start_date, end_date) │ │ └── Storage.load_polars() → 查询 DuckDB │ └── _assemble_wide_table() → Polars DataFrame └── 4. 执行计算 └── ComputeEngine.execute_plans(plans, data) └── data.with_columns([polars_exprs...]) └── Polars 执行表达式计算 ``` --- ## 8. 数据流时序图 ```mermaid sequenceDiagram participant User as 用户代码 participant FE as FactorEngine participant PL as ExecutionPlanner participant DR as DataRouter participant ST as Storage participant CE as ComputeEngine participant DB as DuckDB User->>FE: compute("(close/ts_delay(close,5))-1", start, end) Note over FE: 阶段1:创建执行计划 FE->>PL: create_plan(node, output_name) PL->>PL: extract_dependencies(node) PL->>PL: infer_data_specs(node) PL->>PL: translate_to_polars(node) PL-->>FE: ExecutionPlan Note over FE: 阶段2:数据获取 FE->>DR: fetch_data(data_specs) loop 每个需要的表 DR->>ST: load_polars(table, columns, adjusted_start, end) ST->>DB: SELECT ... WHERE trade_date BETWEEN ... DB-->>ST: 原始数据 ST-->>DR: Polars DataFrame end DR->>DR: assemble_wide_table() DR->>DR: filter_by_date_range() DR-->>FE: 核心宽表 Note over FE: 阶段3:执行计算 FE->>CE: execute(plan, data) CE->>CE: data.with_columns([polars_expr]) CE-->>FE: 计算结果 FE-->>User: Polars DataFrame (含因子列) ``` --- ## 附录:关键代码片段 ### A.1 FactorEngine.compute 方法 ```python def compute( self, factor_names: list[str], start_date: str, end_date: str ) -> pl.DataFrame: # 1. 收集所有执行计划 all_plans = [] for name in factor_names: plan = self.factors[name] all_plans.append(plan) # 2. 合并数据规格 merged_specs = self._merge_data_specs([p.data_specs for p in all_plans]) # 3. 获取数据 data_router = DataRouter(self.storage, start_date, end_date) data = data_router.fetch_data(merged_specs) # 4. 执行计算 compute_engine = ComputeEngine() result = compute_engine.execute_plans(all_plans, data) return result ``` ### A.2 PolarsTranslator 的函数处理器注册 ```python class PolarsTranslator: def __init__(self): self.handlers: dict[str, Callable] = {} self._register_default_handlers() def _register_default_handlers(self): # 时间序列函数 self.handlers['ts_delay'] = lambda x, n: x.shift(n) self.handlers['ts_mean'] = lambda x, n: x.rolling_mean(n) self.handlers['ts_std'] = lambda x, n: x.rolling_std(n) # 截面函数 self.handlers['cs_rank'] = lambda x: x.rank() self.handlers['cs_zscore'] = lambda x: (x - x.mean()) / x.std() ``` ### A.3 时序因子数据获取说明 由于移除了自动日期扩展机制,用户需要显式管理时序因子的日期范围: ```python # 示例:计算 2024-01-15 到 2024-01-20 的 5 日收益率 # 需要显式提供足够的历史数据 result = engine.compute( factor_names=["returns_5d"], # (close / ts_delay(close, 5)) - 1 start_date="20240108", # 向前扩展至少 5 个交易日 end_date="20240120" ) # 在结果中,2024-01-15 之前的日期会因数据不足而返回 null # 用户可以自行过滤到目标日期范围 result = result.filter( (pl.col("trade_date") >= "20240115") & (pl.col("trade_date") <= "20240120") ) ``` **设计原则**: - 显式优于隐式:数据来源透明,用户可以完全控制数据范围 - 符合 Polars 行为:rolling/shift 操作在窗口不足时返回 null - 可验证性:用户可以明确知道用了哪些数据计算因子 --- ## 总结 ProStock 的因子计算引擎采用 **DSL(领域特定语言)+ 延迟计算** 的架构设计,具有以下特点: 1. **声明式**:用户通过数学表达式描述因子逻辑,无需关心实现细节 2. **惰性求值**:表达式构建时不立即计算,生成执行计划后统一执行 3. **智能数据获取**:自动分析依赖、推导数据规格、按需取数 4. **向量化计算**:基于 Polars 的高性能向量化运算,支持时序和截面计算 5. **可扩展性**:通过 FunctionRegistry 可以轻松添加新的因子函数 整个流程从字符串到计算结果,经历了解析 → 编译 → 计划 → 取数 → 计算五个阶段,各组件职责清晰,便于维护和扩展。