169 lines
8.2 KiB
Python
169 lines
8.2 KiB
Python
import numpy as np
|
|
import talib
|
|
from typing import Optional, Dict, Any, List, Tuple, Union
|
|
|
|
from src.algo.TrendLine import calculate_latest_trendline_values
|
|
# 假设这些是你项目中的基础模块
|
|
from src.core_data import Bar, Order
|
|
from src.indicators.base_indicators import Indicator
|
|
from src.indicators.indicators import Empty
|
|
from src.strategies.base_strategy import Strategy
|
|
|
|
|
|
class TrendlineBreakoutStrategy(Strategy):
|
|
"""
|
|
趋势线突破策略 V3 (优化版):
|
|
1. 策略逻辑与 V2 相同,但趋势线计算被重构为一个独立的、
|
|
高性能的辅助方法。
|
|
2. 该方法只计算最新的趋势线值,避免不必要的数组生成。
|
|
|
|
开仓信号:
|
|
- 做多: 上一根收盘价上穿下趋势线
|
|
- 做空: 上一根收盘价下穿上趋势线
|
|
|
|
平仓逻辑:
|
|
- 采用 ATR 滑动止损 (Trailing Stop)。
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
context: Any,
|
|
main_symbol: str,
|
|
trendline_n: int = 50,
|
|
trade_volume: int = 1,
|
|
order_direction: Optional[List[str]] = None,
|
|
atr_period: int = 14,
|
|
atr_multiplier: float = 1.0,
|
|
enable_log: bool = True,
|
|
indicators: Union[Indicator, List[Indicator]] = None,
|
|
):
|
|
super().__init__(context, main_symbol, enable_log)
|
|
self.main_symbol = main_symbol
|
|
self.trendline_n = trendline_n
|
|
self.trade_volume = trade_volume
|
|
self.order_direction = order_direction or ["BUY", "SELL"]
|
|
self.atr_period = atr_period
|
|
self.atr_multiplier = atr_multiplier
|
|
self.pos_meta: Dict[str, Dict[str, Any]] = {}
|
|
if indicators is None:
|
|
indicators = [Empty(), Empty()]
|
|
self.indicators = indicators
|
|
|
|
if self.trendline_n < 3:
|
|
raise ValueError("trendline_n 必须大于或等于 3")
|
|
|
|
log_message = (
|
|
f"TrendlineBreakoutStrategy (V3 Optimized) 初始化:\n"
|
|
f"交易标的={self.main_symbol}, 交易量={self.trade_volume}\n"
|
|
f"趋势线周期={self.trendline_n}, ATR周期={self.atr_period}, ATR倍数={self.atr_multiplier}"
|
|
)
|
|
self.log(log_message)
|
|
|
|
|
|
def _calculate_atr(self, bar_history: List[Bar]) -> Optional[float]:
|
|
# (此函数与上一版本完全相同,保持不变)
|
|
if len(bar_history) < self.atr_period + 1: return None
|
|
highs = np.array([b.high for b in bar_history])
|
|
lows = np.array([b.low for b in bar_history])
|
|
closes = np.array([b.close for b in bar_history])
|
|
atr = talib.ATR(highs, lows, closes, timeperiod=self.atr_period)
|
|
return atr[-1] if not np.isnan(atr[-1]) else None
|
|
|
|
def on_init(self):
|
|
super().on_init()
|
|
self.pos_meta.clear()
|
|
|
|
def on_open_bar(self, open_price: float, symbol: str):
|
|
bar_history = self.get_bar_history()
|
|
min_bars_required = self.trendline_n + 2
|
|
if len(bar_history) < min_bars_required:
|
|
return
|
|
|
|
self.cancel_all_pending_orders(symbol)
|
|
pos = self.get_current_positions().get(symbol, 0)
|
|
|
|
# 1. 优先处理平仓逻辑 (逻辑不变)
|
|
meta = self.pos_meta.get(symbol)
|
|
if meta and pos != 0:
|
|
current_atr = self._calculate_atr(bar_history[:-1])
|
|
if current_atr:
|
|
trailing_stop = meta['trailing_stop']
|
|
direction = meta['direction']
|
|
last_close = bar_history[-1].close
|
|
if direction == "BUY":
|
|
new_stop_level = last_close - current_atr * self.atr_multiplier
|
|
trailing_stop = max(trailing_stop, new_stop_level)
|
|
else: # SELL
|
|
new_stop_level = last_close + current_atr * self.atr_multiplier
|
|
trailing_stop = min(trailing_stop, new_stop_level)
|
|
self.pos_meta[symbol]['trailing_stop'] = trailing_stop
|
|
if (direction == "BUY" and open_price <= trailing_stop) or \
|
|
(direction == "SELL" and open_price >= trailing_stop):
|
|
self.log(f"ATR滑动止损触发: 价格 {open_price:.2f} 触及止损位 {trailing_stop:.2f}")
|
|
self.send_market_order("CLOSE_LONG" if direction == "BUY" else "CLOSE_SHORT", abs(pos))
|
|
del self.pos_meta[symbol]
|
|
return
|
|
|
|
# 2. 开仓逻辑 (调用优化后的方法)
|
|
if pos == 0:
|
|
prices_for_trendline = np.array([b.close for b in bar_history[-self.trendline_n - 1:-1]])
|
|
|
|
# --- 调用新的独立方法 ---
|
|
trendline_val_upper, trendline_val_lower = calculate_latest_trendline_values(prices_for_trendline)
|
|
|
|
if trendline_val_upper is None or trendline_val_lower is None:
|
|
return # 无法计算趋势线,跳过
|
|
|
|
prev_close = bar_history[-2].close
|
|
last_close = bar_history[-1].close
|
|
|
|
current_atr = self._calculate_atr(bar_history[:-1])
|
|
if not current_atr:
|
|
return
|
|
|
|
# if "BUY" in self.order_direction and last_close > trendline_val_upper and self.indicators[0].is_condition_met(*self.get_indicator_tuple()):
|
|
# self.log(f"做多信号: Close({last_close:.2f}) 上穿下趋势线({trendline_val_upper:.2f})")
|
|
# self.send_open_order("BUY", open_price, self.trade_volume, current_atr)
|
|
#
|
|
# elif "SELL" in self.order_direction and last_close < trendline_val_lower and self.indicators[1].is_condition_met(*self.get_indicator_tuple()):
|
|
# self.log(f"做空信号: Close({last_close:.2f}) 下穿上趋势线({trendline_val_lower:.2f})")
|
|
# self.send_open_order("SELL", open_price, self.trade_volume, current_atr)
|
|
|
|
if "BUY" in self.order_direction and last_close > trendline_val_upper and self.indicators[0].is_condition_met(*self.get_indicator_tuple()):
|
|
self.log(f"做多信号: Close({last_close:.2f}) 上穿下趋势线({trendline_val_upper:.2f})")
|
|
self.send_open_order("BUY", open_price, self.trade_volume, current_atr)
|
|
|
|
elif "SELL" in self.order_direction and last_close < trendline_val_lower and self.indicators[1].is_condition_met(*self.get_indicator_tuple()):
|
|
self.log(f"做空信号: Close({last_close:.2f}) 下穿上趋势线({trendline_val_lower:.2f})")
|
|
self.send_open_order("SELL", open_price, self.trade_volume, current_atr)
|
|
|
|
|
|
# send_open_order, send_market_order, on_rollover 等方法与上一版本完全相同,保持不变
|
|
def send_open_order(self, direction: str, entry_price: float, volume: int, current_atr: float):
|
|
if direction == "BUY":
|
|
initial_stop = entry_price - current_atr * self.atr_multiplier
|
|
else:
|
|
initial_stop = entry_price + current_atr * self.atr_multiplier
|
|
current_time = self.get_current_time()
|
|
order_id = f"{self.symbol}_{direction}_{current_time.strftime('%Y%m%d%H%M%S')}"
|
|
order_direction = "BUY" if direction == "BUY" else "SELL"
|
|
order = Order(id=order_id, symbol=self.symbol, direction=order_direction, volume=volume, price_type="MARKET",
|
|
submitted_time=current_time, offset="OPEN")
|
|
self.send_order(order)
|
|
self.pos_meta[self.symbol] = {"direction": direction, "volume": volume, "entry_price": entry_price,
|
|
"trailing_stop": initial_stop}
|
|
self.log(
|
|
f"发送开仓订单: {direction} {volume}手 @ Market Price (执行价约 {entry_price:.2f}), 初始ATR止损位: {initial_stop:.2f}")
|
|
|
|
def send_market_order(self, direction: str, volume: int):
|
|
current_time = self.get_current_time()
|
|
order_id = f"{self.symbol}_{direction}_{current_time.strftime('%Y%m%d%H%M%S')}"
|
|
order = Order(id=order_id, symbol=self.symbol, direction=direction, volume=volume, price_type="MARKET",
|
|
submitted_time=current_time, offset="CLOSE")
|
|
self.send_order(order)
|
|
self.log(f"发送平仓订单: {direction} {volume}手 @ Market Price")
|
|
|
|
def on_rollover(self, old_symbol: str, new_symbol: str):
|
|
super().on_rollover(old_symbol, new_symbol)
|
|
self.cancel_all_pending_orders(new_symbol)
|
|
self.pos_meta.clear() |