Files
NewQuant/futures_trading_strategies/hc/KalmanTrendFollower/KalmanMeanReversion.py

160 lines
6.1 KiB
Python

import numpy as np
import talib
from typing import Optional, Any, List, Dict
from src.core_data import Bar, Order
from src.strategies.base_strategy import Strategy
class KalmanMeanReversion(Strategy):
"""
改进版卡尔曼均值回归策略
1. 以 Slow Line 为止盈目标 (真正的均值回归)
2. 严格的 ATR 止损 (防止趋势爆发)
3. 时间衰减退出机制
4. 趋势斜率过滤
"""
def __init__(
self,
context: Any,
main_symbol: str,
enable_log: bool,
trade_volume: int,
fast_sensitivity: float = 0.05, # 快线灵敏度
slow_sensitivity: float = 0.01, # 慢线灵敏度 (更平滑作为均值)
lookback_variance: int = 60,
atr_period: int = 20,
entry_threshold: float = 1.2, # 偏离多少个ATR入场
stop_loss_multiplier: float = 2.0, # 止损倍数
max_hold_bars: int = 30, # 最大持仓时间
indicator: Any = None,
):
super().__init__(context, main_symbol, enable_log)
self.trade_volume = trade_volume
self.fast_sensitivity = fast_sensitivity
self.slow_sensitivity = slow_sensitivity
self.lookback_variance = lookback_variance
self.atr_period = atr_period
self.entry_threshold = entry_threshold
self.stop_loss_multiplier = stop_loss_multiplier
self.max_hold_bars = max_hold_bars
self.indicator = indicator
# 状态变量
self.kf_fast = {'x': 0.0, 'P': 1.0}
self.kf_slow = {'x': 0.0, 'P': 1.0}
self.kalman_initialized = False
# 记录持仓信息
self.entry_price = 0.0
self.entry_bar_count = 0
self.order_id_counter = 0
def _update_kalman(self, state: dict, measurement: float, Q: float, R: float) -> float:
"""递归卡尔曼滤波更新"""
p_minus = state['P'] + Q
k_gain = p_minus / (p_minus + R)
state['x'] = state['x'] + k_gain * (measurement - state['x'])
state['P'] = (1 - k_gain) * p_minus
return state['x']
def on_open_bar(self, open_price: float, symbol: str):
bar_history = self.get_bar_history()
if len(bar_history) < 100:
return
self.cancel_all_pending_orders()
closes = np.array([b.close for b in bar_history], dtype=float)
last_price = closes[-1]
# 1. 计算自适应噪声 (基于滚动方差)
rolling_var = np.var(closes[-self.lookback_variance:])
r_base = rolling_var if rolling_var > 0 else 1.0
# 2. 初始化或更新卡尔曼滤波器
if not self.kalman_initialized:
self.kf_fast['x'] = self.kf_slow['x'] = last_price
self.kalman_initialized = True
return
# 快线追踪价格,慢线代表平衡位置
fast_line = self._update_kalman(self.kf_fast, last_price, r_base * self.fast_sensitivity, r_base)
slow_line = self._update_kalman(self.kf_slow, last_price, r_base * self.slow_sensitivity, r_base * 10.0)
# 3. 计算 ATR 和 偏离度
highs = np.array([b.high for b in bar_history], dtype=float)
lows = np.array([b.low for b in bar_history], dtype=float)
atr = talib.ATR(highs, lows, closes, self.atr_period)[-1]
if atr <= 0: return
# 计算价格偏离慢线的程度 (Z-Score 的变体)
diff_in_atr = (last_price - slow_line) / atr
# 4. 仓位逻辑
pos = self.get_current_positions().get(symbol, 0)
if pos == 0:
# --- 入场逻辑 ---
# 只有在外部指标允许且偏离度足够大时入场
can_trade = self.indicator is None or self.indicator.is_condition_met(*self.get_indicator_tuple())
if can_trade:
if diff_in_atr > self.entry_threshold:
# 超买,做空
self.execute_order(symbol, "SELL", "OPEN", last_price)
elif diff_in_atr < -self.entry_threshold:
# 超卖,做多
self.execute_order(symbol, "BUY", "OPEN", last_price)
else:
# --- 出场逻辑 ---
self.entry_bar_count += 1
is_long = pos > 0
should_close = False
exit_reason = ""
# A. 均值回归止盈:触碰或穿过慢线
if is_long and last_price >= slow_line:
should_close = True
exit_reason = "Take Profit: Reached Mean"
elif not is_long and last_price <= slow_line:
should_close = True
exit_reason = "Take Profit: Reached Mean"
# B. 严谨止损:背离程度进一步扩大
elif is_long and last_price < self.entry_price - self.stop_loss_multiplier * atr:
should_close = True
exit_reason = "Stop Loss: Deviation Too Large"
elif not is_long and last_price > self.entry_price + self.stop_loss_multiplier * atr:
should_close = True
exit_reason = "Stop Loss: Deviation Too Large"
# C. 时间退出:久攻不下,由于均值回归的时效性,超时即走
elif self.entry_bar_count >= self.max_hold_bars:
should_close = True
exit_reason = "Time Exit: Holding Too Long"
if should_close:
direction = "CLOSE_LONG" if is_long else "CLOSE_SHORT"
self.log(f"EXIT {direction}: Price={last_price}, Reason={exit_reason}")
self.execute_order(symbol, direction, "CLOSE", last_price)
def execute_order(self, symbol, direction, offset, price):
"""执行订单并更新状态"""
if offset == "OPEN":
self.entry_price = price
self.entry_bar_count = 0
order_id = f"{symbol}_{direction}_{self.order_id_counter}"
self.order_id_counter += 1
order = Order(
id=order_id,
symbol=symbol,
direction=direction,
volume=self.trade_volume,
price_type="MARKET",
offset=offset
)
self.send_order(order)