73 lines
2.5 KiB
Python
73 lines
2.5 KiB
Python
|
|
"""Tests for local DSL prompt and output parser in Step 3."""
|
||
|
|
|
||
|
|
from src.factorminer.agent.output_parser import (
|
||
|
|
CandidateFactor,
|
||
|
|
_infer_category,
|
||
|
|
parse_llm_output,
|
||
|
|
)
|
||
|
|
from src.factorminer.agent.prompt_builder import SYSTEM_PROMPT
|
||
|
|
|
||
|
|
|
||
|
|
def test_system_prompt_uses_local_dsl():
|
||
|
|
assert "$close" not in SYSTEM_PROMPT
|
||
|
|
assert "CsRank(" not in SYSTEM_PROMPT
|
||
|
|
assert "cs_rank(" in SYSTEM_PROMPT
|
||
|
|
assert "close / ts_delay(close, 1) - 1" in SYSTEM_PROMPT
|
||
|
|
assert "ts_mean(close, 20)" in SYSTEM_PROMPT
|
||
|
|
|
||
|
|
|
||
|
|
def test_parse_local_dsl_numbered_list():
|
||
|
|
raw = (
|
||
|
|
"1. momentum: -cs_rank(ts_delta(close, 5))\n"
|
||
|
|
"2. volume: cs_zscore((vol - ts_mean(vol, 20)) / ts_std(vol, 20))\n"
|
||
|
|
"3. vwap_dev: cs_rank((close - amount / vol) / (amount / vol))\n"
|
||
|
|
)
|
||
|
|
candidates, failed = parse_llm_output(raw)
|
||
|
|
assert len(candidates) == 3
|
||
|
|
assert candidates[0].name == "momentum"
|
||
|
|
assert candidates[0].formula == "-cs_rank(ts_delta(close, 5))"
|
||
|
|
assert candidates[0].is_valid
|
||
|
|
assert candidates[1].name == "volume"
|
||
|
|
assert (
|
||
|
|
candidates[1].formula == "cs_zscore((vol - ts_mean(vol, 20)) / ts_std(vol, 20))"
|
||
|
|
)
|
||
|
|
assert candidates[2].name == "vwap_dev"
|
||
|
|
assert not failed
|
||
|
|
|
||
|
|
|
||
|
|
def test_parse_local_dsl_formula_only():
|
||
|
|
raw = "cs_rank(close / ts_delay(close, 5) - 1)"
|
||
|
|
candidates, failed = parse_llm_output(raw)
|
||
|
|
assert len(candidates) == 1
|
||
|
|
assert candidates[0].formula == "cs_rank(close / ts_delay(close, 5) - 1)"
|
||
|
|
assert not failed
|
||
|
|
|
||
|
|
|
||
|
|
def test_parse_invalid_parentheses():
|
||
|
|
candidates, failed = parse_llm_output("1. bad: cs_rank(ts_delta(close, 5)")
|
||
|
|
assert len(candidates) == 1
|
||
|
|
assert not candidates[0].is_valid
|
||
|
|
assert "括号" in candidates[0].parse_error
|
||
|
|
|
||
|
|
|
||
|
|
def test_infer_category_local_dsl():
|
||
|
|
assert _infer_category("cs_rank(ts_delta(close, 5))") == "cross_sectional_momentum"
|
||
|
|
assert _infer_category("ts_corr(vol, close, 10)") == "regression"
|
||
|
|
assert _infer_category("ts_std(close, 20)") == "volatility"
|
||
|
|
assert _infer_category("if_(close > open, 1, -1)") == "conditional"
|
||
|
|
|
||
|
|
|
||
|
|
def test_candidate_factor_is_valid_without_tree():
|
||
|
|
cf = CandidateFactor(name="test", formula="cs_rank(close)")
|
||
|
|
assert cf.is_valid
|
||
|
|
assert cf.category == "unknown"
|
||
|
|
|
||
|
|
|
||
|
|
def test_parse_json_local_dsl():
|
||
|
|
raw = '{"name": "mom", "formula": "cs_rank(close / ts_delay(close, 5) - 1)"}'
|
||
|
|
candidates, failed = parse_llm_output(raw)
|
||
|
|
assert len(candidates) == 1
|
||
|
|
assert candidates[0].name == "mom"
|
||
|
|
assert candidates[0].formula == "cs_rank(close / ts_delay(close, 5) - 1)"
|
||
|
|
assert not failed
|