Compare commits

...

3 Commits

Author SHA1 Message Date
05d0c90312 feat(factors): 新增公式解析基础组件
新增公式解析相关模块,支持将字符串表达式解析为 DSL 节点树:
- exceptions.py: 定义公式解析异常体系
  - FormulaParseError 基类,提供位置指示的错误信息
  - UnknownFunctionError 支持模糊匹配建议
  - InvalidSyntaxError、EmptyExpressionError 等具体异常
- parser.py: 基于 Python ast 的公式解析器
  - 支持符号引用、数值常量、二元/一元运算
  - 支持函数调用和比较运算
  - 常量折叠优化
- registry.py: 函数注册表
  - 支持动态注册和查询公式函数
  - 提供可用函数列表和重复注册检查
2026-03-03 00:04:48 +08:00
77e4e94e05 refactor(factors): 拆分 engine.py 为模块化包
将单文件 engine.py (1064行) 拆分为 engine/ 包:
- 数据规格、路由器、计划器、计算引擎、因子引擎分离
- 保持向后兼容,API 无变化
2026-03-02 22:29:18 +08:00
1c0c4a0de1 fix(factors): 修复 cs_rank 等截面函数在依赖表达式时输出全 null 的问题 2026-03-02 22:21:43 +08:00
11 changed files with 2026 additions and 817 deletions

View File

@@ -52,6 +52,22 @@ from src.factors.engine import (
ComputeEngine,
)
from src.factors.parser import FormulaParser
from src.factors.registry import FunctionRegistry
from src.factors.exceptions import (
FormulaParseError,
UnknownFunctionError,
InvalidSyntaxError,
EmptyExpressionError,
RegistryError,
DuplicateFunctionError,
)
# 保持向后兼容factor_engine.py 中的类也可以通过 src.factors.engine 访问
# 例如from src.factors.engine import FactorEngine
__all__ = [
# DSL 层
"Node",
@@ -73,4 +89,15 @@ __all__ = [
"DataRouter",
"ExecutionPlanner",
"ComputeEngine",
# 解析器 (Phase 1 新增)
"FormulaParser",
# 注册表 (Phase 1 新增)
"FunctionRegistry",
# 异常类 (Phase 1 新增)
"FormulaParseError",
"UnknownFunctionError",
"InvalidSyntaxError",
"EmptyExpressionError",
"RegistryError",
"DuplicateFunctionError",
]

View File

@@ -1,817 +0,0 @@
"""FactorEngine - 因子计算引擎统一入口。
提供从表达式注册到结果输出的完整执行链路:
接收研究员的表达式 -> 调用编译器解析依赖 -> 调用路由器连接数据库拉取并组装核心宽表
-> 调用翻译器生成物理执行计划 -> 将计划提交给计算引擎执行并行运算。
"""
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Set, Union
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor
import threading
import polars as pl
from src.factors.dsl import (
Node,
Symbol,
FunctionNode,
BinaryOpNode,
UnaryOpNode,
Constant,
)
from src.factors.compiler import DependencyExtractor
from src.factors.translator import PolarsTranslator
from src.data.storage import Storage
@dataclass
class DataSpec:
"""数据规格定义。
描述因子计算所需的数据表和字段。
Attributes:
table: 数据表名称
columns: 需要的字段列表
lookback_days: 回看天数(用于时序计算)
"""
table: str
columns: List[str]
lookback_days: int = 1
@dataclass
class ExecutionPlan:
"""执行计划。
包含完整的执行所需信息:数据源、转换逻辑、输出格式。
Attributes:
data_specs: 数据规格列表
polars_expr: Polars 表达式
dependencies: 依赖的原始字段
output_name: 输出因子名称
"""
data_specs: List[DataSpec]
polars_expr: pl.Expr
dependencies: Set[str]
output_name: str
class DataRouter:
"""数据路由器 - 按需取数、组装核心宽表。
负责根据数据规格从数据源拉取数据,并组装成统一的宽表格式。
支持内存数据源(用于测试)和真实数据库连接。
Attributes:
data_source: 数据源,可以是内存 DataFrame 字典或数据库连接
is_memory_mode: 是否为内存模式
"""
def __init__(self, data_source: Optional[Dict[str, pl.DataFrame]] = None) -> None:
"""初始化数据路由器。
Args:
data_source: 内存数据源,字典格式 {表名: DataFrame}
为 None 时自动连接 DuckDB 数据库
"""
self.data_source = data_source or {}
self.is_memory_mode = data_source is not None
self._cache: Dict[str, pl.DataFrame] = {}
self._lock = threading.Lock()
# 数据库模式下初始化 Storage
if not self.is_memory_mode:
self._storage = Storage()
else:
self._storage = None
def fetch_data(
self,
data_specs: List[DataSpec],
start_date: str,
end_date: str,
stock_codes: Optional[List[str]] = None,
) -> pl.DataFrame:
"""根据数据规格获取并组装核心宽表。
Args:
data_specs: 数据规格列表
start_date: 开始日期 (YYYYMMDD)
end_date: 结束日期 (YYYYMMDD)
stock_codes: 股票代码列表None 表示全市场
Returns:
组装好的核心宽表 DataFrame
Raises:
ValueError: 当数据源中缺少必要的表或字段时
"""
if not data_specs:
raise ValueError("数据规格不能为空")
# 收集所有需要的表和字段
required_tables: Dict[str, Set[str]] = {}
max_lookback = 0
for spec in data_specs:
if spec.table not in required_tables:
required_tables[spec.table] = set()
required_tables[spec.table].update(spec.columns)
max_lookback = max(max_lookback, spec.lookback_days)
# 调整日期范围以包含回看期
adjusted_start = self._adjust_start_date(start_date, max_lookback)
# 从数据源获取各表数据
table_data = {}
for table_name, columns in required_tables.items():
df = self._load_table(
table_name=table_name,
columns=list(columns),
start_date=adjusted_start,
end_date=end_date,
stock_codes=stock_codes,
)
table_data[table_name] = df
# 组装核心宽表
core_table = self._assemble_wide_table(table_data, required_tables)
# 过滤到实际请求日期范围
core_table = core_table.filter(
(pl.col("trade_date") >= start_date) & (pl.col("trade_date") <= end_date)
)
return core_table
def _load_table(
self,
table_name: str,
columns: List[str],
start_date: str,
end_date: str,
stock_codes: Optional[List[str]] = None,
) -> pl.DataFrame:
"""加载单个表的数据。
Args:
table_name: 表名
columns: 需要的字段
start_date: 开始日期
end_date: 结束日期
stock_codes: 股票代码过滤
Returns:
过滤后的 DataFrame
"""
cache_key = f"{table_name}_{start_date}_{end_date}_{stock_codes}"
with self._lock:
if cache_key in self._cache:
return self._cache[cache_key]
if self.is_memory_mode:
df = self._load_from_memory(
table_name, columns, start_date, end_date, stock_codes
)
else:
df = self._load_from_database(
table_name, columns, start_date, end_date, stock_codes
)
with self._lock:
self._cache[cache_key] = df
return df
def _load_from_memory(
self,
table_name: str,
columns: List[str],
start_date: str,
end_date: str,
stock_codes: Optional[List[str]] = None,
) -> pl.DataFrame:
"""从内存数据源加载数据。"""
if table_name not in self.data_source:
raise ValueError(f"内存数据源中缺少表: {table_name}")
df = self.data_source[table_name]
# 确保必需字段存在
for col in columns:
if col not in df.columns and col not in ["ts_code", "trade_date"]:
raise ValueError(f"{table_name} 缺少字段: {col}")
# 过滤日期和股票
df = df.filter(
(pl.col("trade_date") >= start_date) & (pl.col("trade_date") <= end_date)
)
if stock_codes is not None:
df = df.filter(pl.col("ts_code").is_in(stock_codes))
# 选择需要的列
select_cols = ["ts_code", "trade_date"] + [
c for c in columns if c in df.columns
]
return df.select(select_cols)
def _load_from_database(
self,
table_name: str,
columns: List[str],
start_date: str,
end_date: str,
stock_codes: Optional[List[str]] = None,
) -> pl.DataFrame:
"""从 DuckDB 数据库加载数据。
利用 Storage.load_polars() 方法,支持 SQL 查询下推。
"""
if self._storage is None:
raise RuntimeError("Storage 未初始化")
# 检查表是否存在
if not self._storage.exists(table_name):
raise ValueError(f"数据库中不存在表: {table_name}")
# 构建查询参数
# Storage.load_polars 目前只支持单个 ts_code需要处理列表情况
if stock_codes is not None and len(stock_codes) == 1:
ts_code_filter = stock_codes[0]
else:
ts_code_filter = None
try:
# 从数据库加载原始数据
df = self._storage.load_polars(
name=table_name,
start_date=start_date,
end_date=end_date,
ts_code=ts_code_filter,
)
except Exception as e:
raise RuntimeError(f"从数据库加载表 {table_name} 失败: {e}")
# 如果 stock_codes 是列表且长度 > 1在内存中过滤
if stock_codes is not None and len(stock_codes) > 1:
df = df.filter(pl.col("ts_code").is_in(stock_codes))
# 检查必需字段
for col in columns:
if col not in df.columns and col not in ["ts_code", "trade_date"]:
raise ValueError(f"{table_name} 缺少字段: {col}")
# 选择需要的列
select_cols = ["ts_code", "trade_date"] + [
c for c in columns if c in df.columns
]
return df.select(select_cols)
def _assemble_wide_table(
self,
table_data: Dict[str, pl.DataFrame],
required_tables: Dict[str, Set[str]],
) -> pl.DataFrame:
"""组装多表数据为核心宽表。
使用 left join 合并各表数据,以第一个表为基准。
Args:
table_data: 表名到 DataFrame 的映射
required_tables: 表名到字段集合的映射
Returns:
组装后的宽表
"""
if not table_data:
raise ValueError("没有数据可组装")
# 以第一个表为基准
base_table_name = list(table_data.keys())[0]
result = table_data[base_table_name]
# 与其他表 join
for table_name, df in table_data.items():
if table_name == base_table_name:
continue
# 使用 ts_code 和 trade_date 作为 join 键
result = result.join(
df,
on=["ts_code", "trade_date"],
how="left",
)
return result
def _adjust_start_date(self, start_date: str, lookback_days: int) -> str:
"""根据回看天数调整开始日期。
Args:
start_date: 原始开始日期 (YYYYMMDD)
lookback_days: 需要回看的交易日数
Returns:
调整后的开始日期
"""
# 简化的日期调整假设每月30天向前推移
# 实际应用中应该使用交易日历
year = int(start_date[:4])
month = int(start_date[4:6])
day = int(start_date[6:8])
total_days = lookback_days + 30 # 额外缓冲
day -= total_days
while day <= 0:
month -= 1
if month <= 0:
month = 12
year -= 1
day += 30
return f"{year:04d}{month:02d}{day:02d}"
def clear_cache(self) -> None:
"""清除数据缓存。"""
with self._lock:
self._cache.clear()
# 数据库模式下清理 Storage 连接(可选)
if not self.is_memory_mode and self._storage is not None:
# Storage 使用单例模式,不需要关闭连接
pass
class ExecutionPlanner:
"""执行计划生成器。
整合编译器和翻译器,生成完整的执行计划。
Attributes:
compiler: 依赖提取器
translator: Polars 翻译器
"""
def __init__(self) -> None:
"""初始化执行计划生成器。"""
self.compiler = DependencyExtractor()
self.translator = PolarsTranslator()
def create_plan(
self,
expression: Node,
output_name: str = "factor",
data_specs: Optional[List[DataSpec]] = None,
) -> ExecutionPlan:
"""从表达式创建执行计划。
Args:
expression: DSL 表达式节点
output_name: 输出因子名称
data_specs: 预定义的数据规格None 时自动推导
Returns:
执行计划对象
"""
# 1. 提取依赖
dependencies = self.compiler.extract_dependencies(expression)
# 2. 翻译为 Polars 表达式
polars_expr = self.translator.translate(expression)
# 3. 推导或验证数据规格
if data_specs is None:
data_specs = self._infer_data_specs(dependencies, expression)
return ExecutionPlan(
data_specs=data_specs,
polars_expr=polars_expr,
dependencies=dependencies,
output_name=output_name,
)
def _infer_data_specs(
self,
dependencies: Set[str],
expression: Node,
) -> List[DataSpec]:
"""从依赖推导数据规格。
根据表达式中的函数类型推断回看天数需求。
基础行情字段open, high, low, close, vol, amount, pre_close, change, pct_chg
默认从 pro_bar 表获取。
Args:
dependencies: 依赖的字段集合
expression: 表达式节点
Returns:
数据规格列表
"""
# 计算最大回看窗口
max_window = self._extract_max_window(expression)
lookback_days = max(1, max_window)
# 基础行情字段集合(这些字段从 pro_bar 表获取)
pro_bar_fields = {
"open",
"high",
"low",
"close",
"vol",
"amount",
"pre_close",
"change",
"pct_chg",
"turnover_rate",
"volume_ratio",
}
# 将依赖分为 pro_bar 字段和其他字段
pro_bar_deps = dependencies & pro_bar_fields
other_deps = dependencies - pro_bar_fields
data_specs = []
# pro_bar 表的数据规格
if pro_bar_deps:
data_specs.append(
DataSpec(
table="pro_bar",
columns=sorted(pro_bar_deps),
lookback_days=lookback_days,
)
)
# 其他字段从 daily 表获取
if other_deps:
data_specs.append(
DataSpec(
table="daily",
columns=sorted(other_deps),
lookback_days=lookback_days,
)
)
return data_specs
def _extract_max_window(self, node: Node) -> int:
"""从表达式中提取最大窗口大小。
Args:
node: AST 节点
Returns:
最大窗口大小,无时序函数返回 1
"""
if isinstance(node, FunctionNode):
window = 1
# 检查函数参数中的窗口大小
for arg in node.args:
if (
isinstance(arg, Constant)
and isinstance(arg.value, int)
and arg.value > window
):
window = arg.value
# 递归检查子表达式
for arg in node.args:
if isinstance(arg, Node) and not isinstance(arg, Constant):
window = max(window, self._extract_max_window(arg))
return window
elif isinstance(node, BinaryOpNode):
return max(
self._extract_max_window(node.left),
self._extract_max_window(node.right),
)
elif isinstance(node, UnaryOpNode):
return self._extract_max_window(node.operand)
return 1
class ComputeEngine:
"""计算引擎 - 执行并行运算。
负责将执行计划应用到数据上,支持并行计算。
Attributes:
max_workers: 最大并行工作线程数
use_processes: 是否使用进程池CPU 密集型任务)
"""
def __init__(
self,
max_workers: int = 4,
use_processes: bool = False,
) -> None:
"""初始化计算引擎。
Args:
max_workers: 最大并行工作线程数
use_processes: 是否使用进程池代替线程池
"""
self.max_workers = max_workers
self.use_processes = use_processes
def execute(
self,
plan: ExecutionPlan,
data: pl.DataFrame,
) -> pl.DataFrame:
"""执行计算计划。
Args:
plan: 执行计划
data: 输入数据(核心宽表)
Returns:
包含因子结果的 DataFrame
"""
# 检查依赖字段是否存在
missing_cols = plan.dependencies - set(data.columns)
if missing_cols:
raise ValueError(f"数据缺少必要的字段: {missing_cols}")
# 执行计算
result = data.with_columns([plan.polars_expr.alias(plan.output_name)])
return result
def execute_batch(
self,
plans: List[ExecutionPlan],
data: pl.DataFrame,
) -> pl.DataFrame:
"""批量执行多个计算计划。
Args:
plans: 执行计划列表
data: 输入数据
Returns:
包含所有因子结果的 DataFrame
"""
result = data
for plan in plans:
result = self.execute(plan, result)
return result
def execute_parallel(
self,
plans: List[ExecutionPlan],
data: pl.DataFrame,
) -> pl.DataFrame:
"""并行执行多个计算计划。
Args:
plans: 执行计划列表
data: 输入数据
Returns:
包含所有因子结果的 DataFrame
"""
# 检查计划间依赖
independent_plans = []
dependent_plans = []
available_cols = set(data.columns)
for plan in plans:
if plan.dependencies <= available_cols:
independent_plans.append(plan)
available_cols.add(plan.output_name)
else:
dependent_plans.append(plan)
# 并行执行独立计划
if independent_plans:
ExecutorClass = (
ProcessPoolExecutor if self.use_processes else ThreadPoolExecutor
)
with ExecutorClass(max_workers=self.max_workers) as executor:
futures = {
executor.submit(self._execute_single, plan, data): plan
for plan in independent_plans
}
results = []
for future in futures:
plan = futures[future]
try:
result_col = future.result()
results.append((plan.output_name, result_col))
except Exception as e:
raise RuntimeError(f"计算因子 {plan.output_name} 失败: {e}")
# 合并结果
for name, series in results:
data = data.with_columns([series.alias(name)])
# 顺序执行依赖计划
for plan in dependent_plans:
data = self.execute(plan, data)
return data
def _execute_single(
self,
plan: ExecutionPlan,
data: pl.DataFrame,
) -> pl.Series:
"""执行单个计划并返回结果列。
Args:
plan: 执行计划
data: 输入数据
Returns:
计算结果序列
"""
result = self.execute(plan, data)
return result[plan.output_name]
class FactorEngine:
"""因子计算引擎 - 系统统一入口。
提供从表达式到结果的完整执行链路,是研究员使用系统的唯一接口。
执行流程:
1. 注册表达式 -> 调用编译器解析依赖
2. 调用路由器连接数据库拉取并组装核心宽表
3. 调用翻译器生成物理执行计划
4. 将计划提交给计算引擎执行并行运算
5. 返回包含因子结果的数据表
Attributes:
router: 数据路由器
planner: 执行计划生成器
compute_engine: 计算引擎
registered_expressions: 注册的表达式字典
"""
def __init__(
self,
data_source: Optional[Dict[str, pl.DataFrame]] = None,
max_workers: int = 4,
) -> None:
"""初始化因子引擎。
Args:
data_source: 内存数据源,为 None 时使用数据库连接
max_workers: 并行计算的最大工作线程数
"""
self.router = DataRouter(data_source)
self.planner = ExecutionPlanner()
self.compute_engine = ComputeEngine(max_workers=max_workers)
self.registered_expressions: Dict[str, Node] = {}
self._plans: Dict[str, ExecutionPlan] = {}
def register(
self,
name: str,
expression: Node,
data_specs: Optional[List[DataSpec]] = None,
) -> FactorEngine:
"""注册因子表达式。
Args:
name: 因子名称
expression: DSL 表达式
data_specs: 数据规格None 时自动推导
Returns:
self支持链式调用
Example:
>>> from src.factors.api import close, ts_mean
>>> engine = FactorEngine()
>>> engine.register("ma20", ts_mean(close, 20))
"""
self.registered_expressions[name] = expression
# 预创建执行计划
plan = self.planner.create_plan(
expression=expression,
output_name=name,
data_specs=data_specs,
)
self._plans[name] = plan
return self
def compute(
self,
factor_names: Union[str, List[str]],
start_date: str,
end_date: str,
stock_codes: Optional[List[str]] = None,
) -> pl.DataFrame:
"""计算指定因子的值。
完整的执行流程:取数 -> 组装 -> 翻译 -> 计算。
Args:
factor_names: 因子名称或名称列表
start_date: 开始日期 (YYYYMMDD)
end_date: 结束日期 (YYYYMMDD)
stock_codes: 股票代码列表None 表示全市场
Returns:
包含因子结果的数据表
Raises:
ValueError: 当因子未注册或数据不足时
Example:
>>> result = engine.compute("ma20", "20240101", "20240131")
>>> result = engine.compute(["ma20", "rsi"], "20240101", "20240131")
"""
# 标准化因子名称
if isinstance(factor_names, str):
factor_names = [factor_names]
# 1. 获取执行计划
plans = []
for name in factor_names:
if name not in self._plans:
raise ValueError(f"因子未注册: {name}")
plans.append(self._plans[name])
# 2. 合并数据规格并获取数据
all_specs = []
for plan in plans:
all_specs.extend(plan.data_specs)
# 3. 从路由器获取核心宽表
core_data = self.router.fetch_data(
data_specs=all_specs,
start_date=start_date,
end_date=end_date,
stock_codes=stock_codes,
)
if len(core_data) == 0:
raise ValueError("未获取到任何数据,请检查日期范围和股票代码")
# 4. 执行计算
if len(plans) == 1:
result = self.compute_engine.execute(plans[0], core_data)
else:
result = self.compute_engine.execute_batch(plans, core_data)
return result
def list_registered(self) -> List[str]:
"""获取已注册的因子列表。
Returns:
因子名称列表
"""
return list(self.registered_expressions.keys())
def get_expression(self, name: str) -> Optional[Node]:
"""获取已注册的表达式。
Args:
name: 因子名称
Returns:
表达式节点,未注册时返回 None
"""
return self.registered_expressions.get(name)
def clear(self) -> None:
"""清除所有注册的表达式和缓存。"""
self.registered_expressions.clear()
self._plans.clear()
self.router.clear_cache()
def preview_plan(self, factor_name: str) -> Optional[ExecutionPlan]:
"""预览因子的执行计划。
Args:
factor_name: 因子名称
Returns:
执行计划,未注册时返回 None
"""
return self._plans.get(factor_name)

View File

@@ -0,0 +1,28 @@
"""因子计算引擎模块。
提供完整的因子计算引擎组件:
- DataSpec: 数据规格定义
- ExecutionPlan: 执行计划
- DataRouter: 数据路由器
- ExecutionPlanner: 执行计划生成器
- ComputeEngine: 计算引擎
- FactorEngine: 因子计算引擎(统一入口)
"""
from src.factors.engine.data_spec import DataSpec, ExecutionPlan
from src.factors.engine.data_router import DataRouter
from src.factors.engine.planner import ExecutionPlanner
from src.factors.engine.compute_engine import ComputeEngine
from src.factors.engine.factor_engine import FactorEngine
__all__ = [
"DataSpec",
"ExecutionPlan",
"DataRouter",
"ExecutionPlanner",
"ComputeEngine",
"FactorEngine",
]
# 类型导出(用于类型注解)
# FunctionRegistry 从 src.factors.registry 导入

View File

@@ -0,0 +1,155 @@
"""计算引擎。
执行并行运算,负责将执行计划应用到数据上。
"""
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor
from typing import Any, Dict, List, Optional, Set, Union
import polars as pl
from src.factors.engine.data_spec import ExecutionPlan
class ComputeEngine:
"""计算引擎 - 执行并行运算。
负责将执行计划应用到数据上,支持并行计算。
Attributes:
max_workers: 最大并行工作线程数
use_processes: 是否使用进程池CPU 密集型任务)
"""
def __init__(
self,
max_workers: int = 4,
use_processes: bool = False,
) -> None:
"""初始化计算引擎。
Args:
max_workers: 最大并行工作线程数
use_processes: 是否使用进程池代替线程池
"""
self.max_workers = max_workers
self.use_processes = use_processes
def execute(
self,
plan: ExecutionPlan,
data: pl.DataFrame,
) -> pl.DataFrame:
"""执行计算计划。
Args:
plan: 执行计划
data: 输入数据(核心宽表)
Returns:
包含因子结果的 DataFrame
"""
# 检查依赖字段是否存在
missing_cols = plan.dependencies - set(data.columns)
if missing_cols:
raise ValueError(f"数据缺少必要的字段: {missing_cols}")
# 执行计算
result = data.with_columns([plan.polars_expr.alias(plan.output_name)])
return result
def execute_batch(
self,
plans: List[ExecutionPlan],
data: pl.DataFrame,
) -> pl.DataFrame:
"""批量执行多个计算计划。
Args:
plans: 执行计划列表
data: 输入数据
Returns:
包含所有因子结果的 DataFrame
"""
result = data
for plan in plans:
result = self.execute(plan, result)
return result
def execute_parallel(
self,
plans: List[ExecutionPlan],
data: pl.DataFrame,
) -> pl.DataFrame:
"""并行执行多个计算计划。
Args:
plans: 执行计划列表
data: 输入数据
Returns:
包含所有因子结果的 DataFrame
"""
# 检查计划间依赖
independent_plans = []
dependent_plans = []
available_cols = set(data.columns)
for plan in plans:
if plan.dependencies <= available_cols:
independent_plans.append(plan)
available_cols.add(plan.output_name)
else:
dependent_plans.append(plan)
# 并行执行独立计划
if independent_plans:
ExecutorClass = (
ProcessPoolExecutor if self.use_processes else ThreadPoolExecutor
)
with ExecutorClass(max_workers=self.max_workers) as executor:
futures = {
executor.submit(self._execute_single, plan, data): plan
for plan in independent_plans
}
results = []
for future in futures:
plan = futures[future]
try:
result_col = future.result()
results.append((plan.output_name, result_col))
except Exception as e:
raise RuntimeError(f"计算因子 {plan.output_name} 失败: {e}")
# 合并结果
for name, series in results:
data = data.with_columns([series.alias(name)])
# 顺序执行依赖计划
for plan in dependent_plans:
data = self.execute(plan, data)
return data
def _execute_single(
self,
plan: ExecutionPlan,
data: pl.DataFrame,
) -> pl.Series:
"""执行单个计划并返回结果列。
Args:
plan: 执行计划
data: 输入数据
Returns:
计算结果序列
"""
result = self.execute(plan, data)
return result[plan.output_name]

View File

@@ -0,0 +1,304 @@
"""数据路由器。
按需取数、组装核心宽表。
负责根据数据规格从数据源拉取数据,并组装成统一的宽表格式。
支持内存数据源(用于测试)和真实数据库连接。
"""
from typing import Any, Dict, List, Optional, Set, Union
import threading
import polars as pl
from src.factors.engine.data_spec import DataSpec
from src.data.storage import Storage
class DataRouter:
"""数据路由器 - 按需取数、组装核心宽表。
负责根据数据规格从数据源拉取数据,并组装成统一的宽表格式。
支持内存数据源(用于测试)和真实数据库连接。
Attributes:
data_source: 数据源,可以是内存 DataFrame 字典或数据库连接
is_memory_mode: 是否为内存模式
"""
def __init__(self, data_source: Optional[Dict[str, pl.DataFrame]] = None) -> None:
"""初始化数据路由器。
Args:
data_source: 内存数据源,字典格式 {表名: DataFrame}
为 None 时自动连接 DuckDB 数据库
"""
self.data_source = data_source or {}
self.is_memory_mode = data_source is not None
self._cache: Dict[str, pl.DataFrame] = {}
self._lock = threading.Lock()
# 数据库模式下初始化 Storage
if not self.is_memory_mode:
self._storage = Storage()
else:
self._storage = None
def fetch_data(
self,
data_specs: List[DataSpec],
start_date: str,
end_date: str,
stock_codes: Optional[List[str]] = None,
) -> pl.DataFrame:
"""根据数据规格获取并组装核心宽表。
Args:
data_specs: 数据规格列表
start_date: 开始日期 (YYYYMMDD)
end_date: 结束日期 (YYYYMMDD)
stock_codes: 股票代码列表None 表示全市场
Returns:
组装好的核心宽表 DataFrame
Raises:
ValueError: 当数据源中缺少必要的表或字段时
"""
if not data_specs:
raise ValueError("数据规格不能为空")
# 收集所有需要的表和字段
required_tables: Dict[str, Set[str]] = {}
max_lookback = 0
for spec in data_specs:
if spec.table not in required_tables:
required_tables[spec.table] = set()
required_tables[spec.table].update(spec.columns)
max_lookback = max(max_lookback, spec.lookback_days)
# 调整日期范围以包含回看期
adjusted_start = self._adjust_start_date(start_date, max_lookback)
# 从数据源获取各表数据
table_data = {}
for table_name, columns in required_tables.items():
df = self._load_table(
table_name=table_name,
columns=list(columns),
start_date=adjusted_start,
end_date=end_date,
stock_codes=stock_codes,
)
table_data[table_name] = df
# 组装核心宽表
core_table = self._assemble_wide_table(table_data, required_tables)
# 过滤到实际请求日期范围
core_table = core_table.filter(
(pl.col("trade_date") >= start_date) & (pl.col("trade_date") <= end_date)
)
return core_table
def _load_table(
self,
table_name: str,
columns: List[str],
start_date: str,
end_date: str,
stock_codes: Optional[List[str]] = None,
) -> pl.DataFrame:
"""加载单个表的数据。
Args:
table_name: 表名
columns: 需要的字段
start_date: 开始日期
end_date: 结束日期
stock_codes: 股票代码过滤
Returns:
过滤后的 DataFrame
"""
cache_key = f"{table_name}_{start_date}_{end_date}_{stock_codes}"
with self._lock:
if cache_key in self._cache:
return self._cache[cache_key]
if self.is_memory_mode:
df = self._load_from_memory(
table_name, columns, start_date, end_date, stock_codes
)
else:
df = self._load_from_database(
table_name, columns, start_date, end_date, stock_codes
)
with self._lock:
self._cache[cache_key] = df
return df
def _load_from_memory(
self,
table_name: str,
columns: List[str],
start_date: str,
end_date: str,
stock_codes: Optional[List[str]] = None,
) -> pl.DataFrame:
"""从内存数据源加载数据。"""
if table_name not in self.data_source:
raise ValueError(f"内存数据源中缺少表: {table_name}")
df = self.data_source[table_name]
# 确保必需字段存在
for col in columns:
if col not in df.columns and col not in ["ts_code", "trade_date"]:
raise ValueError(f"{table_name} 缺少字段: {col}")
# 过滤日期和股票
df = df.filter(
(pl.col("trade_date") >= start_date) & (pl.col("trade_date") <= end_date)
)
if stock_codes is not None:
df = df.filter(pl.col("ts_code").is_in(stock_codes))
# 选择需要的列
select_cols = ["ts_code", "trade_date"] + [
c for c in columns if c in df.columns
]
return df.select(select_cols)
def _load_from_database(
self,
table_name: str,
columns: List[str],
start_date: str,
end_date: str,
stock_codes: Optional[List[str]] = None,
) -> pl.DataFrame:
"""从 DuckDB 数据库加载数据。
利用 Storage.load_polars() 方法,支持 SQL 查询下推。
"""
if self._storage is None:
raise RuntimeError("Storage 未初始化")
# 检查表是否存在
if not self._storage.exists(table_name):
raise ValueError(f"数据库中不存在表: {table_name}")
# 构建查询参数
# Storage.load_polars 目前只支持单个 ts_code需要处理列表情况
if stock_codes is not None and len(stock_codes) == 1:
ts_code_filter = stock_codes[0]
else:
ts_code_filter = None
try:
# 从数据库加载原始数据
df = self._storage.load_polars(
name=table_name,
start_date=start_date,
end_date=end_date,
ts_code=ts_code_filter,
)
except Exception as e:
raise RuntimeError(f"从数据库加载表 {table_name} 失败: {e}")
# 如果 stock_codes 是列表且长度 > 1在内存中过滤
if stock_codes is not None and len(stock_codes) > 1:
df = df.filter(pl.col("ts_code").is_in(stock_codes))
# 检查必需字段
for col in columns:
if col not in df.columns and col not in ["ts_code", "trade_date"]:
raise ValueError(f"{table_name} 缺少字段: {col}")
# 选择需要的列
select_cols = ["ts_code", "trade_date"] + [
c for c in columns if c in df.columns
]
return df.select(select_cols)
def _assemble_wide_table(
self,
table_data: Dict[str, pl.DataFrame],
required_tables: Dict[str, Set[str]],
) -> pl.DataFrame:
"""组装多表数据为核心宽表。
使用 left join 合并各表数据,以第一个表为基准。
Args:
table_data: 表名到 DataFrame 的映射
required_tables: 表名到字段集合的映射
Returns:
组装后的宽表
"""
if not table_data:
raise ValueError("没有数据可组装")
# 以第一个表为基准
base_table_name = list(table_data.keys())[0]
result = table_data[base_table_name]
# 与其他表 join
for table_name, df in table_data.items():
if table_name == base_table_name:
continue
# 使用 ts_code 和 trade_date 作为 join 键
result = result.join(
df,
on=["ts_code", "trade_date"],
how="left",
)
return result
def _adjust_start_date(self, start_date: str, lookback_days: int) -> str:
"""根据回看天数调整开始日期。
Args:
start_date: 原始开始日期 (YYYYMMDD)
lookback_days: 需要回看的交易日数
Returns:
调整后的开始日期
"""
# 简化的日期调整假设每月30天向前推移
# 实际应用中应该使用交易日历
year = int(start_date[:4])
month = int(start_date[4:6])
day = int(start_date[6:8])
total_days = lookback_days + 30 # 额外缓冲
day -= total_days
while day <= 0:
month -= 1
if month <= 0:
month = 12
year -= 1
day += 30
return f"{year:04d}{month:02d}{day:02d}"
def clear_cache(self) -> None:
"""清除数据缓存。"""
with self._lock:
self._cache.clear()
# 数据库模式下清理 Storage 连接(可选)
if not self.is_memory_mode and self._storage is not None:
# Storage 使用单例模式,不需要关闭连接
pass

View File

@@ -0,0 +1,47 @@
"""数据规格和执行计划定义。
定义因子计算所需的数据规格和执行计划结构。
"""
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Set, Union
import polars as pl
@dataclass
class DataSpec:
"""数据规格定义。
描述因子计算所需的数据表和字段。
Attributes:
table: 数据表名称
columns: 需要的字段列表
lookback_days: 回看天数(用于时序计算)
"""
table: str
columns: List[str]
lookback_days: int = 1
@dataclass
class ExecutionPlan:
"""执行计划。
包含完整的执行所需信息:数据源、转换逻辑、输出格式。
Attributes:
data_specs: 数据规格列表
polars_expr: Polars 表达式
dependencies: 依赖的原始字段
output_name: 输出因子名称
factor_dependencies: 依赖的其他因子名称(用于分步执行)
"""
data_specs: List[DataSpec]
polars_expr: pl.Expr
dependencies: Set[str]
output_name: str
factor_dependencies: Set[str] = field(default_factory=set)

View File

@@ -0,0 +1,513 @@
"""因子计算引擎 - 系统统一入口。
提供从表达式到结果的完整执行链路,是研究员使用系统的唯一接口。
执行流程:
1. 注册表达式 -> 调用编译器解析依赖
2. 调用路由器连接数据库拉取并组装核心宽表
3. 调用翻译器生成物理执行计划
4. 将计划提交给计算引擎执行并行运算
5. 返回包含因子结果的数据表
"""
from typing import Any, Dict, List, Optional, Set, Union, TYPE_CHECKING
import polars as pl
if TYPE_CHECKING:
from src.factors.registry import FunctionRegistry
from src.factors.dsl import (
Node,
Symbol,
BinaryOpNode,
UnaryOpNode,
FunctionNode,
)
from src.factors.translator import PolarsTranslator
from src.factors.engine.data_spec import DataSpec, ExecutionPlan
from src.factors.engine.data_router import DataRouter
from src.factors.engine.planner import ExecutionPlanner
from src.factors.engine.compute_engine import ComputeEngine
class FactorEngine:
"""因子计算引擎 - 系统统一入口。
提供从表达式到结果的完整执行链路,是研究员使用系统的唯一接口。
执行流程:
1. 注册表达式 -> 调用编译器解析依赖
2. 调用路由器连接数据库拉取并组装核心宽表
3. 调用翻译器生成物理执行计划
4. 将计划提交给计算引擎执行并行运算
5. 返回包含因子结果的数据表
Attributes:
router: 数据路由器
planner: 执行计划生成器
compute_engine: 计算引擎
registered_expressions: 注册的表达式字典
_registry: 函数注册表
_parser: 公式解析器
"""
def __init__(
self,
data_source: Optional[Dict[str, pl.DataFrame]] = None,
max_workers: int = 4,
registry: Optional["FunctionRegistry"] = None,
) -> None:
"""初始化因子引擎。
Args:
data_source: 内存数据源,为 None 时使用数据库连接
max_workers: 并行计算的最大工作线程数
registry: 函数注册表None 时创建独立实例
"""
from src.factors.registry import FunctionRegistry
from src.factors.parser import FormulaParser
self.router = DataRouter(data_source)
self.planner = ExecutionPlanner()
self.compute_engine = ComputeEngine(max_workers=max_workers)
self.registered_expressions: Dict[str, Node] = {}
self._plans: Dict[str, ExecutionPlan] = {}
# 初始化注册表和解析器(支持注入外部注册表实现共享)
self._registry = registry if registry is not None else FunctionRegistry()
self._parser = FormulaParser(self._registry)
def register(
self,
name: str,
expression: Node,
data_specs: Optional[List[DataSpec]] = None,
) -> "FactorEngine":
"""注册因子表达式。
Args:
name: 因子名称
expression: DSL 表达式
data_specs: 数据规格None 时自动推导
Returns:
self支持链式调用
Example:
>>> from src.factors.api import close, ts_mean
>>> engine = FactorEngine()
>>> engine.register("ma20", ts_mean(close, 20))
"""
# 检测因子依赖(在注册当前因子之前检查其他已注册因子)
factor_deps = self._find_factor_dependencies(expression)
self.registered_expressions[name] = expression
# 预创建执行计划
plan = self.planner.create_plan(
expression=expression,
output_name=name,
data_specs=data_specs,
)
# 添加因子依赖信息
plan.factor_dependencies = factor_deps
self._plans[name] = plan
return self
def add_factor(
self,
name: str,
expression: Union[str, Node],
data_specs: Optional[List[DataSpec]] = None,
) -> "FactorEngine":
"""注册因子(支持字符串或 Node 表达式)。
这是 register 方法的增强版,支持字符串表达式解析。
向后兼容register 方法保持不变,继续只接受 Node 类型。
遵循 Fail-Fast 原则:字符串表达式会立即解析,失败时立即抛出异常。
Args:
name: 因子名称
expression: 字符串表达式或 Node 对象
data_specs: 可选的数据规格
Returns:
self支持链式调用
Raises:
TypeError: 当 expression 类型不支持时
FormulaParseError: 当字符串解析失败时(立即报错)
Example:
>>> engine = FactorEngine()
>>>
>>> # 字符串方式(新功能)
>>> engine.add_factor("ma20", "ts_mean(close, 20)")
>>>
>>> # Node 方式(与 register 相同)
>>> from src.factors.api import close, ts_mean
>>> engine.add_factor("ma20", ts_mean(close, 20))
>>>
>>> # 复杂表达式
>>> engine.add_factor("alpha1", "cs_rank(close / open)")
>>>
>>> # 链式调用
>>> (engine
... .add_factor("ma5", "ts_mean(close, 5)")
... .add_factor("ma10", "ts_mean(close, 10)")
... .add_factor("golden_cross", "ma5 > ma10"))
"""
if isinstance(expression, str):
# Fail-Fast立即解析失败立即报错
node = self._parser.parse(expression)
elif isinstance(expression, Node):
node = expression
else:
raise TypeError(
f"表达式必须是 str 或 Node 类型,收到 {type(expression).__name__}"
)
# 委托给现有的 register 方法
return self.register(name, node, data_specs)
def compute(
self,
factor_names: Union[str, List[str]],
start_date: str,
end_date: str,
stock_codes: Optional[List[str]] = None,
) -> pl.DataFrame:
"""计算指定因子的值。
完整的执行流程:取数 -> 组装 -> 翻译 -> 计算。
Args:
factor_names: 因子名称或名称列表
start_date: 开始日期 (YYYYMMDD)
end_date: 结束日期 (YYYYMMDD)
stock_codes: 股票代码列表None 表示全市场
Returns:
包含因子结果的数据表
Raises:
ValueError: 当因子未注册或数据不足时
Example:
>>> result = engine.compute("ma20", "20240101", "20240131")
>>> result = engine.compute(["ma20", "rsi"], "20240101", "20240131")
"""
# 标准化因子名称
if isinstance(factor_names, str):
factor_names = [factor_names]
# 1. 获取执行计划
plans = []
for name in factor_names:
if name not in self._plans:
raise ValueError(f"因子未注册: {name}")
plans.append(self._plans[name])
# 2. 合并数据规格并获取数据
all_specs = []
for plan in plans:
all_specs.extend(plan.data_specs)
# 3. 从路由器获取核心宽表
core_data = self.router.fetch_data(
data_specs=all_specs,
start_date=start_date,
end_date=end_date,
stock_codes=stock_codes,
)
if len(core_data) == 0:
raise ValueError("未获取到任何数据,请检查日期范围和股票代码")
# 4. 按依赖顺序执行计算
if len(plans) == 1:
result = self.compute_engine.execute(plans[0], core_data)
else:
# 使用依赖感知的方式执行
result = self._execute_with_dependencies(factor_names, core_data)
return result
def list_registered(self) -> List[str]:
"""获取已注册的因子列表。
Returns:
因子名称列表
"""
return list(self.registered_expressions.keys())
def get_expression(self, name: str) -> Optional[Node]:
"""获取已注册的表达式。
Args:
name: 因子名称
Returns:
表达式节点,未注册时返回 None
"""
return self.registered_expressions.get(name)
def clear(self) -> None:
"""清除所有注册的表达式和缓存。"""
self.registered_expressions.clear()
self._plans.clear()
self.router.clear_cache()
def preview_plan(self, factor_name: str) -> Optional[ExecutionPlan]:
"""预览因子的执行计划。
Args:
factor_name: 因子名称
Returns:
执行计划,未注册时返回 None
"""
return self._plans.get(factor_name)
def _execute_with_dependencies(
self,
factor_names: List[str],
core_data: pl.DataFrame,
) -> pl.DataFrame:
"""按依赖顺序执行因子计算。
支持 cs_rank 等需要依赖列已存在的场景。
Args:
factor_names: 因子名称列表
core_data: 核心宽表数据
Returns:
包含所有因子结果的数据表
"""
# 1. 拓扑排序
sorted_names = self._topological_sort(factor_names)
# 2. 按顺序执行
result = core_data
for name in sorted_names:
plan = self._plans[name]
# 创建新的执行计划,引用已计算的依赖列
new_plan = self._create_optimized_plan(plan, result)
# 执行计算
result = self.compute_engine.execute(new_plan, result)
return result
def _create_optimized_plan(
self,
plan: ExecutionPlan,
current_data: pl.DataFrame,
) -> ExecutionPlan:
"""创建优化的执行计划。
将表达式中已计算的依赖因子替换为列引用。
Args:
plan: 原始执行计划
current_data: 当前数据(包含已计算的依赖列)
Returns:
新的执行计划
"""
from src.factors.dsl import Symbol
# 获取当前数据中已存在的列
existing_cols = set(current_data.columns)
# 检查依赖列是否已存在
deps_available = plan.factor_dependencies & existing_cols
if not deps_available:
# 没有可用的依赖列,直接返回原计划
return plan
# 获取原始表达式
original_expr = self.registered_expressions[plan.output_name]
# 创建新的表达式,用 Symbol 引用替换依赖因子
def replace_with_symbol(node: Node) -> Node:
"""递归替换表达式中的依赖因子为 Symbol 引用。"""
from typing import Any
n: Any = node
# 检查当前节点是否等于某个已计算依赖因子
for dep_name in deps_available:
dep_expr = self.registered_expressions[dep_name]
if self._expressions_equal(node, dep_expr):
return Symbol(dep_name)
# 递归处理子节点
if isinstance(n, BinaryOpNode):
new_left = replace_with_symbol(n.left)
new_right = replace_with_symbol(n.right)
if new_left is not n.left or new_right is not n.right:
return BinaryOpNode(n.op, new_left, new_right)
elif isinstance(n, UnaryOpNode):
new_operand = replace_with_symbol(n.operand)
if new_operand is not n.operand:
return UnaryOpNode(n.op, new_operand)
elif isinstance(n, FunctionNode):
new_args = [replace_with_symbol(arg) for arg in n.args]
if any(
new_arg is not old_arg for new_arg, old_arg in zip(new_args, n.args)
):
return FunctionNode(n.func_name, *new_args)
return node
# 替换表达式
new_expr = replace_with_symbol(original_expr)
# 重新翻译表达式
translator = PolarsTranslator()
new_polars_expr = translator.translate(new_expr)
# 更新依赖集合
new_factor_deps = plan.factor_dependencies - deps_available
new_deps = plan.dependencies | deps_available
return ExecutionPlan(
data_specs=plan.data_specs,
polars_expr=new_polars_expr,
dependencies=new_deps,
output_name=plan.output_name,
factor_dependencies=new_factor_deps,
)
def _expressions_equal(self, expr1: Node, expr2: Node) -> bool:
"""比较两个表达式是否相等。
用于检测因子间的依赖关系。
Args:
expr1: 第一个表达式
expr2: 第二个表达式
Returns:
是否相等
"""
from typing import Any
e1: Any = expr1
e2: Any = expr2
if type(e1) != type(e2):
return False
if isinstance(e1, Symbol):
return e1.name == e2.name
from src.factors.dsl import Constant
if isinstance(e1, Constant):
return e1.value == e2.value
if isinstance(e1, BinaryOpNode):
return (
e1.op == e2.op
and self._expressions_equal(e1.left, e2.left)
and self._expressions_equal(e1.right, e2.right)
)
if isinstance(e1, UnaryOpNode):
return e1.op == e2.op and self._expressions_equal(e1.operand, e2.operand)
if isinstance(e1, FunctionNode):
if e1.func_name != e2.func_name or len(e1.args) != len(e2.args):
return False
return all(
self._expressions_equal(a1, a2) for a1, a2 in zip(e1.args, e2.args)
)
return False
def _find_factor_dependencies(self, expression: Node) -> Set[str]:
"""查找表达式依赖的其他因子。
遍历已注册因子,检查表达式是否包含任何已注册因子的完整表达式。
Args:
expression: 待检查的表达式
Returns:
依赖的因子名称集合
"""
deps: Set[str] = set()
# 检查表达式本身是否等于某个已注册因子
for name, registered_expr in self.registered_expressions.items():
if self._expressions_equal(expression, registered_expr):
deps.add(name)
break
# 递归检查子节点
if isinstance(expression, BinaryOpNode):
deps.update(self._find_factor_dependencies(expression.left))
deps.update(self._find_factor_dependencies(expression.right))
elif isinstance(expression, UnaryOpNode):
deps.update(self._find_factor_dependencies(expression.operand))
elif isinstance(expression, FunctionNode):
for arg in expression.args:
deps.update(self._find_factor_dependencies(arg))
return deps
def _topological_sort(self, factor_names: List[str]) -> List[str]:
"""按依赖关系对因子进行拓扑排序。
确保依赖的因子先被计算。
Args:
factor_names: 因子名称列表
Returns:
排序后的因子名称列表
Raises:
ValueError: 当检测到循环依赖时
"""
# 构建依赖图
graph: Dict[str, Set[str]] = {}
in_degree: Dict[str, int] = {}
for name in factor_names:
plan = self._plans[name]
# 只考虑在本次计算范围内的依赖
deps = plan.factor_dependencies & set(factor_names)
graph[name] = deps
in_degree[name] = len(deps)
# Kahn 算法
result = []
queue = [name for name, degree in in_degree.items() if degree == 0]
while queue:
# 按原始顺序处理同级别的因子
queue.sort(key=lambda x: factor_names.index(x))
name = queue.pop(0)
result.append(name)
for other in factor_names:
if name in graph[other]:
in_degree[other] -= 1
if in_degree[other] == 0:
queue.append(other)
if len(result) != len(factor_names):
raise ValueError("检测到因子循环依赖")
return result

View File

@@ -0,0 +1,170 @@
"""执行计划生成器。
整合编译器和翻译器,生成完整的执行计划。
"""
from typing import Any, Dict, List, Optional, Set, Union
from src.factors.dsl import (
Node,
Symbol,
FunctionNode,
BinaryOpNode,
UnaryOpNode,
Constant,
)
from src.factors.compiler import DependencyExtractor
from src.factors.translator import PolarsTranslator
from src.factors.engine.data_spec import DataSpec, ExecutionPlan
class ExecutionPlanner:
"""执行计划生成器。
整合编译器和翻译器,生成完整的执行计划。
Attributes:
compiler: 依赖提取器
translator: Polars 翻译器
"""
def __init__(self) -> None:
"""初始化执行计划生成器。"""
self.compiler = DependencyExtractor()
self.translator = PolarsTranslator()
def create_plan(
self,
expression: Node,
output_name: str = "factor",
data_specs: Optional[List[DataSpec]] = None,
) -> ExecutionPlan:
"""从表达式创建执行计划。
Args:
expression: DSL 表达式节点
output_name: 输出因子名称
data_specs: 预定义的数据规格None 时自动推导
Returns:
执行计划对象
"""
# 1. 提取依赖
dependencies = self.compiler.extract_dependencies(expression)
# 2. 翻译为 Polars 表达式
polars_expr = self.translator.translate(expression)
# 3. 推导或验证数据规格
if data_specs is None:
data_specs = self._infer_data_specs(dependencies, expression)
return ExecutionPlan(
data_specs=data_specs,
polars_expr=polars_expr,
dependencies=dependencies,
output_name=output_name,
)
def _infer_data_specs(
self,
dependencies: Set[str],
expression: Node,
) -> List[DataSpec]:
"""从依赖推导数据规格。
根据表达式中的函数类型推断回看天数需求。
基础行情字段open, high, low, close, vol, amount, pre_close, change, pct_chg
默认从 pro_bar 表获取。
Args:
dependencies: 依赖的字段集合
expression: 表达式节点
Returns:
数据规格列表
"""
# 计算最大回看窗口
max_window = self._extract_max_window(expression)
lookback_days = max(1, max_window)
# 基础行情字段集合(这些字段从 pro_bar 表获取)
pro_bar_fields = {
"open",
"high",
"low",
"close",
"vol",
"amount",
"pre_close",
"change",
"pct_chg",
"turnover_rate",
"volume_ratio",
}
# 将依赖分为 pro_bar 字段和其他字段
pro_bar_deps = dependencies & pro_bar_fields
other_deps = dependencies - pro_bar_fields
data_specs = []
# pro_bar 表的数据规格
if pro_bar_deps:
data_specs.append(
DataSpec(
table="pro_bar",
columns=sorted(pro_bar_deps),
lookback_days=lookback_days,
)
)
# 其他字段从 daily 表获取
if other_deps:
data_specs.append(
DataSpec(
table="daily",
columns=sorted(other_deps),
lookback_days=lookback_days,
)
)
return data_specs
def _extract_max_window(self, node: Node) -> int:
"""从表达式中提取最大窗口大小。
Args:
node: AST 节点
Returns:
最大窗口大小,无时序函数返回 1
"""
if isinstance(node, FunctionNode):
window = 1
# 检查函数参数中的窗口大小
for arg in node.args:
if (
isinstance(arg, Constant)
and isinstance(arg.value, int)
and arg.value > window
):
window = arg.value
# 递归检查子表达式
for arg in node.args:
if isinstance(arg, Node) and not isinstance(arg, Constant):
window = max(window, self._extract_max_window(arg))
return window
elif isinstance(node, BinaryOpNode):
return max(
self._extract_max_window(node.left),
self._extract_max_window(node.right),
)
elif isinstance(node, UnaryOpNode):
return self._extract_max_window(node.operand)
return 1

144
src/factors/exceptions.py Normal file
View File

@@ -0,0 +1,144 @@
"""公式解析异常定义。
提供清晰的错误信息,帮助用户快速定位公式解析问题。
"""
import difflib
from typing import List, Optional
class FormulaParseError(Exception):
"""公式解析错误基类。
Attributes:
expr: 原始表达式字符串
lineno: 错误所在行号从1开始
col_offset: 错误所在列号从0开始
"""
def __init__(
self,
message: str,
expr: Optional[str] = None,
lineno: Optional[int] = None,
col_offset: Optional[int] = None,
):
self.expr = expr
self.lineno = lineno
self.col_offset = col_offset
# 构建详细错误信息
full_message = self._format_message(message)
super().__init__(full_message)
def _format_message(self, message: str) -> str:
"""格式化错误信息,包含位置指示器。"""
lines = [f"FormulaParseError: {message}"]
if self.expr:
lines.append(f" 公式: {self.expr}")
# 添加错误位置指示器
if self.col_offset is not None and self.lineno is not None:
# 计算错误行在表达式中的起始位置
expr_lines = self.expr.split("\n")
if 1 <= self.lineno <= len(expr_lines):
error_line = expr_lines[self.lineno - 1]
lines.append(f" {error_line}")
# 添加指向错误位置的箭头
pointer = " " * (self.col_offset + 7) + "^--- 此处出错"
lines.append(pointer)
return "\n".join(lines)
class UnknownFunctionError(FormulaParseError):
"""未知函数错误。
当表达式中使用了未注册的函数时抛出。
Attributes:
func_name: 未知的函数名
available: 可用函数列表
suggestions: 模糊匹配建议列表
"""
def __init__(
self,
func_name: str,
available: List[str],
expr: Optional[str] = None,
lineno: Optional[int] = None,
col_offset: Optional[int] = None,
):
self.func_name = func_name
self.available = available
# 使用 difflib 获取模糊匹配建议
self.suggestions = difflib.get_close_matches(
func_name, available, n=3, cutoff=0.5
)
# 构建错误信息
if self.suggestions:
suggestion_str = ", ".join(f"'{s}'" for s in self.suggestions)
hint_msg = f"你是不是想找: {suggestion_str}"
else:
# 只显示前10个可用函数
available_preview = ", ".join(available[:10])
if len(available) > 10:
available_preview += f", ... 等共 {len(available)} 个函数"
hint_msg = f"可用函数预览: {available_preview}"
msg = f"未知函数 '{func_name}'{hint_msg}"
super().__init__(
message=msg,
expr=expr,
lineno=lineno,
col_offset=col_offset,
)
class InvalidSyntaxError(FormulaParseError):
"""语法错误。
当表达式语法不正确或不支持时抛出。
"""
pass
class UnsupportedOperatorError(InvalidSyntaxError):
"""不支持的运算符错误。
当使用了不支持的运算符时抛出(如位运算、矩阵运算等)。
"""
pass
class EmptyExpressionError(FormulaParseError):
"""空表达式错误。"""
def __init__(self):
super().__init__("表达式不能为空或只包含空白字符")
class RegistryError(Exception):
"""注册表错误基类。"""
pass
class DuplicateFunctionError(RegistryError):
"""函数重复注册错误。
当尝试注册已存在的函数且未设置 force=True 时抛出。
"""
def __init__(self, func_name: str):
self.func_name = func_name
super().__init__(
f"函数 '{func_name}' 已存在。使用 force=True 覆盖,或选择其他名称。"
)

411
src/factors/parser.py Normal file
View File

@@ -0,0 +1,411 @@
"""公式解析器 - 将字符串表达式转换为 DSL 节点树。
基于 Python ast 模块实现,支持算术运算、比较运算、函数调用等。
示例:
>>> from src.factors.parser import FormulaParser
>>> from src.factors.registry import FunctionRegistry
>>> parser = FormulaParser(FunctionRegistry())
>>> node = parser.parse("ts_mean(close, 20)")
>>> print(node)
ts_mean(close, 20)
"""
import ast
from typing import Any, Dict, Optional, TYPE_CHECKING
from src.factors.dsl import Node, Symbol, Constant, BinaryOpNode, UnaryOpNode
from src.factors.exceptions import (
FormulaParseError,
UnknownFunctionError,
InvalidSyntaxError,
EmptyExpressionError,
)
if TYPE_CHECKING:
from src.factors.registry import FunctionRegistry
# 运算符映射表
BIN_OP_MAP: Dict[type, str] = {
ast.Add: "+",
ast.Sub: "-",
ast.Mult: "*",
ast.Div: "/",
ast.Pow: "**",
ast.FloorDiv: "//",
ast.Mod: "%",
}
UNARY_OP_MAP: Dict[type, str] = {
ast.UAdd: "+",
ast.USub: "-",
ast.Invert: "~", # 不支持,应报错
}
COMPARE_OP_MAP: Dict[type, str] = {
ast.Eq: "==",
ast.NotEq: "!=",
ast.Lt: "<",
ast.LtE: "<=",
ast.Gt: ">",
ast.GtE: ">=",
}
class FormulaParser:
"""基于 AST 的公式解析器。
将字符串表达式解析为 DSL 节点树,支持:
- 符号引用(如 close, open
- 数值常量(如 20, 3.14
- 二元运算(如 +, -, *, /
- 一元运算(如 -x
- 函数调用(如 ts_mean(close, 20)
- 比较运算(如 close > open
Attributes:
registry: 函数注册表,用于解析函数调用
"""
def __init__(self, registry: "FunctionRegistry") -> None:
"""初始化解析器。
Args:
registry: 函数注册表,提供函数名到可调用对象的映射
"""
self.registry = registry
def parse(self, expr: str) -> Node:
"""解析字符串表达式为 Node 树。
Args:
expr: 公式字符串,如 "ts_mean(close, 20)"
Returns:
解析后的 Node 节点
Raises:
EmptyExpressionError: 表达式为空时抛出
SyntaxError: Python 语法错误时抛出
FormulaParseError: 解析失败时抛出
Example:
>>> parser.parse("close / open")
BinaryOpNode("/", Symbol("close"), Symbol("open"))
"""
# 检查空表达式
if not expr or not expr.strip():
raise EmptyExpressionError()
# 解析为 Python AST
try:
tree = ast.parse(expr, mode="eval")
except SyntaxError as e:
# 将 SyntaxError 包装为 InvalidSyntaxError统一异常类型
raise InvalidSyntaxError(
message=f"表达式语法错误: {e.msg}",
expr=expr,
lineno=e.lineno,
col_offset=e.offset,
) from e
# 递归访问 AST 节点
try:
return self._visit(tree.body, expr)
except FormulaParseError:
# 重新抛出 FormulaParseError保留已有的位置信息
raise
except Exception as e:
# 将其他异常包装为 FormulaParseError
if not isinstance(e, FormulaParseError):
raise FormulaParseError(
message=f"解析失败: {str(e)}",
expr=expr,
) from e
raise
def _visit(self, node: ast.AST, expr: str) -> Node:
"""递归访问 AST 节点并转换为 DSL 节点。
Args:
node: Python AST 节点
expr: 原始表达式字符串(用于错误报告)
Returns:
对应的 DSL 节点
Raises:
InvalidSyntaxError: 遇到不支持的语法时抛出
"""
# 提取位置信息(如果节点有)
lineno = getattr(node, "lineno", None)
col_offset = getattr(node, "col_offset", None)
try:
if isinstance(node, ast.Name):
return self._visit_Name(node)
elif isinstance(node, ast.Constant):
return self._visit_Constant(node, expr)
elif isinstance(node, ast.BinOp):
return self._visit_BinOp(node, expr)
elif isinstance(node, ast.UnaryOp):
return self._visit_UnaryOp(node, expr)
elif isinstance(node, ast.Call):
return self._visit_Call(node, expr)
elif isinstance(node, ast.Compare):
return self._visit_Compare(node, expr)
else:
raise InvalidSyntaxError(
message=f"不支持的语法: {type(node).__name__}",
expr=expr,
lineno=lineno,
col_offset=col_offset,
)
except FormulaParseError:
# 重新抛出(保留已有的位置信息)
raise
except Exception as e:
# 包装为 FormulaParseError添加位置信息
raise FormulaParseError(
message=f"解析节点失败: {str(e)}",
expr=expr,
lineno=lineno,
col_offset=col_offset,
) from e
def _visit_Name(self, node: ast.Name) -> Symbol:
"""访问名称节点 - 永远转为 Symbol。
注意:利用 AST 语法自然区分变量和函数调用:
- log → Symbol("log")(数据列引用)
- log(close) → 在 _visit_Call 中处理(函数调用)
Args:
node: AST 名称节点
Returns:
Symbol 节点
"""
return Symbol(node.id)
def _visit_Constant(self, node: ast.Constant, expr: str) -> Node:
"""访问常量节点。
支持的类型:
- int/float → Constant 节点
- str → Symbol 节点(支持 ts_mean("close", 20) 语法)
Args:
node: AST 常量节点
expr: 原始表达式字符串
Returns:
Constant 或 Symbol 节点
Raises:
InvalidSyntaxError: 不支持的常量类型
"""
if isinstance(node.value, (int, float)):
return Constant(node.value)
elif isinstance(node.value, str):
# 字符串常量转为 Symbol支持 "close" 写法
return Symbol(node.value)
else:
lineno = getattr(node, "lineno", None)
col_offset = getattr(node, "col_offset", None)
raise InvalidSyntaxError(
message=f"不支持的常量类型: {type(node.value).__name__}",
expr=expr,
lineno=lineno,
col_offset=col_offset,
)
def _visit_BinOp(self, node: ast.BinOp, expr: str) -> BinaryOpNode:
"""访问二元运算节点。
Args:
node: AST 二元运算节点
expr: 原始表达式字符串
Returns:
BinaryOpNode 节点
Raises:
InvalidSyntaxError: 不支持的运算符
"""
left = self._visit(node.left, expr)
right = self._visit(node.right, expr)
op = BIN_OP_MAP.get(type(node.op))
if op is None:
lineno = getattr(node, "lineno", None)
col_offset = getattr(node, "col_offset", None)
raise InvalidSyntaxError(
message=f"不支持的运算符: {type(node.op).__name__}",
expr=expr,
lineno=lineno,
col_offset=col_offset,
)
return BinaryOpNode(op, left, right)
def _visit_UnaryOp(self, node: ast.UnaryOp, expr: str) -> Node:
"""访问一元运算节点。
支持常量折叠优化:纯数值的一元运算直接计算结果。
Args:
node: AST 一元运算节点
expr: 原始表达式字符串
Returns:
Constant常量折叠或 UnaryOpNode 节点
Raises:
InvalidSyntaxError: 不支持的运算符
"""
operand = self._visit(node.operand, expr)
op = UNARY_OP_MAP.get(type(node.op))
lineno = getattr(node, "lineno", None)
col_offset = getattr(node, "col_offset", None)
if op is None:
raise InvalidSyntaxError(
message=f"不支持的一元运算符: {type(node.op).__name__}",
expr=expr,
lineno=lineno,
col_offset=col_offset,
)
if op == "~":
raise InvalidSyntaxError(
message="位运算 '~' 不被支持",
expr=expr,
lineno=lineno,
col_offset=col_offset,
)
# 常量折叠优化:纯数值直接计算
if isinstance(operand, Constant) and isinstance(operand.value, (int, float)):
if op == "-":
return Constant(-operand.value)
elif op == "+":
return operand # +5 就是 5
# 非常量使用运算符重载
if op == "-":
return -operand
elif op == "+":
return +operand
# 不应该到达这里
raise InvalidSyntaxError(
message=f"无法处理的一元运算符: {op}",
expr=expr,
lineno=lineno,
col_offset=col_offset,
)
def _visit_Call(self, node: ast.Call, expr: str) -> Node:
"""访问函数调用节点。
注意:只有在这里查注册表,处理函数调用。
Args:
node: AST 函数调用节点
expr: 原始表达式字符串
Returns:
函数返回的 Node 节点
Raises:
InvalidSyntaxError: 不支持的函数调用语法
UnknownFunctionError: 函数未注册
"""
lineno = getattr(node, "lineno", None)
col_offset = getattr(node, "col_offset", None)
# 只支持简单函数调用(如 func(a, b)
if not isinstance(node.func, ast.Name):
raise InvalidSyntaxError(
message="只支持简单函数调用(如 func(a, b)",
expr=expr,
lineno=lineno,
col_offset=col_offset,
)
func_name = node.func.id
func = self.registry.get(func_name)
if func is None:
raise UnknownFunctionError(
func_name=func_name,
available=self.registry.available_functions(),
expr=expr,
lineno=lineno,
col_offset=col_offset,
)
# 解析位置参数
args = [self._visit(arg, expr) for arg in node.args]
# 解析关键字参数(如果有)
kwargs = {}
for keyword in node.keywords:
kwargs[keyword.arg] = self._visit(keyword.value, expr)
# 应用函数
try:
if kwargs:
return func(*args, **kwargs)
return func(*args)
except TypeError as e:
raise InvalidSyntaxError(
message=f"函数 '{func_name}' 调用失败: {e}",
expr=expr,
lineno=lineno,
col_offset=col_offset,
) from e
def _visit_Compare(self, node: ast.Compare, expr: str) -> BinaryOpNode:
"""访问比较运算节点。
注意:只支持简单二元比较,不支持链式比较(如 a < b < c
Args:
node: AST 比较节点
expr: 原始表达式字符串
Returns:
BinaryOpNode 节点(使用比较运算符)
Raises:
InvalidSyntaxError: 链式比较或不支持的运算符
"""
lineno = getattr(node, "lineno", None)
col_offset = getattr(node, "col_offset", None)
# Python 支持链式比较 (a < b < c),这里简化为二元比较
if len(node.ops) != 1 or len(node.comparators) != 1:
raise InvalidSyntaxError(
message="只支持简单二元比较(如 a > b不支持链式比较",
expr=expr,
lineno=lineno,
col_offset=col_offset,
)
left = self._visit(node.left, expr)
op = COMPARE_OP_MAP.get(type(node.ops[0]))
if op is None:
raise InvalidSyntaxError(
message=f"不支持的比较运算符: {type(node.ops[0]).__name__}",
expr=expr,
lineno=lineno,
col_offset=col_offset,
)
right = self._visit(node.comparators[0], expr)
return BinaryOpNode(op, left, right)

227
src/factors/registry.py Normal file
View File

@@ -0,0 +1,227 @@
"""函数注册表 - 管理字符串函数名到 Python 函数的映射。
支持自动发现和手动注册,与 FormulaParser 配合使用。
示例:
>>> from src.factors.registry import FunctionRegistry
>>> registry = FunctionRegistry(auto_scan=True) # 自动加载 api.py 函数
>>> registry.available_functions()[:5]
['abs', 'clip', 'cs_demean', 'cs_neutralize', 'cs_rank']
"""
import inspect
import typing
from typing import Any, Callable, Dict, List, Optional, Set
from src.factors.dsl import Node, FunctionNode
from src.factors.exceptions import DuplicateFunctionError
class FunctionRegistry:
"""函数注册表。
管理字符串函数名到可调用对象的映射。
自动从 api.py 加载标准函数,支持用户自定义函数注册。
Attributes:
_functions: 函数字典name -> callable
"""
def __init__(self, auto_scan: bool = True) -> None:
"""初始化注册表。
Args:
auto_scan: 是否自动扫描 api.py 模块,默认 True
"""
self._functions: Dict[str, Callable] = {}
if auto_scan:
self._scan_api_module()
def register(
self, name: str, func: Callable, force: bool = False
) -> "FunctionRegistry":
"""注册自定义函数。
Args:
name: 函数名称(字符串形式)
func: 可调用对象
force: 是否强制覆盖已存在的函数,默认 False
Returns:
self支持链式调用
Raises:
DuplicateFunctionError: 当函数名已存在且 force=False 时
Example:
>>> registry = FunctionRegistry(auto_scan=False)
>>> registry.register("my_func", lambda x: x * 2)
>>> registry.get("my_func")(5)
10
"""
if name in self._functions and not force:
raise DuplicateFunctionError(name)
self._functions[name] = func
return self
def unregister(self, name: str) -> "FunctionRegistry":
"""注销函数。
Args:
name: 要注销的函数名
Returns:
self支持链式调用
Raises:
KeyError: 函数不存在时
"""
if name not in self._functions:
raise KeyError(f"函数 '{name}' 不存在")
del self._functions[name]
return self
def get(self, name: str) -> Optional[Callable]:
"""获取函数。
Args:
name: 函数名称
Returns:
函数对象,不存在返回 None
"""
return self._functions.get(name)
def has(self, name: str) -> bool:
"""检查函数是否存在。
Args:
name: 函数名称
Returns:
是否存在
"""
return name in self._functions
def available_functions(self) -> List[str]:
"""返回所有可用函数名列表(按字母序)。
Returns:
排序后的函数名列表
"""
return sorted(self._functions.keys())
def clear(self) -> "FunctionRegistry":
"""清空所有注册的函数。
Returns:
self支持链式调用
"""
self._functions.clear()
return self
def scan_module(
self, module: Any, prefix: str = "", force: bool = False
) -> "FunctionRegistry":
"""扫描指定模块,自动注册符合条件的函数。
扫描规则:
1. 模块级别的函数(排除私有函数 _*
2. 返回类型注解为 Node 或 FunctionNode
Args:
module: 要扫描的模块对象
prefix: 函数名前缀,用于避免命名冲突
force: 是否强制覆盖已存在的函数
Returns:
self支持链式调用
Example:
>>> import my_custom_module
>>> registry.scan_module(my_custom_module, prefix="custom_")
"""
for name, obj in inspect.getmembers(module):
# 只处理非私有函数
if not inspect.isfunction(obj) or name.startswith("_"):
continue
# 检查是否应该注册
if self._should_register(obj):
full_name = prefix + name
self.register(full_name, obj, force=force)
return self
def _scan_api_module(self) -> None:
"""自动扫描 api.py 模块,注册所有符合条件的函数。
这是默认的自动扫描行为,在 __init__ 中调用。
"""
try:
from src.factors import api
self.scan_module(api)
except ImportError:
# api 模块可能不存在,静默跳过
pass
def _should_register(self, func: Callable) -> bool:
"""检查函数是否应该被注册。
基于类型提示检查函数返回类型,只注册返回 Node 或 FunctionNode 的函数。
Args:
func: 要检查的函数
Returns:
是否应该注册该函数
"""
try:
hints = typing.get_type_hints(func)
return_type = hints.get("return")
if return_type is None:
return False
# 处理 Union 类型(如 Union[Node, FunctionNode]
origin = typing.get_origin(return_type)
args = typing.get_args(return_type)
if origin is typing.Union:
# Union 类型,检查任一参数
return any(self._is_node_type(arg) for arg in args)
else:
# 单一类型
return self._is_node_type(return_type)
except Exception:
return False
def _is_node_type(self, typ: Any) -> bool:
"""检查类型是否是 Node 或 FunctionNode 的子类。
Args:
typ: 要检查的类型
Returns:
是否是 Node 相关类型
"""
if not isinstance(typ, type):
return False
return issubclass(typ, (Node, FunctionNode))
def __len__(self) -> int:
"""返回已注册函数数量。"""
return len(self._functions)
def __contains__(self, name: str) -> bool:
"""检查是否包含某个函数名。"""
return name in self._functions
def __repr__(self) -> str:
"""返回注册表字符串表示。"""
return f"FunctionRegistry({len(self._functions)} functions: {self.available_functions()[:5]}...)"