Files
NewQuant/src/execution_simulator.py
2025-07-10 15:07:31 +08:00

420 lines
17 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# src/execution_simulator.py
from datetime import datetime
from typing import Dict, List, Optional
import pandas as pd
from .core_data import Order, Trade, Bar, PortfolioSnapshot
class ExecutionSimulator:
"""
模拟交易执行和管理账户资金、持仓。
"""
def __init__(
self,
initial_capital: float,
slippage_rate: float = 0.0001,
commission_rate: float = 0.0002,
initial_positions: Optional[Dict[str, int]] = None,
initial_average_costs: Optional[Dict[str, float]] = None,
): # 新增参数
self.initial_capital = initial_capital
self.cash = initial_capital
self.positions: Dict[str, int] = (
initial_positions if initial_positions is not None else {}
)
# 修正初始平均成本应该从参数传入而不是默认0.0
self.average_costs: Dict[str, float] = (
initial_average_costs if initial_average_costs is not None else {}
)
# 如果提供了 initial_positions 但没有提供 initial_average_costs可以警告或默认处理
if initial_positions and not initial_average_costs:
print(
f"[{datetime.now()}] 警告: 提供了初始持仓但未提供初始平均成本这些持仓的成本默认为0.0。"
)
for symbol, qty in initial_positions.items():
if symbol not in self.average_costs:
self.average_costs[symbol] = 0.0 # 如果没有提供默认给0
self.slippage_rate = slippage_rate
self.commission_rate = commission_rate
self.trade_log: List[Trade] = []
self.pending_orders: Dict[str, Order] = {}
self._current_time: Optional[datetime] = None
self.indicator_dict = {}
print(
f"模拟器初始化:初始资金={self.initial_capital:.2f}, 滑点率={self.slippage_rate}, 佣金率={self.commission_rate}"
)
if self.positions:
print(f"初始持仓:{self.positions}")
print(f"初始平均成本:{self.average_costs}") # 打印初始成本以便检查
def update_time(self, current_time: datetime):
self._current_time = current_time
def get_current_time(self) -> datetime:
if self._current_time is None:
return None
return self._current_time
def _calculate_fill_price(self, order: Order, current_bar: Bar) -> float:
"""
内部方法:根据订单类型和滑点计算实际成交价格。
撮合逻辑:
- 市价单以当前K线的 **开盘价 (open)** 为基准进行撮合,并考虑滑点。
- 限价单:判断 K 线的 **最高价 (high)** 和 **最低价 (low)** 是否触及限价。如果触及,则以 **限价 (limit_price)** 为基准计算成交价,并考虑滑点。
"""
fill_price = -1.0 # 默认未成交
# 对于市价单,仍然使用开盘价作为基准检查点
base_price_for_market_order = current_bar.open
if order.price_type == "MARKET":
# 市价单:直接以开盘价成交,考虑滑点
if (
order.direction == "BUY" or order.direction == "CLOSE_SHORT"
): # 买入/平空:向上偏离(多付)
fill_price = (base_price_for_market_order + 1) * (1 + self.slippage_rate)
elif (
order.direction == "SELL" or order.direction == "CLOSE_LONG"
): # 卖出/平多:向下偏离(少收)
fill_price = (base_price_for_market_order - 1) * (1 - self.slippage_rate)
else:
fill_price = base_price_for_market_order # 理论上不发生
elif order.price_type == "LIMIT" and order.limit_price is not None:
limit_price = order.limit_price
# 限价单:判断 K 线的高低价是否触及限价
if (
order.direction == "BUY" or order.direction == "CLOSE_SHORT"
): # 限价买入/平空
# 如果当前K线的最低价低于或等于限价则买入限价单有机会成交
if current_bar.low < limit_price:
# 成交价以限价为基准,并考虑滑点(买入向上偏离)
fill_price = limit_price * (1 + self.slippage_rate)
elif (
order.direction == "SELL" or order.direction == "CLOSE_LONG"
): # 限价卖出/平多
# 如果当前K线的最高价高于或等于限价则卖出限价单有机会成交
if current_bar.high > limit_price:
# 成交价以限价为基准,并考虑滑点(卖出向下偏离)
fill_price = limit_price * (1 - self.slippage_rate)
if fill_price <= 0:
return -1.0
return fill_price
def send_order_to_pending(self, order: Order) -> Optional[Order]:
"""
将订单添加到待处理队列。由 BacktestEngine 或 Strategy 调用。
此方法不进行撮合,撮合由 process_pending_orders 统一处理。
"""
if order.id in self.pending_orders:
return None
self.pending_orders[order.id] = order
return order
def process_pending_orders(self, current_bar: Bar, indicator_dict: Dict[str, float]):
"""
处理所有待撮合的订单。在每个K线数据到来时调用。
"""
order_ids_to_process = list(self.pending_orders.keys())
for order_id in order_ids_to_process:
if order_id not in self.pending_orders:
continue
order = self.pending_orders[order_id]
if order.symbol != current_bar.symbol:
continue
trade = self._execute_single_order(order, current_bar)
if trade:
self.trade_log.append(trade)
if trade.is_open_trade:
self.indicator_dict = indicator_dict
elif trade.is_close_trade:
trade.indicator_dict = self.indicator_dict.copy()
def _execute_single_order(self, order: Order, current_bar: Bar) -> Optional[Trade]:
"""
内部方法:尝试执行单个订单,并处理资金和持仓变化。
由 send_order 或 process_pending_orders 调用。
"""
if order.direction == "CANCEL":
success = self.cancel_order(order.id)
if success:
pass
return None
symbol = order.symbol
volume = order.volume
fill_price = self._calculate_fill_price(order, current_bar)
if fill_price <= 0:
return None
trade_value = volume * fill_price
commission = trade_value * self.commission_rate
current_position = self.positions.get(symbol, 0)
current_average_cost = self.average_costs.get(symbol, 0.0)
realized_pnl = 0.0
is_trade_a_close_operation = False
is_trade_an_open_operation = False
if order.direction in ["CLOSE_LONG", "CLOSE_SHORT"]:
is_trade_a_close_operation = True
elif order.direction == "BUY" and current_position < 0:
is_trade_a_close_operation = True
elif order.direction == "SELL" and current_position > 0:
is_trade_a_close_operation = True
if order.direction == "BUY":
if current_position >= 0 or (
current_position < 0 and (current_position + volume) > 0
):
is_trade_an_open_operation = True
elif order.direction == "SELL":
if current_position <= 0 or (
current_position > 0 and (current_position - volume) < 0
):
is_trade_an_open_operation = True
actual_execution_direction = ""
if order.direction == "BUY" or order.direction == "CLOSE_SHORT":
actual_execution_direction = "BUY"
elif order.direction == "SELL" or order.direction == "CLOSE_LONG":
actual_execution_direction = "SELL"
else:
print(
f"[{current_bar.datetime}] 模拟器: 收到未知订单方向 {order.direction} for Order ID: {order.id}. 订单未处理。"
)
if order.id in self.pending_orders:
del self.pending_orders[order.id]
return None
temp_cash = self.cash
temp_positions = self.positions.copy()
temp_average_costs = self.average_costs.copy()
if actual_execution_direction == "BUY":
if current_position >= 0:
required_cash = trade_value + commission
if temp_cash < required_cash:
# print(
# f"[{current_bar.datetime}] 模拟器: 资金不足 (开多), 无法执行买入 {volume} {symbol} @ {fill_price:.2f}. 需要: {required_cash:.2f}, 当前: {temp_cash:.2f}"
# )
if order.id in self.pending_orders:
del self.pending_orders[order.id]
return None
temp_cash -= required_cash
new_total_cost = (
temp_average_costs.get(symbol, 0.0) * temp_positions.get(symbol, 0)
) + (fill_price * volume)
new_total_volume = temp_positions.get(symbol, 0) + volume
temp_average_costs[symbol] = (
new_total_cost / new_total_volume if new_total_volume > 0 else 0.0
)
temp_positions[symbol] = new_total_volume
else: # 当前持有空仓 (平空) - 平仓交易
pnl_per_share = current_average_cost - fill_price
realized_pnl = pnl_per_share * volume
temp_cash -= commission
temp_cash -= trade_value
temp_cash += realized_pnl
temp_positions[symbol] += volume
if temp_positions[symbol] == 0:
del temp_positions[symbol]
if symbol in temp_average_costs:
del temp_average_costs[symbol]
elif current_position < 0 and temp_positions[symbol] > 0:
temp_average_costs[symbol] = fill_price
elif actual_execution_direction == "SELL":
if current_position <= 0: # 当前持有空仓或无仓位 (开空)
if temp_cash < commission:
# print(
# f"[{current_bar.datetime}] 模拟器: 资金不足 (开空佣金), 无法执行卖出 {volume} {symbol} @ {fill_price:.2f}. 佣金: {commission:.2f}, 当前: {temp_cash:.2f}"
# )
if order.id in self.pending_orders:
del self.pending_orders[order.id]
return None
temp_cash -= commission
temp_cash += trade_value # 修正点:开空时将卖出资金计入现金
existing_abs_volume = abs(temp_positions.get(symbol, 0))
existing_abs_cost = (
temp_average_costs.get(symbol, 0.0) * existing_abs_volume
)
new_total_value = existing_abs_cost + (fill_price * volume)
new_total_volume = existing_abs_volume + volume
temp_average_costs[symbol] = (
new_total_value / new_total_volume if new_total_volume > 0 else 0.0
)
temp_positions[symbol] = -new_total_volume
else: # 当前持有多仓 (平多) - 平仓交易
pnl_per_share = fill_price - current_average_cost
realized_pnl = pnl_per_share * volume
temp_cash -= commission
temp_cash += trade_value
temp_cash += realized_pnl
temp_positions[symbol] -= volume
if temp_positions[symbol] == 0:
del temp_positions[symbol]
if symbol in temp_average_costs:
del temp_average_costs[symbol]
elif current_position > 0 and temp_positions[symbol] < 0:
temp_average_costs[symbol] = fill_price
self.cash = temp_cash
self.positions = temp_positions
self.average_costs = temp_average_costs
executed_trade = Trade(
order_id=order.id,
fill_time=current_bar.datetime,
symbol=symbol,
direction=order.direction,
volume=volume,
price=fill_price,
commission=commission,
cash_after_trade=self.cash,
positions_after_trade=self.positions.copy(),
realized_pnl=realized_pnl,
is_open_trade=is_trade_an_open_operation,
is_close_trade=is_trade_a_close_operation,
)
if order.id in self.pending_orders:
del self.pending_orders[order.id]
return executed_trade
def cancel_order(self, order_id: str) -> bool:
if order_id in self.pending_orders:
del self.pending_orders[order_id]
return True
return False
def force_close_all_positions_for_symbol(
self, symbol_to_close: str, closing_bar: Bar
) -> List[Trade]:
closed_trades: List[Trade] = []
if symbol_to_close in self.positions and self.positions[symbol_to_close] != 0:
volume_to_close = self.positions[symbol_to_close]
direction = "CLOSE_LONG" if volume_to_close > 0 else "CLOSE_SHORT"
rollover_order = Order(
id=f"FORCE_CLOSE_{symbol_to_close}_{closing_bar.datetime.strftime('%Y%m%d%H%M%S%f')}",
symbol=symbol_to_close,
direction=direction,
volume=abs(volume_to_close),
price_type="MARKET",
limit_price=None,
submitted_time=closing_bar.datetime,
)
# 这里直接调用 _execute_single_order 确保强制平仓立即成交
trade = self._execute_single_order(rollover_order, closing_bar)
if trade:
closed_trades.append(trade)
else:
print(
f"[{closing_bar.datetime}] 警告: 强制平仓 {symbol_to_close} 失败!"
)
return closed_trades
def cancel_all_pending_orders_for_symbol(self, symbol_to_cancel: str) -> int:
cancelled_count = 0
order_ids_to_cancel = [
order_id
for order_id, order in self.pending_orders.items()
if order.symbol == symbol_to_cancel
]
for order_id in order_ids_to_cancel:
if self.cancel_order(order_id):
cancelled_count += 1
return cancelled_count
def get_pending_orders(self) -> Dict[str, Order]:
return self.pending_orders.copy()
def get_portfolio_value(self, current_bar: Bar) -> float:
total_value = self.cash
for symbol, quantity in self.positions.items():
if symbol == current_bar.symbol:
total_value += quantity * current_bar.open
else:
print(
f"[{current_bar.datetime}] 警告持仓中存在非当前K线合约 {symbol},无法准确计算其市值。"
)
pass
return total_value
def get_current_positions(self) -> Dict[str, int]:
return self.positions.copy()
def get_trade_history(self) -> List[Trade]:
return self.trade_log.copy()
def reset(
self,
new_initial_capital: float = None,
new_initial_positions: Dict[str, int] = None,
new_initial_average_costs: Dict[str, float] = None,
) -> None: # 新增参数
print("ExecutionSimulator: 重置状态。")
self.cash = (
new_initial_capital
if new_initial_capital is not None
else self.initial_capital
)
self.positions = (
new_initial_positions.copy() if new_initial_positions is not None else {}
)
# 修正:重置时也应该考虑传入初始平均成本
self.average_costs = (
new_initial_average_costs.copy()
if new_initial_average_costs is not None
else {}
)
if self.positions and not new_initial_average_costs:
print(
f"[{datetime.now()}] 警告: 重置时提供了初始持仓但未提供初始平均成本这些持仓的成本默认为0.0。"
)
for symbol, qty in self.positions.items():
if symbol not in self.average_costs:
self.average_costs[symbol] = 0.0
self.trade_log = []
self.pending_orders = {}
self._current_time = None
def get_average_position_price(self, symbol: str) -> Optional[float]:
if symbol in self.positions and self.positions[symbol] != 0:
return self.average_costs.get(symbol)
return None