From dd2e8a4a8ed0a2e12656412e730b8a2e992a5f46 Mon Sep 17 00:00:00 2001 From: liaozhaorun <1300336796@qq.com> Date: Wed, 8 Apr 2026 22:27:33 +0800 Subject: [PATCH] =?UTF-8?q?refactor(factorminer):=20=E5=B0=86=20LLM=20Prom?= =?UTF-8?q?pt=20=E5=92=8C=E8=A7=A3=E6=9E=90=E5=99=A8=E6=94=B9=E9=80=A0?= =?UTF-8?q?=E4=B8=BA=E7=9B=B4=E6=8E=A5=E8=BE=93=E5=87=BA=E6=9C=AC=E5=9C=B0?= =?UTF-8?q?=20DSL=20-=20DSL=20=E8=A7=84=E8=8C=83=E6=94=B9=E4=B8=BA=20snake?= =?UTF-8?q?=5Fcase=E3=80=81=E4=B8=AD=E7=BC=80=E8=BF=90=E7=AE=97=E7=AC=A6?= =?UTF-8?q?=EF=BC=8C=E7=A4=BA=E4=BE=8B=E5=90=8C=E6=AD=A5=E6=9B=BF=E6=8D=A2?= =?UTF-8?q?=20-=20=E7=A7=BB=E9=99=A4=20ExpressionTree=20=E4=BE=9D=E8=B5=96?= =?UTF-8?q?=EF=BC=8C=E6=94=B9=E4=B8=BA=E6=8B=AC=E5=8F=B7=E5=8C=B9=E9=85=8D?= =?UTF-8?q?=E7=AD=89=E5=9F=BA=E7=A1=80=E6=A0=A1=E9=AA=8C=20-=20retry=20pro?= =?UTF-8?q?mpt=20=E9=80=82=E9=85=8D=E6=9C=AC=E5=9C=B0=20DSL=20=E8=A7=84?= =?UTF-8?q?=E5=88=99?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../2026-04-08-step3-llm-prompt-local-dsl.md | 458 ++++++++++++++++++ src/factorminer/agent/factor_generator.py | 13 +- src/factorminer/agent/output_parser.py | 101 ++-- src/factorminer/agent/prompt_builder.py | 219 ++++----- tests/test_factorminer_prompt.py | 72 +++ 5 files changed, 686 insertions(+), 177 deletions(-) create mode 100644 docs/plans/2026-04-08-step3-llm-prompt-local-dsl.md create mode 100644 tests/test_factorminer_prompt.py diff --git a/docs/plans/2026-04-08-step3-llm-prompt-local-dsl.md b/docs/plans/2026-04-08-step3-llm-prompt-local-dsl.md new file mode 100644 index 0000000..964b0d7 --- /dev/null +++ b/docs/plans/2026-04-08-step3-llm-prompt-local-dsl.md @@ -0,0 +1,458 @@ +# Step 3: LLM Prompt 改造(直接生成本地 DSL)实施计划 + +> **For Claude:** REQUIRED SUB-SKILL: Use superpowers:executing-plans to implement this plan task-by-task. + +**Goal:** 将 FactorMiner 的 LLM Prompt 和输出解析器从 CamelCase + `$` 前缀 DSL 改造为直接生成本地 snake_case DSL,移除运行时翻译层。 + +**Architecture:** Prompt 直接使用本地 `FactorEngine` 支持的 snake_case 函数名和字段名;`OutputParser` 仅做字符串提取和轻量清洗,不再调用 FactorMiner 的 `ExpressionTree` 解析;`factor_generator.py` 配合返回原始 DSL 字符串。 + +**Tech Stack:** Python, ProStock `src.factors` 本地 DSL (`FactorEngine`) + +--- + +## Task 1: 重写 `src/factorminer/agent/prompt_builder.py` + +**Files:** +- Modify: `src/factorminer/agent/prompt_builder.py` +- Test: `tests/test_factorminer_prompt.py` + +**Step 1: 重写字段列表函数 `_format_feature_list()`** + +将 `$` 前缀字段替换为本地字段,并添加计算说明: + +```python +def _format_feature_list() -> str: + descriptions = { + "open": "开盘价", + "high": "最高价", + "low": "最低价", + "close": "收盘价", + "vol": "成交量(股数)", + "amount": "成交额(金额)", + "vwap": "可用 amount / vol 计算", + "returns": "可用 close / ts_delay(close, 1) - 1 计算", + } + lines = [] + for feat, desc in descriptions.items(): + lines.append(f" {feat}: {desc}") + return "\n".join(lines) +``` + +**Step 2: 定义本地 DSL 算子表映射** + +在 `prompt_builder.py` 中新增 `LOCAL_OPERATOR_TABLE` 常量,列出 prompt 中需要展示的本地可用算子(按类别分组),不再依赖 `OPERATOR_REGISTRY` 遍历: + +```python +LOCAL_OPERATOR_TABLE = { + "ARITHMETIC": [ + ("+", "二元", "x + y"), + ("-", "二元/一元", "x - y 或 -x"), + ("*", "二元", "x * y"), + ("/", "二元", "x / y"), + ("**", "二元", "x ** y (幂运算)"), + (">", "二元", "x > y (条件判断,返回 0/1)"), + ("<", "二元", "x < y (条件判断,返回 0/1)"), + ("abs(x)", "一元", "绝对值"), + ("sign(x)", "一元", "符号函数"), + ("max_(x, y)", "二元", "逐元素最大值"), + ("min_(x, y)", "二元", "逐元素最小值"), + ("clip(x, lower, upper)", "一元带参", "截断"), + ("log(x)", "一元", "自然对数"), + ("sqrt(x)", "一元", "平方根"), + ("exp(x)", "一元", "指数函数"), + ], + "TIMESERIES": [ + ("ts_mean(x, window)", "一元+窗口", "滚动均值"), + ("ts_std(x, window)", "一元+窗口", "滚动标准差"), + ("ts_var(x, window)", "一元+窗口", "滚动方差"), + ("ts_max(x, window)", "一元+窗口", "滚动最大值"), + ("ts_min(x, window)", "一元+窗口", "滚动最小值"), + ("ts_sum(x, window)", "一元+窗口", "滚动求和"), + ("ts_delay(x, periods)", "一元+周期", "滞后 N 期"), + ("ts_delta(x, periods)", "一元+周期", "差分 N 期"), + ("ts_corr(x, y, window)", "二元+窗口", "滚动相关系数"), + ("ts_cov(x, y, window)", "二元+窗口", "滚动协方差"), + ("ts_pct_change(x, periods)", "一元+周期", "N 期百分比变化"), + ("ts_ema(x, window)", "一元+窗口", "指数移动平均"), + ("ts_wma(x, window)", "一元+窗口", "加权移动平均"), + ("ts_skew(x, window)", "一元+窗口", "滚动偏度"), + ("ts_kurt(x, window)", "一元+窗口", "滚动峰度"), + ("ts_rank(x, window)", "一元+窗口", "滚动分位排名"), + ], + "CROSS_SECTIONAL": [ + ("cs_rank(x)", "一元", "截面排名(分位数)"), + ("cs_zscore(x)", "一元", "截面 Z-Score 标准化"), + ("cs_demean(x)", "一元", "截面去均值"), + ("cs_neutralize(x, group)", "一元", "行业/市值中性化"), + ("cs_winsorize(x, lower, upper)", "一元", "截面缩尾处理"), + ], + "CONDITIONAL": [ + ("if_(condition, true_val, false_val)", "三元", "条件选择"), + ("where(condition, true_val, false_val)", "三元", "if_ 的别名"), + ], +} +``` + +然后重写 `_format_operator_table()`: + +```python +def _format_operator_table() -> str: + lines = [] + for cat_name, ops in LOCAL_OPERATOR_TABLE.items(): + lines.append(f"\n### {cat_name} operators") + for op_sig, arity, desc in ops: + lines.append(f"- {op_sig}: {desc} ({arity})") + return "\n".join(lines) +``` + +**Step 3: 重写 `SYSTEM_PROMPT`** + +替换语法规则段落和示例: + +```python +SYSTEM_PROMPT = f"""You are a quantitative researcher mining formulaic alpha factors for stock selection. + +Your goal is to generate novel, predictive factor expressions using the local ProStock DSL. Each factor is a composition of operators applied to raw market features. + +## RAW FEATURES (leaf nodes) +{_format_feature_list()} + +## OPERATOR LIBRARY +{_format_operator_table()} + +## EXPRESSION SYNTAX RULES +1. Expressions use Python-style infix operators: +, -, *, /, **, >, < +2. Function calls use snake_case names with comma-separated arguments: ts_mean(close, 20) +3. Window sizes and periods are numeric arguments placed last in function calls. +4. Valid window sizes are integers, typically in range [2, 250]. +5. Cross-sectional operators (cs_rank, cs_zscore, cs_demean) operate across all stocks at each time step -- they are crucial for making factors comparable. +6. Do NOT use $ prefix for features. Use `close`, `vol`, `amount`, etc. directly. +7. `vwap` is not a raw feature; use `amount / vol` if you need it. +8. `returns` is not a raw feature; use `close / ts_delay(close, 1) - 1` if you need returns. + +## EXAMPLES OF WELL-FORMED FACTORS +- -cs_rank(ts_delta(close, 5)) + Short-term reversal: rank of 5-day price change, negated. +- cs_zscore((vol - ts_mean(vol, 20)) / ts_std(vol, 20)) + Volume surprise: standardized deviation from 20-day mean volume. +- cs_rank((close - amount / vol) / (amount / vol)) + Intraday deviation from VWAP, cross-sectionally ranked. +- -ts_corr(vol, close, 10) + Negative price-volume correlation over 10 days. +- if_(close / ts_delay(close, 1) - 1 > 0, ts_std(close / ts_delay(close, 1) - 1, 10), -ts_std(close / ts_delay(close, 1) - 1, 10)) + Conditional volatility: positive for up-moves, negative for down-moves. +- cs_rank((close - ts_min(low, 20)) / (ts_max(high, 20) - ts_min(low, 20))) + Position within 20-day price range, ranked. + +## KEY PRINCIPLES FOR HIGH-QUALITY FACTORS +- Always wrap the outermost expression with a cross-sectional operator (cs_rank, cs_zscore) for comparability. +- Combine DIFFERENT operator types for novelty (e.g., time-series + cross-sectional + arithmetic). +- Use diverse window sizes; avoid always defaulting to 10. +- Explore uncommon feature combinations (amount, amount/vol are underused). +- Factors with depth 3-7 tend to be best: deep enough to capture non-trivial patterns but not so deep they overfit. +- Prefer economically meaningful combinations over random nesting. +- IMPORTANT: Avoid operators that are NOT listed above (e.g., Decay, TsLinRegSlope, HMA, DEMA, Resid). If you use them, the factor will be rejected. +""" +``` + +**Step 4: 更新所有输出格式示例** + +在 `build_user_prompt`(约第333行)中,将示例公式替换为本地 DSL: + +``` +1. momentum_reversal: -cs_rank(ts_delta(close, 5)) +2. volume_surprise: cs_zscore((vol - ts_mean(vol, 20)) / ts_std(vol, 20)) +``` + +在 `build_specialist_prompt`(约第529行)中同步替换: + +``` +Example: 1. momentum_reversal: -cs_rank(ts_delta(close, 5)) +``` + +**Step 5: 运行 prompt_builder 相关测试(若已有)** + +```bash +uv run pytest tests/test_factorminer_prompt.py -v -k prompt +``` + +--- + +## Task 2: 修改 `src/factorminer/agent/output_parser.py` + +**Files:** +- Modify: `src/factorminer/agent/output_parser.py` +- Test: `tests/test_factorminer_prompt.py` + +**Step 1: 移除 FactorMiner 解析器依赖** + +删除以下导入: + +```python +from src.factorminer.core.expression_tree import ExpressionTree +from src.factorminer.core.parser import parse, try_parse +from src.factorminer.core.types import OperatorType, OPERATOR_REGISTRY +``` + +**Step 2: 修改 `CandidateFactor`** + +```python +@dataclass +class CandidateFactor: + """A candidate factor parsed from LLM output. + + Attributes + ---------- + name : str + Descriptive snake_case name. + formula : str + DSL formula string. + category : str + Inferred category based on outermost operators. + parse_error : str + Error message if formula failed basic validation. + """ + + name: str + formula: str + category: str = "unknown" + parse_error: str = "" + + @property + def is_valid(self) -> bool: + return not self.parse_error and bool(self.formula.strip()) +``` + +**Step 3: 修改 `_infer_category()`** + +将所有 CamelCase 算子名替换为 snake_case: + +```python +def _infer_category(formula: str) -> str: + """Infer a rough category from the outermost operators in the formula.""" + if any(op in formula for op in ("cs_rank", "cs_zscore", "cs_demean", "cs_neutralize", "cs_winsorize")): + if any(op in formula for op in ("ts_corr", "ts_cov")): + return "cross_sectional_regression" + if any(op in formula for op in ("ts_delta", "ts_delay", "ts_pct_change")): + return "cross_sectional_momentum" + if any(op in formula for op in ("ts_std", "ts_var", "ts_skew", "ts_kurt")): + return "cross_sectional_volatility" + if any(op in formula for op in ("ts_mean", "ts_sum", "ts_ema", "ts_wma")): + return "cross_sectional_smoothing" + return "cross_sectional" + if any(op in formula for op in ("ts_corr", "ts_cov")): + return "regression" + if any(op in formula for op in ("ts_delta", "ts_delay", "ts_pct_change")): + return "momentum" + if any(op in formula for op in ("ts_std", "ts_var", "ts_skew", "ts_kurt")): + return "volatility" + if any(op in formula for op in ("if_", "where", ">", "<")): + return "conditional" + return "general" +``` + +**Step 4: 修改 `_FORMULA_ONLY_PATTERN`** + +本地 DSL 公式可能以 `cs_`, `ts_` 开头,也可能以 `-` 开头(如 `-cs_rank(...)`),或字段名/数字开头: + +```python +_FORMULA_ONLY_PATTERN = re.compile( + r"^\s*([a-zA-Z_][a-zA-Z0-9_]*\s*\(.*\)|-.*|\d.*)\s*$" +) +``` + +**Step 5: 修改 `_clean_formula()`** + +移除 `$` 清洗逻辑(当前已不需要替换 `$` 前缀),保留注释、标点和反引号清理: + +```python +def _clean_formula(formula: str) -> str: + """Clean up a formula string before parsing.""" + formula = formula.strip() + # Remove trailing comments + if " #" in formula: + formula = formula[: formula.index(" #")] + if " //" in formula: + formula = formula[: formula.index(" //")] + # Remove trailing punctuation + formula = formula.rstrip(";,.") + # Remove surrounding backticks + formula = formula.strip("`") + return formula.strip() +``` + +**Step 6: 重写 `_try_build_candidate()`** + +不再调用 `try_parse(formula)` 或 `ExpressionTree`,仅做基础校验: + +```python +def _try_build_candidate(name: str, formula: str) -> CandidateFactor: + """Attempt to validate a formula and build a CandidateFactor.""" + # Basic validation: parenthesis balance + if formula.count("(") != formula.count(")"): + return CandidateFactor( + name=name, + formula=formula, + category="unknown", + parse_error="括号不匹配", + ) + category = _infer_category(formula) + return CandidateFactor(name=name, formula=formula, category=category) +``` + +**Step 7: 修改 `_generate_name_from_formula()`** + +正则提取的逻辑调整为适配 snake_case 函数名(第一个括号前的部分): + +```python +def _generate_name_from_formula(formula: str, index: int) -> str: + """Generate a descriptive name from a formula.""" + # Extract the outermost operator (snake_case) + m = re.match(r"([a-zA-Z_][a-zA-Z0-9_]*)\s*\(", formula) + if m: + outer_op = m.group(1).lower() + return f"{outer_op}_factor_{index + 1}" + # Handle unary minus + m = re.match(r"-([a-zA-Z_][a-zA-Z0-9_]*)\s*\(", formula) + if m: + outer_op = m.group(1).lower() + return f"neg_{outer_op}_factor_{index + 1}" + return f"factor_{index + 1}" +``` + +--- + +## Task 3: 适配 `src/factorminer/agent/factor_generator.py` + +**Files:** +- Modify: `src/factorminer/agent/factor_generator.py` + +**Step 1: 更新 retry prompt 的 DSL 规则描述** + +在 `_retry_failed_parses` 方法中(约第199行),将 repair_prompt 中的描述改为本地 DSL 规则: + +```python +repair_prompt = ( + "The following factor formulas failed to parse. " + "Fix each one so it uses ONLY valid local DSL operators and features " + "from the library. Return them in the same numbered format:\n" + ". : \n\n" + "Broken formulas:\n" + + "\n".join(f" {i+1}. {f}" for i, f in enumerate(failed)) + + "\n\nFix all syntax errors, unknown operators, and invalid " + "feature names. Use snake_case functions (e.g., ts_mean, cs_rank), " + "infix operators (+, -, *, /, >, <), and raw features without $ prefix. " + "Every formula must be valid in the local DSL." +) +``` + +**Step 2: 确认 `generate_batch` 无需修改** + +因为 `CandidateFactor.is_valid` 已改为基于字符串校验,`generate_batch` 中的过滤逻辑自然兼容。 + +--- + +## Task 4: 编写测试 `tests/test_factorminer_prompt.py` + +**Files:** +- Create: `tests/test_factorminer_prompt.py` + +**Step 1: 测试 system prompt 使用本地 DSL** + +```python +import pytest +from src.factorminer.agent.prompt_builder import SYSTEM_PROMPT + +def test_system_prompt_uses_local_dsl(): + assert "$close" not in SYSTEM_PROMPT + assert "CsRank(" not in SYSTEM_PROMPT + assert "cs_rank(" in SYSTEM_PROMPT + assert "close / ts_delay(close, 1) - 1" in SYSTEM_PROMPT + assert "ts_mean(close, 20)" in SYSTEM_PROMPT +``` + +**Step 2: 测试 OutputParser 正确提取本地 DSL** + +```python +from src.factorminer.agent.output_parser import parse_llm_output, CandidateFactor + +def test_parse_local_dsl_numbered_list(): + raw = ( + "1. momentum: -cs_rank(ts_delta(close, 5))\n" + "2. volume: cs_zscore((vol - ts_mean(vol, 20)) / ts_std(vol, 20))\n" + "3. vwap_dev: cs_rank((close - amount / vol) / (amount / vol))\n" + ) + candidates, failed = parse_llm_output(raw) + assert len(candidates) == 3 + assert candidates[0].name == "momentum" + assert candidates[0].formula == "-cs_rank(ts_delta(close, 5))" + assert candidates[0].is_valid + assert candidates[1].name == "volume" + assert candidates[1].formula == "cs_zscore((vol - ts_mean(vol, 20)) / ts_std(vol, 20))" + assert candidates[2].name == "vwap_dev" + assert not failed +``` + +**Step 3: 测试 formula-only 行** + +```python +def test_parse_local_dsl_formula_only(): + raw = "cs_rank(close / ts_delay(close, 5) - 1)" + candidates, failed = parse_llm_output(raw) + assert len(candidates) == 1 + assert candidates[0].formula == "cs_rank(close / ts_delay(close, 5) - 1)" + assert not failed +``` + +**Step 4: 测试括号不匹配标记为无效** + +```python +def test_parse_invalid_parentheses(): + candidates, failed = parse_llm_output("1. bad: cs_rank(ts_delta(close, 5)") + assert len(candidates) == 1 + assert not candidates[0].is_valid + assert "括号" in candidates[0].parse_error +``` + +**Step 5: 测试分类推断** + +```python +def test_infer_category_local_dsl(): + from src.factorminer.agent.output_parser import _infer_category + assert _infer_category("cs_rank(ts_delta(close, 5))") == "cross_sectional_momentum" + assert _infer_category("ts_corr(vol, close, 10)") == "regression" + assert _infer_category("ts_std(close, 20)") == "volatility" + assert _infer_category("if_(close > open, 1, -1)") == "conditional" +``` + +**Step 6: 运行测试** + +```bash +uv run pytest tests/test_factorminer_prompt.py -v +``` + +预期:所有测试通过。 + +--- + +## 执行命令汇总 + +```bash +# 安装依赖(若尚未安装) +uv pip install -e . + +# 运行新增测试 +uv run pytest tests/test_factorminer_prompt.py -v + +# 运行 factorminer 相关测试 +uv run pytest tests/test_factorminer_* -v +``` + +--- + +## 提交建议 + +修改完成后建议拆分为两个 commits: + +1. `refactor(factorminer): rewrite LLM prompts to output local snake_case DSL` +2. `test(factorminer): add prompt and output parser tests for local DSL` diff --git a/src/factorminer/agent/factor_generator.py b/src/factorminer/agent/factor_generator.py index e0cf4a9..3e9d1cc 100644 --- a/src/factorminer/agent/factor_generator.py +++ b/src/factorminer/agent/factor_generator.py @@ -117,7 +117,9 @@ class FactorGenerator: max_tokens=self.max_tokens, ) elapsed = time.monotonic() - t0 - logger.info("LLM response received in %.1fs (%d chars)", elapsed, len(raw_output)) + logger.info( + "LLM response received in %.1fs (%d chars)", elapsed, len(raw_output) + ) # 3. Parse output candidates, failed_lines = parse_llm_output(raw_output) @@ -198,14 +200,15 @@ class FactorGenerator: repair_prompt = ( "The following factor formulas failed to parse. " - "Fix each one so it uses ONLY valid operators and features " + "Fix each one so it uses ONLY valid local DSL operators and features " "from the library. Return them in the same numbered format:\n" ". : \n\n" "Broken formulas:\n" - + "\n".join(f" {i+1}. {f}" for i, f in enumerate(failed)) + + "\n".join(f" {i + 1}. {f}" for i, f in enumerate(failed)) + "\n\nFix all syntax errors, unknown operators, and invalid " - "feature names. Every formula must be a valid nested function " - "call using only operators from the library." + "feature names. Use snake_case functions (e.g., ts_mean, cs_rank), " + "infix operators (+, -, *, /, >, <), and raw features without $ prefix. " + "Every formula must be valid in the local DSL." ) try: diff --git a/src/factorminer/agent/output_parser.py b/src/factorminer/agent/output_parser.py index 3634045..03a928f 100644 --- a/src/factorminer/agent/output_parser.py +++ b/src/factorminer/agent/output_parser.py @@ -1,20 +1,16 @@ """Parse LLM output into structured CandidateFactor objects. Handles various output formats from LLMs: numbered lists, JSON, -markdown code blocks, and raw text. Validates each formula against -the expression tree parser. +markdown code blocks, and raw text. Validates each formula with +basic string checks; no FactorMiner-specific parsing is performed. """ from __future__ import annotations import logging import re -from dataclasses import dataclass, field -from typing import Dict, List, Optional, Tuple - -from src.factorminer.core.expression_tree import ExpressionTree -from src.factorminer.core.parser import parse, try_parse -from src.factorminer.core.types import OperatorType, OPERATOR_REGISTRY +from dataclasses import dataclass +from typing import List, Optional, Tuple logger = logging.getLogger(__name__) @@ -29,49 +25,44 @@ class CandidateFactor: Descriptive snake_case name. formula : str DSL formula string. - expression_tree : ExpressionTree or None - Parsed expression tree (None if parsing failed). category : str Inferred category based on outermost operators. parse_error : str - Error message if formula failed to parse. + Error message if formula failed basic validation. """ name: str formula: str - expression_tree: Optional[ExpressionTree] = None category: str = "unknown" parse_error: str = "" @property def is_valid(self) -> bool: - return self.expression_tree is not None + return not self.parse_error and bool(self.formula.strip()) def _infer_category(formula: str) -> str: """Infer a rough category from the outermost operators in the formula.""" - lower = formula.lower() - # Check for cross-sectional operators at the top - if any(op in formula for op in ("CsRank", "CsZScore", "CsDemean", "CsScale", "CsNeutralize", "CsQuantile")): - # Look deeper for sub-category - if any(op in formula for op in ("Corr", "Cov", "Beta", "Resid")): + if any( + op in formula + for op in ("cs_rank", "cs_zscore", "cs_demean", "cs_neutralize", "cs_winsorize") + ): + if any(op in formula for op in ("ts_corr", "ts_cov")): return "cross_sectional_regression" - if any(op in formula for op in ("Delta", "Delay", "Return", "LogReturn")): + if any(op in formula for op in ("ts_delta", "ts_delay", "ts_pct_change")): return "cross_sectional_momentum" - if any(op in formula for op in ("Std", "Var", "Skew", "Kurt")): + if any(op in formula for op in ("ts_std", "ts_var", "ts_skew", "ts_kurt")): return "cross_sectional_volatility" - if any(op in formula for op in ("Mean", "Sum", "EMA", "SMA", "WMA", "DEMA", "HMA", "KAMA")): + if any(op in formula for op in ("ts_mean", "ts_sum", "ts_ema", "ts_wma")): return "cross_sectional_smoothing" - if any(op in formula for op in ("TsLinReg", "TsLinRegSlope")): - return "cross_sectional_trend" return "cross_sectional" - if any(op in formula for op in ("Corr", "Cov", "Beta", "Resid")): + if any(op in formula for op in ("ts_corr", "ts_cov")): return "regression" - if any(op in formula for op in ("Delta", "Delay", "Return", "LogReturn")): + if any(op in formula for op in ("ts_delta", "ts_delay", "ts_pct_change")): return "momentum" - if any(op in formula for op in ("Std", "Var", "Skew", "Kurt")): + if any(op in formula for op in ("ts_std", "ts_var", "ts_skew", "ts_kurt")): return "volatility" - if any(op in formula for op in ("IfElse", "Greater", "Less")): + if any(op in formula for op in ("if_", "where", ">", "<")): return "conditional" return "general" @@ -82,28 +73,26 @@ def _infer_category(formula: str) -> str: # Pattern: "1. name: formula" or "1) name: formula" _NUMBERED_PATTERN = re.compile( - r"^\s*\d+[\.\)]\s*" # numbered prefix + r"^\s*\d+[\.\)]\s*" # numbered prefix r"([a-zA-Z_][a-zA-Z0-9_]*)" # factor name - r"\s*:\s*" # colon separator - r"(.+)$" # formula + r"\s*:\s*" # colon separator + r"(.+)$" # formula ) # Pattern: "name: formula" (no number) _PLAIN_PATTERN = re.compile( r"^\s*([a-zA-Z_][a-zA-Z0-9_]*)" # factor name - r"\s*:\s*" # colon separator - r"(.+)$" # formula + r"\s*:\s*" # colon separator + r"(.+)$" # formula ) -# Pattern: just a formula starting with an operator +# Pattern: just a formula starting with a function call, unary minus, or number _FORMULA_ONLY_PATTERN = re.compile( - r"^\s*([A-Z][a-zA-Z]*\(.+\))\s*$" + r"^\s*([a-zA-Z_][a-zA-Z0-9_]*\s*\(.*\)|-.*|\d.*)\s*$" ) # Pattern: JSON-like {"name": "...", "formula": "..."} -_JSON_PATTERN = re.compile( - r'"name"\s*:\s*"([^"]+)"\s*,\s*"formula"\s*:\s*"([^"]+)"' -) +_JSON_PATTERN = re.compile(r'"name"\s*:\s*"([^"]+)"\s*,\s*"formula"\s*:\s*"([^"]+)"') def _strip_markdown(text: str) -> str: @@ -152,6 +141,8 @@ def parse_llm_output(raw_text: str) -> Tuple[List[CandidateFactor], List[str]]: json_matches = _JSON_PATTERN.findall(text) if json_matches: for name, formula in json_matches: + if name is None or formula is None: + continue formula = _clean_formula(formula) candidate = _try_build_candidate(name, formula) if candidate.name not in seen_names: @@ -184,6 +175,7 @@ def parse_llm_output(raw_text: str) -> Tuple[List[CandidateFactor], List[str]]: m = _FORMULA_ONLY_PATTERN.match(line) if m: formula = m.group(1) + assert formula is not None # Generate name from formula name = _generate_name_from_formula(formula, len(candidates)) @@ -222,38 +214,29 @@ def parse_llm_output(raw_text: str) -> Tuple[List[CandidateFactor], List[str]]: def _try_build_candidate(name: str, formula: str) -> CandidateFactor: - """Attempt to parse a formula and build a CandidateFactor.""" - tree = try_parse(formula) - if tree is not None: - category = _infer_category(formula) - return CandidateFactor( - name=name, - formula=tree.to_string(), # canonicalize - expression_tree=tree, - category=category, - ) - else: - # Try to get a useful error message - error_msg = "" - try: - parse(formula) - except (SyntaxError, KeyError, ValueError) as e: - error_msg = str(e) - + """Attempt to validate a formula and build a CandidateFactor.""" + # Basic validation: parenthesis balance + if formula.count("(") != formula.count(")"): return CandidateFactor( name=name, formula=formula, - expression_tree=None, category="unknown", - parse_error=error_msg, + parse_error="括号不匹配", ) + category = _infer_category(formula) + return CandidateFactor(name=name, formula=formula, category=category) def _generate_name_from_formula(formula: str, index: int) -> str: """Generate a descriptive name from a formula.""" - # Extract the outermost operator - m = re.match(r"([A-Z][a-zA-Z]*)\(", formula) + # Extract the outermost operator (snake_case) + m = re.match(r"([a-zA-Z_][a-zA-Z0-9_]*)\s*\(", formula) if m: outer_op = m.group(1).lower() return f"{outer_op}_factor_{index + 1}" + # Handle unary minus + m = re.match(r"-([a-zA-Z_][a-zA-Z0-9_]*)\s*\(", formula) + if m: + outer_op = m.group(1).lower() + return f"neg_{outer_op}_factor_{index + 1}" return f"factor_{index + 1}" diff --git a/src/factorminer/agent/prompt_builder.py b/src/factorminer/agent/prompt_builder.py index 3b9a1b8..0a0dad6 100644 --- a/src/factorminer/agent/prompt_builder.py +++ b/src/factorminer/agent/prompt_builder.py @@ -9,68 +9,81 @@ from __future__ import annotations from typing import Any, Dict, List, Optional -from src.factorminer.core.types import ( - FEATURES, - OPERATOR_REGISTRY, - OperatorSpec, - OperatorType, -) + +LOCAL_OPERATOR_TABLE = { + "ARITHMETIC": [ + ("+", "二元", "x + y"), + ("-", "二元/一元", "x - y 或 -x"), + ("*", "二元", "x * y"), + ("/", "二元", "x / y"), + ("**", "二元", "x ** y (幂运算)"), + (">", "二元", "x > y (条件判断,返回 0/1)"), + ("<", "二元", "x < y (条件判断,返回 0/1)"), + ("abs(x)", "一元", "绝对值"), + ("sign(x)", "一元", "符号函数"), + ("max_(x, y)", "二元", "逐元素最大值"), + ("min_(x, y)", "二元", "逐元素最小值"), + ("clip(x, lower, upper)", "一元带参", "截断"), + ("log(x)", "一元", "自然对数"), + ("sqrt(x)", "一元", "平方根"), + ("exp(x)", "一元", "指数函数"), + ], + "TIMESERIES": [ + ("ts_mean(x, window)", "一元+窗口", "滚动均值"), + ("ts_std(x, window)", "一元+窗口", "滚动标准差"), + ("ts_var(x, window)", "一元+窗口", "滚动方差"), + ("ts_max(x, window)", "一元+窗口", "滚动最大值"), + ("ts_min(x, window)", "一元+窗口", "滚动最小值"), + ("ts_sum(x, window)", "一元+窗口", "滚动求和"), + ("ts_delay(x, periods)", "一元+周期", "滞后 N 期"), + ("ts_delta(x, periods)", "一元+周期", "差分 N 期"), + ("ts_corr(x, y, window)", "二元+窗口", "滚动相关系数"), + ("ts_cov(x, y, window)", "二元+窗口", "滚动协方差"), + ("ts_pct_change(x, periods)", "一元+周期", "N 期百分比变化"), + ("ts_ema(x, window)", "一元+窗口", "指数移动平均"), + ("ts_wma(x, window)", "一元+窗口", "加权移动平均"), + ("ts_skew(x, window)", "一元+窗口", "滚动偏度"), + ("ts_kurt(x, window)", "一元+窗口", "滚动峰度"), + ("ts_rank(x, window)", "一元+窗口", "滚动分位排名"), + ], + "CROSS_SECTIONAL": [ + ("cs_rank(x)", "一元", "截面排名(分位数)"), + ("cs_zscore(x)", "一元", "截面 Z-Score 标准化"), + ("cs_demean(x)", "一元", "截面去均值"), + ("cs_neutralize(x, group)", "一元", "行业/市值中性化"), + ("cs_winsorize(x, lower, upper)", "一元", "截面缩尾处理"), + ], + "CONDITIONAL": [ + ("if_(condition, true_val, false_val)", "三元", "条件选择"), + ("where(condition, true_val, false_val)", "三元", "if_ 的别名"), + ], +} def _format_operator_table() -> str: """Build a human-readable operator reference table grouped by category.""" - grouped: Dict[str, List[OperatorSpec]] = {} - for spec in OPERATOR_REGISTRY.values(): - cat = spec.category.name - grouped.setdefault(cat, []).append(spec) - - lines: List[str] = [] - for cat_name in [ - "ARITHMETIC", - "STATISTICAL", - "TIMESERIES", - "SMOOTHING", - "CROSS_SECTIONAL", - "REGRESSION", - "LOGICAL", - "AUTO_INVENTED", - ]: - specs = grouped.get(cat_name, []) - if not specs: - continue + lines = [] + for cat_name, ops in LOCAL_OPERATOR_TABLE.items(): lines.append(f"\n### {cat_name} operators") - for spec in sorted(specs, key=lambda s: s.name): - params_str = "" - if spec.param_names: - parts = [] - for pname in spec.param_names: - default = spec.param_defaults.get(pname, "") - lo, hi = spec.param_ranges.get(pname, (None, None)) - range_str = f"[{lo}-{hi}]" if lo is not None else "" - parts.append(f"{pname}={default}{range_str}") - params_str = f" params: {', '.join(parts)}" - arity_args = ", ".join([f"expr{i+1}" for i in range(spec.arity)]) - if spec.param_names: - arity_args += ", " + ", ".join(spec.param_names) - lines.append(f"- {spec.name}({arity_args}): {spec.description}{params_str}") + for op_sig, arity, desc in ops: + lines.append(f"- {op_sig}: {desc} ({arity})") return "\n".join(lines) def _format_feature_list() -> str: """Build a description of available raw features.""" descriptions = { - "$open": "opening price", - "$high": "highest price in the bar", - "$low": "lowest price in the bar", - "$close": "closing price", - "$volume": "trading volume (shares)", - "$amt": "trading amount (currency value)", - "$vwap": "volume-weighted average price", - "$returns": "close-to-close returns", + "open": "开盘价", + "high": "最高价", + "low": "最低价", + "close": "收盘价", + "vol": "成交量(股数)", + "amount": "成交额(金额)", + "vwap": "可用 amount / vol 计算", + "returns": "可用 close / ts_delay(close, 1) - 1 计算", } lines = [] - for feat in FEATURES: - desc = descriptions.get(feat, "") + for feat, desc in descriptions.items(): lines.append(f" {feat}: {desc}") return "\n".join(lines) @@ -81,7 +94,7 @@ def _format_feature_list() -> str: SYSTEM_PROMPT = f"""You are a quantitative researcher mining formulaic alpha factors for stock selection. -Your goal is to generate novel, predictive factor expressions using a tree-structured domain-specific language (DSL). Each factor is a composition of operators applied to raw market features. +Your goal is to generate novel, predictive factor expressions using the local ProStock DSL. Each factor is a composition of operators applied to raw market features. ## RAW FEATURES (leaf nodes) {_format_feature_list()} @@ -90,40 +103,37 @@ Your goal is to generate novel, predictive factor expressions using a tree-struc {_format_operator_table()} ## EXPRESSION SYNTAX RULES -1. Every expression is a nested function call: Operator(args...) -2. Leaf nodes are raw features ($close, $volume, etc.) or numeric constants. -3. Operators are called by name with expression arguments first, then numeric parameters: - - Mean($close, 20) = 20-day rolling mean of $close - - Corr($close, $volume, 10) = 10-day rolling correlation of close and volume - - IfElse(Greater($returns, 0), $volume, Neg($volume)) = conditional -4. No infix operators; use Add(x,y) instead of x+y, Sub(x,y) instead of x-y, etc. -5. Parameters like window sizes are trailing numeric arguments after expression children. -6. Valid window sizes are integers; check each operator's parameter ranges above. -7. Cross-sectional operators (CsRank, CsZScore, CsDemean, CsScale, CsNeutralize) operate across all stocks at each time step -- they are crucial for making factors comparable. +1. Expressions use Python-style infix operators: +, -, *, /, **, >, < +2. Function calls use snake_case names with comma-separated arguments: ts_mean(close, 20) +3. Window sizes and periods are numeric arguments placed last in function calls. +4. Valid window sizes are integers, typically in range [2, 250]. +5. Cross-sectional operators (cs_rank, cs_zscore, cs_demean) operate across all stocks at each time step -- they are crucial for making factors comparable. +6. Do NOT use $ prefix for features. Use `close`, `vol`, `amount`, etc. directly. +7. `vwap` is not a raw feature; use `amount / vol` if you need it. +8. `returns` is not a raw feature; use `close / ts_delay(close, 1) - 1` if you need returns. ## EXAMPLES OF WELL-FORMED FACTORS -- Neg(CsRank(Delta($close, 5))) +- -cs_rank(ts_delta(close, 5)) Short-term reversal: rank of 5-day price change, negated. -- CsZScore(Div(Sub($volume, Mean($volume, 20)), Std($volume, 20))) +- cs_zscore((vol - ts_mean(vol, 20)) / ts_std(vol, 20)) Volume surprise: standardized deviation from 20-day mean volume. -- CsRank(Div(Sub($close, $vwap), $vwap)) +- cs_rank((close - amount / vol) / (amount / vol)) Intraday deviation from VWAP, cross-sectionally ranked. -- Neg(Corr($volume, $close, 10)) +- -ts_corr(vol, close, 10) Negative price-volume correlation over 10 days. -- CsRank(TsLinRegSlope($volume, 20)) - Trend in trading volume over 20 days, ranked. -- IfElse(Greater($returns, 0), Std($returns, 10), Neg(Std($returns, 10))) +- if_(close / ts_delay(close, 1) - 1 > 0, ts_std(close / ts_delay(close, 1) - 1, 10), -ts_std(close / ts_delay(close, 1) - 1, 10)) Conditional volatility: positive for up-moves, negative for down-moves. -- CsRank(Div(Sub($close, TsMin($low, 20)), Sub(TsMax($high, 20), TsMin($low, 20)))) +- cs_rank((close - ts_min(low, 20)) / (ts_max(high, 20) - ts_min(low, 20))) Position within 20-day price range, ranked. ## KEY PRINCIPLES FOR HIGH-QUALITY FACTORS -- Always wrap the outermost expression with a cross-sectional operator (CsRank, CsZScore) for comparability. +- Always wrap the outermost expression with a cross-sectional operator (cs_rank, cs_zscore) for comparability. - Combine DIFFERENT operator types for novelty (e.g., time-series + cross-sectional + arithmetic). - Use diverse window sizes; avoid always defaulting to 10. -- Explore uncommon feature combinations ($amt, $vwap are underused). +- Explore uncommon feature combinations (amount, amount/vol are underused). - Factors with depth 3-7 tend to be best: deep enough to capture non-trivial patterns but not so deep they overfit. - Prefer economically meaningful combinations over random nesting. +- IMPORTANT: Avoid operators that are NOT listed above (e.g., Decay, TsLinRegSlope, HMA, DEMA, Resid). If you use them, the factor will be rejected. """ @@ -131,6 +141,7 @@ Your goal is to generate novel, predictive factor expressions using a tree-struc # PromptBuilder # --------------------------------------------------------------------------- + def normalize_factor_references(entries: Optional[List[Any]]) -> List[str]: """Convert mixed factor metadata into prompt-safe string references.""" if not entries: @@ -220,13 +231,10 @@ class PromptBuilder: lib_size = library_state.get("size", 0) target = library_state.get("target_size", 110) sections.append( - f"\n## CURRENT LIBRARY STATUS\n" - f"Library size: {lib_size} / {target} factors." + f"\n## CURRENT LIBRARY STATUS\nLibrary size: {lib_size} / {target} factors." ) - recent = normalize_factor_references( - library_state.get("recent_admissions", []) - ) + recent = normalize_factor_references(library_state.get("recent_admissions", [])) if recent: sections.append( "Recently admitted factors:\n" @@ -235,10 +243,10 @@ class PromptBuilder: saturation = library_state.get("domain_saturation", {}) if saturation: - sat_lines = [f" {domain}: {pct:.0%} saturated" for domain, pct in saturation.items()] - sections.append( - "Domain saturation:\n" + "\n".join(sat_lines) - ) + sat_lines = [ + f" {domain}: {pct:.0%} saturated" for domain, pct in saturation.items() + ] + sections.append("Domain saturation:\n" + "\n".join(sat_lines)) # --- Memory signal: recommended directions --- rec_dirs = memory_signal.get("recommended_directions", []) @@ -266,10 +274,7 @@ class PromptBuilder: helix_prompt_text = memory_signal.get("prompt_text", "").strip() if helix_prompt_text: - sections.append( - "\n## HELIX RETRIEVAL SUMMARY\n" - f"{helix_prompt_text}" - ) + sections.append(f"\n## HELIX RETRIEVAL SUMMARY\n{helix_prompt_text}") complementary_patterns = memory_signal.get("complementary_patterns", []) if complementary_patterns: @@ -330,8 +335,8 @@ class PromptBuilder: f"Output exactly {batch_size} factors, one per line.\n" f"Format each line as: . : \n" f"Example:\n" - f"1. momentum_reversal: Neg(CsRank(Delta($close, 5)))\n" - f"2. volume_surprise: CsZScore(Div(Sub($volume, Mean($volume, 20)), Std($volume, 20)))\n" + f"1. momentum_reversal: -cs_rank(ts_delta(close, 5))\n" + f"2. volume_surprise: cs_zscore((vol - ts_mean(vol, 20)) / ts_std(vol, 20))\n" f"\nRules:\n" f"- factor_name: lowercase_with_underscores, descriptive, unique\n" f"- formula: valid DSL expression using ONLY operators and features listed above\n" @@ -346,6 +351,7 @@ class PromptBuilder: # New specialist/critic/debate prompt builder functions # --------------------------------------------------------------------------- + def build_specialist_prompt( specialist_name: str, specialist_domain: str, @@ -413,16 +419,12 @@ def build_specialist_prompt( # Regime context if regime_context: - sections.append( - f"\n## CURRENT MARKET REGIME\n{regime_context}" - ) + sections.append(f"\n## CURRENT MARKET REGIME\n{regime_context}") # Library state lib_size = library_diagnostics.get("size", 0) target = library_diagnostics.get("target_size", 110) - sections.append( - f"\n## LIBRARY STATUS\nCurrent: {lib_size}/{target} factors." - ) + sections.append(f"\n## LIBRARY STATUS\nCurrent: {lib_size}/{target} factors.") recent = normalize_factor_references( library_diagnostics.get("recent_admissions", []) @@ -435,31 +437,26 @@ def build_specialist_prompt( saturation = library_diagnostics.get("domain_saturation", {}) if saturation: - sat_lines = [ - f" {d}: {p:.0%} saturated" for d, p in saturation.items() - ] + sat_lines = [f" {d}: {p:.0%} saturated" for d, p in saturation.items()] sections.append("Domain saturation:\n" + "\n".join(sat_lines)) # Memory signal injections rec_dirs = memory_signal.get("recommended_directions", []) if rec_dirs: sections.append( - "\n## RECOMMENDED DIRECTIONS\n" - + "\n".join(f" * {d}" for d in rec_dirs) + "\n## RECOMMENDED DIRECTIONS\n" + "\n".join(f" * {d}" for d in rec_dirs) ) forbidden = memory_signal.get("forbidden_directions", []) if forbidden: sections.append( - "\n## FORBIDDEN DIRECTIONS\n" - + "\n".join(f" X {d}" for d in forbidden) + "\n## FORBIDDEN DIRECTIONS\n" + "\n".join(f" X {d}" for d in forbidden) ) insights = memory_signal.get("strategic_insights", []) if insights: sections.append( - "\n## STRATEGIC INSIGHTS\n" - + "\n".join(f" - {ins}" for ins in insights) + "\n## STRATEGIC INSIGHTS\n" + "\n".join(f" - {ins}" for ins in insights) ) helix_text = memory_signal.get("prompt_text", "").strip() @@ -476,8 +473,7 @@ def build_specialist_prompt( warn = memory_signal.get("conflict_warnings", []) if warn: sections.append( - "\n## SATURATION WARNINGS\n" - + "\n".join(f" ! {w}" for w in warn) + "\n## SATURATION WARNINGS\n" + "\n".join(f" ! {w}" for w in warn) ) gaps = memory_signal.get("semantic_gaps", []) @@ -508,8 +504,7 @@ def build_specialist_prompt( # Avoid patterns if avoid_patterns: sections.append( - "\n## PATTERNS TO AVOID\n" - + "\n".join(f" X {av}" for av in avoid_patterns) + "\n## PATTERNS TO AVOID\n" + "\n".join(f" X {av}" for av in avoid_patterns) ) # Few-shot patterns from memory @@ -526,7 +521,7 @@ def build_specialist_prompt( f"\n## OUTPUT FORMAT\n" f"Generate exactly {n_proposals} novel factor candidates.\n" f"Format: . : \n" - f"Example: 1. momentum_reversal: Neg(CsRank(Delta($close, 5)))\n" + f"Example: 1. momentum_reversal: -cs_rank(ts_delta(close, 5))\n" f"Rules:\n" f"- factor_name: lowercase_with_underscores, unique, descriptive\n" f"- formula: valid DSL expression only\n" @@ -581,16 +576,16 @@ def build_critic_scoring_prompt( ) if memory_signal: - sections.append(f"\n## MEMORY CONTEXT (success patterns)\n{memory_signal[:600]}") + sections.append( + f"\n## MEMORY CONTEXT (success patterns)\n{memory_signal[:600]}" + ) sections.append("\n## CANDIDATES") for c in candidates: name = c.get("name", "unknown") formula = c.get("formula", "") specialist = c.get("specialist", "unknown") - sections.append( - f" [{specialist}] {name}: {formula}" - ) + sections.append(f" [{specialist}] {name}: {formula}") sections.append( "\n## SCORING CRITERIA\n" @@ -647,7 +642,7 @@ def build_debate_synthesis_prompt( all_proposals, key=lambda p: score_map.get(p.get("name", ""), 0.0), reverse=True, - )[:top_k * 2] # take 2x top_k for synthesis + )[: top_k * 2] # take 2x top_k for synthesis sections: List[str] = [] sections.append( @@ -664,9 +659,7 @@ def build_debate_synthesis_prompt( formula = p.get("formula", "?") specialist = p.get("specialist", "?") score = score_map.get(name, 0.5) - sections.append( - f" [{specialist}, score={score:.2f}] {name}: {formula}" - ) + sections.append(f" [{specialist}, score={score:.2f}] {name}: {formula}") sections.append( f"\n## SELECTION CRITERIA\n" diff --git a/tests/test_factorminer_prompt.py b/tests/test_factorminer_prompt.py new file mode 100644 index 0000000..5c0d6d1 --- /dev/null +++ b/tests/test_factorminer_prompt.py @@ -0,0 +1,72 @@ +"""Tests for local DSL prompt and output parser in Step 3.""" + +from src.factorminer.agent.output_parser import ( + CandidateFactor, + _infer_category, + parse_llm_output, +) +from src.factorminer.agent.prompt_builder import SYSTEM_PROMPT + + +def test_system_prompt_uses_local_dsl(): + assert "$close" not in SYSTEM_PROMPT + assert "CsRank(" not in SYSTEM_PROMPT + assert "cs_rank(" in SYSTEM_PROMPT + assert "close / ts_delay(close, 1) - 1" in SYSTEM_PROMPT + assert "ts_mean(close, 20)" in SYSTEM_PROMPT + + +def test_parse_local_dsl_numbered_list(): + raw = ( + "1. momentum: -cs_rank(ts_delta(close, 5))\n" + "2. volume: cs_zscore((vol - ts_mean(vol, 20)) / ts_std(vol, 20))\n" + "3. vwap_dev: cs_rank((close - amount / vol) / (amount / vol))\n" + ) + candidates, failed = parse_llm_output(raw) + assert len(candidates) == 3 + assert candidates[0].name == "momentum" + assert candidates[0].formula == "-cs_rank(ts_delta(close, 5))" + assert candidates[0].is_valid + assert candidates[1].name == "volume" + assert ( + candidates[1].formula == "cs_zscore((vol - ts_mean(vol, 20)) / ts_std(vol, 20))" + ) + assert candidates[2].name == "vwap_dev" + assert not failed + + +def test_parse_local_dsl_formula_only(): + raw = "cs_rank(close / ts_delay(close, 5) - 1)" + candidates, failed = parse_llm_output(raw) + assert len(candidates) == 1 + assert candidates[0].formula == "cs_rank(close / ts_delay(close, 5) - 1)" + assert not failed + + +def test_parse_invalid_parentheses(): + candidates, failed = parse_llm_output("1. bad: cs_rank(ts_delta(close, 5)") + assert len(candidates) == 1 + assert not candidates[0].is_valid + assert "括号" in candidates[0].parse_error + + +def test_infer_category_local_dsl(): + assert _infer_category("cs_rank(ts_delta(close, 5))") == "cross_sectional_momentum" + assert _infer_category("ts_corr(vol, close, 10)") == "regression" + assert _infer_category("ts_std(close, 20)") == "volatility" + assert _infer_category("if_(close > open, 1, -1)") == "conditional" + + +def test_candidate_factor_is_valid_without_tree(): + cf = CandidateFactor(name="test", formula="cs_rank(close)") + assert cf.is_valid + assert cf.category == "unknown" + + +def test_parse_json_local_dsl(): + raw = '{"name": "mom", "formula": "cs_rank(close / ts_delay(close, 5) - 1)"}' + candidates, failed = parse_llm_output(raw) + assert len(candidates) == 1 + assert candidates[0].name == "mom" + assert candidates[0].formula == "cs_rank(close / ts_delay(close, 5) - 1)" + assert not failed