191 lines
7.7 KiB
Python
191 lines
7.7 KiB
Python
import numpy as np
|
||
import pandas as pd
|
||
import talib
|
||
from collections import deque
|
||
from typing import Optional, Any, List, Dict
|
||
from src.core_data import Bar, Order
|
||
from src.indicators.base_indicators import Indicator
|
||
from src.strategies.base_strategy import Strategy
|
||
import numpy as np
|
||
import talib
|
||
from typing import Optional, Any, List, Dict
|
||
from src.core_data import Bar, Order
|
||
from src.strategies.base_strategy import Strategy
|
||
|
||
|
||
class KalmanTrendFollower(Strategy):
|
||
|
||
|
||
def __init__(
|
||
self,
|
||
context: Any,
|
||
main_symbol: str,
|
||
enable_log: bool,
|
||
trade_volume: int,
|
||
# --- 显著降低灵敏度,过滤日内杂波 ---
|
||
fast_sensitivity: float = 0.05, # 调低:让快线也变得稳重
|
||
slow_sensitivity: float = 0.005, # 极低:慢线只代表大趋势方向
|
||
lookback_variance: int = 60, # 增加窗口,计算更稳定的市场噪声
|
||
# --- 趋势跟踪参数 ---
|
||
atr_period: int = 23,
|
||
entry_threshold: float = 0.5, # 差值需超过0.5倍ATR才考虑入场
|
||
trailing_stop_multiplier: float = 4.0, # 关键:4倍ATR跟踪止损,给趋势留足空间
|
||
structural_stop_multiplier: float = 1.0, # 价格破位慢线多少ATR才出场
|
||
indicator: Indicator = None,
|
||
):
|
||
super().__init__(context, main_symbol, enable_log)
|
||
self.trade_volume = trade_volume
|
||
|
||
# 参数初始化
|
||
self.fast_sensitivity = fast_sensitivity
|
||
self.slow_sensitivity = slow_sensitivity
|
||
self.lookback_variance = lookback_variance
|
||
self.atr_period = atr_period
|
||
self.entry_threshold = entry_threshold
|
||
self.trailing_stop_multiplier = trailing_stop_multiplier
|
||
self.structural_stop_multiplier = structural_stop_multiplier
|
||
self.indicator = indicator
|
||
|
||
# 状态变量
|
||
self.kf_fast = {'x': 0.0, 'P': 1.0}
|
||
self.kf_slow = {'x': 0.0, 'P': 1.0}
|
||
self.kalman_initialized = False
|
||
|
||
self.position_meta: Dict[str, Any] = {}
|
||
self.order_id_counter = 0
|
||
|
||
|
||
def _update_kalman(self, state: dict, measurement: float, Q: float, R: float) -> float:
|
||
p_minus = state['P'] + Q
|
||
k_gain = p_minus / (p_minus + R)
|
||
state['x'] = state['x'] + k_gain * (measurement - state['x'])
|
||
state['P'] = (1 - k_gain) * p_minus
|
||
return state['x']
|
||
|
||
|
||
def on_open_bar(self, open_price: float, symbol: str):
|
||
bar_history = self.get_bar_history()
|
||
if len(bar_history) < 100: return
|
||
|
||
self.cancel_all_pending_orders()
|
||
|
||
closes = np.array([b.close for b in bar_history], dtype=float)
|
||
last_price = closes[-1]
|
||
|
||
# 1. 动态计算卡尔曼参数
|
||
# 增加基础噪声 R,使曲线更平滑
|
||
rolling_var = np.var(closes[-self.lookback_variance:])
|
||
r_base = rolling_var if rolling_var > 0 else 1.0
|
||
|
||
# 计算快慢线
|
||
if not self.kalman_initialized:
|
||
self.kf_fast['x'] = self.kf_slow['x'] = last_price
|
||
self.kalman_initialized = True
|
||
return
|
||
|
||
fast_line = self._update_kalman(self.kf_fast, last_price, r_base * self.fast_sensitivity, r_base)
|
||
slow_line = self._update_kalman(self.kf_slow, last_price, r_base * self.slow_sensitivity, r_base * 5.0)
|
||
|
||
# 2. 计算 ATR 和 趋势指标
|
||
highs = np.array([b.high for b in bar_history], dtype=float)
|
||
lows = np.array([b.low for b in bar_history], dtype=float)
|
||
atr = talib.ATR(highs, lows, closes, self.atr_period)[-1]
|
||
|
||
diff = fast_line - slow_line
|
||
diff_in_atr = diff / atr if atr > 0 else 0
|
||
|
||
# 3. 仓位管理逻辑
|
||
pos = self.get_current_positions().get(symbol, 0)
|
||
|
||
if pos == 0 and (self.indicator is None or self.indicator.is_condition_met(*self.get_indicator_tuple())):
|
||
# --- 入场逻辑:必须形成明显的发散 ---
|
||
if diff_in_atr > self.entry_threshold:
|
||
self.open_trade(symbol, "BUY", open_price, atr, slow_line)
|
||
elif diff_in_atr < -self.entry_threshold:
|
||
self.open_trade(symbol, "SELL", open_price, atr, slow_line)
|
||
else:
|
||
# --- 出场逻辑:保护肥尾收益 ---
|
||
self.manage_exit(symbol, pos, last_price, atr, slow_line)
|
||
|
||
|
||
def open_trade(self, symbol, direction, price, atr, slow_line):
|
||
# 记录入场时的最高/最低价,用于动态跟踪止损
|
||
meta = {
|
||
'entry_price': price,
|
||
'extreme_price': price, # 记录持仓期间到达过的最高(多头)或最低(空头)
|
||
'direction': direction,
|
||
'initial_atr': atr
|
||
}
|
||
self.send_limit_order(symbol, direction, price, self.trade_volume, "OPEN", meta)
|
||
self.log(f"TREND ENTRY {direction}: Price={price}, ATR={atr:.2f}")
|
||
|
||
|
||
def manage_exit(self, symbol, pos, price, atr, slow_line):
|
||
meta = self.position_meta.get(symbol)
|
||
if not meta: return
|
||
|
||
is_long = pos > 0
|
||
should_close = False
|
||
|
||
# 更新持仓期间的极端价格(用于计算吊灯止损)
|
||
if is_long:
|
||
meta['extreme_price'] = max(meta['extreme_price'], price)
|
||
# 吊灯止损位:最高价回落 N 倍 ATR
|
||
chandelier_stop = meta['extreme_price'] - self.trailing_stop_multiplier * atr
|
||
# 结构止损位:跌破慢速趋势线一定距离
|
||
structural_stop = slow_line - self.structural_stop_multiplier * atr
|
||
|
||
# 综合取较严的价格作为保护,但不轻易离场
|
||
if price < max(chandelier_stop, structural_stop):
|
||
should_close = True
|
||
exit_reason = "Trailing/Structural Break"
|
||
else:
|
||
meta['extreme_price'] = min(meta['extreme_price'], price)
|
||
chandelier_stop = meta['extreme_price'] + self.trailing_stop_multiplier * atr
|
||
structural_stop = slow_line + self.structural_stop_multiplier * atr
|
||
|
||
if price > min(chandelier_stop, structural_stop):
|
||
should_close = True
|
||
exit_reason = "Trailing/Structural Break"
|
||
|
||
if should_close:
|
||
direction = "CLOSE_LONG" if is_long else "CLOSE_SHORT"
|
||
self.log(f"EXIT {direction}: Price={price}, Reason={exit_reason}")
|
||
self.close_position(symbol, direction, abs(pos))
|
||
|
||
|
||
# (底层 send_market_order / close_position 同前,注意更新 state 时保留 meta['extreme_price'])
|
||
|
||
# --- 底层封装 ---
|
||
def send_market_order(self, symbol, direction, volume, offset, meta=None):
|
||
if offset == "OPEN": self.position_meta[symbol] = meta
|
||
|
||
order_id = f"{symbol}_{direction}_{self.order_id_counter}"
|
||
self.order_id_counter += 1
|
||
order = Order(id=order_id, symbol=symbol, direction=direction,
|
||
volume=volume, price_type="MARKET", offset=offset)
|
||
self.send_order(order)
|
||
self.save_state(self.position_meta)
|
||
|
||
|
||
def send_limit_order(self, symbol, direction, price, volume, offset, meta=None):
|
||
if offset == "OPEN": self.position_meta[symbol] = meta
|
||
|
||
order_id = f"{symbol}_{direction}_{self.order_id_counter}"
|
||
self.order_id_counter += 1
|
||
order = Order(id=order_id, symbol=symbol, direction=direction,
|
||
volume=volume, price_type="LIMIT", offset=offset, limit_price=price)
|
||
self.send_order(order)
|
||
self.save_state(self.position_meta)
|
||
|
||
|
||
def close_position(self, symbol, direction, volume):
|
||
order_id = f"{symbol}_{direction}_{self.order_id_counter}"
|
||
self.order_id_counter += 1
|
||
order = Order(id=order_id, symbol=symbol, direction=direction,
|
||
volume=volume, price_type="MARKET", offset="CLOSE")
|
||
|
||
self.send_order(order)
|
||
self.position_meta.pop(symbol, None)
|
||
self.save_state(self.position_meta)
|