1、新增傅里叶策略
2、新增策略管理、策略重启功能
This commit is contained in:
File diff suppressed because one or more lines are too long
@@ -241,7 +241,7 @@ if __name__ == "__main__":
|
||||
# symbol='KQ.i@SHFE.bu',
|
||||
freq="min15",
|
||||
start_date_str="2021-01-01",
|
||||
end_date_str="2025-10-28",
|
||||
end_date_str="2025-11-28",
|
||||
mode="backtest", # 指定为回测模式
|
||||
tq_user=TQ_USER_NAME,
|
||||
tq_pwd=TQ_PASSWORD,
|
||||
|
||||
File diff suppressed because one or more lines are too long
@@ -73,8 +73,8 @@ class DualModeKalmanStrategy(Strategy):
|
||||
# 卡尔曼滤波器状态
|
||||
self.Q = kalman_process_noise
|
||||
self.R = kalman_measurement_noise
|
||||
self.P = 1.0;
|
||||
self.x_hat = 0.0;
|
||||
self.P = 1.0
|
||||
self.x_hat = 0.0
|
||||
self.kalman_initialized = False
|
||||
|
||||
self._atr_history: deque = deque(maxlen=self.atr_lookback)
|
||||
@@ -105,7 +105,8 @@ class DualModeKalmanStrategy(Strategy):
|
||||
self._atr_history.append(current_atr)
|
||||
if current_atr <= 0 or len(self._atr_history) < self.atr_lookback: return
|
||||
|
||||
if not self.kalman_initialized: self.x_hat = closes[-1]; self.kalman_initialized = True
|
||||
if not self.kalman_initialized: self.x_hat = closes[-1]
|
||||
self.kalman_initialized = True
|
||||
x_hat_minus = self.x_hat
|
||||
P_minus = self.P + self.Q
|
||||
K = P_minus / (P_minus + self.R)
|
||||
|
||||
File diff suppressed because one or more lines are too long
@@ -1,180 +0,0 @@
|
||||
import numpy as np
|
||||
import talib
|
||||
from typing import Optional, Dict, Any, List, Tuple, Union
|
||||
|
||||
from src.algo.TrendLine import calculate_latest_trendline_values
|
||||
# 假设这些是你项目中的基础模块
|
||||
from src.core_data import Bar, Order
|
||||
from src.indicators.base_indicators import Indicator
|
||||
from src.indicators.indicators import Empty
|
||||
from src.strategies.base_strategy import Strategy
|
||||
|
||||
|
||||
class TrendlineBreakoutStrategy(Strategy):
|
||||
"""
|
||||
趋势线突破策略 V3 (优化版):
|
||||
1. 策略逻辑与 V2 相同,但趋势线计算被重构为一个独立的、
|
||||
高性能的辅助方法。
|
||||
2. 该方法只计算最新的趋势线值,避免不必要的数组生成。
|
||||
|
||||
开仓信号:
|
||||
- 做多: 上一根收盘价上穿下趋势线
|
||||
- 做空: 上一根收盘价下穿上趋势线
|
||||
|
||||
平仓逻辑:
|
||||
- 采用 ATR 滑动止损 (Trailing Stop)。
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
context: Any,
|
||||
main_symbol: str,
|
||||
trendline_n: int = 50,
|
||||
trade_volume: int = 1,
|
||||
order_direction: Optional[List[str]] = None,
|
||||
atr_period: int = 14,
|
||||
atr_multiplier: float = 1.0,
|
||||
enable_log: bool = True,
|
||||
indicators: Union[Indicator, List[Indicator]] = None,
|
||||
):
|
||||
super().__init__(context, main_symbol, enable_log)
|
||||
self.main_symbol = main_symbol
|
||||
self.trendline_n = trendline_n
|
||||
self.trade_volume = trade_volume
|
||||
self.order_direction = order_direction or ["BUY", "SELL"]
|
||||
self.atr_period = atr_period
|
||||
self.atr_multiplier = atr_multiplier
|
||||
self.pos_meta: Dict[str, Dict[str, Any]] = {}
|
||||
if indicators is None:
|
||||
indicators = [Empty(), Empty()]
|
||||
self.indicators = indicators
|
||||
|
||||
if self.trendline_n < 3:
|
||||
raise ValueError("trendline_n 必须大于或等于 3")
|
||||
|
||||
log_message = (
|
||||
f"TrendlineBreakoutStrategy (V3 Optimized) 初始化:\n"
|
||||
f"交易标的={self.main_symbol}, 交易量={self.trade_volume}\n"
|
||||
f"趋势线周期={self.trendline_n}, ATR周期={self.atr_period}, ATR倍数={self.atr_multiplier}"
|
||||
)
|
||||
self.log(log_message)
|
||||
|
||||
|
||||
def _calculate_atr(self, bar_history: List[Bar]) -> Optional[float]:
|
||||
# (此函数与上一版本完全相同,保持不变)
|
||||
if len(bar_history) < self.atr_period + 1: return None
|
||||
highs = np.array([b.high for b in bar_history])
|
||||
lows = np.array([b.low for b in bar_history])
|
||||
closes = np.array([b.close for b in bar_history])
|
||||
atr = talib.ATR(highs, lows, closes, timeperiod=self.atr_period)
|
||||
return atr[-1] if not np.isnan(atr[-1]) else None
|
||||
|
||||
def on_init(self):
|
||||
super().on_init()
|
||||
self.pos_meta.clear()
|
||||
|
||||
def on_open_bar(self, open_price: float, symbol: str):
|
||||
bar_history = self.get_bar_history()
|
||||
min_bars_required = self.trendline_n + 2
|
||||
if len(bar_history) < min_bars_required:
|
||||
return
|
||||
|
||||
self.cancel_all_pending_orders(symbol)
|
||||
pos = self.get_current_positions().get(symbol, 0)
|
||||
|
||||
# 1. 优先处理平仓逻辑 (逻辑不变)
|
||||
meta = self.pos_meta.get(symbol)
|
||||
if meta and pos != 0:
|
||||
current_atr = self._calculate_atr(bar_history[:-1])
|
||||
if current_atr:
|
||||
trailing_stop = meta['trailing_stop']
|
||||
direction = meta['direction']
|
||||
last_close = bar_history[-1].close
|
||||
if direction == "BUY":
|
||||
new_stop_level = last_close - current_atr * self.atr_multiplier
|
||||
trailing_stop = max(trailing_stop, new_stop_level)
|
||||
else: # SELL
|
||||
new_stop_level = last_close + current_atr * self.atr_multiplier
|
||||
trailing_stop = min(trailing_stop, new_stop_level)
|
||||
self.pos_meta[symbol]['trailing_stop'] = trailing_stop
|
||||
if (direction == "BUY" and open_price <= trailing_stop) or \
|
||||
(direction == "SELL" and open_price >= trailing_stop):
|
||||
self.log(f"ATR滑动止损触发: 价格 {open_price:.2f} 触及止损位 {trailing_stop:.2f}")
|
||||
self.send_market_order("CLOSE_LONG" if direction == "BUY" else "CLOSE_SHORT", abs(pos))
|
||||
del self.pos_meta[symbol]
|
||||
return
|
||||
|
||||
# 2. 开仓逻辑 (调用优化后的方法)
|
||||
if pos == 0:
|
||||
prices_for_trendline = np.array([b.close for b in bar_history[-self.trendline_n - 1:-1]])
|
||||
|
||||
# --- 调用新的独立方法 ---
|
||||
trendline_val_upper, trendline_val_lower = calculate_latest_trendline_values(prices_for_trendline)
|
||||
|
||||
if trendline_val_upper is None or trendline_val_lower is None:
|
||||
return # 无法计算趋势线,跳过
|
||||
|
||||
prev_close = bar_history[-2].close
|
||||
last_close = bar_history[-1].close
|
||||
|
||||
current_atr = self._calculate_atr(bar_history[:-1])
|
||||
if not current_atr:
|
||||
return
|
||||
|
||||
if "BUY" in self.order_direction and last_close > trendline_val_upper and self.indicators[0].is_condition_met(*self.get_indicator_tuple()):
|
||||
self.log(f"做多信号: Close({last_close:.2f}) 上穿下趋势线({trendline_val_upper:.2f})")
|
||||
self.send_open_order("BUY", open_price, self.trade_volume, current_atr)
|
||||
|
||||
elif "SELL" in self.order_direction and last_close < trendline_val_lower and self.indicators[1].is_condition_met(*self.get_indicator_tuple()):
|
||||
self.log(f"做空信号: Close({last_close:.2f}) 下穿上趋势线({trendline_val_lower:.2f})")
|
||||
self.send_open_order("SELL", open_price, self.trade_volume, current_atr)
|
||||
|
||||
# is_met = self.indicators[0].is_condition_met(*self.get_indicator_tuple())
|
||||
# if "BUY" in self.order_direction and last_close > trendline_val_upper and is_met:
|
||||
# self.log(f"做多信号: Close({last_close:.2f}) 上穿下趋势线({trendline_val_upper:.2f})")
|
||||
# self.send_open_order("BUY", open_price, self.trade_volume, current_atr)
|
||||
#
|
||||
# elif "SELL" in self.order_direction and last_close < trendline_val_lower and is_met:
|
||||
# self.log(f"做空信号: Close({last_close:.2f}) 下穿上趋势线({trendline_val_lower:.2f})")
|
||||
# self.send_open_order("SELL", open_price, self.trade_volume, current_atr)
|
||||
|
||||
# prob = self.indicators[0].get_latest_value(*self.get_indicator_tuple())
|
||||
# if "BUY" in self.order_direction and last_close > trendline_val_upper:
|
||||
# if prob is None or prob[0] < prob[1]:
|
||||
# self.log(f"做多信号: Close({last_close:.2f}) 上穿下趋势线({trendline_val_upper:.2f})")
|
||||
# self.send_open_order("BUY", open_price, self.trade_volume, current_atr)
|
||||
#
|
||||
# elif "SELL" in self.order_direction and last_close < trendline_val_lower:
|
||||
# if prob is None or prob[2] < prob[3]:
|
||||
# self.log(f"做空信号: Close({last_close:.2f}) 下穿上趋势线({trendline_val_lower:.2f})")
|
||||
# self.send_open_order("SELL", open_price, self.trade_volume, current_atr)
|
||||
|
||||
# send_open_order, send_market_order, on_rollover 等方法与上一版本完全相同,保持不变
|
||||
def send_open_order(self, direction: str, entry_price: float, volume: int, current_atr: float):
|
||||
if direction == "BUY":
|
||||
initial_stop = entry_price - current_atr * self.atr_multiplier
|
||||
else:
|
||||
initial_stop = entry_price + current_atr * self.atr_multiplier
|
||||
current_time = self.get_current_time()
|
||||
order_id = f"{self.symbol}_{direction}_{current_time.strftime('%Y%m%d%H%M%S')}"
|
||||
order_direction = "BUY" if direction == "BUY" else "SELL"
|
||||
order = Order(id=order_id, symbol=self.symbol, direction=order_direction, volume=volume, price_type="MARKET",
|
||||
submitted_time=current_time, offset="OPEN")
|
||||
self.send_order(order)
|
||||
self.pos_meta[self.symbol] = {"direction": direction, "volume": volume, "entry_price": entry_price,
|
||||
"trailing_stop": initial_stop}
|
||||
self.log(
|
||||
f"发送开仓订单: {direction} {volume}手 @ Market Price (执行价约 {entry_price:.2f}), 初始ATR止损位: {initial_stop:.2f}")
|
||||
|
||||
def send_market_order(self, direction: str, volume: int):
|
||||
current_time = self.get_current_time()
|
||||
order_id = f"{self.symbol}_{direction}_{current_time.strftime('%Y%m%d%H%M%S')}"
|
||||
order = Order(id=order_id, symbol=self.symbol, direction=direction, volume=volume, price_type="MARKET",
|
||||
submitted_time=current_time, offset="CLOSE")
|
||||
self.send_order(order)
|
||||
self.log(f"发送平仓订单: {direction} {volume}手 @ Market Price")
|
||||
|
||||
def on_rollover(self, old_symbol: str, new_symbol: str):
|
||||
super().on_rollover(old_symbol, new_symbol)
|
||||
self.cancel_all_pending_orders(new_symbol)
|
||||
self.pos_meta.clear()
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because one or more lines are too long
@@ -122,166 +122,117 @@ class ResultAnalyzer:
|
||||
|
||||
def analyze_indicators(self, profit_offset: float = 0.0) -> None:
|
||||
"""
|
||||
分析所有平仓交易的指标值与实现盈亏的关系,并绘制累积盈亏曲线图。
|
||||
图表将展示指标值区间与对应累积盈亏的关系,帮助找出具有概率优势的指标区间。
|
||||
同时会标记出最大和最小累积盈亏对应的指标值,并优化标注位置以避免重叠。
|
||||
分析开仓时的指标值与平仓时实现盈亏的关系,并绘制累积盈亏曲线。
|
||||
简化假设:每笔开仓(is_open_trade=True)都会在后续被一笔平仓(is_close_trade=True)全平。
|
||||
"""
|
||||
close_trades = [trade for trade in self.trade_history if trade.is_close_trade]
|
||||
|
||||
# 1. 分离开仓和平仓交易
|
||||
open_trades = [t for t in self.trade_history if t.is_open_trade]
|
||||
close_trades = [t for t in self.trade_history if t.is_close_trade]
|
||||
|
||||
if not close_trades:
|
||||
print(
|
||||
"没有平仓交易可供分析。请确保 trade_history 中有 is_close_trade 为 True 的交易。"
|
||||
)
|
||||
print("没有平仓交易可供分析。请确保策略有平仓行为。")
|
||||
return
|
||||
|
||||
# 2. 配对数量(取较小值)
|
||||
num_pairs = min(len(open_trades), len(close_trades))
|
||||
if num_pairs == 0:
|
||||
print("开仓和平仓交易数量不匹配,无法配对分析。")
|
||||
return
|
||||
|
||||
print(f"将进行 {num_pairs} 组开仓-平仓配对分析")
|
||||
|
||||
for indicator in self.indicator_list:
|
||||
# 假设每个 indicator 对象都有一个 get_name() 方法
|
||||
indicator_name = indicator.get_name()
|
||||
|
||||
# 收集该指标的所有值和对应的实现盈亏
|
||||
# 3. 按配对顺序收集指标值和盈亏
|
||||
indi_values = []
|
||||
pnls = []
|
||||
for trade in close_trades:
|
||||
# 确保 trade.indicator_dict 中包含当前指标的值
|
||||
# 并且这个值是可用的(非None或NaN)
|
||||
if (
|
||||
trade.indicator_dict is not None
|
||||
and indicator_name in trade.indicator_dict
|
||||
and trade.indicator_dict[indicator_name] is not None
|
||||
):
|
||||
# 检查是否为 NaN,如果使用 np.nan,则需要 isinstance(value, float) and np.isnan(value)
|
||||
# 为了简化,这里假设非 None 即为有效数值
|
||||
if not (
|
||||
isinstance(trade.indicator_dict[indicator_name], float)
|
||||
and np.isnan(trade.indicator_dict[indicator_name])
|
||||
):
|
||||
indi_values.append(trade.indicator_dict[indicator_name])
|
||||
pnls.append(trade.realized_pnl - profit_offset)
|
||||
|
||||
for i in range(num_pairs):
|
||||
open_trade = open_trades[i]
|
||||
close_trade = close_trades[i]
|
||||
|
||||
if (open_trade.indicator_dict is not None and
|
||||
indicator_name in open_trade.indicator_dict):
|
||||
|
||||
value = open_trade.indicator_dict[indicator_name]
|
||||
if not (isinstance(value, float) and np.isnan(value)):
|
||||
indi_values.append(value)
|
||||
pnls.append(close_trade.realized_pnl - profit_offset)
|
||||
|
||||
if not indi_values:
|
||||
print(f"指标 '{indicator_name}' 没有对应的有效平仓交易数据。跳过绘图。")
|
||||
print(f"指标 '{indicator_name}' 没有有效数据,跳过绘图。")
|
||||
continue
|
||||
|
||||
# 将收集到的数据转换为 Pandas DataFrame 进行更便捷的处理
|
||||
# DataFrame 的结构为:['indicator_value', 'realized_pnl']
|
||||
df = pd.DataFrame({"indicator_value": indi_values, "realized_pnl": pnls})
|
||||
# 4. 数据清洗与准备
|
||||
df = pd.DataFrame({
|
||||
"indicator_value": indi_values,
|
||||
"realized_pnl": pnls
|
||||
})
|
||||
|
||||
def remove_extreme(df, col='indicator_value', k=3):
|
||||
"""IQR 稳健过滤,返回过滤后的 df 和被剔除的行数"""
|
||||
q1, q3 = df[col].quantile([0.25, 0.75])
|
||||
iqr = q3 - q1
|
||||
lower, upper = q1 - k * iqr, q3 + k * iqr
|
||||
mask = df[col].between(lower, upper)
|
||||
n_remove = (~mask).sum()
|
||||
return df[mask].copy(), n_remove
|
||||
mask = df[col].between(q1 - k * iqr, q3 + k * iqr)
|
||||
return df[mask].copy(), (~mask).sum()
|
||||
|
||||
df, n_drop = remove_extreme(df) # 默认 k=2.5
|
||||
df, n_drop = remove_extreme(df)
|
||||
if n_drop:
|
||||
print(f"指标 '{indicator_name}' 过滤掉 {n_drop} 个极端异常值 "
|
||||
f"({n_drop / len(df) * 100:.1f}%),剩余 {len(df)} 条样本。")
|
||||
print(f"指标 '{indicator_name}' 过滤掉 {n_drop} 个极端值,剩余 {len(df)} 条样本。")
|
||||
if df.empty:
|
||||
print(f"指标 '{indicator_name}' 过滤后无数据,跳过绘图。")
|
||||
continue
|
||||
|
||||
# 确保数据框不为空
|
||||
if df.empty:
|
||||
print(f"指标 '{indicator_name}' 的数据框为空,跳过绘图。")
|
||||
continue
|
||||
|
||||
# 按照指标值进行排序,这是计算累积和的关键步骤
|
||||
# 5. 计算累积盈亏曲线数据
|
||||
df = df.sort_values(by="indicator_value").reset_index(drop=True)
|
||||
|
||||
# --- 绘制累积收益曲线 ---
|
||||
plt.figure(figsize=(12, 7)) # 创建一个新的图表
|
||||
|
||||
# 获取指标值的范围,用于生成X轴的等距点
|
||||
# 定义 max_val 和 min_val(这是修复的关键)
|
||||
min_val = df["indicator_value"].min()
|
||||
max_val = df["indicator_value"].max()
|
||||
|
||||
# 特殊处理:如果所有指标值都相同
|
||||
# 处理所有值相同的情况
|
||||
if min_val == max_val:
|
||||
total_pnl = df["realized_pnl"].sum()
|
||||
print(
|
||||
f"指标 '{indicator_name}' 的所有值都相同 ({min_val:.2f}),无法创建区间图,绘制一个点表示总收益。"
|
||||
)
|
||||
plt.plot(min_val, total_pnl, "ro", markersize=8) # 绘制一个点
|
||||
plt.title(
|
||||
f"{indicator_name} Value vs. Cumulative PnL (All values are {min_val:.2f})"
|
||||
)
|
||||
plt.figure(figsize=(12, 7))
|
||||
plt.plot(min_val, total_pnl, "ro", markersize=8)
|
||||
plt.title(f"{indicator_name} Value vs. Cumulative PnL (All values are {min_val:.2f})")
|
||||
plt.xlabel(f"{indicator_name} Value")
|
||||
plt.ylabel("Cumulative Realized PnL")
|
||||
plt.grid(True)
|
||||
plt.text(
|
||||
min_val,
|
||||
total_pnl,
|
||||
f" Total PnL: {total_pnl:.2f}",
|
||||
ha="center",
|
||||
va="bottom",
|
||||
)
|
||||
plt.text(min_val, total_pnl, f" Total PnL: {total_pnl:.2f}", ha="center", va="bottom")
|
||||
plt.show()
|
||||
continue
|
||||
|
||||
# 生成X轴上的100个等距点,这些点代表了指标值的不同阈值
|
||||
x_points = np.linspace(min_val, max_val, 100)
|
||||
y_cumulative_pnl = [df[df["indicator_value"] <= xp]["realized_pnl"].sum() for xp in x_points]
|
||||
|
||||
# 计算Y轴的值:对于每个X轴点 xp,Y轴值是所有 'indicator_value' <= xp 的 'realized_pnl' 之和
|
||||
y_cumulative_pnl = []
|
||||
for xp in x_points:
|
||||
# 筛选出指标值小于等于当前x_point的所有交易,并求和它们的 realized_pnl
|
||||
cumulative_pnl = df[df["indicator_value"] <= xp]["realized_pnl"].sum()
|
||||
y_cumulative_pnl.append(cumulative_pnl)
|
||||
# 6. 绘图(完整版)
|
||||
plt.figure(figsize=(12, 7))
|
||||
plt.plot(x_points, y_cumulative_pnl, marker="o", markersize=3,
|
||||
label=f"Cumulative PnL for {indicator_name}", alpha=0.8)
|
||||
|
||||
# 绘制累积盈亏曲线
|
||||
plt.plot(
|
||||
x_points,
|
||||
y_cumulative_pnl,
|
||||
marker="o",
|
||||
linestyle="-",
|
||||
markersize=3,
|
||||
label=f"Cumulative PnL for {indicator_name}",
|
||||
alpha=0.8,
|
||||
)
|
||||
|
||||
# 标记累积盈亏的最大值点
|
||||
# 标记最大值点
|
||||
optimal_index = np.argmax(y_cumulative_pnl)
|
||||
optimal_indi_value = x_points[optimal_index]
|
||||
max_cumulative_pnl = y_cumulative_pnl[optimal_index]
|
||||
|
||||
# 标记累积盈亏的最小值点
|
||||
min_pnl_index = np.argmin(y_cumulative_pnl[:optimal_index]) if len(y_cumulative_pnl[:optimal_index]) > 0 else 0
|
||||
# 标记最小值点(在最大值左侧找)
|
||||
min_pnl_index = np.argmin(y_cumulative_pnl[:optimal_index]) if optimal_index > 0 else 0
|
||||
min_indi_value_at_pnl = x_points[min_pnl_index]
|
||||
min_cumulative_pnl = y_cumulative_pnl[min_pnl_index]
|
||||
|
||||
# 动态调整标注位置以避免重叠
|
||||
offset_x = (max_val - min_val) * 0.05 # 水平偏移量
|
||||
# 动态调整标注位置(使用已定义的max_val和min_val)
|
||||
offset_x = (max_val - min_val) * 0.05
|
||||
|
||||
# 默认标注为右侧对齐,文本在点的左侧
|
||||
max_ha = "right"
|
||||
max_xytext_x = optimal_indi_value - offset_x
|
||||
min_ha = "right"
|
||||
min_xytext_x = min_indi_value_at_pnl - offset_x
|
||||
|
||||
# 如果最大值点在最小值点右侧,则最大值标注放左侧,最小值标注放右侧
|
||||
# 这样可以避免标注文本重叠
|
||||
if optimal_indi_value > min_indi_value_at_pnl:
|
||||
max_ha = "left"
|
||||
max_xytext_x = optimal_indi_value + offset_x
|
||||
min_ha = "right"
|
||||
min_xytext_x = min_indi_value_at_pnl - offset_x
|
||||
else: # 如果最大值点在最小值点左侧或重合
|
||||
max_ha = "right"
|
||||
max_xytext_x = optimal_indi_value - offset_x
|
||||
min_ha = "left"
|
||||
min_xytext_x = min_indi_value_at_pnl + offset_x
|
||||
max_ha, max_xytext_x = "left", optimal_indi_value + offset_x
|
||||
min_ha, min_xytext_x = "right", min_indi_value_at_pnl - offset_x
|
||||
else:
|
||||
max_ha, max_xytext_x = "right", optimal_indi_value - offset_x
|
||||
min_ha, min_xytext_x = "left", min_indi_value_at_pnl + offset_x
|
||||
|
||||
# 绘制最大值垂直线和标注
|
||||
plt.axvline(
|
||||
optimal_indi_value,
|
||||
color="red",
|
||||
linestyle="--",
|
||||
label=f"Max PnL Threshold: {optimal_indi_value:.2f}",
|
||||
alpha=0.7,
|
||||
)
|
||||
# 绘制最大值标注
|
||||
plt.axvline(optimal_indi_value, color="red", linestyle="--", alpha=0.7)
|
||||
plt.annotate(
|
||||
f"Max Cum. PnL: {max_cumulative_pnl:.2f}",
|
||||
xy=(optimal_indi_value, max_cumulative_pnl),
|
||||
@@ -289,24 +240,12 @@ class ResultAnalyzer:
|
||||
arrowprops=dict(facecolor="red", shrink=0.05),
|
||||
horizontalalignment=max_ha,
|
||||
verticalalignment="bottom",
|
||||
color="red",
|
||||
color="red"
|
||||
)
|
||||
|
||||
# 绘制最小值垂直线和标注
|
||||
plt.axvline(
|
||||
min_indi_value_at_pnl,
|
||||
color="blue",
|
||||
linestyle=":",
|
||||
label=f"Min PnL Threshold: {min_indi_value_at_pnl:.2f}",
|
||||
alpha=0.7,
|
||||
)
|
||||
|
||||
# 垂直偏移最小值标注,避免与曲线重叠
|
||||
min_text_y_offset = (
|
||||
-(max_cumulative_pnl - min_cumulative_pnl) * 0.1
|
||||
if max_cumulative_pnl != min_cumulative_pnl
|
||||
else -0.05
|
||||
)
|
||||
# 绘制最小值标注
|
||||
plt.axvline(min_indi_value_at_pnl, color="blue", linestyle=":", alpha=0.7)
|
||||
min_text_y_offset = -(max_cumulative_pnl - min_cumulative_pnl) * 0.1
|
||||
plt.annotate(
|
||||
f"Min Cum. PnL: {min_cumulative_pnl:.2f}",
|
||||
xy=(min_indi_value_at_pnl, min_cumulative_pnl),
|
||||
@@ -314,7 +253,7 @@ class ResultAnalyzer:
|
||||
arrowprops=dict(facecolor="blue", shrink=0.05),
|
||||
horizontalalignment=min_ha,
|
||||
verticalalignment="top",
|
||||
color="blue",
|
||||
color="blue"
|
||||
)
|
||||
|
||||
plt.title(f"{indicator_name} Value vs. Cumulative Realized PnL")
|
||||
@@ -322,7 +261,7 @@ class ResultAnalyzer:
|
||||
plt.ylabel("Cumulative Realized PnL")
|
||||
plt.grid(True)
|
||||
plt.legend()
|
||||
plt.tight_layout() # 自动调整图表参数,使之更紧凑
|
||||
plt.tight_layout()
|
||||
plt.show()
|
||||
|
||||
print("\n所有指标的分析图表已生成。")
|
||||
print("\n所有指标分析完成。")
|
||||
|
||||
@@ -162,22 +162,22 @@ class BacktestEngine:
|
||||
|
||||
self.strategy.on_open_bar(current_bar.open, current_bar.symbol)
|
||||
|
||||
current_indicator_dict = {}
|
||||
# current_indicator_dict = {}
|
||||
close_array = np.array(self.close_list)
|
||||
open_array = np.array(self.open_list)
|
||||
high_array = np.array(self.high_list)
|
||||
low_array = np.array(self.low_list)
|
||||
volume_array = np.array(self.volume_list)
|
||||
|
||||
for indicator in self.indicators:
|
||||
current_indicator_dict[indicator.get_name()] = indicator.get_latest_value(
|
||||
close_array,
|
||||
open_array,
|
||||
high_array,
|
||||
low_array,
|
||||
volume_array
|
||||
)
|
||||
self.simulator.process_pending_orders(current_bar, current_indicator_dict)
|
||||
# for indicator in self.indicators:
|
||||
# current_indicator_dict[indicator.get_name()] = indicator.get_latest_value(
|
||||
# close_array,
|
||||
# open_array,
|
||||
# high_array,
|
||||
# low_array,
|
||||
# volume_array
|
||||
# )
|
||||
self.simulator.process_pending_orders(current_bar)
|
||||
|
||||
self.all_bars.append(current_bar)
|
||||
self.close_list.append(current_bar.close)
|
||||
@@ -191,7 +191,7 @@ class BacktestEngine:
|
||||
# self.strategy.on_bar(current_bar)
|
||||
|
||||
self.strategy.on_close_bar(current_bar)
|
||||
self.simulator.process_pending_orders(current_bar, current_indicator_dict)
|
||||
self.simulator.process_pending_orders(current_bar)
|
||||
|
||||
|
||||
# 8. 记录投资组合快照
|
||||
@@ -231,6 +231,10 @@ class BacktestEngine:
|
||||
# 回测结束后,获取所有交易记录
|
||||
self.trade_history = self.simulator.get_trade_history()
|
||||
|
||||
print("\n--- 批量计算指标并赋值到Trade ---")
|
||||
indicator_df = self._batch_calculate_indicators()
|
||||
self._assign_indicators_to_trades(indicator_df)
|
||||
|
||||
print("--- 回测结束 ---")
|
||||
print(f"总计处理了 {len(self.all_bars)} 根K线。")
|
||||
print(f"总计发生了 {len(self.trade_history)} 笔交易。")
|
||||
@@ -246,6 +250,57 @@ class BacktestEngine:
|
||||
print(f"最终总净值: {final_portfolio_value:.2f}")
|
||||
print(f"总收益率: {total_return_percentage:.2f}%")
|
||||
|
||||
def _batch_calculate_indicators(self) -> pd.DataFrame:
|
||||
"""
|
||||
批量计算所有Bar的指标值,返回DataFrame(index为Bar序号)
|
||||
"""
|
||||
if not self.indicators:
|
||||
return pd.DataFrame()
|
||||
|
||||
# 一次性转换为numpy数组(避免重复构建)
|
||||
close_arr = np.array(self.close_list)
|
||||
open_arr = np.array(self.open_list)
|
||||
high_arr = np.array(self.high_list)
|
||||
low_arr = np.array(self.low_list)
|
||||
volume_arr = np.array(self.volume_list)
|
||||
|
||||
indicator_data = {}
|
||||
for indicator in self.indicators:
|
||||
# 向量化计算所有历史Bar的指标值
|
||||
values = indicator.get_values(
|
||||
close_arr, open_arr, high_arr, low_arr, volume_arr
|
||||
)
|
||||
indicator_data[indicator.get_name()] = values
|
||||
|
||||
return pd.DataFrame(indicator_data)
|
||||
|
||||
def _assign_indicators_to_trades(self, indicator_df: pd.DataFrame):
|
||||
"""
|
||||
根据fill_time将指标值反向赋值给Trade对象
|
||||
"""
|
||||
if indicator_df.empty or not self.trade_history:
|
||||
return
|
||||
|
||||
# 建立时间戳到Bar索引的映射(O(n)复杂度)
|
||||
time_to_index = {
|
||||
bar.datetime: i for i, bar in enumerate(self.all_bars)
|
||||
}
|
||||
|
||||
# 遍历所有交易记录
|
||||
for trade in self.trade_history:
|
||||
# 仅对开仓交易赋值(保持与原逻辑一致)
|
||||
if not trade.is_open_trade:
|
||||
continue
|
||||
|
||||
# 查找成交时间对应的Bar索引
|
||||
bar_idx = time_to_index.get(trade.fill_time) - 1
|
||||
if bar_idx is None:
|
||||
print(f"警告: Trade {trade.order_id} 的fill_time {trade.fill_time} 未找到对应Bar")
|
||||
continue
|
||||
|
||||
# 赋值指标字典(保持与原逻辑一致)
|
||||
trade.indicator_dict = indicator_df.iloc[bar_idx].to_dict()
|
||||
|
||||
def get_backtest_results(self) -> Dict[str, Any]:
|
||||
"""
|
||||
返回回测结果数据,供结果分析模块使用。
|
||||
|
||||
@@ -90,7 +90,7 @@ class ExecutionSimulator:
|
||||
self.pending_orders[order.id] = order
|
||||
return order
|
||||
|
||||
def process_pending_orders(self, current_bar: Bar, indicator_dict: Dict[str, float]):
|
||||
def process_pending_orders(self, current_bar: Bar):
|
||||
order_ids_to_process = list(self.pending_orders.keys())
|
||||
|
||||
for order_id in order_ids_to_process:
|
||||
@@ -154,10 +154,10 @@ class ExecutionSimulator:
|
||||
trade = self._execute_single_order(order, current_bar)
|
||||
if trade:
|
||||
self.trade_log.append(trade)
|
||||
if trade.is_open_trade:
|
||||
self.indicator_dict = indicator_dict
|
||||
elif trade.is_close_trade:
|
||||
trade.indicator_dict = self.indicator_dict.copy()
|
||||
# if trade.is_open_trade:
|
||||
# self.indicator_dict = indicator_dict
|
||||
# elif trade.is_close_trade:
|
||||
# trade.indicator_dict = self.indicator_dict.copy()
|
||||
|
||||
|
||||
def _execute_single_order(self, order: Order, current_bar: Bar) -> Optional[Trade]:
|
||||
|
||||
@@ -44,7 +44,7 @@ INDICATOR_LIST = [
|
||||
ADX(240),
|
||||
BollingerBandwidth(10, nbdev=1.5),
|
||||
BollingerBandwidth(20, 2.0),
|
||||
BollingerBandwidth(50, nbdev=2.5),
|
||||
BollingerBandwidth(50, 2.5),
|
||||
PriceRangeToVolatilityRatio(3, 5),
|
||||
PriceRangeToVolatilityRatio(3, 14),
|
||||
PriceRangeToVolatilityRatio(3, 21),
|
||||
@@ -54,12 +54,16 @@ INDICATOR_LIST = [
|
||||
PriceRangeToVolatilityRatio(21, 5),
|
||||
PriceRangeToVolatilityRatio(21, 14),
|
||||
PriceRangeToVolatilityRatio(21, 21),
|
||||
# ImpulseCandleConviction(3, 1),
|
||||
# RelativeVolumeInWindow(3, 5),
|
||||
# RelativeVolumeInWindow(3, 14),
|
||||
# RelativeVolumeInWindow(3, 21),
|
||||
# RelativeVolumeInWindow(3, 30),
|
||||
# RelativeVolumeInWindow(3, 40),
|
||||
# ZScoreATR(7, 100),
|
||||
# ZScoreATR(14, 100),
|
||||
ImpulseCandleConviction(3, 1),
|
||||
ZScoreATR(7, 100),
|
||||
ZScoreATR(14, 100),
|
||||
FFTTrendStrength(46, 2, 23),
|
||||
FFTTrendStrength(46, 1, 23),
|
||||
AtrVolatility(7),
|
||||
AtrVolatility(14),
|
||||
AtrVolatility(21),
|
||||
AtrVolatility(230),
|
||||
FFTPhaseShift(),
|
||||
VolatilitySkew(),
|
||||
VolatilityTrendRelationship()
|
||||
]
|
||||
|
||||
@@ -4,19 +4,21 @@ from typing import List, Union, Tuple, Optional
|
||||
import numpy as np
|
||||
import talib
|
||||
from numpy.lib._stride_tricks_impl import sliding_window_view
|
||||
from scipy import stats
|
||||
|
||||
from src.indicators.base_indicators import Indicator
|
||||
|
||||
|
||||
class Empty(Indicator, ABC):
|
||||
def get_values(self, close: np.array, open: np.array, high: np.array, low: np.array, volume: np.array):
|
||||
return []
|
||||
|
||||
def is_condition_met(self,
|
||||
close: np.array,
|
||||
open: np.array,
|
||||
high: np.array,
|
||||
low: np.array,
|
||||
volume: np.array):
|
||||
close: np.array,
|
||||
open: np.array,
|
||||
high: np.array,
|
||||
low: np.array,
|
||||
volume: np.array):
|
||||
return True
|
||||
|
||||
def get_name(self):
|
||||
@@ -402,7 +404,7 @@ class PriceRangeToVolatilityRatio(Indicator):
|
||||
|
||||
return ratio
|
||||
|
||||
def _rolling_max(self,arr: np.array, window: int) -> np.array:
|
||||
def _rolling_max(self, arr: np.array, window: int) -> np.array:
|
||||
if len(arr) < window:
|
||||
return np.full_like(arr, np.nan)
|
||||
|
||||
@@ -591,15 +593,17 @@ class ROC_MA(Indicator):
|
||||
"""
|
||||
return f"roc_ma_{self.roc_window}_{self.ma_window}"
|
||||
|
||||
|
||||
from numpy.lib.stride_tricks import sliding_window_view
|
||||
|
||||
|
||||
class ZScoreATR(Indicator):
|
||||
def __init__(
|
||||
self,
|
||||
atr_window: int = 14,
|
||||
z_window: int = 100,
|
||||
down_bound: float = None,
|
||||
up_bound: float = None,
|
||||
self,
|
||||
atr_window: int = 14,
|
||||
z_window: int = 100,
|
||||
down_bound: float = None,
|
||||
up_bound: float = None,
|
||||
):
|
||||
super().__init__(down_bound, up_bound)
|
||||
self.atr_window = atr_window
|
||||
@@ -616,7 +620,7 @@ class ZScoreATR(Indicator):
|
||||
|
||||
# Step 2: 只对有效区域计算 z-score
|
||||
start_idx = self.atr_window - 1 # ATR 从这里开始非 NaN
|
||||
valid_atr = atr[start_idx:] # shape: (n - start_idx,)
|
||||
valid_atr = atr[start_idx:] # shape: (n - start_idx,)
|
||||
valid_n = len(valid_atr)
|
||||
|
||||
if valid_n < self.z_window:
|
||||
@@ -627,8 +631,8 @@ class ZScoreATR(Indicator):
|
||||
windows = sliding_window_view(valid_atr, window_shape=self.z_window)
|
||||
|
||||
# Step 4: 向量化计算均值和标准差(沿窗口轴)
|
||||
means = np.mean(windows, axis=1) # shape: (M,)
|
||||
stds = np.std(windows, axis=1, ddof=0) # shape: (M,)
|
||||
means = np.mean(windows, axis=1) # shape: (M,)
|
||||
stds = np.std(windows, axis=1, ddof=0) # shape: (M,)
|
||||
|
||||
# Step 5: 计算 z-score(当前值是窗口最后一个元素)
|
||||
current_vals = valid_atr[self.z_window - 1:] # 对齐窗口末尾
|
||||
@@ -648,3 +652,598 @@ class ZScoreATR(Indicator):
|
||||
|
||||
def get_name(self):
|
||||
return f"z_atr_{self.atr_window}_{self.z_window}"
|
||||
|
||||
|
||||
from scipy.signal import stft
|
||||
|
||||
|
||||
class FFTTrendStrength(Indicator):
|
||||
"""
|
||||
傅里叶趋势强度指标 (FFT_TrendStrength)
|
||||
|
||||
该指标通过短时傅里叶变换(STFT)计算低频能量占比,量化趋势强度。
|
||||
低频能量占比越高,趋势越强;当该值在不同波动率环境下变化时,
|
||||
往往预示策略转折点。
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
spectral_window: int = 46, # 2天×23根/天
|
||||
low_freq_days: float = 2.0, # 低频定义下限(天)
|
||||
bars_per_day: int = 23, # 每日K线数量
|
||||
down_bound: float = None,
|
||||
up_bound: float = None,
|
||||
shift_window: int = 0,
|
||||
):
|
||||
"""
|
||||
初始化 FFT_TrendStrength 指标。
|
||||
|
||||
Args:
|
||||
spectral_window (int): STFT窗口大小(根K线)
|
||||
low_freq_days (float): 低频定义下限(天)
|
||||
bars_per_day (int): 每日K线数量
|
||||
down_bound (float): (可选) 用于条件判断的下轨
|
||||
up_bound (float): (可选) 用于条件判断的上轨
|
||||
shift_window (int): (可选) 指标值的时间偏移
|
||||
"""
|
||||
super().__init__(down_bound, up_bound)
|
||||
self.spectral_window = spectral_window
|
||||
self.low_freq_days = low_freq_days
|
||||
self.bars_per_day = bars_per_day
|
||||
self.shift_window = shift_window
|
||||
|
||||
def get_values(
|
||||
self,
|
||||
close: np.array,
|
||||
open: np.array,
|
||||
high: np.array,
|
||||
low: np.array,
|
||||
volume: np.array,
|
||||
) -> np.array:
|
||||
"""
|
||||
计算傅里叶趋势强度值。
|
||||
|
||||
Args:
|
||||
close (np.array): 收盘价列表
|
||||
其他参数保留接口兼容性,本指标仅使用close
|
||||
|
||||
Returns:
|
||||
np.array: 趋势强度值列表(0~1),数据不足时为NaN
|
||||
"""
|
||||
n = len(close)
|
||||
trend_strengths = np.full(n, np.nan)
|
||||
|
||||
# 验证最小数据要求
|
||||
min_required = self.spectral_window + 5
|
||||
if n < min_required:
|
||||
return trend_strengths
|
||||
|
||||
# 频率边界计算
|
||||
low_freq_bound = 1.0 / self.low_freq_days if self.low_freq_days > 0 else float('inf')
|
||||
|
||||
# 为每个时间点计算趋势强度
|
||||
for i in range(min_required - 1, n):
|
||||
# 获取窗口内数据
|
||||
window_data = close[max(0, i - self.spectral_window + 1): i + 1]
|
||||
|
||||
# 跳过数据不足的窗口
|
||||
if len(window_data) < self.spectral_window:
|
||||
continue
|
||||
|
||||
# 价格归一化
|
||||
window_mean = np.mean(window_data)
|
||||
window_std = np.std(window_data)
|
||||
if window_std < 1e-8:
|
||||
continue
|
||||
|
||||
normalized = (window_data - window_mean) / window_std
|
||||
|
||||
try:
|
||||
# STFT计算
|
||||
f, t, Zxx = stft(
|
||||
normalized,
|
||||
fs=self.bars_per_day,
|
||||
nperseg=self.spectral_window,
|
||||
noverlap=max(0, self.spectral_window // 2),
|
||||
boundary=None,
|
||||
padded=False
|
||||
)
|
||||
|
||||
# 频率过滤
|
||||
max_freq = self.bars_per_day / 2
|
||||
valid_mask = (f >= 0) & (f <= max_freq)
|
||||
if not np.any(valid_mask):
|
||||
continue
|
||||
|
||||
f = f[valid_mask]
|
||||
Zxx = Zxx[valid_mask, :]
|
||||
|
||||
if Zxx.shape[1] == 0:
|
||||
continue
|
||||
|
||||
# 能量计算
|
||||
current_energy = np.abs(Zxx[:, -1]) ** 2
|
||||
low_freq_mask = f < low_freq_bound
|
||||
high_freq_mask = f > 1.0 # 高频: <1天周期
|
||||
|
||||
low_energy = np.sum(current_energy[low_freq_mask]) if np.any(low_freq_mask) else 0.0
|
||||
high_energy = np.sum(current_energy[high_freq_mask]) if np.any(high_freq_mask) else 0.0
|
||||
total_energy = low_energy + high_energy + 1e-8
|
||||
|
||||
trend_strength = low_energy / total_energy
|
||||
trend_strengths[i] = np.clip(trend_strength, 0.0, 1.0)
|
||||
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
# 应用时间偏移
|
||||
if self.shift_window > 0 and len(trend_strengths) > self.shift_window:
|
||||
trend_strengths = np.roll(trend_strengths, -self.shift_window)
|
||||
trend_strengths[-self.shift_window:] = np.nan
|
||||
|
||||
return trend_strengths
|
||||
|
||||
def get_name(self) -> str:
|
||||
return f"fft_trend_{self.spectral_window}_{self.low_freq_days}"
|
||||
|
||||
|
||||
class AtrVolatility(Indicator):
|
||||
"""
|
||||
波动率环境识别指标 (VolatilityRegime)
|
||||
|
||||
该指标识别当前市场处于高波动还是低波动环境,对策略转折点
|
||||
有强预测能力。在低波动环境下,趋势信号往往失效转为反转。
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vol_window: int = 23, # 波动率计算窗口
|
||||
down_bound: float = None,
|
||||
up_bound: float = None,
|
||||
shift_window: int = 0,
|
||||
):
|
||||
"""
|
||||
初始化 VolatilityRegime 指标。
|
||||
|
||||
Args:
|
||||
vol_window (int): ATR波动率计算窗口
|
||||
high_vol_threshold (float): 高波动阈值(%),高于此值为高波动环境
|
||||
low_vol_threshold (float): 低波动阈值(%),低于此值为低波动环境
|
||||
down_bound (float): (可选) 用于条件判断的下轨
|
||||
up_bound (float): (可选) 用于条件判断的上轨
|
||||
shift_window (int): (可选) 指标值的时间偏移
|
||||
"""
|
||||
super().__init__(down_bound, up_bound)
|
||||
self.vol_window = vol_window
|
||||
self.shift_window = shift_window
|
||||
|
||||
def get_values(
|
||||
self,
|
||||
close: np.array,
|
||||
open: np.array,
|
||||
high: np.array,
|
||||
low: np.array,
|
||||
volume: np.array,
|
||||
) -> np.array:
|
||||
"""
|
||||
计算波动率环境指标。
|
||||
|
||||
返回值含义:
|
||||
- 1.0: 高波动环境 (趋势策略有效)
|
||||
- 0.0: 中波动环境 (谨慎)
|
||||
- -1.0: 低波动环境 (反转策略有效)
|
||||
|
||||
Args:
|
||||
close (np.array): 收盘价列表
|
||||
high (np.array): 最高价列表
|
||||
low (np.array): 最低价列表
|
||||
其他参数保留接口兼容性
|
||||
|
||||
Returns:
|
||||
np.array: 波动率环境标识,数据不足时为NaN
|
||||
"""
|
||||
n = len(close)
|
||||
regimes = np.full(n, np.nan)
|
||||
|
||||
# 验证最小数据要求
|
||||
if n < self.vol_window + 1:
|
||||
return regimes
|
||||
|
||||
# 计算ATR
|
||||
try:
|
||||
atr = talib.ATR(high, low, close, timeperiod=self.vol_window)
|
||||
except Exception:
|
||||
return regimes
|
||||
|
||||
# 计算标准化波动率 (%)
|
||||
volatility = (atr / close) * 100
|
||||
|
||||
return volatility
|
||||
|
||||
def get_name(self) -> str:
|
||||
return f"atr_volume_{self.vol_window}"
|
||||
|
||||
|
||||
class FFTPhaseShift(Indicator):
|
||||
"""
|
||||
傅里叶相位偏移指标 (FFT_PhaseShift)
|
||||
|
||||
该指标检测频域中主导频率的相位偏移,相位突变往往预示市场
|
||||
趋势的转折点。特别适用于捕捉低波动环境下的价格极端位置。
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
spectral_window: int = 46, # 2天×23根/天
|
||||
dominant_freq_bound: float = 0.5, # 主导频率上限(cycles/day)
|
||||
phase_shift_threshold: float = 1.0, # 相位偏移阈值(弧度)
|
||||
bars_per_day: int = 23, # 每日K线数量
|
||||
down_bound: float = None,
|
||||
up_bound: float = None,
|
||||
shift_window: int = 0,
|
||||
):
|
||||
"""
|
||||
初始化 FFT_PhaseShift 指标。
|
||||
|
||||
Args:
|
||||
spectral_window (int): STFT窗口大小(根K线)
|
||||
dominant_freq_bound (float): 主导频率上限(cycles/day)
|
||||
phase_shift_threshold (float): 相位偏移阈值(弧度)
|
||||
bars_per_day (int): 每日K线数量
|
||||
down_bound (float): (可选) 用于条件判断的下轨
|
||||
up_bound (float): (可选) 用于条件判断的上轨
|
||||
shift_window (int): (可选) 指标值的时间偏移
|
||||
"""
|
||||
super().__init__(down_bound, up_bound)
|
||||
self.spectral_window = spectral_window
|
||||
self.dominant_freq_bound = dominant_freq_bound
|
||||
self.phase_shift_threshold = phase_shift_threshold
|
||||
self.bars_per_day = bars_per_day
|
||||
self.shift_window = shift_window
|
||||
|
||||
def get_values(
|
||||
self,
|
||||
close: np.array,
|
||||
open: np.array,
|
||||
high: np.array,
|
||||
low: np.array,
|
||||
volume: np.array,
|
||||
) -> np.array:
|
||||
"""
|
||||
计算傅里叶相位偏移值。
|
||||
|
||||
返回值含义:
|
||||
- 1.0: 相位正向偏移(可能预示上涨转折)
|
||||
- -1.0: 相位负向偏移(可能预示下跌转折)
|
||||
- 0.0: 无显著相位偏移
|
||||
|
||||
Args:
|
||||
close (np.array): 收盘价列表
|
||||
其他参数保留接口兼容性,本指标仅使用close
|
||||
|
||||
Returns:
|
||||
np.array: 相位偏移标识,数据不足时为NaN
|
||||
"""
|
||||
n = len(close)
|
||||
phase_shifts = np.full(n, np.nan)
|
||||
|
||||
# 验证最小数据要求
|
||||
min_required = self.spectral_window + 5
|
||||
if n < min_required:
|
||||
return phase_shifts
|
||||
|
||||
# 为每个时间点计算相位偏移
|
||||
prev_phase = None
|
||||
|
||||
for i in range(min_required - 1, n):
|
||||
# 获取窗口内数据
|
||||
window_data = close[max(0, i - self.spectral_window + 1): i + 1]
|
||||
|
||||
if len(window_data) < self.spectral_window:
|
||||
continue
|
||||
|
||||
# 价格归一化
|
||||
window_mean = np.mean(window_data)
|
||||
window_std = np.std(window_data)
|
||||
if window_std < 1e-8:
|
||||
continue
|
||||
|
||||
normalized = (window_data - window_mean) / window_std
|
||||
|
||||
try:
|
||||
# STFT计算
|
||||
f, t, Zxx = stft(
|
||||
normalized,
|
||||
fs=self.bars_per_day,
|
||||
nperseg=self.spectral_window,
|
||||
noverlap=max(0, self.spectral_window // 2),
|
||||
boundary=None,
|
||||
padded=False
|
||||
)
|
||||
|
||||
# 频率过滤
|
||||
max_freq = self.bars_per_day / 2
|
||||
valid_mask = (f >= 0) & (f <= max_freq)
|
||||
if not np.any(valid_mask):
|
||||
continue
|
||||
|
||||
f = f[valid_mask]
|
||||
Zxx = Zxx[valid_mask, :]
|
||||
|
||||
if Zxx.shape[1] < 2: # 需要至少两个时间点计算相位变化
|
||||
continue
|
||||
|
||||
# 计算相位
|
||||
phases = np.angle(Zxx[:, -1])
|
||||
prev_phases = np.angle(Zxx[:, -2])
|
||||
|
||||
# 找出主导频率(低频)
|
||||
low_freq_mask = f < self.dominant_freq_bound
|
||||
if not np.any(low_freq_mask):
|
||||
continue
|
||||
|
||||
# 计算主导频率的相位差
|
||||
dominant_idx = np.argmax(np.abs(Zxx[low_freq_mask, -1]))
|
||||
current_phase = phases[low_freq_mask][dominant_idx]
|
||||
prev_dominant_phase = prev_phases[low_freq_mask][dominant_idx]
|
||||
|
||||
# 计算相位差(考虑2π周期性)
|
||||
phase_diff = current_phase - prev_dominant_phase
|
||||
phase_diff = (phase_diff + np.pi) % (2 * np.pi) - np.pi
|
||||
|
||||
# 确定相位偏移方向
|
||||
if np.abs(phase_diff) > self.phase_shift_threshold:
|
||||
phase_shifts[i] = 1.0 if phase_diff > 0 else -1.0
|
||||
else:
|
||||
phase_shifts[i] = 0.0
|
||||
|
||||
prev_phase = current_phase
|
||||
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
# 应用时间偏移
|
||||
if self.shift_window > 0 and len(phase_shifts) > self.shift_window:
|
||||
phase_shifts = np.roll(phase_shifts, -self.shift_window)
|
||||
phase_shifts[-self.shift_window:] = np.nan
|
||||
|
||||
return phase_shifts
|
||||
|
||||
def get_name(self) -> str:
|
||||
return f"fft_phase_{self.spectral_window}_{self.dominant_freq_bound}"
|
||||
|
||||
|
||||
class VolatilitySkew(Indicator):
|
||||
"""
|
||||
波动率偏斜指标 (VolatilitySkew)
|
||||
|
||||
该指标测量近期波动率分布的偏斜程度,正偏斜表示波动率上升趋势,
|
||||
负偏斜表示波动率下降趋势。波动率偏斜的变化往往预示策略逻辑
|
||||
的转折点,特别是在低波动环境向高波动环境转换时。
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vol_window: int = 20, # 单期波动率计算窗口
|
||||
skew_window: int = 60, # 偏斜计算窗口
|
||||
down_bound: float = None,
|
||||
up_bound: float = None,
|
||||
shift_window: int = 0,
|
||||
):
|
||||
"""
|
||||
初始化 VolatilitySkew 指标。
|
||||
|
||||
Args:
|
||||
vol_window (int): ATR波动率计算窗口
|
||||
skew_window (int): 偏斜计算窗口
|
||||
positive_threshold (float): 正偏斜阈值
|
||||
negative_threshold (float): 负偏斜阈值
|
||||
down_bound (float): (可选) 用于条件判断的下轨
|
||||
up_bound (float): (可选) 用于条件判断的上轨
|
||||
shift_window (int): (可选) 指标值的时间偏移
|
||||
"""
|
||||
super().__init__(down_bound, up_bound)
|
||||
self.vol_window = vol_window
|
||||
self.skew_window = skew_window
|
||||
self.shift_window = shift_window
|
||||
|
||||
def get_values(
|
||||
self,
|
||||
close: np.array,
|
||||
open: np.array,
|
||||
high: np.array,
|
||||
low: np.array,
|
||||
volume: np.array,
|
||||
) -> np.array:
|
||||
"""
|
||||
计算波动率偏斜指标。
|
||||
|
||||
返回值含义:
|
||||
- 1.0: 正偏斜(波动率上升趋势,可能预示高波动环境到来)
|
||||
- -1.0: 负偏斜(波动率下降趋势,可能预示低波动环境到来)
|
||||
- 0.0: 无显著偏斜
|
||||
|
||||
Args:
|
||||
close (np.array): 收盘价列表
|
||||
high (np.array): 最高价列表
|
||||
low (np.array): 最低价列表
|
||||
其他参数保留接口兼容性
|
||||
|
||||
Returns:
|
||||
np.array: 波动率偏斜标识,数据不足时为NaN
|
||||
"""
|
||||
n = len(close)
|
||||
skews = np.full(n, np.nan)
|
||||
|
||||
# 验证最小数据要求
|
||||
if n < self.vol_window + self.skew_window:
|
||||
return skews
|
||||
|
||||
# 计算ATR
|
||||
try:
|
||||
atr = talib.ATR(high, low, close, timeperiod=self.vol_window)
|
||||
except Exception:
|
||||
return skews
|
||||
|
||||
# 计算标准化波动率 (%)
|
||||
volatility = (atr / close) * 100
|
||||
|
||||
# 计算滚动偏斜
|
||||
for i in range(self.vol_window + self.skew_window - 1, n):
|
||||
window_vol = volatility[i - self.skew_window + 1: i + 1]
|
||||
valid_vol = window_vol[~np.isnan(window_vol)]
|
||||
|
||||
if len(valid_vol) < self.skew_window * 0.7: # 要求70%有效数据
|
||||
continue
|
||||
|
||||
# 计算偏斜
|
||||
skew_value = stats.skew(valid_vol)
|
||||
|
||||
skews[i] = skew_value
|
||||
|
||||
return skews
|
||||
|
||||
def get_name(self) -> str:
|
||||
return f"vol_skew_{self.vol_window}_{self.skew_window}"
|
||||
|
||||
|
||||
import numpy as np
|
||||
import talib
|
||||
from src.indicators.base_indicators import Indicator
|
||||
|
||||
|
||||
class VolatilityTrendRelationship(Indicator):
|
||||
"""
|
||||
精准修复版:波动率-趋势关系指标
|
||||
|
||||
仅修复NaN问题:
|
||||
1. 保留talib的ATR计算(性能和稳定性更优)
|
||||
2. 修复std_val计算中的NaN传播
|
||||
3. 添加严格的NaN处理,确保100%数据有效性
|
||||
4. 保持原始物理逻辑不变
|
||||
|
||||
核心修复点:
|
||||
- 在计算标准差前过滤NaN值
|
||||
- 为平滑后的序列提供安全回退值
|
||||
- 确保所有中间步骤处理NaN
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vol_window: int = 20, # 波动率计算窗口
|
||||
price_lag: int = 3, # 价格自相关滞后
|
||||
ma_window: int = 5, # 平滑窗口
|
||||
down_bound: float = None,
|
||||
up_bound: float = None,
|
||||
shift_window: int = 0,
|
||||
):
|
||||
super().__init__(down_bound, up_bound)
|
||||
self.vol_window = vol_window
|
||||
self.price_lag = price_lag
|
||||
self.ma_window = ma_window
|
||||
self.shift_window = shift_window
|
||||
|
||||
def get_values(
|
||||
self,
|
||||
close: np.array,
|
||||
open: np.array,
|
||||
high: np.array,
|
||||
low: np.array,
|
||||
volume: np.array,
|
||||
) -> np.array:
|
||||
n = len(close)
|
||||
relationship = np.full(n, np.nan)
|
||||
|
||||
# 验证最小数据要求
|
||||
min_required = max(self.vol_window, self.price_lag, self.ma_window) + 5
|
||||
if n < min_required:
|
||||
return relationship
|
||||
|
||||
# 1. 计算标准化波动率 (使用talib,保持性能)
|
||||
try:
|
||||
atr = talib.ATR(high, low, close, timeperiod=self.vol_window)
|
||||
volatility = (atr / close) * 100
|
||||
except Exception:
|
||||
return relationship
|
||||
|
||||
# 2. 计算波动率变化率 (安全处理除零)
|
||||
vol_change = np.zeros(n)
|
||||
for i in range(1, n):
|
||||
if volatility[i - 1] > 1e-8:
|
||||
vol_change[i] = (volatility[i] - volatility[i - 1]) / volatility[i - 1]
|
||||
else:
|
||||
vol_change[i] = 0.0
|
||||
|
||||
# 3. 计算价格自相关 (安全实现)
|
||||
returns = np.diff(close, prepend=close[0]) / (close + 1e-8)
|
||||
autocorr = np.zeros(n)
|
||||
|
||||
for i in range(self.price_lag, n):
|
||||
if i < self.price_lag * 2:
|
||||
continue
|
||||
|
||||
window_returns = returns[i - self.price_lag * 2:i + 1]
|
||||
valid_returns = window_returns[~np.isnan(window_returns)]
|
||||
|
||||
if len(valid_returns) < self.price_lag * 1.5:
|
||||
continue
|
||||
|
||||
# 计算自相关
|
||||
lagged = valid_returns[:-self.price_lag]
|
||||
current = valid_returns[self.price_lag:]
|
||||
|
||||
if len(lagged) == 0 or len(current) == 0:
|
||||
continue
|
||||
|
||||
mean_lagged = np.mean(lagged)
|
||||
mean_current = np.mean(current)
|
||||
|
||||
numerator = np.sum((lagged - mean_lagged) * (current - mean_current))
|
||||
denom_lagged = np.sum((lagged - mean_lagged) ** 2)
|
||||
denom_current = np.sum((current - mean_current) ** 2)
|
||||
|
||||
if denom_lagged > 1e-8 and denom_current > 1e-8:
|
||||
autocorr[i] = numerator / np.sqrt(denom_lagged * denom_current)
|
||||
|
||||
# 4. 计算核心关系指标
|
||||
raw_relationship = vol_change * autocorr
|
||||
|
||||
# 5. 平滑处理 (处理NaN)
|
||||
smoothed_relationship = np.full(n, np.nan)
|
||||
for i in range(self.ma_window - 1, n):
|
||||
window = raw_relationship[max(0, i - self.ma_window + 1):i + 1]
|
||||
valid_window = window[~np.isnan(window)]
|
||||
if len(valid_window) > 0:
|
||||
smoothed_relationship[i] = np.mean(valid_window)
|
||||
|
||||
# 6. 修复关键问题:std_val计算
|
||||
# 获取有效数据范围
|
||||
valid_mask = ~np.isnan(smoothed_relationship[min_required - 1:])
|
||||
if np.any(valid_mask):
|
||||
valid_values = smoothed_relationship[min_required - 1:][valid_mask]
|
||||
std_val = np.std(valid_values) if len(valid_values) > 1 else 1.0
|
||||
else:
|
||||
std_val = 1.0 # 安全回退值
|
||||
|
||||
# 确保std_val不为零
|
||||
std_val = max(std_val, 1e-8)
|
||||
|
||||
# 7. 标准化到稳定范围 (-1, 1)
|
||||
for i in range(n):
|
||||
if not np.isnan(smoothed_relationship[i]):
|
||||
relationship[i] = smoothed_relationship[i] / (std_val * 3.0)
|
||||
else:
|
||||
relationship[i] = 0.0 # 安全默认值
|
||||
|
||||
# 8. 截断到合理范围
|
||||
relationship = np.clip(relationship, -1.0, 1.0)
|
||||
|
||||
# 应用时间偏移
|
||||
if self.shift_window > 0 and len(relationship) > self.shift_window:
|
||||
relationship = np.roll(relationship, -self.shift_window)
|
||||
relationship[-self.shift_window:] = np.nan
|
||||
|
||||
return relationship
|
||||
|
||||
def get_name(self) -> str:
|
||||
return f"vol_trend_rel_{self.vol_window}_{self.price_lag}"
|
||||
|
||||
File diff suppressed because one or more lines are too long
@@ -1,177 +1,249 @@
|
||||
import numpy as np
|
||||
import talib
|
||||
from collections import deque
|
||||
from typing import Optional, Any, List, Dict
|
||||
import bisect
|
||||
|
||||
from src.core_data import Bar, Order
|
||||
from src.indicators.base_indicators import Indicator
|
||||
from src.indicators.indicators import Empty
|
||||
from src.strategies.base_strategy import Strategy
|
||||
|
||||
|
||||
class TVDZScoreStrategy(Strategy):
|
||||
# =============================================================================
|
||||
# 策略实现 (Dual-Mode Kalman Strategy V4 - 滚动窗口修正版)
|
||||
# =============================================================================
|
||||
|
||||
class DualModeKalmanStrategy(Strategy):
|
||||
"""
|
||||
内嵌 TVD (Condat 算法) + Z-Score ATR 的趋势突破策略。
|
||||
无任何外部依赖(如 pytv),纯 NumPy 实现。
|
||||
V4版本更新:
|
||||
1. 【根本性修正】修复了V3版本中因错误使用全局历史数据而引入的前瞻性偏差和
|
||||
路径依赖问题。
|
||||
2. 【正确实现】现在的数据结构严格、精确地只维护当前滚动窗口(vol_lookback)
|
||||
内的数据,确保了策略的可重复性和逻辑正确性。
|
||||
3. 通过bisect库,在保持100%滚动窗口精度的前提下,实现了高效的百分位计算,
|
||||
避免了在每个bar上都进行暴力排序。
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
context: Any,
|
||||
main_symbol: str,
|
||||
enable_log: bool,
|
||||
trade_volume: int,
|
||||
tvd_lam: float = 50.0,
|
||||
atr_window: int = 14,
|
||||
z_window: int = 100,
|
||||
vol_threshold: float = -0.5,
|
||||
entry_threshold_atr: float = 3.0,
|
||||
stop_atr_multiplier: float = 3.0,
|
||||
order_direction: Optional[List[str]] = None,
|
||||
self,
|
||||
context: Any,
|
||||
main_symbol: str,
|
||||
enable_log: bool,
|
||||
trade_volume: int,
|
||||
# ... (所有策略参数与V2版本完全相同) ...
|
||||
strategy_mode: str = 'TREND',
|
||||
kalman_process_noise: float = 0.01,
|
||||
kalman_measurement_noise: float = 0.5,
|
||||
atr_period: int = 20,
|
||||
vol_lookback: int = 100,
|
||||
vol_percentile_threshold: float = 25.0,
|
||||
entry_threshold_atr: float = 2.5,
|
||||
initial_stop_atr_multiplier: float = 2.0,
|
||||
structural_stop_atr_multiplier: float = 2.5,
|
||||
order_direction: Optional[List[str]] = None,
|
||||
indicators: Optional[List[Indicator]] = None,
|
||||
):
|
||||
super().__init__(context, main_symbol, enable_log)
|
||||
# ... (参数赋值与V2版本完全相同) ...
|
||||
if order_direction is None: order_direction = ['BUY', 'SELL']
|
||||
self.strategy_mode = strategy_mode.upper()
|
||||
self.trade_volume = trade_volume
|
||||
self.order_direction = order_direction or ["BUY", "SELL"]
|
||||
self.tvd_lam = tvd_lam
|
||||
self.atr_window = atr_window
|
||||
self.z_window = z_window
|
||||
self.vol_threshold = vol_threshold
|
||||
self.atr_period = atr_period
|
||||
self.vol_lookback = vol_lookback
|
||||
self.vol_percentile_threshold = vol_percentile_threshold
|
||||
self.entry_threshold_atr = entry_threshold_atr
|
||||
self.stop_atr_multiplier = stop_atr_multiplier
|
||||
self.initial_stop_atr_multiplier = initial_stop_atr_multiplier
|
||||
self.structural_stop_atr_multiplier = structural_stop_atr_multiplier
|
||||
self.order_direction = order_direction
|
||||
|
||||
# --- 【修正后的数据结构】 ---
|
||||
# 1. 严格限定长度的deque,用于维护滚动窗口的原始序列
|
||||
self._vol_history_queue: deque = deque(maxlen=self.vol_lookback)
|
||||
# 2. 一个普通list,我们将手动维护其有序性,并确保其内容与deque完全同步
|
||||
self._sorted_vol_history: List[float] = []
|
||||
|
||||
self.Q = kalman_process_noise
|
||||
self.R = kalman_measurement_noise
|
||||
self.P = 1.0
|
||||
self.x_hat = 0.0
|
||||
self.kalman_initialized = False
|
||||
|
||||
self.position_meta: Dict[str, Any] = self.context.load_state()
|
||||
self.main_symbol = main_symbol
|
||||
self.order_id_counter = 0
|
||||
|
||||
self.log(f"TVDZScoreStrategy Initialized | λ={tvd_lam}, VolThresh={vol_threshold}")
|
||||
if indicators is None: indicators = [Empty(), Empty()]
|
||||
self.indicators = indicators
|
||||
|
||||
@staticmethod
|
||||
def _tvd_condat(y, lam):
|
||||
"""Condat's O(N) TVD algorithm."""
|
||||
n = y.size
|
||||
if n == 0:
|
||||
return y.copy()
|
||||
x = y.astype(np.float64)
|
||||
k = 0
|
||||
k0 = 0
|
||||
vmin = x[0] - lam
|
||||
vmax = x[0] + lam
|
||||
for i in range(1, n):
|
||||
if x[i] < vmin:
|
||||
while k < i:
|
||||
x[k] = vmin
|
||||
k += 1
|
||||
k0 = i
|
||||
vmin = x[i] - lam
|
||||
vmax = x[i] + lam
|
||||
elif x[i] > vmax:
|
||||
while k < i:
|
||||
x[k] = vmax
|
||||
k += 1
|
||||
k0 = i
|
||||
vmin = x[i] - lam
|
||||
vmax = x[i] + lam
|
||||
else:
|
||||
vmin = max(vmin, x[i] - lam)
|
||||
vmax = min(vmax, x[i] + lam)
|
||||
if vmin > vmax:
|
||||
k = k0
|
||||
s = np.sum(x[k0:i+1])
|
||||
s /= (i - k0 + 1)
|
||||
x[k0:i+1] = s
|
||||
k = i + 1
|
||||
k0 = k
|
||||
if k0 < n:
|
||||
vmin = x[k0] - lam
|
||||
vmax = x[k0] + lam
|
||||
while k < n:
|
||||
x[k] = vmin
|
||||
k += 1
|
||||
return x
|
||||
self.log(f"DualModeKalmanStrategy V4 (Corrected Rolling Window) Initialized.")
|
||||
|
||||
def _compute_zscore_atr_last(self, high, low, close) -> float:
|
||||
n = len(close)
|
||||
min_req = self.atr_window + self.z_window - 1
|
||||
if n < min_req:
|
||||
return np.nan
|
||||
start = max(0, n - (self.z_window + self.atr_window))
|
||||
seg_h, seg_l, seg_c = high[start:], low[start:], close[start:]
|
||||
atr_full = talib.ATR(seg_h, seg_l, seg_c, timeperiod=self.atr_window)
|
||||
atr_valid = atr_full[self.atr_window - 1:]
|
||||
if len(atr_valid) < self.z_window:
|
||||
return np.nan
|
||||
window_atr = atr_valid[-self.z_window:]
|
||||
mu = np.mean(window_atr)
|
||||
sigma = np.std(window_atr)
|
||||
last_atr = window_atr[-1]
|
||||
return (last_atr - mu) / sigma if sigma > 1e-12 else 0.0
|
||||
def on_init(self):
|
||||
super().on_init()
|
||||
self.cancel_all_pending_orders(self.main_symbol)
|
||||
self.position_meta = self.context.load_state()
|
||||
# 初始化时清空数据结构
|
||||
self._vol_history_queue.clear()
|
||||
self._sorted_vol_history.clear()
|
||||
|
||||
def on_open_bar(self, open_price: float, symbol: str):
|
||||
self.symbol = symbol
|
||||
bar_history = self.get_bar_history()
|
||||
if len(bar_history) < max(100, self.atr_window + self.z_window):
|
||||
return
|
||||
# 确保有足够的数据来填满第一个完整的窗口
|
||||
if len(bar_history) < self.vol_lookback + self.atr_period: return
|
||||
|
||||
closes = np.array([b.close for b in bar_history], dtype=np.float64)
|
||||
highs = np.array([b.high for b in bar_history], dtype=np.float64)
|
||||
lows = np.array([b.low for b in bar_history], dtype=np.float64)
|
||||
highs = np.array([b.high for b in bar_history], dtype=float)
|
||||
lows = np.array([b.low for b in bar_history], dtype=float)
|
||||
closes = np.array([b.close for b in bar_history], dtype=float)
|
||||
current_atr = talib.ATR(highs, lows, closes, self.atr_period)[-1]
|
||||
|
||||
# === TVD 平滑 ===
|
||||
tvd_prices = self._tvd_condat(closes, self.tvd_lam)
|
||||
tvd_price = tvd_prices[-1]
|
||||
last_close = closes[-1]
|
||||
if last_close <= 0: return
|
||||
current_normalized_atr = current_atr / last_close
|
||||
|
||||
# === Z-Score ATR ===
|
||||
current_atr = talib.ATR(highs, lows, closes, timeperiod=self.atr_window)[-1]
|
||||
if current_atr <= 0:
|
||||
return
|
||||
# --- 【核心修正:正确的滚动窗口维护】 ---
|
||||
# 1. 如果窗口已满,deque会自动从左侧弹出一个旧值。我们需要捕捉这个值。
|
||||
oldest_val = None
|
||||
if len(self._vol_history_queue) == self.vol_lookback:
|
||||
oldest_val = self._vol_history_queue[0]
|
||||
|
||||
deviation = closes[-1] - tvd_price
|
||||
deviation_in_atr = deviation / current_atr
|
||||
# 2. 将新值添加到deque的右侧
|
||||
self._vol_history_queue.append(current_normalized_atr)
|
||||
|
||||
# 3. 更新有序列表,使其与deque的状态严格同步
|
||||
if oldest_val is not None:
|
||||
# a. 先从有序列表中移除旧值
|
||||
# 由于浮点数精度问题,直接remove可能不安全,我们使用bisect查找并移除
|
||||
# 这是一个O(log N) + O(N)的操作,但远快于完全重排
|
||||
idx_to_remove = bisect.bisect_left(self._sorted_vol_history, oldest_val)
|
||||
if idx_to_remove < len(self._sorted_vol_history) and abs(
|
||||
self._sorted_vol_history[idx_to_remove] - oldest_val) < 1e-9:
|
||||
self._sorted_vol_history.pop(idx_to_remove)
|
||||
else:
|
||||
# 备用方案,如果bisect找不到(理论上不应该),则暴力移除
|
||||
try:
|
||||
self._sorted_vol_history.remove(oldest_val)
|
||||
except ValueError:
|
||||
pass # 如果值不存在,忽略
|
||||
|
||||
# b. 将新值高效地插入到有序列表中
|
||||
bisect.insort_left(self._sorted_vol_history, current_normalized_atr)
|
||||
|
||||
# 检查窗口是否已填满
|
||||
if len(self._sorted_vol_history) < self.vol_lookback: return
|
||||
|
||||
# ... (卡尔曼滤波器计算部分保持不变) ...
|
||||
if not self.kalman_initialized: self.x_hat = closes[-1]
|
||||
self.kalman_initialized = True
|
||||
x_hat_minus = self.x_hat
|
||||
P_minus = self.P + self.Q
|
||||
K = P_minus / (P_minus + self.R)
|
||||
self.x_hat = x_hat_minus + K * (closes[-1] - x_hat_minus)
|
||||
self.P = (1 - K) * P_minus
|
||||
kalman_price = self.x_hat
|
||||
|
||||
position_volume = self.get_current_positions().get(self.symbol, 0)
|
||||
# ... (持仓同步逻辑不变) ...
|
||||
|
||||
if position_volume != 0:
|
||||
self.manage_open_position(position_volume, bar_history[-1], current_atr, tvd_price)
|
||||
self.manage_open_position(position_volume, bar_history[-1], current_atr, kalman_price)
|
||||
return
|
||||
|
||||
# --- 使用精确的滚动窗口百分位阈值 ---
|
||||
percentile_index = int(self.vol_percentile_threshold / 100.0 * (self.vol_lookback - 1))
|
||||
vol_threshold = self._sorted_vol_history[percentile_index]
|
||||
|
||||
if current_normalized_atr < vol_threshold:
|
||||
return
|
||||
|
||||
self.evaluate_entry_signal(bar_history[-1], kalman_price, current_atr)
|
||||
|
||||
def manage_open_position(self, volume: int, current_bar: Bar, current_atr: float, kalman_price: float):
|
||||
# ... (此部分代码与上一版完全相同,保持不变) ...
|
||||
meta = self.position_meta.get(self.symbol)
|
||||
if not meta: return
|
||||
|
||||
initial_stop_price = meta['initial_stop_price']
|
||||
if (volume > 0 and current_bar.low <= initial_stop_price) or \
|
||||
(volume < 0 and current_bar.high >= initial_stop_price):
|
||||
self.log(f"Initial Stop Loss hit at {initial_stop_price:.4f}")
|
||||
self.close_position("CLOSE_LONG" if volume > 0 else "CLOSE_SHORT", abs(volume))
|
||||
return
|
||||
|
||||
if self.strategy_mode == 'TREND':
|
||||
if volume > 0:
|
||||
stop_price = max(kalman_price - self.structural_stop_atr_multiplier * current_atr, initial_stop_price)
|
||||
if current_bar.low <= stop_price:
|
||||
self.log(f"TREND Mode: Structural Stop hit for LONG at {stop_price:.4f}")
|
||||
self.close_position("CLOSE_LONG", abs(volume))
|
||||
else:
|
||||
stop_price = min(kalman_price + self.structural_stop_atr_multiplier * current_atr, initial_stop_price)
|
||||
if current_bar.high >= stop_price:
|
||||
self.log(f"TREND Mode: Structural Stop hit for SHORT at {stop_price:.4f}")
|
||||
self.close_position("CLOSE_SHORT", abs(volume))
|
||||
|
||||
elif self.strategy_mode == 'REVERSION':
|
||||
if volume > 0 and current_bar.high >= kalman_price:
|
||||
self.log(f"REVERSION Mode: Take Profit for LONG as price reverts to Kalman line at {kalman_price:.4f}")
|
||||
self.close_position("CLOSE_LONG", abs(volume))
|
||||
elif volume < 0 and current_bar.low <= kalman_price:
|
||||
self.log(f"REVERSION Mode: Take Profit for SHORT as price reverts to Kalman line at {kalman_price:.4f}")
|
||||
self.close_position("CLOSE_SHORT", abs(volume))
|
||||
|
||||
def evaluate_entry_signal(self, current_bar: Bar, kalman_price: float, current_atr: float):
|
||||
# ... (此部分代码与上一版完全相同,保持不变) ...
|
||||
deviation = current_bar.close - kalman_price
|
||||
if current_atr <= 0: return
|
||||
deviation_in_atr = deviation / current_atr
|
||||
|
||||
direction = None
|
||||
if "BUY" in self.order_direction and deviation_in_atr > self.entry_threshold_atr:
|
||||
direction = "BUY"
|
||||
elif "SELL" in self.order_direction and deviation_in_atr < -self.entry_threshold_atr:
|
||||
direction = "SELL"
|
||||
|
||||
if self.strategy_mode == 'TREND':
|
||||
if "BUY" in self.order_direction and deviation_in_atr > self.entry_threshold_atr:
|
||||
direction = "BUY"
|
||||
elif "SELL" in self.order_direction and deviation_in_atr < -self.entry_threshold_atr:
|
||||
direction = "SELL"
|
||||
|
||||
elif self.strategy_mode == 'REVERSION':
|
||||
if "SELL" in self.order_direction and deviation_in_atr > self.entry_threshold_atr:
|
||||
direction = "SELL"
|
||||
elif "BUY" in self.order_direction and deviation_in_atr < -self.entry_threshold_atr:
|
||||
direction = "BUY"
|
||||
|
||||
if direction:
|
||||
self.log(f"Signal Fired | Dir: {direction}, Dev: {deviation_in_atr:.2f} ATR")
|
||||
entry_price = closes[-1]
|
||||
stop_loss = (
|
||||
entry_price - self.stop_atr_multiplier * current_atr
|
||||
if direction == "BUY"
|
||||
else entry_price + self.stop_atr_multiplier * current_atr
|
||||
)
|
||||
meta = {"entry_price": entry_price, "stop_loss": stop_loss}
|
||||
self.log(f"{self.strategy_mode} Mode: Entry Signal {direction}. Deviation: {deviation_in_atr:.2f} ATRs.")
|
||||
entry_price = current_bar.close
|
||||
stop_loss_price = entry_price - self.initial_stop_atr_multiplier * current_atr if direction == "BUY" else entry_price + self.initial_stop_atr_multiplier * current_atr
|
||||
meta = {'entry_price': entry_price, 'initial_stop_price': stop_loss_price, 'direction': direction}
|
||||
self.send_market_order(direction, self.trade_volume, "OPEN", meta)
|
||||
self.save_state(self.position_meta)
|
||||
|
||||
def manage_open_position(self, volume: int, current_bar: Bar, current_atr: float, tvd_price: float):
|
||||
meta = self.position_meta.get(self.symbol)
|
||||
if not meta:
|
||||
return
|
||||
stop_loss = meta["stop_loss"]
|
||||
if (volume > 0 and current_bar.low <= stop_loss) or (volume < 0 and current_bar.high >= stop_loss):
|
||||
self.log(f"Stop Loss Hit at {stop_loss:.4f}")
|
||||
self.close_position("CLOSE_LONG" if volume > 0 else "CLOSE_SHORT", abs(volume))
|
||||
|
||||
def close_position(self, direction: str, volume: int):
|
||||
self.send_market_order(direction, volume, offset="CLOSE")
|
||||
if self.symbol in self.position_meta:
|
||||
del self.position_meta[self.symbol]
|
||||
self.position_meta = {}
|
||||
self.save_state(self.position_meta)
|
||||
|
||||
def send_market_order(self, direction: str, volume: int, offset: str, meta: Optional[Dict] = None):
|
||||
if offset == "OPEN" and meta:
|
||||
self.position_meta[self.symbol] = meta
|
||||
if offset == "OPEN" and meta: self.position_meta[self.symbol] = meta
|
||||
order_id = f"{self.symbol}_{direction}_MARKET_{self.order_id_counter}"
|
||||
self.order_id_counter += 1
|
||||
order = Order(id=order_id, symbol=self.symbol, direction=direction, volume=volume,
|
||||
price_type="MARKET", submitted_time=self.get_current_time(), offset=offset)
|
||||
order = Order(id=order_id, symbol=self.symbol, direction=direction, volume=volume, price_type="MARKET",
|
||||
submitted_time=self.get_current_time(), offset=offset)
|
||||
self.send_order(order)
|
||||
|
||||
def send_limit_order(self, limit_price: float, direction: str, volume: int, offset: str,
|
||||
meta: Optional[Dict] = None):
|
||||
if offset == "OPEN" and meta: self.position_meta[self.symbol] = meta
|
||||
order_id = f"{self.symbol}_{direction}_MARKET_{self.order_id_counter}"
|
||||
self.order_id_counter += 1
|
||||
order = Order(id=order_id, symbol=self.symbol, direction=direction, volume=volume, price_type="LIMIT",
|
||||
submitted_time=self.get_current_time(), offset=offset, limit_price=limit_price)
|
||||
self.send_order(order)
|
||||
|
||||
def on_rollover(self, old_symbol: str, new_symbol: str):
|
||||
super().on_rollover(old_symbol, new_symbol)
|
||||
self.position_meta = {}
|
||||
self.log("Rollover: Strategy state reset.")
|
||||
self.kalman_initialized = False
|
||||
self._sorted_vol_history.clear()
|
||||
self.log("Rollover detected. All strategy states have been reset.")
|
||||
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
@@ -1,21 +1,32 @@
|
||||
import numpy as np
|
||||
import talib
|
||||
from datetime import datetime
|
||||
from typing import Optional, Any, List
|
||||
from scipy.signal import stft
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Optional, Any, List, Dict
|
||||
|
||||
from src.core_data import Bar, Order
|
||||
from src.indicators.base_indicators import Indicator
|
||||
from src.indicators.indicators import Empty, NormalizedATR, AtrVolatility
|
||||
from src.strategies.base_strategy import Strategy
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# 瞬态冲击回调与ATR波幅止盈策略 (V4 - 统一信号与Close价核心逻辑版)
|
||||
# 策略实现 (SpectralTrendStrategy)
|
||||
# =============================================================================
|
||||
class TransientShockATRStrategy(Strategy):
|
||||
|
||||
class SpectralTrendStrategy(Strategy):
|
||||
"""
|
||||
V4版本更新 (根据专业建议重构):
|
||||
1. 【核心变更】信号识别逻辑统一:不再区分跳空和大阳线,仅使用 (上一bar.close - 上上bar.close) 的绝对波幅作为市场冲击的唯一衡量标准。
|
||||
2. 【核心变更】信号计算基于Close价:所有信号计算均基于更稳健的收盘价,避免High/Low的噪音干扰。
|
||||
3. 【参数简化】移除了冗余的 signal_gap_pct 参数,使策略更简洁、更易于优化。
|
||||
频域能量相变策略 - 捕获肥尾趋势
|
||||
|
||||
核心哲学:
|
||||
1. 显式傅里叶变换: 直接分离低频(趋势)、高频(噪音)能量
|
||||
2. 相变临界点: 仅当低频能量占比 > 阈值时入场
|
||||
3. 低频交易: 每月仅2-5次信号,持仓数日捕获肥尾
|
||||
4. 完全参数化: 无硬编码,适配任何市场时间结构
|
||||
|
||||
参数说明:
|
||||
- bars_per_day: 市场每日K线数量 (e.g., 23 for 15min US markets)
|
||||
- low_freq_days: 低频定义下限 (天), 默认2.0
|
||||
- high_freq_days: 高频定义上限 (天), 默认1.0
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -24,197 +35,247 @@ class TransientShockATRStrategy(Strategy):
|
||||
main_symbol: str,
|
||||
enable_log: bool,
|
||||
trade_volume: int,
|
||||
# --- 【核心参数】---
|
||||
atr_period: int = 20,
|
||||
signal_move_atr_mult: float = 2.5, # 【变更】定义“市场冲击”的ATR倍数 (统一信号)
|
||||
entry_pullback_pct: float = 0.5,
|
||||
order_expiry_bars: int = 5,
|
||||
initial_stop_atr_mult: float = 2.0,
|
||||
exit_signal_atr_mult: float = 2.5,
|
||||
# --- 【市场结构参数】 ---
|
||||
bars_per_day: int = 23, # 关键: 适配23根/天的市场
|
||||
# --- 【频域核心参数】 ---
|
||||
spectral_window_days: float = 2.0, # STFT窗口大小(天)
|
||||
low_freq_days: float = 2.0, # 低频下限(天)
|
||||
high_freq_days: float = 1.0, # 高频上限(天)
|
||||
trend_strength_threshold: float = 0.8, # 相变临界值
|
||||
exit_threshold: float = 0.4, # 退出阈值
|
||||
# --- 【持仓管理】 ---
|
||||
max_hold_days: int = 10, # 最大持仓天数
|
||||
# --- 其他 ---
|
||||
order_direction: Optional[List[str]] = None,
|
||||
indicators: Optional[List[Indicator]] = None,
|
||||
model_indicator: Indicator = None,
|
||||
):
|
||||
super().__init__(context, main_symbol, enable_log)
|
||||
if not (atr_period > 0 and signal_move_atr_mult > 0 and entry_pullback_pct > 0 and
|
||||
initial_stop_atr_mult > 0 and exit_signal_atr_mult > 0):
|
||||
raise ValueError("所有周期和倍数参数必须大于0")
|
||||
if order_direction is None:
|
||||
order_direction = ['BUY', 'SELL']
|
||||
if indicators is None:
|
||||
indicators = [Empty(), Empty()] # 保持兼容性
|
||||
|
||||
# --- 参数赋值 (完全参数化) ---
|
||||
self.trade_volume = trade_volume
|
||||
self.atr_period = atr_period
|
||||
self.signal_move_atr_mult = signal_move_atr_mult # 【变更】使用新参数
|
||||
self.entry_pullback_pct = entry_pullback_pct
|
||||
self.order_expiry_bars = order_expiry_bars
|
||||
self.initial_stop_atr_mult = initial_stop_atr_mult
|
||||
self.exit_signal_atr_mult = exit_signal_atr_mult
|
||||
self.bars_per_day = bars_per_day
|
||||
self.spectral_window_days = spectral_window_days
|
||||
self.low_freq_days = low_freq_days
|
||||
self.high_freq_days = high_freq_days
|
||||
self.trend_strength_threshold = trend_strength_threshold
|
||||
self.exit_threshold = exit_threshold
|
||||
self.max_hold_days = max_hold_days
|
||||
self.order_direction = order_direction
|
||||
if model_indicator is None:
|
||||
model_indicator = Empty()
|
||||
self.model_indicator = model_indicator
|
||||
|
||||
self.pending_order: Optional[dict] = None
|
||||
self.position_entry_price: float = 0.0
|
||||
self.initial_stop_price: float = 0.0
|
||||
self.order_id_counter = 0
|
||||
# --- 动态计算参数 ---
|
||||
self.spectral_window = int(self.spectral_window_days * self.bars_per_day)
|
||||
# 确保窗口大小为偶数 (STFT要求)
|
||||
self.spectral_window = self.spectral_window if self.spectral_window % 2 == 0 else self.spectral_window + 1
|
||||
|
||||
def on_init(self):
|
||||
"""策略初始化"""
|
||||
self.log(f"🚀 Strategy On Init: Initializing {self.__class__.__name__} V4...")
|
||||
self.cancel_all_pending_orders(self.main_symbol)
|
||||
self.pending_order = None
|
||||
self.position_entry_price = 0.0
|
||||
self.initial_stop_price = 0.0
|
||||
# 频率边界 (cycles/day)
|
||||
self.low_freq_bound = 1.0 / self.low_freq_days if self.low_freq_days > 0 else float('inf')
|
||||
self.high_freq_bound = 1.0 / self.high_freq_days if self.high_freq_days > 0 else 0.0
|
||||
|
||||
# --- 内部状态变量 ---
|
||||
self.main_symbol = main_symbol
|
||||
self.order_id_counter = 0
|
||||
self.log("✅ Strategy Initialized and State Reset.")
|
||||
self.indicators = indicators
|
||||
self.entry_time = None # 入场时间
|
||||
self.position_direction = None # 'LONG' or 'SHORT'
|
||||
self.last_trend_strength = 0.0
|
||||
self.last_dominant_freq = 0.0 # 主导周期(天)
|
||||
|
||||
self.log(f"SpectralTrendStrategy Initialized (bars/day={bars_per_day}, window={self.spectral_window} bars)")
|
||||
|
||||
def on_open_bar(self, open_price: float, symbol: str):
|
||||
"""每根K线开盘时被调用"""
|
||||
self.symbol = symbol
|
||||
bar_history = self.get_bar_history()
|
||||
current_time = self.get_current_time()
|
||||
|
||||
min_bars = self.atr_period + 5
|
||||
if len(bar_history) < min_bars:
|
||||
# 需要足够的数据 (STFT窗口 + 缓冲)
|
||||
if len(bar_history) < self.spectral_window + 10:
|
||||
if self.enable_log and len(bar_history) % 50 == 0:
|
||||
self.log(f"Waiting for {len(bar_history)}/{self.spectral_window + 10} bars")
|
||||
return
|
||||
|
||||
positions = self.get_current_positions()
|
||||
position_volume = positions.get(self.symbol, 0)
|
||||
position_volume = self.get_current_positions().get(self.symbol, 0)
|
||||
|
||||
highs = np.array([b.high for b in bar_history], dtype=float)
|
||||
lows = np.array([b.low for b in bar_history], dtype=float)
|
||||
# 获取历史价格 (使用完整历史)
|
||||
closes = np.array([b.close for b in bar_history], dtype=float)
|
||||
|
||||
current_atr = talib.ATR(highs, lows, closes, timeperiod=self.atr_period)[-1]
|
||||
# 【核心】计算频域趋势强度 (显式傅里叶)
|
||||
trend_strength, dominant_freq = self.calculate_trend_strength(closes)
|
||||
self.last_trend_strength = trend_strength
|
||||
self.last_dominant_freq = dominant_freq
|
||||
|
||||
if not self.trading:
|
||||
# 检查最大持仓时间 (防止极端事件)
|
||||
if self.entry_time and (current_time - self.entry_time) >= timedelta(days=self.max_hold_days):
|
||||
self.log(f"Max hold time reached ({self.max_hold_days} days). Forcing exit.")
|
||||
self.close_all_positions()
|
||||
self.entry_time = None
|
||||
self.position_direction = None
|
||||
return
|
||||
|
||||
if position_volume != 0:
|
||||
self.manage_position(bar_history[-1], position_volume, current_atr)
|
||||
return
|
||||
# 核心逻辑:相变入场/退出
|
||||
if position_volume == 0:
|
||||
self.evaluate_entry_signal(open_price, trend_strength, dominant_freq)
|
||||
else:
|
||||
self.manage_open_position(position_volume, trend_strength, dominant_freq)
|
||||
|
||||
if self.pending_order:
|
||||
self.manage_pending_order()
|
||||
if self.pending_order:
|
||||
return
|
||||
|
||||
self.identify_new_signal(bar_history, current_atr)
|
||||
|
||||
def manage_position(self, last_bar: Bar, position_volume: int, current_atr: float):
|
||||
"""管理当前持仓 (逻辑不变)"""
|
||||
# ... (此部分代码与上一版完全相同,保持不变) ...
|
||||
if position_volume > 0 and last_bar.low <= self.initial_stop_price:
|
||||
self.log(f"⬇️ LONG STOP LOSS: Low={last_bar.low:.4f} <= Stop={self.initial_stop_price:.4f}")
|
||||
self.close_position("CLOSE_LONG", abs(position_volume))
|
||||
return
|
||||
elif position_volume < 0 and last_bar.high >= self.initial_stop_price:
|
||||
self.log(f"⬆️ SHORT STOP LOSS: High={last_bar.high:.4f} >= Stop={self.initial_stop_price:.4f}")
|
||||
self.close_position("CLOSE_SHORT", abs(position_volume))
|
||||
return
|
||||
|
||||
bar_range = last_bar.high - last_bar.low
|
||||
exit_threshold = current_atr * self.exit_signal_atr_mult
|
||||
|
||||
if position_volume > 0 and last_bar.close < last_bar.open and bar_range > exit_threshold:
|
||||
self.log(
|
||||
f"⬇️ LONG VOLATILITY EXIT: Strong Bearish Bar. Range={bar_range:.2f} > Threshold={exit_threshold:.2f}")
|
||||
self.close_position("CLOSE_LONG", abs(position_volume))
|
||||
elif position_volume < 0 and last_bar.close > last_bar.open and bar_range > exit_threshold:
|
||||
self.log(
|
||||
f"⬆️ SHORT VOLATILITY EXIT: Strong Bullish Bar. Range={bar_range:.2f} > Threshold={exit_threshold:.2f}")
|
||||
self.close_position("CLOSE_SHORT", abs(position_volume))
|
||||
|
||||
def manage_pending_order(self):
|
||||
"""管理挂单 (逻辑不变)"""
|
||||
# ... (此部分代码与上一版完全相同,保持不变) ...
|
||||
if not self.pending_order:
|
||||
return
|
||||
|
||||
self.pending_order['bars_waited'] += 1
|
||||
|
||||
if self.pending_order['bars_waited'] >= self.order_expiry_bars:
|
||||
self.log(f"⌛️ PENDING ORDER EXPIRED: Order for {self.pending_order['direction']} "
|
||||
f"at {self.pending_order['price']:.4f} cancelled after {self.pending_order['bars_waited']} bars.")
|
||||
self.cancel_order(self.pending_order['id'])
|
||||
self.pending_order = None
|
||||
|
||||
def identify_new_signal(self, bar_history: List[Bar], current_atr: float):
|
||||
def calculate_trend_strength(self, prices: np.array) -> (float, float):
|
||||
"""
|
||||
【核心逻辑重构】
|
||||
识别新的交易信号。信号源现在统一为 (close - prev_close) 的波幅。
|
||||
【显式傅里叶】计算低频能量占比 (完全参数化)
|
||||
|
||||
步骤:
|
||||
1. 价格归一化 (窗口内)
|
||||
2. 短时傅里叶变换 (STFT) - 采样率=bars_per_day
|
||||
3. 动态计算频段边界 (基于bars_per_day)
|
||||
4. 趋势强度 = 低频能量 / (低频+高频能量)
|
||||
"""
|
||||
last_bar = bar_history[-1]
|
||||
prev_bar = bar_history[-2]
|
||||
# 1. 验证数据长度
|
||||
if len(prices) < self.spectral_window:
|
||||
return 0.0, 0.0
|
||||
|
||||
# --- CHANGE 1: 定义统一的信号阈值 ---
|
||||
signal_threshold = current_atr * self.signal_move_atr_mult
|
||||
# 2. 价格归一化 (仅使用窗口内数据)
|
||||
window_data = prices[-self.spectral_window:]
|
||||
normalized = (window_data - np.mean(window_data)) / (np.std(window_data) + 1e-8)
|
||||
|
||||
# --- CHANGE 2: 计算核心的 close-to-close 波幅 ---
|
||||
move_height = last_bar.close - prev_bar.close
|
||||
# 3. STFT (采样率=bars_per_day)
|
||||
try:
|
||||
# fs: 每天的样本数 (bars_per_day)
|
||||
f, t, Zxx = stft(
|
||||
normalized,
|
||||
fs=self.bars_per_day, # 关键: 适配市场结构
|
||||
nperseg=self.spectral_window,
|
||||
noverlap=max(0, self.spectral_window // 2),
|
||||
boundary=None,
|
||||
padded=False
|
||||
)
|
||||
except Exception as e:
|
||||
self.log(f"STFT calculation error: {str(e)}")
|
||||
return 0.0, 0.0
|
||||
|
||||
# --- 多头信号: 上涨波幅超过阈值 ---
|
||||
if move_height > signal_threshold:
|
||||
self.log(f"💡 Bullish Shock Detected: Move={move_height:.2f} > Threshold={signal_threshold:.2f}")
|
||||
# 回调计算仍然基于 last_bar.high 作为情绪顶点,但回调深度由更稳健的 move_height 决定
|
||||
entry_price = last_bar.high - (move_height * self.entry_pullback_pct)
|
||||
stop_price = entry_price - (current_atr * self.initial_stop_atr_mult)
|
||||
self.place_limit_order("BUY", entry_price, stop_price)
|
||||
return
|
||||
# 4. 过滤无效频率 (STFT返回频率范围: 0 到 fs/2)
|
||||
valid_mask = (f >= 0) & (f <= self.bars_per_day / 2)
|
||||
f = f[valid_mask]
|
||||
Zxx = Zxx[valid_mask, :]
|
||||
|
||||
# --- 空头信号: 下跌波幅超过阈值 ---
|
||||
# 注意: move_height此时为负数
|
||||
if move_height < -signal_threshold:
|
||||
down_move_height = abs(move_height)
|
||||
self.log(f"💡 Bearish Shock Detected: Move={move_height:.2f} < Threshold={-signal_threshold:.2f}")
|
||||
# 回调计算基于 last_bar.low 作为情绪谷点
|
||||
entry_price = last_bar.low + (down_move_height * self.entry_pullback_pct)
|
||||
stop_price = entry_price + (current_atr * self.initial_stop_atr_mult)
|
||||
self.place_limit_order("SELL", entry_price, stop_price)
|
||||
if Zxx.size == 0 or Zxx.shape[1] == 0:
|
||||
return 0.0, 0.0
|
||||
|
||||
def place_limit_order(self, direction: str, price: float, stop_price: float):
|
||||
"""创建、记录并发送一个新的限价挂单"""
|
||||
# ... (此部分代码与上一版完全相同,保持不变) ...
|
||||
order_id = self.generate_order_id(direction, "OPEN")
|
||||
self.pending_order = {
|
||||
"id": order_id, "symbol": self.symbol, "direction": direction,
|
||||
"volume": self.trade_volume, "price": price, "stop_price": stop_price,
|
||||
"bars_waited": 0,
|
||||
}
|
||||
self.log(f"🆕 Placing Limit Order: {direction} at {price:.4f} (Stop: {stop_price:.4f}).")
|
||||
# 5. 计算最新时间点的能量
|
||||
current_energy = np.abs(Zxx[:, -1]) ** 2
|
||||
|
||||
order = Order(
|
||||
id=order_id, symbol=self.symbol, direction=direction,
|
||||
volume=self.trade_volume, price_type="LIMIT", limit_price=price,
|
||||
submitted_time=self.get_current_time(), offset="OPEN"
|
||||
)
|
||||
self.send_order(order)
|
||||
# 6. 动态频段定义 (cycles/day)
|
||||
# 低频: 周期 > low_freq_days → 频率 < 1/low_freq_days
|
||||
low_freq_mask = f < self.low_freq_bound
|
||||
# 高频: 周期 < high_freq_days → 频率 > 1/high_freq_days
|
||||
high_freq_mask = f > self.high_freq_bound
|
||||
|
||||
def on_trade(self, trade):
|
||||
"""处理成交回报 (逻辑不变)"""
|
||||
# ... (此部分代码与上一版完全相同,保持不变) ...
|
||||
if self.pending_order and trade.id == self.pending_order['id']:
|
||||
self.log(
|
||||
f"✅ Order Filled: {trade.direction} at {trade.price:.4f}. Stop loss set at {self.pending_order['stop_price']:.4f}")
|
||||
self.position_entry_price = trade.price
|
||||
self.initial_stop_price = self.pending_order['stop_price']
|
||||
self.pending_order = None
|
||||
# 7. 能量计算
|
||||
low_energy = np.sum(current_energy[low_freq_mask]) if np.any(low_freq_mask) else 0.0
|
||||
high_energy = np.sum(current_energy[high_freq_mask]) if np.any(high_freq_mask) else 0.0
|
||||
total_energy = low_energy + high_energy + 1e-8 # 防除零
|
||||
|
||||
def generate_order_id(self, direction: str, offset: str) -> str:
|
||||
# ... (此部分代码与上一版完全相同,保持不变) ...
|
||||
self.order_id_counter += 1
|
||||
return f"{self.symbol}_{direction}_{offset}_{self.order_id_counter}_{int(datetime.now().timestamp())}"
|
||||
# 8. 趋势强度 = 低频能量占比
|
||||
trend_strength = low_energy / total_energy
|
||||
|
||||
# 9. 计算主导趋势周期 (天)
|
||||
dominant_freq = 0.0
|
||||
if np.any(low_freq_mask) and low_energy > 0:
|
||||
# 找到低频段最大能量对应的频率
|
||||
low_energies = current_energy[low_freq_mask]
|
||||
max_idx = np.argmax(low_energies)
|
||||
dominant_freq = 1.0 / (f[low_freq_mask][max_idx] + 1e-8) # 转换为周期(天)
|
||||
|
||||
return trend_strength, dominant_freq
|
||||
|
||||
def evaluate_entry_signal(self, open_price: float, trend_strength: float, dominant_freq: float):
|
||||
"""评估相变入场信号"""
|
||||
# 仅当趋势强度跨越临界点且有明确周期时入场
|
||||
if trend_strength > self.trend_strength_threshold and dominant_freq > self.low_freq_days:
|
||||
direction = None
|
||||
|
||||
indicator = self.model_indicator
|
||||
|
||||
# 做多信号: 价格在窗口均值上方
|
||||
closes = np.array([b.close for b in self.get_bar_history()[-self.spectral_window:]], dtype=float)
|
||||
if "BUY" in self.order_direction and np.mean(closes[-5:]) > np.mean(closes):
|
||||
direction = "BUY" if indicator.is_condition_met(*self.get_indicator_tuple()) else "SELL"
|
||||
# 做空信号: 价格在窗口均值下方
|
||||
elif "SELL" in self.order_direction and np.mean(closes[-5:]) < np.mean(closes):
|
||||
direction = "SELL" if indicator.is_condition_met(*self.get_indicator_tuple()) else "BUY"
|
||||
|
||||
if direction:
|
||||
self.log(
|
||||
f"Phase Transition Entry: {direction} | Strength={trend_strength:.2f} | Dominant Period={dominant_freq:.1f}d")
|
||||
self.send_limit_order(direction, open_price, self.trade_volume, "OPEN")
|
||||
self.entry_time = self.get_current_time()
|
||||
self.position_direction = "LONG" if direction == "BUY" else "SHORT"
|
||||
|
||||
def manage_open_position(self, volume: int, trend_strength: float, dominant_freq: float):
|
||||
"""管理持仓:仅当相变逆转时退出"""
|
||||
# 相变逆转条件: 趋势强度 < 退出阈值
|
||||
if trend_strength < self.exit_threshold:
|
||||
direction = "CLOSE_LONG" if volume > 0 else "CLOSE_SHORT"
|
||||
self.log(f"Phase Transition Exit: {direction} | Strength={trend_strength:.2f} < {self.exit_threshold}")
|
||||
self.close_position(direction, abs(volume))
|
||||
self.entry_time = None
|
||||
self.position_direction = None
|
||||
|
||||
# --- 辅助函数区 ---
|
||||
def close_all_positions(self):
|
||||
"""强制平仓所有头寸"""
|
||||
positions = self.get_current_positions()
|
||||
if self.symbol in positions and positions[self.symbol] != 0:
|
||||
direction = "CLOSE_LONG" if positions[self.symbol] > 0 else "CLOSE_SHORT"
|
||||
self.close_position(direction, abs(positions[self.symbol]))
|
||||
self.log(f"Forced exit of {abs(positions[self.symbol])} contracts")
|
||||
|
||||
def close_position(self, direction: str, volume: int):
|
||||
# ... (此部分代码与上一版完全相同,保持不变) ...
|
||||
self.send_market_order(direction, volume, offset="CLOSE")
|
||||
self.position_entry_price = 0.0
|
||||
self.initial_stop_price = 0.0
|
||||
|
||||
def send_market_order(self, direction: str, volume: int, offset: str = "OPEN"):
|
||||
# ... (此部分代码与上一版完全相同,保持不变) ...
|
||||
order_id = self.generate_order_id(direction, offset)
|
||||
self.log(f"➡️ Sending Market Order: {direction} {volume} {self.symbol} ({offset})")
|
||||
|
||||
def send_market_order(self, direction: str, volume: int, offset: str):
|
||||
order_id = f"{self.symbol}_{direction}_MARKET_{self.order_id_counter}"
|
||||
self.order_id_counter += 1
|
||||
order = Order(
|
||||
id=order_id, symbol=self.symbol, direction=direction,
|
||||
volume=volume, price_type="MARKET",
|
||||
submitted_time=self.get_current_time(), offset=offset
|
||||
id=order_id,
|
||||
symbol=self.symbol,
|
||||
direction=direction,
|
||||
volume=volume,
|
||||
price_type="MARKET",
|
||||
submitted_time=self.get_current_time(),
|
||||
offset=offset
|
||||
)
|
||||
self.send_order(order)
|
||||
|
||||
def cancel_order(self, order_id: str):
|
||||
# ... (此部分代码与上一版完全相同,保持不变) ...
|
||||
self.log(f"❌ Sending Cancel Request for Order: {order_id}")
|
||||
self.context.cancel_order(order_id)
|
||||
def send_limit_order(self, direction: str, limit_price: float, volume: int, offset: str):
|
||||
order_id = f"{self.symbol}_{direction}_MARKET_{self.order_id_counter}"
|
||||
self.order_id_counter += 1
|
||||
order = Order(
|
||||
id=order_id,
|
||||
symbol=self.symbol,
|
||||
direction=direction,
|
||||
volume=volume,
|
||||
price_type="LIMIT",
|
||||
submitted_time=self.get_current_time(),
|
||||
offset=offset,
|
||||
limit_price=limit_price
|
||||
)
|
||||
self.send_order(order)
|
||||
|
||||
def on_init(self):
|
||||
super().on_init()
|
||||
self.cancel_all_pending_orders(self.main_symbol)
|
||||
self.log("Strategy initialized. Waiting for phase transition signals...")
|
||||
|
||||
def on_rollover(self, old_symbol: str, new_symbol: str):
|
||||
super().on_rollover(old_symbol, new_symbol)
|
||||
self.log(f"Rollover from {old_symbol} to {new_symbol}. Resetting position state.")
|
||||
self.entry_time = None
|
||||
self.position_direction = None
|
||||
self.last_trend_strength = 0.0
|
||||
@@ -1,6 +1,7 @@
|
||||
import numpy as np
|
||||
import talib
|
||||
from typing import Optional, Any, List
|
||||
from scipy.signal import stft
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Optional, Any, List, Dict
|
||||
|
||||
from src.core_data import Bar, Order
|
||||
from src.indicators.base_indicators import Indicator
|
||||
@@ -8,137 +9,346 @@ from src.indicators.indicators import Empty
|
||||
from src.strategies.base_strategy import Strategy
|
||||
|
||||
|
||||
class SuperTrendStrategy(Strategy):
|
||||
# =============================================================================
|
||||
# 策略实现 (VolatilityAdaptiveSpectralStrategy)
|
||||
# =============================================================================
|
||||
|
||||
class SpectralTrendStrategy(Strategy):
|
||||
"""
|
||||
SuperTrend 策略:
|
||||
- 基于 ATR 和价格波动构建上下轨
|
||||
- 价格上穿上轨 → 开多(且多头条件满足)
|
||||
- 价格下穿下轨 → 开空(且空头条件满足)
|
||||
- 反向穿越 → 平仓并反手(或仅平仓,支持空仓)
|
||||
- 标准 SuperTrend 公式:使用 ATR * multiplier
|
||||
波动率自适应频域趋势策略
|
||||
|
||||
核心哲学:
|
||||
1. 显式傅里叶变换: 分离低频(趋势)、高频(噪音)能量
|
||||
2. 波动率条件信号: 根据波动率环境动态调整交易方向
|
||||
- 低波动环境: 趋势策略 (高趋势强度 → 延续)
|
||||
- 高波动环境: 反转策略 (高趋势强度 → 反转)
|
||||
3. 无硬编码参数: 所有阈值通过配置参数设定
|
||||
4. 严格无未来函数: 所有计算使用历史数据
|
||||
|
||||
参数说明:
|
||||
- bars_per_day: 市场每日K线数量
|
||||
- volatility_lookback: 波动率计算窗口(天)
|
||||
- low_vol_threshold: 低波动环境阈值(0-1)
|
||||
- high_vol_threshold: 高波动环境阈值(0-1)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
context: Any,
|
||||
main_symbol: str,
|
||||
enable_log: bool,
|
||||
trade_volume: int,
|
||||
atr_period: int = 10,
|
||||
atr_multiplier: float = 3.0,
|
||||
order_direction: Optional[List[str]] = None,
|
||||
indicators: Optional[List[Indicator]] = None,
|
||||
self,
|
||||
context: Any,
|
||||
main_symbol: str,
|
||||
enable_log: bool,
|
||||
trade_volume: int,
|
||||
# --- 【市场结构参数】 ---
|
||||
bars_per_day: int = 23, # 适配23根/天的市场
|
||||
# --- 【频域核心参数】 ---
|
||||
spectral_window_days: float = 2.0, # STFT窗口大小(天)
|
||||
low_freq_days: float = 2.0, # 低频下限(天)
|
||||
high_freq_days: float = 1.0, # 高频上限(天)
|
||||
trend_strength_threshold: float = 0.8, # 趋势强度阈值
|
||||
exit_threshold: float = 0.5, # 退出阈值
|
||||
# --- 【波动率参数】 ---
|
||||
volatility_lookback_days: float = 5.0, # 波动率计算窗口(天)
|
||||
low_vol_threshold: float = 0.3, # 低波动环境阈值(0-1)
|
||||
high_vol_threshold: float = 0.7, # 高波动环境阈值(0-1)
|
||||
# --- 【持仓管理】 ---
|
||||
max_hold_days: int = 10, # 最大持仓天数
|
||||
# --- 其他 ---
|
||||
order_direction: Optional[List[str]] = None,
|
||||
indicators: Optional[List[Indicator]] = None,
|
||||
):
|
||||
super().__init__(context, main_symbol, enable_log)
|
||||
if order_direction is None:
|
||||
order_direction = ["BUY", "SELL"]
|
||||
order_direction = ['BUY', 'SELL']
|
||||
if indicators is None:
|
||||
indicators = [Empty(), Empty()]
|
||||
indicators = [Empty(), Empty()] # 保持兼容性
|
||||
|
||||
# --- 参数赋值 (完全参数化) ---
|
||||
self.trade_volume = trade_volume
|
||||
self.atr_period = atr_period
|
||||
self.atr_multiplier = atr_multiplier
|
||||
self.bars_per_day = bars_per_day
|
||||
self.spectral_window_days = spectral_window_days
|
||||
self.low_freq_days = low_freq_days
|
||||
self.high_freq_days = high_freq_days
|
||||
self.trend_strength_threshold = trend_strength_threshold
|
||||
self.exit_threshold = exit_threshold
|
||||
self.volatility_lookback_days = volatility_lookback_days
|
||||
self.low_vol_threshold = low_vol_threshold
|
||||
self.high_vol_threshold = high_vol_threshold
|
||||
self.max_hold_days = max_hold_days
|
||||
self.order_direction = order_direction
|
||||
self.indicators = indicators
|
||||
|
||||
# --- 动态计算参数 ---
|
||||
self.spectral_window = int(self.spectral_window_days * self.bars_per_day)
|
||||
self.spectral_window = self.spectral_window if self.spectral_window % 2 == 0 else self.spectral_window + 1
|
||||
self.volatility_window = int(self.volatility_lookback_days * self.bars_per_day)
|
||||
|
||||
# 频率边界 (cycles/day)
|
||||
self.low_freq_bound = 1.0 / self.low_freq_days if self.low_freq_days > 0 else float('inf')
|
||||
self.high_freq_bound = 1.0 / self.high_freq_days if self.high_freq_days > 0 else 0.0
|
||||
|
||||
# --- 内部状态变量 ---
|
||||
self.main_symbol = main_symbol
|
||||
self.order_id_counter = 0
|
||||
self.min_bars_needed = atr_period + 10
|
||||
self.log(f"SuperTrendStrategy Initialized | ATR({atr_period}) × {atr_multiplier}")
|
||||
self.indicators = indicators
|
||||
self.entry_time = None # 入场时间
|
||||
self.position_direction = None # 'LONG' or 'SHORT'
|
||||
self.last_trend_strength = 0.0
|
||||
self.last_dominant_freq = 0.0 # 主导周期(天)
|
||||
self.last_volatility = 0.0 # 标准化波动率(0-1)
|
||||
self.volatility_history = [] # 存储历史波动率
|
||||
|
||||
self.log(f"VolatilityAdaptiveSpectralStrategy Initialized (bars/day={bars_per_day}, "
|
||||
f"window={self.spectral_window} bars, vol_window={self.volatility_window} bars)")
|
||||
|
||||
def on_open_bar(self, open_price: float, symbol: str):
|
||||
"""每根K线开盘时被调用"""
|
||||
self.symbol = symbol
|
||||
bar_history = self.get_bar_history()
|
||||
current_time = self.get_current_time()
|
||||
|
||||
if len(bar_history) < self.min_bars_needed or not self.trading:
|
||||
# 需要足够的数据 (最大窗口 + 缓冲)
|
||||
min_required = max(self.spectral_window, self.volatility_window) + 10
|
||||
if len(bar_history) < min_required:
|
||||
if self.enable_log and len(bar_history) % 50 == 0:
|
||||
self.log(f"Waiting for {len(bar_history)}/{min_required} bars")
|
||||
return
|
||||
|
||||
position = self.get_current_positions().get(self.symbol, 0)
|
||||
position_volume = self.get_current_positions().get(self.symbol, 0)
|
||||
|
||||
# 提取 OHLC
|
||||
highs = np.array([b.high for b in bar_history], dtype=float)
|
||||
lows = np.array([b.low for b in bar_history], dtype=float)
|
||||
closes = np.array([b.close for b in bar_history], dtype=float)
|
||||
# 获取必要历史价格 (仅取所需部分)
|
||||
recent_bars = bar_history[-(max(self.spectral_window, self.volatility_window) + 5):]
|
||||
closes = np.array([b.close for b in recent_bars], dtype=np.float32)
|
||||
highs = np.array([b.high for b in recent_bars], dtype=np.float32)
|
||||
lows = np.array([b.low for b in recent_bars], dtype=np.float32)
|
||||
|
||||
# 1. 计算 ATR
|
||||
atr = talib.ATR(highs, lows, closes, timeperiod=self.atr_period)
|
||||
# 【核心】计算频域趋势强度 (显式傅里叶)
|
||||
trend_strength, dominant_freq = self.calculate_trend_strength(closes)
|
||||
self.last_trend_strength = trend_strength
|
||||
self.last_dominant_freq = dominant_freq
|
||||
|
||||
# 2. 计算基础上下轨
|
||||
hl2 = (highs + lows) / 2.0
|
||||
upper_band_basic = hl2 + self.atr_multiplier * atr
|
||||
lower_band_basic = hl2 - self.atr_multiplier * atr
|
||||
# 【核心】计算标准化波动率 (0-1范围)
|
||||
volatility = self.calculate_normalized_volatility(highs, lows, closes)
|
||||
self.last_volatility = volatility
|
||||
|
||||
# 3. 构建 SuperTrend(带方向的记忆性逻辑)
|
||||
n = len(closes)
|
||||
supertrend = np.full(n, np.nan)
|
||||
direction = np.full(n, 1) # 1 for up, -1 for down
|
||||
# 检查最大持仓时间 (防止极端事件)
|
||||
if self.entry_time and (current_time - self.entry_time) >= timedelta(days=self.max_hold_days):
|
||||
self.log(f"Max hold time reached ({self.max_hold_days} days). Forcing exit.")
|
||||
self.close_all_positions()
|
||||
self.entry_time = None
|
||||
self.position_direction = None
|
||||
return
|
||||
|
||||
# 初始化
|
||||
supertrend[self.atr_period] = upper_band_basic[self.atr_period]
|
||||
direction[self.atr_period] = -1 # 初始假设为 downtrend
|
||||
# 核心逻辑:相变入场/退出
|
||||
if position_volume == 0:
|
||||
self.evaluate_entry_signal(open_price, trend_strength, dominant_freq, volatility, recent_bars)
|
||||
else:
|
||||
self.manage_open_position(position_volume, trend_strength, volatility)
|
||||
|
||||
for i in range(self.atr_period + 1, n):
|
||||
# 默认继承前值
|
||||
supertrend[i] = supertrend[i - 1]
|
||||
direction[i] = direction[i - 1]
|
||||
def calculate_trend_strength(self, closes: np.array) -> (float, float):
|
||||
"""
|
||||
【显式傅里叶】计算低频能量占比 (完全参数化)
|
||||
"""
|
||||
if len(closes) < self.spectral_window:
|
||||
return 0.0, 0.0
|
||||
|
||||
# 如果前一根是 downtrend
|
||||
if direction[i - 1] == -1:
|
||||
if closes[i] > upper_band_basic[i - 1]:
|
||||
direction[i] = 1
|
||||
supertrend[i] = lower_band_basic[i]
|
||||
else:
|
||||
supertrend[i] = min(upper_band_basic[i], supertrend[i - 1])
|
||||
else: # 前一根是 uptrend
|
||||
if closes[i] < lower_band_basic[i - 1]:
|
||||
direction[i] = -1
|
||||
supertrend[i] = upper_band_basic[i]
|
||||
else:
|
||||
supertrend[i] = max(lower_band_basic[i], supertrend[i - 1])
|
||||
# 仅使用窗口内数据
|
||||
window_data = closes[-self.spectral_window:]
|
||||
window_mean = np.mean(window_data)
|
||||
window_std = np.std(window_data)
|
||||
if window_std < 1e-8:
|
||||
return 0.0, 0.0
|
||||
|
||||
# 获取最新状态
|
||||
current_direction = direction[-1]
|
||||
prev_direction = direction[-2] if len(direction) >= 2 else current_direction
|
||||
normalized = (window_data - window_mean) / window_std
|
||||
|
||||
# 4. 确定目标仓位
|
||||
target_position = 0
|
||||
if current_direction == 1:
|
||||
if self.indicators[0].is_condition_met(*self.get_indicator_tuple()):
|
||||
target_position = self.trade_volume # 做多
|
||||
elif current_direction == -1:
|
||||
if self.indicators[1].is_condition_met(*self.get_indicator_tuple()):
|
||||
target_position = -self.trade_volume # 做空
|
||||
try:
|
||||
f, t, Zxx = stft(
|
||||
normalized,
|
||||
fs=self.bars_per_day,
|
||||
nperseg=self.spectral_window,
|
||||
noverlap=max(0, self.spectral_window // 2),
|
||||
boundary=None,
|
||||
padded=False
|
||||
)
|
||||
except Exception as e:
|
||||
self.log(f"STFT calculation error: {str(e)}")
|
||||
return 0.0, 0.0
|
||||
|
||||
# 5. 平仓逻辑:方向翻转即平仓(即使不开反手,也先平)
|
||||
current_position = position
|
||||
should_close_long = (current_position > 0) and (current_direction == -1)
|
||||
should_close_short = (current_position < 0) and (current_direction == 1)
|
||||
# 过滤无效频率
|
||||
max_freq = self.bars_per_day / 2
|
||||
valid_mask = (f >= 0) & (f <= max_freq)
|
||||
if not np.any(valid_mask):
|
||||
return 0.0, 0.0
|
||||
|
||||
# 6. 执行订单
|
||||
if should_close_long or should_close_short or (target_position != current_position):
|
||||
# 先平旧仓
|
||||
if current_position > 0:
|
||||
self.close_position("CLOSE_LONG", current_position)
|
||||
elif current_position < 0:
|
||||
self.close_position("CLOSE_SHORT", -current_position)
|
||||
f = f[valid_mask]
|
||||
Zxx = Zxx[valid_mask, :]
|
||||
|
||||
# 再开新仓(如果条件满足)
|
||||
if target_position > 0:
|
||||
self.send_market_order("BUY", target_position, "OPEN")
|
||||
self.log(f"📈 SuperTrend Long | ATR={atr[-1]:.2f}, Dir=+1")
|
||||
elif target_position < 0:
|
||||
self.send_market_order("SELL", -target_position, "OPEN")
|
||||
self.log(f"📉 SuperTrend Short | ATR={atr[-1]:.2f}, Dir=-1")
|
||||
if Zxx.size == 0 or Zxx.shape[1] == 0:
|
||||
return 0.0, 0.0
|
||||
|
||||
# --- 模板方法 ---
|
||||
def on_init(self):
|
||||
super().on_init()
|
||||
self.cancel_all_pending_orders(self.main_symbol)
|
||||
# 计算最新时间点的能量
|
||||
current_energy = np.abs(Zxx[:, -1]) ** 2
|
||||
|
||||
# 动态频段定义
|
||||
low_freq_mask = f < self.low_freq_bound
|
||||
high_freq_mask = f > self.high_freq_bound
|
||||
|
||||
# 能量计算
|
||||
low_energy = np.sum(current_energy[low_freq_mask]) if np.any(low_freq_mask) else 0.0
|
||||
high_energy = np.sum(current_energy[high_freq_mask]) if np.any(high_freq_mask) else 0.0
|
||||
total_energy = low_energy + high_energy + 1e-8
|
||||
|
||||
# 趋势强度 = 低频能量占比
|
||||
trend_strength = low_energy / total_energy
|
||||
|
||||
# 计算主导趋势周期 (天)
|
||||
dominant_freq = 0.0
|
||||
if np.any(low_freq_mask) and low_energy > 0:
|
||||
low_energies = current_energy[low_freq_mask]
|
||||
max_idx = np.argmax(low_energies)
|
||||
dominant_freq = 1.0 / (f[low_freq_mask][max_idx] + 1e-8)
|
||||
|
||||
return float(trend_strength), float(dominant_freq)
|
||||
|
||||
def calculate_normalized_volatility(self, highs: np.array, lows: np.array, closes: np.array) -> float:
|
||||
"""
|
||||
计算标准化波动率 (0-1范围)
|
||||
|
||||
步骤:
|
||||
1. 计算ATR (真实波幅)
|
||||
2. 标准化ATR (除以价格)
|
||||
3. 归一化到0-1范围 (基于历史波动率)
|
||||
"""
|
||||
if len(closes) < self.volatility_window + 1:
|
||||
return 0.5 # 默认中性值
|
||||
|
||||
# 1. 计算真实波幅 (TR)
|
||||
tr1 = highs[-self.volatility_window - 1:] - lows[-self.volatility_window - 1:]
|
||||
tr2 = np.abs(highs[-self.volatility_window - 1:] - np.roll(closes, 1)[-self.volatility_window - 1:])
|
||||
tr3 = np.abs(lows[-self.volatility_window - 1:] - np.roll(closes, 1)[-self.volatility_window - 1:])
|
||||
tr = np.maximum(tr1, np.maximum(tr2, tr3))
|
||||
|
||||
# 2. 计算ATR
|
||||
atr = np.mean(tr[-self.volatility_window:])
|
||||
|
||||
# 3. 标准化ATR (除以当前价格)
|
||||
current_price = closes[-1]
|
||||
normalized_atr = atr / current_price if current_price > 0 else 0.0
|
||||
|
||||
# 4. 归一化到0-1范围 (基于历史波动率)
|
||||
self.volatility_history.append(normalized_atr)
|
||||
if len(self.volatility_history) > 1000: # 保留1000个历史值
|
||||
self.volatility_history.pop(0)
|
||||
|
||||
if len(self.volatility_history) < 50: # 需要足够历史数据
|
||||
return 0.5
|
||||
|
||||
# 使用历史50-95百分位进行归一化
|
||||
low_percentile = np.percentile(self.volatility_history, 50)
|
||||
high_percentile = np.percentile(self.volatility_history, 95)
|
||||
|
||||
if high_percentile - low_percentile < 1e-8:
|
||||
return 0.5
|
||||
|
||||
# 归一化到0-1范围
|
||||
normalized_vol = (normalized_atr - low_percentile) / (high_percentile - low_percentile + 1e-8)
|
||||
normalized_vol = max(0.0, min(1.0, normalized_vol)) # 限制在0-1范围内
|
||||
|
||||
return normalized_vol
|
||||
|
||||
def evaluate_entry_signal(self, open_price: float, trend_strength: float, dominant_freq: float,
|
||||
volatility: float, recent_bars: List[Bar]):
|
||||
"""评估波动率条件入场信号"""
|
||||
# 仅当趋势强度跨越临界点且有明确周期时入场
|
||||
if trend_strength > self.trend_strength_threshold:
|
||||
direction = None
|
||||
trade_type = ""
|
||||
|
||||
# 计算价格位置 (短期vs长期均值)
|
||||
window_closes = np.array([b.close for b in recent_bars[-self.spectral_window:]], dtype=np.float32)
|
||||
short_avg = np.mean(window_closes[-5:])
|
||||
long_avg = np.mean(window_closes)
|
||||
|
||||
# 添加统计显著性过滤
|
||||
if abs(short_avg - long_avg) < 0.0005 * long_avg:
|
||||
return
|
||||
|
||||
# 【核心】根据波动率环境决定交易逻辑
|
||||
if volatility < self.low_vol_threshold:
|
||||
# 低波动环境: 趋势策略
|
||||
trade_type = "TREND"
|
||||
if "BUY" in self.order_direction and short_avg > long_avg:
|
||||
direction = "BUY"
|
||||
elif "SELL" in self.order_direction and short_avg < long_avg:
|
||||
direction = "SELL"
|
||||
|
||||
elif volatility > self.high_vol_threshold:
|
||||
# 高波动环境: 反转策略
|
||||
trade_type = "REVERSAL"
|
||||
if "BUY" in self.order_direction and short_avg < long_avg:
|
||||
direction = "BUY" # 价格低于均值,预期回归
|
||||
elif "SELL" in self.order_direction and short_avg > long_avg:
|
||||
direction = "SELL" # 价格高于均值,预期反转
|
||||
|
||||
else:
|
||||
# 中波动环境: 谨慎策略 (需要更强信号)
|
||||
trade_type = "CAUTIOUS"
|
||||
if trend_strength > 0.9 and "BUY" in self.order_direction and short_avg > long_avg:
|
||||
direction = "BUY"
|
||||
elif trend_strength > 0.9 and "SELL" in self.order_direction and short_avg < long_avg:
|
||||
direction = "SELL"
|
||||
|
||||
if direction:
|
||||
self.log(
|
||||
f"Entry: {direction} | Type={trade_type} | Strength={trend_strength:.2f} | "
|
||||
f"Volatility={volatility:.2f} | ShortAvg={short_avg:.4f} vs LongAvg={long_avg:.4f}"
|
||||
)
|
||||
self.send_market_order(direction, self.trade_volume, "OPEN")
|
||||
self.entry_time = self.get_current_time()
|
||||
self.position_direction = "LONG" if direction == "BUY" else "SHORT"
|
||||
|
||||
def manage_open_position(self, volume: int, trend_strength: float, volatility: float):
|
||||
"""管理持仓:波动率条件退出"""
|
||||
# 退出条件1: 趋势强度 < 退出阈值
|
||||
if trend_strength < self.exit_threshold:
|
||||
direction = "CLOSE_LONG" if volume > 0 else "CLOSE_SHORT"
|
||||
self.log(f"Exit (Strength): {direction} | Strength={trend_strength:.2f} < {self.exit_threshold}")
|
||||
self.close_position(direction, abs(volume))
|
||||
self.entry_time = None
|
||||
self.position_direction = None
|
||||
return
|
||||
|
||||
# 退出条件2: 波动率环境突变 (从低波动变为高波动,或反之)
|
||||
if self.position_direction == "LONG" and volatility > self.high_vol_threshold * 1.2:
|
||||
# 多头仓位在波动率突增时退出
|
||||
self.log(
|
||||
f"Exit (Volatility Spike): CLOSE_LONG | Volatility={volatility:.2f} > {self.high_vol_threshold * 1.2:.2f}")
|
||||
self.close_position("CLOSE_LONG", abs(volume))
|
||||
self.entry_time = None
|
||||
self.position_direction = None
|
||||
elif self.position_direction == "SHORT" and volatility > self.high_vol_threshold * 1.2:
|
||||
# 空头仓位在波动率突增时退出
|
||||
self.log(
|
||||
f"Exit (Volatility Spike): CLOSE_SHORT | Volatility={volatility:.2f} > {self.high_vol_threshold * 1.2:.2f}")
|
||||
self.close_position("CLOSE_SHORT", abs(volume))
|
||||
self.entry_time = None
|
||||
self.position_direction = None
|
||||
|
||||
# --- 辅助函数区 ---
|
||||
def close_all_positions(self):
|
||||
"""强制平仓所有头寸"""
|
||||
positions = self.get_current_positions()
|
||||
if not positions or self.symbol not in positions or positions[self.symbol] == 0:
|
||||
return
|
||||
|
||||
direction = "CLOSE_LONG" if positions[self.symbol] > 0 else "CLOSE_SHORT"
|
||||
self.close_position(direction, abs(positions[self.symbol]))
|
||||
if self.enable_log:
|
||||
self.log(f"Closed {abs(positions[self.symbol])} contracts")
|
||||
|
||||
def close_position(self, direction: str, volume: int):
|
||||
self.send_market_order(direction, volume, offset="CLOSE")
|
||||
|
||||
def send_market_order(self, direction: str, volume: int, offset: str):
|
||||
order_id = f"{self.symbol}_{direction}_MARKET_{self.order_id_counter}"
|
||||
order_id = f"{self.symbol}_{direction[-6:]}_{self.order_id_counter}"
|
||||
self.order_id_counter += 1
|
||||
order = Order(
|
||||
id=order_id,
|
||||
@@ -147,10 +357,21 @@ class SuperTrendStrategy(Strategy):
|
||||
volume=volume,
|
||||
price_type="MARKET",
|
||||
submitted_time=self.get_current_time(),
|
||||
offset=offset,
|
||||
offset=offset
|
||||
)
|
||||
self.send_order(order)
|
||||
|
||||
def on_init(self):
|
||||
super().on_init()
|
||||
self.cancel_all_pending_orders(self.main_symbol)
|
||||
if self.enable_log:
|
||||
self.log("Strategy initialized. Waiting for volatility-adaptive signals...")
|
||||
|
||||
def on_rollover(self, old_symbol: str, new_symbol: str):
|
||||
super().on_rollover(old_symbol, new_symbol)
|
||||
self.log("Rollover: SuperTrendStrategy state reset.")
|
||||
if self.enable_log:
|
||||
self.log(f"Rollover: {old_symbol} -> {new_symbol}. Resetting state.")
|
||||
self.entry_time = None
|
||||
self.position_direction = None
|
||||
self.last_trend_strength = 0.0
|
||||
self.volatility_history = [] # 重置波动率历史
|
||||
Reference in New Issue
Block a user