feat(scripts): 添加因子批量注册脚本
- 新增 register_factors.py,支持通过 name/desc/dsl 自动注册因子 - 自动生成 F_XXX 格式 factor_id - 默认保存到 src/experiment/data/factors.jsonl
This commit is contained in:
244
src/scripts/register_factors.py
Normal file
244
src/scripts/register_factors.py
Normal file
@@ -0,0 +1,244 @@
|
||||
"""因子批量注册脚本。
|
||||
|
||||
使用 FactorManager 批量注册因子,用户只需提供 name、desc 和表达式,
|
||||
自动生成 factor_id 并保存到 factors.jsonl。
|
||||
|
||||
使用方法:
|
||||
1. 在 FACTORS 列表中添加因子定义
|
||||
2. 运行: uv run python src/scripts/register_factors.py
|
||||
|
||||
示例:
|
||||
FACTORS = [
|
||||
{
|
||||
"name": "mom_5d",
|
||||
"desc": "5日价格动量",
|
||||
"dsl": "cs_rank(close / ts_delay(close, 5) - 1)",
|
||||
"category": "momentum", # 可选扩展字段
|
||||
},
|
||||
...
|
||||
]
|
||||
"""
|
||||
|
||||
import json
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from src.factors.metadata import FactorManager
|
||||
from src.factors.metadata.exceptions import DuplicateFactorError, ValidationError
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# 用户配置区域 - 在这里添加要注册的因子
|
||||
# ============================================================================
|
||||
|
||||
FACTORS: List[Dict[str, Any]] = [
|
||||
# 示例因子,请根据实际需要修改或添加
|
||||
{
|
||||
"name": "mom_5d",
|
||||
"desc": "5日价格动量,收盘价相对于5日前收盘价的涨跌幅进行截面排名",
|
||||
"dsl": "cs_rank(close / ts_delay(close, 5) - 1)",
|
||||
"category": "momentum",
|
||||
},
|
||||
{
|
||||
"name": "mom_20d",
|
||||
"desc": "20日价格动量,收盘价相对于20日前收盘价的涨跌幅进行截面排名",
|
||||
"dsl": "cs_rank(close / ts_delay(close, 20) - 1)",
|
||||
"category": "momentum",
|
||||
},
|
||||
{
|
||||
"name": "volatility_20d",
|
||||
"desc": "20日价格波动率,收益率的20日滚动标准差",
|
||||
"dsl": "ts_std(ts_delta(close, 1) / ts_delay(close, 1), 20)",
|
||||
"category": "volatility",
|
||||
},
|
||||
{
|
||||
"name": "price_ma_ratio",
|
||||
"desc": "价格与20日均线的偏离度",
|
||||
"dsl": "close / ts_mean(close, 20) - 1",
|
||||
"category": "mean_reversion",
|
||||
},
|
||||
{
|
||||
"name": "volume_ratio",
|
||||
"desc": "成交量比率,当日成交量相对于20日均量的比值",
|
||||
"dsl": "volume / ts_mean(volume, 20)",
|
||||
"category": "volume",
|
||||
},
|
||||
]
|
||||
|
||||
# 因子存储路径(默认使用实验目录)
|
||||
OUTPUT_PATH = Path(__file__).parent.parent / "experiment" / "data" / "factors.jsonl"
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# 核心实现
|
||||
# ============================================================================
|
||||
|
||||
|
||||
def get_next_factor_id(filepath: Path) -> str:
|
||||
"""生成下一个 factor_id。
|
||||
|
||||
从现有文件中提取最大序号,生成新的 F_XXX 格式 ID。
|
||||
|
||||
Args:
|
||||
filepath: JSONL 文件路径
|
||||
|
||||
Returns:
|
||||
新的 factor_id,如 "F_001"
|
||||
"""
|
||||
if not filepath.exists():
|
||||
return "F_001"
|
||||
|
||||
try:
|
||||
with open(filepath, "r", encoding="utf-8") as f:
|
||||
lines = f.readlines()
|
||||
except Exception:
|
||||
return "F_001"
|
||||
|
||||
max_num = 0
|
||||
pattern = re.compile(r"^F_(\d+)$")
|
||||
|
||||
for line in lines:
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
try:
|
||||
data = json.loads(line)
|
||||
factor_id = data.get("factor_id", "")
|
||||
match = pattern.match(factor_id)
|
||||
if match:
|
||||
num = int(match.group(1))
|
||||
max_num = max(max_num, num)
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
continue
|
||||
|
||||
return f"F_{max_num + 1:03d}"
|
||||
|
||||
|
||||
def validate_factor(factor: Dict[str, Any]) -> None:
|
||||
"""验证因子定义是否有效。
|
||||
|
||||
Args:
|
||||
factor: 因子定义字典
|
||||
|
||||
Raises:
|
||||
ValueError: 验证失败时抛出
|
||||
"""
|
||||
required_fields = ["name", "desc", "dsl"]
|
||||
for field in required_fields:
|
||||
if field not in factor or not factor[field]:
|
||||
raise ValueError(f"因子缺少必填字段 '{field}'")
|
||||
|
||||
# 验证 name 格式(只允许字母、数字、下划线)
|
||||
if not re.match(r"^[a-zA-Z_][a-zA-Z0-9_]*$", factor["name"]):
|
||||
raise ValueError(
|
||||
f"因子名称 '{factor['name']}' 格式无效,只允许字母、数字、下划线"
|
||||
)
|
||||
|
||||
|
||||
def register_factors(
|
||||
factors: List[Dict[str, Any]],
|
||||
output_path: Optional[Path] = None,
|
||||
skip_duplicates: bool = True,
|
||||
) -> Dict[str, List[str]]:
|
||||
"""批量注册因子。
|
||||
|
||||
Args:
|
||||
factors: 因子定义列表
|
||||
output_path: 输出文件路径,默认使用 OUTPUT_PATH
|
||||
skip_duplicates: 遇到重复因子是否跳过而不是报错
|
||||
|
||||
Returns:
|
||||
注册结果统计,包含成功列表和失败列表
|
||||
"""
|
||||
output_path = output_path or OUTPUT_PATH
|
||||
manager = FactorManager(str(output_path))
|
||||
|
||||
results = {
|
||||
"success": [],
|
||||
"failed": [],
|
||||
"skipped": [],
|
||||
}
|
||||
|
||||
for factor in factors:
|
||||
try:
|
||||
# 验证因子定义
|
||||
validate_factor(factor)
|
||||
|
||||
# 检查 name 是否已存在
|
||||
existing = manager.get_factors_by_name(factor["name"])
|
||||
if len(existing) > 0:
|
||||
if skip_duplicates:
|
||||
results["skipped"].append(factor["name"])
|
||||
print(f"[跳过] 因子 '{factor['name']}' 已存在")
|
||||
continue
|
||||
else:
|
||||
raise DuplicateFactorError(factor["name"])
|
||||
|
||||
# 生成 factor_id
|
||||
factor_id = get_next_factor_id(output_path)
|
||||
|
||||
# 构建完整的因子记录
|
||||
factor_record = {
|
||||
"factor_id": factor_id,
|
||||
"name": factor["name"],
|
||||
"desc": factor["desc"],
|
||||
"dsl": factor["dsl"],
|
||||
}
|
||||
|
||||
# 添加可选扩展字段
|
||||
for key in ["category", "author", "tags", "notes"]:
|
||||
if key in factor:
|
||||
factor_record[key] = factor[key]
|
||||
|
||||
# 注册因子
|
||||
manager.add_factor(factor_record)
|
||||
results["success"].append(factor["name"])
|
||||
print(f"[成功] {factor_id}: {factor['name']}")
|
||||
|
||||
except DuplicateFactorError as e:
|
||||
results["failed"].append(factor.get("name", "unknown"))
|
||||
print(f"[失败] 因子 '{factor.get('name', 'unknown')}': {e}")
|
||||
|
||||
except (ValidationError, ValueError) as e:
|
||||
results["failed"].append(factor.get("name", "unknown"))
|
||||
print(f"[失败] 因子 '{factor.get('name', 'unknown')}': {e}")
|
||||
|
||||
except Exception as e:
|
||||
results["failed"].append(factor.get("name", "unknown"))
|
||||
print(f"[错误] 因子 '{factor.get('name', 'unknown')}': {e}")
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def main():
|
||||
"""主函数。"""
|
||||
print("=" * 60)
|
||||
print("因子批量注册工具")
|
||||
print("=" * 60)
|
||||
print(f"目标文件: {OUTPUT_PATH}")
|
||||
print(f"待注册因子数: {len(FACTORS)}")
|
||||
print("-" * 60)
|
||||
|
||||
if not FACTORS:
|
||||
print("[警告] FACTORS 列表为空,请在脚本中配置要注册的因子")
|
||||
return
|
||||
|
||||
results = register_factors(FACTORS)
|
||||
|
||||
print("-" * 60)
|
||||
print("注册完成:")
|
||||
print(f" 成功: {len(results['success'])} 个")
|
||||
print(f" 跳过: {len(results['skipped'])} 个")
|
||||
print(f" 失败: {len(results['failed'])} 个")
|
||||
|
||||
if results["success"]:
|
||||
print(f"\n已注册因子: {', '.join(results['success'])}")
|
||||
if results["skipped"]:
|
||||
print(f"已跳过因子: {', '.join(results['skipped'])}")
|
||||
if results["failed"]:
|
||||
print(f"失败因子: {', '.join(results['failed'])}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user