160 lines
6.1 KiB
Python
160 lines
6.1 KiB
Python
import numpy as np
|
|
import talib
|
|
from typing import Optional, Any, List, Dict
|
|
from src.core_data import Bar, Order
|
|
from src.strategies.base_strategy import Strategy
|
|
|
|
|
|
class KalmanMeanReversion(Strategy):
|
|
"""
|
|
改进版卡尔曼均值回归策略
|
|
1. 以 Slow Line 为止盈目标 (真正的均值回归)
|
|
2. 严格的 ATR 止损 (防止趋势爆发)
|
|
3. 时间衰减退出机制
|
|
4. 趋势斜率过滤
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
context: Any,
|
|
main_symbol: str,
|
|
enable_log: bool,
|
|
trade_volume: int,
|
|
fast_sensitivity: float = 0.05, # 快线灵敏度
|
|
slow_sensitivity: float = 0.01, # 慢线灵敏度 (更平滑作为均值)
|
|
lookback_variance: int = 60,
|
|
atr_period: int = 20,
|
|
entry_threshold: float = 1.2, # 偏离多少个ATR入场
|
|
stop_loss_multiplier: float = 2.0, # 止损倍数
|
|
max_hold_bars: int = 30, # 最大持仓时间
|
|
indicator: Any = None,
|
|
):
|
|
super().__init__(context, main_symbol, enable_log)
|
|
self.trade_volume = trade_volume
|
|
self.fast_sensitivity = fast_sensitivity
|
|
self.slow_sensitivity = slow_sensitivity
|
|
self.lookback_variance = lookback_variance
|
|
self.atr_period = atr_period
|
|
self.entry_threshold = entry_threshold
|
|
self.stop_loss_multiplier = stop_loss_multiplier
|
|
self.max_hold_bars = max_hold_bars
|
|
|
|
self.indicator = indicator
|
|
|
|
# 状态变量
|
|
self.kf_fast = {'x': 0.0, 'P': 1.0}
|
|
self.kf_slow = {'x': 0.0, 'P': 1.0}
|
|
self.kalman_initialized = False
|
|
|
|
# 记录持仓信息
|
|
self.entry_price = 0.0
|
|
self.entry_bar_count = 0
|
|
self.order_id_counter = 0
|
|
|
|
def _update_kalman(self, state: dict, measurement: float, Q: float, R: float) -> float:
|
|
"""递归卡尔曼滤波更新"""
|
|
p_minus = state['P'] + Q
|
|
k_gain = p_minus / (p_minus + R)
|
|
state['x'] = state['x'] + k_gain * (measurement - state['x'])
|
|
state['P'] = (1 - k_gain) * p_minus
|
|
return state['x']
|
|
|
|
def on_open_bar(self, open_price: float, symbol: str):
|
|
bar_history = self.get_bar_history()
|
|
if len(bar_history) < 100:
|
|
return
|
|
|
|
self.cancel_all_pending_orders()
|
|
closes = np.array([b.close for b in bar_history], dtype=float)
|
|
last_price = closes[-1]
|
|
|
|
# 1. 计算自适应噪声 (基于滚动方差)
|
|
rolling_var = np.var(closes[-self.lookback_variance:])
|
|
r_base = rolling_var if rolling_var > 0 else 1.0
|
|
|
|
# 2. 初始化或更新卡尔曼滤波器
|
|
if not self.kalman_initialized:
|
|
self.kf_fast['x'] = self.kf_slow['x'] = last_price
|
|
self.kalman_initialized = True
|
|
return
|
|
|
|
# 快线追踪价格,慢线代表平衡位置
|
|
fast_line = self._update_kalman(self.kf_fast, last_price, r_base * self.fast_sensitivity, r_base)
|
|
slow_line = self._update_kalman(self.kf_slow, last_price, r_base * self.slow_sensitivity, r_base * 10.0)
|
|
|
|
# 3. 计算 ATR 和 偏离度
|
|
highs = np.array([b.high for b in bar_history], dtype=float)
|
|
lows = np.array([b.low for b in bar_history], dtype=float)
|
|
atr = talib.ATR(highs, lows, closes, self.atr_period)[-1]
|
|
|
|
if atr <= 0: return
|
|
|
|
# 计算价格偏离慢线的程度 (Z-Score 的变体)
|
|
diff_in_atr = (last_price - slow_line) / atr
|
|
|
|
# 4. 仓位逻辑
|
|
pos = self.get_current_positions().get(symbol, 0)
|
|
|
|
if pos == 0:
|
|
# --- 入场逻辑 ---
|
|
# 只有在外部指标允许且偏离度足够大时入场
|
|
can_trade = self.indicator is None or self.indicator.is_condition_met(*self.get_indicator_tuple())
|
|
|
|
if can_trade:
|
|
if diff_in_atr > self.entry_threshold:
|
|
# 超买,做空
|
|
self.execute_order(symbol, "SELL", "OPEN", last_price)
|
|
elif diff_in_atr < -self.entry_threshold:
|
|
# 超卖,做多
|
|
self.execute_order(symbol, "BUY", "OPEN", last_price)
|
|
else:
|
|
# --- 出场逻辑 ---
|
|
self.entry_bar_count += 1
|
|
is_long = pos > 0
|
|
should_close = False
|
|
exit_reason = ""
|
|
|
|
# A. 均值回归止盈:触碰或穿过慢线
|
|
if is_long and last_price >= slow_line:
|
|
should_close = True
|
|
exit_reason = "Take Profit: Reached Mean"
|
|
elif not is_long and last_price <= slow_line:
|
|
should_close = True
|
|
exit_reason = "Take Profit: Reached Mean"
|
|
|
|
# B. 严谨止损:背离程度进一步扩大
|
|
elif is_long and last_price < self.entry_price - self.stop_loss_multiplier * atr:
|
|
should_close = True
|
|
exit_reason = "Stop Loss: Deviation Too Large"
|
|
elif not is_long and last_price > self.entry_price + self.stop_loss_multiplier * atr:
|
|
should_close = True
|
|
exit_reason = "Stop Loss: Deviation Too Large"
|
|
|
|
# C. 时间退出:久攻不下,由于均值回归的时效性,超时即走
|
|
elif self.entry_bar_count >= self.max_hold_bars:
|
|
should_close = True
|
|
exit_reason = "Time Exit: Holding Too Long"
|
|
|
|
if should_close:
|
|
direction = "CLOSE_LONG" if is_long else "CLOSE_SHORT"
|
|
self.log(f"EXIT {direction}: Price={last_price}, Reason={exit_reason}")
|
|
self.execute_order(symbol, direction, "CLOSE", last_price)
|
|
|
|
def execute_order(self, symbol, direction, offset, price):
|
|
"""执行订单并更新状态"""
|
|
if offset == "OPEN":
|
|
self.entry_price = price
|
|
self.entry_bar_count = 0
|
|
|
|
order_id = f"{symbol}_{direction}_{self.order_id_counter}"
|
|
self.order_id_counter += 1
|
|
|
|
order = Order(
|
|
id=order_id,
|
|
symbol=symbol,
|
|
direction=direction,
|
|
volume=self.trade_volume,
|
|
price_type="MARKET",
|
|
offset=offset
|
|
)
|
|
self.send_order(order) |