# src/execution_simulator.py from datetime import datetime from typing import Dict, List, Optional import pandas as pd from .core_data import Order, Trade, Bar, PortfolioSnapshot class ExecutionSimulator: """ 模拟交易执行和管理账户资金、持仓。 """ def __init__( self, initial_capital: float, slippage_rate: float = 0.0001, commission_rate: float = 0.0002, initial_positions: Optional[Dict[str, int]] = None, initial_average_costs: Optional[Dict[str, float]] = None, ): # 新增参数 self.initial_capital = initial_capital self.cash = initial_capital self.positions: Dict[str, int] = ( initial_positions if initial_positions is not None else {} ) # 修正:初始平均成本应该从参数传入,而不是默认0.0 self.average_costs: Dict[str, float] = ( initial_average_costs if initial_average_costs is not None else {} ) # 如果提供了 initial_positions 但没有提供 initial_average_costs,可以警告或默认处理 if initial_positions and not initial_average_costs: print( f"[{datetime.now()}] 警告: 提供了初始持仓但未提供初始平均成本,这些持仓的成本默认为0.0。" ) for symbol, qty in initial_positions.items(): if symbol not in self.average_costs: self.average_costs[symbol] = 0.0 # 如果没有提供,默认给0 self.slippage_rate = slippage_rate self.commission_rate = commission_rate self.trade_log: List[Trade] = [] self.pending_orders: Dict[str, Order] = {} self._current_time: Optional[datetime] = None self.indicator_dict = {} print( f"模拟器初始化:初始资金={self.initial_capital:.2f}, 滑点率={self.slippage_rate}, 佣金率={self.commission_rate}" ) if self.positions: print(f"初始持仓:{self.positions}") print(f"初始平均成本:{self.average_costs}") # 打印初始成本以便检查 def update_time(self, current_time: datetime): self._current_time = current_time def get_current_time(self) -> datetime: if self._current_time is None: return None return self._current_time def _calculate_fill_price(self, order: Order, current_bar: Bar) -> float: """ 内部方法:根据订单类型和滑点计算实际成交价格。 撮合逻辑: - 市价单:以当前K线的 **开盘价 (open)** 为基准进行撮合,并考虑滑点。 - 限价单:判断 K 线的 **最高价 (high)** 和 **最低价 (low)** 是否触及限价。如果触及,则以 **限价 (limit_price)** 为基准计算成交价,并考虑滑点。 """ fill_price = -1.0 # 默认未成交 # 对于市价单,仍然使用开盘价作为基准检查点 base_price_for_market_order = current_bar.open if order.price_type == "MARKET": # 市价单:直接以开盘价成交,考虑滑点 if ( order.direction == "BUY" or order.direction == "CLOSE_SHORT" ): # 买入/平空:向上偏离(多付) fill_price = (base_price_for_market_order + 1) * (1 + self.slippage_rate) elif ( order.direction == "SELL" or order.direction == "CLOSE_LONG" ): # 卖出/平多:向下偏离(少收) fill_price = (base_price_for_market_order - 1) * (1 - self.slippage_rate) else: fill_price = base_price_for_market_order # 理论上不发生 elif order.price_type == "LIMIT" and order.limit_price is not None: limit_price = order.limit_price # 限价单:判断 K 线的高低价是否触及限价 if ( order.direction == "BUY" or order.direction == "CLOSE_SHORT" ): # 限价买入/平空 # 如果当前K线的最低价低于或等于限价,则买入限价单有机会成交 if current_bar.low < limit_price: # 成交价以限价为基准,并考虑滑点(买入向上偏离) fill_price = limit_price * (1 + self.slippage_rate) elif ( order.direction == "SELL" or order.direction == "CLOSE_LONG" ): # 限价卖出/平多 # 如果当前K线的最高价高于或等于限价,则卖出限价单有机会成交 if current_bar.high > limit_price: # 成交价以限价为基准,并考虑滑点(卖出向下偏离) fill_price = limit_price * (1 - self.slippage_rate) if fill_price <= 0: return -1.0 return fill_price def send_order_to_pending(self, order: Order) -> Optional[Order]: """ 将订单添加到待处理队列。由 BacktestEngine 或 Strategy 调用。 此方法不进行撮合,撮合由 process_pending_orders 统一处理。 """ if order.id in self.pending_orders: return None self.pending_orders[order.id] = order return order def process_pending_orders(self, current_bar: Bar, indicator_dict: Dict[str, float]): """ 处理所有待撮合的订单。在每个K线数据到来时调用。 """ order_ids_to_process = list(self.pending_orders.keys()) for order_id in order_ids_to_process: if order_id not in self.pending_orders: continue order = self.pending_orders[order_id] if order.symbol != current_bar.symbol: continue trade = self._execute_single_order(order, current_bar) if trade: self.trade_log.append(trade) if trade.is_open_trade: self.indicator_dict = indicator_dict elif trade.is_close_trade: trade.indicator_dict = self.indicator_dict.copy() def _execute_single_order(self, order: Order, current_bar: Bar) -> Optional[Trade]: """ 内部方法:尝试执行单个订单,并处理资金和持仓变化。 由 send_order 或 process_pending_orders 调用。 """ if order.direction == "CANCEL": success = self.cancel_order(order.id) if success: pass return None symbol = order.symbol volume = order.volume fill_price = self._calculate_fill_price(order, current_bar) if fill_price <= 0: return None trade_value = volume * fill_price commission = trade_value * self.commission_rate current_position = self.positions.get(symbol, 0) current_average_cost = self.average_costs.get(symbol, 0.0) realized_pnl = 0.0 is_trade_a_close_operation = False is_trade_an_open_operation = False if order.direction in ["CLOSE_LONG", "CLOSE_SHORT"]: is_trade_a_close_operation = True elif order.direction == "BUY" and current_position < 0: is_trade_a_close_operation = True elif order.direction == "SELL" and current_position > 0: is_trade_a_close_operation = True if order.direction == "BUY": if current_position >= 0 or ( current_position < 0 and (current_position + volume) > 0 ): is_trade_an_open_operation = True elif order.direction == "SELL": if current_position <= 0 or ( current_position > 0 and (current_position - volume) < 0 ): is_trade_an_open_operation = True actual_execution_direction = "" if order.direction == "BUY" or order.direction == "CLOSE_SHORT": actual_execution_direction = "BUY" elif order.direction == "SELL" or order.direction == "CLOSE_LONG": actual_execution_direction = "SELL" else: print( f"[{current_bar.datetime}] 模拟器: 收到未知订单方向 {order.direction} for Order ID: {order.id}. 订单未处理。" ) if order.id in self.pending_orders: del self.pending_orders[order.id] return None temp_cash = self.cash temp_positions = self.positions.copy() temp_average_costs = self.average_costs.copy() if actual_execution_direction == "BUY": if current_position >= 0: required_cash = trade_value + commission if temp_cash < required_cash: # print( # f"[{current_bar.datetime}] 模拟器: 资金不足 (开多), 无法执行买入 {volume} {symbol} @ {fill_price:.2f}. 需要: {required_cash:.2f}, 当前: {temp_cash:.2f}" # ) if order.id in self.pending_orders: del self.pending_orders[order.id] return None temp_cash -= required_cash new_total_cost = ( temp_average_costs.get(symbol, 0.0) * temp_positions.get(symbol, 0) ) + (fill_price * volume) new_total_volume = temp_positions.get(symbol, 0) + volume temp_average_costs[symbol] = ( new_total_cost / new_total_volume if new_total_volume > 0 else 0.0 ) temp_positions[symbol] = new_total_volume else: # 当前持有空仓 (平空) - 平仓交易 pnl_per_share = current_average_cost - fill_price realized_pnl = pnl_per_share * volume temp_cash -= commission temp_cash -= trade_value temp_cash += realized_pnl temp_positions[symbol] += volume if temp_positions[symbol] == 0: del temp_positions[symbol] if symbol in temp_average_costs: del temp_average_costs[symbol] elif current_position < 0 and temp_positions[symbol] > 0: temp_average_costs[symbol] = fill_price elif actual_execution_direction == "SELL": if current_position <= 0: # 当前持有空仓或无仓位 (开空) if temp_cash < commission: # print( # f"[{current_bar.datetime}] 模拟器: 资金不足 (开空佣金), 无法执行卖出 {volume} {symbol} @ {fill_price:.2f}. 佣金: {commission:.2f}, 当前: {temp_cash:.2f}" # ) if order.id in self.pending_orders: del self.pending_orders[order.id] return None temp_cash -= commission temp_cash += trade_value # 修正点:开空时将卖出资金计入现金 existing_abs_volume = abs(temp_positions.get(symbol, 0)) existing_abs_cost = ( temp_average_costs.get(symbol, 0.0) * existing_abs_volume ) new_total_value = existing_abs_cost + (fill_price * volume) new_total_volume = existing_abs_volume + volume temp_average_costs[symbol] = ( new_total_value / new_total_volume if new_total_volume > 0 else 0.0 ) temp_positions[symbol] = -new_total_volume else: # 当前持有多仓 (平多) - 平仓交易 pnl_per_share = fill_price - current_average_cost realized_pnl = pnl_per_share * volume temp_cash -= commission temp_cash += trade_value temp_cash += realized_pnl temp_positions[symbol] -= volume if temp_positions[symbol] == 0: del temp_positions[symbol] if symbol in temp_average_costs: del temp_average_costs[symbol] elif current_position > 0 and temp_positions[symbol] < 0: temp_average_costs[symbol] = fill_price self.cash = temp_cash self.positions = temp_positions self.average_costs = temp_average_costs executed_trade = Trade( order_id=order.id, fill_time=current_bar.datetime, symbol=symbol, direction=order.direction, volume=volume, price=fill_price, commission=commission, cash_after_trade=self.cash, positions_after_trade=self.positions.copy(), realized_pnl=realized_pnl, is_open_trade=is_trade_an_open_operation, is_close_trade=is_trade_a_close_operation, ) if order.id in self.pending_orders: del self.pending_orders[order.id] return executed_trade def cancel_order(self, order_id: str) -> bool: if order_id in self.pending_orders: del self.pending_orders[order_id] return True return False def force_close_all_positions_for_symbol( self, symbol_to_close: str, closing_bar: Bar ) -> List[Trade]: closed_trades: List[Trade] = [] if symbol_to_close in self.positions and self.positions[symbol_to_close] != 0: volume_to_close = self.positions[symbol_to_close] direction = "CLOSE_LONG" if volume_to_close > 0 else "CLOSE_SHORT" rollover_order = Order( id=f"FORCE_CLOSE_{symbol_to_close}_{closing_bar.datetime.strftime('%Y%m%d%H%M%S%f')}", symbol=symbol_to_close, direction=direction, volume=abs(volume_to_close), price_type="MARKET", limit_price=None, submitted_time=closing_bar.datetime, ) # 这里直接调用 _execute_single_order 确保强制平仓立即成交 trade = self._execute_single_order(rollover_order, closing_bar) if trade: closed_trades.append(trade) else: print( f"[{closing_bar.datetime}] 警告: 强制平仓 {symbol_to_close} 失败!" ) return closed_trades def cancel_all_pending_orders_for_symbol(self, symbol_to_cancel: str) -> int: cancelled_count = 0 order_ids_to_cancel = [ order_id for order_id, order in self.pending_orders.items() if order.symbol == symbol_to_cancel ] for order_id in order_ids_to_cancel: if self.cancel_order(order_id): cancelled_count += 1 return cancelled_count def get_pending_orders(self) -> Dict[str, Order]: return self.pending_orders.copy() def get_portfolio_value(self, current_bar: Bar) -> float: total_value = self.cash for symbol, quantity in self.positions.items(): if symbol == current_bar.symbol: total_value += quantity * current_bar.open else: print( f"[{current_bar.datetime}] 警告:持仓中存在非当前K线合约 {symbol},无法准确计算其市值。" ) pass return total_value def get_current_positions(self) -> Dict[str, int]: return self.positions.copy() def get_trade_history(self) -> List[Trade]: return self.trade_log.copy() def reset( self, new_initial_capital: float = None, new_initial_positions: Dict[str, int] = None, new_initial_average_costs: Dict[str, float] = None, ) -> None: # 新增参数 print("ExecutionSimulator: 重置状态。") self.cash = ( new_initial_capital if new_initial_capital is not None else self.initial_capital ) self.positions = ( new_initial_positions.copy() if new_initial_positions is not None else {} ) # 修正:重置时也应该考虑传入初始平均成本 self.average_costs = ( new_initial_average_costs.copy() if new_initial_average_costs is not None else {} ) if self.positions and not new_initial_average_costs: print( f"[{datetime.now()}] 警告: 重置时提供了初始持仓但未提供初始平均成本,这些持仓的成本默认为0.0。" ) for symbol, qty in self.positions.items(): if symbol not in self.average_costs: self.average_costs[symbol] = 0.0 self.trade_log = [] self.pending_orders = {} self._current_time = None def get_average_position_price(self, symbol: str) -> Optional[float]: if symbol in self.positions and self.positions[symbol] != 0: return self.average_costs.get(symbol) return None