refactor(factorminer): 将 LLM Prompt 和解析器改造为直接输出本地 DSL

- DSL 规范改为 snake_case、中缀运算符,示例同步替换
- 移除 ExpressionTree 依赖,改为括号匹配等基础校验
- retry prompt 适配本地 DSL 规则
This commit is contained in:
2026-04-08 22:27:33 +08:00
parent 65500cce27
commit dd2e8a4a8e
5 changed files with 686 additions and 177 deletions

View File

@@ -0,0 +1,458 @@
# Step 3: LLM Prompt 改造(直接生成本地 DSL实施计划
> **For Claude:** REQUIRED SUB-SKILL: Use superpowers:executing-plans to implement this plan task-by-task.
**Goal:** 将 FactorMiner 的 LLM Prompt 和输出解析器从 CamelCase + `$` 前缀 DSL 改造为直接生成本地 snake_case DSL移除运行时翻译层。
**Architecture:** Prompt 直接使用本地 `FactorEngine` 支持的 snake_case 函数名和字段名;`OutputParser` 仅做字符串提取和轻量清洗,不再调用 FactorMiner 的 `ExpressionTree` 解析;`factor_generator.py` 配合返回原始 DSL 字符串。
**Tech Stack:** Python, ProStock `src.factors` 本地 DSL (`FactorEngine`)
---
## Task 1: 重写 `src/factorminer/agent/prompt_builder.py`
**Files:**
- Modify: `src/factorminer/agent/prompt_builder.py`
- Test: `tests/test_factorminer_prompt.py`
**Step 1: 重写字段列表函数 `_format_feature_list()`**
`$` 前缀字段替换为本地字段,并添加计算说明:
```python
def _format_feature_list() -> str:
descriptions = {
"open": "开盘价",
"high": "最高价",
"low": "最低价",
"close": "收盘价",
"vol": "成交量(股数)",
"amount": "成交额(金额)",
"vwap": "可用 amount / vol 计算",
"returns": "可用 close / ts_delay(close, 1) - 1 计算",
}
lines = []
for feat, desc in descriptions.items():
lines.append(f" {feat}: {desc}")
return "\n".join(lines)
```
**Step 2: 定义本地 DSL 算子表映射**
`prompt_builder.py` 中新增 `LOCAL_OPERATOR_TABLE` 常量,列出 prompt 中需要展示的本地可用算子(按类别分组),不再依赖 `OPERATOR_REGISTRY` 遍历:
```python
LOCAL_OPERATOR_TABLE = {
"ARITHMETIC": [
("+", "二元", "x + y"),
("-", "二元/一元", "x - y 或 -x"),
("*", "二元", "x * y"),
("/", "二元", "x / y"),
("**", "二元", "x ** y (幂运算)"),
(">", "二元", "x > y (条件判断,返回 0/1)"),
("<", "二元", "x < y (条件判断,返回 0/1)"),
("abs(x)", "一元", "绝对值"),
("sign(x)", "一元", "符号函数"),
("max_(x, y)", "二元", "逐元素最大值"),
("min_(x, y)", "二元", "逐元素最小值"),
("clip(x, lower, upper)", "一元带参", "截断"),
("log(x)", "一元", "自然对数"),
("sqrt(x)", "一元", "平方根"),
("exp(x)", "一元", "指数函数"),
],
"TIMESERIES": [
("ts_mean(x, window)", "一元+窗口", "滚动均值"),
("ts_std(x, window)", "一元+窗口", "滚动标准差"),
("ts_var(x, window)", "一元+窗口", "滚动方差"),
("ts_max(x, window)", "一元+窗口", "滚动最大值"),
("ts_min(x, window)", "一元+窗口", "滚动最小值"),
("ts_sum(x, window)", "一元+窗口", "滚动求和"),
("ts_delay(x, periods)", "一元+周期", "滞后 N 期"),
("ts_delta(x, periods)", "一元+周期", "差分 N 期"),
("ts_corr(x, y, window)", "二元+窗口", "滚动相关系数"),
("ts_cov(x, y, window)", "二元+窗口", "滚动协方差"),
("ts_pct_change(x, periods)", "一元+周期", "N 期百分比变化"),
("ts_ema(x, window)", "一元+窗口", "指数移动平均"),
("ts_wma(x, window)", "一元+窗口", "加权移动平均"),
("ts_skew(x, window)", "一元+窗口", "滚动偏度"),
("ts_kurt(x, window)", "一元+窗口", "滚动峰度"),
("ts_rank(x, window)", "一元+窗口", "滚动分位排名"),
],
"CROSS_SECTIONAL": [
("cs_rank(x)", "一元", "截面排名(分位数)"),
("cs_zscore(x)", "一元", "截面 Z-Score 标准化"),
("cs_demean(x)", "一元", "截面去均值"),
("cs_neutralize(x, group)", "一元", "行业/市值中性化"),
("cs_winsorize(x, lower, upper)", "一元", "截面缩尾处理"),
],
"CONDITIONAL": [
("if_(condition, true_val, false_val)", "三元", "条件选择"),
("where(condition, true_val, false_val)", "三元", "if_ 的别名"),
],
}
```
然后重写 `_format_operator_table()`
```python
def _format_operator_table() -> str:
lines = []
for cat_name, ops in LOCAL_OPERATOR_TABLE.items():
lines.append(f"\n### {cat_name} operators")
for op_sig, arity, desc in ops:
lines.append(f"- {op_sig}: {desc} ({arity})")
return "\n".join(lines)
```
**Step 3: 重写 `SYSTEM_PROMPT`**
替换语法规则段落和示例:
```python
SYSTEM_PROMPT = f"""You are a quantitative researcher mining formulaic alpha factors for stock selection.
Your goal is to generate novel, predictive factor expressions using the local ProStock DSL. Each factor is a composition of operators applied to raw market features.
## RAW FEATURES (leaf nodes)
{_format_feature_list()}
## OPERATOR LIBRARY
{_format_operator_table()}
## EXPRESSION SYNTAX RULES
1. Expressions use Python-style infix operators: +, -, *, /, **, >, <
2. Function calls use snake_case names with comma-separated arguments: ts_mean(close, 20)
3. Window sizes and periods are numeric arguments placed last in function calls.
4. Valid window sizes are integers, typically in range [2, 250].
5. Cross-sectional operators (cs_rank, cs_zscore, cs_demean) operate across all stocks at each time step -- they are crucial for making factors comparable.
6. Do NOT use $ prefix for features. Use `close`, `vol`, `amount`, etc. directly.
7. `vwap` is not a raw feature; use `amount / vol` if you need it.
8. `returns` is not a raw feature; use `close / ts_delay(close, 1) - 1` if you need returns.
## EXAMPLES OF WELL-FORMED FACTORS
- -cs_rank(ts_delta(close, 5))
Short-term reversal: rank of 5-day price change, negated.
- cs_zscore((vol - ts_mean(vol, 20)) / ts_std(vol, 20))
Volume surprise: standardized deviation from 20-day mean volume.
- cs_rank((close - amount / vol) / (amount / vol))
Intraday deviation from VWAP, cross-sectionally ranked.
- -ts_corr(vol, close, 10)
Negative price-volume correlation over 10 days.
- if_(close / ts_delay(close, 1) - 1 > 0, ts_std(close / ts_delay(close, 1) - 1, 10), -ts_std(close / ts_delay(close, 1) - 1, 10))
Conditional volatility: positive for up-moves, negative for down-moves.
- cs_rank((close - ts_min(low, 20)) / (ts_max(high, 20) - ts_min(low, 20)))
Position within 20-day price range, ranked.
## KEY PRINCIPLES FOR HIGH-QUALITY FACTORS
- Always wrap the outermost expression with a cross-sectional operator (cs_rank, cs_zscore) for comparability.
- Combine DIFFERENT operator types for novelty (e.g., time-series + cross-sectional + arithmetic).
- Use diverse window sizes; avoid always defaulting to 10.
- Explore uncommon feature combinations (amount, amount/vol are underused).
- Factors with depth 3-7 tend to be best: deep enough to capture non-trivial patterns but not so deep they overfit.
- Prefer economically meaningful combinations over random nesting.
- IMPORTANT: Avoid operators that are NOT listed above (e.g., Decay, TsLinRegSlope, HMA, DEMA, Resid). If you use them, the factor will be rejected.
"""
```
**Step 4: 更新所有输出格式示例**
`build_user_prompt`约第333行将示例公式替换为本地 DSL
```
1. momentum_reversal: -cs_rank(ts_delta(close, 5))
2. volume_surprise: cs_zscore((vol - ts_mean(vol, 20)) / ts_std(vol, 20))
```
`build_specialist_prompt`约第529行中同步替换
```
Example: 1. momentum_reversal: -cs_rank(ts_delta(close, 5))
```
**Step 5: 运行 prompt_builder 相关测试(若已有)**
```bash
uv run pytest tests/test_factorminer_prompt.py -v -k prompt
```
---
## Task 2: 修改 `src/factorminer/agent/output_parser.py`
**Files:**
- Modify: `src/factorminer/agent/output_parser.py`
- Test: `tests/test_factorminer_prompt.py`
**Step 1: 移除 FactorMiner 解析器依赖**
删除以下导入:
```python
from src.factorminer.core.expression_tree import ExpressionTree
from src.factorminer.core.parser import parse, try_parse
from src.factorminer.core.types import OperatorType, OPERATOR_REGISTRY
```
**Step 2: 修改 `CandidateFactor`**
```python
@dataclass
class CandidateFactor:
"""A candidate factor parsed from LLM output.
Attributes
----------
name : str
Descriptive snake_case name.
formula : str
DSL formula string.
category : str
Inferred category based on outermost operators.
parse_error : str
Error message if formula failed basic validation.
"""
name: str
formula: str
category: str = "unknown"
parse_error: str = ""
@property
def is_valid(self) -> bool:
return not self.parse_error and bool(self.formula.strip())
```
**Step 3: 修改 `_infer_category()`**
将所有 CamelCase 算子名替换为 snake_case
```python
def _infer_category(formula: str) -> str:
"""Infer a rough category from the outermost operators in the formula."""
if any(op in formula for op in ("cs_rank", "cs_zscore", "cs_demean", "cs_neutralize", "cs_winsorize")):
if any(op in formula for op in ("ts_corr", "ts_cov")):
return "cross_sectional_regression"
if any(op in formula for op in ("ts_delta", "ts_delay", "ts_pct_change")):
return "cross_sectional_momentum"
if any(op in formula for op in ("ts_std", "ts_var", "ts_skew", "ts_kurt")):
return "cross_sectional_volatility"
if any(op in formula for op in ("ts_mean", "ts_sum", "ts_ema", "ts_wma")):
return "cross_sectional_smoothing"
return "cross_sectional"
if any(op in formula for op in ("ts_corr", "ts_cov")):
return "regression"
if any(op in formula for op in ("ts_delta", "ts_delay", "ts_pct_change")):
return "momentum"
if any(op in formula for op in ("ts_std", "ts_var", "ts_skew", "ts_kurt")):
return "volatility"
if any(op in formula for op in ("if_", "where", ">", "<")):
return "conditional"
return "general"
```
**Step 4: 修改 `_FORMULA_ONLY_PATTERN`**
本地 DSL 公式可能以 `cs_`, `ts_` 开头,也可能以 `-` 开头(如 `-cs_rank(...)`),或字段名/数字开头:
```python
_FORMULA_ONLY_PATTERN = re.compile(
r"^\s*([a-zA-Z_][a-zA-Z0-9_]*\s*\(.*\)|-.*|\d.*)\s*$"
)
```
**Step 5: 修改 `_clean_formula()`**
移除 `$` 清洗逻辑(当前已不需要替换 `$` 前缀),保留注释、标点和反引号清理:
```python
def _clean_formula(formula: str) -> str:
"""Clean up a formula string before parsing."""
formula = formula.strip()
# Remove trailing comments
if " #" in formula:
formula = formula[: formula.index(" #")]
if " //" in formula:
formula = formula[: formula.index(" //")]
# Remove trailing punctuation
formula = formula.rstrip(";,.")
# Remove surrounding backticks
formula = formula.strip("`")
return formula.strip()
```
**Step 6: 重写 `_try_build_candidate()`**
不再调用 `try_parse(formula)``ExpressionTree`,仅做基础校验:
```python
def _try_build_candidate(name: str, formula: str) -> CandidateFactor:
"""Attempt to validate a formula and build a CandidateFactor."""
# Basic validation: parenthesis balance
if formula.count("(") != formula.count(")"):
return CandidateFactor(
name=name,
formula=formula,
category="unknown",
parse_error="括号不匹配",
)
category = _infer_category(formula)
return CandidateFactor(name=name, formula=formula, category=category)
```
**Step 7: 修改 `_generate_name_from_formula()`**
正则提取的逻辑调整为适配 snake_case 函数名(第一个括号前的部分):
```python
def _generate_name_from_formula(formula: str, index: int) -> str:
"""Generate a descriptive name from a formula."""
# Extract the outermost operator (snake_case)
m = re.match(r"([a-zA-Z_][a-zA-Z0-9_]*)\s*\(", formula)
if m:
outer_op = m.group(1).lower()
return f"{outer_op}_factor_{index + 1}"
# Handle unary minus
m = re.match(r"-([a-zA-Z_][a-zA-Z0-9_]*)\s*\(", formula)
if m:
outer_op = m.group(1).lower()
return f"neg_{outer_op}_factor_{index + 1}"
return f"factor_{index + 1}"
```
---
## Task 3: 适配 `src/factorminer/agent/factor_generator.py`
**Files:**
- Modify: `src/factorminer/agent/factor_generator.py`
**Step 1: 更新 retry prompt 的 DSL 规则描述**
`_retry_failed_parses` 方法中约第199行将 repair_prompt 中的描述改为本地 DSL 规则:
```python
repair_prompt = (
"The following factor formulas failed to parse. "
"Fix each one so it uses ONLY valid local DSL operators and features "
"from the library. Return them in the same numbered format:\n"
"<number>. <name>: <corrected_formula>\n\n"
"Broken formulas:\n"
+ "\n".join(f" {i+1}. {f}" for i, f in enumerate(failed))
+ "\n\nFix all syntax errors, unknown operators, and invalid "
"feature names. Use snake_case functions (e.g., ts_mean, cs_rank), "
"infix operators (+, -, *, /, >, <), and raw features without $ prefix. "
"Every formula must be valid in the local DSL."
)
```
**Step 2: 确认 `generate_batch` 无需修改**
因为 `CandidateFactor.is_valid` 已改为基于字符串校验,`generate_batch` 中的过滤逻辑自然兼容。
---
## Task 4: 编写测试 `tests/test_factorminer_prompt.py`
**Files:**
- Create: `tests/test_factorminer_prompt.py`
**Step 1: 测试 system prompt 使用本地 DSL**
```python
import pytest
from src.factorminer.agent.prompt_builder import SYSTEM_PROMPT
def test_system_prompt_uses_local_dsl():
assert "$close" not in SYSTEM_PROMPT
assert "CsRank(" not in SYSTEM_PROMPT
assert "cs_rank(" in SYSTEM_PROMPT
assert "close / ts_delay(close, 1) - 1" in SYSTEM_PROMPT
assert "ts_mean(close, 20)" in SYSTEM_PROMPT
```
**Step 2: 测试 OutputParser 正确提取本地 DSL**
```python
from src.factorminer.agent.output_parser import parse_llm_output, CandidateFactor
def test_parse_local_dsl_numbered_list():
raw = (
"1. momentum: -cs_rank(ts_delta(close, 5))\n"
"2. volume: cs_zscore((vol - ts_mean(vol, 20)) / ts_std(vol, 20))\n"
"3. vwap_dev: cs_rank((close - amount / vol) / (amount / vol))\n"
)
candidates, failed = parse_llm_output(raw)
assert len(candidates) == 3
assert candidates[0].name == "momentum"
assert candidates[0].formula == "-cs_rank(ts_delta(close, 5))"
assert candidates[0].is_valid
assert candidates[1].name == "volume"
assert candidates[1].formula == "cs_zscore((vol - ts_mean(vol, 20)) / ts_std(vol, 20))"
assert candidates[2].name == "vwap_dev"
assert not failed
```
**Step 3: 测试 formula-only 行**
```python
def test_parse_local_dsl_formula_only():
raw = "cs_rank(close / ts_delay(close, 5) - 1)"
candidates, failed = parse_llm_output(raw)
assert len(candidates) == 1
assert candidates[0].formula == "cs_rank(close / ts_delay(close, 5) - 1)"
assert not failed
```
**Step 4: 测试括号不匹配标记为无效**
```python
def test_parse_invalid_parentheses():
candidates, failed = parse_llm_output("1. bad: cs_rank(ts_delta(close, 5)")
assert len(candidates) == 1
assert not candidates[0].is_valid
assert "括号" in candidates[0].parse_error
```
**Step 5: 测试分类推断**
```python
def test_infer_category_local_dsl():
from src.factorminer.agent.output_parser import _infer_category
assert _infer_category("cs_rank(ts_delta(close, 5))") == "cross_sectional_momentum"
assert _infer_category("ts_corr(vol, close, 10)") == "regression"
assert _infer_category("ts_std(close, 20)") == "volatility"
assert _infer_category("if_(close > open, 1, -1)") == "conditional"
```
**Step 6: 运行测试**
```bash
uv run pytest tests/test_factorminer_prompt.py -v
```
预期:所有测试通过。
---
## 执行命令汇总
```bash
# 安装依赖(若尚未安装)
uv pip install -e .
# 运行新增测试
uv run pytest tests/test_factorminer_prompt.py -v
# 运行 factorminer 相关测试
uv run pytest tests/test_factorminer_* -v
```
---
## 提交建议
修改完成后建议拆分为两个 commits
1. `refactor(factorminer): rewrite LLM prompts to output local snake_case DSL`
2. `test(factorminer): add prompt and output parser tests for local DSL`

View File

@@ -117,7 +117,9 @@ class FactorGenerator:
max_tokens=self.max_tokens, max_tokens=self.max_tokens,
) )
elapsed = time.monotonic() - t0 elapsed = time.monotonic() - t0
logger.info("LLM response received in %.1fs (%d chars)", elapsed, len(raw_output)) logger.info(
"LLM response received in %.1fs (%d chars)", elapsed, len(raw_output)
)
# 3. Parse output # 3. Parse output
candidates, failed_lines = parse_llm_output(raw_output) candidates, failed_lines = parse_llm_output(raw_output)
@@ -198,14 +200,15 @@ class FactorGenerator:
repair_prompt = ( repair_prompt = (
"The following factor formulas failed to parse. " "The following factor formulas failed to parse. "
"Fix each one so it uses ONLY valid operators and features " "Fix each one so it uses ONLY valid local DSL operators and features "
"from the library. Return them in the same numbered format:\n" "from the library. Return them in the same numbered format:\n"
"<number>. <name>: <corrected_formula>\n\n" "<number>. <name>: <corrected_formula>\n\n"
"Broken formulas:\n" "Broken formulas:\n"
+ "\n".join(f" {i+1}. {f}" for i, f in enumerate(failed)) + "\n".join(f" {i + 1}. {f}" for i, f in enumerate(failed))
+ "\n\nFix all syntax errors, unknown operators, and invalid " + "\n\nFix all syntax errors, unknown operators, and invalid "
"feature names. Every formula must be a valid nested function " "feature names. Use snake_case functions (e.g., ts_mean, cs_rank), "
"call using only operators from the library." "infix operators (+, -, *, /, >, <), and raw features without $ prefix. "
"Every formula must be valid in the local DSL."
) )
try: try:

View File

@@ -1,20 +1,16 @@
"""Parse LLM output into structured CandidateFactor objects. """Parse LLM output into structured CandidateFactor objects.
Handles various output formats from LLMs: numbered lists, JSON, Handles various output formats from LLMs: numbered lists, JSON,
markdown code blocks, and raw text. Validates each formula against markdown code blocks, and raw text. Validates each formula with
the expression tree parser. basic string checks; no FactorMiner-specific parsing is performed.
""" """
from __future__ import annotations from __future__ import annotations
import logging import logging
import re import re
from dataclasses import dataclass, field from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple from typing import List, Optional, Tuple
from src.factorminer.core.expression_tree import ExpressionTree
from src.factorminer.core.parser import parse, try_parse
from src.factorminer.core.types import OperatorType, OPERATOR_REGISTRY
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -29,49 +25,44 @@ class CandidateFactor:
Descriptive snake_case name. Descriptive snake_case name.
formula : str formula : str
DSL formula string. DSL formula string.
expression_tree : ExpressionTree or None
Parsed expression tree (None if parsing failed).
category : str category : str
Inferred category based on outermost operators. Inferred category based on outermost operators.
parse_error : str parse_error : str
Error message if formula failed to parse. Error message if formula failed basic validation.
""" """
name: str name: str
formula: str formula: str
expression_tree: Optional[ExpressionTree] = None
category: str = "unknown" category: str = "unknown"
parse_error: str = "" parse_error: str = ""
@property @property
def is_valid(self) -> bool: def is_valid(self) -> bool:
return self.expression_tree is not None return not self.parse_error and bool(self.formula.strip())
def _infer_category(formula: str) -> str: def _infer_category(formula: str) -> str:
"""Infer a rough category from the outermost operators in the formula.""" """Infer a rough category from the outermost operators in the formula."""
lower = formula.lower() if any(
# Check for cross-sectional operators at the top op in formula
if any(op in formula for op in ("CsRank", "CsZScore", "CsDemean", "CsScale", "CsNeutralize", "CsQuantile")): for op in ("cs_rank", "cs_zscore", "cs_demean", "cs_neutralize", "cs_winsorize")
# Look deeper for sub-category ):
if any(op in formula for op in ("Corr", "Cov", "Beta", "Resid")): if any(op in formula for op in ("ts_corr", "ts_cov")):
return "cross_sectional_regression" return "cross_sectional_regression"
if any(op in formula for op in ("Delta", "Delay", "Return", "LogReturn")): if any(op in formula for op in ("ts_delta", "ts_delay", "ts_pct_change")):
return "cross_sectional_momentum" return "cross_sectional_momentum"
if any(op in formula for op in ("Std", "Var", "Skew", "Kurt")): if any(op in formula for op in ("ts_std", "ts_var", "ts_skew", "ts_kurt")):
return "cross_sectional_volatility" return "cross_sectional_volatility"
if any(op in formula for op in ("Mean", "Sum", "EMA", "SMA", "WMA", "DEMA", "HMA", "KAMA")): if any(op in formula for op in ("ts_mean", "ts_sum", "ts_ema", "ts_wma")):
return "cross_sectional_smoothing" return "cross_sectional_smoothing"
if any(op in formula for op in ("TsLinReg", "TsLinRegSlope")):
return "cross_sectional_trend"
return "cross_sectional" return "cross_sectional"
if any(op in formula for op in ("Corr", "Cov", "Beta", "Resid")): if any(op in formula for op in ("ts_corr", "ts_cov")):
return "regression" return "regression"
if any(op in formula for op in ("Delta", "Delay", "Return", "LogReturn")): if any(op in formula for op in ("ts_delta", "ts_delay", "ts_pct_change")):
return "momentum" return "momentum"
if any(op in formula for op in ("Std", "Var", "Skew", "Kurt")): if any(op in formula for op in ("ts_std", "ts_var", "ts_skew", "ts_kurt")):
return "volatility" return "volatility"
if any(op in formula for op in ("IfElse", "Greater", "Less")): if any(op in formula for op in ("if_", "where", ">", "<")):
return "conditional" return "conditional"
return "general" return "general"
@@ -95,15 +86,13 @@ _PLAIN_PATTERN = re.compile(
r"(.+)$" # formula r"(.+)$" # formula
) )
# Pattern: just a formula starting with an operator # Pattern: just a formula starting with a function call, unary minus, or number
_FORMULA_ONLY_PATTERN = re.compile( _FORMULA_ONLY_PATTERN = re.compile(
r"^\s*([A-Z][a-zA-Z]*\(.+\))\s*$" r"^\s*([a-zA-Z_][a-zA-Z0-9_]*\s*\(.*\)|-.*|\d.*)\s*$"
) )
# Pattern: JSON-like {"name": "...", "formula": "..."} # Pattern: JSON-like {"name": "...", "formula": "..."}
_JSON_PATTERN = re.compile( _JSON_PATTERN = re.compile(r'"name"\s*:\s*"([^"]+)"\s*,\s*"formula"\s*:\s*"([^"]+)"')
r'"name"\s*:\s*"([^"]+)"\s*,\s*"formula"\s*:\s*"([^"]+)"'
)
def _strip_markdown(text: str) -> str: def _strip_markdown(text: str) -> str:
@@ -152,6 +141,8 @@ def parse_llm_output(raw_text: str) -> Tuple[List[CandidateFactor], List[str]]:
json_matches = _JSON_PATTERN.findall(text) json_matches = _JSON_PATTERN.findall(text)
if json_matches: if json_matches:
for name, formula in json_matches: for name, formula in json_matches:
if name is None or formula is None:
continue
formula = _clean_formula(formula) formula = _clean_formula(formula)
candidate = _try_build_candidate(name, formula) candidate = _try_build_candidate(name, formula)
if candidate.name not in seen_names: if candidate.name not in seen_names:
@@ -184,6 +175,7 @@ def parse_llm_output(raw_text: str) -> Tuple[List[CandidateFactor], List[str]]:
m = _FORMULA_ONLY_PATTERN.match(line) m = _FORMULA_ONLY_PATTERN.match(line)
if m: if m:
formula = m.group(1) formula = m.group(1)
assert formula is not None
# Generate name from formula # Generate name from formula
name = _generate_name_from_formula(formula, len(candidates)) name = _generate_name_from_formula(formula, len(candidates))
@@ -222,38 +214,29 @@ def parse_llm_output(raw_text: str) -> Tuple[List[CandidateFactor], List[str]]:
def _try_build_candidate(name: str, formula: str) -> CandidateFactor: def _try_build_candidate(name: str, formula: str) -> CandidateFactor:
"""Attempt to parse a formula and build a CandidateFactor.""" """Attempt to validate a formula and build a CandidateFactor."""
tree = try_parse(formula) # Basic validation: parenthesis balance
if tree is not None: if formula.count("(") != formula.count(")"):
category = _infer_category(formula)
return CandidateFactor(
name=name,
formula=tree.to_string(), # canonicalize
expression_tree=tree,
category=category,
)
else:
# Try to get a useful error message
error_msg = ""
try:
parse(formula)
except (SyntaxError, KeyError, ValueError) as e:
error_msg = str(e)
return CandidateFactor( return CandidateFactor(
name=name, name=name,
formula=formula, formula=formula,
expression_tree=None,
category="unknown", category="unknown",
parse_error=error_msg, parse_error="括号不匹配",
) )
category = _infer_category(formula)
return CandidateFactor(name=name, formula=formula, category=category)
def _generate_name_from_formula(formula: str, index: int) -> str: def _generate_name_from_formula(formula: str, index: int) -> str:
"""Generate a descriptive name from a formula.""" """Generate a descriptive name from a formula."""
# Extract the outermost operator # Extract the outermost operator (snake_case)
m = re.match(r"([A-Z][a-zA-Z]*)\(", formula) m = re.match(r"([a-zA-Z_][a-zA-Z0-9_]*)\s*\(", formula)
if m: if m:
outer_op = m.group(1).lower() outer_op = m.group(1).lower()
return f"{outer_op}_factor_{index + 1}" return f"{outer_op}_factor_{index + 1}"
# Handle unary minus
m = re.match(r"-([a-zA-Z_][a-zA-Z0-9_]*)\s*\(", formula)
if m:
outer_op = m.group(1).lower()
return f"neg_{outer_op}_factor_{index + 1}"
return f"factor_{index + 1}" return f"factor_{index + 1}"

View File

@@ -9,68 +9,81 @@ from __future__ import annotations
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
from src.factorminer.core.types import (
FEATURES, LOCAL_OPERATOR_TABLE = {
OPERATOR_REGISTRY, "ARITHMETIC": [
OperatorSpec, ("+", "二元", "x + y"),
OperatorType, ("-", "二元/一元", "x - y 或 -x"),
) ("*", "二元", "x * y"),
("/", "二元", "x / y"),
("**", "二元", "x ** y (幂运算)"),
(">", "二元", "x > y (条件判断,返回 0/1)"),
("<", "二元", "x < y (条件判断,返回 0/1)"),
("abs(x)", "一元", "绝对值"),
("sign(x)", "一元", "符号函数"),
("max_(x, y)", "二元", "逐元素最大值"),
("min_(x, y)", "二元", "逐元素最小值"),
("clip(x, lower, upper)", "一元带参", "截断"),
("log(x)", "一元", "自然对数"),
("sqrt(x)", "一元", "平方根"),
("exp(x)", "一元", "指数函数"),
],
"TIMESERIES": [
("ts_mean(x, window)", "一元+窗口", "滚动均值"),
("ts_std(x, window)", "一元+窗口", "滚动标准差"),
("ts_var(x, window)", "一元+窗口", "滚动方差"),
("ts_max(x, window)", "一元+窗口", "滚动最大值"),
("ts_min(x, window)", "一元+窗口", "滚动最小值"),
("ts_sum(x, window)", "一元+窗口", "滚动求和"),
("ts_delay(x, periods)", "一元+周期", "滞后 N 期"),
("ts_delta(x, periods)", "一元+周期", "差分 N 期"),
("ts_corr(x, y, window)", "二元+窗口", "滚动相关系数"),
("ts_cov(x, y, window)", "二元+窗口", "滚动协方差"),
("ts_pct_change(x, periods)", "一元+周期", "N 期百分比变化"),
("ts_ema(x, window)", "一元+窗口", "指数移动平均"),
("ts_wma(x, window)", "一元+窗口", "加权移动平均"),
("ts_skew(x, window)", "一元+窗口", "滚动偏度"),
("ts_kurt(x, window)", "一元+窗口", "滚动峰度"),
("ts_rank(x, window)", "一元+窗口", "滚动分位排名"),
],
"CROSS_SECTIONAL": [
("cs_rank(x)", "一元", "截面排名(分位数)"),
("cs_zscore(x)", "一元", "截面 Z-Score 标准化"),
("cs_demean(x)", "一元", "截面去均值"),
("cs_neutralize(x, group)", "一元", "行业/市值中性化"),
("cs_winsorize(x, lower, upper)", "一元", "截面缩尾处理"),
],
"CONDITIONAL": [
("if_(condition, true_val, false_val)", "三元", "条件选择"),
("where(condition, true_val, false_val)", "三元", "if_ 的别名"),
],
}
def _format_operator_table() -> str: def _format_operator_table() -> str:
"""Build a human-readable operator reference table grouped by category.""" """Build a human-readable operator reference table grouped by category."""
grouped: Dict[str, List[OperatorSpec]] = {} lines = []
for spec in OPERATOR_REGISTRY.values(): for cat_name, ops in LOCAL_OPERATOR_TABLE.items():
cat = spec.category.name
grouped.setdefault(cat, []).append(spec)
lines: List[str] = []
for cat_name in [
"ARITHMETIC",
"STATISTICAL",
"TIMESERIES",
"SMOOTHING",
"CROSS_SECTIONAL",
"REGRESSION",
"LOGICAL",
"AUTO_INVENTED",
]:
specs = grouped.get(cat_name, [])
if not specs:
continue
lines.append(f"\n### {cat_name} operators") lines.append(f"\n### {cat_name} operators")
for spec in sorted(specs, key=lambda s: s.name): for op_sig, arity, desc in ops:
params_str = "" lines.append(f"- {op_sig}: {desc} ({arity})")
if spec.param_names:
parts = []
for pname in spec.param_names:
default = spec.param_defaults.get(pname, "")
lo, hi = spec.param_ranges.get(pname, (None, None))
range_str = f"[{lo}-{hi}]" if lo is not None else ""
parts.append(f"{pname}={default}{range_str}")
params_str = f" params: {', '.join(parts)}"
arity_args = ", ".join([f"expr{i+1}" for i in range(spec.arity)])
if spec.param_names:
arity_args += ", " + ", ".join(spec.param_names)
lines.append(f"- {spec.name}({arity_args}): {spec.description}{params_str}")
return "\n".join(lines) return "\n".join(lines)
def _format_feature_list() -> str: def _format_feature_list() -> str:
"""Build a description of available raw features.""" """Build a description of available raw features."""
descriptions = { descriptions = {
"$open": "opening price", "open": "开盘价",
"$high": "highest price in the bar", "high": "最高价",
"$low": "lowest price in the bar", "low": "最低价",
"$close": "closing price", "close": "收盘价",
"$volume": "trading volume (shares)", "vol": "成交量(股数)",
"$amt": "trading amount (currency value)", "amount": "成交额(金额)",
"$vwap": "volume-weighted average price", "vwap": "可用 amount / vol 计算",
"$returns": "close-to-close returns", "returns": "可用 close / ts_delay(close, 1) - 1 计算",
} }
lines = [] lines = []
for feat in FEATURES: for feat, desc in descriptions.items():
desc = descriptions.get(feat, "")
lines.append(f" {feat}: {desc}") lines.append(f" {feat}: {desc}")
return "\n".join(lines) return "\n".join(lines)
@@ -81,7 +94,7 @@ def _format_feature_list() -> str:
SYSTEM_PROMPT = f"""You are a quantitative researcher mining formulaic alpha factors for stock selection. SYSTEM_PROMPT = f"""You are a quantitative researcher mining formulaic alpha factors for stock selection.
Your goal is to generate novel, predictive factor expressions using a tree-structured domain-specific language (DSL). Each factor is a composition of operators applied to raw market features. Your goal is to generate novel, predictive factor expressions using the local ProStock DSL. Each factor is a composition of operators applied to raw market features.
## RAW FEATURES (leaf nodes) ## RAW FEATURES (leaf nodes)
{_format_feature_list()} {_format_feature_list()}
@@ -90,40 +103,37 @@ Your goal is to generate novel, predictive factor expressions using a tree-struc
{_format_operator_table()} {_format_operator_table()}
## EXPRESSION SYNTAX RULES ## EXPRESSION SYNTAX RULES
1. Every expression is a nested function call: Operator(args...) 1. Expressions use Python-style infix operators: +, -, *, /, **, >, <
2. Leaf nodes are raw features ($close, $volume, etc.) or numeric constants. 2. Function calls use snake_case names with comma-separated arguments: ts_mean(close, 20)
3. Operators are called by name with expression arguments first, then numeric parameters: 3. Window sizes and periods are numeric arguments placed last in function calls.
- Mean($close, 20) = 20-day rolling mean of $close 4. Valid window sizes are integers, typically in range [2, 250].
- Corr($close, $volume, 10) = 10-day rolling correlation of close and volume 5. Cross-sectional operators (cs_rank, cs_zscore, cs_demean) operate across all stocks at each time step -- they are crucial for making factors comparable.
- IfElse(Greater($returns, 0), $volume, Neg($volume)) = conditional 6. Do NOT use $ prefix for features. Use `close`, `vol`, `amount`, etc. directly.
4. No infix operators; use Add(x,y) instead of x+y, Sub(x,y) instead of x-y, etc. 7. `vwap` is not a raw feature; use `amount / vol` if you need it.
5. Parameters like window sizes are trailing numeric arguments after expression children. 8. `returns` is not a raw feature; use `close / ts_delay(close, 1) - 1` if you need returns.
6. Valid window sizes are integers; check each operator's parameter ranges above.
7. Cross-sectional operators (CsRank, CsZScore, CsDemean, CsScale, CsNeutralize) operate across all stocks at each time step -- they are crucial for making factors comparable.
## EXAMPLES OF WELL-FORMED FACTORS ## EXAMPLES OF WELL-FORMED FACTORS
- Neg(CsRank(Delta($close, 5))) - -cs_rank(ts_delta(close, 5))
Short-term reversal: rank of 5-day price change, negated. Short-term reversal: rank of 5-day price change, negated.
- CsZScore(Div(Sub($volume, Mean($volume, 20)), Std($volume, 20))) - cs_zscore((vol - ts_mean(vol, 20)) / ts_std(vol, 20))
Volume surprise: standardized deviation from 20-day mean volume. Volume surprise: standardized deviation from 20-day mean volume.
- CsRank(Div(Sub($close, $vwap), $vwap)) - cs_rank((close - amount / vol) / (amount / vol))
Intraday deviation from VWAP, cross-sectionally ranked. Intraday deviation from VWAP, cross-sectionally ranked.
- Neg(Corr($volume, $close, 10)) - -ts_corr(vol, close, 10)
Negative price-volume correlation over 10 days. Negative price-volume correlation over 10 days.
- CsRank(TsLinRegSlope($volume, 20)) - if_(close / ts_delay(close, 1) - 1 > 0, ts_std(close / ts_delay(close, 1) - 1, 10), -ts_std(close / ts_delay(close, 1) - 1, 10))
Trend in trading volume over 20 days, ranked.
- IfElse(Greater($returns, 0), Std($returns, 10), Neg(Std($returns, 10)))
Conditional volatility: positive for up-moves, negative for down-moves. Conditional volatility: positive for up-moves, negative for down-moves.
- CsRank(Div(Sub($close, TsMin($low, 20)), Sub(TsMax($high, 20), TsMin($low, 20)))) - cs_rank((close - ts_min(low, 20)) / (ts_max(high, 20) - ts_min(low, 20)))
Position within 20-day price range, ranked. Position within 20-day price range, ranked.
## KEY PRINCIPLES FOR HIGH-QUALITY FACTORS ## KEY PRINCIPLES FOR HIGH-QUALITY FACTORS
- Always wrap the outermost expression with a cross-sectional operator (CsRank, CsZScore) for comparability. - Always wrap the outermost expression with a cross-sectional operator (cs_rank, cs_zscore) for comparability.
- Combine DIFFERENT operator types for novelty (e.g., time-series + cross-sectional + arithmetic). - Combine DIFFERENT operator types for novelty (e.g., time-series + cross-sectional + arithmetic).
- Use diverse window sizes; avoid always defaulting to 10. - Use diverse window sizes; avoid always defaulting to 10.
- Explore uncommon feature combinations ($amt, $vwap are underused). - Explore uncommon feature combinations (amount, amount/vol are underused).
- Factors with depth 3-7 tend to be best: deep enough to capture non-trivial patterns but not so deep they overfit. - Factors with depth 3-7 tend to be best: deep enough to capture non-trivial patterns but not so deep they overfit.
- Prefer economically meaningful combinations over random nesting. - Prefer economically meaningful combinations over random nesting.
- IMPORTANT: Avoid operators that are NOT listed above (e.g., Decay, TsLinRegSlope, HMA, DEMA, Resid). If you use them, the factor will be rejected.
""" """
@@ -131,6 +141,7 @@ Your goal is to generate novel, predictive factor expressions using a tree-struc
# PromptBuilder # PromptBuilder
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
def normalize_factor_references(entries: Optional[List[Any]]) -> List[str]: def normalize_factor_references(entries: Optional[List[Any]]) -> List[str]:
"""Convert mixed factor metadata into prompt-safe string references.""" """Convert mixed factor metadata into prompt-safe string references."""
if not entries: if not entries:
@@ -220,13 +231,10 @@ class PromptBuilder:
lib_size = library_state.get("size", 0) lib_size = library_state.get("size", 0)
target = library_state.get("target_size", 110) target = library_state.get("target_size", 110)
sections.append( sections.append(
f"\n## CURRENT LIBRARY STATUS\n" f"\n## CURRENT LIBRARY STATUS\nLibrary size: {lib_size} / {target} factors."
f"Library size: {lib_size} / {target} factors."
) )
recent = normalize_factor_references( recent = normalize_factor_references(library_state.get("recent_admissions", []))
library_state.get("recent_admissions", [])
)
if recent: if recent:
sections.append( sections.append(
"Recently admitted factors:\n" "Recently admitted factors:\n"
@@ -235,10 +243,10 @@ class PromptBuilder:
saturation = library_state.get("domain_saturation", {}) saturation = library_state.get("domain_saturation", {})
if saturation: if saturation:
sat_lines = [f" {domain}: {pct:.0%} saturated" for domain, pct in saturation.items()] sat_lines = [
sections.append( f" {domain}: {pct:.0%} saturated" for domain, pct in saturation.items()
"Domain saturation:\n" + "\n".join(sat_lines) ]
) sections.append("Domain saturation:\n" + "\n".join(sat_lines))
# --- Memory signal: recommended directions --- # --- Memory signal: recommended directions ---
rec_dirs = memory_signal.get("recommended_directions", []) rec_dirs = memory_signal.get("recommended_directions", [])
@@ -266,10 +274,7 @@ class PromptBuilder:
helix_prompt_text = memory_signal.get("prompt_text", "").strip() helix_prompt_text = memory_signal.get("prompt_text", "").strip()
if helix_prompt_text: if helix_prompt_text:
sections.append( sections.append(f"\n## HELIX RETRIEVAL SUMMARY\n{helix_prompt_text}")
"\n## HELIX RETRIEVAL SUMMARY\n"
f"{helix_prompt_text}"
)
complementary_patterns = memory_signal.get("complementary_patterns", []) complementary_patterns = memory_signal.get("complementary_patterns", [])
if complementary_patterns: if complementary_patterns:
@@ -330,8 +335,8 @@ class PromptBuilder:
f"Output exactly {batch_size} factors, one per line.\n" f"Output exactly {batch_size} factors, one per line.\n"
f"Format each line as: <number>. <factor_name>: <formula>\n" f"Format each line as: <number>. <factor_name>: <formula>\n"
f"Example:\n" f"Example:\n"
f"1. momentum_reversal: Neg(CsRank(Delta($close, 5)))\n" f"1. momentum_reversal: -cs_rank(ts_delta(close, 5))\n"
f"2. volume_surprise: CsZScore(Div(Sub($volume, Mean($volume, 20)), Std($volume, 20)))\n" f"2. volume_surprise: cs_zscore((vol - ts_mean(vol, 20)) / ts_std(vol, 20))\n"
f"\nRules:\n" f"\nRules:\n"
f"- factor_name: lowercase_with_underscores, descriptive, unique\n" f"- factor_name: lowercase_with_underscores, descriptive, unique\n"
f"- formula: valid DSL expression using ONLY operators and features listed above\n" f"- formula: valid DSL expression using ONLY operators and features listed above\n"
@@ -346,6 +351,7 @@ class PromptBuilder:
# New specialist/critic/debate prompt builder functions # New specialist/critic/debate prompt builder functions
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
def build_specialist_prompt( def build_specialist_prompt(
specialist_name: str, specialist_name: str,
specialist_domain: str, specialist_domain: str,
@@ -413,16 +419,12 @@ def build_specialist_prompt(
# Regime context # Regime context
if regime_context: if regime_context:
sections.append( sections.append(f"\n## CURRENT MARKET REGIME\n{regime_context}")
f"\n## CURRENT MARKET REGIME\n{regime_context}"
)
# Library state # Library state
lib_size = library_diagnostics.get("size", 0) lib_size = library_diagnostics.get("size", 0)
target = library_diagnostics.get("target_size", 110) target = library_diagnostics.get("target_size", 110)
sections.append( sections.append(f"\n## LIBRARY STATUS\nCurrent: {lib_size}/{target} factors.")
f"\n## LIBRARY STATUS\nCurrent: {lib_size}/{target} factors."
)
recent = normalize_factor_references( recent = normalize_factor_references(
library_diagnostics.get("recent_admissions", []) library_diagnostics.get("recent_admissions", [])
@@ -435,31 +437,26 @@ def build_specialist_prompt(
saturation = library_diagnostics.get("domain_saturation", {}) saturation = library_diagnostics.get("domain_saturation", {})
if saturation: if saturation:
sat_lines = [ sat_lines = [f" {d}: {p:.0%} saturated" for d, p in saturation.items()]
f" {d}: {p:.0%} saturated" for d, p in saturation.items()
]
sections.append("Domain saturation:\n" + "\n".join(sat_lines)) sections.append("Domain saturation:\n" + "\n".join(sat_lines))
# Memory signal injections # Memory signal injections
rec_dirs = memory_signal.get("recommended_directions", []) rec_dirs = memory_signal.get("recommended_directions", [])
if rec_dirs: if rec_dirs:
sections.append( sections.append(
"\n## RECOMMENDED DIRECTIONS\n" "\n## RECOMMENDED DIRECTIONS\n" + "\n".join(f" * {d}" for d in rec_dirs)
+ "\n".join(f" * {d}" for d in rec_dirs)
) )
forbidden = memory_signal.get("forbidden_directions", []) forbidden = memory_signal.get("forbidden_directions", [])
if forbidden: if forbidden:
sections.append( sections.append(
"\n## FORBIDDEN DIRECTIONS\n" "\n## FORBIDDEN DIRECTIONS\n" + "\n".join(f" X {d}" for d in forbidden)
+ "\n".join(f" X {d}" for d in forbidden)
) )
insights = memory_signal.get("strategic_insights", []) insights = memory_signal.get("strategic_insights", [])
if insights: if insights:
sections.append( sections.append(
"\n## STRATEGIC INSIGHTS\n" "\n## STRATEGIC INSIGHTS\n" + "\n".join(f" - {ins}" for ins in insights)
+ "\n".join(f" - {ins}" for ins in insights)
) )
helix_text = memory_signal.get("prompt_text", "").strip() helix_text = memory_signal.get("prompt_text", "").strip()
@@ -476,8 +473,7 @@ def build_specialist_prompt(
warn = memory_signal.get("conflict_warnings", []) warn = memory_signal.get("conflict_warnings", [])
if warn: if warn:
sections.append( sections.append(
"\n## SATURATION WARNINGS\n" "\n## SATURATION WARNINGS\n" + "\n".join(f" ! {w}" for w in warn)
+ "\n".join(f" ! {w}" for w in warn)
) )
gaps = memory_signal.get("semantic_gaps", []) gaps = memory_signal.get("semantic_gaps", [])
@@ -508,8 +504,7 @@ def build_specialist_prompt(
# Avoid patterns # Avoid patterns
if avoid_patterns: if avoid_patterns:
sections.append( sections.append(
"\n## PATTERNS TO AVOID\n" "\n## PATTERNS TO AVOID\n" + "\n".join(f" X {av}" for av in avoid_patterns)
+ "\n".join(f" X {av}" for av in avoid_patterns)
) )
# Few-shot patterns from memory # Few-shot patterns from memory
@@ -526,7 +521,7 @@ def build_specialist_prompt(
f"\n## OUTPUT FORMAT\n" f"\n## OUTPUT FORMAT\n"
f"Generate exactly {n_proposals} novel factor candidates.\n" f"Generate exactly {n_proposals} novel factor candidates.\n"
f"Format: <number>. <factor_name>: <formula>\n" f"Format: <number>. <factor_name>: <formula>\n"
f"Example: 1. momentum_reversal: Neg(CsRank(Delta($close, 5)))\n" f"Example: 1. momentum_reversal: -cs_rank(ts_delta(close, 5))\n"
f"Rules:\n" f"Rules:\n"
f"- factor_name: lowercase_with_underscores, unique, descriptive\n" f"- factor_name: lowercase_with_underscores, unique, descriptive\n"
f"- formula: valid DSL expression only\n" f"- formula: valid DSL expression only\n"
@@ -581,16 +576,16 @@ def build_critic_scoring_prompt(
) )
if memory_signal: if memory_signal:
sections.append(f"\n## MEMORY CONTEXT (success patterns)\n{memory_signal[:600]}") sections.append(
f"\n## MEMORY CONTEXT (success patterns)\n{memory_signal[:600]}"
)
sections.append("\n## CANDIDATES") sections.append("\n## CANDIDATES")
for c in candidates: for c in candidates:
name = c.get("name", "unknown") name = c.get("name", "unknown")
formula = c.get("formula", "") formula = c.get("formula", "")
specialist = c.get("specialist", "unknown") specialist = c.get("specialist", "unknown")
sections.append( sections.append(f" [{specialist}] {name}: {formula}")
f" [{specialist}] {name}: {formula}"
)
sections.append( sections.append(
"\n## SCORING CRITERIA\n" "\n## SCORING CRITERIA\n"
@@ -647,7 +642,7 @@ def build_debate_synthesis_prompt(
all_proposals, all_proposals,
key=lambda p: score_map.get(p.get("name", ""), 0.0), key=lambda p: score_map.get(p.get("name", ""), 0.0),
reverse=True, reverse=True,
)[:top_k * 2] # take 2x top_k for synthesis )[: top_k * 2] # take 2x top_k for synthesis
sections: List[str] = [] sections: List[str] = []
sections.append( sections.append(
@@ -664,9 +659,7 @@ def build_debate_synthesis_prompt(
formula = p.get("formula", "?") formula = p.get("formula", "?")
specialist = p.get("specialist", "?") specialist = p.get("specialist", "?")
score = score_map.get(name, 0.5) score = score_map.get(name, 0.5)
sections.append( sections.append(f" [{specialist}, score={score:.2f}] {name}: {formula}")
f" [{specialist}, score={score:.2f}] {name}: {formula}"
)
sections.append( sections.append(
f"\n## SELECTION CRITERIA\n" f"\n## SELECTION CRITERIA\n"

View File

@@ -0,0 +1,72 @@
"""Tests for local DSL prompt and output parser in Step 3."""
from src.factorminer.agent.output_parser import (
CandidateFactor,
_infer_category,
parse_llm_output,
)
from src.factorminer.agent.prompt_builder import SYSTEM_PROMPT
def test_system_prompt_uses_local_dsl():
assert "$close" not in SYSTEM_PROMPT
assert "CsRank(" not in SYSTEM_PROMPT
assert "cs_rank(" in SYSTEM_PROMPT
assert "close / ts_delay(close, 1) - 1" in SYSTEM_PROMPT
assert "ts_mean(close, 20)" in SYSTEM_PROMPT
def test_parse_local_dsl_numbered_list():
raw = (
"1. momentum: -cs_rank(ts_delta(close, 5))\n"
"2. volume: cs_zscore((vol - ts_mean(vol, 20)) / ts_std(vol, 20))\n"
"3. vwap_dev: cs_rank((close - amount / vol) / (amount / vol))\n"
)
candidates, failed = parse_llm_output(raw)
assert len(candidates) == 3
assert candidates[0].name == "momentum"
assert candidates[0].formula == "-cs_rank(ts_delta(close, 5))"
assert candidates[0].is_valid
assert candidates[1].name == "volume"
assert (
candidates[1].formula == "cs_zscore((vol - ts_mean(vol, 20)) / ts_std(vol, 20))"
)
assert candidates[2].name == "vwap_dev"
assert not failed
def test_parse_local_dsl_formula_only():
raw = "cs_rank(close / ts_delay(close, 5) - 1)"
candidates, failed = parse_llm_output(raw)
assert len(candidates) == 1
assert candidates[0].formula == "cs_rank(close / ts_delay(close, 5) - 1)"
assert not failed
def test_parse_invalid_parentheses():
candidates, failed = parse_llm_output("1. bad: cs_rank(ts_delta(close, 5)")
assert len(candidates) == 1
assert not candidates[0].is_valid
assert "括号" in candidates[0].parse_error
def test_infer_category_local_dsl():
assert _infer_category("cs_rank(ts_delta(close, 5))") == "cross_sectional_momentum"
assert _infer_category("ts_corr(vol, close, 10)") == "regression"
assert _infer_category("ts_std(close, 20)") == "volatility"
assert _infer_category("if_(close > open, 1, -1)") == "conditional"
def test_candidate_factor_is_valid_without_tree():
cf = CandidateFactor(name="test", formula="cs_rank(close)")
assert cf.is_valid
assert cf.category == "unknown"
def test_parse_json_local_dsl():
raw = '{"name": "mom", "formula": "cs_rank(close / ts_delay(close, 5) - 1)"}'
candidates, failed = parse_llm_output(raw)
assert len(candidates) == 1
assert candidates[0].name == "mom"
assert candidates[0].formula == "cs_rank(close / ts_delay(close, 5) - 1)"
assert not failed