feat(scripts): 添加因子批量注册脚本
- 新增 register_factors.py,支持通过 name/desc/dsl 自动注册因子 - 自动生成 F_XXX 格式 factor_id - 默认保存到 src/experiment/data/factors.jsonl
This commit is contained in:
244
src/scripts/register_factors.py
Normal file
244
src/scripts/register_factors.py
Normal file
@@ -0,0 +1,244 @@
|
|||||||
|
"""因子批量注册脚本。
|
||||||
|
|
||||||
|
使用 FactorManager 批量注册因子,用户只需提供 name、desc 和表达式,
|
||||||
|
自动生成 factor_id 并保存到 factors.jsonl。
|
||||||
|
|
||||||
|
使用方法:
|
||||||
|
1. 在 FACTORS 列表中添加因子定义
|
||||||
|
2. 运行: uv run python src/scripts/register_factors.py
|
||||||
|
|
||||||
|
示例:
|
||||||
|
FACTORS = [
|
||||||
|
{
|
||||||
|
"name": "mom_5d",
|
||||||
|
"desc": "5日价格动量",
|
||||||
|
"dsl": "cs_rank(close / ts_delay(close, 5) - 1)",
|
||||||
|
"category": "momentum", # 可选扩展字段
|
||||||
|
},
|
||||||
|
...
|
||||||
|
]
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import re
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
|
from src.factors.metadata import FactorManager
|
||||||
|
from src.factors.metadata.exceptions import DuplicateFactorError, ValidationError
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# 用户配置区域 - 在这里添加要注册的因子
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
FACTORS: List[Dict[str, Any]] = [
|
||||||
|
# 示例因子,请根据实际需要修改或添加
|
||||||
|
{
|
||||||
|
"name": "mom_5d",
|
||||||
|
"desc": "5日价格动量,收盘价相对于5日前收盘价的涨跌幅进行截面排名",
|
||||||
|
"dsl": "cs_rank(close / ts_delay(close, 5) - 1)",
|
||||||
|
"category": "momentum",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "mom_20d",
|
||||||
|
"desc": "20日价格动量,收盘价相对于20日前收盘价的涨跌幅进行截面排名",
|
||||||
|
"dsl": "cs_rank(close / ts_delay(close, 20) - 1)",
|
||||||
|
"category": "momentum",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "volatility_20d",
|
||||||
|
"desc": "20日价格波动率,收益率的20日滚动标准差",
|
||||||
|
"dsl": "ts_std(ts_delta(close, 1) / ts_delay(close, 1), 20)",
|
||||||
|
"category": "volatility",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "price_ma_ratio",
|
||||||
|
"desc": "价格与20日均线的偏离度",
|
||||||
|
"dsl": "close / ts_mean(close, 20) - 1",
|
||||||
|
"category": "mean_reversion",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "volume_ratio",
|
||||||
|
"desc": "成交量比率,当日成交量相对于20日均量的比值",
|
||||||
|
"dsl": "volume / ts_mean(volume, 20)",
|
||||||
|
"category": "volume",
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
# 因子存储路径(默认使用实验目录)
|
||||||
|
OUTPUT_PATH = Path(__file__).parent.parent / "experiment" / "data" / "factors.jsonl"
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# 核心实现
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
def get_next_factor_id(filepath: Path) -> str:
|
||||||
|
"""生成下一个 factor_id。
|
||||||
|
|
||||||
|
从现有文件中提取最大序号,生成新的 F_XXX 格式 ID。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
filepath: JSONL 文件路径
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
新的 factor_id,如 "F_001"
|
||||||
|
"""
|
||||||
|
if not filepath.exists():
|
||||||
|
return "F_001"
|
||||||
|
|
||||||
|
try:
|
||||||
|
with open(filepath, "r", encoding="utf-8") as f:
|
||||||
|
lines = f.readlines()
|
||||||
|
except Exception:
|
||||||
|
return "F_001"
|
||||||
|
|
||||||
|
max_num = 0
|
||||||
|
pattern = re.compile(r"^F_(\d+)$")
|
||||||
|
|
||||||
|
for line in lines:
|
||||||
|
line = line.strip()
|
||||||
|
if not line:
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
data = json.loads(line)
|
||||||
|
factor_id = data.get("factor_id", "")
|
||||||
|
match = pattern.match(factor_id)
|
||||||
|
if match:
|
||||||
|
num = int(match.group(1))
|
||||||
|
max_num = max(max_num, num)
|
||||||
|
except (json.JSONDecodeError, ValueError):
|
||||||
|
continue
|
||||||
|
|
||||||
|
return f"F_{max_num + 1:03d}"
|
||||||
|
|
||||||
|
|
||||||
|
def validate_factor(factor: Dict[str, Any]) -> None:
|
||||||
|
"""验证因子定义是否有效。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
factor: 因子定义字典
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: 验证失败时抛出
|
||||||
|
"""
|
||||||
|
required_fields = ["name", "desc", "dsl"]
|
||||||
|
for field in required_fields:
|
||||||
|
if field not in factor or not factor[field]:
|
||||||
|
raise ValueError(f"因子缺少必填字段 '{field}'")
|
||||||
|
|
||||||
|
# 验证 name 格式(只允许字母、数字、下划线)
|
||||||
|
if not re.match(r"^[a-zA-Z_][a-zA-Z0-9_]*$", factor["name"]):
|
||||||
|
raise ValueError(
|
||||||
|
f"因子名称 '{factor['name']}' 格式无效,只允许字母、数字、下划线"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def register_factors(
|
||||||
|
factors: List[Dict[str, Any]],
|
||||||
|
output_path: Optional[Path] = None,
|
||||||
|
skip_duplicates: bool = True,
|
||||||
|
) -> Dict[str, List[str]]:
|
||||||
|
"""批量注册因子。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
factors: 因子定义列表
|
||||||
|
output_path: 输出文件路径,默认使用 OUTPUT_PATH
|
||||||
|
skip_duplicates: 遇到重复因子是否跳过而不是报错
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
注册结果统计,包含成功列表和失败列表
|
||||||
|
"""
|
||||||
|
output_path = output_path or OUTPUT_PATH
|
||||||
|
manager = FactorManager(str(output_path))
|
||||||
|
|
||||||
|
results = {
|
||||||
|
"success": [],
|
||||||
|
"failed": [],
|
||||||
|
"skipped": [],
|
||||||
|
}
|
||||||
|
|
||||||
|
for factor in factors:
|
||||||
|
try:
|
||||||
|
# 验证因子定义
|
||||||
|
validate_factor(factor)
|
||||||
|
|
||||||
|
# 检查 name 是否已存在
|
||||||
|
existing = manager.get_factors_by_name(factor["name"])
|
||||||
|
if len(existing) > 0:
|
||||||
|
if skip_duplicates:
|
||||||
|
results["skipped"].append(factor["name"])
|
||||||
|
print(f"[跳过] 因子 '{factor['name']}' 已存在")
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
raise DuplicateFactorError(factor["name"])
|
||||||
|
|
||||||
|
# 生成 factor_id
|
||||||
|
factor_id = get_next_factor_id(output_path)
|
||||||
|
|
||||||
|
# 构建完整的因子记录
|
||||||
|
factor_record = {
|
||||||
|
"factor_id": factor_id,
|
||||||
|
"name": factor["name"],
|
||||||
|
"desc": factor["desc"],
|
||||||
|
"dsl": factor["dsl"],
|
||||||
|
}
|
||||||
|
|
||||||
|
# 添加可选扩展字段
|
||||||
|
for key in ["category", "author", "tags", "notes"]:
|
||||||
|
if key in factor:
|
||||||
|
factor_record[key] = factor[key]
|
||||||
|
|
||||||
|
# 注册因子
|
||||||
|
manager.add_factor(factor_record)
|
||||||
|
results["success"].append(factor["name"])
|
||||||
|
print(f"[成功] {factor_id}: {factor['name']}")
|
||||||
|
|
||||||
|
except DuplicateFactorError as e:
|
||||||
|
results["failed"].append(factor.get("name", "unknown"))
|
||||||
|
print(f"[失败] 因子 '{factor.get('name', 'unknown')}': {e}")
|
||||||
|
|
||||||
|
except (ValidationError, ValueError) as e:
|
||||||
|
results["failed"].append(factor.get("name", "unknown"))
|
||||||
|
print(f"[失败] 因子 '{factor.get('name', 'unknown')}': {e}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
results["failed"].append(factor.get("name", "unknown"))
|
||||||
|
print(f"[错误] 因子 '{factor.get('name', 'unknown')}': {e}")
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
"""主函数。"""
|
||||||
|
print("=" * 60)
|
||||||
|
print("因子批量注册工具")
|
||||||
|
print("=" * 60)
|
||||||
|
print(f"目标文件: {OUTPUT_PATH}")
|
||||||
|
print(f"待注册因子数: {len(FACTORS)}")
|
||||||
|
print("-" * 60)
|
||||||
|
|
||||||
|
if not FACTORS:
|
||||||
|
print("[警告] FACTORS 列表为空,请在脚本中配置要注册的因子")
|
||||||
|
return
|
||||||
|
|
||||||
|
results = register_factors(FACTORS)
|
||||||
|
|
||||||
|
print("-" * 60)
|
||||||
|
print("注册完成:")
|
||||||
|
print(f" 成功: {len(results['success'])} 个")
|
||||||
|
print(f" 跳过: {len(results['skipped'])} 个")
|
||||||
|
print(f" 失败: {len(results['failed'])} 个")
|
||||||
|
|
||||||
|
if results["success"]:
|
||||||
|
print(f"\n已注册因子: {', '.join(results['success'])}")
|
||||||
|
if results["skipped"]:
|
||||||
|
print(f"已跳过因子: {', '.join(results['skipped'])}")
|
||||||
|
if results["failed"]:
|
||||||
|
print(f"失败因子: {', '.join(results['failed'])}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
Reference in New Issue
Block a user