Files
NewQuant/futures_trading_strategies/hc/KalmanTrendFollower/KalmanTrendFollower.py

191 lines
7.7 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import numpy as np
import pandas as pd
import talib
from collections import deque
from typing import Optional, Any, List, Dict
from src.core_data import Bar, Order
from src.indicators.base_indicators import Indicator
from src.strategies.base_strategy import Strategy
import numpy as np
import talib
from typing import Optional, Any, List, Dict
from src.core_data import Bar, Order
from src.strategies.base_strategy import Strategy
class KalmanTrendFollower(Strategy):
def __init__(
self,
context: Any,
main_symbol: str,
enable_log: bool,
trade_volume: int,
# --- 显著降低灵敏度,过滤日内杂波 ---
fast_sensitivity: float = 0.05, # 调低:让快线也变得稳重
slow_sensitivity: float = 0.005, # 极低:慢线只代表大趋势方向
lookback_variance: int = 60, # 增加窗口,计算更稳定的市场噪声
# --- 趋势跟踪参数 ---
atr_period: int = 23,
entry_threshold: float = 0.5, # 差值需超过0.5倍ATR才考虑入场
trailing_stop_multiplier: float = 4.0, # 关键4倍ATR跟踪止损给趋势留足空间
structural_stop_multiplier: float = 1.0, # 价格破位慢线多少ATR才出场
indicator: Indicator = None,
):
super().__init__(context, main_symbol, enable_log)
self.trade_volume = trade_volume
# 参数初始化
self.fast_sensitivity = fast_sensitivity
self.slow_sensitivity = slow_sensitivity
self.lookback_variance = lookback_variance
self.atr_period = atr_period
self.entry_threshold = entry_threshold
self.trailing_stop_multiplier = trailing_stop_multiplier
self.structural_stop_multiplier = structural_stop_multiplier
self.indicator = indicator
# 状态变量
self.kf_fast = {'x': 0.0, 'P': 1.0}
self.kf_slow = {'x': 0.0, 'P': 1.0}
self.kalman_initialized = False
self.position_meta: Dict[str, Any] = {}
self.order_id_counter = 0
def _update_kalman(self, state: dict, measurement: float, Q: float, R: float) -> float:
p_minus = state['P'] + Q
k_gain = p_minus / (p_minus + R)
state['x'] = state['x'] + k_gain * (measurement - state['x'])
state['P'] = (1 - k_gain) * p_minus
return state['x']
def on_open_bar(self, open_price: float, symbol: str):
bar_history = self.get_bar_history()
if len(bar_history) < 100: return
self.cancel_all_pending_orders()
closes = np.array([b.close for b in bar_history], dtype=float)
last_price = closes[-1]
# 1. 动态计算卡尔曼参数
# 增加基础噪声 R使曲线更平滑
rolling_var = np.var(closes[-self.lookback_variance:])
r_base = rolling_var if rolling_var > 0 else 1.0
# 计算快慢线
if not self.kalman_initialized:
self.kf_fast['x'] = self.kf_slow['x'] = last_price
self.kalman_initialized = True
return
fast_line = self._update_kalman(self.kf_fast, last_price, r_base * self.fast_sensitivity, r_base)
slow_line = self._update_kalman(self.kf_slow, last_price, r_base * self.slow_sensitivity, r_base * 5.0)
# 2. 计算 ATR 和 趋势指标
highs = np.array([b.high for b in bar_history], dtype=float)
lows = np.array([b.low for b in bar_history], dtype=float)
atr = talib.ATR(highs, lows, closes, self.atr_period)[-1]
diff = fast_line - slow_line
diff_in_atr = diff / atr if atr > 0 else 0
# 3. 仓位管理逻辑
pos = self.get_current_positions().get(symbol, 0)
if pos == 0 and (self.indicator is None or self.indicator.is_condition_met(*self.get_indicator_tuple())):
# --- 入场逻辑:必须形成明显的发散 ---
if diff_in_atr > self.entry_threshold:
self.open_trade(symbol, "BUY", open_price, atr, slow_line)
elif diff_in_atr < -self.entry_threshold:
self.open_trade(symbol, "SELL", open_price, atr, slow_line)
else:
# --- 出场逻辑:保护肥尾收益 ---
self.manage_exit(symbol, pos, last_price, atr, slow_line)
def open_trade(self, symbol, direction, price, atr, slow_line):
# 记录入场时的最高/最低价,用于动态跟踪止损
meta = {
'entry_price': price,
'extreme_price': price, # 记录持仓期间到达过的最高(多头)或最低(空头)
'direction': direction,
'initial_atr': atr
}
self.send_limit_order(symbol, direction, price, self.trade_volume, "OPEN", meta)
self.log(f"TREND ENTRY {direction}: Price={price}, ATR={atr:.2f}")
def manage_exit(self, symbol, pos, price, atr, slow_line):
meta = self.position_meta.get(symbol)
if not meta: return
is_long = pos > 0
should_close = False
# 更新持仓期间的极端价格(用于计算吊灯止损)
if is_long:
meta['extreme_price'] = max(meta['extreme_price'], price)
# 吊灯止损位:最高价回落 N 倍 ATR
chandelier_stop = meta['extreme_price'] - self.trailing_stop_multiplier * atr
# 结构止损位:跌破慢速趋势线一定距离
structural_stop = slow_line - self.structural_stop_multiplier * atr
# 综合取较严的价格作为保护,但不轻易离场
if price < max(chandelier_stop, structural_stop):
should_close = True
exit_reason = "Trailing/Structural Break"
else:
meta['extreme_price'] = min(meta['extreme_price'], price)
chandelier_stop = meta['extreme_price'] + self.trailing_stop_multiplier * atr
structural_stop = slow_line + self.structural_stop_multiplier * atr
if price > min(chandelier_stop, structural_stop):
should_close = True
exit_reason = "Trailing/Structural Break"
if should_close:
direction = "CLOSE_LONG" if is_long else "CLOSE_SHORT"
self.log(f"EXIT {direction}: Price={price}, Reason={exit_reason}")
self.close_position(symbol, direction, abs(pos))
# (底层 send_market_order / close_position 同前,注意更新 state 时保留 meta['extreme_price'])
# --- 底层封装 ---
def send_market_order(self, symbol, direction, volume, offset, meta=None):
if offset == "OPEN": self.position_meta[symbol] = meta
order_id = f"{symbol}_{direction}_{self.order_id_counter}"
self.order_id_counter += 1
order = Order(id=order_id, symbol=symbol, direction=direction,
volume=volume, price_type="MARKET", offset=offset)
self.send_order(order)
self.save_state(self.position_meta)
def send_limit_order(self, symbol, direction, price, volume, offset, meta=None):
if offset == "OPEN": self.position_meta[symbol] = meta
order_id = f"{symbol}_{direction}_{self.order_id_counter}"
self.order_id_counter += 1
order = Order(id=order_id, symbol=symbol, direction=direction,
volume=volume, price_type="LIMIT", offset=offset, limit_price=price)
self.send_order(order)
self.save_state(self.position_meta)
def close_position(self, symbol, direction, volume):
order_id = f"{symbol}_{direction}_{self.order_id_counter}"
self.order_id_counter += 1
order = Order(id=order_id, symbol=symbol, direction=direction,
volume=volume, price_type="MARKET", offset="CLOSE")
self.send_order(order)
self.position_meta.pop(symbol, None)
self.save_state(self.position_meta)