Files
ProStock/src/factors/engine/factor_engine.py

675 lines
23 KiB
Python
Raw Normal View History

"""因子计算引擎 - 系统统一入口。
提供从表达式到结果的完整执行链路是研究员使用系统的唯一接口
执行流程:
1. 注册表达式 -> 调用编译器解析依赖
2. 调用路由器连接数据库拉取并组装核心宽表
3. 调用翻译器生成物理执行计划
4. 将计划提交给计算引擎执行并行运算
5. 返回包含因子结果的数据表
"""
from typing import Any, Dict, List, Optional, Set, Union, TYPE_CHECKING
import polars as pl
if TYPE_CHECKING:
from src.factors.registry import FunctionRegistry
from src.factors.metadata import FactorManager
from src.factors.dsl import (
Node,
Symbol,
BinaryOpNode,
UnaryOpNode,
FunctionNode,
)
from src.factors.translator import PolarsTranslator
from src.factors.engine.data_spec import DataSpec, ExecutionPlan
from src.factors.engine.data_router import DataRouter
from src.factors.engine.planner import ExecutionPlanner
from src.factors.engine.compute_engine import ComputeEngine
from src.factors.engine.ast_optimizer import ExpressionFlattener
class FactorEngine:
"""因子计算引擎 - 系统统一入口。
提供从表达式到结果的完整执行链路是研究员使用系统的唯一接口
执行流程:
1. 注册表达式 -> 调用编译器解析依赖
2. 调用路由器连接数据库拉取并组装核心宽表
3. 调用翻译器生成物理执行计划
4. 将计划提交给计算引擎执行并行运算
5. 返回包含因子结果的数据表
Attributes:
router: 数据路由器
planner: 执行计划生成器
compute_engine: 计算引擎
registered_expressions: 注册的表达式字典
_registry: 函数注册表
_parser: 公式解析器
"""
def __init__(
self,
data_source: Optional[Dict[str, pl.DataFrame]] = None,
registry: Optional["FunctionRegistry"] = None,
) -> None:
"""初始化因子引擎。
Args:
data_source: 内存数据源 None 时使用数据库连接
registry: 函数注册表None 时创建独立实例
"""
from src.factors.registry import FunctionRegistry
from src.factors.parser import FormulaParser
self.router = DataRouter(data_source)
self.planner = ExecutionPlanner()
self.compute_engine = ComputeEngine()
self.registered_expressions: Dict[str, Node] = {}
self._plans: Dict[str, ExecutionPlan] = {}
# 初始化注册表和解析器(支持注入外部注册表实现共享)
self._registry = registry if registry is not None else FunctionRegistry()
self._parser = FormulaParser(self._registry)
# 初始化 metadata 管理器(使用默认路径)
from src.factors.metadata import FactorManager
self._metadata = FactorManager()
def _register_internal(
self,
name: str,
expression: Node,
data_specs: Optional[List[DataSpec]] = None,
) -> "FactorEngine":
"""内部注册方法,直接注册因子表达式。
Args:
name: 因子名称
expression: DSL 表达式
data_specs: 数据规格None 时自动推导
Returns:
self支持链式调用
"""
# 检测因子依赖(在注册当前因子之前检查其他已注册因子)
factor_deps = self._find_factor_dependencies(expression)
# 获取当前所有已注册的因子名称(作为免疫名单,防止被当作数据库字段)
known_factors = set(self.registered_expressions.keys())
self.registered_expressions[name] = expression
# 预创建执行计划,过滤掉已注册的因子,防止被当作数据库字段
plan = self.planner.create_plan(
expression=expression,
output_name=name,
data_specs=data_specs,
ignore_dependencies=known_factors,
)
# 添加因子依赖信息
plan.factor_dependencies = factor_deps
# 如果数据规格为空,继承依赖因子(包括临时因子)的数据规格
if not plan.data_specs and factor_deps:
merged_specs: List[DataSpec] = []
for dep_name in factor_deps:
if dep_name in self._plans:
merged_specs.extend(self._plans[dep_name].data_specs)
# 去重(基于表名)
seen_tables: set = set()
unique_specs: List[DataSpec] = []
for spec in merged_specs:
if spec.table not in seen_tables:
seen_tables.add(spec.table)
unique_specs.append(spec)
plan.data_specs = unique_specs
self._plans[name] = plan
return self
def register(
self,
name: str,
expression: Node,
data_specs: Optional[List[DataSpec]] = None,
) -> "FactorEngine":
"""注册因子表达式(自动处理嵌套窗口函数)。
Args:
name: 因子名称
expression: DSL 表达式
data_specs: 数据规格None 时自动推导
Returns:
self支持链式调用
Example:
>>> from src.factors.api import close, ts_mean
>>> engine = FactorEngine()
>>> engine.register("ma20", ts_mean(close, 20))
"""
# 使用 AST 优化器拍平嵌套窗口函数
flattener = ExpressionFlattener()
flat_expression, tmp_factors = flattener.flatten(expression)
# 先注册所有临时因子(自动推导数据规格)
for tmp_name, tmp_node in tmp_factors.items():
self._register_internal(tmp_name, tmp_node, data_specs=None)
# 最后注册主因子
self._register_internal(name, flat_expression, data_specs)
return self
def _add_factor_from_metadata(
self,
name: str,
factor_name_in_metadata: str,
data_specs: Optional[List[DataSpec]] = None,
) -> "FactorEngine":
"""从 metadata 中查询并注册因子(内部方法)。
Args:
name: 要注册的因子名称引擎中使用的名称
factor_name_in_metadata: metadata 中的因子名称
data_specs: 可选的数据规格
Returns:
self支持链式调用
Raises:
RuntimeError: 当引擎未配置 metadata 路径时
ValueError: 当在 metadata 中未找到因子时
FormulaParseError: DSL 表达式解析失败时
"""
if self._metadata is None:
raise RuntimeError(
"引擎未配置 metadata 路径。请在初始化时传入 metadata_path 参数,"
+ "例如FactorEngine(metadata_path='data/factors.jsonl')"
)
# 从 metadata 查询因子
df = self._metadata.get_factors_by_name(factor_name_in_metadata)
if len(df) == 0:
raise ValueError(
f"在 metadata 中未找到因子 '{factor_name_in_metadata}'"
+ "请确认因子名称正确,或先使用 FactorManager 添加该因子。"
)
# 获取 DSL 表达式
dsl_expr = df["dsl"][0]
# 解析表达式为 Node
node = self._parser.parse(dsl_expr)
# 委托给 register 方法register 会处理嵌套窗口函数拍平)
return self.register(name, node, data_specs)
def add_factor(
self,
name: str,
expression: Optional[Union[str, Node]] = None,
data_specs: Optional[List[DataSpec]] = None,
) -> "FactorEngine":
"""注册因子(支持多种调用方式)。
这是 register 方法的增强版支持以下调用方式
1. 传入 name expression直接注册表达式字符串或 Node
2. 只传入 name metadata 中查询表达式并注册
遵循 Fail-Fast 原则字符串表达式会立即解析失败时立即抛出异常
Args:
name: 因子名称引擎中使用的名称
expression: 字符串表达式或 Node 对象 None 时从 metadata 查询
data_specs: 可选的数据规格
Returns:
self支持链式调用
Raises:
TypeError: expression 类型不支持时
FormulaParseError: 当字符串解析失败时立即报错
RuntimeError: expression None 但未配置 metadata
ValueError: 当在 metadata 中未找到因子时
Example:
>>> engine = FactorEngine()
>>>
>>> # 方式1字符串表达式
>>> engine.add_factor("ma20", "ts_mean(close, 20)")
>>>
>>> # 方式2Node 表达式
>>> from src.factors.api import close, ts_mean
>>> engine.add_factor("ma20", ts_mean(close, 20))
>>>
>>> # 方式3从 metadata 查询(需要初始化时配置 metadata_path
>>> engine.add_factor("return_5") # 从 metadata 查询名为 return_5 的因子
>>>
>>> # 链式调用
>>> (engine
... .add_factor("ma5", "ts_mean(close, 5)")
... .add_factor("ma10", "ts_mean(close, 10)")
... .add_factor("golden_cross", "ma5 > ma10"))
"""
if expression is None:
# 从 metadata 查询表达式
return self._add_factor_from_metadata(name, name, data_specs)
if isinstance(expression, str):
# Fail-Fast立即解析失败立即报错
node = self._parser.parse(expression)
elif isinstance(expression, Node):
node = expression
else:
raise TypeError(
f"表达式必须是 str 或 Node 类型,收到 {type(expression).__name__}"
)
# 委托给现有的 register 方法
return self.register(name, node, data_specs)
def compute(
self,
factor_names: Union[str, List[str]],
start_date: str,
end_date: str,
stock_codes: Optional[List[str]] = None,
) -> pl.DataFrame:
"""计算指定因子的值。
完整的执行流程取数 -> 组装 -> 翻译 -> 计算
Args:
factor_names: 因子名称或名称列表
start_date: 开始日期 (YYYYMMDD)
end_date: 结束日期 (YYYYMMDD)
stock_codes: 股票代码列表None 表示全市场
Returns:
包含因子结果的数据表
Raises:
ValueError: 当因子未注册或数据不足时
Example:
>>> result = engine.compute("ma20", "20240101", "20240131")
>>> result = engine.compute(["ma20", "rsi"], "20240101", "20240131")
"""
# 标准化因子名称
if isinstance(factor_names, str):
factor_names = [factor_names]
# 1. 收集所有需要的因子(包括临时因子依赖)
all_factor_names = self._collect_all_dependencies(factor_names)
# 2. 获取执行计划
plans = []
for name in all_factor_names:
if name not in self._plans:
raise ValueError(f"因子未注册: {name}")
plans.append(self._plans[name])
# 3. 合并数据规格并获取数据
all_specs = []
for plan in plans:
all_specs.extend(plan.data_specs)
# 合并相同表的字段(而不是简单地去重)
table_to_columns: Dict[str, Set[str]] = {}
table_to_spec: Dict[str, DataSpec] = {}
for spec in all_specs:
if spec.table not in table_to_columns:
table_to_columns[spec.table] = set()
table_to_spec[spec.table] = spec
table_to_columns[spec.table].update(spec.columns)
# 创建合并后的数据规格
unique_specs: List[DataSpec] = []
for table_name, columns in table_to_columns.items():
original_spec = table_to_spec[table_name]
unique_specs.append(
DataSpec(
table=table_name,
columns=list(columns),
join_type=original_spec.join_type,
left_on=original_spec.left_on,
right_on=original_spec.right_on,
)
)
# 4. 从路由器获取核心宽表
core_data = self.router.fetch_data(
data_specs=unique_specs,
start_date=start_date,
end_date=end_date,
stock_codes=stock_codes,
)
if len(core_data) == 0:
raise ValueError("未获取到任何数据,请检查日期范围和股票代码")
# 5. 按依赖顺序执行计算(包含临时因子)
result = self._execute_with_dependencies(all_factor_names, core_data)
# 6. 清理内存宽表过滤掉临时因子列__tmp_X
# 保留所有非临时因子列(包括原始数据列和用户请求的因子列)
cols_to_keep = [col for col in result.columns if not col.startswith("__tmp_")]
return result.select(cols_to_keep)
def list_registered(self) -> List[str]:
"""获取已注册的因子列表。
Returns:
因子名称列表
"""
return list(self.registered_expressions.keys())
def get_expression(self, name: str) -> Optional[Node]:
"""获取已注册的表达式。
Args:
name: 因子名称
Returns:
表达式节点未注册时返回 None
"""
return self.registered_expressions.get(name)
def clear(self) -> None:
"""清除所有注册的表达式和缓存。"""
self.registered_expressions.clear()
self._plans.clear()
self.router.clear_cache()
def preview_plan(self, factor_name: str) -> Optional[ExecutionPlan]:
"""预览因子的执行计划。
Args:
factor_name: 因子名称
Returns:
执行计划未注册时返回 None
"""
return self._plans.get(factor_name)
def _execute_with_dependencies(
self,
factor_names: List[str],
core_data: pl.DataFrame,
) -> pl.DataFrame:
"""按依赖顺序执行因子计算。
支持 cs_rank 等需要依赖列已存在的场景
Args:
factor_names: 因子名称列表
core_data: 核心宽表数据
Returns:
包含所有因子结果的数据表
"""
# 1. 拓扑排序
sorted_names = self._topological_sort(factor_names)
# 2. 按顺序执行
result = core_data
for name in sorted_names:
plan = self._plans[name]
# 创建新的执行计划,引用已计算的依赖列
new_plan = self._create_optimized_plan(plan, result)
# 执行计算
result = self.compute_engine.execute(new_plan, result)
return result
def _create_optimized_plan(
self,
plan: ExecutionPlan,
current_data: pl.DataFrame,
) -> ExecutionPlan:
"""创建优化的执行计划。
将表达式中已计算的依赖因子替换为列引用
Args:
plan: 原始执行计划
current_data: 当前数据包含已计算的依赖列
Returns:
新的执行计划
"""
from src.factors.dsl import Symbol
# 获取当前数据中已存在的列
existing_cols = set(current_data.columns)
# 检查依赖列是否已存在
deps_available = plan.factor_dependencies & existing_cols
if not deps_available:
# 没有可用的依赖列,直接返回原计划
return plan
# 获取原始表达式
original_expr = self.registered_expressions[plan.output_name]
# 创建新的表达式,用 Symbol 引用替换依赖因子
def replace_with_symbol(node: Node) -> Node:
"""递归替换表达式中的依赖因子为 Symbol 引用。"""
from typing import Any
n: Any = node
# 检查当前节点是否等于某个已计算依赖因子
for dep_name in deps_available:
dep_expr = self.registered_expressions[dep_name]
if self._expressions_equal(node, dep_expr):
return Symbol(dep_name)
# 递归处理子节点
if isinstance(n, BinaryOpNode):
new_left = replace_with_symbol(n.left)
new_right = replace_with_symbol(n.right)
if new_left is not n.left or new_right is not n.right:
return BinaryOpNode(n.op, new_left, new_right)
elif isinstance(n, UnaryOpNode):
new_operand = replace_with_symbol(n.operand)
if new_operand is not n.operand:
return UnaryOpNode(n.op, new_operand)
elif isinstance(n, FunctionNode):
new_args = [replace_with_symbol(arg) for arg in n.args]
if any(
new_arg is not old_arg for new_arg, old_arg in zip(new_args, n.args)
):
return FunctionNode(n.func_name, *new_args)
return node
# 替换表达式
new_expr = replace_with_symbol(original_expr)
# 重新翻译表达式
translator = PolarsTranslator()
new_polars_expr = translator.translate(new_expr)
# 更新依赖集合
new_factor_deps = plan.factor_dependencies - deps_available
new_deps = plan.dependencies | deps_available
return ExecutionPlan(
data_specs=plan.data_specs,
polars_expr=new_polars_expr,
dependencies=new_deps,
output_name=plan.output_name,
factor_dependencies=new_factor_deps,
)
def _expressions_equal(self, expr1: Node, expr2: Node) -> bool:
"""比较两个表达式是否相等。
用于检测因子间的依赖关系
Args:
expr1: 第一个表达式
expr2: 第二个表达式
Returns:
是否相等
"""
from typing import Any
e1: Any = expr1
e2: Any = expr2
if type(e1) != type(e2):
return False
if isinstance(e1, Symbol):
return e1.name == e2.name
from src.factors.dsl import Constant
if isinstance(e1, Constant):
return e1.value == e2.value
if isinstance(e1, BinaryOpNode):
return (
e1.op == e2.op
and self._expressions_equal(e1.left, e2.left)
and self._expressions_equal(e1.right, e2.right)
)
if isinstance(e1, UnaryOpNode):
return e1.op == e2.op and self._expressions_equal(e1.operand, e2.operand)
if isinstance(e1, FunctionNode):
if e1.func_name != e2.func_name or len(e1.args) != len(e2.args):
return False
return all(
self._expressions_equal(a1, a2) for a1, a2 in zip(e1.args, e2.args)
)
return False
def _collect_all_dependencies(self, factor_names: List[str]) -> List[str]:
"""收集所有因子及其依赖(包括用户定义的因子和临时因子)。"""
collected: Set[str] = set()
result: List[str] = []
def collect_recursive(name: str):
if name in collected:
return
collected.add(name)
# 获取执行计划并递归收集强依赖
plan = self._plans.get(name)
if plan:
for dep_name in plan.factor_dependencies:
collect_recursive(dep_name)
# 依赖收集完毕,再将自己加入列表(天然形成安全的计算顺序)
result.append(name)
for name in factor_names:
collect_recursive(name)
return result
def _find_factor_dependencies(self, expression: Node) -> Set[str]:
"""查找表达式依赖的其他因子(包括临时因子和用户因子引用)。
Args:
expression: 待检查的表达式
Returns:
依赖的因子名称集合
"""
deps: Set[str] = set()
# 1. 【新增】如果直接引用了已注册的因子名称(包含 __tmp_X 或用户因子)
if (
isinstance(expression, Symbol)
and expression.name in self.registered_expressions
):
deps.add(expression.name)
# 2. 检查表达式本身是否等于某个已注册因子的完整 AST
for name, registered_expr in self.registered_expressions.items():
if self._expressions_equal(expression, registered_expr):
deps.add(name)
break
# 3. 递归检查子节点
if isinstance(expression, BinaryOpNode):
deps.update(self._find_factor_dependencies(expression.left))
deps.update(self._find_factor_dependencies(expression.right))
elif isinstance(expression, UnaryOpNode):
deps.update(self._find_factor_dependencies(expression.operand))
elif isinstance(expression, FunctionNode):
for arg in expression.args:
deps.update(self._find_factor_dependencies(arg))
return deps
def _topological_sort(self, factor_names: List[str]) -> List[str]:
"""按依赖关系对因子进行拓扑排序。
确保依赖的因子先被计算
Args:
factor_names: 因子名称列表
Returns:
排序后的因子名称列表
Raises:
ValueError: 当检测到循环依赖时
"""
# 构建依赖图
graph: Dict[str, Set[str]] = {}
in_degree: Dict[str, int] = {}
for name in factor_names:
plan = self._plans[name]
# 只考虑在本次计算范围内的依赖
deps = plan.factor_dependencies & set(factor_names)
graph[name] = deps
in_degree[name] = len(deps)
# Kahn 算法
result = []
queue = [name for name, degree in in_degree.items() if degree == 0]
while queue:
# 按原始顺序处理同级别的因子
queue.sort(key=lambda x: factor_names.index(x))
name = queue.pop(0)
result.append(name)
for other in factor_names:
if name in graph[other]:
in_degree[other] -= 1
if in_degree[other] == 0:
queue.append(other)
if len(result) != len(factor_names):
raise ValueError("检测到因子循环依赖")
return result