Files
ProStock/docs/plans/2026-04-08-step3-llm-prompt-local-dsl.md
liaozhaorun dd2e8a4a8e refactor(factorminer): 将 LLM Prompt 和解析器改造为直接输出本地 DSL
- DSL 规范改为 snake_case、中缀运算符,示例同步替换
- 移除 ExpressionTree 依赖,改为括号匹配等基础校验
- retry prompt 适配本地 DSL 规则
2026-04-08 22:27:33 +08:00

459 lines
16 KiB
Markdown
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# Step 3: LLM Prompt 改造(直接生成本地 DSL实施计划
> **For Claude:** REQUIRED SUB-SKILL: Use superpowers:executing-plans to implement this plan task-by-task.
**Goal:** 将 FactorMiner 的 LLM Prompt 和输出解析器从 CamelCase + `$` 前缀 DSL 改造为直接生成本地 snake_case DSL移除运行时翻译层。
**Architecture:** Prompt 直接使用本地 `FactorEngine` 支持的 snake_case 函数名和字段名;`OutputParser` 仅做字符串提取和轻量清洗,不再调用 FactorMiner 的 `ExpressionTree` 解析;`factor_generator.py` 配合返回原始 DSL 字符串。
**Tech Stack:** Python, ProStock `src.factors` 本地 DSL (`FactorEngine`)
---
## Task 1: 重写 `src/factorminer/agent/prompt_builder.py`
**Files:**
- Modify: `src/factorminer/agent/prompt_builder.py`
- Test: `tests/test_factorminer_prompt.py`
**Step 1: 重写字段列表函数 `_format_feature_list()`**
`$` 前缀字段替换为本地字段,并添加计算说明:
```python
def _format_feature_list() -> str:
descriptions = {
"open": "开盘价",
"high": "最高价",
"low": "最低价",
"close": "收盘价",
"vol": "成交量(股数)",
"amount": "成交额(金额)",
"vwap": "可用 amount / vol 计算",
"returns": "可用 close / ts_delay(close, 1) - 1 计算",
}
lines = []
for feat, desc in descriptions.items():
lines.append(f" {feat}: {desc}")
return "\n".join(lines)
```
**Step 2: 定义本地 DSL 算子表映射**
`prompt_builder.py` 中新增 `LOCAL_OPERATOR_TABLE` 常量,列出 prompt 中需要展示的本地可用算子(按类别分组),不再依赖 `OPERATOR_REGISTRY` 遍历:
```python
LOCAL_OPERATOR_TABLE = {
"ARITHMETIC": [
("+", "二元", "x + y"),
("-", "二元/一元", "x - y 或 -x"),
("*", "二元", "x * y"),
("/", "二元", "x / y"),
("**", "二元", "x ** y (幂运算)"),
(">", "二元", "x > y (条件判断,返回 0/1)"),
("<", "二元", "x < y (条件判断,返回 0/1)"),
("abs(x)", "一元", "绝对值"),
("sign(x)", "一元", "符号函数"),
("max_(x, y)", "二元", "逐元素最大值"),
("min_(x, y)", "二元", "逐元素最小值"),
("clip(x, lower, upper)", "一元带参", "截断"),
("log(x)", "一元", "自然对数"),
("sqrt(x)", "一元", "平方根"),
("exp(x)", "一元", "指数函数"),
],
"TIMESERIES": [
("ts_mean(x, window)", "一元+窗口", "滚动均值"),
("ts_std(x, window)", "一元+窗口", "滚动标准差"),
("ts_var(x, window)", "一元+窗口", "滚动方差"),
("ts_max(x, window)", "一元+窗口", "滚动最大值"),
("ts_min(x, window)", "一元+窗口", "滚动最小值"),
("ts_sum(x, window)", "一元+窗口", "滚动求和"),
("ts_delay(x, periods)", "一元+周期", "滞后 N 期"),
("ts_delta(x, periods)", "一元+周期", "差分 N 期"),
("ts_corr(x, y, window)", "二元+窗口", "滚动相关系数"),
("ts_cov(x, y, window)", "二元+窗口", "滚动协方差"),
("ts_pct_change(x, periods)", "一元+周期", "N 期百分比变化"),
("ts_ema(x, window)", "一元+窗口", "指数移动平均"),
("ts_wma(x, window)", "一元+窗口", "加权移动平均"),
("ts_skew(x, window)", "一元+窗口", "滚动偏度"),
("ts_kurt(x, window)", "一元+窗口", "滚动峰度"),
("ts_rank(x, window)", "一元+窗口", "滚动分位排名"),
],
"CROSS_SECTIONAL": [
("cs_rank(x)", "一元", "截面排名(分位数)"),
("cs_zscore(x)", "一元", "截面 Z-Score 标准化"),
("cs_demean(x)", "一元", "截面去均值"),
("cs_neutralize(x, group)", "一元", "行业/市值中性化"),
("cs_winsorize(x, lower, upper)", "一元", "截面缩尾处理"),
],
"CONDITIONAL": [
("if_(condition, true_val, false_val)", "三元", "条件选择"),
("where(condition, true_val, false_val)", "三元", "if_ 的别名"),
],
}
```
然后重写 `_format_operator_table()`
```python
def _format_operator_table() -> str:
lines = []
for cat_name, ops in LOCAL_OPERATOR_TABLE.items():
lines.append(f"\n### {cat_name} operators")
for op_sig, arity, desc in ops:
lines.append(f"- {op_sig}: {desc} ({arity})")
return "\n".join(lines)
```
**Step 3: 重写 `SYSTEM_PROMPT`**
替换语法规则段落和示例:
```python
SYSTEM_PROMPT = f"""You are a quantitative researcher mining formulaic alpha factors for stock selection.
Your goal is to generate novel, predictive factor expressions using the local ProStock DSL. Each factor is a composition of operators applied to raw market features.
## RAW FEATURES (leaf nodes)
{_format_feature_list()}
## OPERATOR LIBRARY
{_format_operator_table()}
## EXPRESSION SYNTAX RULES
1. Expressions use Python-style infix operators: +, -, *, /, **, >, <
2. Function calls use snake_case names with comma-separated arguments: ts_mean(close, 20)
3. Window sizes and periods are numeric arguments placed last in function calls.
4. Valid window sizes are integers, typically in range [2, 250].
5. Cross-sectional operators (cs_rank, cs_zscore, cs_demean) operate across all stocks at each time step -- they are crucial for making factors comparable.
6. Do NOT use $ prefix for features. Use `close`, `vol`, `amount`, etc. directly.
7. `vwap` is not a raw feature; use `amount / vol` if you need it.
8. `returns` is not a raw feature; use `close / ts_delay(close, 1) - 1` if you need returns.
## EXAMPLES OF WELL-FORMED FACTORS
- -cs_rank(ts_delta(close, 5))
Short-term reversal: rank of 5-day price change, negated.
- cs_zscore((vol - ts_mean(vol, 20)) / ts_std(vol, 20))
Volume surprise: standardized deviation from 20-day mean volume.
- cs_rank((close - amount / vol) / (amount / vol))
Intraday deviation from VWAP, cross-sectionally ranked.
- -ts_corr(vol, close, 10)
Negative price-volume correlation over 10 days.
- if_(close / ts_delay(close, 1) - 1 > 0, ts_std(close / ts_delay(close, 1) - 1, 10), -ts_std(close / ts_delay(close, 1) - 1, 10))
Conditional volatility: positive for up-moves, negative for down-moves.
- cs_rank((close - ts_min(low, 20)) / (ts_max(high, 20) - ts_min(low, 20)))
Position within 20-day price range, ranked.
## KEY PRINCIPLES FOR HIGH-QUALITY FACTORS
- Always wrap the outermost expression with a cross-sectional operator (cs_rank, cs_zscore) for comparability.
- Combine DIFFERENT operator types for novelty (e.g., time-series + cross-sectional + arithmetic).
- Use diverse window sizes; avoid always defaulting to 10.
- Explore uncommon feature combinations (amount, amount/vol are underused).
- Factors with depth 3-7 tend to be best: deep enough to capture non-trivial patterns but not so deep they overfit.
- Prefer economically meaningful combinations over random nesting.
- IMPORTANT: Avoid operators that are NOT listed above (e.g., Decay, TsLinRegSlope, HMA, DEMA, Resid). If you use them, the factor will be rejected.
"""
```
**Step 4: 更新所有输出格式示例**
`build_user_prompt`约第333行将示例公式替换为本地 DSL
```
1. momentum_reversal: -cs_rank(ts_delta(close, 5))
2. volume_surprise: cs_zscore((vol - ts_mean(vol, 20)) / ts_std(vol, 20))
```
`build_specialist_prompt`约第529行中同步替换
```
Example: 1. momentum_reversal: -cs_rank(ts_delta(close, 5))
```
**Step 5: 运行 prompt_builder 相关测试(若已有)**
```bash
uv run pytest tests/test_factorminer_prompt.py -v -k prompt
```
---
## Task 2: 修改 `src/factorminer/agent/output_parser.py`
**Files:**
- Modify: `src/factorminer/agent/output_parser.py`
- Test: `tests/test_factorminer_prompt.py`
**Step 1: 移除 FactorMiner 解析器依赖**
删除以下导入:
```python
from src.factorminer.core.expression_tree import ExpressionTree
from src.factorminer.core.parser import parse, try_parse
from src.factorminer.core.types import OperatorType, OPERATOR_REGISTRY
```
**Step 2: 修改 `CandidateFactor`**
```python
@dataclass
class CandidateFactor:
"""A candidate factor parsed from LLM output.
Attributes
----------
name : str
Descriptive snake_case name.
formula : str
DSL formula string.
category : str
Inferred category based on outermost operators.
parse_error : str
Error message if formula failed basic validation.
"""
name: str
formula: str
category: str = "unknown"
parse_error: str = ""
@property
def is_valid(self) -> bool:
return not self.parse_error and bool(self.formula.strip())
```
**Step 3: 修改 `_infer_category()`**
将所有 CamelCase 算子名替换为 snake_case
```python
def _infer_category(formula: str) -> str:
"""Infer a rough category from the outermost operators in the formula."""
if any(op in formula for op in ("cs_rank", "cs_zscore", "cs_demean", "cs_neutralize", "cs_winsorize")):
if any(op in formula for op in ("ts_corr", "ts_cov")):
return "cross_sectional_regression"
if any(op in formula for op in ("ts_delta", "ts_delay", "ts_pct_change")):
return "cross_sectional_momentum"
if any(op in formula for op in ("ts_std", "ts_var", "ts_skew", "ts_kurt")):
return "cross_sectional_volatility"
if any(op in formula for op in ("ts_mean", "ts_sum", "ts_ema", "ts_wma")):
return "cross_sectional_smoothing"
return "cross_sectional"
if any(op in formula for op in ("ts_corr", "ts_cov")):
return "regression"
if any(op in formula for op in ("ts_delta", "ts_delay", "ts_pct_change")):
return "momentum"
if any(op in formula for op in ("ts_std", "ts_var", "ts_skew", "ts_kurt")):
return "volatility"
if any(op in formula for op in ("if_", "where", ">", "<")):
return "conditional"
return "general"
```
**Step 4: 修改 `_FORMULA_ONLY_PATTERN`**
本地 DSL 公式可能以 `cs_`, `ts_` 开头,也可能以 `-` 开头(如 `-cs_rank(...)`),或字段名/数字开头:
```python
_FORMULA_ONLY_PATTERN = re.compile(
r"^\s*([a-zA-Z_][a-zA-Z0-9_]*\s*\(.*\)|-.*|\d.*)\s*$"
)
```
**Step 5: 修改 `_clean_formula()`**
移除 `$` 清洗逻辑(当前已不需要替换 `$` 前缀),保留注释、标点和反引号清理:
```python
def _clean_formula(formula: str) -> str:
"""Clean up a formula string before parsing."""
formula = formula.strip()
# Remove trailing comments
if " #" in formula:
formula = formula[: formula.index(" #")]
if " //" in formula:
formula = formula[: formula.index(" //")]
# Remove trailing punctuation
formula = formula.rstrip(";,.")
# Remove surrounding backticks
formula = formula.strip("`")
return formula.strip()
```
**Step 6: 重写 `_try_build_candidate()`**
不再调用 `try_parse(formula)``ExpressionTree`,仅做基础校验:
```python
def _try_build_candidate(name: str, formula: str) -> CandidateFactor:
"""Attempt to validate a formula and build a CandidateFactor."""
# Basic validation: parenthesis balance
if formula.count("(") != formula.count(")"):
return CandidateFactor(
name=name,
formula=formula,
category="unknown",
parse_error="括号不匹配",
)
category = _infer_category(formula)
return CandidateFactor(name=name, formula=formula, category=category)
```
**Step 7: 修改 `_generate_name_from_formula()`**
正则提取的逻辑调整为适配 snake_case 函数名(第一个括号前的部分):
```python
def _generate_name_from_formula(formula: str, index: int) -> str:
"""Generate a descriptive name from a formula."""
# Extract the outermost operator (snake_case)
m = re.match(r"([a-zA-Z_][a-zA-Z0-9_]*)\s*\(", formula)
if m:
outer_op = m.group(1).lower()
return f"{outer_op}_factor_{index + 1}"
# Handle unary minus
m = re.match(r"-([a-zA-Z_][a-zA-Z0-9_]*)\s*\(", formula)
if m:
outer_op = m.group(1).lower()
return f"neg_{outer_op}_factor_{index + 1}"
return f"factor_{index + 1}"
```
---
## Task 3: 适配 `src/factorminer/agent/factor_generator.py`
**Files:**
- Modify: `src/factorminer/agent/factor_generator.py`
**Step 1: 更新 retry prompt 的 DSL 规则描述**
`_retry_failed_parses` 方法中约第199行将 repair_prompt 中的描述改为本地 DSL 规则:
```python
repair_prompt = (
"The following factor formulas failed to parse. "
"Fix each one so it uses ONLY valid local DSL operators and features "
"from the library. Return them in the same numbered format:\n"
"<number>. <name>: <corrected_formula>\n\n"
"Broken formulas:\n"
+ "\n".join(f" {i+1}. {f}" for i, f in enumerate(failed))
+ "\n\nFix all syntax errors, unknown operators, and invalid "
"feature names. Use snake_case functions (e.g., ts_mean, cs_rank), "
"infix operators (+, -, *, /, >, <), and raw features without $ prefix. "
"Every formula must be valid in the local DSL."
)
```
**Step 2: 确认 `generate_batch` 无需修改**
因为 `CandidateFactor.is_valid` 已改为基于字符串校验,`generate_batch` 中的过滤逻辑自然兼容。
---
## Task 4: 编写测试 `tests/test_factorminer_prompt.py`
**Files:**
- Create: `tests/test_factorminer_prompt.py`
**Step 1: 测试 system prompt 使用本地 DSL**
```python
import pytest
from src.factorminer.agent.prompt_builder import SYSTEM_PROMPT
def test_system_prompt_uses_local_dsl():
assert "$close" not in SYSTEM_PROMPT
assert "CsRank(" not in SYSTEM_PROMPT
assert "cs_rank(" in SYSTEM_PROMPT
assert "close / ts_delay(close, 1) - 1" in SYSTEM_PROMPT
assert "ts_mean(close, 20)" in SYSTEM_PROMPT
```
**Step 2: 测试 OutputParser 正确提取本地 DSL**
```python
from src.factorminer.agent.output_parser import parse_llm_output, CandidateFactor
def test_parse_local_dsl_numbered_list():
raw = (
"1. momentum: -cs_rank(ts_delta(close, 5))\n"
"2. volume: cs_zscore((vol - ts_mean(vol, 20)) / ts_std(vol, 20))\n"
"3. vwap_dev: cs_rank((close - amount / vol) / (amount / vol))\n"
)
candidates, failed = parse_llm_output(raw)
assert len(candidates) == 3
assert candidates[0].name == "momentum"
assert candidates[0].formula == "-cs_rank(ts_delta(close, 5))"
assert candidates[0].is_valid
assert candidates[1].name == "volume"
assert candidates[1].formula == "cs_zscore((vol - ts_mean(vol, 20)) / ts_std(vol, 20))"
assert candidates[2].name == "vwap_dev"
assert not failed
```
**Step 3: 测试 formula-only 行**
```python
def test_parse_local_dsl_formula_only():
raw = "cs_rank(close / ts_delay(close, 5) - 1)"
candidates, failed = parse_llm_output(raw)
assert len(candidates) == 1
assert candidates[0].formula == "cs_rank(close / ts_delay(close, 5) - 1)"
assert not failed
```
**Step 4: 测试括号不匹配标记为无效**
```python
def test_parse_invalid_parentheses():
candidates, failed = parse_llm_output("1. bad: cs_rank(ts_delta(close, 5)")
assert len(candidates) == 1
assert not candidates[0].is_valid
assert "括号" in candidates[0].parse_error
```
**Step 5: 测试分类推断**
```python
def test_infer_category_local_dsl():
from src.factorminer.agent.output_parser import _infer_category
assert _infer_category("cs_rank(ts_delta(close, 5))") == "cross_sectional_momentum"
assert _infer_category("ts_corr(vol, close, 10)") == "regression"
assert _infer_category("ts_std(close, 20)") == "volatility"
assert _infer_category("if_(close > open, 1, -1)") == "conditional"
```
**Step 6: 运行测试**
```bash
uv run pytest tests/test_factorminer_prompt.py -v
```
预期:所有测试通过。
---
## 执行命令汇总
```bash
# 安装依赖(若尚未安装)
uv pip install -e .
# 运行新增测试
uv run pytest tests/test_factorminer_prompt.py -v
# 运行 factorminer 相关测试
uv run pytest tests/test_factorminer_* -v
```
---
## 提交建议
修改完成后建议拆分为两个 commits
1. `refactor(factorminer): rewrite LLM prompts to output local snake_case DSL`
2. `test(factorminer): add prompt and output parser tests for local DSL`