Files
NewStock/qmt/config_models.py
liaozhaorun e88ba5bcf9 feat(qmt): 新增 Pydantic 配置模型并重构引擎架构
- 新增 config_models.py: 使用 Pydantic 提供强类型配置校验
  - QMTConfig, QMTTerminalConfig, StrategyConfig 等数据模型
  - 支持 slots/percentage 两种下单模式
  - 兼容旧版配置格式迁移
- 新增 validate_config.py: 配置检测 CLI 工具
- 重构 TradingUnit 和 MultiEngineManager 使用新配置模型
- 新增百分比模式买卖逻辑 (_execute_percentage_buy/sell)
- 完善日志记录和错误处理
- 删除 TODO_FIX.md: 清理已完成的缺陷修复任务清单
2026-02-25 21:48:22 +08:00

263 lines
8.5 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# coding: utf-8
"""
QMT 配置数据模型
使用 Pydantic 提供强类型配置校验,确保配置在加载时就被验证,
而不是在运行时才暴露问题。
"""
from pathlib import Path
from typing import Dict, List, Literal, Optional, Any
from pydantic import BaseModel, Field, field_validator, model_validator
class ConfigError(Exception):
"""配置错误异常"""
pass
class RedisConfig(BaseModel):
"""Redis 配置"""
host: str = Field(default="localhost", description="Redis 主机地址")
port: int = Field(default=6379, ge=1, le=65535, description="Redis 端口")
password: Optional[str] = Field(default=None, description="Redis 密码")
db: int = Field(default=0, ge=0, description="Redis 数据库编号")
class ExecutionConfig(BaseModel):
"""交易执行配置"""
buy_price_offset: float = Field(default=0.0, description="买入价格偏移")
sell_price_offset: float = Field(default=0.0, description="卖出价格偏移")
class StrategyConfig(BaseModel):
"""策略配置"""
qmt_id: str = Field(description="关联的 QMT 终端 ID")
order_mode: Literal["slots", "percentage"] = Field(
default="slots", description="下单模式: slots(槽位) 或 percentage(百分比)"
)
total_slots: Optional[int] = Field(
default=None, ge=1, description="总槽位数 (slots 模式必需)"
)
weight: int = Field(default=1, ge=1, description="资金权重")
execution: ExecutionConfig = Field(
default_factory=ExecutionConfig, description="交易执行配置"
)
@model_validator(mode="after")
def validate_slots_requirement(self):
"""校验 slots 模式必须有 total_slots"""
if self.order_mode == "slots" and self.total_slots is None:
raise ValueError(f"策略 '{self}' 使用 slots 模式,必须配置 total_slots")
return self
class QMTTerminalConfig(BaseModel):
"""QMT 终端配置"""
qmt_id: str = Field(description="终端唯一标识")
alias: Optional[str] = Field(default=None, description="终端别名")
path: str = Field(description="QMT 安装路径")
account_id: str = Field(description="资金账号")
account_type: Literal["STOCK", "FUTURE", "CREDIT"] = Field(
default="STOCK", description="账户类型"
)
@field_validator("path")
@classmethod
def validate_path_exists(cls, v: str) -> str:
"""校验路径存在性"""
path = Path(v)
if not path.exists():
raise ValueError(f"QMT 路径不存在: {v}")
return v
@field_validator("alias", mode="before")
@classmethod
def set_default_alias(cls, v: Optional[str], info) -> str:
"""如果 alias 未设置,使用 qmt_id 作为默认值"""
if v is None or v == "":
# 从其他字段获取 qmt_id
data = info.data
return data.get("qmt_id", "unknown")
return v
class AutoReconnectConfig(BaseModel):
"""自动重连配置"""
enabled: bool = Field(default=True, description="是否启用自动重连")
reconnect_time: str = Field(default="22:00", description="重连时间 (HH:MM)")
@field_validator("reconnect_time")
@classmethod
def validate_time_format(cls, v: str) -> str:
"""校验时间格式"""
import datetime
try:
datetime.datetime.strptime(v, "%H:%M")
except ValueError:
raise ValueError(f"时间格式错误: {v},应为 HH:MM 格式")
return v
class QMTConfig(BaseModel):
"""QMT 主配置"""
qmt_terminals: List[QMTTerminalConfig] = Field(description="QMT 终端列表")
strategies: Dict[str, StrategyConfig] = Field(
description="策略配置字典key 为策略名"
)
auto_reconnect: Optional[AutoReconnectConfig] = Field(
default=None, description="自动重连配置"
)
@model_validator(mode="after")
def validate_strategy_terminal_refs(self):
"""校验策略引用的终端是否存在"""
terminal_ids = {t.qmt_id for t in self.qmt_terminals}
for name, strat in self.strategies.items():
if strat.qmt_id not in terminal_ids:
raise ValueError(
f"策略 '{name}' 引用了不存在的终端: '{strat.qmt_id}'"
f"可用终端: {list(terminal_ids)}"
)
return self
@model_validator(mode="after")
def validate_at_least_one_terminal(self):
"""校验至少有一个终端配置"""
if not self.qmt_terminals:
raise ValueError("必须配置至少一个 QMT 终端")
return self
@model_validator(mode="after")
def validate_at_least_one_strategy(self):
"""校验至少有一个策略配置"""
if not self.strategies:
raise ValueError("必须配置至少一个策略")
return self
def get_terminal(self, qmt_id: str) -> Optional[QMTTerminalConfig]:
"""根据 ID 获取终端配置"""
for t in self.qmt_terminals:
if t.qmt_id == qmt_id:
return t
return None
def get_strategies_by_terminal(self, qmt_id: str) -> List[str]:
"""获取指定终端关联的所有策略名"""
return [
name for name, strat in self.strategies.items() if strat.qmt_id == qmt_id
]
def get_strategy(self, name: str) -> Optional[StrategyConfig]:
"""获取策略配置"""
return self.strategies.get(name)
class ConfigLoader:
"""配置加载器"""
# 已知的顶层配置键
KNOWN_TOP_KEYS = {"qmt_terminals", "strategies", "auto_reconnect", "qmt"}
def __init__(self, config_path: str):
self.config_path = Path(config_path)
self._raw_data: Optional[Dict[str, Any]] = None
def load(self) -> QMTConfig:
"""
加载并校验配置
Returns:
QMTConfig: 校验后的配置对象
Raises:
ConfigError: 配置加载或校验失败
"""
# 1. 读取文件
if not self.config_path.exists():
raise ConfigError(f"配置文件不存在: {self.config_path}")
try:
import json
with open(self.config_path, "r", encoding="utf-8") as f:
self._raw_data = json.load(f)
except json.JSONDecodeError as e:
raise ConfigError(f"配置文件 JSON 格式错误: {e}")
except Exception as e:
raise ConfigError(f"读取配置文件失败: {e}")
# 2. 检查未知键(警告但不阻止)
unknown_keys = set(self._raw_data.keys()) - self.KNOWN_TOP_KEYS
if unknown_keys:
import logging
logger = logging.getLogger("QMT_Config")
logger.warning(f"配置文件中有未知的配置项将被忽略: {unknown_keys}")
# 3. 兼容旧版配置格式:将 qmt 转换为 qmt_terminals
data = self._migrate_legacy_config(self._raw_data)
# 4. Pydantic 校验
try:
config = QMTConfig.model_validate(data)
except Exception as e:
raise ConfigError(f"配置校验失败: {e}")
return config
def _migrate_legacy_config(self, data: Dict[str, Any]) -> Dict[str, Any]:
"""
兼容旧版配置格式
旧版使用单个 qmt 配置,新版使用 qmt_terminals 列表
"""
result = dict(data)
# 如果存在旧版 qmt 配置且没有 qmt_terminals
if "qmt" in result and "qmt_terminals" not in result:
legacy_qmt = result.pop("qmt")
# 转换为列表格式
result["qmt_terminals"] = [
{
"qmt_id": "default",
"alias": "default",
"path": legacy_qmt.get("path", ""),
"account_id": legacy_qmt.get("account_id", ""),
"account_type": legacy_qmt.get("account_type", "STOCK"),
}
]
# 为策略添加默认的 qmt_id
if "strategies" in result:
for name, strat in result["strategies"].items():
if isinstance(strat, dict) and "qmt_id" not in strat:
strat["qmt_id"] = "default"
return result
def load_config(config_path: str) -> QMTConfig:
"""
便捷函数:加载 QMT 配置
Args:
config_path: 配置文件路径
Returns:
QMTConfig: 校验后的配置对象
Raises:
ConfigError: 配置加载或校验失败
"""
loader = ConfigLoader(config_path)
return loader.load()