refactor(factorminer): 将 LLM Prompt 和解析器改造为直接输出本地 DSL
- DSL 规范改为 snake_case、中缀运算符,示例同步替换 - 移除 ExpressionTree 依赖,改为括号匹配等基础校验 - retry prompt 适配本地 DSL 规则
This commit is contained in:
458
docs/plans/2026-04-08-step3-llm-prompt-local-dsl.md
Normal file
458
docs/plans/2026-04-08-step3-llm-prompt-local-dsl.md
Normal file
@@ -0,0 +1,458 @@
|
|||||||
|
# Step 3: LLM Prompt 改造(直接生成本地 DSL)实施计划
|
||||||
|
|
||||||
|
> **For Claude:** REQUIRED SUB-SKILL: Use superpowers:executing-plans to implement this plan task-by-task.
|
||||||
|
|
||||||
|
**Goal:** 将 FactorMiner 的 LLM Prompt 和输出解析器从 CamelCase + `$` 前缀 DSL 改造为直接生成本地 snake_case DSL,移除运行时翻译层。
|
||||||
|
|
||||||
|
**Architecture:** Prompt 直接使用本地 `FactorEngine` 支持的 snake_case 函数名和字段名;`OutputParser` 仅做字符串提取和轻量清洗,不再调用 FactorMiner 的 `ExpressionTree` 解析;`factor_generator.py` 配合返回原始 DSL 字符串。
|
||||||
|
|
||||||
|
**Tech Stack:** Python, ProStock `src.factors` 本地 DSL (`FactorEngine`)
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Task 1: 重写 `src/factorminer/agent/prompt_builder.py`
|
||||||
|
|
||||||
|
**Files:**
|
||||||
|
- Modify: `src/factorminer/agent/prompt_builder.py`
|
||||||
|
- Test: `tests/test_factorminer_prompt.py`
|
||||||
|
|
||||||
|
**Step 1: 重写字段列表函数 `_format_feature_list()`**
|
||||||
|
|
||||||
|
将 `$` 前缀字段替换为本地字段,并添加计算说明:
|
||||||
|
|
||||||
|
```python
|
||||||
|
def _format_feature_list() -> str:
|
||||||
|
descriptions = {
|
||||||
|
"open": "开盘价",
|
||||||
|
"high": "最高价",
|
||||||
|
"low": "最低价",
|
||||||
|
"close": "收盘价",
|
||||||
|
"vol": "成交量(股数)",
|
||||||
|
"amount": "成交额(金额)",
|
||||||
|
"vwap": "可用 amount / vol 计算",
|
||||||
|
"returns": "可用 close / ts_delay(close, 1) - 1 计算",
|
||||||
|
}
|
||||||
|
lines = []
|
||||||
|
for feat, desc in descriptions.items():
|
||||||
|
lines.append(f" {feat}: {desc}")
|
||||||
|
return "\n".join(lines)
|
||||||
|
```
|
||||||
|
|
||||||
|
**Step 2: 定义本地 DSL 算子表映射**
|
||||||
|
|
||||||
|
在 `prompt_builder.py` 中新增 `LOCAL_OPERATOR_TABLE` 常量,列出 prompt 中需要展示的本地可用算子(按类别分组),不再依赖 `OPERATOR_REGISTRY` 遍历:
|
||||||
|
|
||||||
|
```python
|
||||||
|
LOCAL_OPERATOR_TABLE = {
|
||||||
|
"ARITHMETIC": [
|
||||||
|
("+", "二元", "x + y"),
|
||||||
|
("-", "二元/一元", "x - y 或 -x"),
|
||||||
|
("*", "二元", "x * y"),
|
||||||
|
("/", "二元", "x / y"),
|
||||||
|
("**", "二元", "x ** y (幂运算)"),
|
||||||
|
(">", "二元", "x > y (条件判断,返回 0/1)"),
|
||||||
|
("<", "二元", "x < y (条件判断,返回 0/1)"),
|
||||||
|
("abs(x)", "一元", "绝对值"),
|
||||||
|
("sign(x)", "一元", "符号函数"),
|
||||||
|
("max_(x, y)", "二元", "逐元素最大值"),
|
||||||
|
("min_(x, y)", "二元", "逐元素最小值"),
|
||||||
|
("clip(x, lower, upper)", "一元带参", "截断"),
|
||||||
|
("log(x)", "一元", "自然对数"),
|
||||||
|
("sqrt(x)", "一元", "平方根"),
|
||||||
|
("exp(x)", "一元", "指数函数"),
|
||||||
|
],
|
||||||
|
"TIMESERIES": [
|
||||||
|
("ts_mean(x, window)", "一元+窗口", "滚动均值"),
|
||||||
|
("ts_std(x, window)", "一元+窗口", "滚动标准差"),
|
||||||
|
("ts_var(x, window)", "一元+窗口", "滚动方差"),
|
||||||
|
("ts_max(x, window)", "一元+窗口", "滚动最大值"),
|
||||||
|
("ts_min(x, window)", "一元+窗口", "滚动最小值"),
|
||||||
|
("ts_sum(x, window)", "一元+窗口", "滚动求和"),
|
||||||
|
("ts_delay(x, periods)", "一元+周期", "滞后 N 期"),
|
||||||
|
("ts_delta(x, periods)", "一元+周期", "差分 N 期"),
|
||||||
|
("ts_corr(x, y, window)", "二元+窗口", "滚动相关系数"),
|
||||||
|
("ts_cov(x, y, window)", "二元+窗口", "滚动协方差"),
|
||||||
|
("ts_pct_change(x, periods)", "一元+周期", "N 期百分比变化"),
|
||||||
|
("ts_ema(x, window)", "一元+窗口", "指数移动平均"),
|
||||||
|
("ts_wma(x, window)", "一元+窗口", "加权移动平均"),
|
||||||
|
("ts_skew(x, window)", "一元+窗口", "滚动偏度"),
|
||||||
|
("ts_kurt(x, window)", "一元+窗口", "滚动峰度"),
|
||||||
|
("ts_rank(x, window)", "一元+窗口", "滚动分位排名"),
|
||||||
|
],
|
||||||
|
"CROSS_SECTIONAL": [
|
||||||
|
("cs_rank(x)", "一元", "截面排名(分位数)"),
|
||||||
|
("cs_zscore(x)", "一元", "截面 Z-Score 标准化"),
|
||||||
|
("cs_demean(x)", "一元", "截面去均值"),
|
||||||
|
("cs_neutralize(x, group)", "一元", "行业/市值中性化"),
|
||||||
|
("cs_winsorize(x, lower, upper)", "一元", "截面缩尾处理"),
|
||||||
|
],
|
||||||
|
"CONDITIONAL": [
|
||||||
|
("if_(condition, true_val, false_val)", "三元", "条件选择"),
|
||||||
|
("where(condition, true_val, false_val)", "三元", "if_ 的别名"),
|
||||||
|
],
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
然后重写 `_format_operator_table()`:
|
||||||
|
|
||||||
|
```python
|
||||||
|
def _format_operator_table() -> str:
|
||||||
|
lines = []
|
||||||
|
for cat_name, ops in LOCAL_OPERATOR_TABLE.items():
|
||||||
|
lines.append(f"\n### {cat_name} operators")
|
||||||
|
for op_sig, arity, desc in ops:
|
||||||
|
lines.append(f"- {op_sig}: {desc} ({arity})")
|
||||||
|
return "\n".join(lines)
|
||||||
|
```
|
||||||
|
|
||||||
|
**Step 3: 重写 `SYSTEM_PROMPT`**
|
||||||
|
|
||||||
|
替换语法规则段落和示例:
|
||||||
|
|
||||||
|
```python
|
||||||
|
SYSTEM_PROMPT = f"""You are a quantitative researcher mining formulaic alpha factors for stock selection.
|
||||||
|
|
||||||
|
Your goal is to generate novel, predictive factor expressions using the local ProStock DSL. Each factor is a composition of operators applied to raw market features.
|
||||||
|
|
||||||
|
## RAW FEATURES (leaf nodes)
|
||||||
|
{_format_feature_list()}
|
||||||
|
|
||||||
|
## OPERATOR LIBRARY
|
||||||
|
{_format_operator_table()}
|
||||||
|
|
||||||
|
## EXPRESSION SYNTAX RULES
|
||||||
|
1. Expressions use Python-style infix operators: +, -, *, /, **, >, <
|
||||||
|
2. Function calls use snake_case names with comma-separated arguments: ts_mean(close, 20)
|
||||||
|
3. Window sizes and periods are numeric arguments placed last in function calls.
|
||||||
|
4. Valid window sizes are integers, typically in range [2, 250].
|
||||||
|
5. Cross-sectional operators (cs_rank, cs_zscore, cs_demean) operate across all stocks at each time step -- they are crucial for making factors comparable.
|
||||||
|
6. Do NOT use $ prefix for features. Use `close`, `vol`, `amount`, etc. directly.
|
||||||
|
7. `vwap` is not a raw feature; use `amount / vol` if you need it.
|
||||||
|
8. `returns` is not a raw feature; use `close / ts_delay(close, 1) - 1` if you need returns.
|
||||||
|
|
||||||
|
## EXAMPLES OF WELL-FORMED FACTORS
|
||||||
|
- -cs_rank(ts_delta(close, 5))
|
||||||
|
Short-term reversal: rank of 5-day price change, negated.
|
||||||
|
- cs_zscore((vol - ts_mean(vol, 20)) / ts_std(vol, 20))
|
||||||
|
Volume surprise: standardized deviation from 20-day mean volume.
|
||||||
|
- cs_rank((close - amount / vol) / (amount / vol))
|
||||||
|
Intraday deviation from VWAP, cross-sectionally ranked.
|
||||||
|
- -ts_corr(vol, close, 10)
|
||||||
|
Negative price-volume correlation over 10 days.
|
||||||
|
- if_(close / ts_delay(close, 1) - 1 > 0, ts_std(close / ts_delay(close, 1) - 1, 10), -ts_std(close / ts_delay(close, 1) - 1, 10))
|
||||||
|
Conditional volatility: positive for up-moves, negative for down-moves.
|
||||||
|
- cs_rank((close - ts_min(low, 20)) / (ts_max(high, 20) - ts_min(low, 20)))
|
||||||
|
Position within 20-day price range, ranked.
|
||||||
|
|
||||||
|
## KEY PRINCIPLES FOR HIGH-QUALITY FACTORS
|
||||||
|
- Always wrap the outermost expression with a cross-sectional operator (cs_rank, cs_zscore) for comparability.
|
||||||
|
- Combine DIFFERENT operator types for novelty (e.g., time-series + cross-sectional + arithmetic).
|
||||||
|
- Use diverse window sizes; avoid always defaulting to 10.
|
||||||
|
- Explore uncommon feature combinations (amount, amount/vol are underused).
|
||||||
|
- Factors with depth 3-7 tend to be best: deep enough to capture non-trivial patterns but not so deep they overfit.
|
||||||
|
- Prefer economically meaningful combinations over random nesting.
|
||||||
|
- IMPORTANT: Avoid operators that are NOT listed above (e.g., Decay, TsLinRegSlope, HMA, DEMA, Resid). If you use them, the factor will be rejected.
|
||||||
|
"""
|
||||||
|
```
|
||||||
|
|
||||||
|
**Step 4: 更新所有输出格式示例**
|
||||||
|
|
||||||
|
在 `build_user_prompt`(约第333行)中,将示例公式替换为本地 DSL:
|
||||||
|
|
||||||
|
```
|
||||||
|
1. momentum_reversal: -cs_rank(ts_delta(close, 5))
|
||||||
|
2. volume_surprise: cs_zscore((vol - ts_mean(vol, 20)) / ts_std(vol, 20))
|
||||||
|
```
|
||||||
|
|
||||||
|
在 `build_specialist_prompt`(约第529行)中同步替换:
|
||||||
|
|
||||||
|
```
|
||||||
|
Example: 1. momentum_reversal: -cs_rank(ts_delta(close, 5))
|
||||||
|
```
|
||||||
|
|
||||||
|
**Step 5: 运行 prompt_builder 相关测试(若已有)**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
uv run pytest tests/test_factorminer_prompt.py -v -k prompt
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Task 2: 修改 `src/factorminer/agent/output_parser.py`
|
||||||
|
|
||||||
|
**Files:**
|
||||||
|
- Modify: `src/factorminer/agent/output_parser.py`
|
||||||
|
- Test: `tests/test_factorminer_prompt.py`
|
||||||
|
|
||||||
|
**Step 1: 移除 FactorMiner 解析器依赖**
|
||||||
|
|
||||||
|
删除以下导入:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from src.factorminer.core.expression_tree import ExpressionTree
|
||||||
|
from src.factorminer.core.parser import parse, try_parse
|
||||||
|
from src.factorminer.core.types import OperatorType, OPERATOR_REGISTRY
|
||||||
|
```
|
||||||
|
|
||||||
|
**Step 2: 修改 `CandidateFactor`**
|
||||||
|
|
||||||
|
```python
|
||||||
|
@dataclass
|
||||||
|
class CandidateFactor:
|
||||||
|
"""A candidate factor parsed from LLM output.
|
||||||
|
|
||||||
|
Attributes
|
||||||
|
----------
|
||||||
|
name : str
|
||||||
|
Descriptive snake_case name.
|
||||||
|
formula : str
|
||||||
|
DSL formula string.
|
||||||
|
category : str
|
||||||
|
Inferred category based on outermost operators.
|
||||||
|
parse_error : str
|
||||||
|
Error message if formula failed basic validation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
name: str
|
||||||
|
formula: str
|
||||||
|
category: str = "unknown"
|
||||||
|
parse_error: str = ""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_valid(self) -> bool:
|
||||||
|
return not self.parse_error and bool(self.formula.strip())
|
||||||
|
```
|
||||||
|
|
||||||
|
**Step 3: 修改 `_infer_category()`**
|
||||||
|
|
||||||
|
将所有 CamelCase 算子名替换为 snake_case:
|
||||||
|
|
||||||
|
```python
|
||||||
|
def _infer_category(formula: str) -> str:
|
||||||
|
"""Infer a rough category from the outermost operators in the formula."""
|
||||||
|
if any(op in formula for op in ("cs_rank", "cs_zscore", "cs_demean", "cs_neutralize", "cs_winsorize")):
|
||||||
|
if any(op in formula for op in ("ts_corr", "ts_cov")):
|
||||||
|
return "cross_sectional_regression"
|
||||||
|
if any(op in formula for op in ("ts_delta", "ts_delay", "ts_pct_change")):
|
||||||
|
return "cross_sectional_momentum"
|
||||||
|
if any(op in formula for op in ("ts_std", "ts_var", "ts_skew", "ts_kurt")):
|
||||||
|
return "cross_sectional_volatility"
|
||||||
|
if any(op in formula for op in ("ts_mean", "ts_sum", "ts_ema", "ts_wma")):
|
||||||
|
return "cross_sectional_smoothing"
|
||||||
|
return "cross_sectional"
|
||||||
|
if any(op in formula for op in ("ts_corr", "ts_cov")):
|
||||||
|
return "regression"
|
||||||
|
if any(op in formula for op in ("ts_delta", "ts_delay", "ts_pct_change")):
|
||||||
|
return "momentum"
|
||||||
|
if any(op in formula for op in ("ts_std", "ts_var", "ts_skew", "ts_kurt")):
|
||||||
|
return "volatility"
|
||||||
|
if any(op in formula for op in ("if_", "where", ">", "<")):
|
||||||
|
return "conditional"
|
||||||
|
return "general"
|
||||||
|
```
|
||||||
|
|
||||||
|
**Step 4: 修改 `_FORMULA_ONLY_PATTERN`**
|
||||||
|
|
||||||
|
本地 DSL 公式可能以 `cs_`, `ts_` 开头,也可能以 `-` 开头(如 `-cs_rank(...)`),或字段名/数字开头:
|
||||||
|
|
||||||
|
```python
|
||||||
|
_FORMULA_ONLY_PATTERN = re.compile(
|
||||||
|
r"^\s*([a-zA-Z_][a-zA-Z0-9_]*\s*\(.*\)|-.*|\d.*)\s*$"
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
**Step 5: 修改 `_clean_formula()`**
|
||||||
|
|
||||||
|
移除 `$` 清洗逻辑(当前已不需要替换 `$` 前缀),保留注释、标点和反引号清理:
|
||||||
|
|
||||||
|
```python
|
||||||
|
def _clean_formula(formula: str) -> str:
|
||||||
|
"""Clean up a formula string before parsing."""
|
||||||
|
formula = formula.strip()
|
||||||
|
# Remove trailing comments
|
||||||
|
if " #" in formula:
|
||||||
|
formula = formula[: formula.index(" #")]
|
||||||
|
if " //" in formula:
|
||||||
|
formula = formula[: formula.index(" //")]
|
||||||
|
# Remove trailing punctuation
|
||||||
|
formula = formula.rstrip(";,.")
|
||||||
|
# Remove surrounding backticks
|
||||||
|
formula = formula.strip("`")
|
||||||
|
return formula.strip()
|
||||||
|
```
|
||||||
|
|
||||||
|
**Step 6: 重写 `_try_build_candidate()`**
|
||||||
|
|
||||||
|
不再调用 `try_parse(formula)` 或 `ExpressionTree`,仅做基础校验:
|
||||||
|
|
||||||
|
```python
|
||||||
|
def _try_build_candidate(name: str, formula: str) -> CandidateFactor:
|
||||||
|
"""Attempt to validate a formula and build a CandidateFactor."""
|
||||||
|
# Basic validation: parenthesis balance
|
||||||
|
if formula.count("(") != formula.count(")"):
|
||||||
|
return CandidateFactor(
|
||||||
|
name=name,
|
||||||
|
formula=formula,
|
||||||
|
category="unknown",
|
||||||
|
parse_error="括号不匹配",
|
||||||
|
)
|
||||||
|
category = _infer_category(formula)
|
||||||
|
return CandidateFactor(name=name, formula=formula, category=category)
|
||||||
|
```
|
||||||
|
|
||||||
|
**Step 7: 修改 `_generate_name_from_formula()`**
|
||||||
|
|
||||||
|
正则提取的逻辑调整为适配 snake_case 函数名(第一个括号前的部分):
|
||||||
|
|
||||||
|
```python
|
||||||
|
def _generate_name_from_formula(formula: str, index: int) -> str:
|
||||||
|
"""Generate a descriptive name from a formula."""
|
||||||
|
# Extract the outermost operator (snake_case)
|
||||||
|
m = re.match(r"([a-zA-Z_][a-zA-Z0-9_]*)\s*\(", formula)
|
||||||
|
if m:
|
||||||
|
outer_op = m.group(1).lower()
|
||||||
|
return f"{outer_op}_factor_{index + 1}"
|
||||||
|
# Handle unary minus
|
||||||
|
m = re.match(r"-([a-zA-Z_][a-zA-Z0-9_]*)\s*\(", formula)
|
||||||
|
if m:
|
||||||
|
outer_op = m.group(1).lower()
|
||||||
|
return f"neg_{outer_op}_factor_{index + 1}"
|
||||||
|
return f"factor_{index + 1}"
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Task 3: 适配 `src/factorminer/agent/factor_generator.py`
|
||||||
|
|
||||||
|
**Files:**
|
||||||
|
- Modify: `src/factorminer/agent/factor_generator.py`
|
||||||
|
|
||||||
|
**Step 1: 更新 retry prompt 的 DSL 规则描述**
|
||||||
|
|
||||||
|
在 `_retry_failed_parses` 方法中(约第199行),将 repair_prompt 中的描述改为本地 DSL 规则:
|
||||||
|
|
||||||
|
```python
|
||||||
|
repair_prompt = (
|
||||||
|
"The following factor formulas failed to parse. "
|
||||||
|
"Fix each one so it uses ONLY valid local DSL operators and features "
|
||||||
|
"from the library. Return them in the same numbered format:\n"
|
||||||
|
"<number>. <name>: <corrected_formula>\n\n"
|
||||||
|
"Broken formulas:\n"
|
||||||
|
+ "\n".join(f" {i+1}. {f}" for i, f in enumerate(failed))
|
||||||
|
+ "\n\nFix all syntax errors, unknown operators, and invalid "
|
||||||
|
"feature names. Use snake_case functions (e.g., ts_mean, cs_rank), "
|
||||||
|
"infix operators (+, -, *, /, >, <), and raw features without $ prefix. "
|
||||||
|
"Every formula must be valid in the local DSL."
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
**Step 2: 确认 `generate_batch` 无需修改**
|
||||||
|
|
||||||
|
因为 `CandidateFactor.is_valid` 已改为基于字符串校验,`generate_batch` 中的过滤逻辑自然兼容。
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Task 4: 编写测试 `tests/test_factorminer_prompt.py`
|
||||||
|
|
||||||
|
**Files:**
|
||||||
|
- Create: `tests/test_factorminer_prompt.py`
|
||||||
|
|
||||||
|
**Step 1: 测试 system prompt 使用本地 DSL**
|
||||||
|
|
||||||
|
```python
|
||||||
|
import pytest
|
||||||
|
from src.factorminer.agent.prompt_builder import SYSTEM_PROMPT
|
||||||
|
|
||||||
|
def test_system_prompt_uses_local_dsl():
|
||||||
|
assert "$close" not in SYSTEM_PROMPT
|
||||||
|
assert "CsRank(" not in SYSTEM_PROMPT
|
||||||
|
assert "cs_rank(" in SYSTEM_PROMPT
|
||||||
|
assert "close / ts_delay(close, 1) - 1" in SYSTEM_PROMPT
|
||||||
|
assert "ts_mean(close, 20)" in SYSTEM_PROMPT
|
||||||
|
```
|
||||||
|
|
||||||
|
**Step 2: 测试 OutputParser 正确提取本地 DSL**
|
||||||
|
|
||||||
|
```python
|
||||||
|
from src.factorminer.agent.output_parser import parse_llm_output, CandidateFactor
|
||||||
|
|
||||||
|
def test_parse_local_dsl_numbered_list():
|
||||||
|
raw = (
|
||||||
|
"1. momentum: -cs_rank(ts_delta(close, 5))\n"
|
||||||
|
"2. volume: cs_zscore((vol - ts_mean(vol, 20)) / ts_std(vol, 20))\n"
|
||||||
|
"3. vwap_dev: cs_rank((close - amount / vol) / (amount / vol))\n"
|
||||||
|
)
|
||||||
|
candidates, failed = parse_llm_output(raw)
|
||||||
|
assert len(candidates) == 3
|
||||||
|
assert candidates[0].name == "momentum"
|
||||||
|
assert candidates[0].formula == "-cs_rank(ts_delta(close, 5))"
|
||||||
|
assert candidates[0].is_valid
|
||||||
|
assert candidates[1].name == "volume"
|
||||||
|
assert candidates[1].formula == "cs_zscore((vol - ts_mean(vol, 20)) / ts_std(vol, 20))"
|
||||||
|
assert candidates[2].name == "vwap_dev"
|
||||||
|
assert not failed
|
||||||
|
```
|
||||||
|
|
||||||
|
**Step 3: 测试 formula-only 行**
|
||||||
|
|
||||||
|
```python
|
||||||
|
def test_parse_local_dsl_formula_only():
|
||||||
|
raw = "cs_rank(close / ts_delay(close, 5) - 1)"
|
||||||
|
candidates, failed = parse_llm_output(raw)
|
||||||
|
assert len(candidates) == 1
|
||||||
|
assert candidates[0].formula == "cs_rank(close / ts_delay(close, 5) - 1)"
|
||||||
|
assert not failed
|
||||||
|
```
|
||||||
|
|
||||||
|
**Step 4: 测试括号不匹配标记为无效**
|
||||||
|
|
||||||
|
```python
|
||||||
|
def test_parse_invalid_parentheses():
|
||||||
|
candidates, failed = parse_llm_output("1. bad: cs_rank(ts_delta(close, 5)")
|
||||||
|
assert len(candidates) == 1
|
||||||
|
assert not candidates[0].is_valid
|
||||||
|
assert "括号" in candidates[0].parse_error
|
||||||
|
```
|
||||||
|
|
||||||
|
**Step 5: 测试分类推断**
|
||||||
|
|
||||||
|
```python
|
||||||
|
def test_infer_category_local_dsl():
|
||||||
|
from src.factorminer.agent.output_parser import _infer_category
|
||||||
|
assert _infer_category("cs_rank(ts_delta(close, 5))") == "cross_sectional_momentum"
|
||||||
|
assert _infer_category("ts_corr(vol, close, 10)") == "regression"
|
||||||
|
assert _infer_category("ts_std(close, 20)") == "volatility"
|
||||||
|
assert _infer_category("if_(close > open, 1, -1)") == "conditional"
|
||||||
|
```
|
||||||
|
|
||||||
|
**Step 6: 运行测试**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
uv run pytest tests/test_factorminer_prompt.py -v
|
||||||
|
```
|
||||||
|
|
||||||
|
预期:所有测试通过。
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 执行命令汇总
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# 安装依赖(若尚未安装)
|
||||||
|
uv pip install -e .
|
||||||
|
|
||||||
|
# 运行新增测试
|
||||||
|
uv run pytest tests/test_factorminer_prompt.py -v
|
||||||
|
|
||||||
|
# 运行 factorminer 相关测试
|
||||||
|
uv run pytest tests/test_factorminer_* -v
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 提交建议
|
||||||
|
|
||||||
|
修改完成后建议拆分为两个 commits:
|
||||||
|
|
||||||
|
1. `refactor(factorminer): rewrite LLM prompts to output local snake_case DSL`
|
||||||
|
2. `test(factorminer): add prompt and output parser tests for local DSL`
|
||||||
@@ -117,7 +117,9 @@ class FactorGenerator:
|
|||||||
max_tokens=self.max_tokens,
|
max_tokens=self.max_tokens,
|
||||||
)
|
)
|
||||||
elapsed = time.monotonic() - t0
|
elapsed = time.monotonic() - t0
|
||||||
logger.info("LLM response received in %.1fs (%d chars)", elapsed, len(raw_output))
|
logger.info(
|
||||||
|
"LLM response received in %.1fs (%d chars)", elapsed, len(raw_output)
|
||||||
|
)
|
||||||
|
|
||||||
# 3. Parse output
|
# 3. Parse output
|
||||||
candidates, failed_lines = parse_llm_output(raw_output)
|
candidates, failed_lines = parse_llm_output(raw_output)
|
||||||
@@ -198,14 +200,15 @@ class FactorGenerator:
|
|||||||
|
|
||||||
repair_prompt = (
|
repair_prompt = (
|
||||||
"The following factor formulas failed to parse. "
|
"The following factor formulas failed to parse. "
|
||||||
"Fix each one so it uses ONLY valid operators and features "
|
"Fix each one so it uses ONLY valid local DSL operators and features "
|
||||||
"from the library. Return them in the same numbered format:\n"
|
"from the library. Return them in the same numbered format:\n"
|
||||||
"<number>. <name>: <corrected_formula>\n\n"
|
"<number>. <name>: <corrected_formula>\n\n"
|
||||||
"Broken formulas:\n"
|
"Broken formulas:\n"
|
||||||
+ "\n".join(f" {i+1}. {f}" for i, f in enumerate(failed))
|
+ "\n".join(f" {i + 1}. {f}" for i, f in enumerate(failed))
|
||||||
+ "\n\nFix all syntax errors, unknown operators, and invalid "
|
+ "\n\nFix all syntax errors, unknown operators, and invalid "
|
||||||
"feature names. Every formula must be a valid nested function "
|
"feature names. Use snake_case functions (e.g., ts_mean, cs_rank), "
|
||||||
"call using only operators from the library."
|
"infix operators (+, -, *, /, >, <), and raw features without $ prefix. "
|
||||||
|
"Every formula must be valid in the local DSL."
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|||||||
@@ -1,20 +1,16 @@
|
|||||||
"""Parse LLM output into structured CandidateFactor objects.
|
"""Parse LLM output into structured CandidateFactor objects.
|
||||||
|
|
||||||
Handles various output formats from LLMs: numbered lists, JSON,
|
Handles various output formats from LLMs: numbered lists, JSON,
|
||||||
markdown code blocks, and raw text. Validates each formula against
|
markdown code blocks, and raw text. Validates each formula with
|
||||||
the expression tree parser.
|
basic string checks; no FactorMiner-specific parsing is performed.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass
|
||||||
from typing import Dict, List, Optional, Tuple
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
from src.factorminer.core.expression_tree import ExpressionTree
|
|
||||||
from src.factorminer.core.parser import parse, try_parse
|
|
||||||
from src.factorminer.core.types import OperatorType, OPERATOR_REGISTRY
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -29,49 +25,44 @@ class CandidateFactor:
|
|||||||
Descriptive snake_case name.
|
Descriptive snake_case name.
|
||||||
formula : str
|
formula : str
|
||||||
DSL formula string.
|
DSL formula string.
|
||||||
expression_tree : ExpressionTree or None
|
|
||||||
Parsed expression tree (None if parsing failed).
|
|
||||||
category : str
|
category : str
|
||||||
Inferred category based on outermost operators.
|
Inferred category based on outermost operators.
|
||||||
parse_error : str
|
parse_error : str
|
||||||
Error message if formula failed to parse.
|
Error message if formula failed basic validation.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
name: str
|
name: str
|
||||||
formula: str
|
formula: str
|
||||||
expression_tree: Optional[ExpressionTree] = None
|
|
||||||
category: str = "unknown"
|
category: str = "unknown"
|
||||||
parse_error: str = ""
|
parse_error: str = ""
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def is_valid(self) -> bool:
|
def is_valid(self) -> bool:
|
||||||
return self.expression_tree is not None
|
return not self.parse_error and bool(self.formula.strip())
|
||||||
|
|
||||||
|
|
||||||
def _infer_category(formula: str) -> str:
|
def _infer_category(formula: str) -> str:
|
||||||
"""Infer a rough category from the outermost operators in the formula."""
|
"""Infer a rough category from the outermost operators in the formula."""
|
||||||
lower = formula.lower()
|
if any(
|
||||||
# Check for cross-sectional operators at the top
|
op in formula
|
||||||
if any(op in formula for op in ("CsRank", "CsZScore", "CsDemean", "CsScale", "CsNeutralize", "CsQuantile")):
|
for op in ("cs_rank", "cs_zscore", "cs_demean", "cs_neutralize", "cs_winsorize")
|
||||||
# Look deeper for sub-category
|
):
|
||||||
if any(op in formula for op in ("Corr", "Cov", "Beta", "Resid")):
|
if any(op in formula for op in ("ts_corr", "ts_cov")):
|
||||||
return "cross_sectional_regression"
|
return "cross_sectional_regression"
|
||||||
if any(op in formula for op in ("Delta", "Delay", "Return", "LogReturn")):
|
if any(op in formula for op in ("ts_delta", "ts_delay", "ts_pct_change")):
|
||||||
return "cross_sectional_momentum"
|
return "cross_sectional_momentum"
|
||||||
if any(op in formula for op in ("Std", "Var", "Skew", "Kurt")):
|
if any(op in formula for op in ("ts_std", "ts_var", "ts_skew", "ts_kurt")):
|
||||||
return "cross_sectional_volatility"
|
return "cross_sectional_volatility"
|
||||||
if any(op in formula for op in ("Mean", "Sum", "EMA", "SMA", "WMA", "DEMA", "HMA", "KAMA")):
|
if any(op in formula for op in ("ts_mean", "ts_sum", "ts_ema", "ts_wma")):
|
||||||
return "cross_sectional_smoothing"
|
return "cross_sectional_smoothing"
|
||||||
if any(op in formula for op in ("TsLinReg", "TsLinRegSlope")):
|
|
||||||
return "cross_sectional_trend"
|
|
||||||
return "cross_sectional"
|
return "cross_sectional"
|
||||||
if any(op in formula for op in ("Corr", "Cov", "Beta", "Resid")):
|
if any(op in formula for op in ("ts_corr", "ts_cov")):
|
||||||
return "regression"
|
return "regression"
|
||||||
if any(op in formula for op in ("Delta", "Delay", "Return", "LogReturn")):
|
if any(op in formula for op in ("ts_delta", "ts_delay", "ts_pct_change")):
|
||||||
return "momentum"
|
return "momentum"
|
||||||
if any(op in formula for op in ("Std", "Var", "Skew", "Kurt")):
|
if any(op in formula for op in ("ts_std", "ts_var", "ts_skew", "ts_kurt")):
|
||||||
return "volatility"
|
return "volatility"
|
||||||
if any(op in formula for op in ("IfElse", "Greater", "Less")):
|
if any(op in formula for op in ("if_", "where", ">", "<")):
|
||||||
return "conditional"
|
return "conditional"
|
||||||
return "general"
|
return "general"
|
||||||
|
|
||||||
@@ -95,15 +86,13 @@ _PLAIN_PATTERN = re.compile(
|
|||||||
r"(.+)$" # formula
|
r"(.+)$" # formula
|
||||||
)
|
)
|
||||||
|
|
||||||
# Pattern: just a formula starting with an operator
|
# Pattern: just a formula starting with a function call, unary minus, or number
|
||||||
_FORMULA_ONLY_PATTERN = re.compile(
|
_FORMULA_ONLY_PATTERN = re.compile(
|
||||||
r"^\s*([A-Z][a-zA-Z]*\(.+\))\s*$"
|
r"^\s*([a-zA-Z_][a-zA-Z0-9_]*\s*\(.*\)|-.*|\d.*)\s*$"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Pattern: JSON-like {"name": "...", "formula": "..."}
|
# Pattern: JSON-like {"name": "...", "formula": "..."}
|
||||||
_JSON_PATTERN = re.compile(
|
_JSON_PATTERN = re.compile(r'"name"\s*:\s*"([^"]+)"\s*,\s*"formula"\s*:\s*"([^"]+)"')
|
||||||
r'"name"\s*:\s*"([^"]+)"\s*,\s*"formula"\s*:\s*"([^"]+)"'
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _strip_markdown(text: str) -> str:
|
def _strip_markdown(text: str) -> str:
|
||||||
@@ -152,6 +141,8 @@ def parse_llm_output(raw_text: str) -> Tuple[List[CandidateFactor], List[str]]:
|
|||||||
json_matches = _JSON_PATTERN.findall(text)
|
json_matches = _JSON_PATTERN.findall(text)
|
||||||
if json_matches:
|
if json_matches:
|
||||||
for name, formula in json_matches:
|
for name, formula in json_matches:
|
||||||
|
if name is None or formula is None:
|
||||||
|
continue
|
||||||
formula = _clean_formula(formula)
|
formula = _clean_formula(formula)
|
||||||
candidate = _try_build_candidate(name, formula)
|
candidate = _try_build_candidate(name, formula)
|
||||||
if candidate.name not in seen_names:
|
if candidate.name not in seen_names:
|
||||||
@@ -184,6 +175,7 @@ def parse_llm_output(raw_text: str) -> Tuple[List[CandidateFactor], List[str]]:
|
|||||||
m = _FORMULA_ONLY_PATTERN.match(line)
|
m = _FORMULA_ONLY_PATTERN.match(line)
|
||||||
if m:
|
if m:
|
||||||
formula = m.group(1)
|
formula = m.group(1)
|
||||||
|
assert formula is not None
|
||||||
# Generate name from formula
|
# Generate name from formula
|
||||||
name = _generate_name_from_formula(formula, len(candidates))
|
name = _generate_name_from_formula(formula, len(candidates))
|
||||||
|
|
||||||
@@ -222,38 +214,29 @@ def parse_llm_output(raw_text: str) -> Tuple[List[CandidateFactor], List[str]]:
|
|||||||
|
|
||||||
|
|
||||||
def _try_build_candidate(name: str, formula: str) -> CandidateFactor:
|
def _try_build_candidate(name: str, formula: str) -> CandidateFactor:
|
||||||
"""Attempt to parse a formula and build a CandidateFactor."""
|
"""Attempt to validate a formula and build a CandidateFactor."""
|
||||||
tree = try_parse(formula)
|
# Basic validation: parenthesis balance
|
||||||
if tree is not None:
|
if formula.count("(") != formula.count(")"):
|
||||||
category = _infer_category(formula)
|
|
||||||
return CandidateFactor(
|
|
||||||
name=name,
|
|
||||||
formula=tree.to_string(), # canonicalize
|
|
||||||
expression_tree=tree,
|
|
||||||
category=category,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# Try to get a useful error message
|
|
||||||
error_msg = ""
|
|
||||||
try:
|
|
||||||
parse(formula)
|
|
||||||
except (SyntaxError, KeyError, ValueError) as e:
|
|
||||||
error_msg = str(e)
|
|
||||||
|
|
||||||
return CandidateFactor(
|
return CandidateFactor(
|
||||||
name=name,
|
name=name,
|
||||||
formula=formula,
|
formula=formula,
|
||||||
expression_tree=None,
|
|
||||||
category="unknown",
|
category="unknown",
|
||||||
parse_error=error_msg,
|
parse_error="括号不匹配",
|
||||||
)
|
)
|
||||||
|
category = _infer_category(formula)
|
||||||
|
return CandidateFactor(name=name, formula=formula, category=category)
|
||||||
|
|
||||||
|
|
||||||
def _generate_name_from_formula(formula: str, index: int) -> str:
|
def _generate_name_from_formula(formula: str, index: int) -> str:
|
||||||
"""Generate a descriptive name from a formula."""
|
"""Generate a descriptive name from a formula."""
|
||||||
# Extract the outermost operator
|
# Extract the outermost operator (snake_case)
|
||||||
m = re.match(r"([A-Z][a-zA-Z]*)\(", formula)
|
m = re.match(r"([a-zA-Z_][a-zA-Z0-9_]*)\s*\(", formula)
|
||||||
if m:
|
if m:
|
||||||
outer_op = m.group(1).lower()
|
outer_op = m.group(1).lower()
|
||||||
return f"{outer_op}_factor_{index + 1}"
|
return f"{outer_op}_factor_{index + 1}"
|
||||||
|
# Handle unary minus
|
||||||
|
m = re.match(r"-([a-zA-Z_][a-zA-Z0-9_]*)\s*\(", formula)
|
||||||
|
if m:
|
||||||
|
outer_op = m.group(1).lower()
|
||||||
|
return f"neg_{outer_op}_factor_{index + 1}"
|
||||||
return f"factor_{index + 1}"
|
return f"factor_{index + 1}"
|
||||||
|
|||||||
@@ -9,68 +9,81 @@ from __future__ import annotations
|
|||||||
|
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
from src.factorminer.core.types import (
|
|
||||||
FEATURES,
|
LOCAL_OPERATOR_TABLE = {
|
||||||
OPERATOR_REGISTRY,
|
"ARITHMETIC": [
|
||||||
OperatorSpec,
|
("+", "二元", "x + y"),
|
||||||
OperatorType,
|
("-", "二元/一元", "x - y 或 -x"),
|
||||||
)
|
("*", "二元", "x * y"),
|
||||||
|
("/", "二元", "x / y"),
|
||||||
|
("**", "二元", "x ** y (幂运算)"),
|
||||||
|
(">", "二元", "x > y (条件判断,返回 0/1)"),
|
||||||
|
("<", "二元", "x < y (条件判断,返回 0/1)"),
|
||||||
|
("abs(x)", "一元", "绝对值"),
|
||||||
|
("sign(x)", "一元", "符号函数"),
|
||||||
|
("max_(x, y)", "二元", "逐元素最大值"),
|
||||||
|
("min_(x, y)", "二元", "逐元素最小值"),
|
||||||
|
("clip(x, lower, upper)", "一元带参", "截断"),
|
||||||
|
("log(x)", "一元", "自然对数"),
|
||||||
|
("sqrt(x)", "一元", "平方根"),
|
||||||
|
("exp(x)", "一元", "指数函数"),
|
||||||
|
],
|
||||||
|
"TIMESERIES": [
|
||||||
|
("ts_mean(x, window)", "一元+窗口", "滚动均值"),
|
||||||
|
("ts_std(x, window)", "一元+窗口", "滚动标准差"),
|
||||||
|
("ts_var(x, window)", "一元+窗口", "滚动方差"),
|
||||||
|
("ts_max(x, window)", "一元+窗口", "滚动最大值"),
|
||||||
|
("ts_min(x, window)", "一元+窗口", "滚动最小值"),
|
||||||
|
("ts_sum(x, window)", "一元+窗口", "滚动求和"),
|
||||||
|
("ts_delay(x, periods)", "一元+周期", "滞后 N 期"),
|
||||||
|
("ts_delta(x, periods)", "一元+周期", "差分 N 期"),
|
||||||
|
("ts_corr(x, y, window)", "二元+窗口", "滚动相关系数"),
|
||||||
|
("ts_cov(x, y, window)", "二元+窗口", "滚动协方差"),
|
||||||
|
("ts_pct_change(x, periods)", "一元+周期", "N 期百分比变化"),
|
||||||
|
("ts_ema(x, window)", "一元+窗口", "指数移动平均"),
|
||||||
|
("ts_wma(x, window)", "一元+窗口", "加权移动平均"),
|
||||||
|
("ts_skew(x, window)", "一元+窗口", "滚动偏度"),
|
||||||
|
("ts_kurt(x, window)", "一元+窗口", "滚动峰度"),
|
||||||
|
("ts_rank(x, window)", "一元+窗口", "滚动分位排名"),
|
||||||
|
],
|
||||||
|
"CROSS_SECTIONAL": [
|
||||||
|
("cs_rank(x)", "一元", "截面排名(分位数)"),
|
||||||
|
("cs_zscore(x)", "一元", "截面 Z-Score 标准化"),
|
||||||
|
("cs_demean(x)", "一元", "截面去均值"),
|
||||||
|
("cs_neutralize(x, group)", "一元", "行业/市值中性化"),
|
||||||
|
("cs_winsorize(x, lower, upper)", "一元", "截面缩尾处理"),
|
||||||
|
],
|
||||||
|
"CONDITIONAL": [
|
||||||
|
("if_(condition, true_val, false_val)", "三元", "条件选择"),
|
||||||
|
("where(condition, true_val, false_val)", "三元", "if_ 的别名"),
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
def _format_operator_table() -> str:
|
def _format_operator_table() -> str:
|
||||||
"""Build a human-readable operator reference table grouped by category."""
|
"""Build a human-readable operator reference table grouped by category."""
|
||||||
grouped: Dict[str, List[OperatorSpec]] = {}
|
lines = []
|
||||||
for spec in OPERATOR_REGISTRY.values():
|
for cat_name, ops in LOCAL_OPERATOR_TABLE.items():
|
||||||
cat = spec.category.name
|
|
||||||
grouped.setdefault(cat, []).append(spec)
|
|
||||||
|
|
||||||
lines: List[str] = []
|
|
||||||
for cat_name in [
|
|
||||||
"ARITHMETIC",
|
|
||||||
"STATISTICAL",
|
|
||||||
"TIMESERIES",
|
|
||||||
"SMOOTHING",
|
|
||||||
"CROSS_SECTIONAL",
|
|
||||||
"REGRESSION",
|
|
||||||
"LOGICAL",
|
|
||||||
"AUTO_INVENTED",
|
|
||||||
]:
|
|
||||||
specs = grouped.get(cat_name, [])
|
|
||||||
if not specs:
|
|
||||||
continue
|
|
||||||
lines.append(f"\n### {cat_name} operators")
|
lines.append(f"\n### {cat_name} operators")
|
||||||
for spec in sorted(specs, key=lambda s: s.name):
|
for op_sig, arity, desc in ops:
|
||||||
params_str = ""
|
lines.append(f"- {op_sig}: {desc} ({arity})")
|
||||||
if spec.param_names:
|
|
||||||
parts = []
|
|
||||||
for pname in spec.param_names:
|
|
||||||
default = spec.param_defaults.get(pname, "")
|
|
||||||
lo, hi = spec.param_ranges.get(pname, (None, None))
|
|
||||||
range_str = f"[{lo}-{hi}]" if lo is not None else ""
|
|
||||||
parts.append(f"{pname}={default}{range_str}")
|
|
||||||
params_str = f" params: {', '.join(parts)}"
|
|
||||||
arity_args = ", ".join([f"expr{i+1}" for i in range(spec.arity)])
|
|
||||||
if spec.param_names:
|
|
||||||
arity_args += ", " + ", ".join(spec.param_names)
|
|
||||||
lines.append(f"- {spec.name}({arity_args}): {spec.description}{params_str}")
|
|
||||||
return "\n".join(lines)
|
return "\n".join(lines)
|
||||||
|
|
||||||
|
|
||||||
def _format_feature_list() -> str:
|
def _format_feature_list() -> str:
|
||||||
"""Build a description of available raw features."""
|
"""Build a description of available raw features."""
|
||||||
descriptions = {
|
descriptions = {
|
||||||
"$open": "opening price",
|
"open": "开盘价",
|
||||||
"$high": "highest price in the bar",
|
"high": "最高价",
|
||||||
"$low": "lowest price in the bar",
|
"low": "最低价",
|
||||||
"$close": "closing price",
|
"close": "收盘价",
|
||||||
"$volume": "trading volume (shares)",
|
"vol": "成交量(股数)",
|
||||||
"$amt": "trading amount (currency value)",
|
"amount": "成交额(金额)",
|
||||||
"$vwap": "volume-weighted average price",
|
"vwap": "可用 amount / vol 计算",
|
||||||
"$returns": "close-to-close returns",
|
"returns": "可用 close / ts_delay(close, 1) - 1 计算",
|
||||||
}
|
}
|
||||||
lines = []
|
lines = []
|
||||||
for feat in FEATURES:
|
for feat, desc in descriptions.items():
|
||||||
desc = descriptions.get(feat, "")
|
|
||||||
lines.append(f" {feat}: {desc}")
|
lines.append(f" {feat}: {desc}")
|
||||||
return "\n".join(lines)
|
return "\n".join(lines)
|
||||||
|
|
||||||
@@ -81,7 +94,7 @@ def _format_feature_list() -> str:
|
|||||||
|
|
||||||
SYSTEM_PROMPT = f"""You are a quantitative researcher mining formulaic alpha factors for stock selection.
|
SYSTEM_PROMPT = f"""You are a quantitative researcher mining formulaic alpha factors for stock selection.
|
||||||
|
|
||||||
Your goal is to generate novel, predictive factor expressions using a tree-structured domain-specific language (DSL). Each factor is a composition of operators applied to raw market features.
|
Your goal is to generate novel, predictive factor expressions using the local ProStock DSL. Each factor is a composition of operators applied to raw market features.
|
||||||
|
|
||||||
## RAW FEATURES (leaf nodes)
|
## RAW FEATURES (leaf nodes)
|
||||||
{_format_feature_list()}
|
{_format_feature_list()}
|
||||||
@@ -90,40 +103,37 @@ Your goal is to generate novel, predictive factor expressions using a tree-struc
|
|||||||
{_format_operator_table()}
|
{_format_operator_table()}
|
||||||
|
|
||||||
## EXPRESSION SYNTAX RULES
|
## EXPRESSION SYNTAX RULES
|
||||||
1. Every expression is a nested function call: Operator(args...)
|
1. Expressions use Python-style infix operators: +, -, *, /, **, >, <
|
||||||
2. Leaf nodes are raw features ($close, $volume, etc.) or numeric constants.
|
2. Function calls use snake_case names with comma-separated arguments: ts_mean(close, 20)
|
||||||
3. Operators are called by name with expression arguments first, then numeric parameters:
|
3. Window sizes and periods are numeric arguments placed last in function calls.
|
||||||
- Mean($close, 20) = 20-day rolling mean of $close
|
4. Valid window sizes are integers, typically in range [2, 250].
|
||||||
- Corr($close, $volume, 10) = 10-day rolling correlation of close and volume
|
5. Cross-sectional operators (cs_rank, cs_zscore, cs_demean) operate across all stocks at each time step -- they are crucial for making factors comparable.
|
||||||
- IfElse(Greater($returns, 0), $volume, Neg($volume)) = conditional
|
6. Do NOT use $ prefix for features. Use `close`, `vol`, `amount`, etc. directly.
|
||||||
4. No infix operators; use Add(x,y) instead of x+y, Sub(x,y) instead of x-y, etc.
|
7. `vwap` is not a raw feature; use `amount / vol` if you need it.
|
||||||
5. Parameters like window sizes are trailing numeric arguments after expression children.
|
8. `returns` is not a raw feature; use `close / ts_delay(close, 1) - 1` if you need returns.
|
||||||
6. Valid window sizes are integers; check each operator's parameter ranges above.
|
|
||||||
7. Cross-sectional operators (CsRank, CsZScore, CsDemean, CsScale, CsNeutralize) operate across all stocks at each time step -- they are crucial for making factors comparable.
|
|
||||||
|
|
||||||
## EXAMPLES OF WELL-FORMED FACTORS
|
## EXAMPLES OF WELL-FORMED FACTORS
|
||||||
- Neg(CsRank(Delta($close, 5)))
|
- -cs_rank(ts_delta(close, 5))
|
||||||
Short-term reversal: rank of 5-day price change, negated.
|
Short-term reversal: rank of 5-day price change, negated.
|
||||||
- CsZScore(Div(Sub($volume, Mean($volume, 20)), Std($volume, 20)))
|
- cs_zscore((vol - ts_mean(vol, 20)) / ts_std(vol, 20))
|
||||||
Volume surprise: standardized deviation from 20-day mean volume.
|
Volume surprise: standardized deviation from 20-day mean volume.
|
||||||
- CsRank(Div(Sub($close, $vwap), $vwap))
|
- cs_rank((close - amount / vol) / (amount / vol))
|
||||||
Intraday deviation from VWAP, cross-sectionally ranked.
|
Intraday deviation from VWAP, cross-sectionally ranked.
|
||||||
- Neg(Corr($volume, $close, 10))
|
- -ts_corr(vol, close, 10)
|
||||||
Negative price-volume correlation over 10 days.
|
Negative price-volume correlation over 10 days.
|
||||||
- CsRank(TsLinRegSlope($volume, 20))
|
- if_(close / ts_delay(close, 1) - 1 > 0, ts_std(close / ts_delay(close, 1) - 1, 10), -ts_std(close / ts_delay(close, 1) - 1, 10))
|
||||||
Trend in trading volume over 20 days, ranked.
|
|
||||||
- IfElse(Greater($returns, 0), Std($returns, 10), Neg(Std($returns, 10)))
|
|
||||||
Conditional volatility: positive for up-moves, negative for down-moves.
|
Conditional volatility: positive for up-moves, negative for down-moves.
|
||||||
- CsRank(Div(Sub($close, TsMin($low, 20)), Sub(TsMax($high, 20), TsMin($low, 20))))
|
- cs_rank((close - ts_min(low, 20)) / (ts_max(high, 20) - ts_min(low, 20)))
|
||||||
Position within 20-day price range, ranked.
|
Position within 20-day price range, ranked.
|
||||||
|
|
||||||
## KEY PRINCIPLES FOR HIGH-QUALITY FACTORS
|
## KEY PRINCIPLES FOR HIGH-QUALITY FACTORS
|
||||||
- Always wrap the outermost expression with a cross-sectional operator (CsRank, CsZScore) for comparability.
|
- Always wrap the outermost expression with a cross-sectional operator (cs_rank, cs_zscore) for comparability.
|
||||||
- Combine DIFFERENT operator types for novelty (e.g., time-series + cross-sectional + arithmetic).
|
- Combine DIFFERENT operator types for novelty (e.g., time-series + cross-sectional + arithmetic).
|
||||||
- Use diverse window sizes; avoid always defaulting to 10.
|
- Use diverse window sizes; avoid always defaulting to 10.
|
||||||
- Explore uncommon feature combinations ($amt, $vwap are underused).
|
- Explore uncommon feature combinations (amount, amount/vol are underused).
|
||||||
- Factors with depth 3-7 tend to be best: deep enough to capture non-trivial patterns but not so deep they overfit.
|
- Factors with depth 3-7 tend to be best: deep enough to capture non-trivial patterns but not so deep they overfit.
|
||||||
- Prefer economically meaningful combinations over random nesting.
|
- Prefer economically meaningful combinations over random nesting.
|
||||||
|
- IMPORTANT: Avoid operators that are NOT listed above (e.g., Decay, TsLinRegSlope, HMA, DEMA, Resid). If you use them, the factor will be rejected.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
@@ -131,6 +141,7 @@ Your goal is to generate novel, predictive factor expressions using a tree-struc
|
|||||||
# PromptBuilder
|
# PromptBuilder
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
def normalize_factor_references(entries: Optional[List[Any]]) -> List[str]:
|
def normalize_factor_references(entries: Optional[List[Any]]) -> List[str]:
|
||||||
"""Convert mixed factor metadata into prompt-safe string references."""
|
"""Convert mixed factor metadata into prompt-safe string references."""
|
||||||
if not entries:
|
if not entries:
|
||||||
@@ -220,13 +231,10 @@ class PromptBuilder:
|
|||||||
lib_size = library_state.get("size", 0)
|
lib_size = library_state.get("size", 0)
|
||||||
target = library_state.get("target_size", 110)
|
target = library_state.get("target_size", 110)
|
||||||
sections.append(
|
sections.append(
|
||||||
f"\n## CURRENT LIBRARY STATUS\n"
|
f"\n## CURRENT LIBRARY STATUS\nLibrary size: {lib_size} / {target} factors."
|
||||||
f"Library size: {lib_size} / {target} factors."
|
|
||||||
)
|
)
|
||||||
|
|
||||||
recent = normalize_factor_references(
|
recent = normalize_factor_references(library_state.get("recent_admissions", []))
|
||||||
library_state.get("recent_admissions", [])
|
|
||||||
)
|
|
||||||
if recent:
|
if recent:
|
||||||
sections.append(
|
sections.append(
|
||||||
"Recently admitted factors:\n"
|
"Recently admitted factors:\n"
|
||||||
@@ -235,10 +243,10 @@ class PromptBuilder:
|
|||||||
|
|
||||||
saturation = library_state.get("domain_saturation", {})
|
saturation = library_state.get("domain_saturation", {})
|
||||||
if saturation:
|
if saturation:
|
||||||
sat_lines = [f" {domain}: {pct:.0%} saturated" for domain, pct in saturation.items()]
|
sat_lines = [
|
||||||
sections.append(
|
f" {domain}: {pct:.0%} saturated" for domain, pct in saturation.items()
|
||||||
"Domain saturation:\n" + "\n".join(sat_lines)
|
]
|
||||||
)
|
sections.append("Domain saturation:\n" + "\n".join(sat_lines))
|
||||||
|
|
||||||
# --- Memory signal: recommended directions ---
|
# --- Memory signal: recommended directions ---
|
||||||
rec_dirs = memory_signal.get("recommended_directions", [])
|
rec_dirs = memory_signal.get("recommended_directions", [])
|
||||||
@@ -266,10 +274,7 @@ class PromptBuilder:
|
|||||||
|
|
||||||
helix_prompt_text = memory_signal.get("prompt_text", "").strip()
|
helix_prompt_text = memory_signal.get("prompt_text", "").strip()
|
||||||
if helix_prompt_text:
|
if helix_prompt_text:
|
||||||
sections.append(
|
sections.append(f"\n## HELIX RETRIEVAL SUMMARY\n{helix_prompt_text}")
|
||||||
"\n## HELIX RETRIEVAL SUMMARY\n"
|
|
||||||
f"{helix_prompt_text}"
|
|
||||||
)
|
|
||||||
|
|
||||||
complementary_patterns = memory_signal.get("complementary_patterns", [])
|
complementary_patterns = memory_signal.get("complementary_patterns", [])
|
||||||
if complementary_patterns:
|
if complementary_patterns:
|
||||||
@@ -330,8 +335,8 @@ class PromptBuilder:
|
|||||||
f"Output exactly {batch_size} factors, one per line.\n"
|
f"Output exactly {batch_size} factors, one per line.\n"
|
||||||
f"Format each line as: <number>. <factor_name>: <formula>\n"
|
f"Format each line as: <number>. <factor_name>: <formula>\n"
|
||||||
f"Example:\n"
|
f"Example:\n"
|
||||||
f"1. momentum_reversal: Neg(CsRank(Delta($close, 5)))\n"
|
f"1. momentum_reversal: -cs_rank(ts_delta(close, 5))\n"
|
||||||
f"2. volume_surprise: CsZScore(Div(Sub($volume, Mean($volume, 20)), Std($volume, 20)))\n"
|
f"2. volume_surprise: cs_zscore((vol - ts_mean(vol, 20)) / ts_std(vol, 20))\n"
|
||||||
f"\nRules:\n"
|
f"\nRules:\n"
|
||||||
f"- factor_name: lowercase_with_underscores, descriptive, unique\n"
|
f"- factor_name: lowercase_with_underscores, descriptive, unique\n"
|
||||||
f"- formula: valid DSL expression using ONLY operators and features listed above\n"
|
f"- formula: valid DSL expression using ONLY operators and features listed above\n"
|
||||||
@@ -346,6 +351,7 @@ class PromptBuilder:
|
|||||||
# New specialist/critic/debate prompt builder functions
|
# New specialist/critic/debate prompt builder functions
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
def build_specialist_prompt(
|
def build_specialist_prompt(
|
||||||
specialist_name: str,
|
specialist_name: str,
|
||||||
specialist_domain: str,
|
specialist_domain: str,
|
||||||
@@ -413,16 +419,12 @@ def build_specialist_prompt(
|
|||||||
|
|
||||||
# Regime context
|
# Regime context
|
||||||
if regime_context:
|
if regime_context:
|
||||||
sections.append(
|
sections.append(f"\n## CURRENT MARKET REGIME\n{regime_context}")
|
||||||
f"\n## CURRENT MARKET REGIME\n{regime_context}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Library state
|
# Library state
|
||||||
lib_size = library_diagnostics.get("size", 0)
|
lib_size = library_diagnostics.get("size", 0)
|
||||||
target = library_diagnostics.get("target_size", 110)
|
target = library_diagnostics.get("target_size", 110)
|
||||||
sections.append(
|
sections.append(f"\n## LIBRARY STATUS\nCurrent: {lib_size}/{target} factors.")
|
||||||
f"\n## LIBRARY STATUS\nCurrent: {lib_size}/{target} factors."
|
|
||||||
)
|
|
||||||
|
|
||||||
recent = normalize_factor_references(
|
recent = normalize_factor_references(
|
||||||
library_diagnostics.get("recent_admissions", [])
|
library_diagnostics.get("recent_admissions", [])
|
||||||
@@ -435,31 +437,26 @@ def build_specialist_prompt(
|
|||||||
|
|
||||||
saturation = library_diagnostics.get("domain_saturation", {})
|
saturation = library_diagnostics.get("domain_saturation", {})
|
||||||
if saturation:
|
if saturation:
|
||||||
sat_lines = [
|
sat_lines = [f" {d}: {p:.0%} saturated" for d, p in saturation.items()]
|
||||||
f" {d}: {p:.0%} saturated" for d, p in saturation.items()
|
|
||||||
]
|
|
||||||
sections.append("Domain saturation:\n" + "\n".join(sat_lines))
|
sections.append("Domain saturation:\n" + "\n".join(sat_lines))
|
||||||
|
|
||||||
# Memory signal injections
|
# Memory signal injections
|
||||||
rec_dirs = memory_signal.get("recommended_directions", [])
|
rec_dirs = memory_signal.get("recommended_directions", [])
|
||||||
if rec_dirs:
|
if rec_dirs:
|
||||||
sections.append(
|
sections.append(
|
||||||
"\n## RECOMMENDED DIRECTIONS\n"
|
"\n## RECOMMENDED DIRECTIONS\n" + "\n".join(f" * {d}" for d in rec_dirs)
|
||||||
+ "\n".join(f" * {d}" for d in rec_dirs)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
forbidden = memory_signal.get("forbidden_directions", [])
|
forbidden = memory_signal.get("forbidden_directions", [])
|
||||||
if forbidden:
|
if forbidden:
|
||||||
sections.append(
|
sections.append(
|
||||||
"\n## FORBIDDEN DIRECTIONS\n"
|
"\n## FORBIDDEN DIRECTIONS\n" + "\n".join(f" X {d}" for d in forbidden)
|
||||||
+ "\n".join(f" X {d}" for d in forbidden)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
insights = memory_signal.get("strategic_insights", [])
|
insights = memory_signal.get("strategic_insights", [])
|
||||||
if insights:
|
if insights:
|
||||||
sections.append(
|
sections.append(
|
||||||
"\n## STRATEGIC INSIGHTS\n"
|
"\n## STRATEGIC INSIGHTS\n" + "\n".join(f" - {ins}" for ins in insights)
|
||||||
+ "\n".join(f" - {ins}" for ins in insights)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
helix_text = memory_signal.get("prompt_text", "").strip()
|
helix_text = memory_signal.get("prompt_text", "").strip()
|
||||||
@@ -476,8 +473,7 @@ def build_specialist_prompt(
|
|||||||
warn = memory_signal.get("conflict_warnings", [])
|
warn = memory_signal.get("conflict_warnings", [])
|
||||||
if warn:
|
if warn:
|
||||||
sections.append(
|
sections.append(
|
||||||
"\n## SATURATION WARNINGS\n"
|
"\n## SATURATION WARNINGS\n" + "\n".join(f" ! {w}" for w in warn)
|
||||||
+ "\n".join(f" ! {w}" for w in warn)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
gaps = memory_signal.get("semantic_gaps", [])
|
gaps = memory_signal.get("semantic_gaps", [])
|
||||||
@@ -508,8 +504,7 @@ def build_specialist_prompt(
|
|||||||
# Avoid patterns
|
# Avoid patterns
|
||||||
if avoid_patterns:
|
if avoid_patterns:
|
||||||
sections.append(
|
sections.append(
|
||||||
"\n## PATTERNS TO AVOID\n"
|
"\n## PATTERNS TO AVOID\n" + "\n".join(f" X {av}" for av in avoid_patterns)
|
||||||
+ "\n".join(f" X {av}" for av in avoid_patterns)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Few-shot patterns from memory
|
# Few-shot patterns from memory
|
||||||
@@ -526,7 +521,7 @@ def build_specialist_prompt(
|
|||||||
f"\n## OUTPUT FORMAT\n"
|
f"\n## OUTPUT FORMAT\n"
|
||||||
f"Generate exactly {n_proposals} novel factor candidates.\n"
|
f"Generate exactly {n_proposals} novel factor candidates.\n"
|
||||||
f"Format: <number>. <factor_name>: <formula>\n"
|
f"Format: <number>. <factor_name>: <formula>\n"
|
||||||
f"Example: 1. momentum_reversal: Neg(CsRank(Delta($close, 5)))\n"
|
f"Example: 1. momentum_reversal: -cs_rank(ts_delta(close, 5))\n"
|
||||||
f"Rules:\n"
|
f"Rules:\n"
|
||||||
f"- factor_name: lowercase_with_underscores, unique, descriptive\n"
|
f"- factor_name: lowercase_with_underscores, unique, descriptive\n"
|
||||||
f"- formula: valid DSL expression only\n"
|
f"- formula: valid DSL expression only\n"
|
||||||
@@ -581,16 +576,16 @@ def build_critic_scoring_prompt(
|
|||||||
)
|
)
|
||||||
|
|
||||||
if memory_signal:
|
if memory_signal:
|
||||||
sections.append(f"\n## MEMORY CONTEXT (success patterns)\n{memory_signal[:600]}")
|
sections.append(
|
||||||
|
f"\n## MEMORY CONTEXT (success patterns)\n{memory_signal[:600]}"
|
||||||
|
)
|
||||||
|
|
||||||
sections.append("\n## CANDIDATES")
|
sections.append("\n## CANDIDATES")
|
||||||
for c in candidates:
|
for c in candidates:
|
||||||
name = c.get("name", "unknown")
|
name = c.get("name", "unknown")
|
||||||
formula = c.get("formula", "")
|
formula = c.get("formula", "")
|
||||||
specialist = c.get("specialist", "unknown")
|
specialist = c.get("specialist", "unknown")
|
||||||
sections.append(
|
sections.append(f" [{specialist}] {name}: {formula}")
|
||||||
f" [{specialist}] {name}: {formula}"
|
|
||||||
)
|
|
||||||
|
|
||||||
sections.append(
|
sections.append(
|
||||||
"\n## SCORING CRITERIA\n"
|
"\n## SCORING CRITERIA\n"
|
||||||
@@ -647,7 +642,7 @@ def build_debate_synthesis_prompt(
|
|||||||
all_proposals,
|
all_proposals,
|
||||||
key=lambda p: score_map.get(p.get("name", ""), 0.0),
|
key=lambda p: score_map.get(p.get("name", ""), 0.0),
|
||||||
reverse=True,
|
reverse=True,
|
||||||
)[:top_k * 2] # take 2x top_k for synthesis
|
)[: top_k * 2] # take 2x top_k for synthesis
|
||||||
|
|
||||||
sections: List[str] = []
|
sections: List[str] = []
|
||||||
sections.append(
|
sections.append(
|
||||||
@@ -664,9 +659,7 @@ def build_debate_synthesis_prompt(
|
|||||||
formula = p.get("formula", "?")
|
formula = p.get("formula", "?")
|
||||||
specialist = p.get("specialist", "?")
|
specialist = p.get("specialist", "?")
|
||||||
score = score_map.get(name, 0.5)
|
score = score_map.get(name, 0.5)
|
||||||
sections.append(
|
sections.append(f" [{specialist}, score={score:.2f}] {name}: {formula}")
|
||||||
f" [{specialist}, score={score:.2f}] {name}: {formula}"
|
|
||||||
)
|
|
||||||
|
|
||||||
sections.append(
|
sections.append(
|
||||||
f"\n## SELECTION CRITERIA\n"
|
f"\n## SELECTION CRITERIA\n"
|
||||||
|
|||||||
72
tests/test_factorminer_prompt.py
Normal file
72
tests/test_factorminer_prompt.py
Normal file
@@ -0,0 +1,72 @@
|
|||||||
|
"""Tests for local DSL prompt and output parser in Step 3."""
|
||||||
|
|
||||||
|
from src.factorminer.agent.output_parser import (
|
||||||
|
CandidateFactor,
|
||||||
|
_infer_category,
|
||||||
|
parse_llm_output,
|
||||||
|
)
|
||||||
|
from src.factorminer.agent.prompt_builder import SYSTEM_PROMPT
|
||||||
|
|
||||||
|
|
||||||
|
def test_system_prompt_uses_local_dsl():
|
||||||
|
assert "$close" not in SYSTEM_PROMPT
|
||||||
|
assert "CsRank(" not in SYSTEM_PROMPT
|
||||||
|
assert "cs_rank(" in SYSTEM_PROMPT
|
||||||
|
assert "close / ts_delay(close, 1) - 1" in SYSTEM_PROMPT
|
||||||
|
assert "ts_mean(close, 20)" in SYSTEM_PROMPT
|
||||||
|
|
||||||
|
|
||||||
|
def test_parse_local_dsl_numbered_list():
|
||||||
|
raw = (
|
||||||
|
"1. momentum: -cs_rank(ts_delta(close, 5))\n"
|
||||||
|
"2. volume: cs_zscore((vol - ts_mean(vol, 20)) / ts_std(vol, 20))\n"
|
||||||
|
"3. vwap_dev: cs_rank((close - amount / vol) / (amount / vol))\n"
|
||||||
|
)
|
||||||
|
candidates, failed = parse_llm_output(raw)
|
||||||
|
assert len(candidates) == 3
|
||||||
|
assert candidates[0].name == "momentum"
|
||||||
|
assert candidates[0].formula == "-cs_rank(ts_delta(close, 5))"
|
||||||
|
assert candidates[0].is_valid
|
||||||
|
assert candidates[1].name == "volume"
|
||||||
|
assert (
|
||||||
|
candidates[1].formula == "cs_zscore((vol - ts_mean(vol, 20)) / ts_std(vol, 20))"
|
||||||
|
)
|
||||||
|
assert candidates[2].name == "vwap_dev"
|
||||||
|
assert not failed
|
||||||
|
|
||||||
|
|
||||||
|
def test_parse_local_dsl_formula_only():
|
||||||
|
raw = "cs_rank(close / ts_delay(close, 5) - 1)"
|
||||||
|
candidates, failed = parse_llm_output(raw)
|
||||||
|
assert len(candidates) == 1
|
||||||
|
assert candidates[0].formula == "cs_rank(close / ts_delay(close, 5) - 1)"
|
||||||
|
assert not failed
|
||||||
|
|
||||||
|
|
||||||
|
def test_parse_invalid_parentheses():
|
||||||
|
candidates, failed = parse_llm_output("1. bad: cs_rank(ts_delta(close, 5)")
|
||||||
|
assert len(candidates) == 1
|
||||||
|
assert not candidates[0].is_valid
|
||||||
|
assert "括号" in candidates[0].parse_error
|
||||||
|
|
||||||
|
|
||||||
|
def test_infer_category_local_dsl():
|
||||||
|
assert _infer_category("cs_rank(ts_delta(close, 5))") == "cross_sectional_momentum"
|
||||||
|
assert _infer_category("ts_corr(vol, close, 10)") == "regression"
|
||||||
|
assert _infer_category("ts_std(close, 20)") == "volatility"
|
||||||
|
assert _infer_category("if_(close > open, 1, -1)") == "conditional"
|
||||||
|
|
||||||
|
|
||||||
|
def test_candidate_factor_is_valid_without_tree():
|
||||||
|
cf = CandidateFactor(name="test", formula="cs_rank(close)")
|
||||||
|
assert cf.is_valid
|
||||||
|
assert cf.category == "unknown"
|
||||||
|
|
||||||
|
|
||||||
|
def test_parse_json_local_dsl():
|
||||||
|
raw = '{"name": "mom", "formula": "cs_rank(close / ts_delay(close, 5) - 1)"}'
|
||||||
|
candidates, failed = parse_llm_output(raw)
|
||||||
|
assert len(candidates) == 1
|
||||||
|
assert candidates[0].name == "mom"
|
||||||
|
assert candidates[0].formula == "cs_rank(close / ts_delay(close, 5) - 1)"
|
||||||
|
assert not failed
|
||||||
Reference in New Issue
Block a user