Files
NewQuant/futures_trading_strategies/hc/TrendlineBreakoutStrategy/TrendlineBreakoutStrategy.py
2025-10-05 00:09:59 +08:00

169 lines
8.2 KiB
Python

import numpy as np
import talib
from typing import Optional, Dict, Any, List, Tuple, Union
from src.algo.TrendLine import calculate_latest_trendline_values
# 假设这些是你项目中的基础模块
from src.core_data import Bar, Order
from src.indicators.base_indicators import Indicator
from src.indicators.indicators import Empty
from src.strategies.base_strategy import Strategy
class TrendlineBreakoutStrategy(Strategy):
"""
趋势线突破策略 V3 (优化版):
1. 策略逻辑与 V2 相同,但趋势线计算被重构为一个独立的、
高性能的辅助方法。
2. 该方法只计算最新的趋势线值,避免不必要的数组生成。
开仓信号:
- 做多: 上一根收盘价上穿下趋势线
- 做空: 上一根收盘价下穿上趋势线
平仓逻辑:
- 采用 ATR 滑动止损 (Trailing Stop)。
"""
def __init__(
self,
context: Any,
main_symbol: str,
trendline_n: int = 50,
trade_volume: int = 1,
order_direction: Optional[List[str]] = None,
atr_period: int = 14,
atr_multiplier: float = 1.0,
enable_log: bool = True,
indicators: Union[Indicator, List[Indicator]] = None,
):
super().__init__(context, main_symbol, enable_log)
self.main_symbol = main_symbol
self.trendline_n = trendline_n
self.trade_volume = trade_volume
self.order_direction = order_direction or ["BUY", "SELL"]
self.atr_period = atr_period
self.atr_multiplier = atr_multiplier
self.pos_meta: Dict[str, Dict[str, Any]] = {}
if indicators is None:
indicators = [Empty(), Empty()]
self.indicators = indicators
if self.trendline_n < 3:
raise ValueError("trendline_n 必须大于或等于 3")
log_message = (
f"TrendlineBreakoutStrategy (V3 Optimized) 初始化:\n"
f"交易标的={self.main_symbol}, 交易量={self.trade_volume}\n"
f"趋势线周期={self.trendline_n}, ATR周期={self.atr_period}, ATR倍数={self.atr_multiplier}"
)
self.log(log_message)
def _calculate_atr(self, bar_history: List[Bar]) -> Optional[float]:
# (此函数与上一版本完全相同,保持不变)
if len(bar_history) < self.atr_period + 1: return None
highs = np.array([b.high for b in bar_history])
lows = np.array([b.low for b in bar_history])
closes = np.array([b.close for b in bar_history])
atr = talib.ATR(highs, lows, closes, timeperiod=self.atr_period)
return atr[-1] if not np.isnan(atr[-1]) else None
def on_init(self):
super().on_init()
self.pos_meta.clear()
def on_open_bar(self, open_price: float, symbol: str):
bar_history = self.get_bar_history()
min_bars_required = self.trendline_n + 2
if len(bar_history) < min_bars_required:
return
self.cancel_all_pending_orders(symbol)
pos = self.get_current_positions().get(symbol, 0)
# 1. 优先处理平仓逻辑 (逻辑不变)
meta = self.pos_meta.get(symbol)
if meta and pos != 0:
current_atr = self._calculate_atr(bar_history[:-1])
if current_atr:
trailing_stop = meta['trailing_stop']
direction = meta['direction']
last_close = bar_history[-1].close
if direction == "BUY":
new_stop_level = last_close - current_atr * self.atr_multiplier
trailing_stop = max(trailing_stop, new_stop_level)
else: # SELL
new_stop_level = last_close + current_atr * self.atr_multiplier
trailing_stop = min(trailing_stop, new_stop_level)
self.pos_meta[symbol]['trailing_stop'] = trailing_stop
if (direction == "BUY" and open_price <= trailing_stop) or \
(direction == "SELL" and open_price >= trailing_stop):
self.log(f"ATR滑动止损触发: 价格 {open_price:.2f} 触及止损位 {trailing_stop:.2f}")
self.send_market_order("CLOSE_LONG" if direction == "BUY" else "CLOSE_SHORT", abs(pos))
del self.pos_meta[symbol]
return
# 2. 开仓逻辑 (调用优化后的方法)
if pos == 0:
prices_for_trendline = np.array([b.close for b in bar_history[-self.trendline_n - 1:-1]])
# --- 调用新的独立方法 ---
trendline_val_upper, trendline_val_lower = calculate_latest_trendline_values(prices_for_trendline)
if trendline_val_upper is None or trendline_val_lower is None:
return # 无法计算趋势线,跳过
prev_close = bar_history[-2].close
last_close = bar_history[-1].close
current_atr = self._calculate_atr(bar_history[:-1])
if not current_atr:
return
# if "BUY" in self.order_direction and last_close > trendline_val_upper and self.indicators[0].is_condition_met(*self.get_indicator_tuple()):
# self.log(f"做多信号: Close({last_close:.2f}) 上穿下趋势线({trendline_val_upper:.2f})")
# self.send_open_order("BUY", open_price, self.trade_volume, current_atr)
#
# elif "SELL" in self.order_direction and last_close < trendline_val_lower and self.indicators[1].is_condition_met(*self.get_indicator_tuple()):
# self.log(f"做空信号: Close({last_close:.2f}) 下穿上趋势线({trendline_val_lower:.2f})")
# self.send_open_order("SELL", open_price, self.trade_volume, current_atr)
if "BUY" in self.order_direction and last_close > trendline_val_upper and self.indicators[0].is_condition_met(*self.get_indicator_tuple()):
self.log(f"做多信号: Close({last_close:.2f}) 上穿下趋势线({trendline_val_upper:.2f})")
self.send_open_order("BUY", open_price, self.trade_volume, current_atr)
elif "SELL" in self.order_direction and last_close < trendline_val_lower and self.indicators[1].is_condition_met(*self.get_indicator_tuple()):
self.log(f"做空信号: Close({last_close:.2f}) 下穿上趋势线({trendline_val_lower:.2f})")
self.send_open_order("SELL", open_price, self.trade_volume, current_atr)
# send_open_order, send_market_order, on_rollover 等方法与上一版本完全相同,保持不变
def send_open_order(self, direction: str, entry_price: float, volume: int, current_atr: float):
if direction == "BUY":
initial_stop = entry_price - current_atr * self.atr_multiplier
else:
initial_stop = entry_price + current_atr * self.atr_multiplier
current_time = self.get_current_time()
order_id = f"{self.symbol}_{direction}_{current_time.strftime('%Y%m%d%H%M%S')}"
order_direction = "BUY" if direction == "BUY" else "SELL"
order = Order(id=order_id, symbol=self.symbol, direction=order_direction, volume=volume, price_type="MARKET",
submitted_time=current_time, offset="OPEN")
self.send_order(order)
self.pos_meta[self.symbol] = {"direction": direction, "volume": volume, "entry_price": entry_price,
"trailing_stop": initial_stop}
self.log(
f"发送开仓订单: {direction} {volume}手 @ Market Price (执行价约 {entry_price:.2f}), 初始ATR止损位: {initial_stop:.2f}")
def send_market_order(self, direction: str, volume: int):
current_time = self.get_current_time()
order_id = f"{self.symbol}_{direction}_{current_time.strftime('%Y%m%d%H%M%S')}"
order = Order(id=order_id, symbol=self.symbol, direction=direction, volume=volume, price_type="MARKET",
submitted_time=current_time, offset="CLOSE")
self.send_order(order)
self.log(f"发送平仓订单: {direction} {volume}手 @ Market Price")
def on_rollover(self, old_symbol: str, new_symbol: str):
super().on_rollover(old_symbol, new_symbol)
self.cancel_all_pending_orders(new_symbol)
self.pos_meta.clear()