refactor(factors): 简化 add_factor API 并默认启用 metadata
- 合并 add_factor_by_name 到 add_factor,支持三种调用方式 - FactorManager 构造函数改为可选参数,使用默认路径 - FactorEngine 默认启用 metadata,无需手动配置路径
This commit is contained in:
@@ -81,12 +81,16 @@ class FactorEngine:
|
||||
self._registry = registry if registry is not None else FunctionRegistry()
|
||||
self._parser = FormulaParser(self._registry)
|
||||
|
||||
# 初始化 metadata 管理器(可选)
|
||||
self._metadata: Optional["FactorManager"] = None
|
||||
# 初始化 metadata 管理器(可选,默认启用)
|
||||
if metadata_path is not None:
|
||||
from src.factors.metadata import FactorManager
|
||||
|
||||
self._metadata = FactorManager(metadata_path)
|
||||
else:
|
||||
# 使用 FactorManager 的默认路径
|
||||
from src.factors.metadata import FactorManager
|
||||
|
||||
self._metadata = FactorManager()
|
||||
|
||||
def register(
|
||||
self,
|
||||
@@ -128,22 +132,68 @@ class FactorEngine:
|
||||
|
||||
return self
|
||||
|
||||
def _add_factor_from_metadata(
|
||||
self,
|
||||
name: str,
|
||||
factor_name_in_metadata: str,
|
||||
data_specs: Optional[List[DataSpec]] = None,
|
||||
) -> "FactorEngine":
|
||||
"""从 metadata 中查询并注册因子(内部方法)。
|
||||
|
||||
Args:
|
||||
name: 要注册的因子名称(引擎中使用的名称)
|
||||
factor_name_in_metadata: metadata 中的因子名称
|
||||
data_specs: 可选的数据规格
|
||||
|
||||
Returns:
|
||||
self,支持链式调用
|
||||
|
||||
Raises:
|
||||
RuntimeError: 当引擎未配置 metadata 路径时
|
||||
ValueError: 当在 metadata 中未找到因子时
|
||||
FormulaParseError: 当 DSL 表达式解析失败时
|
||||
"""
|
||||
if self._metadata is None:
|
||||
raise RuntimeError(
|
||||
"引擎未配置 metadata 路径。请在初始化时传入 metadata_path 参数,"
|
||||
+ "例如:FactorEngine(metadata_path='data/factors.jsonl')"
|
||||
)
|
||||
|
||||
# 从 metadata 查询因子
|
||||
df = self._metadata.get_factors_by_name(factor_name_in_metadata)
|
||||
|
||||
if len(df) == 0:
|
||||
raise ValueError(
|
||||
f"在 metadata 中未找到因子 '{factor_name_in_metadata}'。"
|
||||
+ "请确认因子名称正确,或先使用 FactorManager 添加该因子。"
|
||||
)
|
||||
|
||||
# 获取 DSL 表达式
|
||||
dsl_expr = df["dsl"][0]
|
||||
|
||||
# 解析表达式为 Node
|
||||
node = self._parser.parse(dsl_expr)
|
||||
|
||||
# 委托给 register 方法
|
||||
return self.register(name, node, data_specs)
|
||||
|
||||
def add_factor(
|
||||
self,
|
||||
name: str,
|
||||
expression: Union[str, Node],
|
||||
expression: Optional[Union[str, Node]] = None,
|
||||
data_specs: Optional[List[DataSpec]] = None,
|
||||
) -> "FactorEngine":
|
||||
"""注册因子(支持字符串或 Node 表达式)。
|
||||
"""注册因子(支持多种调用方式)。
|
||||
|
||||
这是 register 方法的增强版,支持字符串表达式解析。
|
||||
向后兼容:register 方法保持不变,继续只接受 Node 类型。
|
||||
这是 register 方法的增强版,支持以下调用方式:
|
||||
1. 传入 name 和 expression:直接注册表达式(字符串或 Node)
|
||||
2. 只传入 name:从 metadata 中查询表达式并注册
|
||||
|
||||
遵循 Fail-Fast 原则:字符串表达式会立即解析,失败时立即抛出异常。
|
||||
|
||||
Args:
|
||||
name: 因子名称
|
||||
expression: 字符串表达式或 Node 对象
|
||||
name: 因子名称(引擎中使用的名称)
|
||||
expression: 字符串表达式或 Node 对象,为 None 时从 metadata 查询
|
||||
data_specs: 可选的数据规格
|
||||
|
||||
Returns:
|
||||
@@ -152,19 +202,21 @@ class FactorEngine:
|
||||
Raises:
|
||||
TypeError: 当 expression 类型不支持时
|
||||
FormulaParseError: 当字符串解析失败时(立即报错)
|
||||
RuntimeError: 当 expression 为 None 但未配置 metadata 时
|
||||
ValueError: 当在 metadata 中未找到因子时
|
||||
|
||||
Example:
|
||||
>>> engine = FactorEngine()
|
||||
>>>
|
||||
>>> # 字符串方式(新功能)
|
||||
>>> # 方式1:字符串表达式
|
||||
>>> engine.add_factor("ma20", "ts_mean(close, 20)")
|
||||
>>>
|
||||
>>> # Node 方式(与 register 相同)
|
||||
>>> # 方式2:Node 表达式
|
||||
>>> from src.factors.api import close, ts_mean
|
||||
>>> engine.add_factor("ma20", ts_mean(close, 20))
|
||||
>>>
|
||||
>>> # 复杂表达式
|
||||
>>> engine.add_factor("alpha1", "cs_rank(close / open)")
|
||||
>>> # 方式3:从 metadata 查询(需要初始化时配置 metadata_path)
|
||||
>>> engine.add_factor("return_5") # 从 metadata 查询名为 return_5 的因子
|
||||
>>>
|
||||
>>> # 链式调用
|
||||
>>> (engine
|
||||
@@ -172,6 +224,10 @@ class FactorEngine:
|
||||
... .add_factor("ma10", "ts_mean(close, 10)")
|
||||
... .add_factor("golden_cross", "ma5 > ma10"))
|
||||
"""
|
||||
if expression is None:
|
||||
# 从 metadata 查询表达式
|
||||
return self._add_factor_from_metadata(name, name, data_specs)
|
||||
|
||||
if isinstance(expression, str):
|
||||
# Fail-Fast:立即解析,失败立即报错
|
||||
node = self._parser.parse(expression)
|
||||
@@ -185,76 +241,6 @@ class FactorEngine:
|
||||
# 委托给现有的 register 方法
|
||||
return self.register(name, node, data_specs)
|
||||
|
||||
def add_factor_by_name(
|
||||
self,
|
||||
name: str,
|
||||
factor_name_in_metadata: Optional[str] = None,
|
||||
data_specs: Optional[List[DataSpec]] = None,
|
||||
) -> "FactorEngine":
|
||||
"""根据 metadata 中的因子名称注册因子。
|
||||
|
||||
从 metadata 管理器中根据因子名称查询 DSL 表达式,
|
||||
然后解析并注册到引擎中。
|
||||
|
||||
Args:
|
||||
name: 要注册的因子名称(引擎中使用的名称)
|
||||
factor_name_in_metadata: metadata 中的因子名称,
|
||||
为 None 时默认使用 name 参数
|
||||
data_specs: 可选的数据规格
|
||||
|
||||
Returns:
|
||||
self,支持链式调用
|
||||
|
||||
Raises:
|
||||
RuntimeError: 当引擎未配置 metadata 路径时
|
||||
ValueError: 当在 metadata 中未找到因子时
|
||||
FormulaParseError: 当 DSL 表达式解析失败时
|
||||
|
||||
Example:
|
||||
>>> # 初始化时启用 metadata
|
||||
>>> engine = FactorEngine(metadata_path="data/factors.jsonl")
|
||||
>>>
|
||||
>>> # 注册 metadata 中的因子(使用相同名称)
|
||||
>>> engine.add_factor_by_name("return_5")
|
||||
>>>
|
||||
>>> # 使用不同名称注册
|
||||
>>> engine.add_factor_by_name("my_mom", "momentum_5d")
|
||||
>>>
|
||||
>>> # 链式调用
|
||||
>>> (engine
|
||||
... .add_factor_by_name("ma20")
|
||||
... .add_factor_by_name("rsi14")
|
||||
... .compute(["ma20", "rsi14"], "20240101", "20240131"))
|
||||
"""
|
||||
if self._metadata is None:
|
||||
raise RuntimeError(
|
||||
"引擎未配置 metadata 路径。请在初始化时传入 metadata_path 参数,"
|
||||
+ "例如:FactorEngine(metadata_path='data/factors.jsonl')"
|
||||
)
|
||||
|
||||
# 使用传入的名称或默认使用 name
|
||||
query_name = (
|
||||
factor_name_in_metadata if factor_name_in_metadata is not None else name
|
||||
)
|
||||
|
||||
# 从 metadata 查询因子
|
||||
df = self._metadata.get_factors_by_name(query_name)
|
||||
|
||||
if len(df) == 0:
|
||||
raise ValueError(
|
||||
f"在 metadata 中未找到因子 '{query_name}'。"
|
||||
+ "请确认因子名称正确,或先使用 FactorManager 添加该因子。"
|
||||
)
|
||||
|
||||
# 获取 DSL 表达式
|
||||
dsl_expr = df["dsl"][0]
|
||||
|
||||
# 解析表达式为 Node
|
||||
node = self._parser.parse(dsl_expr)
|
||||
|
||||
# 委托给 register 方法
|
||||
return self.register(name, node, data_specs)
|
||||
|
||||
def compute(
|
||||
self,
|
||||
factor_names: Union[str, List[str]],
|
||||
|
||||
Reference in New Issue
Block a user