289 lines
12 KiB
Python
289 lines
12 KiB
Python
import numpy as np
|
|
import talib
|
|
from scipy.signal import stft
|
|
from datetime import datetime, timedelta
|
|
from typing import Optional, Any, List, Dict
|
|
|
|
from src.core_data import Bar, Order
|
|
from src.indicators.base_indicators import Indicator
|
|
from src.indicators.indicators import Empty, ADX
|
|
from src.strategies.base_strategy import Strategy
|
|
|
|
|
|
class SpectralTrendStrategy(Strategy):
|
|
"""
|
|
频域能量相变策略 - 极简回归版 (动态ATR止损)
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
context: Any,
|
|
main_symbol: str,
|
|
enable_log: bool,
|
|
trade_volume: int,
|
|
# --- 市场参数 ---
|
|
bars_per_day: int = 23,
|
|
# --- 策略参数 ---
|
|
spectral_window_days: float = 2.0,
|
|
low_freq_days: float = 2.0,
|
|
high_freq_days: float = 1.0,
|
|
trend_strength_threshold: float = 0.2,
|
|
exit_threshold: float = 0.1,
|
|
slope_threshold: float = 0.0,
|
|
max_hold_days: int = 10,
|
|
# --- 风控参数 ---
|
|
stop_loss_atr_multiplier: float = 2.0, # 止损距离是当前ATR的几倍
|
|
stop_loss_atr_period: int = 14, # ATR计算周期
|
|
# --- 其他 ---
|
|
order_direction: Optional[List[str]] = None,
|
|
indicators: Indicator = None,
|
|
model_indicator: Indicator = None,
|
|
holding_indicators: Indicator = None,
|
|
reverse: bool = False,
|
|
):
|
|
super().__init__(context, main_symbol, enable_log)
|
|
if order_direction is None:
|
|
order_direction = ['BUY', 'SELL']
|
|
|
|
self.trade_volume = trade_volume
|
|
self.bars_per_day = bars_per_day
|
|
self.spectral_window_days = spectral_window_days
|
|
self.low_freq_days = low_freq_days
|
|
self.high_freq_days = high_freq_days
|
|
self.trend_strength_threshold = trend_strength_threshold
|
|
self.exit_threshold = exit_threshold
|
|
self.slope_threshold = slope_threshold
|
|
self.max_hold_days = max_hold_days
|
|
|
|
# --- 风控参数 ---
|
|
self.sl_atr_multiplier = stop_loss_atr_multiplier
|
|
self.sl_atr_period = stop_loss_atr_period
|
|
# 注意:移除了 self.stop_loss_price 状态变量,改为实时计算
|
|
|
|
self.order_direction = order_direction
|
|
self.model_indicator = model_indicator or Empty()
|
|
self.indicators = indicators or Empty()
|
|
self.holding_indicators = holding_indicators or Empty()
|
|
self.reverse = reverse
|
|
|
|
# 计算窗口大小
|
|
self.spectral_window = int(self.spectral_window_days * self.bars_per_day)
|
|
if self.spectral_window % 2 != 0:
|
|
self.spectral_window += 1
|
|
|
|
self.low_freq_bound = 1.0 / self.low_freq_days if self.low_freq_days > 0 else float('inf')
|
|
self.high_freq_bound = 1.0 / self.high_freq_days if self.high_freq_days > 0 else 0.0
|
|
|
|
self.order_id_counter = 0
|
|
self.entry_time = None
|
|
self.position_direction = None
|
|
|
|
self.log(
|
|
f"SpectralTrendStrategy Init. Window: {self.spectral_window}, Dynamic ATR SL: {self.sl_atr_multiplier}x")
|
|
|
|
def on_open_bar(self, open_price: float, symbol: str):
|
|
|
|
self.symbol = symbol
|
|
bar_history = self.get_bar_history()
|
|
current_time = self.get_current_time()
|
|
|
|
self.cancel_all_pending_orders(self.main_symbol)
|
|
|
|
# 确保数据长度足够计算 STFT 和 ATR
|
|
required_len = max(self.spectral_window, self.sl_atr_period + 5)
|
|
if len(bar_history) < required_len:
|
|
return
|
|
|
|
# 强制平仓检查 (时间)
|
|
# if self.entry_time and (current_time - self.entry_time) >= timedelta(days=self.max_hold_days):
|
|
# self.close_all_positions(reason="MaxHoldDays")
|
|
# return
|
|
|
|
# 获取数据用于 STFT
|
|
closes = np.array([b.close for b in bar_history[-self.spectral_window:]], dtype=float)
|
|
|
|
# --- 计算 ATR (每一根Bar都计算最新的ATR) ---
|
|
atr_window = self.sl_atr_period + 10
|
|
highs_atr = np.array([b.high for b in bar_history[-atr_window:]], dtype=float)
|
|
lows_atr = np.array([b.low for b in bar_history[-atr_window:]], dtype=float)
|
|
closes_atr = np.array([b.close for b in bar_history[-atr_window:]], dtype=float)
|
|
|
|
try:
|
|
atr_values = talib.ATR(highs_atr, lows_atr, closes_atr, timeperiod=self.sl_atr_period)
|
|
current_atr = atr_values[-1]
|
|
except Exception as e:
|
|
self.log(f"ATR Calculation Error: {e}")
|
|
current_atr = 0.0
|
|
|
|
|
|
|
|
# 计算核心指标
|
|
trend_strength, trend_slope = self.calculate_market_state(closes)
|
|
|
|
position_volume = self.get_current_positions().get(self.symbol, 0)
|
|
|
|
if self.trading:
|
|
if position_volume == 0:
|
|
self.evaluate_entry_signal(open_price, trend_strength, trend_slope)
|
|
else:
|
|
# 传入 current_atr 用于动态止损计算
|
|
self.manage_open_position(position_volume, trend_strength, trend_slope, open_price, current_atr)
|
|
|
|
def calculate_market_state(self, prices: np.array) -> (float, float):
|
|
# ... (此处逻辑保持不变) ...
|
|
if len(prices) < self.spectral_window:
|
|
return 0.0, 0.0
|
|
|
|
window_data = prices[-self.spectral_window:]
|
|
normalized = (window_data - np.mean(window_data)) / (np.std(window_data) + 1e-8)
|
|
normalized = normalized[-self.spectral_window:]
|
|
|
|
try:
|
|
f, t, Zxx = stft(
|
|
normalized,
|
|
fs=self.bars_per_day,
|
|
nperseg=self.spectral_window,
|
|
noverlap=max(0, self.spectral_window // 2),
|
|
boundary=None,
|
|
padded=False
|
|
)
|
|
except Exception as e:
|
|
return 0.0, 0.0
|
|
|
|
valid_mask = (f >= 0) & (f <= self.bars_per_day / 2)
|
|
f = f[valid_mask]
|
|
Zxx = Zxx[valid_mask, :]
|
|
|
|
if Zxx.size == 0 or Zxx.shape[1] == 0:
|
|
return 0.0, 0.0
|
|
|
|
current_energy = np.abs(Zxx[:, -1]) ** 2
|
|
|
|
low_freq_mask = f < self.low_freq_bound
|
|
high_freq_mask = f > self.high_freq_bound
|
|
|
|
low_energy = np.sum(current_energy[low_freq_mask]) if np.any(low_freq_mask) else 0.0
|
|
high_energy = np.sum(current_energy[high_freq_mask]) if np.any(high_freq_mask) else 0.0
|
|
total_energy = low_energy + high_energy + 1e-8
|
|
|
|
trend_strength = low_energy / total_energy
|
|
|
|
x = np.arange(len(normalized))
|
|
slope, intercept = np.polyfit(x, normalized, 1)
|
|
|
|
return trend_strength, slope
|
|
|
|
def evaluate_entry_signal(self, open_price: float, trend_strength: float, trend_slope: float):
|
|
"""
|
|
入场逻辑:不再计算止损价,只负责开仓
|
|
"""
|
|
if trend_strength > self.trend_strength_threshold:
|
|
direction = None
|
|
|
|
if "BUY" in self.order_direction and trend_slope > self.slope_threshold:
|
|
direction = "BUY"
|
|
elif "SELL" in self.order_direction and trend_slope < -self.slope_threshold:
|
|
direction = "SELL"
|
|
|
|
if direction:
|
|
if not self.indicators.is_condition_met(*self.get_indicator_tuple()):
|
|
return
|
|
|
|
if not self.model_indicator.is_condition_met(*self.get_indicator_tuple()):
|
|
direction = "SELL" if direction == "BUY" else "BUY"
|
|
if self.reverse:
|
|
direction = "SELL" if direction == "BUY" else "BUY"
|
|
|
|
self.log(f"Signal: {direction} | Strength={trend_strength:.2f} | Slope={trend_slope:.4f}")
|
|
|
|
self.send_limit_order(direction, open_price, self.trade_volume, "OPEN")
|
|
self.entry_time = self.get_current_time()
|
|
self.position_direction = "LONG" if direction == "BUY" else "SHORT"
|
|
|
|
def manage_open_position(self, volume: int, trend_strength: float, trend_slope: float, current_price: float,
|
|
current_atr: float):
|
|
"""
|
|
离场逻辑:实时计算均价止损
|
|
"""
|
|
# --- 1. 动态ATR止损检查 ---
|
|
# 获取持仓均价
|
|
self.log(f'trend_strength: {trend_strength:.2f}')
|
|
avg_entry_price = self.get_average_position_price(self.symbol)
|
|
|
|
if not self.holding_indicators.is_condition_met(*self.get_indicator_tuple()):
|
|
direction = "CLOSE_LONG" if volume > 0 else "CLOSE_SHORT"
|
|
self.log(f"Exit (Signal): {direction} | Strength={trend_strength:.2f} < {self.exit_threshold}")
|
|
self.close_position(direction, abs(volume))
|
|
self.entry_time = None
|
|
self.position_direction = None
|
|
|
|
# 确保 ATR 和 均价 有效
|
|
if current_atr > 0 and avg_entry_price > 0:
|
|
is_stop_loss = False
|
|
exit_dir = ""
|
|
stop_price = 0.0
|
|
|
|
sl_distance = current_atr * self.sl_atr_multiplier
|
|
|
|
# 多头持仓:止损价 = 均价 - N * ATR
|
|
if volume > 0:
|
|
stop_price = avg_entry_price - sl_distance
|
|
if current_price <= stop_price:
|
|
is_stop_loss = True
|
|
exit_dir = "CLOSE_LONG"
|
|
|
|
# 空头持仓:止损价 = 均价 + N * ATR
|
|
elif volume < 0:
|
|
stop_price = avg_entry_price + sl_distance
|
|
if current_price >= stop_price:
|
|
is_stop_loss = True
|
|
exit_dir = "CLOSE_SHORT"
|
|
|
|
if is_stop_loss:
|
|
self.log(
|
|
f"ATR STOP LOSS: {exit_dir} | Current={current_price:.2f} | AvgEntry={avg_entry_price:.2f} | ATR={current_atr:.2f} | StopPrice={stop_price:.2f}")
|
|
self.close_position(exit_dir, abs(volume))
|
|
self.entry_time = None
|
|
self.position_direction = None
|
|
return # 止损触发后直接返回
|
|
|
|
# --- 2. 信号离场 (原能量逻辑) ---
|
|
if trend_strength < self.exit_threshold:
|
|
direction = "CLOSE_LONG" if volume > 0 else "CLOSE_SHORT"
|
|
self.log(f"Exit (Signal): {direction} | Strength={trend_strength:.2f} < {self.exit_threshold}")
|
|
self.close_position(direction, abs(volume))
|
|
self.entry_time = None
|
|
self.position_direction = None
|
|
|
|
# --- 交易辅助 ---
|
|
def close_all_positions(self, reason=""):
|
|
positions = self.get_current_positions()
|
|
if self.symbol in positions and positions[self.symbol] != 0:
|
|
dir = "CLOSE_LONG" if positions[self.symbol] > 0 else "CLOSE_SHORT"
|
|
self.log(f"Close All ({reason}): {dir}")
|
|
self.close_position(dir, abs(positions[self.symbol]))
|
|
self.entry_time = None
|
|
self.position_direction = None
|
|
|
|
def close_position(self, direction: str, volume: int):
|
|
self.send_market_order(direction, volume, offset="CLOSE")
|
|
|
|
def send_market_order(self, direction: str, volume: int, offset: str):
|
|
order_id = f"{self.symbol}_{direction}_MKT_{self.order_id_counter}"
|
|
self.order_id_counter += 1
|
|
order = Order(
|
|
id=order_id, symbol=self.symbol, direction=direction, volume=volume,
|
|
price_type="MARKET", submitted_time=self.get_current_time(), offset=offset
|
|
)
|
|
self.send_order(order)
|
|
|
|
def send_limit_order(self, direction: str, limit_price: float, volume: int, offset: str):
|
|
order_id = f"{self.symbol}_{direction}_LMT_{self.order_id_counter}"
|
|
self.order_id_counter += 1
|
|
order = Order(
|
|
id=order_id, symbol=self.symbol, direction=direction, volume=volume,
|
|
price_type="LIMIT", submitted_time=self.get_current_time(), offset=offset,
|
|
limit_price=limit_price
|
|
)
|
|
self.send_order(order)
|