Files
ProStock/src/factors/engine.py
liaozhaorun b461a4940d refactor(factor): 完成因子框架 DSL 化重构
- 重构 FactorEngine 实现完整的 DSL 表达式执行链路
- 新增 DataRouter 数据路由器,支持内存模式和核心宽表组装
- 新增 ExecutionPlanner 执行计划生成器,整合编译器和翻译器
- 新增 ComputeEngine 计算引擎,支持并行运算
- 完善 factors/__init__.py 公开 API 导出
- 新增 test_factor_engine.py 引擎单元测试
- 移除旧引擎实现和废弃的 DSL promotion 测试
- 更新 AGENTS.md 添加 v2.2 架构变更历史和 Factors 框架设计说明
2026-03-01 15:03:56 +08:00

707 lines
20 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""FactorEngine - 因子计算引擎统一入口。
提供从表达式注册到结果输出的完整执行链路:
接收研究员的表达式 -> 调用编译器解析依赖 -> 调用路由器连接数据库拉取并组装核心宽表
-> 调用翻译器生成物理执行计划 -> 将计划提交给计算引擎执行并行运算。
"""
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Set, Union
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor
import threading
import polars as pl
from src.factors.dsl import (
Node,
Symbol,
FunctionNode,
BinaryOpNode,
UnaryOpNode,
Constant,
)
from src.factors.compiler import DependencyExtractor
from src.factors.translator import PolarsTranslator
@dataclass
class DataSpec:
"""数据规格定义。
描述因子计算所需的数据表和字段。
Attributes:
table: 数据表名称
columns: 需要的字段列表
lookback_days: 回看天数(用于时序计算)
"""
table: str
columns: List[str]
lookback_days: int = 1
@dataclass
class ExecutionPlan:
"""执行计划。
包含完整的执行所需信息:数据源、转换逻辑、输出格式。
Attributes:
data_specs: 数据规格列表
polars_expr: Polars 表达式
dependencies: 依赖的原始字段
output_name: 输出因子名称
"""
data_specs: List[DataSpec]
polars_expr: pl.Expr
dependencies: Set[str]
output_name: str
class DataRouter:
"""数据路由器 - 按需取数、组装核心宽表。
负责根据数据规格从数据源拉取数据,并组装成统一的宽表格式。
支持内存数据源(用于测试)和真实数据库连接。
Attributes:
data_source: 数据源,可以是内存 DataFrame 字典或数据库连接
is_memory_mode: 是否为内存模式
"""
def __init__(self, data_source: Optional[Dict[str, pl.DataFrame]] = None) -> None:
"""初始化数据路由器。
Args:
data_source: 内存数据源,字典格式 {表名: DataFrame}
为 None 时需要在子类中实现数据库连接
"""
self.data_source = data_source or {}
self.is_memory_mode = data_source is not None
self._cache: Dict[str, pl.DataFrame] = {}
self._lock = threading.Lock()
def fetch_data(
self,
data_specs: List[DataSpec],
start_date: str,
end_date: str,
stock_codes: Optional[List[str]] = None,
) -> pl.DataFrame:
"""根据数据规格获取并组装核心宽表。
Args:
data_specs: 数据规格列表
start_date: 开始日期 (YYYYMMDD)
end_date: 结束日期 (YYYYMMDD)
stock_codes: 股票代码列表None 表示全市场
Returns:
组装好的核心宽表 DataFrame
Raises:
ValueError: 当数据源中缺少必要的表或字段时
"""
if not data_specs:
raise ValueError("数据规格不能为空")
# 收集所有需要的表和字段
required_tables: Dict[str, Set[str]] = {}
max_lookback = 0
for spec in data_specs:
if spec.table not in required_tables:
required_tables[spec.table] = set()
required_tables[spec.table].update(spec.columns)
max_lookback = max(max_lookback, spec.lookback_days)
# 调整日期范围以包含回看期
adjusted_start = self._adjust_start_date(start_date, max_lookback)
# 从数据源获取各表数据
table_data = {}
for table_name, columns in required_tables.items():
df = self._load_table(
table_name=table_name,
columns=list(columns),
start_date=adjusted_start,
end_date=end_date,
stock_codes=stock_codes,
)
table_data[table_name] = df
# 组装核心宽表
core_table = self._assemble_wide_table(table_data, required_tables)
# 过滤到实际请求日期范围
core_table = core_table.filter(
(pl.col("trade_date") >= start_date) & (pl.col("trade_date") <= end_date)
)
return core_table
def _load_table(
self,
table_name: str,
columns: List[str],
start_date: str,
end_date: str,
stock_codes: Optional[List[str]] = None,
) -> pl.DataFrame:
"""加载单个表的数据。
Args:
table_name: 表名
columns: 需要的字段
start_date: 开始日期
end_date: 结束日期
stock_codes: 股票代码过滤
Returns:
过滤后的 DataFrame
"""
cache_key = f"{table_name}_{start_date}_{end_date}_{stock_codes}"
with self._lock:
if cache_key in self._cache:
return self._cache[cache_key]
if self.is_memory_mode:
if table_name not in self.data_source:
raise ValueError(f"内存数据源中缺少表: {table_name}")
df = self.data_source[table_name]
# 确保必需字段存在
for col in columns:
if col not in df.columns and col not in ["ts_code", "trade_date"]:
raise ValueError(f"{table_name} 缺少字段: {col}")
# 过滤日期和股票
df = df.filter(
(pl.col("trade_date") >= start_date)
& (pl.col("trade_date") <= end_date)
)
if stock_codes is not None:
df = df.filter(pl.col("ts_code").is_in(stock_codes))
# 选择需要的列
select_cols = ["ts_code", "trade_date"] + [
c for c in columns if c in df.columns
]
df = df.select(select_cols)
else:
# TODO: 实现真实数据库连接DuckDB
raise NotImplementedError("数据库连接模式尚未实现")
with self._lock:
self._cache[cache_key] = df
return df
def _assemble_wide_table(
self,
table_data: Dict[str, pl.DataFrame],
required_tables: Dict[str, Set[str]],
) -> pl.DataFrame:
"""组装多表数据为核心宽表。
使用 left join 合并各表数据,以第一个表为基准。
Args:
table_data: 表名到 DataFrame 的映射
required_tables: 表名到字段集合的映射
Returns:
组装后的宽表
"""
if not table_data:
raise ValueError("没有数据可组装")
# 以第一个表为基准
base_table_name = list(table_data.keys())[0]
result = table_data[base_table_name]
# 与其他表 join
for table_name, df in table_data.items():
if table_name == base_table_name:
continue
# 使用 ts_code 和 trade_date 作为 join 键
result = result.join(
df,
on=["ts_code", "trade_date"],
how="left",
)
return result
def _adjust_start_date(self, start_date: str, lookback_days: int) -> str:
"""根据回看天数调整开始日期。
Args:
start_date: 原始开始日期 (YYYYMMDD)
lookback_days: 需要回看的交易日数
Returns:
调整后的开始日期
"""
# 简化的日期调整假设每月30天向前推移
# 实际应用中应该使用交易日历
year = int(start_date[:4])
month = int(start_date[4:6])
day = int(start_date[6:8])
total_days = lookback_days + 30 # 额外缓冲
day -= total_days
while day <= 0:
month -= 1
if month <= 0:
month = 12
year -= 1
day += 30
return f"{year:04d}{month:02d}{day:02d}"
def clear_cache(self) -> None:
"""清除数据缓存。"""
with self._lock:
self._cache.clear()
class ExecutionPlanner:
"""执行计划生成器。
整合编译器和翻译器,生成完整的执行计划。
Attributes:
compiler: 依赖提取器
translator: Polars 翻译器
"""
def __init__(self) -> None:
"""初始化执行计划生成器。"""
self.compiler = DependencyExtractor()
self.translator = PolarsTranslator()
def create_plan(
self,
expression: Node,
output_name: str = "factor",
data_specs: Optional[List[DataSpec]] = None,
) -> ExecutionPlan:
"""从表达式创建执行计划。
Args:
expression: DSL 表达式节点
output_name: 输出因子名称
data_specs: 预定义的数据规格None 时自动推导
Returns:
执行计划对象
"""
# 1. 提取依赖
dependencies = self.compiler.extract_dependencies(expression)
# 2. 翻译为 Polars 表达式
polars_expr = self.translator.translate(expression)
# 3. 推导或验证数据规格
if data_specs is None:
data_specs = self._infer_data_specs(dependencies, expression)
return ExecutionPlan(
data_specs=data_specs,
polars_expr=polars_expr,
dependencies=dependencies,
output_name=output_name,
)
def _infer_data_specs(
self,
dependencies: Set[str],
expression: Node,
) -> List[DataSpec]:
"""从依赖推导数据规格。
根据表达式中的函数类型推断回看天数需求。
Args:
dependencies: 依赖的字段集合
expression: 表达式节点
Returns:
数据规格列表
"""
# 计算最大回看窗口
max_window = self._extract_max_window(expression)
lookback_days = max(1, max_window)
# 假设所有字段都来自 daily 表
columns = list(dependencies)
return [
DataSpec(
table="daily",
columns=columns,
lookback_days=lookback_days,
)
]
def _extract_max_window(self, node: Node) -> int:
"""从表达式中提取最大窗口大小。
Args:
node: AST 节点
Returns:
最大窗口大小,无时序函数返回 1
"""
if isinstance(node, FunctionNode):
window = 1
# 检查函数参数中的窗口大小
for arg in node.args:
if (
isinstance(arg, Constant)
and isinstance(arg.value, int)
and arg.value > window
):
window = arg.value
# 递归检查子表达式
for arg in node.args:
if isinstance(arg, Node) and not isinstance(arg, Constant):
window = max(window, self._extract_max_window(arg))
return window
elif isinstance(node, BinaryOpNode):
return max(
self._extract_max_window(node.left),
self._extract_max_window(node.right),
)
elif isinstance(node, UnaryOpNode):
return self._extract_max_window(node.operand)
return 1
class ComputeEngine:
"""计算引擎 - 执行并行运算。
负责将执行计划应用到数据上,支持并行计算。
Attributes:
max_workers: 最大并行工作线程数
use_processes: 是否使用进程池CPU 密集型任务)
"""
def __init__(
self,
max_workers: int = 4,
use_processes: bool = False,
) -> None:
"""初始化计算引擎。
Args:
max_workers: 最大并行工作线程数
use_processes: 是否使用进程池代替线程池
"""
self.max_workers = max_workers
self.use_processes = use_processes
def execute(
self,
plan: ExecutionPlan,
data: pl.DataFrame,
) -> pl.DataFrame:
"""执行计算计划。
Args:
plan: 执行计划
data: 输入数据(核心宽表)
Returns:
包含因子结果的 DataFrame
"""
# 检查依赖字段是否存在
missing_cols = plan.dependencies - set(data.columns)
if missing_cols:
raise ValueError(f"数据缺少必要的字段: {missing_cols}")
# 执行计算
result = data.with_columns([plan.polars_expr.alias(plan.output_name)])
return result
def execute_batch(
self,
plans: List[ExecutionPlan],
data: pl.DataFrame,
) -> pl.DataFrame:
"""批量执行多个计算计划。
Args:
plans: 执行计划列表
data: 输入数据
Returns:
包含所有因子结果的 DataFrame
"""
result = data
for plan in plans:
result = self.execute(plan, result)
return result
def execute_parallel(
self,
plans: List[ExecutionPlan],
data: pl.DataFrame,
) -> pl.DataFrame:
"""并行执行多个计算计划。
Args:
plans: 执行计划列表
data: 输入数据
Returns:
包含所有因子结果的 DataFrame
"""
# 检查计划间依赖
independent_plans = []
dependent_plans = []
available_cols = set(data.columns)
for plan in plans:
if plan.dependencies <= available_cols:
independent_plans.append(plan)
available_cols.add(plan.output_name)
else:
dependent_plans.append(plan)
# 并行执行独立计划
if independent_plans:
ExecutorClass = (
ProcessPoolExecutor if self.use_processes else ThreadPoolExecutor
)
with ExecutorClass(max_workers=self.max_workers) as executor:
futures = {
executor.submit(self._execute_single, plan, data): plan
for plan in independent_plans
}
results = []
for future in futures:
plan = futures[future]
try:
result_col = future.result()
results.append((plan.output_name, result_col))
except Exception as e:
raise RuntimeError(f"计算因子 {plan.output_name} 失败: {e}")
# 合并结果
for name, series in results:
data = data.with_columns([series.alias(name)])
# 顺序执行依赖计划
for plan in dependent_plans:
data = self.execute(plan, data)
return data
def _execute_single(
self,
plan: ExecutionPlan,
data: pl.DataFrame,
) -> pl.Series:
"""执行单个计划并返回结果列。
Args:
plan: 执行计划
data: 输入数据
Returns:
计算结果序列
"""
result = self.execute(plan, data)
return result[plan.output_name]
class FactorEngine:
"""因子计算引擎 - 系统统一入口。
提供从表达式到结果的完整执行链路,是研究员使用系统的唯一接口。
执行流程:
1. 注册表达式 -> 调用编译器解析依赖
2. 调用路由器连接数据库拉取并组装核心宽表
3. 调用翻译器生成物理执行计划
4. 将计划提交给计算引擎执行并行运算
5. 返回包含因子结果的数据表
Attributes:
router: 数据路由器
planner: 执行计划生成器
compute_engine: 计算引擎
registered_expressions: 注册的表达式字典
"""
def __init__(
self,
data_source: Optional[Dict[str, pl.DataFrame]] = None,
max_workers: int = 4,
) -> None:
"""初始化因子引擎。
Args:
data_source: 内存数据源,为 None 时使用数据库连接
max_workers: 并行计算的最大工作线程数
"""
self.router = DataRouter(data_source)
self.planner = ExecutionPlanner()
self.compute_engine = ComputeEngine(max_workers=max_workers)
self.registered_expressions: Dict[str, Node] = {}
self._plans: Dict[str, ExecutionPlan] = {}
def register(
self,
name: str,
expression: Node,
data_specs: Optional[List[DataSpec]] = None,
) -> FactorEngine:
"""注册因子表达式。
Args:
name: 因子名称
expression: DSL 表达式
data_specs: 数据规格None 时自动推导
Returns:
self支持链式调用
Example:
>>> from src.factors.api import close, ts_mean
>>> engine = FactorEngine()
>>> engine.register("ma20", ts_mean(close, 20))
"""
self.registered_expressions[name] = expression
# 预创建执行计划
plan = self.planner.create_plan(
expression=expression,
output_name=name,
data_specs=data_specs,
)
self._plans[name] = plan
return self
def compute(
self,
factor_names: Union[str, List[str]],
start_date: str,
end_date: str,
stock_codes: Optional[List[str]] = None,
) -> pl.DataFrame:
"""计算指定因子的值。
完整的执行流程:取数 -> 组装 -> 翻译 -> 计算。
Args:
factor_names: 因子名称或名称列表
start_date: 开始日期 (YYYYMMDD)
end_date: 结束日期 (YYYYMMDD)
stock_codes: 股票代码列表None 表示全市场
Returns:
包含因子结果的数据表
Raises:
ValueError: 当因子未注册或数据不足时
Example:
>>> result = engine.compute("ma20", "20240101", "20240131")
>>> result = engine.compute(["ma20", "rsi"], "20240101", "20240131")
"""
# 标准化因子名称
if isinstance(factor_names, str):
factor_names = [factor_names]
# 1. 获取执行计划
plans = []
for name in factor_names:
if name not in self._plans:
raise ValueError(f"因子未注册: {name}")
plans.append(self._plans[name])
# 2. 合并数据规格并获取数据
all_specs = []
for plan in plans:
all_specs.extend(plan.data_specs)
# 3. 从路由器获取核心宽表
core_data = self.router.fetch_data(
data_specs=all_specs,
start_date=start_date,
end_date=end_date,
stock_codes=stock_codes,
)
if len(core_data) == 0:
raise ValueError("未获取到任何数据,请检查日期范围和股票代码")
# 4. 执行计算
if len(plans) == 1:
result = self.compute_engine.execute(plans[0], core_data)
else:
result = self.compute_engine.execute_batch(plans, core_data)
return result
def list_registered(self) -> List[str]:
"""获取已注册的因子列表。
Returns:
因子名称列表
"""
return list(self.registered_expressions.keys())
def get_expression(self, name: str) -> Optional[Node]:
"""获取已注册的表达式。
Args:
name: 因子名称
Returns:
表达式节点,未注册时返回 None
"""
return self.registered_expressions.get(name)
def clear(self) -> None:
"""清除所有注册的表达式和缓存。"""
self.registered_expressions.clear()
self._plans.clear()
self.router.clear_cache()
def preview_plan(self, factor_name: str) -> Optional[ExecutionPlan]:
"""预览因子的执行计划。
Args:
factor_name: 因子名称
Returns:
执行计划,未注册时返回 None
"""
return self._plans.get(factor_name)