Files
ProStock/src/factors/engine/factor_engine.py
liaozhaorun 77e4e94e05 refactor(factors): 拆分 engine.py 为模块化包
将单文件 engine.py (1064行) 拆分为 engine/ 包:
- 数据规格、路由器、计划器、计算引擎、因子引擎分离
- 保持向后兼容,API 无变化
2026-03-02 22:29:18 +08:00

443 lines
14 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""因子计算引擎 - 系统统一入口。
提供从表达式到结果的完整执行链路,是研究员使用系统的唯一接口。
执行流程:
1. 注册表达式 -> 调用编译器解析依赖
2. 调用路由器连接数据库拉取并组装核心宽表
3. 调用翻译器生成物理执行计划
4. 将计划提交给计算引擎执行并行运算
5. 返回包含因子结果的数据表
"""
from typing import Any, Dict, List, Optional, Set, Union
import polars as pl
from src.factors.dsl import (
Node,
Symbol,
BinaryOpNode,
UnaryOpNode,
FunctionNode,
)
from src.factors.translator import PolarsTranslator
from src.factors.engine.data_spec import DataSpec, ExecutionPlan
from src.factors.engine.data_router import DataRouter
from src.factors.engine.planner import ExecutionPlanner
from src.factors.engine.compute_engine import ComputeEngine
class FactorEngine:
"""因子计算引擎 - 系统统一入口。
提供从表达式到结果的完整执行链路,是研究员使用系统的唯一接口。
执行流程:
1. 注册表达式 -> 调用编译器解析依赖
2. 调用路由器连接数据库拉取并组装核心宽表
3. 调用翻译器生成物理执行计划
4. 将计划提交给计算引擎执行并行运算
5. 返回包含因子结果的数据表
Attributes:
router: 数据路由器
planner: 执行计划生成器
compute_engine: 计算引擎
registered_expressions: 注册的表达式字典
"""
def __init__(
self,
data_source: Optional[Dict[str, pl.DataFrame]] = None,
max_workers: int = 4,
) -> None:
"""初始化因子引擎。
Args:
data_source: 内存数据源,为 None 时使用数据库连接
max_workers: 并行计算的最大工作线程数
"""
self.router = DataRouter(data_source)
self.planner = ExecutionPlanner()
self.compute_engine = ComputeEngine(max_workers=max_workers)
self.registered_expressions: Dict[str, Node] = {}
self._plans: Dict[str, ExecutionPlan] = {}
def register(
self,
name: str,
expression: Node,
data_specs: Optional[List[DataSpec]] = None,
) -> "FactorEngine":
"""注册因子表达式。
Args:
name: 因子名称
expression: DSL 表达式
data_specs: 数据规格None 时自动推导
Returns:
self支持链式调用
Example:
>>> from src.factors.api import close, ts_mean
>>> engine = FactorEngine()
>>> engine.register("ma20", ts_mean(close, 20))
"""
# 检测因子依赖(在注册当前因子之前检查其他已注册因子)
factor_deps = self._find_factor_dependencies(expression)
self.registered_expressions[name] = expression
# 预创建执行计划
plan = self.planner.create_plan(
expression=expression,
output_name=name,
data_specs=data_specs,
)
# 添加因子依赖信息
plan.factor_dependencies = factor_deps
self._plans[name] = plan
return self
def compute(
self,
factor_names: Union[str, List[str]],
start_date: str,
end_date: str,
stock_codes: Optional[List[str]] = None,
) -> pl.DataFrame:
"""计算指定因子的值。
完整的执行流程:取数 -> 组装 -> 翻译 -> 计算。
Args:
factor_names: 因子名称或名称列表
start_date: 开始日期 (YYYYMMDD)
end_date: 结束日期 (YYYYMMDD)
stock_codes: 股票代码列表None 表示全市场
Returns:
包含因子结果的数据表
Raises:
ValueError: 当因子未注册或数据不足时
Example:
>>> result = engine.compute("ma20", "20240101", "20240131")
>>> result = engine.compute(["ma20", "rsi"], "20240101", "20240131")
"""
# 标准化因子名称
if isinstance(factor_names, str):
factor_names = [factor_names]
# 1. 获取执行计划
plans = []
for name in factor_names:
if name not in self._plans:
raise ValueError(f"因子未注册: {name}")
plans.append(self._plans[name])
# 2. 合并数据规格并获取数据
all_specs = []
for plan in plans:
all_specs.extend(plan.data_specs)
# 3. 从路由器获取核心宽表
core_data = self.router.fetch_data(
data_specs=all_specs,
start_date=start_date,
end_date=end_date,
stock_codes=stock_codes,
)
if len(core_data) == 0:
raise ValueError("未获取到任何数据,请检查日期范围和股票代码")
# 4. 按依赖顺序执行计算
if len(plans) == 1:
result = self.compute_engine.execute(plans[0], core_data)
else:
# 使用依赖感知的方式执行
result = self._execute_with_dependencies(factor_names, core_data)
return result
def list_registered(self) -> List[str]:
"""获取已注册的因子列表。
Returns:
因子名称列表
"""
return list(self.registered_expressions.keys())
def get_expression(self, name: str) -> Optional[Node]:
"""获取已注册的表达式。
Args:
name: 因子名称
Returns:
表达式节点,未注册时返回 None
"""
return self.registered_expressions.get(name)
def clear(self) -> None:
"""清除所有注册的表达式和缓存。"""
self.registered_expressions.clear()
self._plans.clear()
self.router.clear_cache()
def preview_plan(self, factor_name: str) -> Optional[ExecutionPlan]:
"""预览因子的执行计划。
Args:
factor_name: 因子名称
Returns:
执行计划,未注册时返回 None
"""
return self._plans.get(factor_name)
def _execute_with_dependencies(
self,
factor_names: List[str],
core_data: pl.DataFrame,
) -> pl.DataFrame:
"""按依赖顺序执行因子计算。
支持 cs_rank 等需要依赖列已存在的场景。
Args:
factor_names: 因子名称列表
core_data: 核心宽表数据
Returns:
包含所有因子结果的数据表
"""
# 1. 拓扑排序
sorted_names = self._topological_sort(factor_names)
# 2. 按顺序执行
result = core_data
for name in sorted_names:
plan = self._plans[name]
# 创建新的执行计划,引用已计算的依赖列
new_plan = self._create_optimized_plan(plan, result)
# 执行计算
result = self.compute_engine.execute(new_plan, result)
return result
def _create_optimized_plan(
self,
plan: ExecutionPlan,
current_data: pl.DataFrame,
) -> ExecutionPlan:
"""创建优化的执行计划。
将表达式中已计算的依赖因子替换为列引用。
Args:
plan: 原始执行计划
current_data: 当前数据(包含已计算的依赖列)
Returns:
新的执行计划
"""
from src.factors.dsl import Symbol
# 获取当前数据中已存在的列
existing_cols = set(current_data.columns)
# 检查依赖列是否已存在
deps_available = plan.factor_dependencies & existing_cols
if not deps_available:
# 没有可用的依赖列,直接返回原计划
return plan
# 获取原始表达式
original_expr = self.registered_expressions[plan.output_name]
# 创建新的表达式,用 Symbol 引用替换依赖因子
def replace_with_symbol(node: Node) -> Node:
"""递归替换表达式中的依赖因子为 Symbol 引用。"""
from typing import Any
n: Any = node
# 检查当前节点是否等于某个已计算依赖因子
for dep_name in deps_available:
dep_expr = self.registered_expressions[dep_name]
if self._expressions_equal(node, dep_expr):
return Symbol(dep_name)
# 递归处理子节点
if isinstance(n, BinaryOpNode):
new_left = replace_with_symbol(n.left)
new_right = replace_with_symbol(n.right)
if new_left is not n.left or new_right is not n.right:
return BinaryOpNode(n.op, new_left, new_right)
elif isinstance(n, UnaryOpNode):
new_operand = replace_with_symbol(n.operand)
if new_operand is not n.operand:
return UnaryOpNode(n.op, new_operand)
elif isinstance(n, FunctionNode):
new_args = [replace_with_symbol(arg) for arg in n.args]
if any(
new_arg is not old_arg for new_arg, old_arg in zip(new_args, n.args)
):
return FunctionNode(n.func_name, *new_args)
return node
# 替换表达式
new_expr = replace_with_symbol(original_expr)
# 重新翻译表达式
translator = PolarsTranslator()
new_polars_expr = translator.translate(new_expr)
# 更新依赖集合
new_factor_deps = plan.factor_dependencies - deps_available
new_deps = plan.dependencies | deps_available
return ExecutionPlan(
data_specs=plan.data_specs,
polars_expr=new_polars_expr,
dependencies=new_deps,
output_name=plan.output_name,
factor_dependencies=new_factor_deps,
)
def _expressions_equal(self, expr1: Node, expr2: Node) -> bool:
"""比较两个表达式是否相等。
用于检测因子间的依赖关系。
Args:
expr1: 第一个表达式
expr2: 第二个表达式
Returns:
是否相等
"""
from typing import Any
e1: Any = expr1
e2: Any = expr2
if type(e1) != type(e2):
return False
if isinstance(e1, Symbol):
return e1.name == e2.name
from src.factors.dsl import Constant
if isinstance(e1, Constant):
return e1.value == e2.value
if isinstance(e1, BinaryOpNode):
return (
e1.op == e2.op
and self._expressions_equal(e1.left, e2.left)
and self._expressions_equal(e1.right, e2.right)
)
if isinstance(e1, UnaryOpNode):
return e1.op == e2.op and self._expressions_equal(e1.operand, e2.operand)
if isinstance(e1, FunctionNode):
if e1.func_name != e2.func_name or len(e1.args) != len(e2.args):
return False
return all(
self._expressions_equal(a1, a2) for a1, a2 in zip(e1.args, e2.args)
)
return False
def _find_factor_dependencies(self, expression: Node) -> Set[str]:
"""查找表达式依赖的其他因子。
遍历已注册因子,检查表达式是否包含任何已注册因子的完整表达式。
Args:
expression: 待检查的表达式
Returns:
依赖的因子名称集合
"""
deps: Set[str] = set()
# 检查表达式本身是否等于某个已注册因子
for name, registered_expr in self.registered_expressions.items():
if self._expressions_equal(expression, registered_expr):
deps.add(name)
break
# 递归检查子节点
if isinstance(expression, BinaryOpNode):
deps.update(self._find_factor_dependencies(expression.left))
deps.update(self._find_factor_dependencies(expression.right))
elif isinstance(expression, UnaryOpNode):
deps.update(self._find_factor_dependencies(expression.operand))
elif isinstance(expression, FunctionNode):
for arg in expression.args:
deps.update(self._find_factor_dependencies(arg))
return deps
def _topological_sort(self, factor_names: List[str]) -> List[str]:
"""按依赖关系对因子进行拓扑排序。
确保依赖的因子先被计算。
Args:
factor_names: 因子名称列表
Returns:
排序后的因子名称列表
Raises:
ValueError: 当检测到循环依赖时
"""
# 构建依赖图
graph: Dict[str, Set[str]] = {}
in_degree: Dict[str, int] = {}
for name in factor_names:
plan = self._plans[name]
# 只考虑在本次计算范围内的依赖
deps = plan.factor_dependencies & set(factor_names)
graph[name] = deps
in_degree[name] = len(deps)
# Kahn 算法
result = []
queue = [name for name, degree in in_degree.items() if degree == 0]
while queue:
# 按原始顺序处理同级别的因子
queue.sort(key=lambda x: factor_names.index(x))
name = queue.pop(0)
result.append(name)
for other in factor_names:
if name in graph[other]:
in_degree[other] -= 1
if in_degree[other] == 0:
queue.append(other)
if len(result) != len(factor_names):
raise ValueError("检测到因子循环依赖")
return result