refactor(factorminer): 将 LLM Prompt 和解析器改造为直接输出本地 DSL

- DSL 规范改为 snake_case、中缀运算符,示例同步替换
- 移除 ExpressionTree 依赖,改为括号匹配等基础校验
- retry prompt 适配本地 DSL 规则
This commit is contained in:
2026-04-08 22:27:33 +08:00
parent 65500cce27
commit dd2e8a4a8e
5 changed files with 686 additions and 177 deletions

View File

@@ -0,0 +1,72 @@
"""Tests for local DSL prompt and output parser in Step 3."""
from src.factorminer.agent.output_parser import (
CandidateFactor,
_infer_category,
parse_llm_output,
)
from src.factorminer.agent.prompt_builder import SYSTEM_PROMPT
def test_system_prompt_uses_local_dsl():
assert "$close" not in SYSTEM_PROMPT
assert "CsRank(" not in SYSTEM_PROMPT
assert "cs_rank(" in SYSTEM_PROMPT
assert "close / ts_delay(close, 1) - 1" in SYSTEM_PROMPT
assert "ts_mean(close, 20)" in SYSTEM_PROMPT
def test_parse_local_dsl_numbered_list():
raw = (
"1. momentum: -cs_rank(ts_delta(close, 5))\n"
"2. volume: cs_zscore((vol - ts_mean(vol, 20)) / ts_std(vol, 20))\n"
"3. vwap_dev: cs_rank((close - amount / vol) / (amount / vol))\n"
)
candidates, failed = parse_llm_output(raw)
assert len(candidates) == 3
assert candidates[0].name == "momentum"
assert candidates[0].formula == "-cs_rank(ts_delta(close, 5))"
assert candidates[0].is_valid
assert candidates[1].name == "volume"
assert (
candidates[1].formula == "cs_zscore((vol - ts_mean(vol, 20)) / ts_std(vol, 20))"
)
assert candidates[2].name == "vwap_dev"
assert not failed
def test_parse_local_dsl_formula_only():
raw = "cs_rank(close / ts_delay(close, 5) - 1)"
candidates, failed = parse_llm_output(raw)
assert len(candidates) == 1
assert candidates[0].formula == "cs_rank(close / ts_delay(close, 5) - 1)"
assert not failed
def test_parse_invalid_parentheses():
candidates, failed = parse_llm_output("1. bad: cs_rank(ts_delta(close, 5)")
assert len(candidates) == 1
assert not candidates[0].is_valid
assert "括号" in candidates[0].parse_error
def test_infer_category_local_dsl():
assert _infer_category("cs_rank(ts_delta(close, 5))") == "cross_sectional_momentum"
assert _infer_category("ts_corr(vol, close, 10)") == "regression"
assert _infer_category("ts_std(close, 20)") == "volatility"
assert _infer_category("if_(close > open, 1, -1)") == "conditional"
def test_candidate_factor_is_valid_without_tree():
cf = CandidateFactor(name="test", formula="cs_rank(close)")
assert cf.is_valid
assert cf.category == "unknown"
def test_parse_json_local_dsl():
raw = '{"name": "mom", "formula": "cs_rank(close / ts_delay(close, 5) - 1)"}'
candidates, failed = parse_llm_output(raw)
assert len(candidates) == 1
assert candidates[0].name == "mom"
assert candidates[0].formula == "cs_rank(close / ts_delay(close, 5) - 1)"
assert not failed