Files
NewStock/qmt/config_models.py

263 lines
8.5 KiB
Python
Raw Normal View History

# coding: utf-8
"""
QMT 配置数据模型
使用 Pydantic 提供强类型配置校验确保配置在加载时就被验证
而不是在运行时才暴露问题
"""
from pathlib import Path
from typing import Dict, List, Literal, Optional, Any
from pydantic import BaseModel, Field, field_validator, model_validator
class ConfigError(Exception):
"""配置错误异常"""
pass
class RedisConfig(BaseModel):
"""Redis 配置"""
host: str = Field(default="localhost", description="Redis 主机地址")
port: int = Field(default=6379, ge=1, le=65535, description="Redis 端口")
password: Optional[str] = Field(default=None, description="Redis 密码")
db: int = Field(default=0, ge=0, description="Redis 数据库编号")
class ExecutionConfig(BaseModel):
"""交易执行配置"""
buy_price_offset: float = Field(default=0.0, description="买入价格偏移")
sell_price_offset: float = Field(default=0.0, description="卖出价格偏移")
class StrategyConfig(BaseModel):
"""策略配置"""
qmt_id: str = Field(description="关联的 QMT 终端 ID")
order_mode: Literal["slots", "percentage"] = Field(
default="slots", description="下单模式: slots(槽位) 或 percentage(百分比)"
)
total_slots: Optional[int] = Field(
default=None, ge=1, description="总槽位数 (slots 模式必需)"
)
weight: int = Field(default=1, ge=1, description="资金权重")
execution: ExecutionConfig = Field(
default_factory=ExecutionConfig, description="交易执行配置"
)
@model_validator(mode="after")
def validate_slots_requirement(self):
"""校验 slots 模式必须有 total_slots"""
if self.order_mode == "slots" and self.total_slots is None:
raise ValueError(f"策略 '{self}' 使用 slots 模式,必须配置 total_slots")
return self
class QMTTerminalConfig(BaseModel):
"""QMT 终端配置"""
qmt_id: str = Field(description="终端唯一标识")
alias: Optional[str] = Field(default=None, description="终端别名")
path: str = Field(description="QMT 安装路径")
account_id: str = Field(description="资金账号")
account_type: Literal["STOCK", "FUTURE", "CREDIT"] = Field(
default="STOCK", description="账户类型"
)
@field_validator("path")
@classmethod
def validate_path_exists(cls, v: str) -> str:
"""校验路径存在性"""
path = Path(v)
if not path.exists():
raise ValueError(f"QMT 路径不存在: {v}")
return v
@field_validator("alias", mode="before")
@classmethod
def set_default_alias(cls, v: Optional[str], info) -> str:
"""如果 alias 未设置,使用 qmt_id 作为默认值"""
if v is None or v == "":
# 从其他字段获取 qmt_id
data = info.data
return data.get("qmt_id", "unknown")
return v
class AutoReconnectConfig(BaseModel):
"""自动重连配置"""
enabled: bool = Field(default=True, description="是否启用自动重连")
reconnect_time: str = Field(default="22:00", description="重连时间 (HH:MM)")
@field_validator("reconnect_time")
@classmethod
def validate_time_format(cls, v: str) -> str:
"""校验时间格式"""
import datetime
try:
datetime.datetime.strptime(v, "%H:%M")
except ValueError:
raise ValueError(f"时间格式错误: {v},应为 HH:MM 格式")
return v
class QMTConfig(BaseModel):
"""QMT 主配置"""
qmt_terminals: List[QMTTerminalConfig] = Field(description="QMT 终端列表")
strategies: Dict[str, StrategyConfig] = Field(
description="策略配置字典key 为策略名"
)
auto_reconnect: Optional[AutoReconnectConfig] = Field(
default=None, description="自动重连配置"
)
@model_validator(mode="after")
def validate_strategy_terminal_refs(self):
"""校验策略引用的终端是否存在"""
terminal_ids = {t.qmt_id for t in self.qmt_terminals}
for name, strat in self.strategies.items():
if strat.qmt_id not in terminal_ids:
raise ValueError(
f"策略 '{name}' 引用了不存在的终端: '{strat.qmt_id}'"
f"可用终端: {list(terminal_ids)}"
)
return self
@model_validator(mode="after")
def validate_at_least_one_terminal(self):
"""校验至少有一个终端配置"""
if not self.qmt_terminals:
raise ValueError("必须配置至少一个 QMT 终端")
return self
@model_validator(mode="after")
def validate_at_least_one_strategy(self):
"""校验至少有一个策略配置"""
if not self.strategies:
raise ValueError("必须配置至少一个策略")
return self
def get_terminal(self, qmt_id: str) -> Optional[QMTTerminalConfig]:
"""根据 ID 获取终端配置"""
for t in self.qmt_terminals:
if t.qmt_id == qmt_id:
return t
return None
def get_strategies_by_terminal(self, qmt_id: str) -> List[str]:
"""获取指定终端关联的所有策略名"""
return [
name for name, strat in self.strategies.items() if strat.qmt_id == qmt_id
]
def get_strategy(self, name: str) -> Optional[StrategyConfig]:
"""获取策略配置"""
return self.strategies.get(name)
class ConfigLoader:
"""配置加载器"""
# 已知的顶层配置键
KNOWN_TOP_KEYS = {"qmt_terminals", "strategies", "auto_reconnect", "qmt"}
def __init__(self, config_path: str):
self.config_path = Path(config_path)
self._raw_data: Optional[Dict[str, Any]] = None
def load(self) -> QMTConfig:
"""
加载并校验配置
Returns:
QMTConfig: 校验后的配置对象
Raises:
ConfigError: 配置加载或校验失败
"""
# 1. 读取文件
if not self.config_path.exists():
raise ConfigError(f"配置文件不存在: {self.config_path}")
try:
import json
with open(self.config_path, "r", encoding="utf-8") as f:
self._raw_data = json.load(f)
except json.JSONDecodeError as e:
raise ConfigError(f"配置文件 JSON 格式错误: {e}")
except Exception as e:
raise ConfigError(f"读取配置文件失败: {e}")
# 2. 检查未知键(警告但不阻止)
unknown_keys = set(self._raw_data.keys()) - self.KNOWN_TOP_KEYS
if unknown_keys:
import logging
logger = logging.getLogger("QMT_Config")
logger.warning(f"配置文件中有未知的配置项将被忽略: {unknown_keys}")
# 3. 兼容旧版配置格式:将 qmt 转换为 qmt_terminals
data = self._migrate_legacy_config(self._raw_data)
# 4. Pydantic 校验
try:
config = QMTConfig.model_validate(data)
except Exception as e:
raise ConfigError(f"配置校验失败: {e}")
return config
def _migrate_legacy_config(self, data: Dict[str, Any]) -> Dict[str, Any]:
"""
兼容旧版配置格式
旧版使用单个 qmt 配置新版使用 qmt_terminals 列表
"""
result = dict(data)
# 如果存在旧版 qmt 配置且没有 qmt_terminals
if "qmt" in result and "qmt_terminals" not in result:
legacy_qmt = result.pop("qmt")
# 转换为列表格式
result["qmt_terminals"] = [
{
"qmt_id": "default",
"alias": "default",
"path": legacy_qmt.get("path", ""),
"account_id": legacy_qmt.get("account_id", ""),
"account_type": legacy_qmt.get("account_type", "STOCK"),
}
]
# 为策略添加默认的 qmt_id
if "strategies" in result:
for name, strat in result["strategies"].items():
if isinstance(strat, dict) and "qmt_id" not in strat:
strat["qmt_id"] = "default"
return result
def load_config(config_path: str) -> QMTConfig:
"""
便捷函数加载 QMT 配置
Args:
config_path: 配置文件路径
Returns:
QMTConfig: 校验后的配置对象
Raises:
ConfigError: 配置加载或校验失败
"""
loader = ConfigLoader(config_path)
return loader.load()