refactor(factorminer): 将 LLM Prompt 和解析器改造为直接输出本地 DSL
- DSL 规范改为 snake_case、中缀运算符,示例同步替换 - 移除 ExpressionTree 依赖,改为括号匹配等基础校验 - retry prompt 适配本地 DSL 规则
This commit is contained in:
72
tests/test_factorminer_prompt.py
Normal file
72
tests/test_factorminer_prompt.py
Normal file
@@ -0,0 +1,72 @@
|
||||
"""Tests for local DSL prompt and output parser in Step 3."""
|
||||
|
||||
from src.factorminer.agent.output_parser import (
|
||||
CandidateFactor,
|
||||
_infer_category,
|
||||
parse_llm_output,
|
||||
)
|
||||
from src.factorminer.agent.prompt_builder import SYSTEM_PROMPT
|
||||
|
||||
|
||||
def test_system_prompt_uses_local_dsl():
|
||||
assert "$close" not in SYSTEM_PROMPT
|
||||
assert "CsRank(" not in SYSTEM_PROMPT
|
||||
assert "cs_rank(" in SYSTEM_PROMPT
|
||||
assert "close / ts_delay(close, 1) - 1" in SYSTEM_PROMPT
|
||||
assert "ts_mean(close, 20)" in SYSTEM_PROMPT
|
||||
|
||||
|
||||
def test_parse_local_dsl_numbered_list():
|
||||
raw = (
|
||||
"1. momentum: -cs_rank(ts_delta(close, 5))\n"
|
||||
"2. volume: cs_zscore((vol - ts_mean(vol, 20)) / ts_std(vol, 20))\n"
|
||||
"3. vwap_dev: cs_rank((close - amount / vol) / (amount / vol))\n"
|
||||
)
|
||||
candidates, failed = parse_llm_output(raw)
|
||||
assert len(candidates) == 3
|
||||
assert candidates[0].name == "momentum"
|
||||
assert candidates[0].formula == "-cs_rank(ts_delta(close, 5))"
|
||||
assert candidates[0].is_valid
|
||||
assert candidates[1].name == "volume"
|
||||
assert (
|
||||
candidates[1].formula == "cs_zscore((vol - ts_mean(vol, 20)) / ts_std(vol, 20))"
|
||||
)
|
||||
assert candidates[2].name == "vwap_dev"
|
||||
assert not failed
|
||||
|
||||
|
||||
def test_parse_local_dsl_formula_only():
|
||||
raw = "cs_rank(close / ts_delay(close, 5) - 1)"
|
||||
candidates, failed = parse_llm_output(raw)
|
||||
assert len(candidates) == 1
|
||||
assert candidates[0].formula == "cs_rank(close / ts_delay(close, 5) - 1)"
|
||||
assert not failed
|
||||
|
||||
|
||||
def test_parse_invalid_parentheses():
|
||||
candidates, failed = parse_llm_output("1. bad: cs_rank(ts_delta(close, 5)")
|
||||
assert len(candidates) == 1
|
||||
assert not candidates[0].is_valid
|
||||
assert "括号" in candidates[0].parse_error
|
||||
|
||||
|
||||
def test_infer_category_local_dsl():
|
||||
assert _infer_category("cs_rank(ts_delta(close, 5))") == "cross_sectional_momentum"
|
||||
assert _infer_category("ts_corr(vol, close, 10)") == "regression"
|
||||
assert _infer_category("ts_std(close, 20)") == "volatility"
|
||||
assert _infer_category("if_(close > open, 1, -1)") == "conditional"
|
||||
|
||||
|
||||
def test_candidate_factor_is_valid_without_tree():
|
||||
cf = CandidateFactor(name="test", formula="cs_rank(close)")
|
||||
assert cf.is_valid
|
||||
assert cf.category == "unknown"
|
||||
|
||||
|
||||
def test_parse_json_local_dsl():
|
||||
raw = '{"name": "mom", "formula": "cs_rank(close / ts_delay(close, 5) - 1)"}'
|
||||
candidates, failed = parse_llm_output(raw)
|
||||
assert len(candidates) == 1
|
||||
assert candidates[0].name == "mom"
|
||||
assert candidates[0].formula == "cs_rank(close / ts_delay(close, 5) - 1)"
|
||||
assert not failed
|
||||
Reference in New Issue
Block a user