1、卡尔曼策略

This commit is contained in:
2025-11-07 16:26:00 +08:00
parent 9358dba814
commit 2eec6452ee
42 changed files with 4680 additions and 23709 deletions

File diff suppressed because one or more lines are too long

View File

@@ -1,181 +1,157 @@
import numpy as np
import pandas as pd
import talib
from typing import Optional, Dict, Any, List
from datetime import datetime
from typing import Optional, Any, List, Dict
from src.core_data import Bar, Order
from src.indicators.base_indicators import Indicator
from src.indicators.indicators import Empty
from src.strategies.base_strategy import Strategy
class RsiStrategy(Strategy):
# =============================================================================
# 策略实现 (SimpleRSIStrategy)
# =============================================================================
class SimpleRSIStrategy(Strategy):
"""
反转策略RSI 2-24 → PCA → 模型预测 → 极端值反向开仓
开仓:下一根 Open 价挂限价单
平仓:满足以下任一条件后市价平仓
1. 价格触及固定价差止损线
2. 持有满 holding_bars 根 K 线
一个最简单的RSI交易策略。
核心哲学:
1. 动量跟踪: 遵循RSI指标反映的市场动量。RSI高于60被视为强势市场
适合做多低于40被视为弱势市场适合做空。
2. 信号持续则持仓: 与传统的“超买/超卖”反转逻辑不同,本策略认为只要
市场维持强势RSI > 60或弱势RSI < 40就应该继续持有仓位
直到出现明确的反转信号。
3. 简单的风险控制: 不设置固定的止盈目标,让利润自由发展。只设置一个
基于ATR的动态止损用于管理最基本的下行风险。
"""
def __init__(
self,
context: Any,
main_symbol: str,
model: Any,
pca: Any,
scaler: Any,
lower_bound: float,
upper_bound: float,
trade_volume: int = 1,
enable_log: bool,
trade_volume: int,
# --- 【RSI 指标参数】 ---
rsi_period: int = 14,
rsi_threshold: int = 70,
# --- 【风险管理】 ---
atr_period: int = 14,
stop_loss_atr_multiplier: float = 1.5,
# --- 其他 ---
order_direction: Optional[List[str]] = None,
holding_bars: int = 5,
# --- MODIFICATION START ---
stop_loss_points: Optional[float] = 5, # 止损点数, e.g., 50.0 for 50个价格点
# --- MODIFICATION END ---
enable_log: bool = False,
use_talib: bool = True,
indicators: Optional[List[Indicator]] = None,
):
super().__init__(context, main_symbol, enable_log)
self.main_symbol = main_symbol
if order_direction is None: order_direction = ['BUY', 'SELL']
if indicators is None: indicators = [Empty(), Empty()] # 保持与模板的兼容性
# --- 参数赋值 ---
self.trade_volume = trade_volume
self.model = model
self.pca = pca
self.scaler = scaler
self.lower_bound = lower_bound
self.upper_bound = upper_bound
self.order_direction = order_direction or ["BUY", "SELL"]
self.holding_bars = holding_bars
# --- MODIFICATION START ---
self.stop_loss_points = stop_loss_points
# --- MODIFICATION END ---
self.use_talib = use_talib
self.rsi_period = rsi_period
self.rsi_threshold = rsi_threshold
self.rsi_upper_threshold = self.rsi_threshold
self.rsi_lower_threshold = 100 - self.rsi_threshold
self.atr_period = atr_period
self.stop_loss_atr_multiplier = stop_loss_atr_multiplier
self.order_direction = order_direction
self.close_cache: List[float] = []
self.cache_size = 500
self.pos_meta: Dict[str, Dict[str, Any]] = {}
# --- 内部状态变量 ---
self.main_symbol = main_symbol
self.order_id_counter = 0
self.indicators = indicators # 保持与模板的兼容性
# --- MODIFICATION START ---
log_message = (
f"RsiPcaReversalStrategy 初始化:\n"
f"交易量={self.trade_volume}, lower={self.lower_bound}, upper={self.upper_bound}\n"
f"时间出场={self.holding_bars} bars\n"
f"固定价差止损={self.stop_loss_points if self.stop_loss_points else 'N/A'} points"
)
self.log(log_message)
# --- MODIFICATION END ---
# ... (工具函数保持不变) ...
def update_close_cache(self, bar_history: List[Bar]) -> None:
self.close_cache = [b.close for b in bar_history[-self.cache_size:]]
def calc_rsi_vector(self) -> np.ndarray:
close = np.array(self.close_cache, dtype=float)
rsi_vec = []
for i in range(2, 25):
if self.use_talib:
# talib 版本(最快)
rsi = talib.RSI(close, timeperiod=i)[-1]
else:
# 原滚动均值版本(与旧代码逻辑完全一致)
gain = np.where(np.diff(close) > 0, np.diff(close), 0)
loss = np.where(np.diff(close) < 0, -np.diff(close), 0)
avg_gain = pd.Series(gain).rolling(window=i, min_periods=i).mean().iloc[-1]
avg_loss = pd.Series(loss).rolling(window=i, min_periods=i).mean().iloc[-1]
rs = avg_gain / avg_loss if avg_loss != 0 else 100
rsi = 100 - (100 / (1 + rs))
rsi_vec.append(rsi)
return np.array(rsi_vec)
def predict_ret5(self, rsi_vec: np.ndarray) -> float:
vec_std = self.scaler.transform(rsi_vec.reshape(1, -1))
vec_pca = self.pca.transform(vec_std)
return self.model.predict(vec_pca)[0]
def on_init(self):
super().on_init()
self.pos_meta.clear()
self.log(f"SimpleRSIStrategy Initialized")
def on_open_bar(self, open_price: float, symbol: str):
"""每根K线开盘时被调用"""
self.symbol = symbol
bar_history = self.get_bar_history()
if len(bar_history) < 30:
# 需要足够的数据来计算指标
if len(bar_history) < self.rsi_period + 5:
return
self.cancel_all_pending_orders(symbol)
pos = self.get_current_positions().get(symbol, 0)
position_volume = self.get_current_positions().get(self.symbol, 0)
# 1. 更新缓存 & 计算特征
self.update_close_cache(bar_history)
rsi_vec = self.calc_rsi_vector()
if np.isnan(rsi_vec).any():
return
pred = self.predict_ret5(rsi_vec)
if not self.trading: return
# 2. 平仓逻辑
meta = self.pos_meta.get(symbol)
if meta and pos:
exit_reason = None
# 计算最新的RSI和ATR值
closes = np.array([b.close for b in bar_history], dtype=float)
highs = np.array([b.high for b in bar_history], dtype=float)
lows = np.array([b.low for b in bar_history], dtype=float)
# --- MODIFICATION START ---
# 2a. 检查固定价差止损
if self.stop_loss_points is not None:
current_price = open_price # 使用当前bar的开盘价作为检查价格
entry_price = meta['entry_price']
direction = meta['direction']
current_rsi = talib.RSI(closes, self.rsi_period)[-1]
current_atr = talib.ATR(highs, lows, closes, self.atr_period)[-1]
if direction == "BUY" and current_price <= entry_price - self.stop_loss_points:
exit_reason = f"Stop Loss Hit ({entry_price - self.stop_loss_points})"
elif direction == "SELL" and current_price >= entry_price + self.stop_loss_points:
exit_reason = f"Stop Loss Hit ({entry_price + self.stop_loss_points})"
# --- MODIFICATION END ---
current_bar = bar_history[-1]
# 2b. 检查时间出
if not exit_reason and len(bar_history) >= meta['expiry_idx']:
exit_reason = "Time Expiry"
# 核心逻辑:管理仓位或评估入
if position_volume != 0:
self.manage_open_position(position_volume, current_bar, current_rsi, current_atr)
else:
self.evaluate_entry_signal(open_price, current_rsi, current_atr)
if exit_reason:
self.log(f"平仓信号触发: {exit_reason}")
self.send_market_order(
"CLOSE_LONG" if meta['direction'] == "BUY" else "CLOSE_SHORT",
meta['volume'],
)
del self.pos_meta[symbol]
return
def manage_open_position(self, volume: int, current_bar: Bar, current_rsi: float, current_atr: float):
"""管理持仓:检查止损和反向信号平仓"""
# 3. 开仓逻辑 (保持不变)
if pos == 0:
is_long = volume > 0
# stop_loss_price = meta['stop_loss_price']
# # 1. 检查止损
# if (is_long and current_bar.low <= stop_loss_price) or \
# (not is_long and current_bar.high >= stop_loss_price):
# self.log(f"ATR Stop Loss Hit at {stop_loss_price:.4f}")
# self.close_position("CLOSE_LONG" if is_long else "CLOSE_SHORT", abs(volume))
# return
# 2. 检查反向信号平仓
if is_long and current_rsi < self.rsi_lower_threshold:
self.log(f"RSI dropped below {self.rsi_lower_threshold}, closing long position.")
self.close_position("CLOSE_LONG", abs(volume))
elif not is_long and current_rsi > self.rsi_upper_threshold:
self.log(f"RSI rose above {self.rsi_upper_threshold}, closing short position.")
self.close_position("CLOSE_SHORT", abs(volume))
def evaluate_entry_signal(self, open_price: float, current_rsi: float, current_atr: float):
"""评估入场信号"""
direction = None
# 做多信号
if "BUY" in self.order_direction and current_rsi > self.rsi_upper_threshold and self.indicators[
0].is_condition_met(*self.get_indicator_tuple()):
direction = "BUY"
# 做空信号
elif "SELL" in self.order_direction and current_rsi < self.rsi_lower_threshold and self.indicators[
1].is_condition_met(*self.get_indicator_tuple()):
direction = "SELL"
if direction and current_atr > 0:
entry_price = open_price
if pred < self.lower_bound and "SELL" in self.order_direction:
self.send_limit_order("SELL", entry_price, self.trade_volume, bar_history[-1].datetime)
elif pred > self.upper_bound and "BUY" in self.order_direction:
self.send_limit_order("BUY", entry_price, self.trade_volume, bar_history[-1].datetime)
def send_open_order(self, direction: str, limit_price: float, volume: int, entry_dt: Any):
# (此函数逻辑已在上个版本中更新记录entry_price保持不变)
order_id = f"{self.symbol}_{direction}_{entry_dt.strftime('%Y%m%d%H%M%S')}"
order = Order(
id=order_id, symbol=self.symbol, direction=direction, volume=volume,
price_type="MARKET", submitted_time=entry_dt, offset="OPEN",
)
self.send_order(order)
self.pos_meta[self.symbol] = {
"direction": direction,
"volume": volume,
"expiry_idx": len(self.get_bar_history()) + self.holding_bars,
"entry_price": limit_price
}
self.log(f"开仓信号: {direction} at {limit_price}")
self.log(f"RSI Signal: {direction}. RSI={current_rsi:.2f}. ")
def send_market_order(self, direction: str, volume: int):
# ... (此函数保持不变) ...
order_id = f"{self.symbol}_{direction}_{self.get_current_time().strftime('%Y%m%d%H%M%S')}"
order = Order(
id=order_id, symbol=self.symbol, direction=direction, volume=volume,
price_type="MARKET", submitted_time=self.get_current_time(), offset="CLOSE",
)
self.send_market_order(direction, self.trade_volume, "OPEN")
# --- 辅助函数区 (与模板保持一致) ---
def on_init(self):
super().on_init()
self.cancel_all_pending_orders(self.main_symbol)
def close_position(self, direction: str, volume: int):
self.send_market_order(direction, volume, offset="CLOSE")
def send_market_order(self, direction: str, volume: int, offset: str):
order_id = f"{self.symbol}_{direction}_MARKET_{self.order_id_counter}"
self.order_id_counter += 1
order = Order(id=order_id, symbol=self.symbol, direction=direction, volume=volume, price_type="MARKET",
submitted_time=self.get_current_time(), offset=offset)
self.send_order(order)
def on_rollover(self, old_symbol: str, new_symbol: str):
# ... (此函数保持不变) ...
super().on_rollover(old_symbol, new_symbol)
self.cancel_all_pending_orders(new_symbol)
self.pos_meta.clear()
# 主力合约换月时,清空所有状态,避免旧数据干扰
self.log("Rollover detected. All strategy states have been reset.")

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@@ -1,13 +1,11 @@
import numpy as np
import pandas as pd
from typing import Optional, Dict, Any, List, Union
from typing import Optional, Dict, Any, List
# 假设这些是你项目中的模块
from src.core_data import Bar, Order
from src.indicators.base_indicators import Indicator
from src.indicators.indicators import Empty
from src.strategies.base_strategy import Strategy
from src.algo.TrendLine import calculate_latest_trendline_values
from src.algo.TrendLine import calculate_latest_trendline_values, calculate_latest_trendline_values_v2
class TrendlineHawkesStrategy(Strategy):
@@ -31,7 +29,6 @@ class TrendlineHawkesStrategy(Strategy):
hawkes_entry_percent: float = 0.95,
hawkes_exit_percent: float = 0.50,
enable_log: bool = True,
indicators: Union[Indicator, List[Indicator]] = None,
):
super().__init__(context, main_symbol, enable_log)
# ... 参数赋值与V3完全相同 ...
@@ -45,9 +42,6 @@ class TrendlineHawkesStrategy(Strategy):
self.hawkes_entry_percent = hawkes_entry_percent
self.hawkes_exit_percent = hawkes_exit_percent
self.pos_meta: Dict[str, Dict[str, Any]] = {}
if indicators is None:
indicators = [Empty(), Empty()]
self.indicators = indicators
# --- 【核心修改】状态缓存重构 ---
# 只缓存上一个时间点的霍克斯强度值 (未缩放)
@@ -146,13 +140,13 @@ class TrendlineHawkesStrategy(Strategy):
if pos == 0:
close_prices = np.array([b.close for b in bar_history])
prices_for_trendline = close_prices[-self.trendline_n - 1:-1]
trend_upper, trend_lower = calculate_latest_trendline_values(prices_for_trendline)
trend_upper, trend_lower = calculate_latest_trendline_values_v2(prices_for_trendline)
if trend_upper is not None and trend_lower is not None:
prev_close = bar_history[-2].close
last_close = bar_history[-1].close
upper_break_event = last_close > trend_upper and prev_close < trend_upper and self.indicators[0].is_condition_met(*self.get_indicator_tuple())
lower_break_event = last_close < trend_lower and prev_close > trend_lower and self.indicators[1].is_condition_met(*self.get_indicator_tuple())
upper_break_event = last_close > trend_upper and prev_close < trend_upper
lower_break_event = last_close < trend_lower and prev_close > trend_lower
hawkes_confirmation = latest_hawkes_value > latest_hawkes_upper
if hawkes_confirmation and (upper_break_event or lower_break_event):

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@@ -49,6 +49,51 @@ def compute_price_volume_distribution(bars: List[Bar], tick_size: float) -> Opti
return df.groupby('price')['volume'].sum().sort_index()
# =====================================================================================
# 以下是最终的、以性能为首要考量的、超高速VP计算模块
# =====================================================================================
def compute_fast_volume_profile(bars: List[Bar], tick_size: float) -> Optional[pd.Series]:
"""
[全局核心函数] 使用“离散重心累加”法超高速地从K线数据中构建成交量剖面图。
该方法将每根K线的全部成交量一次性地归于其加权中心价
(HLC/3)所对齐的tick上。这在保持核心逻辑精确性的同时
实现了计算速度的数量级提升。
Args:
bars: 用于计算的K线历史数据。
tick_size: 合约的最小变动价位。
Returns:
一个代表成交量剖面图的Pandas Series或None。
"""
if not bars:
return None
# 使用字典进行累加这是Python中最快的操作之一
volume_dict = {}
for bar in bars:
if bar.volume == 0:
continue
# 1. 计算K线的加权中心价
center_price = (bar.high + bar.low + bar.close) / 3
# 2. 将中心价对齐到最近的tick
aligned_price = round(center_price / tick_size) * tick_size
# 3. 将该Bar的全部成交量一次性累加到对齐后的价格点上
volume_dict[aligned_price] = volume_dict.get(aligned_price, 0) + bar.volume
if not volume_dict:
return None
# 将最终结果转换为一个有序的Pandas Series
return pd.Series(volume_dict).sort_index()
# 确保在文件顶部导入
from scipy.signal import find_peaks
@@ -152,7 +197,7 @@ class ValueMigrationStrategy(Strategy):
enable_log: bool,
trade_volume: int,
tick_size: float = 1,
profile_period: int = 100,
profile_period: int = 60,
recalc_interval: int = 4,
hvn_distance_ticks: int = 20,
entry_offset_atr: float = 0.0,
@@ -214,6 +259,7 @@ class ValueMigrationStrategy(Strategy):
if self._bar_counter % self.recalc_interval == 1:
profile_bars = bar_history[-self.profile_period:]
dist = compute_price_volume_distribution(profile_bars, self.tick_size)
# dist = compute_fast_volume_profile(profile_bars, self.tick_size)
if dist is not None and not dist.empty:
# self._cached_hvns = find_hvns_with_distance(dist, self.hvn_distance_ticks)
self._cached_hvns = find_hvns_strict(dist, self.hvn_distance_ticks)
@@ -222,7 +268,7 @@ class ValueMigrationStrategy(Strategy):
if not self._cached_hvns: return
# --- 4. 评估新机会 (挂单逻辑) ---
self.evaluate_entry_signal(bar_history)
self.evaluate_entry_signal(bar_history, open_price)
def manage_open_position(self, volume: int, current_price: float):
"""主动管理已开仓位的止盈止损。"""
@@ -254,7 +300,7 @@ class ValueMigrationStrategy(Strategy):
self.log(f"空头{action}触发 at {current_price:.2f}")
self.close_position("CLOSE_SHORT", abs(volume))
def evaluate_entry_signal(self, bar_history: List[Bar]):
def evaluate_entry_signal(self, bar_history: List[Bar], current_price: float):
prev_close = bar_history[-2].close
current_close = bar_history[-1].close
@@ -267,20 +313,37 @@ class ValueMigrationStrategy(Strategy):
for hvn in sorted(self._cached_hvns):
# (为了简洁,买卖逻辑合并)
direction = None
if "BUY" in self.order_direction and (prev_close < hvn < current_close):
# if "BUY" in self.order_direction and (prev_close < hvn < current_close):
# direction = "BUY"
# pass_filter = self.indicator_long is None or self.indicator_long.is_condition_met(
# *self.get_indicator_tuple())
# elif "SELL" in self.order_direction and (prev_close > hvn > current_close):
# direction = "SELL"
# pass_filter = self.indicator_short is None or self.indicator_short.is_condition_met(
# *self.get_indicator_tuple())
if "BUY" in self.order_direction and (current_close > hvn and bar_history[-1].low == hvn):
direction = "BUY"
pass_filter = self.indicator_long is None or self.indicator_long.is_condition_met(
*self.get_indicator_tuple())
elif "SELL" in self.order_direction and (prev_close > hvn > current_close):
elif "SELL" in self.order_direction and (current_close < hvn and bar_history[-1].high == hvn):
direction = "SELL"
pass_filter = self.indicator_short is None or self.indicator_short.is_condition_met(
*self.get_indicator_tuple())
# if "BUY" in self.order_direction and (current_close > hvn and bar_history[-1].low == hvn):
# direction = "SELL"
# pass_filter = self.indicator_long is None or self.indicator_long.is_condition_met(
# *self.get_indicator_tuple())
# elif "SELL" in self.order_direction and (current_close < hvn and bar_history[-1].high == hvn):
# direction = "BUY"
# pass_filter = self.indicator_short is None or self.indicator_short.is_condition_met(
# *self.get_indicator_tuple())
else:
continue # 没有触发穿越
if direction and pass_filter:
offset = self.entry_offset_atr * current_atr
limit_price = hvn + offset if direction == "BUY" else hvn - offset
# limit_price = hvn + offset if direction == "BUY" else hvn - offset
limit_price = current_price
self.log(f"价格穿越HVN({hvn:.2f}). 在 {limit_price:.2f} 挂限价{direction}单。")
self.send_hvn_limit_order(direction, limit_price, current_atr)
@@ -296,8 +359,8 @@ class ValueMigrationStrategy(Strategy):
order = Order(
id=order_id, symbol=self.symbol, direction=direction, volume=self.trade_volume,
price_type="LIMIT", limit_price=limit_price, submitted_time=self.get_current_time(),
offset="OPEN"
price_type="LIMIT", submitted_time=self.get_current_time(),
offset="OPEN", limit_price=limit_price
)
self.send_order(order)

View File

@@ -1,5 +1,5 @@
# =====================================================================================
# 以下是新增的 ValueMigrationStrategy 策略代码
# 以下是新增的 ValueMigrationStrategy 策略代码 (已按新需求修改)
# =====================================================================================
from collections import deque
from datetime import timedelta, time
@@ -49,6 +49,51 @@ def compute_price_volume_distribution(bars: List[Bar], tick_size: float) -> Opti
return df.groupby('price')['volume'].sum().sort_index()
# =====================================================================================
# 以下是最终的、以性能为首要考量的、超高速VP计算模块
# =====================================================================================
def compute_fast_volume_profile(bars: List[Bar], tick_size: float) -> Optional[pd.Series]:
"""
[全局核心函数] 使用“离散重心累加”法超高速地从K线数据中构建成交量剖面图。
该方法将每根K线的全部成交量一次性地归于其加权中心价
(HLC/3)所对齐的tick上。这在保持核心逻辑精确性的同时
实现了计算速度的数量级提升。
Args:
bars: 用于计算的K线历史数据。
tick_size: 合约的最小变动价位。
Returns:
一个代表成交量剖面图的Pandas Series或None。
"""
if not bars:
return None
# 使用字典进行累加这是Python中最快的操作之一
volume_dict = {}
for bar in bars:
if bar.volume == 0:
continue
# 1. 计算K线的加权中心价
center_price = (bar.high + bar.low + bar.close) / 3
# 2. 将中心价对齐到最近的tick
aligned_price = round(center_price / tick_size) * tick_size
# 3. 将该Bar的全部成交量一次性累加到对齐后的价格点上
volume_dict[aligned_price] = volume_dict.get(aligned_price, 0) + bar.volume
if not volume_dict:
return None
# 将最终结果转换为一个有序的Pandas Series
return pd.Series(volume_dict).sort_index()
# 确保在文件顶部导入
from scipy.signal import find_peaks
@@ -77,99 +122,114 @@ def find_hvns_with_distance(price_volume_dist: pd.Series, distance_in_ticks: int
return hvn_prices
# 确保在文件顶部导入
from scipy.signal import find_peaks
def find_hvns_strict(price_volume_dist: pd.Series, window_radius: int) -> List[float]:
"""
[全局函数] 使用严格的“滚动窗口最大值”定义来识别HVNs。
一个点是HVN当且仅当它的成交量大于其左右各 `window_radius` 个点的成交量。
Args:
price_volume_dist: 价格-成交量分布序列。
window_radius: 定义了检查窗口的半径 (即您所说的 N)。
Returns:
一个包含所有被识别出的HVN价格的列表。
"""
if price_volume_dist.empty or window_radius == 0:
return [price_volume_dist.idxmax()] if not price_volume_dist.empty else []
# 1. 确保价格序列是连续的用0填充缺失的ticks
full_price_range = np.arange(price_volume_dist.index.min(),
price_volume_dist.index.max() + price_volume_dist.index.to_series().diff().min(),
price_volume_dist.index.to_series().diff().min())
continuous_dist = price_volume_dist.reindex(full_price_range, fill_value=0)
# 2. 计算滚动窗口最大值
window_size = 2 * window_radius + 1
rolling_max = continuous_dist.rolling(window=window_size, center=True).max()
# 3. 找到那些自身成交量就等于其窗口最大值的点
is_hvn = (continuous_dist == rolling_max) & (continuous_dist > 0)
hvn_prices = continuous_dist[is_hvn].index.tolist()
# 4. 处理平顶山如果连续多个点都是HVN只保留中间那个
if not hvn_prices:
return [price_volume_dist.idxmax()] # 如果找不到返回POC
final_hvns = []
i = 0
while i < len(hvn_prices):
# 找到一个连续HVN块
j = i
while j + 1 < len(hvn_prices) and (hvn_prices[j + 1] - hvn_prices[j]) < (
2 * price_volume_dist.index.to_series().diff().min()):
j += 1
# 取这个连续块的中间点
middle_index = i + (j - i) // 2
final_hvns.append(hvn_prices[middle_index])
i = j + 1
return final_hvns
# =====================================================================================
# 以下是V2版本的、简化了状态管理的 HVNPullbackStrategy 代码
# =====================================================================================
# 引入必要的类型,确保代码清晰
from typing import Any, Dict, Optional, List
import numpy as np
import talib
class ValueMigrationStrategy(Strategy):
"""
一个基于动态HVN突破后回测的量化交易策略。(V3: 集成上下文状态管理)
一个基于动态HVN突破后回测的量化交易策略。(V2: 简化状态管理)
V3版本完全集成BacktestContext的状态管理功能实现了策略重启后的状态恢复。
- 状态被简化为两个核心变量_pending_sl_price 和 _pending_tp_price。
- 在策略初始化时安全地加载状态,并兼容空状态或旧版状态
- 在下单或平仓时立即持久化状态,确保数据一致性。
- 增加了逻辑检查,处理重启后可能出现的状态与实际持仓不一致的问题。
V2版本简化了内部状态管理移除了基于order_id的复杂元数据传递
使用更直接、更健壮的单一状态变量来处理挂单的止盈止损参数,
完美适配“单次单持仓”的策略逻辑
"""
def __init__(
self,
context: Any, # 通常会是 BacktestContext
context: Any,
main_symbol: str,
enable_log: bool,
trade_volume: int,
tick_size: float = 1,
profile_period: int = 100,
profile_period: int = 72,
recalc_interval: int = 4,
hvn_distance_ticks: int = 1,
entry_offset_atr: float = 0.0,
stop_loss_atr: float = 1.0,
take_profit_atr: float = 1.0,
hvn_distance_ticks: int = 20,
entry_offset_atr: float = 0.0, # 注意:此参数在新逻辑中不再用于决定挂单价格
stop_loss_atr: float = 0.5,
take_profit_atr: float = 2.0,
atr_period: int = 14,
order_direction=None,
indicators=[None, None],
):
super().__init__(context, main_symbol, enable_log)
# --- 参数初始化 (保持不变) ---
if order_direction is None:
order_direction = ['BUY', 'SELL']
self.trade_volume = trade_volume
self.tick_size = tick_size
self.profile_period = profile_period
self.recalc_interval = recalc_interval
self.hvn_distance_ticks = hvn_distance_ticks
self.entry_offset_atr = entry_offset_atr
self.entry_offset_atr = entry_offset_atr # 保留此参数以兼容旧配置,但进场逻辑已不使用
self.stop_loss_atr = stop_loss_atr
self.take_profit_atr = take_profit_atr
self.atr_period = atr_period
self.order_direction = order_direction
self.indicator_long = indicators[0]
self.indicator_short = indicators[1]
self.main_symbol = main_symbol
self.order_id_counter = 0
self._bar_counter = 0
self._cached_hvns: List[float] = []
# --- 新增: 初始化时加载状态 ---
# --- V2: 简化的状态管理 ---
self._pending_sl_price: Optional[float] = None
self._pending_tp_price: Optional[float] = None
self._load_state_from_context()
def _get_state_dict(self) -> Dict[str, Any]:
"""一个辅助函数,用于生成当前需要保存的状态字典。"""
return {
"_pending_sl_price": self._pending_sl_price,
"_pending_tp_price": self._pending_tp_price,
}
def _load_state_from_context(self):
"""
[新增] 从上下文中加载状态,并进行健壮性处理。
"""
loaded_state = self.context.load_state()
if not loaded_state:
self.log("未找到历史状态,进行全新初始化。")
return
# 使用 .get() 方法安全地读取即使key不存在或state为空也不会报错。
# 这完美解决了“读取的state的key不一样”的问题。
self._pending_sl_price = loaded_state.get("_pending_sl_price")
self._pending_tp_price = loaded_state.get("_pending_tp_price")
if self._pending_sl_price is not None:
self.log(f"成功从上下文加载状态: SL={self._pending_sl_price}, TP={self._pending_tp_price}")
else:
self.log("加载的状态为空或格式不兼容,视为全新初始化。")
def on_open_bar(self, open_price: float, symbol: str):
self.symbol = symbol
@@ -180,134 +240,134 @@ class ValueMigrationStrategy(Strategy):
if len(bar_history) < required_len:
return
# 取消所有挂单这符合原逻辑确保每根bar都是新的开始
# --- 1. 取消所有挂单并重置挂单状态 ---
self.cancel_all_pending_orders(self.symbol)
# --- 2. 管理现有持仓 ---
position_volume = self.get_current_positions().get(self.symbol, 0)
# --- 新增: 状态一致性检查 ---
# 场景:策略重启后,加载了之前的止盈止损状态,但发现实际上并没有持仓
# (可能因为上次平仓后、清空状态前程序就关闭了)。
# 这种情况下,状态是无效的“幽灵状态”,必须清除。
if position_volume == 0 and self._pending_sl_price is not None:
self.log("检测到状态与实际持仓不符 (有状态但无持仓),重置本地状态。")
self._pending_sl_price = None
self._pending_tp_price = None
self.context.save_state(self._get_state_dict()) # 立即同步清除后的状态
# --- 1. 管理现有持仓 (如果存在) ---
if position_volume != 0:
self.manage_open_position(position_volume, open_price)
return
# 周期性地计算HVNs
# --- 3. 周期性地计算HVNs ---
if self._bar_counter % self.recalc_interval == 1:
profile_bars = bar_history[-self.profile_period:]
dist = compute_price_volume_distribution(profile_bars, self.tick_size)
if dist is not None and not dist.empty:
self._cached_hvns = find_hvns_with_distance(dist, self.hvn_distance_ticks)
self.log(f"识别到新的高价值节点: {[f'{p:.2f}' for p in self._cached_hvns]}")
self._cached_hvns = find_hvns_strict(dist, self.hvn_distance_ticks)
self.log(f"New HVNs identified at: {[f'{p:.2f}' for p in self._cached_hvns]}")
if not self._cached_hvns: return
# 评估新机会 (挂单逻辑)
self.evaluate_entry_signal(bar_history)
# --- 4. 评估新机会 (挂单逻辑) ---
self.evaluate_entry_signal(bar_history, open_price)
def manage_open_position(self, volume: int, current_price: float):
"""
[修改] 主动管理已开仓位的止盈止损。
不再使用 position_meta直接依赖实例变量。
"""
# [关键安全检查]: 如果有持仓,但却没有止盈止损状态,这是一个危险的信号。
# 可能是状态文件损坏或逻辑错误。为控制风险,应立即平仓。
if self._pending_sl_price is None or self._pending_tp_price is None:
self.log("风险警告:存在持仓但无有效的止盈止损价格,立即市价平仓!")
self.close_position("CLOSE_LONG" if volume > 0 else "CLOSE_SHORT", abs(volume))
return
"""主动管理已开仓位的止盈止损。"""
sl_price = self._pending_sl_price
tp_price = self._pending_tp_price
# 止盈止损逻辑 (保持不变)
if sl_price is None or tp_price is None:
self.log("错误:持仓存在但找不到对应的止盈止损价格。")
self.close_position("CLOSE_LONG" if volume > 0 else "CLOSE_SHORT", abs(volume))
return
if volume > 0: # 多头
if current_price <= sl_price or current_price >= tp_price:
action = "止损" if current_price <= sl_price else "止盈"
self.log(f"多头{action}触发 {current_price:.2f} (SL: {sl_price}, TP: {tp_price})")
self.log(f"多头{action}触发 at {current_price:.2f}")
self.close_position("CLOSE_LONG", abs(volume))
elif volume < 0: # 空头
if current_price >= sl_price or current_price <= tp_price:
action = "止损" if current_price >= sl_price else "止盈"
self.log(f"空头{action}触发 {current_price:.2f} (SL: {sl_price}, TP: {tp_price})")
self.log(f"空头{action}触发 at {current_price:.2f}")
self.close_position("CLOSE_SHORT", abs(volume))
def evaluate_entry_signal(self, bar_history: List[Bar]):
# [修改] 在挂单前先重置旧的挂单状态虽然on_open_bar开头也做了但这里更保险
self._pending_sl_price = None
self._pending_tp_price = None
# <--- 修改开始: 这是根据您的新需求完全重写的函数 --->
def evaluate_entry_signal(self, bar_history: List[Bar], current_price: float):
"""
新版进场逻辑:
寻找一个“穿越+确认”的模式。
1. K线[-2]的收盘价穿越了某个HVN。
2. K线[-1]的收盘价没有反转而是保持在HVN的同一侧。
3. 如果满足条件则在HVN价格处挂一个限价单。
"""
# 新逻辑需要至少3根K线来判断“穿越+确认”
if len(bar_history) < 3:
return
# ... 原有挂单信号计算逻辑保持不变 ...
prev_close = bar_history[-2].close
current_close = bar_history[-1].close
# 获取最近的三根K线收盘价用于逻辑判断
prev_prev_close = bar_history[-3].close # K线 n-2 (穿越前的K线)
prev_close = bar_history[-2].close # K线 n-1 (穿越K线)
current_close = bar_history[-1].close # K线 n (确认K线)
# ATR计算仍然需要用于后续的止盈止损设置
highs = np.array([b.high for b in bar_history], dtype=float)
lows = np.array([b.low for b in bar_history], dtype=float)
closes = np.array([b.close for b in bar_history], dtype=float)
current_atr = talib.ATR(highs, lows, closes, self.atr_period)[-1]
if current_atr < self.tick_size: return
for hvn in sorted(self._cached_hvns):
if "BUY" in self.order_direction and (prev_close < hvn < current_close):
direction = "BUY"
for hvn in sorted(self._cached_hvns, reverse=True): # 从高到低检查HVN
direction = None
pass_filter = False
# --- 检查买入信号: 向上穿越HVN且下一根K线收盘价仍在HVN之上 ---
if ("BUY" in self.order_direction and
bar_history[-1].low == hvn and # K线[-3]在HVN之下
current_price > hvn): # K线[-1]保持在HVN之上 (确认)
direction = "SELL"
pass_filter = self.indicator_long is None or self.indicator_long.is_condition_met(
*self.get_indicator_tuple())
elif "SELL" in self.order_direction and (prev_close > hvn > current_close):
direction = "SELL"
if pass_filter:
self.log(f"买入信号确认: 价格在K线[-2]向上穿越HVN({hvn:.2f})并在K线[-1]保持在上方。")
# --- 检查卖出信号: 向下穿越HVN且下一根K线收盘价仍在HVN之下 ---
elif ("SELL" in self.order_direction and
bar_history[-1].high == hvn and # K线[-3]在HVN之下
current_price < hvn): # K线[-1]保持在HVN之下 (确认)
direction = "BUY"
pass_filter = self.indicator_short is None or self.indicator_short.is_condition_met(
*self.get_indicator_tuple())
else:
continue
if pass_filter:
self.log(f"卖出信号确认: 价格在K线[-2]向下穿越HVN({hvn:.2f})并在K线[-1]保持在下方。")
# 如果找到了符合条件的信号并通过了过滤器
if direction and pass_filter:
offset = self.entry_offset_atr * current_atr
limit_price = hvn + offset if direction == "BUY" else hvn - offset
self.log(f"价格穿越HVN({hvn:.2f})。在 {limit_price:.2f} 挂限价{direction}单。")
# 新逻辑: 直接在HVN价格挂单不再使用ATR偏移
limit_price = hvn + (1 if direction == "BUY" else -1)
self.log(f"信号有效,准备在HVN({limit_price:.2f})处挂限价{direction}单。")
self.send_hvn_limit_order(direction, limit_price, current_atr)
return
return # 每次只处理一个信号挂一个单然后等待下根K线
def send_hvn_limit_order(self, direction: str, limit_price: float, entry_atr: float):
# 1. 设置实例的止盈止损状态
# [V2 关键逻辑]: 直接更新实例变量,为即将到来的持仓预设止盈止损
self._pending_sl_price = limit_price - self.stop_loss_atr * entry_atr if direction == "BUY" else limit_price + self.stop_loss_atr * entry_atr
self._pending_tp_price = limit_price + self.take_profit_atr * entry_atr if direction == "BUY" else limit_price - self.take_profit_atr * entry_atr
# 2. [新增] 状态已更新,立即通过上下文持久化
self.context.save_state(self._get_state_dict())
self.log(f"状态已更新并保存: SL={self._pending_sl_price}, TP={self._pending_tp_price}")
# 3. 发送订单
order_id = f"{self.symbol}_{direction}_LIMIT_{self.order_id_counter}"
self.order_id_counter += 1
order = Order(
id=order_id, symbol=self.symbol, direction=direction, volume=self.trade_volume,
price_type="LIMIT", limit_price=limit_price, submitted_time=self.get_current_time(),
price_type="MARKET", submitted_time=self.get_current_time(),
offset="OPEN"
)
self.send_order(order)
self.log(f"已发送挂单: {direction} @ {limit_price:.2f}, 预设SL={self._pending_sl_price:.2f}, TP={self._pending_tp_price:.2f}")
def close_position(self, direction: str, volume: int):
"""[修改] 平仓时,必须清空状态并立即保存。"""
# 1. 发送平仓市价单
self.send_market_order(direction, volume)
# 2. 清空本地的止盈止损状态
# 清理止盈止损状态
self._pending_sl_price = None
self._pending_tp_price = None
self.log("仓位已平,重置止盈止损状态。")
# 3. [新增] 状态已清空,立即通过上下文持久化这个“空状态”
self.context.save_state(self._get_state_dict())
self.log("持仓已平,相关的止盈止损状态已清空并保存。")
def send_market_order(self, direction: str, volume: int, offset: str = "CLOSE"):
# ... 此辅助函数保持不变 ...
order_id = f"{self.symbol}_{direction}_{offset}_{self.get_current_time().strftime('%Y%m%d%H%M%S')}_{self.order_id_counter}"
self.order_id_counter += 1
order = Order(

View File

@@ -33,6 +33,7 @@ class Strategy(ABC):
"""
self.context = context # 存储 context 对象
self.symbol = symbol # 策略操作的合约Symbol
self.main_symbol = symbol
self.params = params
self.enable_log = enable_log
self.trading = False
@@ -187,7 +188,7 @@ class Strategy(ABC):
# 你可以将其他 kwargs (如 sep, end, file, flush) 传递给 print
# 但通常日志方法不会频繁使用这些。这里只支持最基础的打印。
print(f"{time_prefix}策略 ({self.symbol}): {message}", **kwargs)
print(f"{time_prefix}策略 ({self.symbol}): {message}")
def on_rollover(self, old_symbol: str, new_symbol: str):
"""
@@ -219,13 +220,21 @@ class Strategy(ABC):
return self._indicator_cache
# 数据有变化,重新创建数组并更新缓存
close = np.array(close_data)
open_price = np.array(self.get_price_history("open"))
high = np.array(self.get_price_history("high"))
low = np.array(self.get_price_history("low"))
volume = np.array(self.get_price_history("volume"))
close = np.array(close_data[-1000:])
open_price = np.array(self.get_price_history("open")[-1000:])
high = np.array(self.get_price_history("high")[-1000:])
low = np.array(self.get_price_history("low")[-1000:])
volume = np.array(self.get_price_history("volume")[-1000:])
self._indicator_cache = (close, open_price, high, low, volume)
self._cache_length = current_length
return self._indicator_cache
def save_state(self, state: Any) -> None:
if self.trading:
self.context.save_state(state)
def load_state(self) -> None:
if self.trading:
self.context.load_state()