import numpy as np import talib from typing import Optional, Any, List, Dict from src.core_data import Bar, Order from src.strategies.base_strategy import Strategy class KalmanMeanReversion(Strategy): """ 改进版卡尔曼均值回归策略 1. 以 Slow Line 为止盈目标 (真正的均值回归) 2. 严格的 ATR 止损 (防止趋势爆发) 3. 时间衰减退出机制 4. 趋势斜率过滤 """ def __init__( self, context: Any, main_symbol: str, enable_log: bool, trade_volume: int, fast_sensitivity: float = 0.05, # 快线灵敏度 slow_sensitivity: float = 0.01, # 慢线灵敏度 (更平滑作为均值) lookback_variance: int = 60, atr_period: int = 20, entry_threshold: float = 1.2, # 偏离多少个ATR入场 stop_loss_multiplier: float = 2.0, # 止损倍数 max_hold_bars: int = 30, # 最大持仓时间 indicator: Any = None, ): super().__init__(context, main_symbol, enable_log) self.trade_volume = trade_volume self.fast_sensitivity = fast_sensitivity self.slow_sensitivity = slow_sensitivity self.lookback_variance = lookback_variance self.atr_period = atr_period self.entry_threshold = entry_threshold self.stop_loss_multiplier = stop_loss_multiplier self.max_hold_bars = max_hold_bars self.indicator = indicator # 状态变量 self.kf_fast = {'x': 0.0, 'P': 1.0} self.kf_slow = {'x': 0.0, 'P': 1.0} self.kalman_initialized = False # 记录持仓信息 self.entry_price = 0.0 self.entry_bar_count = 0 self.order_id_counter = 0 def _update_kalman(self, state: dict, measurement: float, Q: float, R: float) -> float: """递归卡尔曼滤波更新""" p_minus = state['P'] + Q k_gain = p_minus / (p_minus + R) state['x'] = state['x'] + k_gain * (measurement - state['x']) state['P'] = (1 - k_gain) * p_minus return state['x'] def on_open_bar(self, open_price: float, symbol: str): bar_history = self.get_bar_history() if len(bar_history) < 100: return self.cancel_all_pending_orders() closes = np.array([b.close for b in bar_history], dtype=float) last_price = closes[-1] # 1. 计算自适应噪声 (基于滚动方差) rolling_var = np.var(closes[-self.lookback_variance:]) r_base = rolling_var if rolling_var > 0 else 1.0 # 2. 初始化或更新卡尔曼滤波器 if not self.kalman_initialized: self.kf_fast['x'] = self.kf_slow['x'] = last_price self.kalman_initialized = True return # 快线追踪价格,慢线代表平衡位置 fast_line = self._update_kalman(self.kf_fast, last_price, r_base * self.fast_sensitivity, r_base) slow_line = self._update_kalman(self.kf_slow, last_price, r_base * self.slow_sensitivity, r_base * 10.0) # 3. 计算 ATR 和 偏离度 highs = np.array([b.high for b in bar_history], dtype=float) lows = np.array([b.low for b in bar_history], dtype=float) atr = talib.ATR(highs, lows, closes, self.atr_period)[-1] if atr <= 0: return # 计算价格偏离慢线的程度 (Z-Score 的变体) diff_in_atr = (last_price - slow_line) / atr # 4. 仓位逻辑 pos = self.get_current_positions().get(symbol, 0) if pos == 0: # --- 入场逻辑 --- # 只有在外部指标允许且偏离度足够大时入场 can_trade = self.indicator is None or self.indicator.is_condition_met(*self.get_indicator_tuple()) if can_trade: if diff_in_atr > self.entry_threshold: # 超买,做空 self.execute_order(symbol, "SELL", "OPEN", last_price) elif diff_in_atr < -self.entry_threshold: # 超卖,做多 self.execute_order(symbol, "BUY", "OPEN", last_price) else: # --- 出场逻辑 --- self.entry_bar_count += 1 is_long = pos > 0 should_close = False exit_reason = "" # A. 均值回归止盈:触碰或穿过慢线 if is_long and last_price >= slow_line: should_close = True exit_reason = "Take Profit: Reached Mean" elif not is_long and last_price <= slow_line: should_close = True exit_reason = "Take Profit: Reached Mean" # B. 严谨止损:背离程度进一步扩大 elif is_long and last_price < self.entry_price - self.stop_loss_multiplier * atr: should_close = True exit_reason = "Stop Loss: Deviation Too Large" elif not is_long and last_price > self.entry_price + self.stop_loss_multiplier * atr: should_close = True exit_reason = "Stop Loss: Deviation Too Large" # C. 时间退出:久攻不下,由于均值回归的时效性,超时即走 elif self.entry_bar_count >= self.max_hold_bars: should_close = True exit_reason = "Time Exit: Holding Too Long" if should_close: direction = "CLOSE_LONG" if is_long else "CLOSE_SHORT" self.log(f"EXIT {direction}: Price={last_price}, Reason={exit_reason}") self.execute_order(symbol, direction, "CLOSE", last_price) def execute_order(self, symbol, direction, offset, price): """执行订单并更新状态""" if offset == "OPEN": self.entry_price = price self.entry_bar_count = 0 order_id = f"{symbol}_{direction}_{self.order_id_counter}" self.order_id_counter += 1 order = Order( id=order_id, symbol=symbol, direction=direction, volume=self.trade_volume, price_type="MARKET", offset=offset ) self.send_order(order)