Files
ProStock/src/factors/engine/factor_engine.py
liaozhaorun 6927d20de1 feat(training): LightGBM支持验证集早停
- 为fit方法添加eval_set参数,支持验证集评估和早停

- 因子引擎简化初始化,移除metadata_path参数

- 回归实验精简因子定义,移除冗余因子库
2026-03-14 22:51:24 +08:00

675 lines
23 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""因子计算引擎 - 系统统一入口。
提供从表达式到结果的完整执行链路,是研究员使用系统的唯一接口。
执行流程:
1. 注册表达式 -> 调用编译器解析依赖
2. 调用路由器连接数据库拉取并组装核心宽表
3. 调用翻译器生成物理执行计划
4. 将计划提交给计算引擎执行并行运算
5. 返回包含因子结果的数据表
"""
from typing import Any, Dict, List, Optional, Set, Union, TYPE_CHECKING
import polars as pl
if TYPE_CHECKING:
from src.factors.registry import FunctionRegistry
from src.factors.metadata import FactorManager
from src.factors.dsl import (
Node,
Symbol,
BinaryOpNode,
UnaryOpNode,
FunctionNode,
)
from src.factors.translator import PolarsTranslator
from src.factors.engine.data_spec import DataSpec, ExecutionPlan
from src.factors.engine.data_router import DataRouter
from src.factors.engine.planner import ExecutionPlanner
from src.factors.engine.compute_engine import ComputeEngine
from src.factors.engine.ast_optimizer import ExpressionFlattener
class FactorEngine:
"""因子计算引擎 - 系统统一入口。
提供从表达式到结果的完整执行链路,是研究员使用系统的唯一接口。
执行流程:
1. 注册表达式 -> 调用编译器解析依赖
2. 调用路由器连接数据库拉取并组装核心宽表
3. 调用翻译器生成物理执行计划
4. 将计划提交给计算引擎执行并行运算
5. 返回包含因子结果的数据表
Attributes:
router: 数据路由器
planner: 执行计划生成器
compute_engine: 计算引擎
registered_expressions: 注册的表达式字典
_registry: 函数注册表
_parser: 公式解析器
"""
def __init__(
self,
data_source: Optional[Dict[str, pl.DataFrame]] = None,
registry: Optional["FunctionRegistry"] = None,
) -> None:
"""初始化因子引擎。
Args:
data_source: 内存数据源,为 None 时使用数据库连接
registry: 函数注册表None 时创建独立实例
"""
from src.factors.registry import FunctionRegistry
from src.factors.parser import FormulaParser
self.router = DataRouter(data_source)
self.planner = ExecutionPlanner()
self.compute_engine = ComputeEngine()
self.registered_expressions: Dict[str, Node] = {}
self._plans: Dict[str, ExecutionPlan] = {}
# 初始化注册表和解析器(支持注入外部注册表实现共享)
self._registry = registry if registry is not None else FunctionRegistry()
self._parser = FormulaParser(self._registry)
# 初始化 metadata 管理器(使用默认路径)
from src.factors.metadata import FactorManager
self._metadata = FactorManager()
def _register_internal(
self,
name: str,
expression: Node,
data_specs: Optional[List[DataSpec]] = None,
) -> "FactorEngine":
"""内部注册方法,直接注册因子表达式。
Args:
name: 因子名称
expression: DSL 表达式
data_specs: 数据规格None 时自动推导
Returns:
self支持链式调用
"""
# 检测因子依赖(在注册当前因子之前检查其他已注册因子)
factor_deps = self._find_factor_dependencies(expression)
# 获取当前所有已注册的因子名称(作为免疫名单,防止被当作数据库字段)
known_factors = set(self.registered_expressions.keys())
self.registered_expressions[name] = expression
# 预创建执行计划,过滤掉已注册的因子,防止被当作数据库字段
plan = self.planner.create_plan(
expression=expression,
output_name=name,
data_specs=data_specs,
ignore_dependencies=known_factors,
)
# 添加因子依赖信息
plan.factor_dependencies = factor_deps
# 如果数据规格为空,继承依赖因子(包括临时因子)的数据规格
if not plan.data_specs and factor_deps:
merged_specs: List[DataSpec] = []
for dep_name in factor_deps:
if dep_name in self._plans:
merged_specs.extend(self._plans[dep_name].data_specs)
# 去重(基于表名)
seen_tables: set = set()
unique_specs: List[DataSpec] = []
for spec in merged_specs:
if spec.table not in seen_tables:
seen_tables.add(spec.table)
unique_specs.append(spec)
plan.data_specs = unique_specs
self._plans[name] = plan
return self
def register(
self,
name: str,
expression: Node,
data_specs: Optional[List[DataSpec]] = None,
) -> "FactorEngine":
"""注册因子表达式(自动处理嵌套窗口函数)。
Args:
name: 因子名称
expression: DSL 表达式
data_specs: 数据规格None 时自动推导
Returns:
self支持链式调用
Example:
>>> from src.factors.api import close, ts_mean
>>> engine = FactorEngine()
>>> engine.register("ma20", ts_mean(close, 20))
"""
# 使用 AST 优化器拍平嵌套窗口函数
flattener = ExpressionFlattener()
flat_expression, tmp_factors = flattener.flatten(expression)
# 先注册所有临时因子(自动推导数据规格)
for tmp_name, tmp_node in tmp_factors.items():
self._register_internal(tmp_name, tmp_node, data_specs=None)
# 最后注册主因子
self._register_internal(name, flat_expression, data_specs)
return self
def _add_factor_from_metadata(
self,
name: str,
factor_name_in_metadata: str,
data_specs: Optional[List[DataSpec]] = None,
) -> "FactorEngine":
"""从 metadata 中查询并注册因子(内部方法)。
Args:
name: 要注册的因子名称(引擎中使用的名称)
factor_name_in_metadata: metadata 中的因子名称
data_specs: 可选的数据规格
Returns:
self支持链式调用
Raises:
RuntimeError: 当引擎未配置 metadata 路径时
ValueError: 当在 metadata 中未找到因子时
FormulaParseError: 当 DSL 表达式解析失败时
"""
if self._metadata is None:
raise RuntimeError(
"引擎未配置 metadata 路径。请在初始化时传入 metadata_path 参数,"
+ "例如FactorEngine(metadata_path='data/factors.jsonl')"
)
# 从 metadata 查询因子
df = self._metadata.get_factors_by_name(factor_name_in_metadata)
if len(df) == 0:
raise ValueError(
f"在 metadata 中未找到因子 '{factor_name_in_metadata}'"
+ "请确认因子名称正确,或先使用 FactorManager 添加该因子。"
)
# 获取 DSL 表达式
dsl_expr = df["dsl"][0]
# 解析表达式为 Node
node = self._parser.parse(dsl_expr)
# 委托给 register 方法register 会处理嵌套窗口函数拍平)
return self.register(name, node, data_specs)
def add_factor(
self,
name: str,
expression: Optional[Union[str, Node]] = None,
data_specs: Optional[List[DataSpec]] = None,
) -> "FactorEngine":
"""注册因子(支持多种调用方式)。
这是 register 方法的增强版,支持以下调用方式:
1. 传入 name 和 expression直接注册表达式字符串或 Node
2. 只传入 name从 metadata 中查询表达式并注册
遵循 Fail-Fast 原则:字符串表达式会立即解析,失败时立即抛出异常。
Args:
name: 因子名称(引擎中使用的名称)
expression: 字符串表达式或 Node 对象,为 None 时从 metadata 查询
data_specs: 可选的数据规格
Returns:
self支持链式调用
Raises:
TypeError: 当 expression 类型不支持时
FormulaParseError: 当字符串解析失败时(立即报错)
RuntimeError: 当 expression 为 None 但未配置 metadata 时
ValueError: 当在 metadata 中未找到因子时
Example:
>>> engine = FactorEngine()
>>>
>>> # 方式1字符串表达式
>>> engine.add_factor("ma20", "ts_mean(close, 20)")
>>>
>>> # 方式2Node 表达式
>>> from src.factors.api import close, ts_mean
>>> engine.add_factor("ma20", ts_mean(close, 20))
>>>
>>> # 方式3从 metadata 查询(需要初始化时配置 metadata_path
>>> engine.add_factor("return_5") # 从 metadata 查询名为 return_5 的因子
>>>
>>> # 链式调用
>>> (engine
... .add_factor("ma5", "ts_mean(close, 5)")
... .add_factor("ma10", "ts_mean(close, 10)")
... .add_factor("golden_cross", "ma5 > ma10"))
"""
if expression is None:
# 从 metadata 查询表达式
return self._add_factor_from_metadata(name, name, data_specs)
if isinstance(expression, str):
# Fail-Fast立即解析失败立即报错
node = self._parser.parse(expression)
elif isinstance(expression, Node):
node = expression
else:
raise TypeError(
f"表达式必须是 str 或 Node 类型,收到 {type(expression).__name__}"
)
# 委托给现有的 register 方法
return self.register(name, node, data_specs)
def compute(
self,
factor_names: Union[str, List[str]],
start_date: str,
end_date: str,
stock_codes: Optional[List[str]] = None,
) -> pl.DataFrame:
"""计算指定因子的值。
完整的执行流程:取数 -> 组装 -> 翻译 -> 计算。
Args:
factor_names: 因子名称或名称列表
start_date: 开始日期 (YYYYMMDD)
end_date: 结束日期 (YYYYMMDD)
stock_codes: 股票代码列表None 表示全市场
Returns:
包含因子结果的数据表
Raises:
ValueError: 当因子未注册或数据不足时
Example:
>>> result = engine.compute("ma20", "20240101", "20240131")
>>> result = engine.compute(["ma20", "rsi"], "20240101", "20240131")
"""
# 标准化因子名称
if isinstance(factor_names, str):
factor_names = [factor_names]
# 1. 收集所有需要的因子(包括临时因子依赖)
all_factor_names = self._collect_all_dependencies(factor_names)
# 2. 获取执行计划
plans = []
for name in all_factor_names:
if name not in self._plans:
raise ValueError(f"因子未注册: {name}")
plans.append(self._plans[name])
# 3. 合并数据规格并获取数据
all_specs = []
for plan in plans:
all_specs.extend(plan.data_specs)
# 合并相同表的字段(而不是简单地去重)
table_to_columns: Dict[str, Set[str]] = {}
table_to_spec: Dict[str, DataSpec] = {}
for spec in all_specs:
if spec.table not in table_to_columns:
table_to_columns[spec.table] = set()
table_to_spec[spec.table] = spec
table_to_columns[spec.table].update(spec.columns)
# 创建合并后的数据规格
unique_specs: List[DataSpec] = []
for table_name, columns in table_to_columns.items():
original_spec = table_to_spec[table_name]
unique_specs.append(
DataSpec(
table=table_name,
columns=list(columns),
join_type=original_spec.join_type,
left_on=original_spec.left_on,
right_on=original_spec.right_on,
)
)
# 4. 从路由器获取核心宽表
core_data = self.router.fetch_data(
data_specs=unique_specs,
start_date=start_date,
end_date=end_date,
stock_codes=stock_codes,
)
if len(core_data) == 0:
raise ValueError("未获取到任何数据,请检查日期范围和股票代码")
# 5. 按依赖顺序执行计算(包含临时因子)
result = self._execute_with_dependencies(all_factor_names, core_data)
# 6. 清理内存宽表过滤掉临时因子列__tmp_X
# 保留所有非临时因子列(包括原始数据列和用户请求的因子列)
cols_to_keep = [col for col in result.columns if not col.startswith("__tmp_")]
return result.select(cols_to_keep)
def list_registered(self) -> List[str]:
"""获取已注册的因子列表。
Returns:
因子名称列表
"""
return list(self.registered_expressions.keys())
def get_expression(self, name: str) -> Optional[Node]:
"""获取已注册的表达式。
Args:
name: 因子名称
Returns:
表达式节点,未注册时返回 None
"""
return self.registered_expressions.get(name)
def clear(self) -> None:
"""清除所有注册的表达式和缓存。"""
self.registered_expressions.clear()
self._plans.clear()
self.router.clear_cache()
def preview_plan(self, factor_name: str) -> Optional[ExecutionPlan]:
"""预览因子的执行计划。
Args:
factor_name: 因子名称
Returns:
执行计划,未注册时返回 None
"""
return self._plans.get(factor_name)
def _execute_with_dependencies(
self,
factor_names: List[str],
core_data: pl.DataFrame,
) -> pl.DataFrame:
"""按依赖顺序执行因子计算。
支持 cs_rank 等需要依赖列已存在的场景。
Args:
factor_names: 因子名称列表
core_data: 核心宽表数据
Returns:
包含所有因子结果的数据表
"""
# 1. 拓扑排序
sorted_names = self._topological_sort(factor_names)
# 2. 按顺序执行
result = core_data
for name in sorted_names:
plan = self._plans[name]
# 创建新的执行计划,引用已计算的依赖列
new_plan = self._create_optimized_plan(plan, result)
# 执行计算
result = self.compute_engine.execute(new_plan, result)
return result
def _create_optimized_plan(
self,
plan: ExecutionPlan,
current_data: pl.DataFrame,
) -> ExecutionPlan:
"""创建优化的执行计划。
将表达式中已计算的依赖因子替换为列引用。
Args:
plan: 原始执行计划
current_data: 当前数据(包含已计算的依赖列)
Returns:
新的执行计划
"""
from src.factors.dsl import Symbol
# 获取当前数据中已存在的列
existing_cols = set(current_data.columns)
# 检查依赖列是否已存在
deps_available = plan.factor_dependencies & existing_cols
if not deps_available:
# 没有可用的依赖列,直接返回原计划
return plan
# 获取原始表达式
original_expr = self.registered_expressions[plan.output_name]
# 创建新的表达式,用 Symbol 引用替换依赖因子
def replace_with_symbol(node: Node) -> Node:
"""递归替换表达式中的依赖因子为 Symbol 引用。"""
from typing import Any
n: Any = node
# 检查当前节点是否等于某个已计算依赖因子
for dep_name in deps_available:
dep_expr = self.registered_expressions[dep_name]
if self._expressions_equal(node, dep_expr):
return Symbol(dep_name)
# 递归处理子节点
if isinstance(n, BinaryOpNode):
new_left = replace_with_symbol(n.left)
new_right = replace_with_symbol(n.right)
if new_left is not n.left or new_right is not n.right:
return BinaryOpNode(n.op, new_left, new_right)
elif isinstance(n, UnaryOpNode):
new_operand = replace_with_symbol(n.operand)
if new_operand is not n.operand:
return UnaryOpNode(n.op, new_operand)
elif isinstance(n, FunctionNode):
new_args = [replace_with_symbol(arg) for arg in n.args]
if any(
new_arg is not old_arg for new_arg, old_arg in zip(new_args, n.args)
):
return FunctionNode(n.func_name, *new_args)
return node
# 替换表达式
new_expr = replace_with_symbol(original_expr)
# 重新翻译表达式
translator = PolarsTranslator()
new_polars_expr = translator.translate(new_expr)
# 更新依赖集合
new_factor_deps = plan.factor_dependencies - deps_available
new_deps = plan.dependencies | deps_available
return ExecutionPlan(
data_specs=plan.data_specs,
polars_expr=new_polars_expr,
dependencies=new_deps,
output_name=plan.output_name,
factor_dependencies=new_factor_deps,
)
def _expressions_equal(self, expr1: Node, expr2: Node) -> bool:
"""比较两个表达式是否相等。
用于检测因子间的依赖关系。
Args:
expr1: 第一个表达式
expr2: 第二个表达式
Returns:
是否相等
"""
from typing import Any
e1: Any = expr1
e2: Any = expr2
if type(e1) != type(e2):
return False
if isinstance(e1, Symbol):
return e1.name == e2.name
from src.factors.dsl import Constant
if isinstance(e1, Constant):
return e1.value == e2.value
if isinstance(e1, BinaryOpNode):
return (
e1.op == e2.op
and self._expressions_equal(e1.left, e2.left)
and self._expressions_equal(e1.right, e2.right)
)
if isinstance(e1, UnaryOpNode):
return e1.op == e2.op and self._expressions_equal(e1.operand, e2.operand)
if isinstance(e1, FunctionNode):
if e1.func_name != e2.func_name or len(e1.args) != len(e2.args):
return False
return all(
self._expressions_equal(a1, a2) for a1, a2 in zip(e1.args, e2.args)
)
return False
def _collect_all_dependencies(self, factor_names: List[str]) -> List[str]:
"""收集所有因子及其依赖(包括用户定义的因子和临时因子)。"""
collected: Set[str] = set()
result: List[str] = []
def collect_recursive(name: str):
if name in collected:
return
collected.add(name)
# 获取执行计划并递归收集强依赖
plan = self._plans.get(name)
if plan:
for dep_name in plan.factor_dependencies:
collect_recursive(dep_name)
# 依赖收集完毕,再将自己加入列表(天然形成安全的计算顺序)
result.append(name)
for name in factor_names:
collect_recursive(name)
return result
def _find_factor_dependencies(self, expression: Node) -> Set[str]:
"""查找表达式依赖的其他因子(包括临时因子和用户因子引用)。
Args:
expression: 待检查的表达式
Returns:
依赖的因子名称集合
"""
deps: Set[str] = set()
# 1. 【新增】如果直接引用了已注册的因子名称(包含 __tmp_X 或用户因子)
if (
isinstance(expression, Symbol)
and expression.name in self.registered_expressions
):
deps.add(expression.name)
# 2. 检查表达式本身是否等于某个已注册因子的完整 AST
for name, registered_expr in self.registered_expressions.items():
if self._expressions_equal(expression, registered_expr):
deps.add(name)
break
# 3. 递归检查子节点
if isinstance(expression, BinaryOpNode):
deps.update(self._find_factor_dependencies(expression.left))
deps.update(self._find_factor_dependencies(expression.right))
elif isinstance(expression, UnaryOpNode):
deps.update(self._find_factor_dependencies(expression.operand))
elif isinstance(expression, FunctionNode):
for arg in expression.args:
deps.update(self._find_factor_dependencies(arg))
return deps
def _topological_sort(self, factor_names: List[str]) -> List[str]:
"""按依赖关系对因子进行拓扑排序。
确保依赖的因子先被计算。
Args:
factor_names: 因子名称列表
Returns:
排序后的因子名称列表
Raises:
ValueError: 当检测到循环依赖时
"""
# 构建依赖图
graph: Dict[str, Set[str]] = {}
in_degree: Dict[str, int] = {}
for name in factor_names:
plan = self._plans[name]
# 只考虑在本次计算范围内的依赖
deps = plan.factor_dependencies & set(factor_names)
graph[name] = deps
in_degree[name] = len(deps)
# Kahn 算法
result = []
queue = [name for name, degree in in_degree.items() if degree == 0]
while queue:
# 按原始顺序处理同级别的因子
queue.sort(key=lambda x: factor_names.index(x))
name = queue.pop(0)
result.append(name)
for other in factor_names:
if name in graph[other]:
in_degree[other] -= 1
if in_degree[other] == 0:
queue.append(other)
if len(result) != len(factor_names):
raise ValueError("检测到因子循环依赖")
return result