refactor(factors): 拆分 engine.py 为模块化包

将单文件 engine.py (1064行) 拆分为 engine/ 包:
- 数据规格、路由器、计划器、计算引擎、因子引擎分离
- 保持向后兼容,API 无变化
This commit is contained in:
2026-03-02 22:29:18 +08:00
parent 1c0c4a0de1
commit 77e4e94e05
7 changed files with 1146 additions and 0 deletions

View File

@@ -52,6 +52,9 @@ from src.factors.engine import (
ComputeEngine,
)
# 保持向后兼容factor_engine.py 中的类也可以通过 src.factors.engine 访问
# 例如from src.factors.engine import FactorEngine
__all__ = [
# DSL 层
"Node",

View File

@@ -0,0 +1,25 @@
"""因子计算引擎模块。
提供完整的因子计算引擎组件:
- DataSpec: 数据规格定义
- ExecutionPlan: 执行计划
- DataRouter: 数据路由器
- ExecutionPlanner: 执行计划生成器
- ComputeEngine: 计算引擎
- FactorEngine: 因子计算引擎(统一入口)
"""
from src.factors.engine.data_spec import DataSpec, ExecutionPlan
from src.factors.engine.data_router import DataRouter
from src.factors.engine.planner import ExecutionPlanner
from src.factors.engine.compute_engine import ComputeEngine
from src.factors.engine.factor_engine import FactorEngine
__all__ = [
"DataSpec",
"ExecutionPlan",
"DataRouter",
"ExecutionPlanner",
"ComputeEngine",
"FactorEngine",
]

View File

@@ -0,0 +1,155 @@
"""计算引擎。
执行并行运算,负责将执行计划应用到数据上。
"""
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor
from typing import Any, Dict, List, Optional, Set, Union
import polars as pl
from src.factors.engine.data_spec import ExecutionPlan
class ComputeEngine:
"""计算引擎 - 执行并行运算。
负责将执行计划应用到数据上,支持并行计算。
Attributes:
max_workers: 最大并行工作线程数
use_processes: 是否使用进程池CPU 密集型任务)
"""
def __init__(
self,
max_workers: int = 4,
use_processes: bool = False,
) -> None:
"""初始化计算引擎。
Args:
max_workers: 最大并行工作线程数
use_processes: 是否使用进程池代替线程池
"""
self.max_workers = max_workers
self.use_processes = use_processes
def execute(
self,
plan: ExecutionPlan,
data: pl.DataFrame,
) -> pl.DataFrame:
"""执行计算计划。
Args:
plan: 执行计划
data: 输入数据(核心宽表)
Returns:
包含因子结果的 DataFrame
"""
# 检查依赖字段是否存在
missing_cols = plan.dependencies - set(data.columns)
if missing_cols:
raise ValueError(f"数据缺少必要的字段: {missing_cols}")
# 执行计算
result = data.with_columns([plan.polars_expr.alias(plan.output_name)])
return result
def execute_batch(
self,
plans: List[ExecutionPlan],
data: pl.DataFrame,
) -> pl.DataFrame:
"""批量执行多个计算计划。
Args:
plans: 执行计划列表
data: 输入数据
Returns:
包含所有因子结果的 DataFrame
"""
result = data
for plan in plans:
result = self.execute(plan, result)
return result
def execute_parallel(
self,
plans: List[ExecutionPlan],
data: pl.DataFrame,
) -> pl.DataFrame:
"""并行执行多个计算计划。
Args:
plans: 执行计划列表
data: 输入数据
Returns:
包含所有因子结果的 DataFrame
"""
# 检查计划间依赖
independent_plans = []
dependent_plans = []
available_cols = set(data.columns)
for plan in plans:
if plan.dependencies <= available_cols:
independent_plans.append(plan)
available_cols.add(plan.output_name)
else:
dependent_plans.append(plan)
# 并行执行独立计划
if independent_plans:
ExecutorClass = (
ProcessPoolExecutor if self.use_processes else ThreadPoolExecutor
)
with ExecutorClass(max_workers=self.max_workers) as executor:
futures = {
executor.submit(self._execute_single, plan, data): plan
for plan in independent_plans
}
results = []
for future in futures:
plan = futures[future]
try:
result_col = future.result()
results.append((plan.output_name, result_col))
except Exception as e:
raise RuntimeError(f"计算因子 {plan.output_name} 失败: {e}")
# 合并结果
for name, series in results:
data = data.with_columns([series.alias(name)])
# 顺序执行依赖计划
for plan in dependent_plans:
data = self.execute(plan, data)
return data
def _execute_single(
self,
plan: ExecutionPlan,
data: pl.DataFrame,
) -> pl.Series:
"""执行单个计划并返回结果列。
Args:
plan: 执行计划
data: 输入数据
Returns:
计算结果序列
"""
result = self.execute(plan, data)
return result[plan.output_name]

View File

@@ -0,0 +1,304 @@
"""数据路由器。
按需取数、组装核心宽表。
负责根据数据规格从数据源拉取数据,并组装成统一的宽表格式。
支持内存数据源(用于测试)和真实数据库连接。
"""
from typing import Any, Dict, List, Optional, Set, Union
import threading
import polars as pl
from src.factors.engine.data_spec import DataSpec
from src.data.storage import Storage
class DataRouter:
"""数据路由器 - 按需取数、组装核心宽表。
负责根据数据规格从数据源拉取数据,并组装成统一的宽表格式。
支持内存数据源(用于测试)和真实数据库连接。
Attributes:
data_source: 数据源,可以是内存 DataFrame 字典或数据库连接
is_memory_mode: 是否为内存模式
"""
def __init__(self, data_source: Optional[Dict[str, pl.DataFrame]] = None) -> None:
"""初始化数据路由器。
Args:
data_source: 内存数据源,字典格式 {表名: DataFrame}
为 None 时自动连接 DuckDB 数据库
"""
self.data_source = data_source or {}
self.is_memory_mode = data_source is not None
self._cache: Dict[str, pl.DataFrame] = {}
self._lock = threading.Lock()
# 数据库模式下初始化 Storage
if not self.is_memory_mode:
self._storage = Storage()
else:
self._storage = None
def fetch_data(
self,
data_specs: List[DataSpec],
start_date: str,
end_date: str,
stock_codes: Optional[List[str]] = None,
) -> pl.DataFrame:
"""根据数据规格获取并组装核心宽表。
Args:
data_specs: 数据规格列表
start_date: 开始日期 (YYYYMMDD)
end_date: 结束日期 (YYYYMMDD)
stock_codes: 股票代码列表None 表示全市场
Returns:
组装好的核心宽表 DataFrame
Raises:
ValueError: 当数据源中缺少必要的表或字段时
"""
if not data_specs:
raise ValueError("数据规格不能为空")
# 收集所有需要的表和字段
required_tables: Dict[str, Set[str]] = {}
max_lookback = 0
for spec in data_specs:
if spec.table not in required_tables:
required_tables[spec.table] = set()
required_tables[spec.table].update(spec.columns)
max_lookback = max(max_lookback, spec.lookback_days)
# 调整日期范围以包含回看期
adjusted_start = self._adjust_start_date(start_date, max_lookback)
# 从数据源获取各表数据
table_data = {}
for table_name, columns in required_tables.items():
df = self._load_table(
table_name=table_name,
columns=list(columns),
start_date=adjusted_start,
end_date=end_date,
stock_codes=stock_codes,
)
table_data[table_name] = df
# 组装核心宽表
core_table = self._assemble_wide_table(table_data, required_tables)
# 过滤到实际请求日期范围
core_table = core_table.filter(
(pl.col("trade_date") >= start_date) & (pl.col("trade_date") <= end_date)
)
return core_table
def _load_table(
self,
table_name: str,
columns: List[str],
start_date: str,
end_date: str,
stock_codes: Optional[List[str]] = None,
) -> pl.DataFrame:
"""加载单个表的数据。
Args:
table_name: 表名
columns: 需要的字段
start_date: 开始日期
end_date: 结束日期
stock_codes: 股票代码过滤
Returns:
过滤后的 DataFrame
"""
cache_key = f"{table_name}_{start_date}_{end_date}_{stock_codes}"
with self._lock:
if cache_key in self._cache:
return self._cache[cache_key]
if self.is_memory_mode:
df = self._load_from_memory(
table_name, columns, start_date, end_date, stock_codes
)
else:
df = self._load_from_database(
table_name, columns, start_date, end_date, stock_codes
)
with self._lock:
self._cache[cache_key] = df
return df
def _load_from_memory(
self,
table_name: str,
columns: List[str],
start_date: str,
end_date: str,
stock_codes: Optional[List[str]] = None,
) -> pl.DataFrame:
"""从内存数据源加载数据。"""
if table_name not in self.data_source:
raise ValueError(f"内存数据源中缺少表: {table_name}")
df = self.data_source[table_name]
# 确保必需字段存在
for col in columns:
if col not in df.columns and col not in ["ts_code", "trade_date"]:
raise ValueError(f"{table_name} 缺少字段: {col}")
# 过滤日期和股票
df = df.filter(
(pl.col("trade_date") >= start_date) & (pl.col("trade_date") <= end_date)
)
if stock_codes is not None:
df = df.filter(pl.col("ts_code").is_in(stock_codes))
# 选择需要的列
select_cols = ["ts_code", "trade_date"] + [
c for c in columns if c in df.columns
]
return df.select(select_cols)
def _load_from_database(
self,
table_name: str,
columns: List[str],
start_date: str,
end_date: str,
stock_codes: Optional[List[str]] = None,
) -> pl.DataFrame:
"""从 DuckDB 数据库加载数据。
利用 Storage.load_polars() 方法,支持 SQL 查询下推。
"""
if self._storage is None:
raise RuntimeError("Storage 未初始化")
# 检查表是否存在
if not self._storage.exists(table_name):
raise ValueError(f"数据库中不存在表: {table_name}")
# 构建查询参数
# Storage.load_polars 目前只支持单个 ts_code需要处理列表情况
if stock_codes is not None and len(stock_codes) == 1:
ts_code_filter = stock_codes[0]
else:
ts_code_filter = None
try:
# 从数据库加载原始数据
df = self._storage.load_polars(
name=table_name,
start_date=start_date,
end_date=end_date,
ts_code=ts_code_filter,
)
except Exception as e:
raise RuntimeError(f"从数据库加载表 {table_name} 失败: {e}")
# 如果 stock_codes 是列表且长度 > 1在内存中过滤
if stock_codes is not None and len(stock_codes) > 1:
df = df.filter(pl.col("ts_code").is_in(stock_codes))
# 检查必需字段
for col in columns:
if col not in df.columns and col not in ["ts_code", "trade_date"]:
raise ValueError(f"{table_name} 缺少字段: {col}")
# 选择需要的列
select_cols = ["ts_code", "trade_date"] + [
c for c in columns if c in df.columns
]
return df.select(select_cols)
def _assemble_wide_table(
self,
table_data: Dict[str, pl.DataFrame],
required_tables: Dict[str, Set[str]],
) -> pl.DataFrame:
"""组装多表数据为核心宽表。
使用 left join 合并各表数据,以第一个表为基准。
Args:
table_data: 表名到 DataFrame 的映射
required_tables: 表名到字段集合的映射
Returns:
组装后的宽表
"""
if not table_data:
raise ValueError("没有数据可组装")
# 以第一个表为基准
base_table_name = list(table_data.keys())[0]
result = table_data[base_table_name]
# 与其他表 join
for table_name, df in table_data.items():
if table_name == base_table_name:
continue
# 使用 ts_code 和 trade_date 作为 join 键
result = result.join(
df,
on=["ts_code", "trade_date"],
how="left",
)
return result
def _adjust_start_date(self, start_date: str, lookback_days: int) -> str:
"""根据回看天数调整开始日期。
Args:
start_date: 原始开始日期 (YYYYMMDD)
lookback_days: 需要回看的交易日数
Returns:
调整后的开始日期
"""
# 简化的日期调整假设每月30天向前推移
# 实际应用中应该使用交易日历
year = int(start_date[:4])
month = int(start_date[4:6])
day = int(start_date[6:8])
total_days = lookback_days + 30 # 额外缓冲
day -= total_days
while day <= 0:
month -= 1
if month <= 0:
month = 12
year -= 1
day += 30
return f"{year:04d}{month:02d}{day:02d}"
def clear_cache(self) -> None:
"""清除数据缓存。"""
with self._lock:
self._cache.clear()
# 数据库模式下清理 Storage 连接(可选)
if not self.is_memory_mode and self._storage is not None:
# Storage 使用单例模式,不需要关闭连接
pass

View File

@@ -0,0 +1,47 @@
"""数据规格和执行计划定义。
定义因子计算所需的数据规格和执行计划结构。
"""
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Set, Union
import polars as pl
@dataclass
class DataSpec:
"""数据规格定义。
描述因子计算所需的数据表和字段。
Attributes:
table: 数据表名称
columns: 需要的字段列表
lookback_days: 回看天数(用于时序计算)
"""
table: str
columns: List[str]
lookback_days: int = 1
@dataclass
class ExecutionPlan:
"""执行计划。
包含完整的执行所需信息:数据源、转换逻辑、输出格式。
Attributes:
data_specs: 数据规格列表
polars_expr: Polars 表达式
dependencies: 依赖的原始字段
output_name: 输出因子名称
factor_dependencies: 依赖的其他因子名称(用于分步执行)
"""
data_specs: List[DataSpec]
polars_expr: pl.Expr
dependencies: Set[str]
output_name: str
factor_dependencies: Set[str] = field(default_factory=set)

View File

@@ -0,0 +1,442 @@
"""因子计算引擎 - 系统统一入口。
提供从表达式到结果的完整执行链路,是研究员使用系统的唯一接口。
执行流程:
1. 注册表达式 -> 调用编译器解析依赖
2. 调用路由器连接数据库拉取并组装核心宽表
3. 调用翻译器生成物理执行计划
4. 将计划提交给计算引擎执行并行运算
5. 返回包含因子结果的数据表
"""
from typing import Any, Dict, List, Optional, Set, Union
import polars as pl
from src.factors.dsl import (
Node,
Symbol,
BinaryOpNode,
UnaryOpNode,
FunctionNode,
)
from src.factors.translator import PolarsTranslator
from src.factors.engine.data_spec import DataSpec, ExecutionPlan
from src.factors.engine.data_router import DataRouter
from src.factors.engine.planner import ExecutionPlanner
from src.factors.engine.compute_engine import ComputeEngine
class FactorEngine:
"""因子计算引擎 - 系统统一入口。
提供从表达式到结果的完整执行链路,是研究员使用系统的唯一接口。
执行流程:
1. 注册表达式 -> 调用编译器解析依赖
2. 调用路由器连接数据库拉取并组装核心宽表
3. 调用翻译器生成物理执行计划
4. 将计划提交给计算引擎执行并行运算
5. 返回包含因子结果的数据表
Attributes:
router: 数据路由器
planner: 执行计划生成器
compute_engine: 计算引擎
registered_expressions: 注册的表达式字典
"""
def __init__(
self,
data_source: Optional[Dict[str, pl.DataFrame]] = None,
max_workers: int = 4,
) -> None:
"""初始化因子引擎。
Args:
data_source: 内存数据源,为 None 时使用数据库连接
max_workers: 并行计算的最大工作线程数
"""
self.router = DataRouter(data_source)
self.planner = ExecutionPlanner()
self.compute_engine = ComputeEngine(max_workers=max_workers)
self.registered_expressions: Dict[str, Node] = {}
self._plans: Dict[str, ExecutionPlan] = {}
def register(
self,
name: str,
expression: Node,
data_specs: Optional[List[DataSpec]] = None,
) -> "FactorEngine":
"""注册因子表达式。
Args:
name: 因子名称
expression: DSL 表达式
data_specs: 数据规格None 时自动推导
Returns:
self支持链式调用
Example:
>>> from src.factors.api import close, ts_mean
>>> engine = FactorEngine()
>>> engine.register("ma20", ts_mean(close, 20))
"""
# 检测因子依赖(在注册当前因子之前检查其他已注册因子)
factor_deps = self._find_factor_dependencies(expression)
self.registered_expressions[name] = expression
# 预创建执行计划
plan = self.planner.create_plan(
expression=expression,
output_name=name,
data_specs=data_specs,
)
# 添加因子依赖信息
plan.factor_dependencies = factor_deps
self._plans[name] = plan
return self
def compute(
self,
factor_names: Union[str, List[str]],
start_date: str,
end_date: str,
stock_codes: Optional[List[str]] = None,
) -> pl.DataFrame:
"""计算指定因子的值。
完整的执行流程:取数 -> 组装 -> 翻译 -> 计算。
Args:
factor_names: 因子名称或名称列表
start_date: 开始日期 (YYYYMMDD)
end_date: 结束日期 (YYYYMMDD)
stock_codes: 股票代码列表None 表示全市场
Returns:
包含因子结果的数据表
Raises:
ValueError: 当因子未注册或数据不足时
Example:
>>> result = engine.compute("ma20", "20240101", "20240131")
>>> result = engine.compute(["ma20", "rsi"], "20240101", "20240131")
"""
# 标准化因子名称
if isinstance(factor_names, str):
factor_names = [factor_names]
# 1. 获取执行计划
plans = []
for name in factor_names:
if name not in self._plans:
raise ValueError(f"因子未注册: {name}")
plans.append(self._plans[name])
# 2. 合并数据规格并获取数据
all_specs = []
for plan in plans:
all_specs.extend(plan.data_specs)
# 3. 从路由器获取核心宽表
core_data = self.router.fetch_data(
data_specs=all_specs,
start_date=start_date,
end_date=end_date,
stock_codes=stock_codes,
)
if len(core_data) == 0:
raise ValueError("未获取到任何数据,请检查日期范围和股票代码")
# 4. 按依赖顺序执行计算
if len(plans) == 1:
result = self.compute_engine.execute(plans[0], core_data)
else:
# 使用依赖感知的方式执行
result = self._execute_with_dependencies(factor_names, core_data)
return result
def list_registered(self) -> List[str]:
"""获取已注册的因子列表。
Returns:
因子名称列表
"""
return list(self.registered_expressions.keys())
def get_expression(self, name: str) -> Optional[Node]:
"""获取已注册的表达式。
Args:
name: 因子名称
Returns:
表达式节点,未注册时返回 None
"""
return self.registered_expressions.get(name)
def clear(self) -> None:
"""清除所有注册的表达式和缓存。"""
self.registered_expressions.clear()
self._plans.clear()
self.router.clear_cache()
def preview_plan(self, factor_name: str) -> Optional[ExecutionPlan]:
"""预览因子的执行计划。
Args:
factor_name: 因子名称
Returns:
执行计划,未注册时返回 None
"""
return self._plans.get(factor_name)
def _execute_with_dependencies(
self,
factor_names: List[str],
core_data: pl.DataFrame,
) -> pl.DataFrame:
"""按依赖顺序执行因子计算。
支持 cs_rank 等需要依赖列已存在的场景。
Args:
factor_names: 因子名称列表
core_data: 核心宽表数据
Returns:
包含所有因子结果的数据表
"""
# 1. 拓扑排序
sorted_names = self._topological_sort(factor_names)
# 2. 按顺序执行
result = core_data
for name in sorted_names:
plan = self._plans[name]
# 创建新的执行计划,引用已计算的依赖列
new_plan = self._create_optimized_plan(plan, result)
# 执行计算
result = self.compute_engine.execute(new_plan, result)
return result
def _create_optimized_plan(
self,
plan: ExecutionPlan,
current_data: pl.DataFrame,
) -> ExecutionPlan:
"""创建优化的执行计划。
将表达式中已计算的依赖因子替换为列引用。
Args:
plan: 原始执行计划
current_data: 当前数据(包含已计算的依赖列)
Returns:
新的执行计划
"""
from src.factors.dsl import Symbol
# 获取当前数据中已存在的列
existing_cols = set(current_data.columns)
# 检查依赖列是否已存在
deps_available = plan.factor_dependencies & existing_cols
if not deps_available:
# 没有可用的依赖列,直接返回原计划
return plan
# 获取原始表达式
original_expr = self.registered_expressions[plan.output_name]
# 创建新的表达式,用 Symbol 引用替换依赖因子
def replace_with_symbol(node: Node) -> Node:
"""递归替换表达式中的依赖因子为 Symbol 引用。"""
from typing import Any
n: Any = node
# 检查当前节点是否等于某个已计算依赖因子
for dep_name in deps_available:
dep_expr = self.registered_expressions[dep_name]
if self._expressions_equal(node, dep_expr):
return Symbol(dep_name)
# 递归处理子节点
if isinstance(n, BinaryOpNode):
new_left = replace_with_symbol(n.left)
new_right = replace_with_symbol(n.right)
if new_left is not n.left or new_right is not n.right:
return BinaryOpNode(n.op, new_left, new_right)
elif isinstance(n, UnaryOpNode):
new_operand = replace_with_symbol(n.operand)
if new_operand is not n.operand:
return UnaryOpNode(n.op, new_operand)
elif isinstance(n, FunctionNode):
new_args = [replace_with_symbol(arg) for arg in n.args]
if any(
new_arg is not old_arg for new_arg, old_arg in zip(new_args, n.args)
):
return FunctionNode(n.func_name, *new_args)
return node
# 替换表达式
new_expr = replace_with_symbol(original_expr)
# 重新翻译表达式
translator = PolarsTranslator()
new_polars_expr = translator.translate(new_expr)
# 更新依赖集合
new_factor_deps = plan.factor_dependencies - deps_available
new_deps = plan.dependencies | deps_available
return ExecutionPlan(
data_specs=plan.data_specs,
polars_expr=new_polars_expr,
dependencies=new_deps,
output_name=plan.output_name,
factor_dependencies=new_factor_deps,
)
def _expressions_equal(self, expr1: Node, expr2: Node) -> bool:
"""比较两个表达式是否相等。
用于检测因子间的依赖关系。
Args:
expr1: 第一个表达式
expr2: 第二个表达式
Returns:
是否相等
"""
from typing import Any
e1: Any = expr1
e2: Any = expr2
if type(e1) != type(e2):
return False
if isinstance(e1, Symbol):
return e1.name == e2.name
from src.factors.dsl import Constant
if isinstance(e1, Constant):
return e1.value == e2.value
if isinstance(e1, BinaryOpNode):
return (
e1.op == e2.op
and self._expressions_equal(e1.left, e2.left)
and self._expressions_equal(e1.right, e2.right)
)
if isinstance(e1, UnaryOpNode):
return e1.op == e2.op and self._expressions_equal(e1.operand, e2.operand)
if isinstance(e1, FunctionNode):
if e1.func_name != e2.func_name or len(e1.args) != len(e2.args):
return False
return all(
self._expressions_equal(a1, a2) for a1, a2 in zip(e1.args, e2.args)
)
return False
def _find_factor_dependencies(self, expression: Node) -> Set[str]:
"""查找表达式依赖的其他因子。
遍历已注册因子,检查表达式是否包含任何已注册因子的完整表达式。
Args:
expression: 待检查的表达式
Returns:
依赖的因子名称集合
"""
deps: Set[str] = set()
# 检查表达式本身是否等于某个已注册因子
for name, registered_expr in self.registered_expressions.items():
if self._expressions_equal(expression, registered_expr):
deps.add(name)
break
# 递归检查子节点
if isinstance(expression, BinaryOpNode):
deps.update(self._find_factor_dependencies(expression.left))
deps.update(self._find_factor_dependencies(expression.right))
elif isinstance(expression, UnaryOpNode):
deps.update(self._find_factor_dependencies(expression.operand))
elif isinstance(expression, FunctionNode):
for arg in expression.args:
deps.update(self._find_factor_dependencies(arg))
return deps
def _topological_sort(self, factor_names: List[str]) -> List[str]:
"""按依赖关系对因子进行拓扑排序。
确保依赖的因子先被计算。
Args:
factor_names: 因子名称列表
Returns:
排序后的因子名称列表
Raises:
ValueError: 当检测到循环依赖时
"""
# 构建依赖图
graph: Dict[str, Set[str]] = {}
in_degree: Dict[str, int] = {}
for name in factor_names:
plan = self._plans[name]
# 只考虑在本次计算范围内的依赖
deps = plan.factor_dependencies & set(factor_names)
graph[name] = deps
in_degree[name] = len(deps)
# Kahn 算法
result = []
queue = [name for name, degree in in_degree.items() if degree == 0]
while queue:
# 按原始顺序处理同级别的因子
queue.sort(key=lambda x: factor_names.index(x))
name = queue.pop(0)
result.append(name)
for other in factor_names:
if name in graph[other]:
in_degree[other] -= 1
if in_degree[other] == 0:
queue.append(other)
if len(result) != len(factor_names):
raise ValueError("检测到因子循环依赖")
return result

View File

@@ -0,0 +1,170 @@
"""执行计划生成器。
整合编译器和翻译器,生成完整的执行计划。
"""
from typing import Any, Dict, List, Optional, Set, Union
from src.factors.dsl import (
Node,
Symbol,
FunctionNode,
BinaryOpNode,
UnaryOpNode,
Constant,
)
from src.factors.compiler import DependencyExtractor
from src.factors.translator import PolarsTranslator
from src.factors.engine.data_spec import DataSpec, ExecutionPlan
class ExecutionPlanner:
"""执行计划生成器。
整合编译器和翻译器,生成完整的执行计划。
Attributes:
compiler: 依赖提取器
translator: Polars 翻译器
"""
def __init__(self) -> None:
"""初始化执行计划生成器。"""
self.compiler = DependencyExtractor()
self.translator = PolarsTranslator()
def create_plan(
self,
expression: Node,
output_name: str = "factor",
data_specs: Optional[List[DataSpec]] = None,
) -> ExecutionPlan:
"""从表达式创建执行计划。
Args:
expression: DSL 表达式节点
output_name: 输出因子名称
data_specs: 预定义的数据规格None 时自动推导
Returns:
执行计划对象
"""
# 1. 提取依赖
dependencies = self.compiler.extract_dependencies(expression)
# 2. 翻译为 Polars 表达式
polars_expr = self.translator.translate(expression)
# 3. 推导或验证数据规格
if data_specs is None:
data_specs = self._infer_data_specs(dependencies, expression)
return ExecutionPlan(
data_specs=data_specs,
polars_expr=polars_expr,
dependencies=dependencies,
output_name=output_name,
)
def _infer_data_specs(
self,
dependencies: Set[str],
expression: Node,
) -> List[DataSpec]:
"""从依赖推导数据规格。
根据表达式中的函数类型推断回看天数需求。
基础行情字段open, high, low, close, vol, amount, pre_close, change, pct_chg
默认从 pro_bar 表获取。
Args:
dependencies: 依赖的字段集合
expression: 表达式节点
Returns:
数据规格列表
"""
# 计算最大回看窗口
max_window = self._extract_max_window(expression)
lookback_days = max(1, max_window)
# 基础行情字段集合(这些字段从 pro_bar 表获取)
pro_bar_fields = {
"open",
"high",
"low",
"close",
"vol",
"amount",
"pre_close",
"change",
"pct_chg",
"turnover_rate",
"volume_ratio",
}
# 将依赖分为 pro_bar 字段和其他字段
pro_bar_deps = dependencies & pro_bar_fields
other_deps = dependencies - pro_bar_fields
data_specs = []
# pro_bar 表的数据规格
if pro_bar_deps:
data_specs.append(
DataSpec(
table="pro_bar",
columns=sorted(pro_bar_deps),
lookback_days=lookback_days,
)
)
# 其他字段从 daily 表获取
if other_deps:
data_specs.append(
DataSpec(
table="daily",
columns=sorted(other_deps),
lookback_days=lookback_days,
)
)
return data_specs
def _extract_max_window(self, node: Node) -> int:
"""从表达式中提取最大窗口大小。
Args:
node: AST 节点
Returns:
最大窗口大小,无时序函数返回 1
"""
if isinstance(node, FunctionNode):
window = 1
# 检查函数参数中的窗口大小
for arg in node.args:
if (
isinstance(arg, Constant)
and isinstance(arg.value, int)
and arg.value > window
):
window = arg.value
# 递归检查子表达式
for arg in node.args:
if isinstance(arg, Node) and not isinstance(arg, Constant):
window = max(window, self._extract_max_window(arg))
return window
elif isinstance(node, BinaryOpNode):
return max(
self._extract_max_window(node.left),
self._extract_max_window(node.right),
)
elif isinstance(node, UnaryOpNode):
return self._extract_max_window(node.operand)
return 1