Files
ProStock/docs/plans/2026-04-08-step3-llm-prompt-local-dsl.md
liaozhaorun dd2e8a4a8e refactor(factorminer): 将 LLM Prompt 和解析器改造为直接输出本地 DSL
- DSL 规范改为 snake_case、中缀运算符,示例同步替换
- 移除 ExpressionTree 依赖,改为括号匹配等基础校验
- retry prompt 适配本地 DSL 规则
2026-04-08 22:27:33 +08:00

16 KiB
Raw Blame History

Step 3: LLM Prompt 改造(直接生成本地 DSL实施计划

For Claude: REQUIRED SUB-SKILL: Use superpowers:executing-plans to implement this plan task-by-task.

Goal: 将 FactorMiner 的 LLM Prompt 和输出解析器从 CamelCase + $ 前缀 DSL 改造为直接生成本地 snake_case DSL移除运行时翻译层。

Architecture: Prompt 直接使用本地 FactorEngine 支持的 snake_case 函数名和字段名;OutputParser 仅做字符串提取和轻量清洗,不再调用 FactorMiner 的 ExpressionTree 解析;factor_generator.py 配合返回原始 DSL 字符串。

Tech Stack: Python, ProStock src.factors 本地 DSL (FactorEngine)


Task 1: 重写 src/factorminer/agent/prompt_builder.py

Files:

  • Modify: src/factorminer/agent/prompt_builder.py
  • Test: tests/test_factorminer_prompt.py

Step 1: 重写字段列表函数 _format_feature_list()

$ 前缀字段替换为本地字段,并添加计算说明:

def _format_feature_list() -> str:
    descriptions = {
        "open": "开盘价",
        "high": "最高价",
        "low": "最低价",
        "close": "收盘价",
        "vol": "成交量(股数)",
        "amount": "成交额(金额)",
        "vwap": "可用 amount / vol 计算",
        "returns": "可用 close / ts_delay(close, 1) - 1 计算",
    }
    lines = []
    for feat, desc in descriptions.items():
        lines.append(f"  {feat}: {desc}")
    return "\n".join(lines)

Step 2: 定义本地 DSL 算子表映射

prompt_builder.py 中新增 LOCAL_OPERATOR_TABLE 常量,列出 prompt 中需要展示的本地可用算子(按类别分组),不再依赖 OPERATOR_REGISTRY 遍历:

LOCAL_OPERATOR_TABLE = {
    "ARITHMETIC": [
        ("+", "二元", "x + y"),
        ("-", "二元/一元", "x - y 或 -x"),
        ("*", "二元", "x * y"),
        ("/", "二元", "x / y"),
        ("**", "二元", "x ** y (幂运算)"),
        (">", "二元", "x > y (条件判断,返回 0/1)"),
        ("<", "二元", "x < y (条件判断,返回 0/1)"),
        ("abs(x)", "一元", "绝对值"),
        ("sign(x)", "一元", "符号函数"),
        ("max_(x, y)", "二元", "逐元素最大值"),
        ("min_(x, y)", "二元", "逐元素最小值"),
        ("clip(x, lower, upper)", "一元带参", "截断"),
        ("log(x)", "一元", "自然对数"),
        ("sqrt(x)", "一元", "平方根"),
        ("exp(x)", "一元", "指数函数"),
    ],
    "TIMESERIES": [
        ("ts_mean(x, window)", "一元+窗口", "滚动均值"),
        ("ts_std(x, window)", "一元+窗口", "滚动标准差"),
        ("ts_var(x, window)", "一元+窗口", "滚动方差"),
        ("ts_max(x, window)", "一元+窗口", "滚动最大值"),
        ("ts_min(x, window)", "一元+窗口", "滚动最小值"),
        ("ts_sum(x, window)", "一元+窗口", "滚动求和"),
        ("ts_delay(x, periods)", "一元+周期", "滞后 N 期"),
        ("ts_delta(x, periods)", "一元+周期", "差分 N 期"),
        ("ts_corr(x, y, window)", "二元+窗口", "滚动相关系数"),
        ("ts_cov(x, y, window)", "二元+窗口", "滚动协方差"),
        ("ts_pct_change(x, periods)", "一元+周期", "N 期百分比变化"),
        ("ts_ema(x, window)", "一元+窗口", "指数移动平均"),
        ("ts_wma(x, window)", "一元+窗口", "加权移动平均"),
        ("ts_skew(x, window)", "一元+窗口", "滚动偏度"),
        ("ts_kurt(x, window)", "一元+窗口", "滚动峰度"),
        ("ts_rank(x, window)", "一元+窗口", "滚动分位排名"),
    ],
    "CROSS_SECTIONAL": [
        ("cs_rank(x)", "一元", "截面排名(分位数)"),
        ("cs_zscore(x)", "一元", "截面 Z-Score 标准化"),
        ("cs_demean(x)", "一元", "截面去均值"),
        ("cs_neutralize(x, group)", "一元", "行业/市值中性化"),
        ("cs_winsorize(x, lower, upper)", "一元", "截面缩尾处理"),
    ],
    "CONDITIONAL": [
        ("if_(condition, true_val, false_val)", "三元", "条件选择"),
        ("where(condition, true_val, false_val)", "三元", "if_ 的别名"),
    ],
}

然后重写 _format_operator_table()

def _format_operator_table() -> str:
    lines = []
    for cat_name, ops in LOCAL_OPERATOR_TABLE.items():
        lines.append(f"\n### {cat_name} operators")
        for op_sig, arity, desc in ops:
            lines.append(f"- {op_sig}: {desc} ({arity})")
    return "\n".join(lines)

Step 3: 重写 SYSTEM_PROMPT

替换语法规则段落和示例:

SYSTEM_PROMPT = f"""You are a quantitative researcher mining formulaic alpha factors for stock selection.

Your goal is to generate novel, predictive factor expressions using the local ProStock DSL. Each factor is a composition of operators applied to raw market features.

## RAW FEATURES (leaf nodes)
{_format_feature_list()}

## OPERATOR LIBRARY
{_format_operator_table()}

## EXPRESSION SYNTAX RULES
1. Expressions use Python-style infix operators: +, -, *, /, **, >, <
2. Function calls use snake_case names with comma-separated arguments: ts_mean(close, 20)
3. Window sizes and periods are numeric arguments placed last in function calls.
4. Valid window sizes are integers, typically in range [2, 250].
5. Cross-sectional operators (cs_rank, cs_zscore, cs_demean) operate across all stocks at each time step -- they are crucial for making factors comparable.
6. Do NOT use $ prefix for features. Use `close`, `vol`, `amount`, etc. directly.
7. `vwap` is not a raw feature; use `amount / vol` if you need it.
8. `returns` is not a raw feature; use `close / ts_delay(close, 1) - 1` if you need returns.

## EXAMPLES OF WELL-FORMED FACTORS
- -cs_rank(ts_delta(close, 5))
  Short-term reversal: rank of 5-day price change, negated.
- cs_zscore((vol - ts_mean(vol, 20)) / ts_std(vol, 20))
  Volume surprise: standardized deviation from 20-day mean volume.
- cs_rank((close - amount / vol) / (amount / vol))
  Intraday deviation from VWAP, cross-sectionally ranked.
- -ts_corr(vol, close, 10)
  Negative price-volume correlation over 10 days.
- if_(close / ts_delay(close, 1) - 1 > 0, ts_std(close / ts_delay(close, 1) - 1, 10), -ts_std(close / ts_delay(close, 1) - 1, 10))
  Conditional volatility: positive for up-moves, negative for down-moves.
- cs_rank((close - ts_min(low, 20)) / (ts_max(high, 20) - ts_min(low, 20)))
  Position within 20-day price range, ranked.

## KEY PRINCIPLES FOR HIGH-QUALITY FACTORS
- Always wrap the outermost expression with a cross-sectional operator (cs_rank, cs_zscore) for comparability.
- Combine DIFFERENT operator types for novelty (e.g., time-series + cross-sectional + arithmetic).
- Use diverse window sizes; avoid always defaulting to 10.
- Explore uncommon feature combinations (amount, amount/vol are underused).
- Factors with depth 3-7 tend to be best: deep enough to capture non-trivial patterns but not so deep they overfit.
- Prefer economically meaningful combinations over random nesting.
- IMPORTANT: Avoid operators that are NOT listed above (e.g., Decay, TsLinRegSlope, HMA, DEMA, Resid). If you use them, the factor will be rejected.
"""

Step 4: 更新所有输出格式示例

build_user_prompt约第333行将示例公式替换为本地 DSL

1. momentum_reversal: -cs_rank(ts_delta(close, 5))
2. volume_surprise: cs_zscore((vol - ts_mean(vol, 20)) / ts_std(vol, 20))

build_specialist_prompt约第529行中同步替换

Example: 1. momentum_reversal: -cs_rank(ts_delta(close, 5))

Step 5: 运行 prompt_builder 相关测试(若已有)

uv run pytest tests/test_factorminer_prompt.py -v -k prompt

Task 2: 修改 src/factorminer/agent/output_parser.py

Files:

  • Modify: src/factorminer/agent/output_parser.py
  • Test: tests/test_factorminer_prompt.py

Step 1: 移除 FactorMiner 解析器依赖

删除以下导入:

from src.factorminer.core.expression_tree import ExpressionTree
from src.factorminer.core.parser import parse, try_parse
from src.factorminer.core.types import OperatorType, OPERATOR_REGISTRY

Step 2: 修改 CandidateFactor

@dataclass
class CandidateFactor:
    """A candidate factor parsed from LLM output.

    Attributes
    ----------
    name : str
        Descriptive snake_case name.
    formula : str
        DSL formula string.
    category : str
        Inferred category based on outermost operators.
    parse_error : str
        Error message if formula failed basic validation.
    """

    name: str
    formula: str
    category: str = "unknown"
    parse_error: str = ""

    @property
    def is_valid(self) -> bool:
        return not self.parse_error and bool(self.formula.strip())

Step 3: 修改 _infer_category()

将所有 CamelCase 算子名替换为 snake_case

def _infer_category(formula: str) -> str:
    """Infer a rough category from the outermost operators in the formula."""
    if any(op in formula for op in ("cs_rank", "cs_zscore", "cs_demean", "cs_neutralize", "cs_winsorize")):
        if any(op in formula for op in ("ts_corr", "ts_cov")):
            return "cross_sectional_regression"
        if any(op in formula for op in ("ts_delta", "ts_delay", "ts_pct_change")):
            return "cross_sectional_momentum"
        if any(op in formula for op in ("ts_std", "ts_var", "ts_skew", "ts_kurt")):
            return "cross_sectional_volatility"
        if any(op in formula for op in ("ts_mean", "ts_sum", "ts_ema", "ts_wma")):
            return "cross_sectional_smoothing"
        return "cross_sectional"
    if any(op in formula for op in ("ts_corr", "ts_cov")):
        return "regression"
    if any(op in formula for op in ("ts_delta", "ts_delay", "ts_pct_change")):
        return "momentum"
    if any(op in formula for op in ("ts_std", "ts_var", "ts_skew", "ts_kurt")):
        return "volatility"
    if any(op in formula for op in ("if_", "where", ">", "<")):
        return "conditional"
    return "general"

Step 4: 修改 _FORMULA_ONLY_PATTERN

本地 DSL 公式可能以 cs_, ts_ 开头,也可能以 - 开头(如 -cs_rank(...)),或字段名/数字开头:

_FORMULA_ONLY_PATTERN = re.compile(
    r"^\s*([a-zA-Z_][a-zA-Z0-9_]*\s*\(.*\)|-.*|\d.*)\s*$"
)

Step 5: 修改 _clean_formula()

移除 $ 清洗逻辑(当前已不需要替换 $ 前缀),保留注释、标点和反引号清理:

def _clean_formula(formula: str) -> str:
    """Clean up a formula string before parsing."""
    formula = formula.strip()
    # Remove trailing comments
    if " #" in formula:
        formula = formula[: formula.index(" #")]
    if " //" in formula:
        formula = formula[: formula.index(" //")]
    # Remove trailing punctuation
    formula = formula.rstrip(";,.")
    # Remove surrounding backticks
    formula = formula.strip("`")
    return formula.strip()

Step 6: 重写 _try_build_candidate()

不再调用 try_parse(formula)ExpressionTree,仅做基础校验:

def _try_build_candidate(name: str, formula: str) -> CandidateFactor:
    """Attempt to validate a formula and build a CandidateFactor."""
    # Basic validation: parenthesis balance
    if formula.count("(") != formula.count(")"):
        return CandidateFactor(
            name=name,
            formula=formula,
            category="unknown",
            parse_error="括号不匹配",
        )
    category = _infer_category(formula)
    return CandidateFactor(name=name, formula=formula, category=category)

Step 7: 修改 _generate_name_from_formula()

正则提取的逻辑调整为适配 snake_case 函数名(第一个括号前的部分):

def _generate_name_from_formula(formula: str, index: int) -> str:
    """Generate a descriptive name from a formula."""
    # Extract the outermost operator (snake_case)
    m = re.match(r"([a-zA-Z_][a-zA-Z0-9_]*)\s*\(", formula)
    if m:
        outer_op = m.group(1).lower()
        return f"{outer_op}_factor_{index + 1}"
    # Handle unary minus
    m = re.match(r"-([a-zA-Z_][a-zA-Z0-9_]*)\s*\(", formula)
    if m:
        outer_op = m.group(1).lower()
        return f"neg_{outer_op}_factor_{index + 1}"
    return f"factor_{index + 1}"

Task 3: 适配 src/factorminer/agent/factor_generator.py

Files:

  • Modify: src/factorminer/agent/factor_generator.py

Step 1: 更新 retry prompt 的 DSL 规则描述

_retry_failed_parses 方法中约第199行将 repair_prompt 中的描述改为本地 DSL 规则:

repair_prompt = (
    "The following factor formulas failed to parse. "
    "Fix each one so it uses ONLY valid local DSL operators and features "
    "from the library. Return them in the same numbered format:\n"
    "<number>. <name>: <corrected_formula>\n\n"
    "Broken formulas:\n"
    + "\n".join(f"  {i+1}. {f}" for i, f in enumerate(failed))
    + "\n\nFix all syntax errors, unknown operators, and invalid "
    "feature names. Use snake_case functions (e.g., ts_mean, cs_rank), "
    "infix operators (+, -, *, /, >, <), and raw features without $ prefix. "
    "Every formula must be valid in the local DSL."
)

Step 2: 确认 generate_batch 无需修改

因为 CandidateFactor.is_valid 已改为基于字符串校验,generate_batch 中的过滤逻辑自然兼容。


Task 4: 编写测试 tests/test_factorminer_prompt.py

Files:

  • Create: tests/test_factorminer_prompt.py

Step 1: 测试 system prompt 使用本地 DSL

import pytest
from src.factorminer.agent.prompt_builder import SYSTEM_PROMPT

def test_system_prompt_uses_local_dsl():
    assert "$close" not in SYSTEM_PROMPT
    assert "CsRank(" not in SYSTEM_PROMPT
    assert "cs_rank(" in SYSTEM_PROMPT
    assert "close / ts_delay(close, 1) - 1" in SYSTEM_PROMPT
    assert "ts_mean(close, 20)" in SYSTEM_PROMPT

Step 2: 测试 OutputParser 正确提取本地 DSL

from src.factorminer.agent.output_parser import parse_llm_output, CandidateFactor

def test_parse_local_dsl_numbered_list():
    raw = (
        "1. momentum: -cs_rank(ts_delta(close, 5))\n"
        "2. volume: cs_zscore((vol - ts_mean(vol, 20)) / ts_std(vol, 20))\n"
        "3. vwap_dev: cs_rank((close - amount / vol) / (amount / vol))\n"
    )
    candidates, failed = parse_llm_output(raw)
    assert len(candidates) == 3
    assert candidates[0].name == "momentum"
    assert candidates[0].formula == "-cs_rank(ts_delta(close, 5))"
    assert candidates[0].is_valid
    assert candidates[1].name == "volume"
    assert candidates[1].formula == "cs_zscore((vol - ts_mean(vol, 20)) / ts_std(vol, 20))"
    assert candidates[2].name == "vwap_dev"
    assert not failed

Step 3: 测试 formula-only 行

def test_parse_local_dsl_formula_only():
    raw = "cs_rank(close / ts_delay(close, 5) - 1)"
    candidates, failed = parse_llm_output(raw)
    assert len(candidates) == 1
    assert candidates[0].formula == "cs_rank(close / ts_delay(close, 5) - 1)"
    assert not failed

Step 4: 测试括号不匹配标记为无效

def test_parse_invalid_parentheses():
    candidates, failed = parse_llm_output("1. bad: cs_rank(ts_delta(close, 5)")
    assert len(candidates) == 1
    assert not candidates[0].is_valid
    assert "括号" in candidates[0].parse_error

Step 5: 测试分类推断

def test_infer_category_local_dsl():
    from src.factorminer.agent.output_parser import _infer_category
    assert _infer_category("cs_rank(ts_delta(close, 5))") == "cross_sectional_momentum"
    assert _infer_category("ts_corr(vol, close, 10)") == "regression"
    assert _infer_category("ts_std(close, 20)") == "volatility"
    assert _infer_category("if_(close > open, 1, -1)") == "conditional"

Step 6: 运行测试

uv run pytest tests/test_factorminer_prompt.py -v

预期:所有测试通过。


执行命令汇总

# 安装依赖(若尚未安装)
uv pip install -e .

# 运行新增测试
uv run pytest tests/test_factorminer_prompt.py -v

# 运行 factorminer 相关测试
uv run pytest tests/test_factorminer_* -v

提交建议

修改完成后建议拆分为两个 commits

  1. refactor(factorminer): rewrite LLM prompts to output local snake_case DSL
  2. test(factorminer): add prompt and output parser tests for local DSL