import numpy as np import talib from typing import Optional, Dict, Any, List, Tuple, Union from src.algo.TrendLine import calculate_latest_trendline_values # 假设这些是你项目中的基础模块 from src.core_data import Bar, Order from src.indicators.base_indicators import Indicator from src.indicators.indicators import Empty from src.strategies.base_strategy import Strategy class TrendlineBreakoutStrategy(Strategy): """ 趋势线突破策略 V3 (优化版): 1. 策略逻辑与 V2 相同,但趋势线计算被重构为一个独立的、 高性能的辅助方法。 2. 该方法只计算最新的趋势线值,避免不必要的数组生成。 开仓信号: - 做多: 上一根收盘价上穿下趋势线 - 做空: 上一根收盘价下穿上趋势线 平仓逻辑: - 采用 ATR 滑动止损 (Trailing Stop)。 """ def __init__( self, context: Any, main_symbol: str, trendline_n: int = 50, trade_volume: int = 1, order_direction: Optional[List[str]] = None, atr_period: int = 14, atr_multiplier: float = 1.0, enable_log: bool = True, indicators: Union[Indicator, List[Indicator]] = None, ): super().__init__(context, main_symbol, enable_log) self.main_symbol = main_symbol self.trendline_n = trendline_n self.trade_volume = trade_volume self.order_direction = order_direction or ["BUY", "SELL"] self.atr_period = atr_period self.atr_multiplier = atr_multiplier self.pos_meta: Dict[str, Dict[str, Any]] = {} if indicators is None: indicators = [Empty(), Empty()] self.indicators = indicators if self.trendline_n < 3: raise ValueError("trendline_n 必须大于或等于 3") log_message = ( f"TrendlineBreakoutStrategy (V3 Optimized) 初始化:\n" f"交易标的={self.main_symbol}, 交易量={self.trade_volume}\n" f"趋势线周期={self.trendline_n}, ATR周期={self.atr_period}, ATR倍数={self.atr_multiplier}" ) self.log(log_message) def _calculate_atr(self, bar_history: List[Bar]) -> Optional[float]: # (此函数与上一版本完全相同,保持不变) if len(bar_history) < self.atr_period + 1: return None highs = np.array([b.high for b in bar_history]) lows = np.array([b.low for b in bar_history]) closes = np.array([b.close for b in bar_history]) atr = talib.ATR(highs, lows, closes, timeperiod=self.atr_period) return atr[-1] if not np.isnan(atr[-1]) else None def on_init(self): super().on_init() self.pos_meta.clear() def on_open_bar(self, open_price: float, symbol: str): bar_history = self.get_bar_history() min_bars_required = self.trendline_n + 2 if len(bar_history) < min_bars_required: return self.cancel_all_pending_orders(symbol) pos = self.get_current_positions().get(symbol, 0) # 1. 优先处理平仓逻辑 (逻辑不变) meta = self.pos_meta.get(symbol) if meta and pos != 0: current_atr = self._calculate_atr(bar_history[:-1]) if current_atr: trailing_stop = meta['trailing_stop'] direction = meta['direction'] last_close = bar_history[-1].close if direction == "BUY": new_stop_level = last_close - current_atr * self.atr_multiplier trailing_stop = max(trailing_stop, new_stop_level) else: # SELL new_stop_level = last_close + current_atr * self.atr_multiplier trailing_stop = min(trailing_stop, new_stop_level) self.pos_meta[symbol]['trailing_stop'] = trailing_stop if (direction == "BUY" and open_price <= trailing_stop) or \ (direction == "SELL" and open_price >= trailing_stop): self.log(f"ATR滑动止损触发: 价格 {open_price:.2f} 触及止损位 {trailing_stop:.2f}") self.send_market_order("CLOSE_LONG" if direction == "BUY" else "CLOSE_SHORT", abs(pos)) del self.pos_meta[symbol] return # 2. 开仓逻辑 (调用优化后的方法) if pos == 0: prices_for_trendline = np.array([b.close for b in bar_history[-self.trendline_n - 1:-1]]) # --- 调用新的独立方法 --- trendline_val_upper, trendline_val_lower = calculate_latest_trendline_values(prices_for_trendline) if trendline_val_upper is None or trendline_val_lower is None: return # 无法计算趋势线,跳过 prev_close = bar_history[-2].close last_close = bar_history[-1].close current_atr = self._calculate_atr(bar_history[:-1]) if not current_atr: return # if "BUY" in self.order_direction and last_close > trendline_val_upper and self.indicators[0].is_condition_met(*self.get_indicator_tuple()): # self.log(f"做多信号: Close({last_close:.2f}) 上穿下趋势线({trendline_val_upper:.2f})") # self.send_open_order("BUY", open_price, self.trade_volume, current_atr) # # elif "SELL" in self.order_direction and last_close < trendline_val_lower and self.indicators[1].is_condition_met(*self.get_indicator_tuple()): # self.log(f"做空信号: Close({last_close:.2f}) 下穿上趋势线({trendline_val_lower:.2f})") # self.send_open_order("SELL", open_price, self.trade_volume, current_atr) if "BUY" in self.order_direction and last_close > trendline_val_upper and self.indicators[0].is_condition_met(*self.get_indicator_tuple()): self.log(f"做多信号: Close({last_close:.2f}) 上穿下趋势线({trendline_val_upper:.2f})") self.send_open_order("BUY", open_price, self.trade_volume, current_atr) elif "SELL" in self.order_direction and last_close < trendline_val_lower and self.indicators[1].is_condition_met(*self.get_indicator_tuple()): self.log(f"做空信号: Close({last_close:.2f}) 下穿上趋势线({trendline_val_lower:.2f})") self.send_open_order("SELL", open_price, self.trade_volume, current_atr) # send_open_order, send_market_order, on_rollover 等方法与上一版本完全相同,保持不变 def send_open_order(self, direction: str, entry_price: float, volume: int, current_atr: float): if direction == "BUY": initial_stop = entry_price - current_atr * self.atr_multiplier else: initial_stop = entry_price + current_atr * self.atr_multiplier current_time = self.get_current_time() order_id = f"{self.symbol}_{direction}_{current_time.strftime('%Y%m%d%H%M%S')}" order_direction = "BUY" if direction == "BUY" else "SELL" order = Order(id=order_id, symbol=self.symbol, direction=order_direction, volume=volume, price_type="MARKET", submitted_time=current_time, offset="OPEN") self.send_order(order) self.pos_meta[self.symbol] = {"direction": direction, "volume": volume, "entry_price": entry_price, "trailing_stop": initial_stop} self.log( f"发送开仓订单: {direction} {volume}手 @ Market Price (执行价约 {entry_price:.2f}), 初始ATR止损位: {initial_stop:.2f}") def send_market_order(self, direction: str, volume: int): current_time = self.get_current_time() order_id = f"{self.symbol}_{direction}_{current_time.strftime('%Y%m%d%H%M%S')}" order = Order(id=order_id, symbol=self.symbol, direction=direction, volume=volume, price_type="MARKET", submitted_time=current_time, offset="CLOSE") self.send_order(order) self.log(f"发送平仓订单: {direction} {volume}手 @ Market Price") def on_rollover(self, old_symbol: str, new_symbol: str): super().on_rollover(old_symbol, new_symbol) self.cancel_all_pending_orders(new_symbol) self.pos_meta.clear()