2026-03-02 22:29:18 +08:00
|
|
|
|
"""因子计算引擎 - 系统统一入口。
|
|
|
|
|
|
|
|
|
|
|
|
提供从表达式到结果的完整执行链路,是研究员使用系统的唯一接口。
|
|
|
|
|
|
|
|
|
|
|
|
执行流程:
|
|
|
|
|
|
1. 注册表达式 -> 调用编译器解析依赖
|
|
|
|
|
|
2. 调用路由器连接数据库拉取并组装核心宽表
|
|
|
|
|
|
3. 调用翻译器生成物理执行计划
|
|
|
|
|
|
4. 将计划提交给计算引擎执行并行运算
|
|
|
|
|
|
5. 返回包含因子结果的数据表
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
2026-03-03 00:04:48 +08:00
|
|
|
|
from typing import Any, Dict, List, Optional, Set, Union, TYPE_CHECKING
|
2026-03-02 22:29:18 +08:00
|
|
|
|
|
|
|
|
|
|
import polars as pl
|
|
|
|
|
|
|
2026-03-03 00:04:48 +08:00
|
|
|
|
if TYPE_CHECKING:
|
|
|
|
|
|
from src.factors.registry import FunctionRegistry
|
2026-03-11 22:54:52 +08:00
|
|
|
|
from src.factors.metadata import FactorManager
|
2026-03-03 00:04:48 +08:00
|
|
|
|
|
2026-03-02 22:29:18 +08:00
|
|
|
|
from src.factors.dsl import (
|
|
|
|
|
|
Node,
|
|
|
|
|
|
Symbol,
|
|
|
|
|
|
BinaryOpNode,
|
|
|
|
|
|
UnaryOpNode,
|
|
|
|
|
|
FunctionNode,
|
|
|
|
|
|
)
|
|
|
|
|
|
from src.factors.translator import PolarsTranslator
|
|
|
|
|
|
from src.factors.engine.data_spec import DataSpec, ExecutionPlan
|
|
|
|
|
|
from src.factors.engine.data_router import DataRouter
|
|
|
|
|
|
from src.factors.engine.planner import ExecutionPlanner
|
|
|
|
|
|
from src.factors.engine.compute_engine import ComputeEngine
|
2026-03-14 01:06:17 +08:00
|
|
|
|
from src.factors.engine.ast_optimizer import ExpressionFlattener
|
2026-03-02 22:29:18 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class FactorEngine:
|
|
|
|
|
|
"""因子计算引擎 - 系统统一入口。
|
|
|
|
|
|
|
|
|
|
|
|
提供从表达式到结果的完整执行链路,是研究员使用系统的唯一接口。
|
|
|
|
|
|
|
|
|
|
|
|
执行流程:
|
|
|
|
|
|
1. 注册表达式 -> 调用编译器解析依赖
|
|
|
|
|
|
2. 调用路由器连接数据库拉取并组装核心宽表
|
|
|
|
|
|
3. 调用翻译器生成物理执行计划
|
|
|
|
|
|
4. 将计划提交给计算引擎执行并行运算
|
|
|
|
|
|
5. 返回包含因子结果的数据表
|
|
|
|
|
|
|
|
|
|
|
|
Attributes:
|
|
|
|
|
|
router: 数据路由器
|
|
|
|
|
|
planner: 执行计划生成器
|
|
|
|
|
|
compute_engine: 计算引擎
|
|
|
|
|
|
registered_expressions: 注册的表达式字典
|
2026-03-03 00:04:48 +08:00
|
|
|
|
_registry: 函数注册表
|
|
|
|
|
|
_parser: 公式解析器
|
2026-03-02 22:29:18 +08:00
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
def __init__(
|
|
|
|
|
|
self,
|
|
|
|
|
|
data_source: Optional[Dict[str, pl.DataFrame]] = None,
|
|
|
|
|
|
max_workers: int = 4,
|
2026-03-03 00:04:48 +08:00
|
|
|
|
registry: Optional["FunctionRegistry"] = None,
|
2026-03-11 22:54:52 +08:00
|
|
|
|
metadata_path: Optional[str] = None,
|
2026-03-02 22:29:18 +08:00
|
|
|
|
) -> None:
|
|
|
|
|
|
"""初始化因子引擎。
|
|
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
|
data_source: 内存数据源,为 None 时使用数据库连接
|
|
|
|
|
|
max_workers: 并行计算的最大工作线程数
|
2026-03-03 00:04:48 +08:00
|
|
|
|
registry: 函数注册表,None 时创建独立实例
|
2026-03-11 22:54:52 +08:00
|
|
|
|
metadata_path: 因子元数据文件路径,为 None 时不启用 metadata 功能
|
2026-03-02 22:29:18 +08:00
|
|
|
|
"""
|
2026-03-03 00:04:48 +08:00
|
|
|
|
from src.factors.registry import FunctionRegistry
|
|
|
|
|
|
from src.factors.parser import FormulaParser
|
|
|
|
|
|
|
2026-03-02 22:29:18 +08:00
|
|
|
|
self.router = DataRouter(data_source)
|
|
|
|
|
|
self.planner = ExecutionPlanner()
|
|
|
|
|
|
self.compute_engine = ComputeEngine(max_workers=max_workers)
|
|
|
|
|
|
self.registered_expressions: Dict[str, Node] = {}
|
|
|
|
|
|
self._plans: Dict[str, ExecutionPlan] = {}
|
|
|
|
|
|
|
2026-03-03 00:04:48 +08:00
|
|
|
|
# 初始化注册表和解析器(支持注入外部注册表实现共享)
|
|
|
|
|
|
self._registry = registry if registry is not None else FunctionRegistry()
|
|
|
|
|
|
self._parser = FormulaParser(self._registry)
|
|
|
|
|
|
|
2026-03-12 22:34:25 +08:00
|
|
|
|
# 初始化 metadata 管理器(可选,默认启用)
|
2026-03-11 22:54:52 +08:00
|
|
|
|
if metadata_path is not None:
|
|
|
|
|
|
from src.factors.metadata import FactorManager
|
|
|
|
|
|
|
|
|
|
|
|
self._metadata = FactorManager(metadata_path)
|
2026-03-12 22:34:25 +08:00
|
|
|
|
else:
|
|
|
|
|
|
# 使用 FactorManager 的默认路径
|
|
|
|
|
|
from src.factors.metadata import FactorManager
|
|
|
|
|
|
|
|
|
|
|
|
self._metadata = FactorManager()
|
2026-03-11 22:54:52 +08:00
|
|
|
|
|
2026-03-14 01:06:17 +08:00
|
|
|
|
def _register_internal(
|
2026-03-02 22:29:18 +08:00
|
|
|
|
self,
|
|
|
|
|
|
name: str,
|
|
|
|
|
|
expression: Node,
|
|
|
|
|
|
data_specs: Optional[List[DataSpec]] = None,
|
|
|
|
|
|
) -> "FactorEngine":
|
2026-03-14 01:06:17 +08:00
|
|
|
|
"""内部注册方法,直接注册因子表达式。
|
2026-03-02 22:29:18 +08:00
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
|
name: 因子名称
|
|
|
|
|
|
expression: DSL 表达式
|
|
|
|
|
|
data_specs: 数据规格,None 时自动推导
|
|
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
|
self,支持链式调用
|
|
|
|
|
|
"""
|
|
|
|
|
|
# 检测因子依赖(在注册当前因子之前检查其他已注册因子)
|
|
|
|
|
|
factor_deps = self._find_factor_dependencies(expression)
|
|
|
|
|
|
|
2026-03-14 01:06:17 +08:00
|
|
|
|
# 获取当前所有已注册的因子名称(作为免疫名单,防止被当作数据库字段)
|
|
|
|
|
|
known_factors = set(self.registered_expressions.keys())
|
|
|
|
|
|
|
2026-03-02 22:29:18 +08:00
|
|
|
|
self.registered_expressions[name] = expression
|
|
|
|
|
|
|
2026-03-14 01:06:17 +08:00
|
|
|
|
# 预创建执行计划,过滤掉已注册的因子,防止被当作数据库字段
|
2026-03-02 22:29:18 +08:00
|
|
|
|
plan = self.planner.create_plan(
|
|
|
|
|
|
expression=expression,
|
|
|
|
|
|
output_name=name,
|
|
|
|
|
|
data_specs=data_specs,
|
2026-03-14 01:06:17 +08:00
|
|
|
|
ignore_dependencies=known_factors,
|
2026-03-02 22:29:18 +08:00
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
# 添加因子依赖信息
|
|
|
|
|
|
plan.factor_dependencies = factor_deps
|
|
|
|
|
|
|
2026-03-14 01:06:17 +08:00
|
|
|
|
# 如果数据规格为空,继承依赖因子(包括临时因子)的数据规格
|
|
|
|
|
|
if not plan.data_specs and factor_deps:
|
|
|
|
|
|
merged_specs: List[DataSpec] = []
|
|
|
|
|
|
for dep_name in factor_deps:
|
|
|
|
|
|
if dep_name in self._plans:
|
|
|
|
|
|
merged_specs.extend(self._plans[dep_name].data_specs)
|
|
|
|
|
|
|
|
|
|
|
|
# 去重(基于表名)
|
|
|
|
|
|
seen_tables: set = set()
|
|
|
|
|
|
unique_specs: List[DataSpec] = []
|
|
|
|
|
|
for spec in merged_specs:
|
|
|
|
|
|
if spec.table not in seen_tables:
|
|
|
|
|
|
seen_tables.add(spec.table)
|
|
|
|
|
|
unique_specs.append(spec)
|
|
|
|
|
|
plan.data_specs = unique_specs
|
|
|
|
|
|
|
2026-03-02 22:29:18 +08:00
|
|
|
|
self._plans[name] = plan
|
|
|
|
|
|
|
|
|
|
|
|
return self
|
|
|
|
|
|
|
2026-03-14 01:06:17 +08:00
|
|
|
|
def register(
|
|
|
|
|
|
self,
|
|
|
|
|
|
name: str,
|
|
|
|
|
|
expression: Node,
|
|
|
|
|
|
data_specs: Optional[List[DataSpec]] = None,
|
|
|
|
|
|
) -> "FactorEngine":
|
|
|
|
|
|
"""注册因子表达式(自动处理嵌套窗口函数)。
|
|
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
|
name: 因子名称
|
|
|
|
|
|
expression: DSL 表达式
|
|
|
|
|
|
data_specs: 数据规格,None 时自动推导
|
|
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
|
self,支持链式调用
|
|
|
|
|
|
|
|
|
|
|
|
Example:
|
|
|
|
|
|
>>> from src.factors.api import close, ts_mean
|
|
|
|
|
|
>>> engine = FactorEngine()
|
|
|
|
|
|
>>> engine.register("ma20", ts_mean(close, 20))
|
|
|
|
|
|
"""
|
|
|
|
|
|
# 使用 AST 优化器拍平嵌套窗口函数
|
|
|
|
|
|
flattener = ExpressionFlattener()
|
|
|
|
|
|
flat_expression, tmp_factors = flattener.flatten(expression)
|
|
|
|
|
|
|
|
|
|
|
|
# 先注册所有临时因子(自动推导数据规格)
|
|
|
|
|
|
for tmp_name, tmp_node in tmp_factors.items():
|
|
|
|
|
|
self._register_internal(tmp_name, tmp_node, data_specs=None)
|
|
|
|
|
|
|
|
|
|
|
|
# 最后注册主因子
|
|
|
|
|
|
self._register_internal(name, flat_expression, data_specs)
|
|
|
|
|
|
|
|
|
|
|
|
return self
|
|
|
|
|
|
|
2026-03-12 22:34:25 +08:00
|
|
|
|
def _add_factor_from_metadata(
|
|
|
|
|
|
self,
|
|
|
|
|
|
name: str,
|
|
|
|
|
|
factor_name_in_metadata: str,
|
|
|
|
|
|
data_specs: Optional[List[DataSpec]] = None,
|
|
|
|
|
|
) -> "FactorEngine":
|
|
|
|
|
|
"""从 metadata 中查询并注册因子(内部方法)。
|
|
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
|
name: 要注册的因子名称(引擎中使用的名称)
|
|
|
|
|
|
factor_name_in_metadata: metadata 中的因子名称
|
|
|
|
|
|
data_specs: 可选的数据规格
|
|
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
|
self,支持链式调用
|
|
|
|
|
|
|
|
|
|
|
|
Raises:
|
|
|
|
|
|
RuntimeError: 当引擎未配置 metadata 路径时
|
|
|
|
|
|
ValueError: 当在 metadata 中未找到因子时
|
|
|
|
|
|
FormulaParseError: 当 DSL 表达式解析失败时
|
|
|
|
|
|
"""
|
|
|
|
|
|
if self._metadata is None:
|
|
|
|
|
|
raise RuntimeError(
|
|
|
|
|
|
"引擎未配置 metadata 路径。请在初始化时传入 metadata_path 参数,"
|
|
|
|
|
|
+ "例如:FactorEngine(metadata_path='data/factors.jsonl')"
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
# 从 metadata 查询因子
|
|
|
|
|
|
df = self._metadata.get_factors_by_name(factor_name_in_metadata)
|
|
|
|
|
|
|
|
|
|
|
|
if len(df) == 0:
|
|
|
|
|
|
raise ValueError(
|
|
|
|
|
|
f"在 metadata 中未找到因子 '{factor_name_in_metadata}'。"
|
|
|
|
|
|
+ "请确认因子名称正确,或先使用 FactorManager 添加该因子。"
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
# 获取 DSL 表达式
|
|
|
|
|
|
dsl_expr = df["dsl"][0]
|
|
|
|
|
|
|
|
|
|
|
|
# 解析表达式为 Node
|
|
|
|
|
|
node = self._parser.parse(dsl_expr)
|
|
|
|
|
|
|
2026-03-14 01:06:17 +08:00
|
|
|
|
# 委托给 register 方法(register 会处理嵌套窗口函数拍平)
|
2026-03-12 22:34:25 +08:00
|
|
|
|
return self.register(name, node, data_specs)
|
|
|
|
|
|
|
2026-03-03 00:04:48 +08:00
|
|
|
|
def add_factor(
|
|
|
|
|
|
self,
|
|
|
|
|
|
name: str,
|
2026-03-12 22:34:25 +08:00
|
|
|
|
expression: Optional[Union[str, Node]] = None,
|
2026-03-03 00:04:48 +08:00
|
|
|
|
data_specs: Optional[List[DataSpec]] = None,
|
|
|
|
|
|
) -> "FactorEngine":
|
2026-03-12 22:34:25 +08:00
|
|
|
|
"""注册因子(支持多种调用方式)。
|
2026-03-03 00:04:48 +08:00
|
|
|
|
|
2026-03-12 22:34:25 +08:00
|
|
|
|
这是 register 方法的增强版,支持以下调用方式:
|
|
|
|
|
|
1. 传入 name 和 expression:直接注册表达式(字符串或 Node)
|
|
|
|
|
|
2. 只传入 name:从 metadata 中查询表达式并注册
|
2026-03-03 00:04:48 +08:00
|
|
|
|
|
|
|
|
|
|
遵循 Fail-Fast 原则:字符串表达式会立即解析,失败时立即抛出异常。
|
|
|
|
|
|
|
|
|
|
|
|
Args:
|
2026-03-12 22:34:25 +08:00
|
|
|
|
name: 因子名称(引擎中使用的名称)
|
|
|
|
|
|
expression: 字符串表达式或 Node 对象,为 None 时从 metadata 查询
|
2026-03-03 00:04:48 +08:00
|
|
|
|
data_specs: 可选的数据规格
|
|
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
|
self,支持链式调用
|
|
|
|
|
|
|
|
|
|
|
|
Raises:
|
|
|
|
|
|
TypeError: 当 expression 类型不支持时
|
|
|
|
|
|
FormulaParseError: 当字符串解析失败时(立即报错)
|
2026-03-12 22:34:25 +08:00
|
|
|
|
RuntimeError: 当 expression 为 None 但未配置 metadata 时
|
|
|
|
|
|
ValueError: 当在 metadata 中未找到因子时
|
2026-03-03 00:04:48 +08:00
|
|
|
|
|
|
|
|
|
|
Example:
|
|
|
|
|
|
>>> engine = FactorEngine()
|
|
|
|
|
|
>>>
|
2026-03-12 22:34:25 +08:00
|
|
|
|
>>> # 方式1:字符串表达式
|
2026-03-03 00:04:48 +08:00
|
|
|
|
>>> engine.add_factor("ma20", "ts_mean(close, 20)")
|
|
|
|
|
|
>>>
|
2026-03-12 22:34:25 +08:00
|
|
|
|
>>> # 方式2:Node 表达式
|
2026-03-03 00:04:48 +08:00
|
|
|
|
>>> from src.factors.api import close, ts_mean
|
|
|
|
|
|
>>> engine.add_factor("ma20", ts_mean(close, 20))
|
|
|
|
|
|
>>>
|
2026-03-12 22:34:25 +08:00
|
|
|
|
>>> # 方式3:从 metadata 查询(需要初始化时配置 metadata_path)
|
|
|
|
|
|
>>> engine.add_factor("return_5") # 从 metadata 查询名为 return_5 的因子
|
2026-03-03 00:04:48 +08:00
|
|
|
|
>>>
|
|
|
|
|
|
>>> # 链式调用
|
|
|
|
|
|
>>> (engine
|
|
|
|
|
|
... .add_factor("ma5", "ts_mean(close, 5)")
|
|
|
|
|
|
... .add_factor("ma10", "ts_mean(close, 10)")
|
|
|
|
|
|
... .add_factor("golden_cross", "ma5 > ma10"))
|
|
|
|
|
|
"""
|
2026-03-12 22:34:25 +08:00
|
|
|
|
if expression is None:
|
|
|
|
|
|
# 从 metadata 查询表达式
|
|
|
|
|
|
return self._add_factor_from_metadata(name, name, data_specs)
|
|
|
|
|
|
|
2026-03-03 00:04:48 +08:00
|
|
|
|
if isinstance(expression, str):
|
|
|
|
|
|
# Fail-Fast:立即解析,失败立即报错
|
|
|
|
|
|
node = self._parser.parse(expression)
|
|
|
|
|
|
elif isinstance(expression, Node):
|
|
|
|
|
|
node = expression
|
|
|
|
|
|
else:
|
|
|
|
|
|
raise TypeError(
|
|
|
|
|
|
f"表达式必须是 str 或 Node 类型,收到 {type(expression).__name__}"
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
# 委托给现有的 register 方法
|
|
|
|
|
|
return self.register(name, node, data_specs)
|
|
|
|
|
|
|
2026-03-02 22:29:18 +08:00
|
|
|
|
def compute(
|
|
|
|
|
|
self,
|
|
|
|
|
|
factor_names: Union[str, List[str]],
|
|
|
|
|
|
start_date: str,
|
|
|
|
|
|
end_date: str,
|
|
|
|
|
|
stock_codes: Optional[List[str]] = None,
|
|
|
|
|
|
) -> pl.DataFrame:
|
|
|
|
|
|
"""计算指定因子的值。
|
|
|
|
|
|
|
|
|
|
|
|
完整的执行流程:取数 -> 组装 -> 翻译 -> 计算。
|
|
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
|
factor_names: 因子名称或名称列表
|
|
|
|
|
|
start_date: 开始日期 (YYYYMMDD)
|
|
|
|
|
|
end_date: 结束日期 (YYYYMMDD)
|
|
|
|
|
|
stock_codes: 股票代码列表,None 表示全市场
|
|
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
|
包含因子结果的数据表
|
|
|
|
|
|
|
|
|
|
|
|
Raises:
|
|
|
|
|
|
ValueError: 当因子未注册或数据不足时
|
|
|
|
|
|
|
|
|
|
|
|
Example:
|
|
|
|
|
|
>>> result = engine.compute("ma20", "20240101", "20240131")
|
|
|
|
|
|
>>> result = engine.compute(["ma20", "rsi"], "20240101", "20240131")
|
|
|
|
|
|
"""
|
|
|
|
|
|
# 标准化因子名称
|
|
|
|
|
|
if isinstance(factor_names, str):
|
|
|
|
|
|
factor_names = [factor_names]
|
|
|
|
|
|
|
2026-03-14 01:06:17 +08:00
|
|
|
|
# 1. 收集所有需要的因子(包括临时因子依赖)
|
|
|
|
|
|
all_factor_names = self._collect_all_dependencies(factor_names)
|
|
|
|
|
|
|
|
|
|
|
|
# 2. 获取执行计划
|
2026-03-02 22:29:18 +08:00
|
|
|
|
plans = []
|
2026-03-14 01:06:17 +08:00
|
|
|
|
for name in all_factor_names:
|
2026-03-02 22:29:18 +08:00
|
|
|
|
if name not in self._plans:
|
|
|
|
|
|
raise ValueError(f"因子未注册: {name}")
|
|
|
|
|
|
plans.append(self._plans[name])
|
|
|
|
|
|
|
2026-03-14 01:06:17 +08:00
|
|
|
|
# 3. 合并数据规格并获取数据
|
2026-03-02 22:29:18 +08:00
|
|
|
|
all_specs = []
|
|
|
|
|
|
for plan in plans:
|
|
|
|
|
|
all_specs.extend(plan.data_specs)
|
|
|
|
|
|
|
2026-03-14 01:06:17 +08:00
|
|
|
|
# 去重数据规格(基于表名)
|
|
|
|
|
|
seen_tables: set = set()
|
|
|
|
|
|
unique_specs: List[DataSpec] = []
|
|
|
|
|
|
for spec in all_specs:
|
|
|
|
|
|
if spec.table not in seen_tables:
|
|
|
|
|
|
seen_tables.add(spec.table)
|
|
|
|
|
|
unique_specs.append(spec)
|
|
|
|
|
|
|
|
|
|
|
|
# 4. 从路由器获取核心宽表
|
2026-03-02 22:29:18 +08:00
|
|
|
|
core_data = self.router.fetch_data(
|
2026-03-14 01:06:17 +08:00
|
|
|
|
data_specs=unique_specs,
|
2026-03-02 22:29:18 +08:00
|
|
|
|
start_date=start_date,
|
|
|
|
|
|
end_date=end_date,
|
|
|
|
|
|
stock_codes=stock_codes,
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
if len(core_data) == 0:
|
|
|
|
|
|
raise ValueError("未获取到任何数据,请检查日期范围和股票代码")
|
|
|
|
|
|
|
2026-03-14 01:06:17 +08:00
|
|
|
|
# 5. 按依赖顺序执行计算(包含临时因子)
|
|
|
|
|
|
result = self._execute_with_dependencies(all_factor_names, core_data)
|
2026-03-02 22:29:18 +08:00
|
|
|
|
|
2026-03-14 01:06:17 +08:00
|
|
|
|
# 6. 清理内存宽表,过滤掉临时因子列(__tmp_X)
|
|
|
|
|
|
# 保留所有非临时因子列(包括原始数据列和用户请求的因子列)
|
|
|
|
|
|
cols_to_keep = [col for col in result.columns if not col.startswith("__tmp_")]
|
|
|
|
|
|
|
|
|
|
|
|
return result.select(cols_to_keep)
|
2026-03-02 22:29:18 +08:00
|
|
|
|
|
|
|
|
|
|
def list_registered(self) -> List[str]:
|
|
|
|
|
|
"""获取已注册的因子列表。
|
|
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
|
因子名称列表
|
|
|
|
|
|
"""
|
|
|
|
|
|
return list(self.registered_expressions.keys())
|
|
|
|
|
|
|
|
|
|
|
|
def get_expression(self, name: str) -> Optional[Node]:
|
|
|
|
|
|
"""获取已注册的表达式。
|
|
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
|
name: 因子名称
|
|
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
|
表达式节点,未注册时返回 None
|
|
|
|
|
|
"""
|
|
|
|
|
|
return self.registered_expressions.get(name)
|
|
|
|
|
|
|
|
|
|
|
|
def clear(self) -> None:
|
|
|
|
|
|
"""清除所有注册的表达式和缓存。"""
|
|
|
|
|
|
self.registered_expressions.clear()
|
|
|
|
|
|
self._plans.clear()
|
|
|
|
|
|
self.router.clear_cache()
|
|
|
|
|
|
|
|
|
|
|
|
def preview_plan(self, factor_name: str) -> Optional[ExecutionPlan]:
|
|
|
|
|
|
"""预览因子的执行计划。
|
|
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
|
factor_name: 因子名称
|
|
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
|
执行计划,未注册时返回 None
|
|
|
|
|
|
"""
|
|
|
|
|
|
return self._plans.get(factor_name)
|
|
|
|
|
|
|
|
|
|
|
|
def _execute_with_dependencies(
|
|
|
|
|
|
self,
|
|
|
|
|
|
factor_names: List[str],
|
|
|
|
|
|
core_data: pl.DataFrame,
|
|
|
|
|
|
) -> pl.DataFrame:
|
|
|
|
|
|
"""按依赖顺序执行因子计算。
|
|
|
|
|
|
|
|
|
|
|
|
支持 cs_rank 等需要依赖列已存在的场景。
|
|
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
|
factor_names: 因子名称列表
|
|
|
|
|
|
core_data: 核心宽表数据
|
|
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
|
包含所有因子结果的数据表
|
|
|
|
|
|
"""
|
|
|
|
|
|
# 1. 拓扑排序
|
|
|
|
|
|
sorted_names = self._topological_sort(factor_names)
|
|
|
|
|
|
|
|
|
|
|
|
# 2. 按顺序执行
|
|
|
|
|
|
result = core_data
|
|
|
|
|
|
for name in sorted_names:
|
|
|
|
|
|
plan = self._plans[name]
|
|
|
|
|
|
|
|
|
|
|
|
# 创建新的执行计划,引用已计算的依赖列
|
|
|
|
|
|
new_plan = self._create_optimized_plan(plan, result)
|
|
|
|
|
|
|
|
|
|
|
|
# 执行计算
|
|
|
|
|
|
result = self.compute_engine.execute(new_plan, result)
|
|
|
|
|
|
|
|
|
|
|
|
return result
|
|
|
|
|
|
|
|
|
|
|
|
def _create_optimized_plan(
|
|
|
|
|
|
self,
|
|
|
|
|
|
plan: ExecutionPlan,
|
|
|
|
|
|
current_data: pl.DataFrame,
|
|
|
|
|
|
) -> ExecutionPlan:
|
|
|
|
|
|
"""创建优化的执行计划。
|
|
|
|
|
|
|
|
|
|
|
|
将表达式中已计算的依赖因子替换为列引用。
|
|
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
|
plan: 原始执行计划
|
|
|
|
|
|
current_data: 当前数据(包含已计算的依赖列)
|
|
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
|
新的执行计划
|
|
|
|
|
|
"""
|
|
|
|
|
|
from src.factors.dsl import Symbol
|
|
|
|
|
|
|
|
|
|
|
|
# 获取当前数据中已存在的列
|
|
|
|
|
|
existing_cols = set(current_data.columns)
|
|
|
|
|
|
|
|
|
|
|
|
# 检查依赖列是否已存在
|
|
|
|
|
|
deps_available = plan.factor_dependencies & existing_cols
|
|
|
|
|
|
|
|
|
|
|
|
if not deps_available:
|
|
|
|
|
|
# 没有可用的依赖列,直接返回原计划
|
|
|
|
|
|
return plan
|
|
|
|
|
|
|
|
|
|
|
|
# 获取原始表达式
|
|
|
|
|
|
original_expr = self.registered_expressions[plan.output_name]
|
|
|
|
|
|
|
|
|
|
|
|
# 创建新的表达式,用 Symbol 引用替换依赖因子
|
|
|
|
|
|
def replace_with_symbol(node: Node) -> Node:
|
|
|
|
|
|
"""递归替换表达式中的依赖因子为 Symbol 引用。"""
|
|
|
|
|
|
from typing import Any
|
|
|
|
|
|
|
|
|
|
|
|
n: Any = node
|
|
|
|
|
|
|
|
|
|
|
|
# 检查当前节点是否等于某个已计算依赖因子
|
|
|
|
|
|
for dep_name in deps_available:
|
|
|
|
|
|
dep_expr = self.registered_expressions[dep_name]
|
|
|
|
|
|
if self._expressions_equal(node, dep_expr):
|
|
|
|
|
|
return Symbol(dep_name)
|
|
|
|
|
|
|
|
|
|
|
|
# 递归处理子节点
|
|
|
|
|
|
if isinstance(n, BinaryOpNode):
|
|
|
|
|
|
new_left = replace_with_symbol(n.left)
|
|
|
|
|
|
new_right = replace_with_symbol(n.right)
|
|
|
|
|
|
if new_left is not n.left or new_right is not n.right:
|
|
|
|
|
|
return BinaryOpNode(n.op, new_left, new_right)
|
|
|
|
|
|
elif isinstance(n, UnaryOpNode):
|
|
|
|
|
|
new_operand = replace_with_symbol(n.operand)
|
|
|
|
|
|
if new_operand is not n.operand:
|
|
|
|
|
|
return UnaryOpNode(n.op, new_operand)
|
|
|
|
|
|
elif isinstance(n, FunctionNode):
|
|
|
|
|
|
new_args = [replace_with_symbol(arg) for arg in n.args]
|
|
|
|
|
|
if any(
|
|
|
|
|
|
new_arg is not old_arg for new_arg, old_arg in zip(new_args, n.args)
|
|
|
|
|
|
):
|
|
|
|
|
|
return FunctionNode(n.func_name, *new_args)
|
|
|
|
|
|
|
|
|
|
|
|
return node
|
|
|
|
|
|
|
|
|
|
|
|
# 替换表达式
|
|
|
|
|
|
new_expr = replace_with_symbol(original_expr)
|
|
|
|
|
|
|
|
|
|
|
|
# 重新翻译表达式
|
|
|
|
|
|
translator = PolarsTranslator()
|
|
|
|
|
|
new_polars_expr = translator.translate(new_expr)
|
|
|
|
|
|
|
|
|
|
|
|
# 更新依赖集合
|
|
|
|
|
|
new_factor_deps = plan.factor_dependencies - deps_available
|
|
|
|
|
|
new_deps = plan.dependencies | deps_available
|
|
|
|
|
|
|
|
|
|
|
|
return ExecutionPlan(
|
|
|
|
|
|
data_specs=plan.data_specs,
|
|
|
|
|
|
polars_expr=new_polars_expr,
|
|
|
|
|
|
dependencies=new_deps,
|
|
|
|
|
|
output_name=plan.output_name,
|
|
|
|
|
|
factor_dependencies=new_factor_deps,
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
def _expressions_equal(self, expr1: Node, expr2: Node) -> bool:
|
|
|
|
|
|
"""比较两个表达式是否相等。
|
|
|
|
|
|
|
|
|
|
|
|
用于检测因子间的依赖关系。
|
|
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
|
expr1: 第一个表达式
|
|
|
|
|
|
expr2: 第二个表达式
|
|
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
|
是否相等
|
|
|
|
|
|
"""
|
|
|
|
|
|
from typing import Any
|
|
|
|
|
|
|
|
|
|
|
|
e1: Any = expr1
|
|
|
|
|
|
e2: Any = expr2
|
|
|
|
|
|
|
|
|
|
|
|
if type(e1) != type(e2):
|
|
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
|
|
if isinstance(e1, Symbol):
|
|
|
|
|
|
return e1.name == e2.name
|
|
|
|
|
|
|
|
|
|
|
|
from src.factors.dsl import Constant
|
|
|
|
|
|
|
|
|
|
|
|
if isinstance(e1, Constant):
|
|
|
|
|
|
return e1.value == e2.value
|
|
|
|
|
|
|
|
|
|
|
|
if isinstance(e1, BinaryOpNode):
|
|
|
|
|
|
return (
|
|
|
|
|
|
e1.op == e2.op
|
|
|
|
|
|
and self._expressions_equal(e1.left, e2.left)
|
|
|
|
|
|
and self._expressions_equal(e1.right, e2.right)
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
if isinstance(e1, UnaryOpNode):
|
|
|
|
|
|
return e1.op == e2.op and self._expressions_equal(e1.operand, e2.operand)
|
|
|
|
|
|
|
|
|
|
|
|
if isinstance(e1, FunctionNode):
|
|
|
|
|
|
if e1.func_name != e2.func_name or len(e1.args) != len(e2.args):
|
|
|
|
|
|
return False
|
|
|
|
|
|
return all(
|
|
|
|
|
|
self._expressions_equal(a1, a2) for a1, a2 in zip(e1.args, e2.args)
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
return False
|
|
|
|
|
|
|
2026-03-14 01:06:17 +08:00
|
|
|
|
def _collect_all_dependencies(self, factor_names: List[str]) -> List[str]:
|
|
|
|
|
|
"""收集所有因子及其依赖(包括用户定义的因子和临时因子)。"""
|
|
|
|
|
|
collected: Set[str] = set()
|
|
|
|
|
|
result: List[str] = []
|
|
|
|
|
|
|
|
|
|
|
|
def collect_recursive(name: str):
|
|
|
|
|
|
if name in collected:
|
|
|
|
|
|
return
|
|
|
|
|
|
collected.add(name)
|
|
|
|
|
|
|
|
|
|
|
|
# 获取执行计划并递归收集强依赖
|
|
|
|
|
|
plan = self._plans.get(name)
|
|
|
|
|
|
if plan:
|
|
|
|
|
|
for dep_name in plan.factor_dependencies:
|
|
|
|
|
|
collect_recursive(dep_name)
|
|
|
|
|
|
|
|
|
|
|
|
# 依赖收集完毕,再将自己加入列表(天然形成安全的计算顺序)
|
|
|
|
|
|
result.append(name)
|
2026-03-02 22:29:18 +08:00
|
|
|
|
|
2026-03-14 01:06:17 +08:00
|
|
|
|
for name in factor_names:
|
|
|
|
|
|
collect_recursive(name)
|
|
|
|
|
|
|
|
|
|
|
|
return result
|
|
|
|
|
|
|
|
|
|
|
|
def _find_factor_dependencies(self, expression: Node) -> Set[str]:
|
|
|
|
|
|
"""查找表达式依赖的其他因子(包括临时因子和用户因子引用)。
|
2026-03-02 22:29:18 +08:00
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
|
expression: 待检查的表达式
|
|
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
|
依赖的因子名称集合
|
|
|
|
|
|
"""
|
|
|
|
|
|
deps: Set[str] = set()
|
|
|
|
|
|
|
2026-03-14 01:06:17 +08:00
|
|
|
|
# 1. 【新增】如果直接引用了已注册的因子名称(包含 __tmp_X 或用户因子)
|
|
|
|
|
|
if (
|
|
|
|
|
|
isinstance(expression, Symbol)
|
|
|
|
|
|
and expression.name in self.registered_expressions
|
|
|
|
|
|
):
|
|
|
|
|
|
deps.add(expression.name)
|
|
|
|
|
|
|
|
|
|
|
|
# 2. 检查表达式本身是否等于某个已注册因子的完整 AST
|
2026-03-02 22:29:18 +08:00
|
|
|
|
for name, registered_expr in self.registered_expressions.items():
|
|
|
|
|
|
if self._expressions_equal(expression, registered_expr):
|
|
|
|
|
|
deps.add(name)
|
|
|
|
|
|
break
|
|
|
|
|
|
|
2026-03-14 01:06:17 +08:00
|
|
|
|
# 3. 递归检查子节点
|
2026-03-02 22:29:18 +08:00
|
|
|
|
if isinstance(expression, BinaryOpNode):
|
|
|
|
|
|
deps.update(self._find_factor_dependencies(expression.left))
|
|
|
|
|
|
deps.update(self._find_factor_dependencies(expression.right))
|
|
|
|
|
|
elif isinstance(expression, UnaryOpNode):
|
|
|
|
|
|
deps.update(self._find_factor_dependencies(expression.operand))
|
|
|
|
|
|
elif isinstance(expression, FunctionNode):
|
|
|
|
|
|
for arg in expression.args:
|
|
|
|
|
|
deps.update(self._find_factor_dependencies(arg))
|
|
|
|
|
|
|
|
|
|
|
|
return deps
|
|
|
|
|
|
|
|
|
|
|
|
def _topological_sort(self, factor_names: List[str]) -> List[str]:
|
|
|
|
|
|
"""按依赖关系对因子进行拓扑排序。
|
|
|
|
|
|
|
|
|
|
|
|
确保依赖的因子先被计算。
|
|
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
|
factor_names: 因子名称列表
|
|
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
|
排序后的因子名称列表
|
|
|
|
|
|
|
|
|
|
|
|
Raises:
|
|
|
|
|
|
ValueError: 当检测到循环依赖时
|
|
|
|
|
|
"""
|
|
|
|
|
|
# 构建依赖图
|
|
|
|
|
|
graph: Dict[str, Set[str]] = {}
|
|
|
|
|
|
in_degree: Dict[str, int] = {}
|
|
|
|
|
|
|
|
|
|
|
|
for name in factor_names:
|
|
|
|
|
|
plan = self._plans[name]
|
|
|
|
|
|
# 只考虑在本次计算范围内的依赖
|
|
|
|
|
|
deps = plan.factor_dependencies & set(factor_names)
|
|
|
|
|
|
graph[name] = deps
|
|
|
|
|
|
in_degree[name] = len(deps)
|
|
|
|
|
|
|
|
|
|
|
|
# Kahn 算法
|
|
|
|
|
|
result = []
|
|
|
|
|
|
queue = [name for name, degree in in_degree.items() if degree == 0]
|
|
|
|
|
|
|
|
|
|
|
|
while queue:
|
|
|
|
|
|
# 按原始顺序处理同级别的因子
|
|
|
|
|
|
queue.sort(key=lambda x: factor_names.index(x))
|
|
|
|
|
|
name = queue.pop(0)
|
|
|
|
|
|
result.append(name)
|
|
|
|
|
|
|
|
|
|
|
|
for other in factor_names:
|
|
|
|
|
|
if name in graph[other]:
|
|
|
|
|
|
in_degree[other] -= 1
|
|
|
|
|
|
if in_degree[other] == 0:
|
|
|
|
|
|
queue.append(other)
|
|
|
|
|
|
|
|
|
|
|
|
if len(result) != len(factor_names):
|
|
|
|
|
|
raise ValueError("检测到因子循环依赖")
|
|
|
|
|
|
|
|
|
|
|
|
return result
|