1、卡尔曼策略
This commit is contained in:
@@ -27,8 +27,8 @@ class GridSearchAnalyzer:
|
||||
self.optimization_metric = optimization_metric
|
||||
self.param_names = [k for k in grid_results[0].keys() if k != optimization_metric]
|
||||
|
||||
if len(self.param_names) != 2:
|
||||
raise ValueError("GridSearchAnalyzer 当前只支持分析两个参数的网格搜索结果。")
|
||||
# if len(self.param_names) != 2:
|
||||
# raise ValueError("GridSearchAnalyzer 当前只支持分析两个参数的网格搜索结果。")
|
||||
|
||||
self.param1_name = self.param_names[0]
|
||||
self.param2_name = self.param_names[1]
|
||||
|
||||
@@ -34,32 +34,6 @@ class BacktestContext:
|
||||
self._current_bar: Optional['Bar'] = None
|
||||
self._engine = None
|
||||
|
||||
# --- 新增:状态管理功能 ---
|
||||
|
||||
def save_state(self, state: Dict[str, Any]) -> None:
|
||||
"""
|
||||
保存策略的当前状态。
|
||||
|
||||
策略应在适当的时机(例如,每日结束、策略关闭时)调用此方法
|
||||
来持久化其内部变量。
|
||||
|
||||
Args:
|
||||
state (Dict[str, Any]): 包含策略状态的字典。
|
||||
"""
|
||||
self._state_repository.save(state)
|
||||
|
||||
def load_state(self) -> Dict[str, Any]:
|
||||
"""
|
||||
加载策略的历史状态。
|
||||
|
||||
策略应在初始化时调用此方法来恢复之前的运行状态。
|
||||
如果不存在历史状态,将返回一个空字典。
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: 包含策略历史状态的字典。
|
||||
"""
|
||||
return self._state_repository.load()
|
||||
|
||||
# --- 现有功能保持不变 ---
|
||||
|
||||
def set_current_bar(self, bar: 'Bar'):
|
||||
@@ -102,4 +76,30 @@ class BacktestContext:
|
||||
def is_rollover_bar(self) -> bool:
|
||||
if self._engine:
|
||||
return self._engine.is_rollover_bar
|
||||
return False
|
||||
return False
|
||||
|
||||
# --- 新增:状态管理功能 ---
|
||||
|
||||
def save_state(self, state: Dict[str, Any]) -> None:
|
||||
"""
|
||||
保存策略的当前状态。
|
||||
|
||||
策略应在适当的时机(例如,每日结束、策略关闭时)调用此方法
|
||||
来持久化其内部变量。
|
||||
|
||||
Args:
|
||||
state (Dict[str, Any]): 包含策略状态的字典。
|
||||
"""
|
||||
self._state_repository.save(state)
|
||||
|
||||
def load_state(self) -> Dict[str, Any]:
|
||||
"""
|
||||
加载策略的历史状态。
|
||||
|
||||
策略应在初始化时调用此方法来恢复之前的运行状态。
|
||||
如果不存在历史状态,将返回一个空字典。
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: 包含策略历史状态的字典。
|
||||
"""
|
||||
return self._state_repository.load()
|
||||
|
||||
@@ -64,9 +64,9 @@ class ExecutionSimulator:
|
||||
base_price_for_market_order = current_bar.open
|
||||
|
||||
if order.price_type == "MARKET":
|
||||
if order.direction in ["BUY"]:
|
||||
if order.direction in ["BUY", 'CLOSE_SHORT']:
|
||||
fill_price = (base_price_for_market_order + 1) * (1 + self.slippage_rate)
|
||||
elif order.direction in ["SELL"]:
|
||||
elif order.direction in ["SELL", 'CLOSE_LONG']:
|
||||
fill_price = (base_price_for_market_order - 1) * (1 - self.slippage_rate)
|
||||
else:
|
||||
fill_price = base_price_for_market_order
|
||||
@@ -74,10 +74,10 @@ class ExecutionSimulator:
|
||||
elif order.price_type == "LIMIT" and order.limit_price is not None:
|
||||
limit_price = order.limit_price
|
||||
if order.direction in ["BUY", "CLOSE_SHORT"]:
|
||||
if current_bar.low <= limit_price: # MODIFIED: 使用<=更符合常规逻辑
|
||||
if current_bar.low < limit_price: # MODIFIED: 使用<=更符合常规逻辑
|
||||
fill_price = limit_price * (1 + self.slippage_rate)
|
||||
elif order.direction in ["SELL", "CLOSE_LONG"]:
|
||||
if current_bar.high >= limit_price: # MODIFIED: 使用>=更符合常规逻辑
|
||||
if current_bar.high > limit_price: # MODIFIED: 使用>=更符合常规逻辑
|
||||
fill_price = limit_price * (1 - self.slippage_rate)
|
||||
|
||||
if fill_price <= 0:
|
||||
@@ -98,7 +98,6 @@ class ExecutionSimulator:
|
||||
continue
|
||||
|
||||
order = self.pending_orders[order_id]
|
||||
|
||||
if order.symbol != current_bar.symbol:
|
||||
continue
|
||||
|
||||
|
||||
@@ -44,5 +44,44 @@ class Indicator(ABC):
|
||||
@abstractmethod
|
||||
def get_name(self):
|
||||
pass
|
||||
|
||||
|
||||
|
||||
class CompositeIndicator(Indicator):
|
||||
def __init__(self, indicators: List[Indicator], down_bound=None, up_bound=None, shift_window=0):
|
||||
# 聚合指标通常不使用自身的 bound 和 shift_window,但为兼容基类保留
|
||||
super().__init__(down_bound=down_bound, up_bound=up_bound, shift_window=shift_window)
|
||||
if not indicators:
|
||||
raise ValueError("At least one indicator is required.")
|
||||
self.indicators = indicators
|
||||
|
||||
def get_values(self, close: np.array, open: np.array, high: np.array, low: np.array, volume: np.array):
|
||||
# 聚合指标本身不产生数值序列,返回空数组或 None
|
||||
# 但为保持类型一致,返回一个长度匹配的 dummy array(如全1)
|
||||
# 或者更合理:返回与输入等长的布尔数组(表示每时刻是否所有条件满足)
|
||||
# 这里选择后者,增强实用性
|
||||
n = len(close)
|
||||
result = np.ones(n, dtype=bool)
|
||||
for ind in self.indicators:
|
||||
# 获取每个子指标的 condition 满足情况(需自定义辅助方法)
|
||||
# 但原 Indicator 没有提供 per-timestamp condition,所以简化处理:
|
||||
# 我们只关心最新值,因此 get_values 对 Composite 意义不大
|
||||
pass
|
||||
# 保守起见:返回 None 或抛出 NotImplementedError
|
||||
# 但为避免破坏调用链,返回一个 dummy array
|
||||
return np.full(n, np.nan)
|
||||
|
||||
def is_condition_met(self,
|
||||
close: np.array,
|
||||
open: np.array,
|
||||
high: np.array,
|
||||
low: np.array,
|
||||
volume: np.array):
|
||||
# 关键逻辑:所有子 indicator 的 is_condition_met 必须为 True
|
||||
for indicator in self.indicators:
|
||||
if not indicator.is_condition_met(close, open, high, low, volume):
|
||||
return False
|
||||
return True
|
||||
|
||||
def get_name(self):
|
||||
return '.'.join([indicator.get_name() for indicator in self.indicators])
|
||||
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
from curses import window
|
||||
|
||||
from src.indicators.indicators import *
|
||||
|
||||
INDICATOR_LIST = [
|
||||
@@ -24,7 +22,7 @@ INDICATOR_LIST = [
|
||||
# DifferencedVolumeIndicator(shift_window=20),
|
||||
StochasticOscillator(fastk_period=14, slowd_period=3, slowk_period=3),
|
||||
StochasticOscillator(5, 3, 3),
|
||||
StochasticOscillator(fastk_period=21, slowd_period=5, slowk_period=5),
|
||||
StochasticOscillator(21, 5, 5),
|
||||
RateOfChange(5),
|
||||
RateOfChange(10),
|
||||
RateOfChange(15),
|
||||
@@ -45,7 +43,7 @@ INDICATOR_LIST = [
|
||||
ADX(120),
|
||||
ADX(240),
|
||||
BollingerBandwidth(10, nbdev=1.5),
|
||||
BollingerBandwidth(20, nbdev=2.0),
|
||||
BollingerBandwidth(20, 2.0),
|
||||
BollingerBandwidth(50, nbdev=2.5),
|
||||
PriceRangeToVolatilityRatio(3, 5),
|
||||
PriceRangeToVolatilityRatio(3, 14),
|
||||
@@ -62,4 +60,6 @@ INDICATOR_LIST = [
|
||||
# RelativeVolumeInWindow(3, 21),
|
||||
# RelativeVolumeInWindow(3, 30),
|
||||
# RelativeVolumeInWindow(3, 40),
|
||||
# ZScoreATR(7, 100),
|
||||
# ZScoreATR(14, 100),
|
||||
]
|
||||
|
||||
@@ -589,4 +589,62 @@ class ROC_MA(Indicator):
|
||||
"""
|
||||
返回指标的唯一名称,用于标识和调试。
|
||||
"""
|
||||
return f"roc_ma_{self.roc_window}_{self.ma_window}"
|
||||
return f"roc_ma_{self.roc_window}_{self.ma_window}"
|
||||
|
||||
from numpy.lib.stride_tricks import sliding_window_view
|
||||
|
||||
class ZScoreATR(Indicator):
|
||||
def __init__(
|
||||
self,
|
||||
atr_window: int = 14,
|
||||
z_window: int = 100,
|
||||
down_bound: float = None,
|
||||
up_bound: float = None,
|
||||
):
|
||||
super().__init__(down_bound, up_bound)
|
||||
self.atr_window = atr_window
|
||||
self.z_window = z_window
|
||||
|
||||
def get_values(self, close, open, high, low, volume) -> np.ndarray:
|
||||
n = len(close)
|
||||
min_len = self.atr_window + self.z_window
|
||||
if n < min_len:
|
||||
return np.full(n, np.nan, dtype=np.float64)
|
||||
|
||||
# Step 1: 计算 ATR (NumPy array)
|
||||
atr = talib.ATR(high, low, close, timeperiod=self.atr_window) # shape: (n,)
|
||||
|
||||
# Step 2: 只对有效区域计算 z-score
|
||||
start_idx = self.atr_window - 1 # ATR 从这里开始非 NaN
|
||||
valid_atr = atr[start_idx:] # shape: (n - start_idx,)
|
||||
valid_n = len(valid_atr)
|
||||
|
||||
if valid_n < self.z_window:
|
||||
return np.full(n, np.nan, dtype=np.float64)
|
||||
|
||||
# Step 3: 使用 sliding_window_view 构造滚动窗口(无数据复制)
|
||||
# windows: shape = (valid_n - z_window + 1, z_window)
|
||||
windows = sliding_window_view(valid_atr, window_shape=self.z_window)
|
||||
|
||||
# Step 4: 向量化计算均值和标准差(沿窗口轴)
|
||||
means = np.mean(windows, axis=1) # shape: (M,)
|
||||
stds = np.std(windows, axis=1, ddof=0) # shape: (M,)
|
||||
|
||||
# Step 5: 计算 z-score(当前值是窗口最后一个元素)
|
||||
current_vals = valid_atr[self.z_window - 1:] # 对齐窗口末尾
|
||||
zscores_valid = np.empty_like(valid_atr)
|
||||
zscores_valid[:self.z_window - 1] = np.nan
|
||||
|
||||
# 安全除法:避免除零
|
||||
with np.errstate(divide='ignore', invalid='ignore'):
|
||||
z = (current_vals - means) / stds
|
||||
zscores_valid[self.z_window - 1:] = np.where(stds > 1e-12, z, 0.0)
|
||||
|
||||
# Step 6: 拼回完整长度(前面 ATR 无效部分为 NaN)
|
||||
result = np.full(n, np.nan, dtype=np.float64)
|
||||
result[start_idx:] = zscores_valid
|
||||
|
||||
return result
|
||||
|
||||
def get_name(self):
|
||||
return f"z_atr_{self.atr_window}_{self.z_window}"
|
||||
File diff suppressed because one or more lines are too long
@@ -1,181 +1,157 @@
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import talib
|
||||
from typing import Optional, Dict, Any, List
|
||||
from datetime import datetime
|
||||
from typing import Optional, Any, List, Dict
|
||||
|
||||
from src.core_data import Bar, Order
|
||||
from src.indicators.base_indicators import Indicator
|
||||
from src.indicators.indicators import Empty
|
||||
from src.strategies.base_strategy import Strategy
|
||||
|
||||
|
||||
class RsiStrategy(Strategy):
|
||||
# =============================================================================
|
||||
# 策略实现 (SimpleRSIStrategy)
|
||||
# =============================================================================
|
||||
|
||||
class SimpleRSIStrategy(Strategy):
|
||||
"""
|
||||
反转策略:RSI 2-24 → PCA → 模型预测 → 极端值反向开仓
|
||||
开仓:下一根 Open 价挂限价单
|
||||
平仓:满足以下任一条件后市价平仓
|
||||
1. 价格触及固定价差止损线
|
||||
2. 持有满 holding_bars 根 K 线
|
||||
一个最简单的RSI交易策略。
|
||||
|
||||
核心哲学:
|
||||
1. 动量跟踪: 遵循RSI指标反映的市场动量。RSI高于60被视为强势市场,
|
||||
适合做多;低于40被视为弱势市场,适合做空。
|
||||
2. 信号持续则持仓: 与传统的“超买/超卖”反转逻辑不同,本策略认为只要
|
||||
市场维持强势(RSI > 60)或弱势(RSI < 40),就应该继续持有仓位,
|
||||
直到出现明确的反转信号。
|
||||
3. 简单的风险控制: 不设置固定的止盈目标,让利润自由发展。只设置一个
|
||||
基于ATR的动态止损,用于管理最基本的下行风险。
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
context: Any,
|
||||
main_symbol: str,
|
||||
model: Any,
|
||||
pca: Any,
|
||||
scaler: Any,
|
||||
lower_bound: float,
|
||||
upper_bound: float,
|
||||
trade_volume: int = 1,
|
||||
enable_log: bool,
|
||||
trade_volume: int,
|
||||
# --- 【RSI 指标参数】 ---
|
||||
rsi_period: int = 14,
|
||||
rsi_threshold: int = 70,
|
||||
# --- 【风险管理】 ---
|
||||
atr_period: int = 14,
|
||||
stop_loss_atr_multiplier: float = 1.5,
|
||||
# --- 其他 ---
|
||||
order_direction: Optional[List[str]] = None,
|
||||
holding_bars: int = 5,
|
||||
# --- MODIFICATION START ---
|
||||
stop_loss_points: Optional[float] = 5, # 止损点数, e.g., 50.0 for 50个价格点
|
||||
# --- MODIFICATION END ---
|
||||
enable_log: bool = False,
|
||||
use_talib: bool = True,
|
||||
indicators: Optional[List[Indicator]] = None,
|
||||
):
|
||||
super().__init__(context, main_symbol, enable_log)
|
||||
self.main_symbol = main_symbol
|
||||
if order_direction is None: order_direction = ['BUY', 'SELL']
|
||||
if indicators is None: indicators = [Empty(), Empty()] # 保持与模板的兼容性
|
||||
|
||||
# --- 参数赋值 ---
|
||||
self.trade_volume = trade_volume
|
||||
self.model = model
|
||||
self.pca = pca
|
||||
self.scaler = scaler
|
||||
self.lower_bound = lower_bound
|
||||
self.upper_bound = upper_bound
|
||||
self.order_direction = order_direction or ["BUY", "SELL"]
|
||||
self.holding_bars = holding_bars
|
||||
# --- MODIFICATION START ---
|
||||
self.stop_loss_points = stop_loss_points
|
||||
# --- MODIFICATION END ---
|
||||
self.use_talib = use_talib
|
||||
self.rsi_period = rsi_period
|
||||
self.rsi_threshold = rsi_threshold
|
||||
self.rsi_upper_threshold = self.rsi_threshold
|
||||
self.rsi_lower_threshold = 100 - self.rsi_threshold
|
||||
self.atr_period = atr_period
|
||||
self.stop_loss_atr_multiplier = stop_loss_atr_multiplier
|
||||
self.order_direction = order_direction
|
||||
|
||||
self.close_cache: List[float] = []
|
||||
self.cache_size = 500
|
||||
self.pos_meta: Dict[str, Dict[str, Any]] = {}
|
||||
# --- 内部状态变量 ---
|
||||
self.main_symbol = main_symbol
|
||||
self.order_id_counter = 0
|
||||
self.indicators = indicators # 保持与模板的兼容性
|
||||
|
||||
# --- MODIFICATION START ---
|
||||
log_message = (
|
||||
f"RsiPcaReversalStrategy 初始化:\n"
|
||||
f"交易量={self.trade_volume}, lower={self.lower_bound}, upper={self.upper_bound}\n"
|
||||
f"时间出场={self.holding_bars} bars\n"
|
||||
f"固定价差止损={self.stop_loss_points if self.stop_loss_points else 'N/A'} points"
|
||||
)
|
||||
self.log(log_message)
|
||||
# --- MODIFICATION END ---
|
||||
|
||||
# ... (工具函数保持不变) ...
|
||||
def update_close_cache(self, bar_history: List[Bar]) -> None:
|
||||
self.close_cache = [b.close for b in bar_history[-self.cache_size:]]
|
||||
|
||||
def calc_rsi_vector(self) -> np.ndarray:
|
||||
close = np.array(self.close_cache, dtype=float)
|
||||
rsi_vec = []
|
||||
for i in range(2, 25):
|
||||
if self.use_talib:
|
||||
# talib 版本(最快)
|
||||
rsi = talib.RSI(close, timeperiod=i)[-1]
|
||||
else:
|
||||
# 原滚动均值版本(与旧代码逻辑完全一致)
|
||||
gain = np.where(np.diff(close) > 0, np.diff(close), 0)
|
||||
loss = np.where(np.diff(close) < 0, -np.diff(close), 0)
|
||||
avg_gain = pd.Series(gain).rolling(window=i, min_periods=i).mean().iloc[-1]
|
||||
avg_loss = pd.Series(loss).rolling(window=i, min_periods=i).mean().iloc[-1]
|
||||
rs = avg_gain / avg_loss if avg_loss != 0 else 100
|
||||
rsi = 100 - (100 / (1 + rs))
|
||||
rsi_vec.append(rsi)
|
||||
return np.array(rsi_vec)
|
||||
|
||||
def predict_ret5(self, rsi_vec: np.ndarray) -> float:
|
||||
vec_std = self.scaler.transform(rsi_vec.reshape(1, -1))
|
||||
vec_pca = self.pca.transform(vec_std)
|
||||
return self.model.predict(vec_pca)[0]
|
||||
|
||||
def on_init(self):
|
||||
super().on_init()
|
||||
self.pos_meta.clear()
|
||||
self.log(f"SimpleRSIStrategy Initialized")
|
||||
|
||||
def on_open_bar(self, open_price: float, symbol: str):
|
||||
"""每根K线开盘时被调用"""
|
||||
self.symbol = symbol
|
||||
bar_history = self.get_bar_history()
|
||||
if len(bar_history) < 30:
|
||||
|
||||
# 需要足够的数据来计算指标
|
||||
if len(bar_history) < self.rsi_period + 5:
|
||||
return
|
||||
|
||||
self.cancel_all_pending_orders(symbol)
|
||||
pos = self.get_current_positions().get(symbol, 0)
|
||||
position_volume = self.get_current_positions().get(self.symbol, 0)
|
||||
|
||||
# 1. 更新缓存 & 计算特征
|
||||
self.update_close_cache(bar_history)
|
||||
rsi_vec = self.calc_rsi_vector()
|
||||
if np.isnan(rsi_vec).any():
|
||||
return
|
||||
pred = self.predict_ret5(rsi_vec)
|
||||
if not self.trading: return
|
||||
|
||||
# 2. 平仓逻辑
|
||||
meta = self.pos_meta.get(symbol)
|
||||
if meta and pos:
|
||||
exit_reason = None
|
||||
# 计算最新的RSI和ATR值
|
||||
closes = np.array([b.close for b in bar_history], dtype=float)
|
||||
highs = np.array([b.high for b in bar_history], dtype=float)
|
||||
lows = np.array([b.low for b in bar_history], dtype=float)
|
||||
|
||||
# --- MODIFICATION START ---
|
||||
# 2a. 检查固定价差止损
|
||||
if self.stop_loss_points is not None:
|
||||
current_price = open_price # 使用当前bar的开盘价作为检查价格
|
||||
entry_price = meta['entry_price']
|
||||
direction = meta['direction']
|
||||
current_rsi = talib.RSI(closes, self.rsi_period)[-1]
|
||||
current_atr = talib.ATR(highs, lows, closes, self.atr_period)[-1]
|
||||
|
||||
if direction == "BUY" and current_price <= entry_price - self.stop_loss_points:
|
||||
exit_reason = f"Stop Loss Hit ({entry_price - self.stop_loss_points})"
|
||||
elif direction == "SELL" and current_price >= entry_price + self.stop_loss_points:
|
||||
exit_reason = f"Stop Loss Hit ({entry_price + self.stop_loss_points})"
|
||||
# --- MODIFICATION END ---
|
||||
current_bar = bar_history[-1]
|
||||
|
||||
# 2b. 检查时间出场
|
||||
if not exit_reason and len(bar_history) >= meta['expiry_idx']:
|
||||
exit_reason = "Time Expiry"
|
||||
# 核心逻辑:管理仓位或评估入场
|
||||
if position_volume != 0:
|
||||
self.manage_open_position(position_volume, current_bar, current_rsi, current_atr)
|
||||
else:
|
||||
self.evaluate_entry_signal(open_price, current_rsi, current_atr)
|
||||
|
||||
if exit_reason:
|
||||
self.log(f"平仓信号触发: {exit_reason}")
|
||||
self.send_market_order(
|
||||
"CLOSE_LONG" if meta['direction'] == "BUY" else "CLOSE_SHORT",
|
||||
meta['volume'],
|
||||
)
|
||||
del self.pos_meta[symbol]
|
||||
return
|
||||
def manage_open_position(self, volume: int, current_bar: Bar, current_rsi: float, current_atr: float):
|
||||
"""管理持仓:检查止损和反向信号平仓"""
|
||||
|
||||
# 3. 开仓逻辑 (保持不变)
|
||||
if pos == 0:
|
||||
is_long = volume > 0
|
||||
# stop_loss_price = meta['stop_loss_price']
|
||||
|
||||
# # 1. 检查止损
|
||||
# if (is_long and current_bar.low <= stop_loss_price) or \
|
||||
# (not is_long and current_bar.high >= stop_loss_price):
|
||||
# self.log(f"ATR Stop Loss Hit at {stop_loss_price:.4f}")
|
||||
# self.close_position("CLOSE_LONG" if is_long else "CLOSE_SHORT", abs(volume))
|
||||
# return
|
||||
|
||||
# 2. 检查反向信号平仓
|
||||
if is_long and current_rsi < self.rsi_lower_threshold:
|
||||
self.log(f"RSI dropped below {self.rsi_lower_threshold}, closing long position.")
|
||||
self.close_position("CLOSE_LONG", abs(volume))
|
||||
elif not is_long and current_rsi > self.rsi_upper_threshold:
|
||||
self.log(f"RSI rose above {self.rsi_upper_threshold}, closing short position.")
|
||||
self.close_position("CLOSE_SHORT", abs(volume))
|
||||
|
||||
def evaluate_entry_signal(self, open_price: float, current_rsi: float, current_atr: float):
|
||||
"""评估入场信号"""
|
||||
direction = None
|
||||
|
||||
# 做多信号
|
||||
if "BUY" in self.order_direction and current_rsi > self.rsi_upper_threshold and self.indicators[
|
||||
0].is_condition_met(*self.get_indicator_tuple()):
|
||||
direction = "BUY"
|
||||
# 做空信号
|
||||
elif "SELL" in self.order_direction and current_rsi < self.rsi_lower_threshold and self.indicators[
|
||||
1].is_condition_met(*self.get_indicator_tuple()):
|
||||
direction = "SELL"
|
||||
|
||||
if direction and current_atr > 0:
|
||||
entry_price = open_price
|
||||
if pred < self.lower_bound and "SELL" in self.order_direction:
|
||||
self.send_limit_order("SELL", entry_price, self.trade_volume, bar_history[-1].datetime)
|
||||
elif pred > self.upper_bound and "BUY" in self.order_direction:
|
||||
self.send_limit_order("BUY", entry_price, self.trade_volume, bar_history[-1].datetime)
|
||||
|
||||
def send_open_order(self, direction: str, limit_price: float, volume: int, entry_dt: Any):
|
||||
# (此函数逻辑已在上个版本中更新,记录entry_price,保持不变)
|
||||
order_id = f"{self.symbol}_{direction}_{entry_dt.strftime('%Y%m%d%H%M%S')}"
|
||||
order = Order(
|
||||
id=order_id, symbol=self.symbol, direction=direction, volume=volume,
|
||||
price_type="MARKET", submitted_time=entry_dt, offset="OPEN",
|
||||
)
|
||||
self.send_order(order)
|
||||
|
||||
self.pos_meta[self.symbol] = {
|
||||
"direction": direction,
|
||||
"volume": volume,
|
||||
"expiry_idx": len(self.get_bar_history()) + self.holding_bars,
|
||||
"entry_price": limit_price
|
||||
}
|
||||
self.log(f"开仓信号: {direction} at {limit_price}")
|
||||
self.log(f"RSI Signal: {direction}. RSI={current_rsi:.2f}. ")
|
||||
|
||||
def send_market_order(self, direction: str, volume: int):
|
||||
# ... (此函数保持不变) ...
|
||||
order_id = f"{self.symbol}_{direction}_{self.get_current_time().strftime('%Y%m%d%H%M%S')}"
|
||||
order = Order(
|
||||
id=order_id, symbol=self.symbol, direction=direction, volume=volume,
|
||||
price_type="MARKET", submitted_time=self.get_current_time(), offset="CLOSE",
|
||||
)
|
||||
self.send_market_order(direction, self.trade_volume, "OPEN")
|
||||
|
||||
# --- 辅助函数区 (与模板保持一致) ---
|
||||
def on_init(self):
|
||||
super().on_init()
|
||||
self.cancel_all_pending_orders(self.main_symbol)
|
||||
|
||||
def close_position(self, direction: str, volume: int):
|
||||
self.send_market_order(direction, volume, offset="CLOSE")
|
||||
|
||||
|
||||
def send_market_order(self, direction: str, volume: int, offset: str):
|
||||
order_id = f"{self.symbol}_{direction}_MARKET_{self.order_id_counter}"
|
||||
self.order_id_counter += 1
|
||||
order = Order(id=order_id, symbol=self.symbol, direction=direction, volume=volume, price_type="MARKET",
|
||||
submitted_time=self.get_current_time(), offset=offset)
|
||||
self.send_order(order)
|
||||
|
||||
def on_rollover(self, old_symbol: str, new_symbol: str):
|
||||
# ... (此函数保持不变) ...
|
||||
super().on_rollover(old_symbol, new_symbol)
|
||||
self.cancel_all_pending_orders(new_symbol)
|
||||
self.pos_meta.clear()
|
||||
# 主力合约换月时,清空所有状态,避免旧数据干扰
|
||||
self.log("Rollover detected. All strategy states have been reset.")
|
||||
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
@@ -1,13 +1,11 @@
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from typing import Optional, Dict, Any, List, Union
|
||||
from typing import Optional, Dict, Any, List
|
||||
|
||||
# 假设这些是你项目中的模块
|
||||
from src.core_data import Bar, Order
|
||||
from src.indicators.base_indicators import Indicator
|
||||
from src.indicators.indicators import Empty
|
||||
from src.strategies.base_strategy import Strategy
|
||||
from src.algo.TrendLine import calculate_latest_trendline_values
|
||||
from src.algo.TrendLine import calculate_latest_trendline_values, calculate_latest_trendline_values_v2
|
||||
|
||||
|
||||
class TrendlineHawkesStrategy(Strategy):
|
||||
@@ -31,7 +29,6 @@ class TrendlineHawkesStrategy(Strategy):
|
||||
hawkes_entry_percent: float = 0.95,
|
||||
hawkes_exit_percent: float = 0.50,
|
||||
enable_log: bool = True,
|
||||
indicators: Union[Indicator, List[Indicator]] = None,
|
||||
):
|
||||
super().__init__(context, main_symbol, enable_log)
|
||||
# ... 参数赋值与V3完全相同 ...
|
||||
@@ -45,9 +42,6 @@ class TrendlineHawkesStrategy(Strategy):
|
||||
self.hawkes_entry_percent = hawkes_entry_percent
|
||||
self.hawkes_exit_percent = hawkes_exit_percent
|
||||
self.pos_meta: Dict[str, Dict[str, Any]] = {}
|
||||
if indicators is None:
|
||||
indicators = [Empty(), Empty()]
|
||||
self.indicators = indicators
|
||||
|
||||
# --- 【核心修改】状态缓存重构 ---
|
||||
# 只缓存上一个时间点的霍克斯强度值 (未缩放)
|
||||
@@ -146,13 +140,13 @@ class TrendlineHawkesStrategy(Strategy):
|
||||
if pos == 0:
|
||||
close_prices = np.array([b.close for b in bar_history])
|
||||
prices_for_trendline = close_prices[-self.trendline_n - 1:-1]
|
||||
trend_upper, trend_lower = calculate_latest_trendline_values(prices_for_trendline)
|
||||
trend_upper, trend_lower = calculate_latest_trendline_values_v2(prices_for_trendline)
|
||||
|
||||
if trend_upper is not None and trend_lower is not None:
|
||||
prev_close = bar_history[-2].close
|
||||
last_close = bar_history[-1].close
|
||||
upper_break_event = last_close > trend_upper and prev_close < trend_upper and self.indicators[0].is_condition_met(*self.get_indicator_tuple())
|
||||
lower_break_event = last_close < trend_lower and prev_close > trend_lower and self.indicators[1].is_condition_met(*self.get_indicator_tuple())
|
||||
upper_break_event = last_close > trend_upper and prev_close < trend_upper
|
||||
lower_break_event = last_close < trend_lower and prev_close > trend_lower
|
||||
hawkes_confirmation = latest_hawkes_value > latest_hawkes_upper
|
||||
|
||||
if hawkes_confirmation and (upper_break_event or lower_break_event):
|
||||
|
||||
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
@@ -49,6 +49,51 @@ def compute_price_volume_distribution(bars: List[Bar], tick_size: float) -> Opti
|
||||
return df.groupby('price')['volume'].sum().sort_index()
|
||||
|
||||
|
||||
# =====================================================================================
|
||||
# 以下是最终的、以性能为首要考量的、超高速VP计算模块
|
||||
# =====================================================================================
|
||||
|
||||
def compute_fast_volume_profile(bars: List[Bar], tick_size: float) -> Optional[pd.Series]:
|
||||
"""
|
||||
[全局核心函数] 使用“离散重心累加”法,超高速地从K线数据中构建成交量剖面图。
|
||||
|
||||
该方法将每根K线的全部成交量,一次性地归于其加权中心价
|
||||
(HLC/3)所对齐的tick上。这在保持核心逻辑精确性的同时,
|
||||
实现了计算速度的数量级提升。
|
||||
|
||||
Args:
|
||||
bars: 用于计算的K线历史数据。
|
||||
tick_size: 合约的最小变动价位。
|
||||
|
||||
Returns:
|
||||
一个代表成交量剖面图的Pandas Series,或None。
|
||||
"""
|
||||
if not bars:
|
||||
return None
|
||||
|
||||
# 使用字典进行累加,这是Python中最快的操作之一
|
||||
volume_dict = {}
|
||||
|
||||
for bar in bars:
|
||||
if bar.volume == 0:
|
||||
continue
|
||||
|
||||
# 1. 计算K线的加权中心价
|
||||
center_price = (bar.high + bar.low + bar.close) / 3
|
||||
|
||||
# 2. 将中心价对齐到最近的tick
|
||||
aligned_price = round(center_price / tick_size) * tick_size
|
||||
|
||||
# 3. 将该Bar的全部成交量,一次性累加到对齐后的价格点上
|
||||
volume_dict[aligned_price] = volume_dict.get(aligned_price, 0) + bar.volume
|
||||
|
||||
if not volume_dict:
|
||||
return None
|
||||
|
||||
# 将最终结果转换为一个有序的Pandas Series
|
||||
return pd.Series(volume_dict).sort_index()
|
||||
|
||||
|
||||
# 确保在文件顶部导入
|
||||
from scipy.signal import find_peaks
|
||||
|
||||
@@ -152,7 +197,7 @@ class ValueMigrationStrategy(Strategy):
|
||||
enable_log: bool,
|
||||
trade_volume: int,
|
||||
tick_size: float = 1,
|
||||
profile_period: int = 100,
|
||||
profile_period: int = 60,
|
||||
recalc_interval: int = 4,
|
||||
hvn_distance_ticks: int = 20,
|
||||
entry_offset_atr: float = 0.0,
|
||||
@@ -214,6 +259,7 @@ class ValueMigrationStrategy(Strategy):
|
||||
if self._bar_counter % self.recalc_interval == 1:
|
||||
profile_bars = bar_history[-self.profile_period:]
|
||||
dist = compute_price_volume_distribution(profile_bars, self.tick_size)
|
||||
# dist = compute_fast_volume_profile(profile_bars, self.tick_size)
|
||||
if dist is not None and not dist.empty:
|
||||
# self._cached_hvns = find_hvns_with_distance(dist, self.hvn_distance_ticks)
|
||||
self._cached_hvns = find_hvns_strict(dist, self.hvn_distance_ticks)
|
||||
@@ -222,7 +268,7 @@ class ValueMigrationStrategy(Strategy):
|
||||
if not self._cached_hvns: return
|
||||
|
||||
# --- 4. 评估新机会 (挂单逻辑) ---
|
||||
self.evaluate_entry_signal(bar_history)
|
||||
self.evaluate_entry_signal(bar_history, open_price)
|
||||
|
||||
def manage_open_position(self, volume: int, current_price: float):
|
||||
"""主动管理已开仓位的止盈止损。"""
|
||||
@@ -254,7 +300,7 @@ class ValueMigrationStrategy(Strategy):
|
||||
self.log(f"空头{action}触发 at {current_price:.2f}")
|
||||
self.close_position("CLOSE_SHORT", abs(volume))
|
||||
|
||||
def evaluate_entry_signal(self, bar_history: List[Bar]):
|
||||
def evaluate_entry_signal(self, bar_history: List[Bar], current_price: float):
|
||||
prev_close = bar_history[-2].close
|
||||
current_close = bar_history[-1].close
|
||||
|
||||
@@ -267,20 +313,37 @@ class ValueMigrationStrategy(Strategy):
|
||||
for hvn in sorted(self._cached_hvns):
|
||||
# (为了简洁,买卖逻辑合并)
|
||||
direction = None
|
||||
if "BUY" in self.order_direction and (prev_close < hvn < current_close):
|
||||
# if "BUY" in self.order_direction and (prev_close < hvn < current_close):
|
||||
# direction = "BUY"
|
||||
# pass_filter = self.indicator_long is None or self.indicator_long.is_condition_met(
|
||||
# *self.get_indicator_tuple())
|
||||
# elif "SELL" in self.order_direction and (prev_close > hvn > current_close):
|
||||
# direction = "SELL"
|
||||
# pass_filter = self.indicator_short is None or self.indicator_short.is_condition_met(
|
||||
# *self.get_indicator_tuple())
|
||||
if "BUY" in self.order_direction and (current_close > hvn and bar_history[-1].low == hvn):
|
||||
direction = "BUY"
|
||||
pass_filter = self.indicator_long is None or self.indicator_long.is_condition_met(
|
||||
*self.get_indicator_tuple())
|
||||
elif "SELL" in self.order_direction and (prev_close > hvn > current_close):
|
||||
elif "SELL" in self.order_direction and (current_close < hvn and bar_history[-1].high == hvn):
|
||||
direction = "SELL"
|
||||
pass_filter = self.indicator_short is None or self.indicator_short.is_condition_met(
|
||||
*self.get_indicator_tuple())
|
||||
# if "BUY" in self.order_direction and (current_close > hvn and bar_history[-1].low == hvn):
|
||||
# direction = "SELL"
|
||||
# pass_filter = self.indicator_long is None or self.indicator_long.is_condition_met(
|
||||
# *self.get_indicator_tuple())
|
||||
# elif "SELL" in self.order_direction and (current_close < hvn and bar_history[-1].high == hvn):
|
||||
# direction = "BUY"
|
||||
# pass_filter = self.indicator_short is None or self.indicator_short.is_condition_met(
|
||||
# *self.get_indicator_tuple())
|
||||
else:
|
||||
continue # 没有触发穿越
|
||||
|
||||
if direction and pass_filter:
|
||||
offset = self.entry_offset_atr * current_atr
|
||||
limit_price = hvn + offset if direction == "BUY" else hvn - offset
|
||||
# limit_price = hvn + offset if direction == "BUY" else hvn - offset
|
||||
limit_price = current_price
|
||||
|
||||
self.log(f"价格穿越HVN({hvn:.2f}). 在 {limit_price:.2f} 挂限价{direction}单。")
|
||||
self.send_hvn_limit_order(direction, limit_price, current_atr)
|
||||
@@ -296,8 +359,8 @@ class ValueMigrationStrategy(Strategy):
|
||||
|
||||
order = Order(
|
||||
id=order_id, symbol=self.symbol, direction=direction, volume=self.trade_volume,
|
||||
price_type="LIMIT", limit_price=limit_price, submitted_time=self.get_current_time(),
|
||||
offset="OPEN"
|
||||
price_type="LIMIT", submitted_time=self.get_current_time(),
|
||||
offset="OPEN", limit_price=limit_price
|
||||
)
|
||||
self.send_order(order)
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
# =====================================================================================
|
||||
# 以下是新增的 ValueMigrationStrategy 策略代码
|
||||
# 以下是新增的 ValueMigrationStrategy 策略代码 (已按新需求修改)
|
||||
# =====================================================================================
|
||||
from collections import deque
|
||||
from datetime import timedelta, time
|
||||
@@ -49,6 +49,51 @@ def compute_price_volume_distribution(bars: List[Bar], tick_size: float) -> Opti
|
||||
return df.groupby('price')['volume'].sum().sort_index()
|
||||
|
||||
|
||||
# =====================================================================================
|
||||
# 以下是最终的、以性能为首要考量的、超高速VP计算模块
|
||||
# =====================================================================================
|
||||
|
||||
def compute_fast_volume_profile(bars: List[Bar], tick_size: float) -> Optional[pd.Series]:
|
||||
"""
|
||||
[全局核心函数] 使用“离散重心累加”法,超高速地从K线数据中构建成交量剖面图。
|
||||
|
||||
该方法将每根K线的全部成交量,一次性地归于其加权中心价
|
||||
(HLC/3)所对齐的tick上。这在保持核心逻辑精确性的同时,
|
||||
实现了计算速度的数量级提升。
|
||||
|
||||
Args:
|
||||
bars: 用于计算的K线历史数据。
|
||||
tick_size: 合约的最小变动价位。
|
||||
|
||||
Returns:
|
||||
一个代表成交量剖面图的Pandas Series,或None。
|
||||
"""
|
||||
if not bars:
|
||||
return None
|
||||
|
||||
# 使用字典进行累加,这是Python中最快的操作之一
|
||||
volume_dict = {}
|
||||
|
||||
for bar in bars:
|
||||
if bar.volume == 0:
|
||||
continue
|
||||
|
||||
# 1. 计算K线的加权中心价
|
||||
center_price = (bar.high + bar.low + bar.close) / 3
|
||||
|
||||
# 2. 将中心价对齐到最近的tick
|
||||
aligned_price = round(center_price / tick_size) * tick_size
|
||||
|
||||
# 3. 将该Bar的全部成交量,一次性累加到对齐后的价格点上
|
||||
volume_dict[aligned_price] = volume_dict.get(aligned_price, 0) + bar.volume
|
||||
|
||||
if not volume_dict:
|
||||
return None
|
||||
|
||||
# 将最终结果转换为一个有序的Pandas Series
|
||||
return pd.Series(volume_dict).sort_index()
|
||||
|
||||
|
||||
# 确保在文件顶部导入
|
||||
from scipy.signal import find_peaks
|
||||
|
||||
@@ -77,99 +122,114 @@ def find_hvns_with_distance(price_volume_dist: pd.Series, distance_in_ticks: int
|
||||
return hvn_prices
|
||||
|
||||
|
||||
# 确保在文件顶部导入
|
||||
from scipy.signal import find_peaks
|
||||
def find_hvns_strict(price_volume_dist: pd.Series, window_radius: int) -> List[float]:
|
||||
"""
|
||||
[全局函数] 使用严格的“滚动窗口最大值”定义来识别HVNs。
|
||||
一个点是HVN,当且仅当它的成交量大于其左右各 `window_radius` 个点的成交量。
|
||||
|
||||
Args:
|
||||
price_volume_dist: 价格-成交量分布序列。
|
||||
window_radius: 定义了检查窗口的半径 (即您所说的 N)。
|
||||
|
||||
Returns:
|
||||
一个包含所有被识别出的HVN价格的列表。
|
||||
"""
|
||||
if price_volume_dist.empty or window_radius == 0:
|
||||
return [price_volume_dist.idxmax()] if not price_volume_dist.empty else []
|
||||
|
||||
# 1. 确保价格序列是连续的,用0填充缺失的ticks
|
||||
full_price_range = np.arange(price_volume_dist.index.min(),
|
||||
price_volume_dist.index.max() + price_volume_dist.index.to_series().diff().min(),
|
||||
price_volume_dist.index.to_series().diff().min())
|
||||
continuous_dist = price_volume_dist.reindex(full_price_range, fill_value=0)
|
||||
|
||||
# 2. 计算滚动窗口最大值
|
||||
window_size = 2 * window_radius + 1
|
||||
rolling_max = continuous_dist.rolling(window=window_size, center=True).max()
|
||||
|
||||
# 3. 找到那些自身成交量就等于其窗口最大值的点
|
||||
is_hvn = (continuous_dist == rolling_max) & (continuous_dist > 0)
|
||||
hvn_prices = continuous_dist[is_hvn].index.tolist()
|
||||
|
||||
# 4. 处理平顶山:如果连续多个点都是HVN,只保留中间那个
|
||||
if not hvn_prices:
|
||||
return [price_volume_dist.idxmax()] # 如果找不到,返回POC
|
||||
|
||||
final_hvns = []
|
||||
i = 0
|
||||
while i < len(hvn_prices):
|
||||
# 找到一个连续HVN块
|
||||
j = i
|
||||
while j + 1 < len(hvn_prices) and (hvn_prices[j + 1] - hvn_prices[j]) < (
|
||||
2 * price_volume_dist.index.to_series().diff().min()):
|
||||
j += 1
|
||||
|
||||
# 取这个连续块的中间点
|
||||
middle_index = i + (j - i) // 2
|
||||
final_hvns.append(hvn_prices[middle_index])
|
||||
|
||||
i = j + 1
|
||||
|
||||
return final_hvns
|
||||
|
||||
|
||||
# =====================================================================================
|
||||
# 以下是V2版本的、简化了状态管理的 HVNPullbackStrategy 代码
|
||||
# =====================================================================================
|
||||
|
||||
# 引入必要的类型,确保代码清晰
|
||||
from typing import Any, Dict, Optional, List
|
||||
import numpy as np
|
||||
import talib
|
||||
|
||||
|
||||
class ValueMigrationStrategy(Strategy):
|
||||
"""
|
||||
一个基于动态HVN突破后回测的量化交易策略。(V3: 集成上下文状态管理)
|
||||
一个基于动态HVN突破后回测的量化交易策略。(V2: 简化状态管理)
|
||||
|
||||
V3版本完全集成BacktestContext的状态管理功能,实现了策略重启后的状态恢复。
|
||||
- 状态被简化为两个核心变量:_pending_sl_price 和 _pending_tp_price。
|
||||
- 在策略初始化时安全地加载状态,并兼容空状态或旧版状态。
|
||||
- 在下单或平仓时立即持久化状态,确保数据一致性。
|
||||
- 增加了逻辑检查,处理重启后可能出现的状态与实际持仓不一致的问题。
|
||||
V2版本简化了内部状态管理,移除了基于order_id的复杂元数据传递,
|
||||
使用更直接、更健壮的单一状态变量来处理挂单的止盈止损参数,
|
||||
完美适配“单次单持仓”的策略逻辑。
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
context: Any, # 通常会是 BacktestContext
|
||||
context: Any,
|
||||
main_symbol: str,
|
||||
enable_log: bool,
|
||||
trade_volume: int,
|
||||
tick_size: float = 1,
|
||||
profile_period: int = 100,
|
||||
profile_period: int = 72,
|
||||
recalc_interval: int = 4,
|
||||
hvn_distance_ticks: int = 1,
|
||||
entry_offset_atr: float = 0.0,
|
||||
stop_loss_atr: float = 1.0,
|
||||
take_profit_atr: float = 1.0,
|
||||
hvn_distance_ticks: int = 20,
|
||||
entry_offset_atr: float = 0.0, # 注意:此参数在新逻辑中不再用于决定挂单价格
|
||||
stop_loss_atr: float = 0.5,
|
||||
take_profit_atr: float = 2.0,
|
||||
atr_period: int = 14,
|
||||
order_direction=None,
|
||||
indicators=[None, None],
|
||||
):
|
||||
super().__init__(context, main_symbol, enable_log)
|
||||
# --- 参数初始化 (保持不变) ---
|
||||
if order_direction is None:
|
||||
order_direction = ['BUY', 'SELL']
|
||||
|
||||
self.trade_volume = trade_volume
|
||||
self.tick_size = tick_size
|
||||
self.profile_period = profile_period
|
||||
self.recalc_interval = recalc_interval
|
||||
self.hvn_distance_ticks = hvn_distance_ticks
|
||||
self.entry_offset_atr = entry_offset_atr
|
||||
self.entry_offset_atr = entry_offset_atr # 保留此参数以兼容旧配置,但进场逻辑已不使用
|
||||
self.stop_loss_atr = stop_loss_atr
|
||||
self.take_profit_atr = take_profit_atr
|
||||
self.atr_period = atr_period
|
||||
|
||||
self.order_direction = order_direction
|
||||
self.indicator_long = indicators[0]
|
||||
self.indicator_short = indicators[1]
|
||||
|
||||
self.main_symbol = main_symbol
|
||||
self.order_id_counter = 0
|
||||
|
||||
self._bar_counter = 0
|
||||
self._cached_hvns: List[float] = []
|
||||
|
||||
# --- 新增: 初始化时加载状态 ---
|
||||
# --- V2: 简化的状态管理 ---
|
||||
self._pending_sl_price: Optional[float] = None
|
||||
self._pending_tp_price: Optional[float] = None
|
||||
self._load_state_from_context()
|
||||
|
||||
def _get_state_dict(self) -> Dict[str, Any]:
|
||||
"""一个辅助函数,用于生成当前需要保存的状态字典。"""
|
||||
return {
|
||||
"_pending_sl_price": self._pending_sl_price,
|
||||
"_pending_tp_price": self._pending_tp_price,
|
||||
}
|
||||
|
||||
def _load_state_from_context(self):
|
||||
"""
|
||||
[新增] 从上下文中加载状态,并进行健壮性处理。
|
||||
"""
|
||||
loaded_state = self.context.load_state()
|
||||
if not loaded_state:
|
||||
self.log("未找到历史状态,进行全新初始化。")
|
||||
return
|
||||
|
||||
# 使用 .get() 方法安全地读取,即使key不存在或state为空也不会报错。
|
||||
# 这完美解决了“读取的state的key不一样”的问题。
|
||||
self._pending_sl_price = loaded_state.get("_pending_sl_price")
|
||||
self._pending_tp_price = loaded_state.get("_pending_tp_price")
|
||||
|
||||
if self._pending_sl_price is not None:
|
||||
self.log(f"成功从上下文加载状态: SL={self._pending_sl_price}, TP={self._pending_tp_price}")
|
||||
else:
|
||||
self.log("加载的状态为空或格式不兼容,视为全新初始化。")
|
||||
|
||||
def on_open_bar(self, open_price: float, symbol: str):
|
||||
self.symbol = symbol
|
||||
@@ -180,134 +240,134 @@ class ValueMigrationStrategy(Strategy):
|
||||
if len(bar_history) < required_len:
|
||||
return
|
||||
|
||||
# 取消所有挂单,这符合原逻辑,确保每根bar都是新的开始
|
||||
# --- 1. 取消所有挂单并重置挂单状态 ---
|
||||
self.cancel_all_pending_orders(self.symbol)
|
||||
|
||||
# --- 2. 管理现有持仓 ---
|
||||
position_volume = self.get_current_positions().get(self.symbol, 0)
|
||||
|
||||
# --- 新增: 状态一致性检查 ---
|
||||
# 场景:策略重启后,加载了之前的止盈止损状态,但发现实际上并没有持仓
|
||||
# (可能因为上次平仓后、清空状态前程序就关闭了)。
|
||||
# 这种情况下,状态是无效的“幽灵状态”,必须清除。
|
||||
if position_volume == 0 and self._pending_sl_price is not None:
|
||||
self.log("检测到状态与实际持仓不符 (有状态但无持仓),重置本地状态。")
|
||||
self._pending_sl_price = None
|
||||
self._pending_tp_price = None
|
||||
self.context.save_state(self._get_state_dict()) # 立即同步清除后的状态
|
||||
|
||||
# --- 1. 管理现有持仓 (如果存在) ---
|
||||
if position_volume != 0:
|
||||
self.manage_open_position(position_volume, open_price)
|
||||
return
|
||||
|
||||
|
||||
# 周期性地计算HVNs
|
||||
# --- 3. 周期性地计算HVNs ---
|
||||
if self._bar_counter % self.recalc_interval == 1:
|
||||
profile_bars = bar_history[-self.profile_period:]
|
||||
dist = compute_price_volume_distribution(profile_bars, self.tick_size)
|
||||
if dist is not None and not dist.empty:
|
||||
self._cached_hvns = find_hvns_with_distance(dist, self.hvn_distance_ticks)
|
||||
self.log(f"识别到新的高价值节点: {[f'{p:.2f}' for p in self._cached_hvns]}")
|
||||
self._cached_hvns = find_hvns_strict(dist, self.hvn_distance_ticks)
|
||||
self.log(f"New HVNs identified at: {[f'{p:.2f}' for p in self._cached_hvns]}")
|
||||
|
||||
if not self._cached_hvns: return
|
||||
|
||||
# 评估新机会 (挂单逻辑)
|
||||
self.evaluate_entry_signal(bar_history)
|
||||
# --- 4. 评估新机会 (挂单逻辑) ---
|
||||
self.evaluate_entry_signal(bar_history, open_price)
|
||||
|
||||
def manage_open_position(self, volume: int, current_price: float):
|
||||
"""
|
||||
[修改] 主动管理已开仓位的止盈止损。
|
||||
不再使用 position_meta,直接依赖实例变量。
|
||||
"""
|
||||
# [关键安全检查]: 如果有持仓,但却没有止盈止损状态,这是一个危险的信号。
|
||||
# 可能是状态文件损坏或逻辑错误。为控制风险,应立即平仓。
|
||||
if self._pending_sl_price is None or self._pending_tp_price is None:
|
||||
self.log("风险警告:存在持仓但无有效的止盈止损价格,立即市价平仓!")
|
||||
self.close_position("CLOSE_LONG" if volume > 0 else "CLOSE_SHORT", abs(volume))
|
||||
return
|
||||
|
||||
"""主动管理已开仓位的止盈止损。"""
|
||||
sl_price = self._pending_sl_price
|
||||
tp_price = self._pending_tp_price
|
||||
|
||||
# 止盈止损逻辑 (保持不变)
|
||||
if sl_price is None or tp_price is None:
|
||||
self.log("错误:持仓存在但找不到对应的止盈止损价格。")
|
||||
self.close_position("CLOSE_LONG" if volume > 0 else "CLOSE_SHORT", abs(volume))
|
||||
return
|
||||
|
||||
if volume > 0: # 多头
|
||||
if current_price <= sl_price or current_price >= tp_price:
|
||||
action = "止损" if current_price <= sl_price else "止盈"
|
||||
self.log(f"多头{action}触发于 {current_price:.2f} (SL: {sl_price}, TP: {tp_price})")
|
||||
self.log(f"多头{action}触发 at {current_price:.2f}")
|
||||
self.close_position("CLOSE_LONG", abs(volume))
|
||||
elif volume < 0: # 空头
|
||||
if current_price >= sl_price or current_price <= tp_price:
|
||||
action = "止损" if current_price >= sl_price else "止盈"
|
||||
self.log(f"空头{action}触发于 {current_price:.2f} (SL: {sl_price}, TP: {tp_price})")
|
||||
self.log(f"空头{action}触发 at {current_price:.2f}")
|
||||
self.close_position("CLOSE_SHORT", abs(volume))
|
||||
|
||||
def evaluate_entry_signal(self, bar_history: List[Bar]):
|
||||
# [修改] 在挂单前,先重置旧的挂单状态,虽然on_open_bar开头也做了,但这里更保险
|
||||
self._pending_sl_price = None
|
||||
self._pending_tp_price = None
|
||||
# <--- 修改开始: 这是根据您的新需求完全重写的函数 --->
|
||||
def evaluate_entry_signal(self, bar_history: List[Bar], current_price: float):
|
||||
"""
|
||||
新版进场逻辑:
|
||||
寻找一个“穿越+确认”的模式。
|
||||
1. K线[-2]的收盘价穿越了某个HVN。
|
||||
2. K线[-1]的收盘价没有反转,而是保持在HVN的同一侧。
|
||||
3. 如果满足条件,则在HVN价格处挂一个限价单。
|
||||
"""
|
||||
# 新逻辑需要至少3根K线来判断“穿越+确认”
|
||||
if len(bar_history) < 3:
|
||||
return
|
||||
|
||||
# ... 原有挂单信号计算逻辑保持不变 ...
|
||||
prev_close = bar_history[-2].close
|
||||
current_close = bar_history[-1].close
|
||||
# 获取最近的三根K线收盘价,用于逻辑判断
|
||||
prev_prev_close = bar_history[-3].close # K线 n-2 (穿越前的K线)
|
||||
prev_close = bar_history[-2].close # K线 n-1 (穿越K线)
|
||||
current_close = bar_history[-1].close # K线 n (确认K线)
|
||||
|
||||
# ATR计算仍然需要,用于后续的止盈止损设置
|
||||
highs = np.array([b.high for b in bar_history], dtype=float)
|
||||
lows = np.array([b.low for b in bar_history], dtype=float)
|
||||
closes = np.array([b.close for b in bar_history], dtype=float)
|
||||
current_atr = talib.ATR(highs, lows, closes, self.atr_period)[-1]
|
||||
if current_atr < self.tick_size: return
|
||||
|
||||
for hvn in sorted(self._cached_hvns):
|
||||
if "BUY" in self.order_direction and (prev_close < hvn < current_close):
|
||||
direction = "BUY"
|
||||
for hvn in sorted(self._cached_hvns, reverse=True): # 从高到低检查HVN
|
||||
direction = None
|
||||
pass_filter = False
|
||||
|
||||
# --- 检查买入信号: 向上穿越HVN,且下一根K线收盘价仍在HVN之上 ---
|
||||
if ("BUY" in self.order_direction and
|
||||
bar_history[-1].low == hvn and # K线[-3]在HVN之下
|
||||
current_price > hvn): # K线[-1]保持在HVN之上 (确认)
|
||||
|
||||
direction = "SELL"
|
||||
pass_filter = self.indicator_long is None or self.indicator_long.is_condition_met(
|
||||
*self.get_indicator_tuple())
|
||||
elif "SELL" in self.order_direction and (prev_close > hvn > current_close):
|
||||
direction = "SELL"
|
||||
if pass_filter:
|
||||
self.log(f"买入信号确认: 价格在K线[-2]向上穿越HVN({hvn:.2f}),并在K线[-1]保持在上方。")
|
||||
|
||||
# --- 检查卖出信号: 向下穿越HVN,且下一根K线收盘价仍在HVN之下 ---
|
||||
elif ("SELL" in self.order_direction and
|
||||
bar_history[-1].high == hvn and # K线[-3]在HVN之下
|
||||
current_price < hvn): # K线[-1]保持在HVN之下 (确认)
|
||||
|
||||
direction = "BUY"
|
||||
pass_filter = self.indicator_short is None or self.indicator_short.is_condition_met(
|
||||
*self.get_indicator_tuple())
|
||||
else:
|
||||
continue
|
||||
if pass_filter:
|
||||
self.log(f"卖出信号确认: 价格在K线[-2]向下穿越HVN({hvn:.2f}),并在K线[-1]保持在下方。")
|
||||
|
||||
# 如果找到了符合条件的信号并通过了过滤器
|
||||
if direction and pass_filter:
|
||||
offset = self.entry_offset_atr * current_atr
|
||||
limit_price = hvn + offset if direction == "BUY" else hvn - offset
|
||||
self.log(f"价格穿越HVN({hvn:.2f})。在 {limit_price:.2f} 挂限价{direction}单。")
|
||||
# 新逻辑: 直接在HVN价格挂单,不再使用ATR偏移
|
||||
limit_price = hvn + (1 if direction == "BUY" else -1)
|
||||
self.log(f"信号有效,准备在HVN({limit_price:.2f})处挂限价{direction}单。")
|
||||
self.send_hvn_limit_order(direction, limit_price, current_atr)
|
||||
return
|
||||
return # 每次只处理一个信号,挂一个单,然后等待下根K线
|
||||
|
||||
def send_hvn_limit_order(self, direction: str, limit_price: float, entry_atr: float):
|
||||
# 1. 设置实例的止盈止损状态
|
||||
# [V2 关键逻辑]: 直接更新实例变量,为即将到来的持仓预设止盈止损
|
||||
self._pending_sl_price = limit_price - self.stop_loss_atr * entry_atr if direction == "BUY" else limit_price + self.stop_loss_atr * entry_atr
|
||||
self._pending_tp_price = limit_price + self.take_profit_atr * entry_atr if direction == "BUY" else limit_price - self.take_profit_atr * entry_atr
|
||||
|
||||
# 2. [新增] 状态已更新,立即通过上下文持久化
|
||||
self.context.save_state(self._get_state_dict())
|
||||
self.log(f"状态已更新并保存: SL={self._pending_sl_price}, TP={self._pending_tp_price}")
|
||||
|
||||
# 3. 发送订单
|
||||
order_id = f"{self.symbol}_{direction}_LIMIT_{self.order_id_counter}"
|
||||
self.order_id_counter += 1
|
||||
|
||||
order = Order(
|
||||
id=order_id, symbol=self.symbol, direction=direction, volume=self.trade_volume,
|
||||
price_type="LIMIT", limit_price=limit_price, submitted_time=self.get_current_time(),
|
||||
price_type="MARKET", submitted_time=self.get_current_time(),
|
||||
offset="OPEN"
|
||||
)
|
||||
self.send_order(order)
|
||||
self.log(f"已发送挂单: {direction} @ {limit_price:.2f}, 预设SL={self._pending_sl_price:.2f}, TP={self._pending_tp_price:.2f}")
|
||||
|
||||
|
||||
def close_position(self, direction: str, volume: int):
|
||||
"""[修改] 平仓时,必须清空状态并立即保存。"""
|
||||
# 1. 发送平仓市价单
|
||||
self.send_market_order(direction, volume)
|
||||
|
||||
# 2. 清空本地的止盈止损状态
|
||||
# 清理止盈止损状态
|
||||
self._pending_sl_price = None
|
||||
self._pending_tp_price = None
|
||||
self.log("仓位已平,重置止盈止损状态。")
|
||||
|
||||
# 3. [新增] 状态已清空,立即通过上下文持久化这个“空状态”
|
||||
self.context.save_state(self._get_state_dict())
|
||||
self.log("持仓已平,相关的止盈止损状态已清空并保存。")
|
||||
|
||||
def send_market_order(self, direction: str, volume: int, offset: str = "CLOSE"):
|
||||
# ... 此辅助函数保持不变 ...
|
||||
order_id = f"{self.symbol}_{direction}_{offset}_{self.get_current_time().strftime('%Y%m%d%H%M%S')}_{self.order_id_counter}"
|
||||
self.order_id_counter += 1
|
||||
order = Order(
|
||||
|
||||
@@ -33,6 +33,7 @@ class Strategy(ABC):
|
||||
"""
|
||||
self.context = context # 存储 context 对象
|
||||
self.symbol = symbol # 策略操作的合约Symbol
|
||||
self.main_symbol = symbol
|
||||
self.params = params
|
||||
self.enable_log = enable_log
|
||||
self.trading = False
|
||||
@@ -187,7 +188,7 @@ class Strategy(ABC):
|
||||
|
||||
# 你可以将其他 kwargs (如 sep, end, file, flush) 传递给 print,
|
||||
# 但通常日志方法不会频繁使用这些。这里只支持最基础的打印。
|
||||
print(f"{time_prefix}策略 ({self.symbol}): {message}", **kwargs)
|
||||
print(f"{time_prefix}策略 ({self.symbol}): {message}")
|
||||
|
||||
def on_rollover(self, old_symbol: str, new_symbol: str):
|
||||
"""
|
||||
@@ -219,13 +220,21 @@ class Strategy(ABC):
|
||||
return self._indicator_cache
|
||||
|
||||
# 数据有变化,重新创建数组并更新缓存
|
||||
close = np.array(close_data)
|
||||
open_price = np.array(self.get_price_history("open"))
|
||||
high = np.array(self.get_price_history("high"))
|
||||
low = np.array(self.get_price_history("low"))
|
||||
volume = np.array(self.get_price_history("volume"))
|
||||
close = np.array(close_data[-1000:])
|
||||
open_price = np.array(self.get_price_history("open")[-1000:])
|
||||
high = np.array(self.get_price_history("high")[-1000:])
|
||||
low = np.array(self.get_price_history("low")[-1000:])
|
||||
volume = np.array(self.get_price_history("volume")[-1000:])
|
||||
|
||||
self._indicator_cache = (close, open_price, high, low, volume)
|
||||
self._cache_length = current_length
|
||||
|
||||
return self._indicator_cache
|
||||
|
||||
def save_state(self, state: Any) -> None:
|
||||
if self.trading:
|
||||
self.context.save_state(state)
|
||||
|
||||
def load_state(self) -> None:
|
||||
if self.trading:
|
||||
self.context.load_state()
|
||||
|
||||
@@ -23,7 +23,8 @@ class TqsdkContext:
|
||||
Tqsdk 回测上下文,适配原有 BacktestContext 接口。
|
||||
策略通过此上下文与 Tqsdk 进行交互。
|
||||
"""
|
||||
def __init__(self, api: TqApi):
|
||||
def __init__(self, api: TqApi,
|
||||
state_repository: 'StateRepository'):
|
||||
"""
|
||||
初始化 Tqsdk 回测上下文。
|
||||
|
||||
@@ -33,6 +34,7 @@ class TqsdkContext:
|
||||
self._api = api
|
||||
self._current_bar: Optional[Bar] = None
|
||||
self._engine: Optional['TqsdkEngine'] = None # 添加对引擎的引用,用于访问其状态或触发事件
|
||||
self._state_repository = state_repository # NEW: 存储状态仓储实例
|
||||
|
||||
# 用于缓存 Tqsdk 的 K 线序列,避免每次都 get_kline_serial
|
||||
self._kline_serial: Dict[str, object] = {}
|
||||
@@ -202,4 +204,28 @@ class TqsdkContext:
|
||||
return self._engine.get_bar_history()
|
||||
|
||||
def get_price_history(self, key: str):
|
||||
return self._engine.get_price_history(key)
|
||||
return self._engine.get_price_history(key)
|
||||
|
||||
def save_state(self, state: Dict[str, Any]) -> None:
|
||||
"""
|
||||
保存策略的当前状态。
|
||||
|
||||
策略应在适当的时机(例如,每日结束、策略关闭时)调用此方法
|
||||
来持久化其内部变量。
|
||||
|
||||
Args:
|
||||
state (Dict[str, Any]): 包含策略状态的字典。
|
||||
"""
|
||||
self._state_repository.save(state)
|
||||
|
||||
def load_state(self) -> Dict[str, Any]:
|
||||
"""
|
||||
加载策略的历史状态。
|
||||
|
||||
策略应在初始化时调用此方法来恢复之前的运行状态。
|
||||
如果不存在历史状态,将返回一个空字典。
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: 包含策略历史状态的字典。
|
||||
"""
|
||||
return self._state_repository.load()
|
||||
@@ -1,11 +1,13 @@
|
||||
# filename: tqsdk_engine.py
|
||||
|
||||
import asyncio
|
||||
import traceback
|
||||
from datetime import date, datetime, timedelta
|
||||
from typing import Literal, Type, Dict, Any, List, Optional
|
||||
import pandas as pd
|
||||
import uuid
|
||||
|
||||
from src.common_utils import generate_strategy_identifier
|
||||
# 导入你提供的 core_data 中的类型
|
||||
from src.core_data import Bar, Order, Trade, PortfolioSnapshot
|
||||
|
||||
@@ -22,6 +24,7 @@ from tqsdk import (
|
||||
BacktestFinished,
|
||||
)
|
||||
|
||||
from src.state_repo import MemoryStateRepository
|
||||
# 导入 TqsdkContext 和 BaseStrategy
|
||||
from src.tqsdk_context import TqsdkContext
|
||||
from src.strategies.base_strategy import Strategy # 假设你的策略基类在此路径
|
||||
@@ -36,15 +39,15 @@ class TqsdkEngine:
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
strategy_class: Type[Strategy],
|
||||
strategy_params: Dict[str, Any],
|
||||
api: TqApi,
|
||||
roll_over_mode: bool = False, # 是否开启换月模式检测
|
||||
symbol: str = None,
|
||||
start_time: Optional[datetime] = None,
|
||||
end_time: Optional[datetime] = None,
|
||||
duration_seconds: int = 1,
|
||||
self,
|
||||
strategy_class: Type[Strategy],
|
||||
strategy_params: Dict[str, Any],
|
||||
api: TqApi,
|
||||
roll_over_mode: bool = False, # 是否开启换月模式检测
|
||||
symbol: str = None,
|
||||
start_time: Optional[datetime] = None,
|
||||
end_time: Optional[datetime] = None,
|
||||
duration_seconds: int = 1,
|
||||
):
|
||||
"""
|
||||
初始化 Tqsdk 回测引擎。
|
||||
@@ -83,7 +86,8 @@ class TqsdkEngine:
|
||||
# )
|
||||
|
||||
# 初始化上下文
|
||||
self._context: TqsdkContext = TqsdkContext(api=self._api)
|
||||
identifier = generate_strategy_identifier(strategy_class, strategy_params)
|
||||
self._context: TqsdkContext = TqsdkContext(api=self._api, state_repository=MemoryStateRepository(identifier))
|
||||
# 实例化策略,并将上下文传递给它
|
||||
self._strategy: Strategy = self.strategy_class(
|
||||
context=self._context, **self.strategy_params
|
||||
@@ -113,6 +117,8 @@ class TqsdkEngine:
|
||||
if roll_over_mode:
|
||||
self.quote = api.get_quote(self.symbol)
|
||||
|
||||
self.target_pos_dict = {}
|
||||
|
||||
print("TqsdkEngine: 初始化完成。")
|
||||
|
||||
@property
|
||||
@@ -127,7 +133,6 @@ class TqsdkEngine:
|
||||
异步处理 Context 中排队的订单和取消请求。
|
||||
"""
|
||||
# 处理订单
|
||||
print(self._context.order_queue)
|
||||
while self._context.order_queue:
|
||||
order_to_send: Order = self._context.order_queue.popleft()
|
||||
print(f"Engine: 处理订单请求: {order_to_send}")
|
||||
@@ -155,7 +160,23 @@ class TqsdkEngine:
|
||||
if "SHFE" in order_to_send.symbol:
|
||||
tqsdk_offset = "OPEN"
|
||||
|
||||
try:
|
||||
if "CLOSE" in order_to_send.direction:
|
||||
current_positions = self._context.get_current_positions()
|
||||
current_pos_volume = current_positions.get(order_to_send.symbol, 0)
|
||||
|
||||
target_volume = None
|
||||
if order_to_send.direction == 'CLOSE_LONG':
|
||||
target_volume = current_pos_volume - order_to_send.volume
|
||||
elif order_to_send.direction == 'CLOSE_SHORT':
|
||||
target_volume = current_pos_volume + order_to_send.volume
|
||||
|
||||
if target_volume is not None:
|
||||
if order_to_send.symbol not in self.target_pos_dict:
|
||||
self.target_pos_dict[order_to_send.symbol] = TargetPosTask(self._api, order_to_send.symbol)
|
||||
|
||||
self.target_pos_dict[order_to_send.symbol].set_target_volume(target_volume)
|
||||
else:
|
||||
# try:
|
||||
tq_order = self._api.insert_order(
|
||||
symbol=order_to_send.symbol,
|
||||
direction=tqsdk_direction,
|
||||
@@ -181,42 +202,10 @@ class TqsdkEngine:
|
||||
tq_order.insert_date_time, unit="ns", utc=True
|
||||
)
|
||||
|
||||
# 等待订单状态更新(成交/撤销/报错)
|
||||
# 在 Tqsdk 中,订单和成交是独立的,通常在 wait_update() 循环中通过 api.is_changing() 检查
|
||||
# 这里为了模拟同步处理,直接等待订单状态最终确定
|
||||
# 注意:实际回测中,不应在这里长时间阻塞,而应在主循环中持续 wait_update
|
||||
# 为了简化适配,这里模拟即时处理,但可能与真实异步行为有差异。
|
||||
# 更健壮的方式是在主循环中通过订单状态回调更新
|
||||
# 这里我们假设订单会很快更新状态,或者在下一个 wait_update() 周期中被检测到
|
||||
self._api.wait_update() # 等待一次更新
|
||||
|
||||
# # 检查最终订单状态和成交
|
||||
# if tq_order.status == "FINISHED":
|
||||
# # 查找对应的成交记录
|
||||
# for trade_id, tq_trade in self._api.get_trade().items():
|
||||
# if tq_trade.order_id == tq_order.order_id and tq_trade.volume > 0: # 确保是实际成交
|
||||
# # 创建 core_data.Trade 对象
|
||||
# trade = Trade(
|
||||
# order_id=tq_trade.order_id,
|
||||
# fill_time=tafunc.get_datetime_from_timestamp(tq_trade.trade_date_time) if tq_trade.trade_date_time else datetime.now(),
|
||||
# symbol=order_to_send.symbol, # 使用 Context 中的 symbol
|
||||
# direction=tq_trade.direction, # 实际成交方向
|
||||
# volume=tq_trade.volume,
|
||||
# price=tq_trade.price,
|
||||
# commission=tq_trade.commission,
|
||||
# cash_after_trade=self._api.get_account().available,
|
||||
# positions_after_trade=self._context.get_current_positions(),
|
||||
# realized_pnl=tq_trade.realized_pnl, # Tqsdk TqTrade 对象有 realized_pnl
|
||||
# is_open_trade=tq_trade.offset == "OPEN",
|
||||
# is_close_trade=tq_trade.offset in ["CLOSE", "CLOSETODAY", "CLOSEYESTERDAY"]
|
||||
# )
|
||||
# self.trade_history.append(trade)
|
||||
# print(f"Engine: 成交记录: {trade}")
|
||||
# break # 找到成交就跳出
|
||||
# order_to_send.status = tq_order.status # 更新最终状态
|
||||
except Exception as e:
|
||||
print(f"Engine: 发送订单 {order_to_send.id} 失败: {e}")
|
||||
# order_to_send.status = "ERROR"
|
||||
#
|
||||
# except Exception as e:
|
||||
# print(f"Engine: 发送订单 {order_to_send.id} 失败: {e}")
|
||||
|
||||
# 处理取消请求
|
||||
while self._context.cancel_queue:
|
||||
@@ -254,7 +243,7 @@ class TqsdkEngine:
|
||||
price = quote.last_price
|
||||
current_prices[symbol] = price
|
||||
total_market_value += (
|
||||
price * qty * quote.volume_multiple
|
||||
price * qty * quote.volume_multiple
|
||||
) # volume_multiple 乘数
|
||||
else:
|
||||
# 如果没有最新价格,使用最近的K线收盘价作为估算
|
||||
@@ -267,11 +256,11 @@ class TqsdkEngine:
|
||||
price = last_kline.close
|
||||
current_prices[symbol] = price
|
||||
total_market_value += (
|
||||
price * qty * self._api.get_instrument(symbol).volume_multiple
|
||||
price * qty * self._api.get_instrument(symbol).volume_multiple
|
||||
) # 使用 instrument 的乘数
|
||||
|
||||
total_value = (
|
||||
account.available + account.frozen_margin + total_market_value
|
||||
account.available + account.frozen_margin + total_market_value
|
||||
) # Tqsdk 的 balance 已包含持仓市值和冻结资金
|
||||
# Tqsdk 的 total_profit/balance 已经包含了所有盈亏和资金
|
||||
|
||||
@@ -344,8 +333,8 @@ class TqsdkEngine:
|
||||
self._api.wait_update()
|
||||
|
||||
if self.roll_over_mode and (
|
||||
self._api.is_changing(self.quote, "underlying_symbol")
|
||||
or self._last_underlying_symbol != self.quote.underlying_symbol
|
||||
self._api.is_changing(self.quote, "underlying_symbol")
|
||||
or self._last_underlying_symbol != self.quote.underlying_symbol
|
||||
):
|
||||
self._last_underlying_symbol = self.quote.underlying_symbol
|
||||
|
||||
@@ -355,9 +344,9 @@ class TqsdkEngine:
|
||||
now_dt = now_dt.tz_convert(BEIJING_TZ)
|
||||
|
||||
if (
|
||||
self.now is not None
|
||||
and self.now.hour != 13
|
||||
and now_dt.hour == 13
|
||||
self.now is not None
|
||||
and self.now.hour != 13
|
||||
and now_dt.hour == 13
|
||||
):
|
||||
self.main()
|
||||
|
||||
@@ -400,8 +389,8 @@ class TqsdkEngine:
|
||||
)
|
||||
|
||||
if (
|
||||
self.last_processed_bar is None
|
||||
or self.last_processed_bar.datetime != kline_dt
|
||||
self.last_processed_bar is None
|
||||
or self.last_processed_bar.datetime != kline_dt
|
||||
):
|
||||
|
||||
# 设置当前 Bar 到 Context
|
||||
@@ -410,9 +399,9 @@ class TqsdkEngine:
|
||||
# Tqsdk 的 is_changing 用于判断数据是否有变化,对于回测遍历 K 线,每次迭代都算作新 Bar
|
||||
# 如果 kline_row.datetime 与上次不同,则认为是新 Bar
|
||||
if (
|
||||
self.roll_over_mode
|
||||
and self.last_processed_bar is not None
|
||||
and self._last_underlying_symbol != self.last_processed_bar.symbol
|
||||
self.roll_over_mode
|
||||
and self.last_processed_bar is not None
|
||||
and self._last_underlying_symbol != self.last_processed_bar.symbol
|
||||
):
|
||||
self._is_rollover_bar = True
|
||||
print(
|
||||
@@ -439,6 +428,8 @@ class TqsdkEngine:
|
||||
# 记录投资组合快照
|
||||
self._record_portfolio_snapshot(current_bar.datetime)
|
||||
else:
|
||||
if current_bar.volume == 0:
|
||||
return
|
||||
self.all_bars.append(current_bar)
|
||||
|
||||
self.close_list.append(current_bar.close)
|
||||
|
||||
@@ -13,6 +13,8 @@ from tqsdk import TqApi, TqAccount, tafunc
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from src.state_repo import StateRepository
|
||||
|
||||
# 使用 TYPE_CHECKING 避免循环导入,只在类型检查时导入 TqsdkEngine
|
||||
if TYPE_CHECKING:
|
||||
from src.tqsdk_engine import TqsdkEngine # 假设 TqsdkEngine 在 tqsdk_engine.py 中
|
||||
@@ -23,7 +25,7 @@ class TqsdkContext:
|
||||
Tqsdk 回测上下文,适配原有 BacktestContext 接口。
|
||||
策略通过此上下文与 Tqsdk 进行交互。
|
||||
"""
|
||||
def __init__(self, api: TqApi):
|
||||
def __init__(self, api: TqApi, state_repository: StateRepository):
|
||||
"""
|
||||
初始化 Tqsdk 回测上下文。
|
||||
|
||||
@@ -33,7 +35,8 @@ class TqsdkContext:
|
||||
self._api = api
|
||||
self._current_bar: Optional[Bar] = None
|
||||
self._engine: Optional['TqsdkEngine'] = None # 添加对引擎的引用,用于访问其状态或触发事件
|
||||
|
||||
self._state_repository = state_repository # NEW: 存储状态仓储实例
|
||||
|
||||
# 用于缓存 Tqsdk 的 K 线序列,避免每次都 get_kline_serial
|
||||
self._kline_serial: Dict[str, object] = {}
|
||||
|
||||
@@ -202,4 +205,28 @@ class TqsdkContext:
|
||||
return self._engine.get_bar_history()
|
||||
|
||||
def get_price_history(self, key: str):
|
||||
return self._engine.get_price_history(key)
|
||||
return self._engine.get_price_history(key)
|
||||
|
||||
def save_state(self, state: Dict[str, Any]) -> None:
|
||||
"""
|
||||
保存策略的当前状态。
|
||||
|
||||
策略应在适当的时机(例如,每日结束、策略关闭时)调用此方法
|
||||
来持久化其内部变量。
|
||||
|
||||
Args:
|
||||
state (Dict[str, Any]): 包含策略状态的字典。
|
||||
"""
|
||||
self._state_repository.save(state)
|
||||
|
||||
def load_state(self) -> Dict[str, Any]:
|
||||
"""
|
||||
加载策略的历史状态。
|
||||
|
||||
策略应在初始化时调用此方法来恢复之前的运行状态。
|
||||
如果不存在历史状态,将返回一个空字典。
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: 包含策略历史状态的字典。
|
||||
"""
|
||||
return self._state_repository.load()
|
||||
@@ -7,7 +7,7 @@ import pandas as pd
|
||||
import time
|
||||
|
||||
# 导入你提供的 core_data 中的类型
|
||||
from src.common_utils import is_bar_pre_close_period, is_futures_trading_time
|
||||
from src.common_utils import is_bar_pre_close_period, is_futures_trading_time, generate_strategy_identifier
|
||||
from src.core_data import Bar, Order, Trade, PortfolioSnapshot
|
||||
|
||||
# 导入 Tqsdk 的核心类型
|
||||
@@ -23,6 +23,7 @@ from tqsdk import (
|
||||
BacktestFinished,
|
||||
)
|
||||
|
||||
from src.state_repo import JsonFileStateRepository
|
||||
# 导入 TqsdkContext 和 BaseStrategy
|
||||
from src.tqsdk_real_context import TqsdkContext
|
||||
from src.strategies.base_strategy import Strategy # 假设你的策略基类在此路径
|
||||
@@ -87,8 +88,8 @@ class TqsdkEngine:
|
||||
# )
|
||||
|
||||
# 初始化上下文
|
||||
self._context: TqsdkContext = TqsdkContext(api=self._api)
|
||||
# 实例化策略,并将上下文传递给它
|
||||
identifier = generate_strategy_identifier(strategy_class, strategy_params)
|
||||
self._context: TqsdkContext = TqsdkContext(api=self._api, state_repository=JsonFileStateRepository(identifier)) # 实例化策略,并将上下文传递给它
|
||||
self._strategy: Strategy = self.strategy_class(
|
||||
context=self._context, **self.strategy_params
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user