refactor(factorminer): 将 110 个 PAPER_FACTORS 迁移到本地 snake_case DSL
- 新增一次性翻译脚本 src/scripts/translate_paper_factors.py - 将 library_io.PAPER_FACTORS 中的 CamelCase DSL 公式替换为本地 DSL - 对使用未实现算子(Decay、TsLinRegSlope、TsLinRegResid、Resid、 Quantile、HMA、DEMA)的 16 个因子注释为 # TODO - 新增 test_factorminer_paper_factors.py 验证所有翻译后公式的 DSL 解析 - 更新整合计划中 step1 的状态
This commit is contained in:
File diff suppressed because it is too large
Load Diff
293
src/scripts/translate_paper_factors.py
Normal file
293
src/scripts/translate_paper_factors.py
Normal file
@@ -0,0 +1,293 @@
|
||||
"""一次性 Paper Factors DSL 迁移脚本。
|
||||
|
||||
将 src.factorminer.core.library_io 中硬编码的 110 个 PAPER_FACTORS
|
||||
的 CamelCase DSL 公式翻译为本地 snake_case DSL 字符串。
|
||||
翻译结果直接替换回原常量列表。
|
||||
"""
|
||||
|
||||
import re
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import List, Tuple
|
||||
|
||||
# 确保能导入项目模块
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
|
||||
|
||||
from src.factorminer.core.library_io import PAPER_FACTORS
|
||||
|
||||
|
||||
UNSUPPORTED_OPS = {
|
||||
"Decay",
|
||||
"TsLinRegSlope",
|
||||
"TsLinRegResid",
|
||||
"Resid",
|
||||
"Quantile",
|
||||
"HMA",
|
||||
"DEMA",
|
||||
}
|
||||
|
||||
FIELD_MAP = {
|
||||
"$close": "close",
|
||||
"$volume": "vol",
|
||||
"$amt": "amount",
|
||||
"$vwap": "(amount / vol)",
|
||||
"$returns": "(close / ts_delay(close, 1) - 1)",
|
||||
"$high": "high",
|
||||
"$low": "low",
|
||||
"$open": "open",
|
||||
}
|
||||
|
||||
|
||||
class Node:
|
||||
"""AST 节点基类。"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class LeafNode(Node):
|
||||
"""叶子节点(字段或常量)。"""
|
||||
|
||||
def __init__(self, value: str) -> None:
|
||||
self.value = value
|
||||
|
||||
|
||||
class FuncNode(Node):
|
||||
"""函数调用节点。"""
|
||||
|
||||
def __init__(self, name: str, args: List[Node]) -> None:
|
||||
self.name = name
|
||||
self.args = args
|
||||
|
||||
|
||||
def _parse_expr(s: str, i: int = 0) -> Tuple[Node, int]:
|
||||
"""递归解析 CamelCase DSL 表达式。
|
||||
|
||||
Args:
|
||||
s: 表达式字符串
|
||||
i: 起始索引
|
||||
|
||||
Returns:
|
||||
(解析后的节点, 下一个索引)
|
||||
"""
|
||||
j = i
|
||||
|
||||
# 尝试读取数字常量(支持 1e-8, 0.75 等)
|
||||
num_match = re.match(r"-?\d+(\.\d+)?(e[+-]?\d+)?", s[j:])
|
||||
if num_match:
|
||||
val = num_match.group(0)
|
||||
j += len(val)
|
||||
return LeafNode(val), j
|
||||
|
||||
# 读取标识符
|
||||
while j < len(s) and (s[j].isalnum() or s[j] == "$" or s[j] == "_"):
|
||||
j += 1
|
||||
name = s[i:j]
|
||||
|
||||
if j >= len(s) or s[j] != "(":
|
||||
return LeafNode(name), j
|
||||
|
||||
# 函数调用
|
||||
j += 1 # skip '('
|
||||
args: List[Node] = []
|
||||
while j < len(s) and s[j] != ")":
|
||||
# 跳过空白
|
||||
while j < len(s) and s[j] == " ":
|
||||
j += 1
|
||||
if j >= len(s):
|
||||
break
|
||||
if s[j] == ")":
|
||||
break
|
||||
arg, j = _parse_expr(s, j)
|
||||
args.append(arg)
|
||||
while j < len(s) and s[j] == " ":
|
||||
j += 1
|
||||
if j < len(s) and s[j] == ",":
|
||||
j += 1
|
||||
elif j < len(s) and s[j] == ")":
|
||||
break
|
||||
|
||||
if j < len(s) and s[j] == ")":
|
||||
j += 1
|
||||
return FuncNode(name, args), j
|
||||
|
||||
|
||||
def _reconstruct_original(node: Node) -> str:
|
||||
"""将 AST 重新拼接为原始 CamelCase 字符串。"""
|
||||
if isinstance(node, LeafNode):
|
||||
return node.value
|
||||
args_str = ", ".join(_reconstruct_original(a) for a in node.args)
|
||||
return f"{node.name}({args_str})"
|
||||
|
||||
|
||||
def _contains_unsupported(node: Node) -> bool:
|
||||
"""检查 AST 中是否包含未实现算子。"""
|
||||
if isinstance(node, LeafNode):
|
||||
return False
|
||||
if node.name in UNSUPPORTED_OPS:
|
||||
return True
|
||||
return any(_contains_unsupported(a) for a in node.args)
|
||||
|
||||
|
||||
def _translate(node: Node, toplevel: bool = True) -> str:
|
||||
"""将 AST 翻译为本地 snake_case DSL。
|
||||
|
||||
Args:
|
||||
node: AST 节点
|
||||
toplevel: 是否处于顶层调用
|
||||
|
||||
Returns:
|
||||
本地 DSL 字符串;若包含未实现算子则返回 # TODO: <原始公式>
|
||||
"""
|
||||
if toplevel and _contains_unsupported(node):
|
||||
return f"# TODO: {_reconstruct_original(node)}"
|
||||
|
||||
if isinstance(node, LeafNode):
|
||||
val = node.value
|
||||
if val in FIELD_MAP:
|
||||
return FIELD_MAP[val]
|
||||
return val
|
||||
|
||||
name = node.name
|
||||
args = node.args
|
||||
ta = [a for a in args]
|
||||
|
||||
if name == "Neg":
|
||||
return f"(-{_translate(ta[0], toplevel=False)})"
|
||||
if name == "Add":
|
||||
return f"({_translate(ta[0])} + {_translate(ta[1])})"
|
||||
if name == "Sub":
|
||||
return f"({_translate(ta[0])} - {_translate(ta[1])})"
|
||||
if name == "Mul":
|
||||
return f"({_translate(ta[0])} * {_translate(ta[1])})"
|
||||
if name == "Div":
|
||||
return f"({_translate(ta[0])} / {_translate(ta[1])})"
|
||||
if name == "Greater":
|
||||
return f"({_translate(ta[0])} > {_translate(ta[1])})"
|
||||
if name == "Square":
|
||||
return f"({_translate(ta[0])} ** 2)"
|
||||
if name == "CsRank":
|
||||
return f"cs_rank({_translate(ta[0])})"
|
||||
if name == "CsZscore":
|
||||
return f"cs_zscore({_translate(ta[0])})"
|
||||
if name in ("TsMean", "Mean"):
|
||||
return f"ts_mean({_translate(ta[0])}, {ta[1].value})"
|
||||
if name == "TsMax":
|
||||
return f"ts_max({_translate(ta[0])}, {ta[1].value})"
|
||||
if name == "TsMin":
|
||||
return f"ts_min({_translate(ta[0])}, {ta[1].value})"
|
||||
if name == "Std":
|
||||
return f"ts_std({_translate(ta[0])}, {ta[1].value})"
|
||||
if name == "Delta":
|
||||
return f"ts_delta({_translate(ta[0])}, {ta[1].value})"
|
||||
if name == "Delay":
|
||||
return f"ts_delay({_translate(ta[0])}, {ta[1].value})"
|
||||
if name == "Corr":
|
||||
return f"ts_corr({_translate(ta[0])}, {_translate(ta[1])}, {ta[2].value})"
|
||||
if name == "Cov":
|
||||
return f"ts_cov({_translate(ta[0])}, {_translate(ta[1])}, {ta[2].value})"
|
||||
if name == "Sum":
|
||||
return f"ts_sum({_translate(ta[0])}, {ta[1].value})"
|
||||
if name == "Return":
|
||||
return f"ts_pct_change({_translate(ta[0])}, {ta[1].value})"
|
||||
if name == "EMA":
|
||||
return f"ts_ema({_translate(ta[0])}, {ta[1].value})"
|
||||
if name == "WMA":
|
||||
return f"ts_wma({_translate(ta[0])}, {ta[1].value})"
|
||||
if name == "SMA":
|
||||
return f"ts_mean({_translate(ta[0])}, {ta[1].value})"
|
||||
if name == "Skew":
|
||||
return f"ts_skew({_translate(ta[0])}, {ta[1].value})"
|
||||
if name == "Kurt":
|
||||
return f"ts_kurt({_translate(ta[0])}, {ta[1].value})"
|
||||
if name == "Abs":
|
||||
return f"abs({_translate(ta[0])})"
|
||||
if name == "Sign":
|
||||
return f"sign({_translate(ta[0])})"
|
||||
if name == "Max":
|
||||
return f"max_({_translate(ta[0])}, {_translate(ta[1])})"
|
||||
if name == "Min":
|
||||
return f"min_({_translate(ta[0])}, {_translate(ta[1])})"
|
||||
if name == "IfElse":
|
||||
return f"if_({_translate(ta[0])}, {_translate(ta[1])}, {_translate(ta[2])})"
|
||||
if name == "TsRank":
|
||||
return f"ts_rank({_translate(ta[0])}, {ta[1].value})"
|
||||
|
||||
raise ValueError(f"未知函数: {name}")
|
||||
|
||||
|
||||
def _indent_block(lines: List[str], spaces: int = 4) -> List[str]:
|
||||
"""给代码块增加统一缩进。"""
|
||||
prefix = " " * spaces
|
||||
return [prefix + line for line in lines]
|
||||
|
||||
|
||||
def main() -> None:
|
||||
"""主函数:翻译 PAPER_FACTORS 并写回 library_io.py。"""
|
||||
success_count = 0
|
||||
todo_count = 0
|
||||
translated_entries: List[str] = []
|
||||
|
||||
for entry in PAPER_FACTORS:
|
||||
formula = entry["formula"]
|
||||
tree, next_pos = _parse_expr(formula)
|
||||
if next_pos != len(formula):
|
||||
print(f"[ERROR] 解析未完成: {formula} (停在 {next_pos})")
|
||||
raise SystemExit(1)
|
||||
|
||||
translated = _translate(tree)
|
||||
is_todo = translated.startswith("# TODO:")
|
||||
|
||||
if is_todo:
|
||||
todo_count += 1
|
||||
# 注释掉整个字典条目,避免影响后续流程
|
||||
lines = [
|
||||
"# {",
|
||||
f'# "name": {entry["name"]!r},',
|
||||
f'# "formula": {translated!r},',
|
||||
f'# "category": {entry["category"]!r},',
|
||||
"# },",
|
||||
]
|
||||
else:
|
||||
success_count += 1
|
||||
lines = [
|
||||
"{",
|
||||
f' "name": {entry["name"]!r},',
|
||||
f' "formula": {translated!r},',
|
||||
f' "category": {entry["category"]!r},',
|
||||
" },",
|
||||
]
|
||||
translated_entries.extend(lines)
|
||||
|
||||
# 构建新的 PAPER_FACTORS 代码块
|
||||
new_factor_block = "PAPER_FACTORS: List[Dict[str, str]] = [\n"
|
||||
for line in translated_entries:
|
||||
if line.startswith("# "):
|
||||
new_factor_block += f" {line}\n"
|
||||
elif line in ("{", " },"):
|
||||
new_factor_block += f" {line}\n"
|
||||
else:
|
||||
new_factor_block += f" {line}\n"
|
||||
new_factor_block += "]\n"
|
||||
|
||||
# 读取原文件
|
||||
lib_io_path = Path("src/factorminer/core/library_io.py")
|
||||
original_text = lib_io_path.read_text(encoding="utf-8")
|
||||
|
||||
# 用正则替换 PAPER_FACTORS 定义块
|
||||
pattern = r"PAPER_FACTORS: List\[Dict\[str, str\]\] = \[.*?^\]"
|
||||
match = re.search(pattern, original_text, re.DOTALL | re.MULTILINE)
|
||||
if not match:
|
||||
print("[ERROR] 未能在 library_io.py 中定位 PAPER_FACTORS 定义块")
|
||||
raise SystemExit(1)
|
||||
|
||||
new_text = (
|
||||
original_text[: match.start()] + new_factor_block + original_text[match.end() :]
|
||||
)
|
||||
lib_io_path.write_text(new_text, encoding="utf-8")
|
||||
|
||||
total = len(PAPER_FACTORS)
|
||||
print(f"[translate] 成功 {success_count}/{total},TODO {todo_count} 个(已注释)")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user