Files
NewQuant/futures_trading_strategies/MA/ValueMigrationStrategy/ValueMigrationStrategy2.py

369 lines
15 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# =====================================================================================
# 以下是新增的 ValueMigrationStrategy 策略代码
# =====================================================================================
from collections import deque
from datetime import timedelta, time
import numpy as np
import pandas as pd
from typing import List, Any, Optional, Dict
import talib
from src.core_data import Bar, Order
from src.strategies.ValueMigrationStrategy.data_class import ProfileStats, calculate_profile_from_bars
from src.strategies.base_strategy import Strategy
# = ===================================================================
# 全局辅助函数 (Global Helper Functions)
# 将这些函数放在文件顶部,以便所有策略类都能调用
# =====================================================================
def compute_price_volume_distribution(bars: List[Bar], tick_size: float) -> Optional[pd.Series]:
"""
[全局函数] 从K线数据中计算出原始的价格-成交量分布。
"""
if not bars:
return None
data = []
# 为了性能我们只处理有限数量的bars防止内存问题
# 在实际应用中,更高效的实现是必要的
for bar in bars[-500:]: # 添加一个安全限制
price_range = np.arange(bar.low, bar.high + tick_size, tick_size)
if len(price_range) == 0 or bar.volume == 0: continue
# 将成交量近似分布到K线覆盖的每个tick上
volume_per_tick = bar.volume / len(price_range)
for price in price_range:
data.append({'price': price, 'volume': volume_per_tick})
if not data:
return None
df = pd.DataFrame(data)
if df.empty:
return None
return df.groupby('price')['volume'].sum().sort_index()
# =====================================================================================
# 以下是最终的、以性能为首要考量的、超高速VP计算模块
# =====================================================================================
def compute_fast_volume_profile(bars: List[Bar], tick_size: float) -> Optional[pd.Series]:
"""
[全局核心函数] 使用“离散重心累加”法超高速地从K线数据中构建成交量剖面图。
该方法将每根K线的全部成交量一次性地归于其加权中心价
(HLC/3)所对齐的tick上。这在保持核心逻辑精确性的同时
实现了计算速度的数量级提升。
Args:
bars: 用于计算的K线历史数据。
tick_size: 合约的最小变动价位。
Returns:
一个代表成交量剖面图的Pandas Series或None。
"""
if not bars:
return None
# 使用字典进行累加这是Python中最快的操作之一
volume_dict = {}
for bar in bars:
if bar.volume == 0:
continue
# 1. 计算K线的加权中心价
center_price = (bar.high + bar.low + bar.close) / 3
# 2. 将中心价对齐到最近的tick
aligned_price = round(center_price / tick_size) * tick_size
# 3. 将该Bar的全部成交量一次性累加到对齐后的价格点上
volume_dict[aligned_price] = volume_dict.get(aligned_price, 0) + bar.volume
if not volume_dict:
return None
# 将最终结果转换为一个有序的Pandas Series
return pd.Series(volume_dict).sort_index()
# 确保在文件顶部导入
from scipy.signal import find_peaks
def find_hvns_with_distance(price_volume_dist: pd.Series, distance_in_ticks: int) -> List[float]:
"""
[全局函数] 使用峰值查找算法根据峰值间的最小距离来识别HVNs。
Args:
price_volume_dist: 价格-成交量分布序列。
distance_in_ticks: 两个HVN之间必须间隔的最小tick数量。
Returns:
一个包含所有被识别出的HVN价格的列表。
"""
if price_volume_dist.empty or len(price_volume_dist) < 3:
return []
# distance参数确保找到的峰值之间至少相隔N个点
peaks_indices, _ = find_peaks(price_volume_dist.values, distance=distance_in_ticks)
if len(peaks_indices) == 0:
return [price_volume_dist.idxmax()] # 默认返回POC
hvn_prices = price_volume_dist.index[peaks_indices].tolist()
return hvn_prices
def find_hvns_strict(price_volume_dist: pd.Series, window_radius: int) -> List[float]:
"""
[全局函数] 使用严格的“滚动窗口最大值”定义来识别HVNs。
一个点是HVN当且仅当它的成交量大于其左右各 `window_radius` 个点的成交量。
Args:
price_volume_dist: 价格-成交量分布序列。
window_radius: 定义了检查窗口的半径 (即您所说的 N)。
Returns:
一个包含所有被识别出的HVN价格的列表。
"""
if price_volume_dist.empty or window_radius == 0:
return [price_volume_dist.idxmax()] if not price_volume_dist.empty else []
# 1. 确保价格序列是连续的用0填充缺失的ticks
full_price_range = np.arange(price_volume_dist.index.min(),
price_volume_dist.index.max() + price_volume_dist.index.to_series().diff().min(),
price_volume_dist.index.to_series().diff().min())
continuous_dist = price_volume_dist.reindex(full_price_range, fill_value=0)
# 2. 计算滚动窗口最大值
window_size = 2 * window_radius + 1
rolling_max = continuous_dist.rolling(window=window_size, center=True).max()
# 3. 找到那些自身成交量就等于其窗口最大值的点
is_hvn = (continuous_dist == rolling_max) & (continuous_dist > 0)
hvn_prices = continuous_dist[is_hvn].index.tolist()
# 4. 处理平顶山如果连续多个点都是HVN只保留中间那个
if not hvn_prices:
return [price_volume_dist.idxmax()] # 如果找不到返回POC
final_hvns = []
i = 0
while i < len(hvn_prices):
# 找到一个连续HVN块
j = i
while j + 1 < len(hvn_prices) and (hvn_prices[j + 1] - hvn_prices[j]) < (
2 * price_volume_dist.index.to_series().diff().min()):
j += 1
# 取这个连续块的中间点
middle_index = i + (j - i) // 2
final_hvns.append(hvn_prices[middle_index])
i = j + 1
return final_hvns
# 确保在文件顶部导入
from scipy.signal import find_peaks
# =====================================================================================
# 以下是V2版本的、简化了状态管理的 HVNPullbackStrategy 代码
# =====================================================================================
class ValueMigrationStrategy(Strategy):
"""
一个基于动态HVN突破后回测的量化交易策略。(V2: 简化状态管理)
V2版本简化了内部状态管理移除了基于order_id的复杂元数据传递
使用更直接、更健壮的单一状态变量来处理挂单的止盈止损参数,
完美适配“单次单持仓”的策略逻辑。
"""
def __init__(
self,
context: Any,
main_symbol: str,
enable_log: bool,
trade_volume: int,
tick_size: float = 1,
profile_period: int = 100,
recalc_interval: int = 4,
hvn_distance_ticks: int = 20,
entry_offset_atr: float = 0.0,
stop_loss_atr: float = 1.0,
take_profit_atr: float = 2.0,
atr_period: int = 14,
order_direction=None,
indicators=[None, None],
):
super().__init__(context, main_symbol, enable_log)
if order_direction is None:
order_direction = ['BUY', 'SELL']
self.trade_volume = trade_volume
self.tick_size = tick_size
self.profile_period = profile_period
self.recalc_interval = recalc_interval
self.hvn_distance_ticks = hvn_distance_ticks
self.entry_offset_atr = entry_offset_atr
self.stop_loss_atr = stop_loss_atr
self.take_profit_atr = take_profit_atr
self.atr_period = atr_period
self.order_direction = order_direction
self.indicator_long = indicators[0]
self.indicator_short = indicators[1]
self.main_symbol = main_symbol
self.order_id_counter = 0
self._bar_counter = 0
self._cached_hvns: List[float] = []
# --- V2: 简化的状态管理 ---
self._pending_sl_price: Optional[float] = None
self._pending_tp_price: Optional[float] = None
def on_open_bar(self, open_price: float, symbol: str):
self.symbol = symbol
self._bar_counter += 1
bar_history = self.get_bar_history()
required_len = max(self.profile_period, self.atr_period) + 1
if len(bar_history) < required_len:
return
# --- 1. 取消所有挂单并重置挂单状态 ---
self.cancel_all_pending_orders(self.symbol)
# self._pending_sl_price = None
# self._pending_tp_price = None
# --- 2. 管理现有持仓 ---
position_volume = self.get_current_positions().get(self.symbol, 0)
if position_volume != 0:
self.manage_open_position(position_volume, open_price)
return
# --- 3. 周期性地计算HVNs ---
if self._bar_counter % self.recalc_interval == 1:
profile_bars = bar_history[-self.profile_period:]
dist = compute_price_volume_distribution(profile_bars, self.tick_size)
# dist = compute_fast_volume_profile(profile_bars, self.tick_size)
if dist is not None and not dist.empty:
# self._cached_hvns = find_hvns_with_distance(dist, self.hvn_distance_ticks)
self._cached_hvns = find_hvns_strict(dist, self.hvn_distance_ticks)
self.log(f"New HVNs identified at: {[f'{p:.2f}' for p in self._cached_hvns]}")
if not self._cached_hvns: return
# --- 4. 评估新机会 (挂单逻辑) ---
self.evaluate_entry_signal(bar_history)
def manage_open_position(self, volume: int, current_price: float):
"""主动管理已开仓位的止盈止损。"""
# # [V2 关键逻辑]: 检测是否为新持仓
# # 如果这是一个新持仓,并且我们有预设的止盈止损,就将其存入
# if self._pending_sl_price is not None and self._pending_tp_price is not None:
# meta = {'sl_price': self._pending_sl_price, 'tp_price': self._pending_tp_price}
# self.position_meta = meta
# self.log(f"新持仓确认。已设置TP/SL: {meta}")
# else:
# # 这种情况理论上不应发生,但作为保护
# self.log("Error: New position detected but no pending TP/SL values found.")
# self.close_position("CLOSE_LONG" if volume > 0 else "CLOSE_SHORT", abs(volume))
# return
# [常规逻辑]: 检查止盈止损
sl_price = self._pending_sl_price
tp_price = self._pending_tp_price
if volume > 0: # 多头
if current_price <= sl_price or current_price >= tp_price:
action = "止损" if current_price <= sl_price else "止盈"
self.log(f"多头{action}触发 at {current_price:.2f}")
self.close_position("CLOSE_LONG", abs(volume), current_price)
elif volume < 0: # 空头
if current_price >= sl_price or current_price <= tp_price:
action = "止损" if current_price >= sl_price else "止盈"
self.log(f"空头{action}触发 at {current_price:.2f}")
self.close_position("CLOSE_SHORT", abs(volume), current_price)
def evaluate_entry_signal(self, bar_history: List[Bar]):
prev_close = bar_history[-2].close
current_close = bar_history[-1].close
highs = np.array([b.high for b in bar_history], dtype=float)
lows = np.array([b.low for b in bar_history], dtype=float)
closes = np.array([b.close for b in bar_history], dtype=float)
current_atr = talib.ATR(highs, lows, closes, self.atr_period)[-1]
if current_atr < self.tick_size: return
for hvn in sorted(self._cached_hvns):
# (为了简洁,买卖逻辑合并)
direction = None
if "BUY" in self.order_direction and (prev_close < hvn < current_close):
direction = "SELL"
pass_filter = self.indicator_long is None or self.indicator_long.is_condition_met(
*self.get_indicator_tuple())
elif "SELL" in self.order_direction and (prev_close > hvn > current_close):
direction = "BUY"
pass_filter = self.indicator_short is None or self.indicator_short.is_condition_met(
*self.get_indicator_tuple())
else:
continue # 没有触发穿越
if direction and pass_filter:
offset = self.entry_offset_atr * current_atr
limit_price = hvn + offset if direction == "BUY" else hvn - offset
self.log(f"价格穿越HVN({hvn:.2f}). 在 {limit_price:.2f} 挂限价{direction}单。")
# self.send_hvn_limit_order(direction, limit_price + 1 if direction == 'BUY' else -1, current_atr)
self.send_hvn_limit_order(direction, limit_price, current_atr)
return # 每次只挂一个单
def send_hvn_limit_order(self, direction: str, limit_price: float, entry_atr: float):
print(limit_price, self.get_current_time())
# [V2 关键逻辑]: 直接更新实例变量
self._pending_sl_price = limit_price - self.stop_loss_atr * entry_atr if direction == "BUY" else limit_price + self.stop_loss_atr * entry_atr
self._pending_tp_price = limit_price + self.take_profit_atr * entry_atr if direction == "BUY" else limit_price - self.take_profit_atr * entry_atr
order_id = f"{self.symbol}_{direction}_LIMIT_{self.order_id_counter}"
self.order_id_counter += 1
# order = Order(
# id=order_id, symbol=self.symbol, direction=direction, volume=self.trade_volume,
# price_type="LIMIT", limit_price=limit_price, submitted_time=self.get_current_time(),
# offset="OPEN"
# )
order = Order(
id=order_id, symbol=self.symbol, direction=direction, volume=self.trade_volume,
price_type="STOP", stop_price=limit_price, submitted_time=self.get_current_time(),
offset="OPEN"
)
self.send_order(order)
def close_position(self, direction: str, volume: int, current_price: float):
self.send_market_order(direction, volume, current_price)
def send_market_order(self, direction: str, volume: int, current_price: float, offset: str = "CLOSE"):
order_id = f"{self.symbol}_{direction}_{offset}_{self.get_current_time().strftime('%Y%m%d%H%M%S')}_{self.order_id_counter}"
self.order_id_counter += 1
order = Order(
id=order_id, symbol=self.symbol, direction=direction, volume=volume,
price_type="MARKET", submitted_time=self.get_current_time(), offset=offset, limit_price=current_price
)
self.send_order(order)