- 新增 config_models.py: 使用 Pydantic 提供强类型配置校验 - QMTConfig, QMTTerminalConfig, StrategyConfig 等数据模型 - 支持 slots/percentage 两种下单模式 - 兼容旧版配置格式迁移 - 新增 validate_config.py: 配置检测 CLI 工具 - 重构 TradingUnit 和 MultiEngineManager 使用新配置模型 - 新增百分比模式买卖逻辑 (_execute_percentage_buy/sell) - 完善日志记录和错误处理 - 删除 TODO_FIX.md: 清理已完成的缺陷修复任务清单
263 lines
8.5 KiB
Python
263 lines
8.5 KiB
Python
# coding: utf-8
|
||
"""
|
||
QMT 配置数据模型
|
||
|
||
使用 Pydantic 提供强类型配置校验,确保配置在加载时就被验证,
|
||
而不是在运行时才暴露问题。
|
||
"""
|
||
|
||
from pathlib import Path
|
||
from typing import Dict, List, Literal, Optional, Any
|
||
from pydantic import BaseModel, Field, field_validator, model_validator
|
||
|
||
|
||
class ConfigError(Exception):
|
||
"""配置错误异常"""
|
||
|
||
pass
|
||
|
||
|
||
class RedisConfig(BaseModel):
|
||
"""Redis 配置"""
|
||
|
||
host: str = Field(default="localhost", description="Redis 主机地址")
|
||
port: int = Field(default=6379, ge=1, le=65535, description="Redis 端口")
|
||
password: Optional[str] = Field(default=None, description="Redis 密码")
|
||
db: int = Field(default=0, ge=0, description="Redis 数据库编号")
|
||
|
||
|
||
class ExecutionConfig(BaseModel):
|
||
"""交易执行配置"""
|
||
|
||
buy_price_offset: float = Field(default=0.0, description="买入价格偏移")
|
||
sell_price_offset: float = Field(default=0.0, description="卖出价格偏移")
|
||
|
||
|
||
class StrategyConfig(BaseModel):
|
||
"""策略配置"""
|
||
|
||
qmt_id: str = Field(description="关联的 QMT 终端 ID")
|
||
order_mode: Literal["slots", "percentage"] = Field(
|
||
default="slots", description="下单模式: slots(槽位) 或 percentage(百分比)"
|
||
)
|
||
total_slots: Optional[int] = Field(
|
||
default=None, ge=1, description="总槽位数 (slots 模式必需)"
|
||
)
|
||
weight: int = Field(default=1, ge=1, description="资金权重")
|
||
execution: ExecutionConfig = Field(
|
||
default_factory=ExecutionConfig, description="交易执行配置"
|
||
)
|
||
|
||
@model_validator(mode="after")
|
||
def validate_slots_requirement(self):
|
||
"""校验 slots 模式必须有 total_slots"""
|
||
if self.order_mode == "slots" and self.total_slots is None:
|
||
raise ValueError(f"策略 '{self}' 使用 slots 模式,必须配置 total_slots")
|
||
return self
|
||
|
||
|
||
class QMTTerminalConfig(BaseModel):
|
||
"""QMT 终端配置"""
|
||
|
||
qmt_id: str = Field(description="终端唯一标识")
|
||
alias: Optional[str] = Field(default=None, description="终端别名")
|
||
path: str = Field(description="QMT 安装路径")
|
||
account_id: str = Field(description="资金账号")
|
||
account_type: Literal["STOCK", "FUTURE", "CREDIT"] = Field(
|
||
default="STOCK", description="账户类型"
|
||
)
|
||
|
||
@field_validator("path")
|
||
@classmethod
|
||
def validate_path_exists(cls, v: str) -> str:
|
||
"""校验路径存在性"""
|
||
path = Path(v)
|
||
if not path.exists():
|
||
raise ValueError(f"QMT 路径不存在: {v}")
|
||
return v
|
||
|
||
@field_validator("alias", mode="before")
|
||
@classmethod
|
||
def set_default_alias(cls, v: Optional[str], info) -> str:
|
||
"""如果 alias 未设置,使用 qmt_id 作为默认值"""
|
||
if v is None or v == "":
|
||
# 从其他字段获取 qmt_id
|
||
data = info.data
|
||
return data.get("qmt_id", "unknown")
|
||
return v
|
||
|
||
|
||
class AutoReconnectConfig(BaseModel):
|
||
"""自动重连配置"""
|
||
|
||
enabled: bool = Field(default=True, description="是否启用自动重连")
|
||
reconnect_time: str = Field(default="22:00", description="重连时间 (HH:MM)")
|
||
|
||
@field_validator("reconnect_time")
|
||
@classmethod
|
||
def validate_time_format(cls, v: str) -> str:
|
||
"""校验时间格式"""
|
||
import datetime
|
||
|
||
try:
|
||
datetime.datetime.strptime(v, "%H:%M")
|
||
except ValueError:
|
||
raise ValueError(f"时间格式错误: {v},应为 HH:MM 格式")
|
||
return v
|
||
|
||
|
||
class QMTConfig(BaseModel):
|
||
"""QMT 主配置"""
|
||
|
||
qmt_terminals: List[QMTTerminalConfig] = Field(description="QMT 终端列表")
|
||
strategies: Dict[str, StrategyConfig] = Field(
|
||
description="策略配置字典,key 为策略名"
|
||
)
|
||
auto_reconnect: Optional[AutoReconnectConfig] = Field(
|
||
default=None, description="自动重连配置"
|
||
)
|
||
|
||
@model_validator(mode="after")
|
||
def validate_strategy_terminal_refs(self):
|
||
"""校验策略引用的终端是否存在"""
|
||
terminal_ids = {t.qmt_id for t in self.qmt_terminals}
|
||
for name, strat in self.strategies.items():
|
||
if strat.qmt_id not in terminal_ids:
|
||
raise ValueError(
|
||
f"策略 '{name}' 引用了不存在的终端: '{strat.qmt_id}',"
|
||
f"可用终端: {list(terminal_ids)}"
|
||
)
|
||
return self
|
||
|
||
@model_validator(mode="after")
|
||
def validate_at_least_one_terminal(self):
|
||
"""校验至少有一个终端配置"""
|
||
if not self.qmt_terminals:
|
||
raise ValueError("必须配置至少一个 QMT 终端")
|
||
return self
|
||
|
||
@model_validator(mode="after")
|
||
def validate_at_least_one_strategy(self):
|
||
"""校验至少有一个策略配置"""
|
||
if not self.strategies:
|
||
raise ValueError("必须配置至少一个策略")
|
||
return self
|
||
|
||
def get_terminal(self, qmt_id: str) -> Optional[QMTTerminalConfig]:
|
||
"""根据 ID 获取终端配置"""
|
||
for t in self.qmt_terminals:
|
||
if t.qmt_id == qmt_id:
|
||
return t
|
||
return None
|
||
|
||
def get_strategies_by_terminal(self, qmt_id: str) -> List[str]:
|
||
"""获取指定终端关联的所有策略名"""
|
||
return [
|
||
name for name, strat in self.strategies.items() if strat.qmt_id == qmt_id
|
||
]
|
||
|
||
def get_strategy(self, name: str) -> Optional[StrategyConfig]:
|
||
"""获取策略配置"""
|
||
return self.strategies.get(name)
|
||
|
||
|
||
class ConfigLoader:
|
||
"""配置加载器"""
|
||
|
||
# 已知的顶层配置键
|
||
KNOWN_TOP_KEYS = {"qmt_terminals", "strategies", "auto_reconnect", "qmt"}
|
||
|
||
def __init__(self, config_path: str):
|
||
self.config_path = Path(config_path)
|
||
self._raw_data: Optional[Dict[str, Any]] = None
|
||
|
||
def load(self) -> QMTConfig:
|
||
"""
|
||
加载并校验配置
|
||
|
||
Returns:
|
||
QMTConfig: 校验后的配置对象
|
||
|
||
Raises:
|
||
ConfigError: 配置加载或校验失败
|
||
"""
|
||
# 1. 读取文件
|
||
if not self.config_path.exists():
|
||
raise ConfigError(f"配置文件不存在: {self.config_path}")
|
||
|
||
try:
|
||
import json
|
||
|
||
with open(self.config_path, "r", encoding="utf-8") as f:
|
||
self._raw_data = json.load(f)
|
||
except json.JSONDecodeError as e:
|
||
raise ConfigError(f"配置文件 JSON 格式错误: {e}")
|
||
except Exception as e:
|
||
raise ConfigError(f"读取配置文件失败: {e}")
|
||
|
||
# 2. 检查未知键(警告但不阻止)
|
||
unknown_keys = set(self._raw_data.keys()) - self.KNOWN_TOP_KEYS
|
||
if unknown_keys:
|
||
import logging
|
||
|
||
logger = logging.getLogger("QMT_Config")
|
||
logger.warning(f"配置文件中有未知的配置项将被忽略: {unknown_keys}")
|
||
|
||
# 3. 兼容旧版配置格式:将 qmt 转换为 qmt_terminals
|
||
data = self._migrate_legacy_config(self._raw_data)
|
||
|
||
# 4. Pydantic 校验
|
||
try:
|
||
config = QMTConfig.model_validate(data)
|
||
except Exception as e:
|
||
raise ConfigError(f"配置校验失败: {e}")
|
||
|
||
return config
|
||
|
||
def _migrate_legacy_config(self, data: Dict[str, Any]) -> Dict[str, Any]:
|
||
"""
|
||
兼容旧版配置格式
|
||
|
||
旧版使用单个 qmt 配置,新版使用 qmt_terminals 列表
|
||
"""
|
||
result = dict(data)
|
||
|
||
# 如果存在旧版 qmt 配置且没有 qmt_terminals
|
||
if "qmt" in result and "qmt_terminals" not in result:
|
||
legacy_qmt = result.pop("qmt")
|
||
# 转换为列表格式
|
||
result["qmt_terminals"] = [
|
||
{
|
||
"qmt_id": "default",
|
||
"alias": "default",
|
||
"path": legacy_qmt.get("path", ""),
|
||
"account_id": legacy_qmt.get("account_id", ""),
|
||
"account_type": legacy_qmt.get("account_type", "STOCK"),
|
||
}
|
||
]
|
||
|
||
# 为策略添加默认的 qmt_id
|
||
if "strategies" in result:
|
||
for name, strat in result["strategies"].items():
|
||
if isinstance(strat, dict) and "qmt_id" not in strat:
|
||
strat["qmt_id"] = "default"
|
||
|
||
return result
|
||
|
||
|
||
def load_config(config_path: str) -> QMTConfig:
|
||
"""
|
||
便捷函数:加载 QMT 配置
|
||
|
||
Args:
|
||
config_path: 配置文件路径
|
||
|
||
Returns:
|
||
QMTConfig: 校验后的配置对象
|
||
|
||
Raises:
|
||
ConfigError: 配置加载或校验失败
|
||
"""
|
||
loader = ConfigLoader(config_path)
|
||
return loader.load()
|