1、新增傅里叶策略

2、新增策略管理、策略重启功能
This commit is contained in:
2025-11-20 16:10:16 +08:00
parent 2ae9f2db9e
commit 2c917a467a
19 changed files with 3368 additions and 6643 deletions

File diff suppressed because one or more lines are too long

View File

@@ -241,7 +241,7 @@ if __name__ == "__main__":
# symbol='KQ.i@SHFE.bu',
freq="min15",
start_date_str="2021-01-01",
end_date_str="2025-10-28",
end_date_str="2025-11-28",
mode="backtest", # 指定为回测模式
tq_user=TQ_USER_NAME,
tq_pwd=TQ_PASSWORD,

File diff suppressed because one or more lines are too long

View File

@@ -73,8 +73,8 @@ class DualModeKalmanStrategy(Strategy):
# 卡尔曼滤波器状态
self.Q = kalman_process_noise
self.R = kalman_measurement_noise
self.P = 1.0;
self.x_hat = 0.0;
self.P = 1.0
self.x_hat = 0.0
self.kalman_initialized = False
self._atr_history: deque = deque(maxlen=self.atr_lookback)
@@ -105,7 +105,8 @@ class DualModeKalmanStrategy(Strategy):
self._atr_history.append(current_atr)
if current_atr <= 0 or len(self._atr_history) < self.atr_lookback: return
if not self.kalman_initialized: self.x_hat = closes[-1]; self.kalman_initialized = True
if not self.kalman_initialized: self.x_hat = closes[-1]
self.kalman_initialized = True
x_hat_minus = self.x_hat
P_minus = self.P + self.Q
K = P_minus / (P_minus + self.R)

File diff suppressed because one or more lines are too long

View File

@@ -1,180 +0,0 @@
import numpy as np
import talib
from typing import Optional, Dict, Any, List, Tuple, Union
from src.algo.TrendLine import calculate_latest_trendline_values
# 假设这些是你项目中的基础模块
from src.core_data import Bar, Order
from src.indicators.base_indicators import Indicator
from src.indicators.indicators import Empty
from src.strategies.base_strategy import Strategy
class TrendlineBreakoutStrategy(Strategy):
"""
趋势线突破策略 V3 (优化版):
1. 策略逻辑与 V2 相同,但趋势线计算被重构为一个独立的、
高性能的辅助方法。
2. 该方法只计算最新的趋势线值,避免不必要的数组生成。
开仓信号:
- 做多: 上一根收盘价上穿下趋势线
- 做空: 上一根收盘价下穿上趋势线
平仓逻辑:
- 采用 ATR 滑动止损 (Trailing Stop)。
"""
def __init__(
self,
context: Any,
main_symbol: str,
trendline_n: int = 50,
trade_volume: int = 1,
order_direction: Optional[List[str]] = None,
atr_period: int = 14,
atr_multiplier: float = 1.0,
enable_log: bool = True,
indicators: Union[Indicator, List[Indicator]] = None,
):
super().__init__(context, main_symbol, enable_log)
self.main_symbol = main_symbol
self.trendline_n = trendline_n
self.trade_volume = trade_volume
self.order_direction = order_direction or ["BUY", "SELL"]
self.atr_period = atr_period
self.atr_multiplier = atr_multiplier
self.pos_meta: Dict[str, Dict[str, Any]] = {}
if indicators is None:
indicators = [Empty(), Empty()]
self.indicators = indicators
if self.trendline_n < 3:
raise ValueError("trendline_n 必须大于或等于 3")
log_message = (
f"TrendlineBreakoutStrategy (V3 Optimized) 初始化:\n"
f"交易标的={self.main_symbol}, 交易量={self.trade_volume}\n"
f"趋势线周期={self.trendline_n}, ATR周期={self.atr_period}, ATR倍数={self.atr_multiplier}"
)
self.log(log_message)
def _calculate_atr(self, bar_history: List[Bar]) -> Optional[float]:
# (此函数与上一版本完全相同,保持不变)
if len(bar_history) < self.atr_period + 1: return None
highs = np.array([b.high for b in bar_history])
lows = np.array([b.low for b in bar_history])
closes = np.array([b.close for b in bar_history])
atr = talib.ATR(highs, lows, closes, timeperiod=self.atr_period)
return atr[-1] if not np.isnan(atr[-1]) else None
def on_init(self):
super().on_init()
self.pos_meta.clear()
def on_open_bar(self, open_price: float, symbol: str):
bar_history = self.get_bar_history()
min_bars_required = self.trendline_n + 2
if len(bar_history) < min_bars_required:
return
self.cancel_all_pending_orders(symbol)
pos = self.get_current_positions().get(symbol, 0)
# 1. 优先处理平仓逻辑 (逻辑不变)
meta = self.pos_meta.get(symbol)
if meta and pos != 0:
current_atr = self._calculate_atr(bar_history[:-1])
if current_atr:
trailing_stop = meta['trailing_stop']
direction = meta['direction']
last_close = bar_history[-1].close
if direction == "BUY":
new_stop_level = last_close - current_atr * self.atr_multiplier
trailing_stop = max(trailing_stop, new_stop_level)
else: # SELL
new_stop_level = last_close + current_atr * self.atr_multiplier
trailing_stop = min(trailing_stop, new_stop_level)
self.pos_meta[symbol]['trailing_stop'] = trailing_stop
if (direction == "BUY" and open_price <= trailing_stop) or \
(direction == "SELL" and open_price >= trailing_stop):
self.log(f"ATR滑动止损触发: 价格 {open_price:.2f} 触及止损位 {trailing_stop:.2f}")
self.send_market_order("CLOSE_LONG" if direction == "BUY" else "CLOSE_SHORT", abs(pos))
del self.pos_meta[symbol]
return
# 2. 开仓逻辑 (调用优化后的方法)
if pos == 0:
prices_for_trendline = np.array([b.close for b in bar_history[-self.trendline_n - 1:-1]])
# --- 调用新的独立方法 ---
trendline_val_upper, trendline_val_lower = calculate_latest_trendline_values(prices_for_trendline)
if trendline_val_upper is None or trendline_val_lower is None:
return # 无法计算趋势线,跳过
prev_close = bar_history[-2].close
last_close = bar_history[-1].close
current_atr = self._calculate_atr(bar_history[:-1])
if not current_atr:
return
if "BUY" in self.order_direction and last_close > trendline_val_upper and self.indicators[0].is_condition_met(*self.get_indicator_tuple()):
self.log(f"做多信号: Close({last_close:.2f}) 上穿下趋势线({trendline_val_upper:.2f})")
self.send_open_order("BUY", open_price, self.trade_volume, current_atr)
elif "SELL" in self.order_direction and last_close < trendline_val_lower and self.indicators[1].is_condition_met(*self.get_indicator_tuple()):
self.log(f"做空信号: Close({last_close:.2f}) 下穿上趋势线({trendline_val_lower:.2f})")
self.send_open_order("SELL", open_price, self.trade_volume, current_atr)
# is_met = self.indicators[0].is_condition_met(*self.get_indicator_tuple())
# if "BUY" in self.order_direction and last_close > trendline_val_upper and is_met:
# self.log(f"做多信号: Close({last_close:.2f}) 上穿下趋势线({trendline_val_upper:.2f})")
# self.send_open_order("BUY", open_price, self.trade_volume, current_atr)
#
# elif "SELL" in self.order_direction and last_close < trendline_val_lower and is_met:
# self.log(f"做空信号: Close({last_close:.2f}) 下穿上趋势线({trendline_val_lower:.2f})")
# self.send_open_order("SELL", open_price, self.trade_volume, current_atr)
# prob = self.indicators[0].get_latest_value(*self.get_indicator_tuple())
# if "BUY" in self.order_direction and last_close > trendline_val_upper:
# if prob is None or prob[0] < prob[1]:
# self.log(f"做多信号: Close({last_close:.2f}) 上穿下趋势线({trendline_val_upper:.2f})")
# self.send_open_order("BUY", open_price, self.trade_volume, current_atr)
#
# elif "SELL" in self.order_direction and last_close < trendline_val_lower:
# if prob is None or prob[2] < prob[3]:
# self.log(f"做空信号: Close({last_close:.2f}) 下穿上趋势线({trendline_val_lower:.2f})")
# self.send_open_order("SELL", open_price, self.trade_volume, current_atr)
# send_open_order, send_market_order, on_rollover 等方法与上一版本完全相同,保持不变
def send_open_order(self, direction: str, entry_price: float, volume: int, current_atr: float):
if direction == "BUY":
initial_stop = entry_price - current_atr * self.atr_multiplier
else:
initial_stop = entry_price + current_atr * self.atr_multiplier
current_time = self.get_current_time()
order_id = f"{self.symbol}_{direction}_{current_time.strftime('%Y%m%d%H%M%S')}"
order_direction = "BUY" if direction == "BUY" else "SELL"
order = Order(id=order_id, symbol=self.symbol, direction=order_direction, volume=volume, price_type="MARKET",
submitted_time=current_time, offset="OPEN")
self.send_order(order)
self.pos_meta[self.symbol] = {"direction": direction, "volume": volume, "entry_price": entry_price,
"trailing_stop": initial_stop}
self.log(
f"发送开仓订单: {direction} {volume}手 @ Market Price (执行价约 {entry_price:.2f}), 初始ATR止损位: {initial_stop:.2f}")
def send_market_order(self, direction: str, volume: int):
current_time = self.get_current_time()
order_id = f"{self.symbol}_{direction}_{current_time.strftime('%Y%m%d%H%M%S')}"
order = Order(id=order_id, symbol=self.symbol, direction=direction, volume=volume, price_type="MARKET",
submitted_time=current_time, offset="CLOSE")
self.send_order(order)
self.log(f"发送平仓订单: {direction} {volume}手 @ Market Price")
def on_rollover(self, old_symbol: str, new_symbol: str):
super().on_rollover(old_symbol, new_symbol)
self.cancel_all_pending_orders(new_symbol)
self.pos_meta.clear()

File diff suppressed because one or more lines are too long

View File

@@ -122,166 +122,117 @@ class ResultAnalyzer:
def analyze_indicators(self, profit_offset: float = 0.0) -> None:
"""
分析所有平仓交易的指标值与实现盈亏的关系,并绘制累积盈亏曲线
图表将展示指标值区间与对应累积盈亏的关系,帮助找出具有概率优势的指标区间
同时会标记出最大和最小累积盈亏对应的指标值,并优化标注位置以避免重叠。
分析开仓时的指标值与平仓时实现盈亏的关系,并绘制累积盈亏曲线。
简化假设:每笔开仓(is_open_trade=True)都会在后续被一笔平仓(is_close_trade=True)全平
"""
close_trades = [trade for trade in self.trade_history if trade.is_close_trade]
# 1. 分离开仓和平仓交易
open_trades = [t for t in self.trade_history if t.is_open_trade]
close_trades = [t for t in self.trade_history if t.is_close_trade]
if not close_trades:
print(
"没有平仓交易可供分析。请确保 trade_history 中有 is_close_trade 为 True 的交易。"
)
print("没有平仓交易可供分析。请确保策略有平仓行为。")
return
# 2. 配对数量(取较小值)
num_pairs = min(len(open_trades), len(close_trades))
if num_pairs == 0:
print("开仓和平仓交易数量不匹配,无法配对分析。")
return
print(f"将进行 {num_pairs} 组开仓-平仓配对分析")
for indicator in self.indicator_list:
# 假设每个 indicator 对象都有一个 get_name() 方法
indicator_name = indicator.get_name()
# 收集指标的所有值和对应的实现盈亏
# 3. 按配对顺序收集指标值和盈亏
indi_values = []
pnls = []
for trade in close_trades:
# 确保 trade.indicator_dict 中包含当前指标的值
# 并且这个值是可用的非None或NaN
if (
trade.indicator_dict is not None
and indicator_name in trade.indicator_dict
and trade.indicator_dict[indicator_name] is not None
):
# 检查是否为 NaN如果使用 np.nan则需要 isinstance(value, float) and np.isnan(value)
# 为了简化,这里假设非 None 即为有效数值
if not (
isinstance(trade.indicator_dict[indicator_name], float)
and np.isnan(trade.indicator_dict[indicator_name])
):
indi_values.append(trade.indicator_dict[indicator_name])
pnls.append(trade.realized_pnl - profit_offset)
for i in range(num_pairs):
open_trade = open_trades[i]
close_trade = close_trades[i]
if (open_trade.indicator_dict is not None and
indicator_name in open_trade.indicator_dict):
value = open_trade.indicator_dict[indicator_name]
if not (isinstance(value, float) and np.isnan(value)):
indi_values.append(value)
pnls.append(close_trade.realized_pnl - profit_offset)
if not indi_values:
print(f"指标 '{indicator_name}' 没有对应的有效平仓交易数据跳过绘图。")
print(f"指标 '{indicator_name}' 没有有效数据跳过绘图。")
continue
# 将收集到的数据转换为 Pandas DataFrame 进行更便捷的处理
# DataFrame 的结构为:['indicator_value', 'realized_pnl']
df = pd.DataFrame({"indicator_value": indi_values, "realized_pnl": pnls})
# 4. 数据清洗与准备
df = pd.DataFrame({
"indicator_value": indi_values,
"realized_pnl": pnls
})
def remove_extreme(df, col='indicator_value', k=3):
"""IQR 稳健过滤,返回过滤后的 df 和被剔除的行数"""
q1, q3 = df[col].quantile([0.25, 0.75])
iqr = q3 - q1
lower, upper = q1 - k * iqr, q3 + k * iqr
mask = df[col].between(lower, upper)
n_remove = (~mask).sum()
return df[mask].copy(), n_remove
mask = df[col].between(q1 - k * iqr, q3 + k * iqr)
return df[mask].copy(), (~mask).sum()
df, n_drop = remove_extreme(df) # 默认 k=2.5
df, n_drop = remove_extreme(df)
if n_drop:
print(f"指标 '{indicator_name}' 过滤掉 {n_drop} 个极端异常值 "
f"({n_drop / len(df) * 100:.1f}%),剩余 {len(df)} 条样本。")
print(f"指标 '{indicator_name}' 过滤掉 {n_drop} 个极端值,剩余 {len(df)} 条样本。")
if df.empty:
print(f"指标 '{indicator_name}' 过滤后无数据,跳过绘图。")
continue
# 确保数据框不为空
if df.empty:
print(f"指标 '{indicator_name}' 的数据框为空,跳过绘图。")
continue
# 按照指标值进行排序,这是计算累积和的关键步骤
# 5. 计算累积盈亏曲线数据
df = df.sort_values(by="indicator_value").reset_index(drop=True)
# --- 绘制累积收益曲线 ---
plt.figure(figsize=(12, 7)) # 创建一个新的图表
# 获取指标值的范围用于生成X轴的等距点
# 定义 max_val 和 min_val这是修复的关键
min_val = df["indicator_value"].min()
max_val = df["indicator_value"].max()
# 特殊处理:如果所有指标值都相同
# 处理所有值相同的情况
if min_val == max_val:
total_pnl = df["realized_pnl"].sum()
print(
f"指标 '{indicator_name}' 的所有值都相同 ({min_val:.2f}),无法创建区间图,绘制一个点表示总收益。"
)
plt.plot(min_val, total_pnl, "ro", markersize=8) # 绘制一个点
plt.title(
f"{indicator_name} Value vs. Cumulative PnL (All values are {min_val:.2f})"
)
plt.figure(figsize=(12, 7))
plt.plot(min_val, total_pnl, "ro", markersize=8)
plt.title(f"{indicator_name} Value vs. Cumulative PnL (All values are {min_val:.2f})")
plt.xlabel(f"{indicator_name} Value")
plt.ylabel("Cumulative Realized PnL")
plt.grid(True)
plt.text(
min_val,
total_pnl,
f" Total PnL: {total_pnl:.2f}",
ha="center",
va="bottom",
)
plt.text(min_val, total_pnl, f" Total PnL: {total_pnl:.2f}", ha="center", va="bottom")
plt.show()
continue
# 生成X轴上的100个等距点这些点代表了指标值的不同阈值
x_points = np.linspace(min_val, max_val, 100)
y_cumulative_pnl = [df[df["indicator_value"] <= xp]["realized_pnl"].sum() for xp in x_points]
# 计算Y轴的值对于每个X轴点 xpY轴值是所有 'indicator_value' <= xp 的 'realized_pnl' 之和
y_cumulative_pnl = []
for xp in x_points:
# 筛选出指标值小于等于当前x_point的所有交易并求和它们的 realized_pnl
cumulative_pnl = df[df["indicator_value"] <= xp]["realized_pnl"].sum()
y_cumulative_pnl.append(cumulative_pnl)
# 6. 绘图(完整版)
plt.figure(figsize=(12, 7))
plt.plot(x_points, y_cumulative_pnl, marker="o", markersize=3,
label=f"Cumulative PnL for {indicator_name}", alpha=0.8)
# 绘制累积盈亏曲线
plt.plot(
x_points,
y_cumulative_pnl,
marker="o",
linestyle="-",
markersize=3,
label=f"Cumulative PnL for {indicator_name}",
alpha=0.8,
)
# 标记累积盈亏的最大值点
# 标记最大值点
optimal_index = np.argmax(y_cumulative_pnl)
optimal_indi_value = x_points[optimal_index]
max_cumulative_pnl = y_cumulative_pnl[optimal_index]
# 标记累积盈亏的最小值点
min_pnl_index = np.argmin(y_cumulative_pnl[:optimal_index]) if len(y_cumulative_pnl[:optimal_index]) > 0 else 0
# 标记最小值点(在最大值左侧找)
min_pnl_index = np.argmin(y_cumulative_pnl[:optimal_index]) if optimal_index > 0 else 0
min_indi_value_at_pnl = x_points[min_pnl_index]
min_cumulative_pnl = y_cumulative_pnl[min_pnl_index]
# 动态调整标注位置以避免重叠
offset_x = (max_val - min_val) * 0.05 # 水平偏移量
# 动态调整标注位置使用已定义的max_val和min_val
offset_x = (max_val - min_val) * 0.05
# 默认标注为右侧对齐,文本在点的左侧
max_ha = "right"
max_xytext_x = optimal_indi_value - offset_x
min_ha = "right"
min_xytext_x = min_indi_value_at_pnl - offset_x
# 如果最大值点在最小值点右侧,则最大值标注放左侧,最小值标注放右侧
# 这样可以避免标注文本重叠
if optimal_indi_value > min_indi_value_at_pnl:
max_ha = "left"
max_xytext_x = optimal_indi_value + offset_x
min_ha = "right"
min_xytext_x = min_indi_value_at_pnl - offset_x
else: # 如果最大值点在最小值点左侧或重合
max_ha = "right"
max_xytext_x = optimal_indi_value - offset_x
min_ha = "left"
min_xytext_x = min_indi_value_at_pnl + offset_x
max_ha, max_xytext_x = "left", optimal_indi_value + offset_x
min_ha, min_xytext_x = "right", min_indi_value_at_pnl - offset_x
else:
max_ha, max_xytext_x = "right", optimal_indi_value - offset_x
min_ha, min_xytext_x = "left", min_indi_value_at_pnl + offset_x
# 绘制最大值垂直线和标注
plt.axvline(
optimal_indi_value,
color="red",
linestyle="--",
label=f"Max PnL Threshold: {optimal_indi_value:.2f}",
alpha=0.7,
)
# 绘制最大值标注
plt.axvline(optimal_indi_value, color="red", linestyle="--", alpha=0.7)
plt.annotate(
f"Max Cum. PnL: {max_cumulative_pnl:.2f}",
xy=(optimal_indi_value, max_cumulative_pnl),
@@ -289,24 +240,12 @@ class ResultAnalyzer:
arrowprops=dict(facecolor="red", shrink=0.05),
horizontalalignment=max_ha,
verticalalignment="bottom",
color="red",
color="red"
)
# 绘制最小值垂直线和标注
plt.axvline(
min_indi_value_at_pnl,
color="blue",
linestyle=":",
label=f"Min PnL Threshold: {min_indi_value_at_pnl:.2f}",
alpha=0.7,
)
# 垂直偏移最小值标注,避免与曲线重叠
min_text_y_offset = (
-(max_cumulative_pnl - min_cumulative_pnl) * 0.1
if max_cumulative_pnl != min_cumulative_pnl
else -0.05
)
# 绘制最小值标注
plt.axvline(min_indi_value_at_pnl, color="blue", linestyle=":", alpha=0.7)
min_text_y_offset = -(max_cumulative_pnl - min_cumulative_pnl) * 0.1
plt.annotate(
f"Min Cum. PnL: {min_cumulative_pnl:.2f}",
xy=(min_indi_value_at_pnl, min_cumulative_pnl),
@@ -314,7 +253,7 @@ class ResultAnalyzer:
arrowprops=dict(facecolor="blue", shrink=0.05),
horizontalalignment=min_ha,
verticalalignment="top",
color="blue",
color="blue"
)
plt.title(f"{indicator_name} Value vs. Cumulative Realized PnL")
@@ -322,7 +261,7 @@ class ResultAnalyzer:
plt.ylabel("Cumulative Realized PnL")
plt.grid(True)
plt.legend()
plt.tight_layout() # 自动调整图表参数,使之更紧凑
plt.tight_layout()
plt.show()
print("\n所有指标分析图表已生成。")
print("\n所有指标分析成。")

View File

@@ -162,22 +162,22 @@ class BacktestEngine:
self.strategy.on_open_bar(current_bar.open, current_bar.symbol)
current_indicator_dict = {}
# current_indicator_dict = {}
close_array = np.array(self.close_list)
open_array = np.array(self.open_list)
high_array = np.array(self.high_list)
low_array = np.array(self.low_list)
volume_array = np.array(self.volume_list)
for indicator in self.indicators:
current_indicator_dict[indicator.get_name()] = indicator.get_latest_value(
close_array,
open_array,
high_array,
low_array,
volume_array
)
self.simulator.process_pending_orders(current_bar, current_indicator_dict)
# for indicator in self.indicators:
# current_indicator_dict[indicator.get_name()] = indicator.get_latest_value(
# close_array,
# open_array,
# high_array,
# low_array,
# volume_array
# )
self.simulator.process_pending_orders(current_bar)
self.all_bars.append(current_bar)
self.close_list.append(current_bar.close)
@@ -191,7 +191,7 @@ class BacktestEngine:
# self.strategy.on_bar(current_bar)
self.strategy.on_close_bar(current_bar)
self.simulator.process_pending_orders(current_bar, current_indicator_dict)
self.simulator.process_pending_orders(current_bar)
# 8. 记录投资组合快照
@@ -231,6 +231,10 @@ class BacktestEngine:
# 回测结束后,获取所有交易记录
self.trade_history = self.simulator.get_trade_history()
print("\n--- 批量计算指标并赋值到Trade ---")
indicator_df = self._batch_calculate_indicators()
self._assign_indicators_to_trades(indicator_df)
print("--- 回测结束 ---")
print(f"总计处理了 {len(self.all_bars)} 根K线。")
print(f"总计发生了 {len(self.trade_history)} 笔交易。")
@@ -246,6 +250,57 @@ class BacktestEngine:
print(f"最终总净值: {final_portfolio_value:.2f}")
print(f"总收益率: {total_return_percentage:.2f}%")
def _batch_calculate_indicators(self) -> pd.DataFrame:
"""
批量计算所有Bar的指标值返回DataFrameindex为Bar序号
"""
if not self.indicators:
return pd.DataFrame()
# 一次性转换为numpy数组避免重复构建
close_arr = np.array(self.close_list)
open_arr = np.array(self.open_list)
high_arr = np.array(self.high_list)
low_arr = np.array(self.low_list)
volume_arr = np.array(self.volume_list)
indicator_data = {}
for indicator in self.indicators:
# 向量化计算所有历史Bar的指标值
values = indicator.get_values(
close_arr, open_arr, high_arr, low_arr, volume_arr
)
indicator_data[indicator.get_name()] = values
return pd.DataFrame(indicator_data)
def _assign_indicators_to_trades(self, indicator_df: pd.DataFrame):
"""
根据fill_time将指标值反向赋值给Trade对象
"""
if indicator_df.empty or not self.trade_history:
return
# 建立时间戳到Bar索引的映射O(n)复杂度)
time_to_index = {
bar.datetime: i for i, bar in enumerate(self.all_bars)
}
# 遍历所有交易记录
for trade in self.trade_history:
# 仅对开仓交易赋值(保持与原逻辑一致)
if not trade.is_open_trade:
continue
# 查找成交时间对应的Bar索引
bar_idx = time_to_index.get(trade.fill_time) - 1
if bar_idx is None:
print(f"警告: Trade {trade.order_id} 的fill_time {trade.fill_time} 未找到对应Bar")
continue
# 赋值指标字典(保持与原逻辑一致)
trade.indicator_dict = indicator_df.iloc[bar_idx].to_dict()
def get_backtest_results(self) -> Dict[str, Any]:
"""
返回回测结果数据,供结果分析模块使用。

View File

@@ -90,7 +90,7 @@ class ExecutionSimulator:
self.pending_orders[order.id] = order
return order
def process_pending_orders(self, current_bar: Bar, indicator_dict: Dict[str, float]):
def process_pending_orders(self, current_bar: Bar):
order_ids_to_process = list(self.pending_orders.keys())
for order_id in order_ids_to_process:
@@ -154,10 +154,10 @@ class ExecutionSimulator:
trade = self._execute_single_order(order, current_bar)
if trade:
self.trade_log.append(trade)
if trade.is_open_trade:
self.indicator_dict = indicator_dict
elif trade.is_close_trade:
trade.indicator_dict = self.indicator_dict.copy()
# if trade.is_open_trade:
# self.indicator_dict = indicator_dict
# elif trade.is_close_trade:
# trade.indicator_dict = self.indicator_dict.copy()
def _execute_single_order(self, order: Order, current_bar: Bar) -> Optional[Trade]:

View File

@@ -44,7 +44,7 @@ INDICATOR_LIST = [
ADX(240),
BollingerBandwidth(10, nbdev=1.5),
BollingerBandwidth(20, 2.0),
BollingerBandwidth(50, nbdev=2.5),
BollingerBandwidth(50, 2.5),
PriceRangeToVolatilityRatio(3, 5),
PriceRangeToVolatilityRatio(3, 14),
PriceRangeToVolatilityRatio(3, 21),
@@ -54,12 +54,16 @@ INDICATOR_LIST = [
PriceRangeToVolatilityRatio(21, 5),
PriceRangeToVolatilityRatio(21, 14),
PriceRangeToVolatilityRatio(21, 21),
# ImpulseCandleConviction(3, 1),
# RelativeVolumeInWindow(3, 5),
# RelativeVolumeInWindow(3, 14),
# RelativeVolumeInWindow(3, 21),
# RelativeVolumeInWindow(3, 30),
# RelativeVolumeInWindow(3, 40),
# ZScoreATR(7, 100),
# ZScoreATR(14, 100),
ImpulseCandleConviction(3, 1),
ZScoreATR(7, 100),
ZScoreATR(14, 100),
FFTTrendStrength(46, 2, 23),
FFTTrendStrength(46, 1, 23),
AtrVolatility(7),
AtrVolatility(14),
AtrVolatility(21),
AtrVolatility(230),
FFTPhaseShift(),
VolatilitySkew(),
VolatilityTrendRelationship()
]

View File

@@ -4,19 +4,21 @@ from typing import List, Union, Tuple, Optional
import numpy as np
import talib
from numpy.lib._stride_tricks_impl import sliding_window_view
from scipy import stats
from src.indicators.base_indicators import Indicator
class Empty(Indicator, ABC):
def get_values(self, close: np.array, open: np.array, high: np.array, low: np.array, volume: np.array):
return []
def is_condition_met(self,
close: np.array,
open: np.array,
high: np.array,
low: np.array,
volume: np.array):
close: np.array,
open: np.array,
high: np.array,
low: np.array,
volume: np.array):
return True
def get_name(self):
@@ -402,7 +404,7 @@ class PriceRangeToVolatilityRatio(Indicator):
return ratio
def _rolling_max(self,arr: np.array, window: int) -> np.array:
def _rolling_max(self, arr: np.array, window: int) -> np.array:
if len(arr) < window:
return np.full_like(arr, np.nan)
@@ -591,15 +593,17 @@ class ROC_MA(Indicator):
"""
return f"roc_ma_{self.roc_window}_{self.ma_window}"
from numpy.lib.stride_tricks import sliding_window_view
class ZScoreATR(Indicator):
def __init__(
self,
atr_window: int = 14,
z_window: int = 100,
down_bound: float = None,
up_bound: float = None,
self,
atr_window: int = 14,
z_window: int = 100,
down_bound: float = None,
up_bound: float = None,
):
super().__init__(down_bound, up_bound)
self.atr_window = atr_window
@@ -616,7 +620,7 @@ class ZScoreATR(Indicator):
# Step 2: 只对有效区域计算 z-score
start_idx = self.atr_window - 1 # ATR 从这里开始非 NaN
valid_atr = atr[start_idx:] # shape: (n - start_idx,)
valid_atr = atr[start_idx:] # shape: (n - start_idx,)
valid_n = len(valid_atr)
if valid_n < self.z_window:
@@ -627,8 +631,8 @@ class ZScoreATR(Indicator):
windows = sliding_window_view(valid_atr, window_shape=self.z_window)
# Step 4: 向量化计算均值和标准差(沿窗口轴)
means = np.mean(windows, axis=1) # shape: (M,)
stds = np.std(windows, axis=1, ddof=0) # shape: (M,)
means = np.mean(windows, axis=1) # shape: (M,)
stds = np.std(windows, axis=1, ddof=0) # shape: (M,)
# Step 5: 计算 z-score当前值是窗口最后一个元素
current_vals = valid_atr[self.z_window - 1:] # 对齐窗口末尾
@@ -648,3 +652,598 @@ class ZScoreATR(Indicator):
def get_name(self):
return f"z_atr_{self.atr_window}_{self.z_window}"
from scipy.signal import stft
class FFTTrendStrength(Indicator):
"""
傅里叶趋势强度指标 (FFT_TrendStrength)
该指标通过短时傅里叶变换(STFT)计算低频能量占比,量化趋势强度。
低频能量占比越高,趋势越强;当该值在不同波动率环境下变化时,
往往预示策略转折点。
"""
def __init__(
self,
spectral_window: int = 46, # 2天×23根/天
low_freq_days: float = 2.0, # 低频定义下限(天)
bars_per_day: int = 23, # 每日K线数量
down_bound: float = None,
up_bound: float = None,
shift_window: int = 0,
):
"""
初始化 FFT_TrendStrength 指标。
Args:
spectral_window (int): STFT窗口大小(根K线)
low_freq_days (float): 低频定义下限(天)
bars_per_day (int): 每日K线数量
down_bound (float): (可选) 用于条件判断的下轨
up_bound (float): (可选) 用于条件判断的上轨
shift_window (int): (可选) 指标值的时间偏移
"""
super().__init__(down_bound, up_bound)
self.spectral_window = spectral_window
self.low_freq_days = low_freq_days
self.bars_per_day = bars_per_day
self.shift_window = shift_window
def get_values(
self,
close: np.array,
open: np.array,
high: np.array,
low: np.array,
volume: np.array,
) -> np.array:
"""
计算傅里叶趋势强度值。
Args:
close (np.array): 收盘价列表
其他参数保留接口兼容性本指标仅使用close
Returns:
np.array: 趋势强度值列表(0~1)数据不足时为NaN
"""
n = len(close)
trend_strengths = np.full(n, np.nan)
# 验证最小数据要求
min_required = self.spectral_window + 5
if n < min_required:
return trend_strengths
# 频率边界计算
low_freq_bound = 1.0 / self.low_freq_days if self.low_freq_days > 0 else float('inf')
# 为每个时间点计算趋势强度
for i in range(min_required - 1, n):
# 获取窗口内数据
window_data = close[max(0, i - self.spectral_window + 1): i + 1]
# 跳过数据不足的窗口
if len(window_data) < self.spectral_window:
continue
# 价格归一化
window_mean = np.mean(window_data)
window_std = np.std(window_data)
if window_std < 1e-8:
continue
normalized = (window_data - window_mean) / window_std
try:
# STFT计算
f, t, Zxx = stft(
normalized,
fs=self.bars_per_day,
nperseg=self.spectral_window,
noverlap=max(0, self.spectral_window // 2),
boundary=None,
padded=False
)
# 频率过滤
max_freq = self.bars_per_day / 2
valid_mask = (f >= 0) & (f <= max_freq)
if not np.any(valid_mask):
continue
f = f[valid_mask]
Zxx = Zxx[valid_mask, :]
if Zxx.shape[1] == 0:
continue
# 能量计算
current_energy = np.abs(Zxx[:, -1]) ** 2
low_freq_mask = f < low_freq_bound
high_freq_mask = f > 1.0 # 高频: <1天周期
low_energy = np.sum(current_energy[low_freq_mask]) if np.any(low_freq_mask) else 0.0
high_energy = np.sum(current_energy[high_freq_mask]) if np.any(high_freq_mask) else 0.0
total_energy = low_energy + high_energy + 1e-8
trend_strength = low_energy / total_energy
trend_strengths[i] = np.clip(trend_strength, 0.0, 1.0)
except Exception:
continue
# 应用时间偏移
if self.shift_window > 0 and len(trend_strengths) > self.shift_window:
trend_strengths = np.roll(trend_strengths, -self.shift_window)
trend_strengths[-self.shift_window:] = np.nan
return trend_strengths
def get_name(self) -> str:
return f"fft_trend_{self.spectral_window}_{self.low_freq_days}"
class AtrVolatility(Indicator):
"""
波动率环境识别指标 (VolatilityRegime)
该指标识别当前市场处于高波动还是低波动环境,对策略转折点
有强预测能力。在低波动环境下,趋势信号往往失效转为反转。
"""
def __init__(
self,
vol_window: int = 23, # 波动率计算窗口
down_bound: float = None,
up_bound: float = None,
shift_window: int = 0,
):
"""
初始化 VolatilityRegime 指标。
Args:
vol_window (int): ATR波动率计算窗口
high_vol_threshold (float): 高波动阈值(%),高于此值为高波动环境
low_vol_threshold (float): 低波动阈值(%),低于此值为低波动环境
down_bound (float): (可选) 用于条件判断的下轨
up_bound (float): (可选) 用于条件判断的上轨
shift_window (int): (可选) 指标值的时间偏移
"""
super().__init__(down_bound, up_bound)
self.vol_window = vol_window
self.shift_window = shift_window
def get_values(
self,
close: np.array,
open: np.array,
high: np.array,
low: np.array,
volume: np.array,
) -> np.array:
"""
计算波动率环境指标。
返回值含义:
- 1.0: 高波动环境 (趋势策略有效)
- 0.0: 中波动环境 (谨慎)
- -1.0: 低波动环境 (反转策略有效)
Args:
close (np.array): 收盘价列表
high (np.array): 最高价列表
low (np.array): 最低价列表
其他参数保留接口兼容性
Returns:
np.array: 波动率环境标识数据不足时为NaN
"""
n = len(close)
regimes = np.full(n, np.nan)
# 验证最小数据要求
if n < self.vol_window + 1:
return regimes
# 计算ATR
try:
atr = talib.ATR(high, low, close, timeperiod=self.vol_window)
except Exception:
return regimes
# 计算标准化波动率 (%)
volatility = (atr / close) * 100
return volatility
def get_name(self) -> str:
return f"atr_volume_{self.vol_window}"
class FFTPhaseShift(Indicator):
"""
傅里叶相位偏移指标 (FFT_PhaseShift)
该指标检测频域中主导频率的相位偏移,相位突变往往预示市场
趋势的转折点。特别适用于捕捉低波动环境下的价格极端位置。
"""
def __init__(
self,
spectral_window: int = 46, # 2天×23根/天
dominant_freq_bound: float = 0.5, # 主导频率上限(cycles/day)
phase_shift_threshold: float = 1.0, # 相位偏移阈值(弧度)
bars_per_day: int = 23, # 每日K线数量
down_bound: float = None,
up_bound: float = None,
shift_window: int = 0,
):
"""
初始化 FFT_PhaseShift 指标。
Args:
spectral_window (int): STFT窗口大小(根K线)
dominant_freq_bound (float): 主导频率上限(cycles/day)
phase_shift_threshold (float): 相位偏移阈值(弧度)
bars_per_day (int): 每日K线数量
down_bound (float): (可选) 用于条件判断的下轨
up_bound (float): (可选) 用于条件判断的上轨
shift_window (int): (可选) 指标值的时间偏移
"""
super().__init__(down_bound, up_bound)
self.spectral_window = spectral_window
self.dominant_freq_bound = dominant_freq_bound
self.phase_shift_threshold = phase_shift_threshold
self.bars_per_day = bars_per_day
self.shift_window = shift_window
def get_values(
self,
close: np.array,
open: np.array,
high: np.array,
low: np.array,
volume: np.array,
) -> np.array:
"""
计算傅里叶相位偏移值。
返回值含义:
- 1.0: 相位正向偏移(可能预示上涨转折)
- -1.0: 相位负向偏移(可能预示下跌转折)
- 0.0: 无显著相位偏移
Args:
close (np.array): 收盘价列表
其他参数保留接口兼容性本指标仅使用close
Returns:
np.array: 相位偏移标识数据不足时为NaN
"""
n = len(close)
phase_shifts = np.full(n, np.nan)
# 验证最小数据要求
min_required = self.spectral_window + 5
if n < min_required:
return phase_shifts
# 为每个时间点计算相位偏移
prev_phase = None
for i in range(min_required - 1, n):
# 获取窗口内数据
window_data = close[max(0, i - self.spectral_window + 1): i + 1]
if len(window_data) < self.spectral_window:
continue
# 价格归一化
window_mean = np.mean(window_data)
window_std = np.std(window_data)
if window_std < 1e-8:
continue
normalized = (window_data - window_mean) / window_std
try:
# STFT计算
f, t, Zxx = stft(
normalized,
fs=self.bars_per_day,
nperseg=self.spectral_window,
noverlap=max(0, self.spectral_window // 2),
boundary=None,
padded=False
)
# 频率过滤
max_freq = self.bars_per_day / 2
valid_mask = (f >= 0) & (f <= max_freq)
if not np.any(valid_mask):
continue
f = f[valid_mask]
Zxx = Zxx[valid_mask, :]
if Zxx.shape[1] < 2: # 需要至少两个时间点计算相位变化
continue
# 计算相位
phases = np.angle(Zxx[:, -1])
prev_phases = np.angle(Zxx[:, -2])
# 找出主导频率(低频)
low_freq_mask = f < self.dominant_freq_bound
if not np.any(low_freq_mask):
continue
# 计算主导频率的相位差
dominant_idx = np.argmax(np.abs(Zxx[low_freq_mask, -1]))
current_phase = phases[low_freq_mask][dominant_idx]
prev_dominant_phase = prev_phases[low_freq_mask][dominant_idx]
# 计算相位差(考虑2π周期性)
phase_diff = current_phase - prev_dominant_phase
phase_diff = (phase_diff + np.pi) % (2 * np.pi) - np.pi
# 确定相位偏移方向
if np.abs(phase_diff) > self.phase_shift_threshold:
phase_shifts[i] = 1.0 if phase_diff > 0 else -1.0
else:
phase_shifts[i] = 0.0
prev_phase = current_phase
except Exception:
continue
# 应用时间偏移
if self.shift_window > 0 and len(phase_shifts) > self.shift_window:
phase_shifts = np.roll(phase_shifts, -self.shift_window)
phase_shifts[-self.shift_window:] = np.nan
return phase_shifts
def get_name(self) -> str:
return f"fft_phase_{self.spectral_window}_{self.dominant_freq_bound}"
class VolatilitySkew(Indicator):
"""
波动率偏斜指标 (VolatilitySkew)
该指标测量近期波动率分布的偏斜程度,正偏斜表示波动率上升趋势,
负偏斜表示波动率下降趋势。波动率偏斜的变化往往预示策略逻辑
的转折点,特别是在低波动环境向高波动环境转换时。
"""
def __init__(
self,
vol_window: int = 20, # 单期波动率计算窗口
skew_window: int = 60, # 偏斜计算窗口
down_bound: float = None,
up_bound: float = None,
shift_window: int = 0,
):
"""
初始化 VolatilitySkew 指标。
Args:
vol_window (int): ATR波动率计算窗口
skew_window (int): 偏斜计算窗口
positive_threshold (float): 正偏斜阈值
negative_threshold (float): 负偏斜阈值
down_bound (float): (可选) 用于条件判断的下轨
up_bound (float): (可选) 用于条件判断的上轨
shift_window (int): (可选) 指标值的时间偏移
"""
super().__init__(down_bound, up_bound)
self.vol_window = vol_window
self.skew_window = skew_window
self.shift_window = shift_window
def get_values(
self,
close: np.array,
open: np.array,
high: np.array,
low: np.array,
volume: np.array,
) -> np.array:
"""
计算波动率偏斜指标。
返回值含义:
- 1.0: 正偏斜(波动率上升趋势,可能预示高波动环境到来)
- -1.0: 负偏斜(波动率下降趋势,可能预示低波动环境到来)
- 0.0: 无显著偏斜
Args:
close (np.array): 收盘价列表
high (np.array): 最高价列表
low (np.array): 最低价列表
其他参数保留接口兼容性
Returns:
np.array: 波动率偏斜标识数据不足时为NaN
"""
n = len(close)
skews = np.full(n, np.nan)
# 验证最小数据要求
if n < self.vol_window + self.skew_window:
return skews
# 计算ATR
try:
atr = talib.ATR(high, low, close, timeperiod=self.vol_window)
except Exception:
return skews
# 计算标准化波动率 (%)
volatility = (atr / close) * 100
# 计算滚动偏斜
for i in range(self.vol_window + self.skew_window - 1, n):
window_vol = volatility[i - self.skew_window + 1: i + 1]
valid_vol = window_vol[~np.isnan(window_vol)]
if len(valid_vol) < self.skew_window * 0.7: # 要求70%有效数据
continue
# 计算偏斜
skew_value = stats.skew(valid_vol)
skews[i] = skew_value
return skews
def get_name(self) -> str:
return f"vol_skew_{self.vol_window}_{self.skew_window}"
import numpy as np
import talib
from src.indicators.base_indicators import Indicator
class VolatilityTrendRelationship(Indicator):
"""
精准修复版:波动率-趋势关系指标
仅修复NaN问题
1. 保留talib的ATR计算性能和稳定性更优
2. 修复std_val计算中的NaN传播
3. 添加严格的NaN处理确保100%数据有效性
4. 保持原始物理逻辑不变
核心修复点:
- 在计算标准差前过滤NaN值
- 为平滑后的序列提供安全回退值
- 确保所有中间步骤处理NaN
"""
def __init__(
self,
vol_window: int = 20, # 波动率计算窗口
price_lag: int = 3, # 价格自相关滞后
ma_window: int = 5, # 平滑窗口
down_bound: float = None,
up_bound: float = None,
shift_window: int = 0,
):
super().__init__(down_bound, up_bound)
self.vol_window = vol_window
self.price_lag = price_lag
self.ma_window = ma_window
self.shift_window = shift_window
def get_values(
self,
close: np.array,
open: np.array,
high: np.array,
low: np.array,
volume: np.array,
) -> np.array:
n = len(close)
relationship = np.full(n, np.nan)
# 验证最小数据要求
min_required = max(self.vol_window, self.price_lag, self.ma_window) + 5
if n < min_required:
return relationship
# 1. 计算标准化波动率 (使用talib保持性能)
try:
atr = talib.ATR(high, low, close, timeperiod=self.vol_window)
volatility = (atr / close) * 100
except Exception:
return relationship
# 2. 计算波动率变化率 (安全处理除零)
vol_change = np.zeros(n)
for i in range(1, n):
if volatility[i - 1] > 1e-8:
vol_change[i] = (volatility[i] - volatility[i - 1]) / volatility[i - 1]
else:
vol_change[i] = 0.0
# 3. 计算价格自相关 (安全实现)
returns = np.diff(close, prepend=close[0]) / (close + 1e-8)
autocorr = np.zeros(n)
for i in range(self.price_lag, n):
if i < self.price_lag * 2:
continue
window_returns = returns[i - self.price_lag * 2:i + 1]
valid_returns = window_returns[~np.isnan(window_returns)]
if len(valid_returns) < self.price_lag * 1.5:
continue
# 计算自相关
lagged = valid_returns[:-self.price_lag]
current = valid_returns[self.price_lag:]
if len(lagged) == 0 or len(current) == 0:
continue
mean_lagged = np.mean(lagged)
mean_current = np.mean(current)
numerator = np.sum((lagged - mean_lagged) * (current - mean_current))
denom_lagged = np.sum((lagged - mean_lagged) ** 2)
denom_current = np.sum((current - mean_current) ** 2)
if denom_lagged > 1e-8 and denom_current > 1e-8:
autocorr[i] = numerator / np.sqrt(denom_lagged * denom_current)
# 4. 计算核心关系指标
raw_relationship = vol_change * autocorr
# 5. 平滑处理 (处理NaN)
smoothed_relationship = np.full(n, np.nan)
for i in range(self.ma_window - 1, n):
window = raw_relationship[max(0, i - self.ma_window + 1):i + 1]
valid_window = window[~np.isnan(window)]
if len(valid_window) > 0:
smoothed_relationship[i] = np.mean(valid_window)
# 6. 修复关键问题std_val计算
# 获取有效数据范围
valid_mask = ~np.isnan(smoothed_relationship[min_required - 1:])
if np.any(valid_mask):
valid_values = smoothed_relationship[min_required - 1:][valid_mask]
std_val = np.std(valid_values) if len(valid_values) > 1 else 1.0
else:
std_val = 1.0 # 安全回退值
# 确保std_val不为零
std_val = max(std_val, 1e-8)
# 7. 标准化到稳定范围 (-1, 1)
for i in range(n):
if not np.isnan(smoothed_relationship[i]):
relationship[i] = smoothed_relationship[i] / (std_val * 3.0)
else:
relationship[i] = 0.0 # 安全默认值
# 8. 截断到合理范围
relationship = np.clip(relationship, -1.0, 1.0)
# 应用时间偏移
if self.shift_window > 0 and len(relationship) > self.shift_window:
relationship = np.roll(relationship, -self.shift_window)
relationship[-self.shift_window:] = np.nan
return relationship
def get_name(self) -> str:
return f"vol_trend_rel_{self.vol_window}_{self.price_lag}"

File diff suppressed because one or more lines are too long

View File

@@ -1,177 +1,249 @@
import numpy as np
import talib
from collections import deque
from typing import Optional, Any, List, Dict
import bisect
from src.core_data import Bar, Order
from src.indicators.base_indicators import Indicator
from src.indicators.indicators import Empty
from src.strategies.base_strategy import Strategy
class TVDZScoreStrategy(Strategy):
# =============================================================================
# 策略实现 (Dual-Mode Kalman Strategy V4 - 滚动窗口修正版)
# =============================================================================
class DualModeKalmanStrategy(Strategy):
"""
内嵌 TVD (Condat 算法) + Z-Score ATR 的趋势突破策略。
无任何外部依赖(如 pytv纯 NumPy 实现。
V4版本更新:
1. 【根本性修正】修复了V3版本中因错误使用全局历史数据而引入的前瞻性偏差和
路径依赖问题。
2. 【正确实现】现在的数据结构严格、精确地只维护当前滚动窗口(vol_lookback)
内的数据,确保了策略的可重复性和逻辑正确性。
3. 通过bisect库在保持100%滚动窗口精度的前提下,实现了高效的百分位计算,
避免了在每个bar上都进行暴力排序。
"""
def __init__(
self,
context: Any,
main_symbol: str,
enable_log: bool,
trade_volume: int,
tvd_lam: float = 50.0,
atr_window: int = 14,
z_window: int = 100,
vol_threshold: float = -0.5,
entry_threshold_atr: float = 3.0,
stop_atr_multiplier: float = 3.0,
order_direction: Optional[List[str]] = None,
self,
context: Any,
main_symbol: str,
enable_log: bool,
trade_volume: int,
# ... (所有策略参数与V2版本完全相同) ...
strategy_mode: str = 'TREND',
kalman_process_noise: float = 0.01,
kalman_measurement_noise: float = 0.5,
atr_period: int = 20,
vol_lookback: int = 100,
vol_percentile_threshold: float = 25.0,
entry_threshold_atr: float = 2.5,
initial_stop_atr_multiplier: float = 2.0,
structural_stop_atr_multiplier: float = 2.5,
order_direction: Optional[List[str]] = None,
indicators: Optional[List[Indicator]] = None,
):
super().__init__(context, main_symbol, enable_log)
# ... (参数赋值与V2版本完全相同) ...
if order_direction is None: order_direction = ['BUY', 'SELL']
self.strategy_mode = strategy_mode.upper()
self.trade_volume = trade_volume
self.order_direction = order_direction or ["BUY", "SELL"]
self.tvd_lam = tvd_lam
self.atr_window = atr_window
self.z_window = z_window
self.vol_threshold = vol_threshold
self.atr_period = atr_period
self.vol_lookback = vol_lookback
self.vol_percentile_threshold = vol_percentile_threshold
self.entry_threshold_atr = entry_threshold_atr
self.stop_atr_multiplier = stop_atr_multiplier
self.initial_stop_atr_multiplier = initial_stop_atr_multiplier
self.structural_stop_atr_multiplier = structural_stop_atr_multiplier
self.order_direction = order_direction
# --- 【修正后的数据结构】 ---
# 1. 严格限定长度的deque用于维护滚动窗口的原始序列
self._vol_history_queue: deque = deque(maxlen=self.vol_lookback)
# 2. 一个普通list我们将手动维护其有序性并确保其内容与deque完全同步
self._sorted_vol_history: List[float] = []
self.Q = kalman_process_noise
self.R = kalman_measurement_noise
self.P = 1.0
self.x_hat = 0.0
self.kalman_initialized = False
self.position_meta: Dict[str, Any] = self.context.load_state()
self.main_symbol = main_symbol
self.order_id_counter = 0
self.log(f"TVDZScoreStrategy Initialized | λ={tvd_lam}, VolThresh={vol_threshold}")
if indicators is None: indicators = [Empty(), Empty()]
self.indicators = indicators
@staticmethod
def _tvd_condat(y, lam):
"""Condat's O(N) TVD algorithm."""
n = y.size
if n == 0:
return y.copy()
x = y.astype(np.float64)
k = 0
k0 = 0
vmin = x[0] - lam
vmax = x[0] + lam
for i in range(1, n):
if x[i] < vmin:
while k < i:
x[k] = vmin
k += 1
k0 = i
vmin = x[i] - lam
vmax = x[i] + lam
elif x[i] > vmax:
while k < i:
x[k] = vmax
k += 1
k0 = i
vmin = x[i] - lam
vmax = x[i] + lam
else:
vmin = max(vmin, x[i] - lam)
vmax = min(vmax, x[i] + lam)
if vmin > vmax:
k = k0
s = np.sum(x[k0:i+1])
s /= (i - k0 + 1)
x[k0:i+1] = s
k = i + 1
k0 = k
if k0 < n:
vmin = x[k0] - lam
vmax = x[k0] + lam
while k < n:
x[k] = vmin
k += 1
return x
self.log(f"DualModeKalmanStrategy V4 (Corrected Rolling Window) Initialized.")
def _compute_zscore_atr_last(self, high, low, close) -> float:
n = len(close)
min_req = self.atr_window + self.z_window - 1
if n < min_req:
return np.nan
start = max(0, n - (self.z_window + self.atr_window))
seg_h, seg_l, seg_c = high[start:], low[start:], close[start:]
atr_full = talib.ATR(seg_h, seg_l, seg_c, timeperiod=self.atr_window)
atr_valid = atr_full[self.atr_window - 1:]
if len(atr_valid) < self.z_window:
return np.nan
window_atr = atr_valid[-self.z_window:]
mu = np.mean(window_atr)
sigma = np.std(window_atr)
last_atr = window_atr[-1]
return (last_atr - mu) / sigma if sigma > 1e-12 else 0.0
def on_init(self):
super().on_init()
self.cancel_all_pending_orders(self.main_symbol)
self.position_meta = self.context.load_state()
# 初始化时清空数据结构
self._vol_history_queue.clear()
self._sorted_vol_history.clear()
def on_open_bar(self, open_price: float, symbol: str):
self.symbol = symbol
bar_history = self.get_bar_history()
if len(bar_history) < max(100, self.atr_window + self.z_window):
return
# 确保有足够的数据来填满第一个完整的窗口
if len(bar_history) < self.vol_lookback + self.atr_period: return
closes = np.array([b.close for b in bar_history], dtype=np.float64)
highs = np.array([b.high for b in bar_history], dtype=np.float64)
lows = np.array([b.low for b in bar_history], dtype=np.float64)
highs = np.array([b.high for b in bar_history], dtype=float)
lows = np.array([b.low for b in bar_history], dtype=float)
closes = np.array([b.close for b in bar_history], dtype=float)
current_atr = talib.ATR(highs, lows, closes, self.atr_period)[-1]
# === TVD 平滑 ===
tvd_prices = self._tvd_condat(closes, self.tvd_lam)
tvd_price = tvd_prices[-1]
last_close = closes[-1]
if last_close <= 0: return
current_normalized_atr = current_atr / last_close
# === Z-Score ATR ===
current_atr = talib.ATR(highs, lows, closes, timeperiod=self.atr_window)[-1]
if current_atr <= 0:
return
# --- 【核心修正:正确的滚动窗口维护】 ---
# 1. 如果窗口已满deque会自动从左侧弹出一个旧值。我们需要捕捉这个值。
oldest_val = None
if len(self._vol_history_queue) == self.vol_lookback:
oldest_val = self._vol_history_queue[0]
deviation = closes[-1] - tvd_price
deviation_in_atr = deviation / current_atr
# 2. 将新值添加到deque的右侧
self._vol_history_queue.append(current_normalized_atr)
# 3. 更新有序列表使其与deque的状态严格同步
if oldest_val is not None:
# a. 先从有序列表中移除旧值
# 由于浮点数精度问题直接remove可能不安全我们使用bisect查找并移除
# 这是一个O(log N) + O(N)的操作,但远快于完全重排
idx_to_remove = bisect.bisect_left(self._sorted_vol_history, oldest_val)
if idx_to_remove < len(self._sorted_vol_history) and abs(
self._sorted_vol_history[idx_to_remove] - oldest_val) < 1e-9:
self._sorted_vol_history.pop(idx_to_remove)
else:
# 备用方案如果bisect找不到理论上不应该则暴力移除
try:
self._sorted_vol_history.remove(oldest_val)
except ValueError:
pass # 如果值不存在,忽略
# b. 将新值高效地插入到有序列表中
bisect.insort_left(self._sorted_vol_history, current_normalized_atr)
# 检查窗口是否已填满
if len(self._sorted_vol_history) < self.vol_lookback: return
# ... (卡尔曼滤波器计算部分保持不变) ...
if not self.kalman_initialized: self.x_hat = closes[-1]
self.kalman_initialized = True
x_hat_minus = self.x_hat
P_minus = self.P + self.Q
K = P_minus / (P_minus + self.R)
self.x_hat = x_hat_minus + K * (closes[-1] - x_hat_minus)
self.P = (1 - K) * P_minus
kalman_price = self.x_hat
position_volume = self.get_current_positions().get(self.symbol, 0)
# ... (持仓同步逻辑不变) ...
if position_volume != 0:
self.manage_open_position(position_volume, bar_history[-1], current_atr, tvd_price)
self.manage_open_position(position_volume, bar_history[-1], current_atr, kalman_price)
return
# --- 使用精确的滚动窗口百分位阈值 ---
percentile_index = int(self.vol_percentile_threshold / 100.0 * (self.vol_lookback - 1))
vol_threshold = self._sorted_vol_history[percentile_index]
if current_normalized_atr < vol_threshold:
return
self.evaluate_entry_signal(bar_history[-1], kalman_price, current_atr)
def manage_open_position(self, volume: int, current_bar: Bar, current_atr: float, kalman_price: float):
# ... (此部分代码与上一版完全相同,保持不变) ...
meta = self.position_meta.get(self.symbol)
if not meta: return
initial_stop_price = meta['initial_stop_price']
if (volume > 0 and current_bar.low <= initial_stop_price) or \
(volume < 0 and current_bar.high >= initial_stop_price):
self.log(f"Initial Stop Loss hit at {initial_stop_price:.4f}")
self.close_position("CLOSE_LONG" if volume > 0 else "CLOSE_SHORT", abs(volume))
return
if self.strategy_mode == 'TREND':
if volume > 0:
stop_price = max(kalman_price - self.structural_stop_atr_multiplier * current_atr, initial_stop_price)
if current_bar.low <= stop_price:
self.log(f"TREND Mode: Structural Stop hit for LONG at {stop_price:.4f}")
self.close_position("CLOSE_LONG", abs(volume))
else:
stop_price = min(kalman_price + self.structural_stop_atr_multiplier * current_atr, initial_stop_price)
if current_bar.high >= stop_price:
self.log(f"TREND Mode: Structural Stop hit for SHORT at {stop_price:.4f}")
self.close_position("CLOSE_SHORT", abs(volume))
elif self.strategy_mode == 'REVERSION':
if volume > 0 and current_bar.high >= kalman_price:
self.log(f"REVERSION Mode: Take Profit for LONG as price reverts to Kalman line at {kalman_price:.4f}")
self.close_position("CLOSE_LONG", abs(volume))
elif volume < 0 and current_bar.low <= kalman_price:
self.log(f"REVERSION Mode: Take Profit for SHORT as price reverts to Kalman line at {kalman_price:.4f}")
self.close_position("CLOSE_SHORT", abs(volume))
def evaluate_entry_signal(self, current_bar: Bar, kalman_price: float, current_atr: float):
# ... (此部分代码与上一版完全相同,保持不变) ...
deviation = current_bar.close - kalman_price
if current_atr <= 0: return
deviation_in_atr = deviation / current_atr
direction = None
if "BUY" in self.order_direction and deviation_in_atr > self.entry_threshold_atr:
direction = "BUY"
elif "SELL" in self.order_direction and deviation_in_atr < -self.entry_threshold_atr:
direction = "SELL"
if self.strategy_mode == 'TREND':
if "BUY" in self.order_direction and deviation_in_atr > self.entry_threshold_atr:
direction = "BUY"
elif "SELL" in self.order_direction and deviation_in_atr < -self.entry_threshold_atr:
direction = "SELL"
elif self.strategy_mode == 'REVERSION':
if "SELL" in self.order_direction and deviation_in_atr > self.entry_threshold_atr:
direction = "SELL"
elif "BUY" in self.order_direction and deviation_in_atr < -self.entry_threshold_atr:
direction = "BUY"
if direction:
self.log(f"Signal Fired | Dir: {direction}, Dev: {deviation_in_atr:.2f} ATR")
entry_price = closes[-1]
stop_loss = (
entry_price - self.stop_atr_multiplier * current_atr
if direction == "BUY"
else entry_price + self.stop_atr_multiplier * current_atr
)
meta = {"entry_price": entry_price, "stop_loss": stop_loss}
self.log(f"{self.strategy_mode} Mode: Entry Signal {direction}. Deviation: {deviation_in_atr:.2f} ATRs.")
entry_price = current_bar.close
stop_loss_price = entry_price - self.initial_stop_atr_multiplier * current_atr if direction == "BUY" else entry_price + self.initial_stop_atr_multiplier * current_atr
meta = {'entry_price': entry_price, 'initial_stop_price': stop_loss_price, 'direction': direction}
self.send_market_order(direction, self.trade_volume, "OPEN", meta)
self.save_state(self.position_meta)
def manage_open_position(self, volume: int, current_bar: Bar, current_atr: float, tvd_price: float):
meta = self.position_meta.get(self.symbol)
if not meta:
return
stop_loss = meta["stop_loss"]
if (volume > 0 and current_bar.low <= stop_loss) or (volume < 0 and current_bar.high >= stop_loss):
self.log(f"Stop Loss Hit at {stop_loss:.4f}")
self.close_position("CLOSE_LONG" if volume > 0 else "CLOSE_SHORT", abs(volume))
def close_position(self, direction: str, volume: int):
self.send_market_order(direction, volume, offset="CLOSE")
if self.symbol in self.position_meta:
del self.position_meta[self.symbol]
self.position_meta = {}
self.save_state(self.position_meta)
def send_market_order(self, direction: str, volume: int, offset: str, meta: Optional[Dict] = None):
if offset == "OPEN" and meta:
self.position_meta[self.symbol] = meta
if offset == "OPEN" and meta: self.position_meta[self.symbol] = meta
order_id = f"{self.symbol}_{direction}_MARKET_{self.order_id_counter}"
self.order_id_counter += 1
order = Order(id=order_id, symbol=self.symbol, direction=direction, volume=volume,
price_type="MARKET", submitted_time=self.get_current_time(), offset=offset)
order = Order(id=order_id, symbol=self.symbol, direction=direction, volume=volume, price_type="MARKET",
submitted_time=self.get_current_time(), offset=offset)
self.send_order(order)
def send_limit_order(self, limit_price: float, direction: str, volume: int, offset: str,
meta: Optional[Dict] = None):
if offset == "OPEN" and meta: self.position_meta[self.symbol] = meta
order_id = f"{self.symbol}_{direction}_MARKET_{self.order_id_counter}"
self.order_id_counter += 1
order = Order(id=order_id, symbol=self.symbol, direction=direction, volume=volume, price_type="LIMIT",
submitted_time=self.get_current_time(), offset=offset, limit_price=limit_price)
self.send_order(order)
def on_rollover(self, old_symbol: str, new_symbol: str):
super().on_rollover(old_symbol, new_symbol)
self.position_meta = {}
self.log("Rollover: Strategy state reset.")
self.kalman_initialized = False
self._sorted_vol_history.clear()
self.log("Rollover detected. All strategy states have been reset.")

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@@ -1,21 +1,32 @@
import numpy as np
import talib
from datetime import datetime
from typing import Optional, Any, List
from scipy.signal import stft
from datetime import datetime, timedelta
from typing import Optional, Any, List, Dict
from src.core_data import Bar, Order
from src.indicators.base_indicators import Indicator
from src.indicators.indicators import Empty, NormalizedATR, AtrVolatility
from src.strategies.base_strategy import Strategy
# =============================================================================
# 瞬态冲击回调与ATR波幅止盈策略 (V4 - 统一信号与Close价核心逻辑版)
# 策略实现 (SpectralTrendStrategy)
# =============================================================================
class TransientShockATRStrategy(Strategy):
class SpectralTrendStrategy(Strategy):
"""
V4版本更新 (根据专业建议重构):
1. 【核心变更】信号识别逻辑统一:不再区分跳空和大阳线,仅使用 (上一bar.close - 上上bar.close) 的绝对波幅作为市场冲击的唯一衡量标准。
2. 【核心变更】信号计算基于Close价所有信号计算均基于更稳健的收盘价避免High/Low的噪音干扰。
3. 【参数简化】移除了冗余的 signal_gap_pct 参数,使策略更简洁、更易于优化。
频域能量相变策略 - 捕获肥尾趋势
核心哲学:
1. 显式傅里叶变换: 直接分离低频(趋势)、高频(噪音)能量
2. 相变临界点: 仅当低频能量占比 > 阈值时入场
3. 低频交易: 每月仅2-5次信号持仓数日捕获肥尾
4. 完全参数化: 无硬编码,适配任何市场时间结构
参数说明:
- bars_per_day: 市场每日K线数量 (e.g., 23 for 15min US markets)
- low_freq_days: 低频定义下限 (天), 默认2.0
- high_freq_days: 高频定义上限 (天), 默认1.0
"""
def __init__(
@@ -24,197 +35,247 @@ class TransientShockATRStrategy(Strategy):
main_symbol: str,
enable_log: bool,
trade_volume: int,
# --- 【核心参数】---
atr_period: int = 20,
signal_move_atr_mult: float = 2.5, # 【变更】定义“市场冲击”的ATR倍数 (统一信号)
entry_pullback_pct: float = 0.5,
order_expiry_bars: int = 5,
initial_stop_atr_mult: float = 2.0,
exit_signal_atr_mult: float = 2.5,
# --- 【市场结构参数】 ---
bars_per_day: int = 23, # 关键: 适配23根/天的市场
# --- 【频域核心参数】 ---
spectral_window_days: float = 2.0, # STFT窗口大小(天)
low_freq_days: float = 2.0, # 低频下限(天)
high_freq_days: float = 1.0, # 高频上限(天)
trend_strength_threshold: float = 0.8, # 相变临界值
exit_threshold: float = 0.4, # 退出阈值
# --- 【持仓管理】 ---
max_hold_days: int = 10, # 最大持仓天数
# --- 其他 ---
order_direction: Optional[List[str]] = None,
indicators: Optional[List[Indicator]] = None,
model_indicator: Indicator = None,
):
super().__init__(context, main_symbol, enable_log)
if not (atr_period > 0 and signal_move_atr_mult > 0 and entry_pullback_pct > 0 and
initial_stop_atr_mult > 0 and exit_signal_atr_mult > 0):
raise ValueError("所有周期和倍数参数必须大于0")
if order_direction is None:
order_direction = ['BUY', 'SELL']
if indicators is None:
indicators = [Empty(), Empty()] # 保持兼容性
# --- 参数赋值 (完全参数化) ---
self.trade_volume = trade_volume
self.atr_period = atr_period
self.signal_move_atr_mult = signal_move_atr_mult # 【变更】使用新参数
self.entry_pullback_pct = entry_pullback_pct
self.order_expiry_bars = order_expiry_bars
self.initial_stop_atr_mult = initial_stop_atr_mult
self.exit_signal_atr_mult = exit_signal_atr_mult
self.bars_per_day = bars_per_day
self.spectral_window_days = spectral_window_days
self.low_freq_days = low_freq_days
self.high_freq_days = high_freq_days
self.trend_strength_threshold = trend_strength_threshold
self.exit_threshold = exit_threshold
self.max_hold_days = max_hold_days
self.order_direction = order_direction
if model_indicator is None:
model_indicator = Empty()
self.model_indicator = model_indicator
self.pending_order: Optional[dict] = None
self.position_entry_price: float = 0.0
self.initial_stop_price: float = 0.0
self.order_id_counter = 0
# --- 动态计算参数 ---
self.spectral_window = int(self.spectral_window_days * self.bars_per_day)
# 确保窗口大小为偶数 (STFT要求)
self.spectral_window = self.spectral_window if self.spectral_window % 2 == 0 else self.spectral_window + 1
def on_init(self):
"""策略初始化"""
self.log(f"🚀 Strategy On Init: Initializing {self.__class__.__name__} V4...")
self.cancel_all_pending_orders(self.main_symbol)
self.pending_order = None
self.position_entry_price = 0.0
self.initial_stop_price = 0.0
# 频率边界 (cycles/day)
self.low_freq_bound = 1.0 / self.low_freq_days if self.low_freq_days > 0 else float('inf')
self.high_freq_bound = 1.0 / self.high_freq_days if self.high_freq_days > 0 else 0.0
# --- 内部状态变量 ---
self.main_symbol = main_symbol
self.order_id_counter = 0
self.log("✅ Strategy Initialized and State Reset.")
self.indicators = indicators
self.entry_time = None # 入场时间
self.position_direction = None # 'LONG' or 'SHORT'
self.last_trend_strength = 0.0
self.last_dominant_freq = 0.0 # 主导周期(天)
self.log(f"SpectralTrendStrategy Initialized (bars/day={bars_per_day}, window={self.spectral_window} bars)")
def on_open_bar(self, open_price: float, symbol: str):
"""每根K线开盘时被调用"""
self.symbol = symbol
bar_history = self.get_bar_history()
current_time = self.get_current_time()
min_bars = self.atr_period + 5
if len(bar_history) < min_bars:
# 需要足够的数据 (STFT窗口 + 缓冲)
if len(bar_history) < self.spectral_window + 10:
if self.enable_log and len(bar_history) % 50 == 0:
self.log(f"Waiting for {len(bar_history)}/{self.spectral_window + 10} bars")
return
positions = self.get_current_positions()
position_volume = positions.get(self.symbol, 0)
position_volume = self.get_current_positions().get(self.symbol, 0)
highs = np.array([b.high for b in bar_history], dtype=float)
lows = np.array([b.low for b in bar_history], dtype=float)
# 获取历史价格 (使用完整历史)
closes = np.array([b.close for b in bar_history], dtype=float)
current_atr = talib.ATR(highs, lows, closes, timeperiod=self.atr_period)[-1]
# 【核心】计算频域趋势强度 (显式傅里叶)
trend_strength, dominant_freq = self.calculate_trend_strength(closes)
self.last_trend_strength = trend_strength
self.last_dominant_freq = dominant_freq
if not self.trading:
# 检查最大持仓时间 (防止极端事件)
if self.entry_time and (current_time - self.entry_time) >= timedelta(days=self.max_hold_days):
self.log(f"Max hold time reached ({self.max_hold_days} days). Forcing exit.")
self.close_all_positions()
self.entry_time = None
self.position_direction = None
return
if position_volume != 0:
self.manage_position(bar_history[-1], position_volume, current_atr)
return
# 核心逻辑:相变入场/退出
if position_volume == 0:
self.evaluate_entry_signal(open_price, trend_strength, dominant_freq)
else:
self.manage_open_position(position_volume, trend_strength, dominant_freq)
if self.pending_order:
self.manage_pending_order()
if self.pending_order:
return
self.identify_new_signal(bar_history, current_atr)
def manage_position(self, last_bar: Bar, position_volume: int, current_atr: float):
"""管理当前持仓 (逻辑不变)"""
# ... (此部分代码与上一版完全相同,保持不变) ...
if position_volume > 0 and last_bar.low <= self.initial_stop_price:
self.log(f"⬇️ LONG STOP LOSS: Low={last_bar.low:.4f} <= Stop={self.initial_stop_price:.4f}")
self.close_position("CLOSE_LONG", abs(position_volume))
return
elif position_volume < 0 and last_bar.high >= self.initial_stop_price:
self.log(f"⬆️ SHORT STOP LOSS: High={last_bar.high:.4f} >= Stop={self.initial_stop_price:.4f}")
self.close_position("CLOSE_SHORT", abs(position_volume))
return
bar_range = last_bar.high - last_bar.low
exit_threshold = current_atr * self.exit_signal_atr_mult
if position_volume > 0 and last_bar.close < last_bar.open and bar_range > exit_threshold:
self.log(
f"⬇️ LONG VOLATILITY EXIT: Strong Bearish Bar. Range={bar_range:.2f} > Threshold={exit_threshold:.2f}")
self.close_position("CLOSE_LONG", abs(position_volume))
elif position_volume < 0 and last_bar.close > last_bar.open and bar_range > exit_threshold:
self.log(
f"⬆️ SHORT VOLATILITY EXIT: Strong Bullish Bar. Range={bar_range:.2f} > Threshold={exit_threshold:.2f}")
self.close_position("CLOSE_SHORT", abs(position_volume))
def manage_pending_order(self):
"""管理挂单 (逻辑不变)"""
# ... (此部分代码与上一版完全相同,保持不变) ...
if not self.pending_order:
return
self.pending_order['bars_waited'] += 1
if self.pending_order['bars_waited'] >= self.order_expiry_bars:
self.log(f"⌛️ PENDING ORDER EXPIRED: Order for {self.pending_order['direction']} "
f"at {self.pending_order['price']:.4f} cancelled after {self.pending_order['bars_waited']} bars.")
self.cancel_order(self.pending_order['id'])
self.pending_order = None
def identify_new_signal(self, bar_history: List[Bar], current_atr: float):
def calculate_trend_strength(self, prices: np.array) -> (float, float):
"""
核心逻辑重构】
识别新的交易信号。信号源现在统一为 (close - prev_close) 的波幅。
显式傅里叶】计算低频能量占比 (完全参数化)
步骤:
1. 价格归一化 (窗口内)
2. 短时傅里叶变换 (STFT) - 采样率=bars_per_day
3. 动态计算频段边界 (基于bars_per_day)
4. 趋势强度 = 低频能量 / (低频+高频能量)
"""
last_bar = bar_history[-1]
prev_bar = bar_history[-2]
# 1. 验证数据长度
if len(prices) < self.spectral_window:
return 0.0, 0.0
# --- CHANGE 1: 定义统一的信号阈值 ---
signal_threshold = current_atr * self.signal_move_atr_mult
# 2. 价格归一化 (仅使用窗口内数据)
window_data = prices[-self.spectral_window:]
normalized = (window_data - np.mean(window_data)) / (np.std(window_data) + 1e-8)
# --- CHANGE 2: 计算核心的 close-to-close 波幅 ---
move_height = last_bar.close - prev_bar.close
# 3. STFT (采样率=bars_per_day)
try:
# fs: 每天的样本数 (bars_per_day)
f, t, Zxx = stft(
normalized,
fs=self.bars_per_day, # 关键: 适配市场结构
nperseg=self.spectral_window,
noverlap=max(0, self.spectral_window // 2),
boundary=None,
padded=False
)
except Exception as e:
self.log(f"STFT calculation error: {str(e)}")
return 0.0, 0.0
# --- 多头信号: 上涨波幅超过阈值 ---
if move_height > signal_threshold:
self.log(f"💡 Bullish Shock Detected: Move={move_height:.2f} > Threshold={signal_threshold:.2f}")
# 回调计算仍然基于 last_bar.high 作为情绪顶点,但回调深度由更稳健的 move_height 决定
entry_price = last_bar.high - (move_height * self.entry_pullback_pct)
stop_price = entry_price - (current_atr * self.initial_stop_atr_mult)
self.place_limit_order("BUY", entry_price, stop_price)
return
# 4. 过滤无效频率 (STFT返回频率范围: 0 到 fs/2)
valid_mask = (f >= 0) & (f <= self.bars_per_day / 2)
f = f[valid_mask]
Zxx = Zxx[valid_mask, :]
# --- 空头信号: 下跌波幅超过阈值 ---
# 注意: move_height此时为负数
if move_height < -signal_threshold:
down_move_height = abs(move_height)
self.log(f"💡 Bearish Shock Detected: Move={move_height:.2f} < Threshold={-signal_threshold:.2f}")
# 回调计算基于 last_bar.low 作为情绪谷点
entry_price = last_bar.low + (down_move_height * self.entry_pullback_pct)
stop_price = entry_price + (current_atr * self.initial_stop_atr_mult)
self.place_limit_order("SELL", entry_price, stop_price)
if Zxx.size == 0 or Zxx.shape[1] == 0:
return 0.0, 0.0
def place_limit_order(self, direction: str, price: float, stop_price: float):
"""创建、记录并发送一个新的限价挂单"""
# ... (此部分代码与上一版完全相同,保持不变) ...
order_id = self.generate_order_id(direction, "OPEN")
self.pending_order = {
"id": order_id, "symbol": self.symbol, "direction": direction,
"volume": self.trade_volume, "price": price, "stop_price": stop_price,
"bars_waited": 0,
}
self.log(f"🆕 Placing Limit Order: {direction} at {price:.4f} (Stop: {stop_price:.4f}).")
# 5. 计算最新时间点的能量
current_energy = np.abs(Zxx[:, -1]) ** 2
order = Order(
id=order_id, symbol=self.symbol, direction=direction,
volume=self.trade_volume, price_type="LIMIT", limit_price=price,
submitted_time=self.get_current_time(), offset="OPEN"
)
self.send_order(order)
# 6. 动态频段定义 (cycles/day)
# 低频: 周期 > low_freq_days → 频率 < 1/low_freq_days
low_freq_mask = f < self.low_freq_bound
# 高频: 周期 < high_freq_days → 频率 > 1/high_freq_days
high_freq_mask = f > self.high_freq_bound
def on_trade(self, trade):
"""处理成交回报 (逻辑不变)"""
# ... (此部分代码与上一版完全相同,保持不变) ...
if self.pending_order and trade.id == self.pending_order['id']:
self.log(
f"✅ Order Filled: {trade.direction} at {trade.price:.4f}. Stop loss set at {self.pending_order['stop_price']:.4f}")
self.position_entry_price = trade.price
self.initial_stop_price = self.pending_order['stop_price']
self.pending_order = None
# 7. 能量计算
low_energy = np.sum(current_energy[low_freq_mask]) if np.any(low_freq_mask) else 0.0
high_energy = np.sum(current_energy[high_freq_mask]) if np.any(high_freq_mask) else 0.0
total_energy = low_energy + high_energy + 1e-8 # 防除零
def generate_order_id(self, direction: str, offset: str) -> str:
# ... (此部分代码与上一版完全相同,保持不变) ...
self.order_id_counter += 1
return f"{self.symbol}_{direction}_{offset}_{self.order_id_counter}_{int(datetime.now().timestamp())}"
# 8. 趋势强度 = 低频能量占比
trend_strength = low_energy / total_energy
# 9. 计算主导趋势周期 (天)
dominant_freq = 0.0
if np.any(low_freq_mask) and low_energy > 0:
# 找到低频段最大能量对应的频率
low_energies = current_energy[low_freq_mask]
max_idx = np.argmax(low_energies)
dominant_freq = 1.0 / (f[low_freq_mask][max_idx] + 1e-8) # 转换为周期(天)
return trend_strength, dominant_freq
def evaluate_entry_signal(self, open_price: float, trend_strength: float, dominant_freq: float):
"""评估相变入场信号"""
# 仅当趋势强度跨越临界点且有明确周期时入场
if trend_strength > self.trend_strength_threshold and dominant_freq > self.low_freq_days:
direction = None
indicator = self.model_indicator
# 做多信号: 价格在窗口均值上方
closes = np.array([b.close for b in self.get_bar_history()[-self.spectral_window:]], dtype=float)
if "BUY" in self.order_direction and np.mean(closes[-5:]) > np.mean(closes):
direction = "BUY" if indicator.is_condition_met(*self.get_indicator_tuple()) else "SELL"
# 做空信号: 价格在窗口均值下方
elif "SELL" in self.order_direction and np.mean(closes[-5:]) < np.mean(closes):
direction = "SELL" if indicator.is_condition_met(*self.get_indicator_tuple()) else "BUY"
if direction:
self.log(
f"Phase Transition Entry: {direction} | Strength={trend_strength:.2f} | Dominant Period={dominant_freq:.1f}d")
self.send_limit_order(direction, open_price, self.trade_volume, "OPEN")
self.entry_time = self.get_current_time()
self.position_direction = "LONG" if direction == "BUY" else "SHORT"
def manage_open_position(self, volume: int, trend_strength: float, dominant_freq: float):
"""管理持仓:仅当相变逆转时退出"""
# 相变逆转条件: 趋势强度 < 退出阈值
if trend_strength < self.exit_threshold:
direction = "CLOSE_LONG" if volume > 0 else "CLOSE_SHORT"
self.log(f"Phase Transition Exit: {direction} | Strength={trend_strength:.2f} < {self.exit_threshold}")
self.close_position(direction, abs(volume))
self.entry_time = None
self.position_direction = None
# --- 辅助函数区 ---
def close_all_positions(self):
"""强制平仓所有头寸"""
positions = self.get_current_positions()
if self.symbol in positions and positions[self.symbol] != 0:
direction = "CLOSE_LONG" if positions[self.symbol] > 0 else "CLOSE_SHORT"
self.close_position(direction, abs(positions[self.symbol]))
self.log(f"Forced exit of {abs(positions[self.symbol])} contracts")
def close_position(self, direction: str, volume: int):
# ... (此部分代码与上一版完全相同,保持不变) ...
self.send_market_order(direction, volume, offset="CLOSE")
self.position_entry_price = 0.0
self.initial_stop_price = 0.0
def send_market_order(self, direction: str, volume: int, offset: str = "OPEN"):
# ... (此部分代码与上一版完全相同,保持不变) ...
order_id = self.generate_order_id(direction, offset)
self.log(f"➡️ Sending Market Order: {direction} {volume} {self.symbol} ({offset})")
def send_market_order(self, direction: str, volume: int, offset: str):
order_id = f"{self.symbol}_{direction}_MARKET_{self.order_id_counter}"
self.order_id_counter += 1
order = Order(
id=order_id, symbol=self.symbol, direction=direction,
volume=volume, price_type="MARKET",
submitted_time=self.get_current_time(), offset=offset
id=order_id,
symbol=self.symbol,
direction=direction,
volume=volume,
price_type="MARKET",
submitted_time=self.get_current_time(),
offset=offset
)
self.send_order(order)
def cancel_order(self, order_id: str):
# ... (此部分代码与上一版完全相同,保持不变) ...
self.log(f"❌ Sending Cancel Request for Order: {order_id}")
self.context.cancel_order(order_id)
def send_limit_order(self, direction: str, limit_price: float, volume: int, offset: str):
order_id = f"{self.symbol}_{direction}_MARKET_{self.order_id_counter}"
self.order_id_counter += 1
order = Order(
id=order_id,
symbol=self.symbol,
direction=direction,
volume=volume,
price_type="LIMIT",
submitted_time=self.get_current_time(),
offset=offset,
limit_price=limit_price
)
self.send_order(order)
def on_init(self):
super().on_init()
self.cancel_all_pending_orders(self.main_symbol)
self.log("Strategy initialized. Waiting for phase transition signals...")
def on_rollover(self, old_symbol: str, new_symbol: str):
super().on_rollover(old_symbol, new_symbol)
self.log(f"Rollover from {old_symbol} to {new_symbol}. Resetting position state.")
self.entry_time = None
self.position_direction = None
self.last_trend_strength = 0.0

View File

@@ -1,6 +1,7 @@
import numpy as np
import talib
from typing import Optional, Any, List
from scipy.signal import stft
from datetime import datetime, timedelta
from typing import Optional, Any, List, Dict
from src.core_data import Bar, Order
from src.indicators.base_indicators import Indicator
@@ -8,137 +9,346 @@ from src.indicators.indicators import Empty
from src.strategies.base_strategy import Strategy
class SuperTrendStrategy(Strategy):
# =============================================================================
# 策略实现 (VolatilityAdaptiveSpectralStrategy)
# =============================================================================
class SpectralTrendStrategy(Strategy):
"""
SuperTrend 策略
- 基于 ATR 和价格波动构建上下轨
- 价格上穿上轨 → 开多(且多头条件满足)
- 价格下穿下轨 → 开空(且空头条件满足)
- 反向穿越 → 平仓并反手(或仅平仓,支持空仓)
- 标准 SuperTrend 公式:使用 ATR * multiplier
波动率自适应频域趋势策略
核心哲学:
1. 显式傅里叶变换: 分离低频(趋势)、高频(噪音)能量
2. 波动率条件信号: 根据波动率环境动态调整交易方向
- 低波动环境: 趋势策略 (高趋势强度 → 延续)
- 高波动环境: 反转策略 (高趋势强度 → 反转)
3. 无硬编码参数: 所有阈值通过配置参数设定
4. 严格无未来函数: 所有计算使用历史数据
参数说明:
- bars_per_day: 市场每日K线数量
- volatility_lookback: 波动率计算窗口(天)
- low_vol_threshold: 低波动环境阈值(0-1)
- high_vol_threshold: 高波动环境阈值(0-1)
"""
def __init__(
self,
context: Any,
main_symbol: str,
enable_log: bool,
trade_volume: int,
atr_period: int = 10,
atr_multiplier: float = 3.0,
order_direction: Optional[List[str]] = None,
indicators: Optional[List[Indicator]] = None,
self,
context: Any,
main_symbol: str,
enable_log: bool,
trade_volume: int,
# --- 【市场结构参数】 ---
bars_per_day: int = 23, # 适配23根/天的市场
# --- 【频域核心参数】 ---
spectral_window_days: float = 2.0, # STFT窗口大小(天)
low_freq_days: float = 2.0, # 低频下限(天)
high_freq_days: float = 1.0, # 高频上限(天)
trend_strength_threshold: float = 0.8, # 趋势强度阈值
exit_threshold: float = 0.5, # 退出阈值
# --- 【波动率参数】 ---
volatility_lookback_days: float = 5.0, # 波动率计算窗口(天)
low_vol_threshold: float = 0.3, # 低波动环境阈值(0-1)
high_vol_threshold: float = 0.7, # 高波动环境阈值(0-1)
# --- 【持仓管理】 ---
max_hold_days: int = 10, # 最大持仓天数
# --- 其他 ---
order_direction: Optional[List[str]] = None,
indicators: Optional[List[Indicator]] = None,
):
super().__init__(context, main_symbol, enable_log)
if order_direction is None:
order_direction = ["BUY", "SELL"]
order_direction = ['BUY', 'SELL']
if indicators is None:
indicators = [Empty(), Empty()]
indicators = [Empty(), Empty()] # 保持兼容性
# --- 参数赋值 (完全参数化) ---
self.trade_volume = trade_volume
self.atr_period = atr_period
self.atr_multiplier = atr_multiplier
self.bars_per_day = bars_per_day
self.spectral_window_days = spectral_window_days
self.low_freq_days = low_freq_days
self.high_freq_days = high_freq_days
self.trend_strength_threshold = trend_strength_threshold
self.exit_threshold = exit_threshold
self.volatility_lookback_days = volatility_lookback_days
self.low_vol_threshold = low_vol_threshold
self.high_vol_threshold = high_vol_threshold
self.max_hold_days = max_hold_days
self.order_direction = order_direction
self.indicators = indicators
# --- 动态计算参数 ---
self.spectral_window = int(self.spectral_window_days * self.bars_per_day)
self.spectral_window = self.spectral_window if self.spectral_window % 2 == 0 else self.spectral_window + 1
self.volatility_window = int(self.volatility_lookback_days * self.bars_per_day)
# 频率边界 (cycles/day)
self.low_freq_bound = 1.0 / self.low_freq_days if self.low_freq_days > 0 else float('inf')
self.high_freq_bound = 1.0 / self.high_freq_days if self.high_freq_days > 0 else 0.0
# --- 内部状态变量 ---
self.main_symbol = main_symbol
self.order_id_counter = 0
self.min_bars_needed = atr_period + 10
self.log(f"SuperTrendStrategy Initialized | ATR({atr_period}) × {atr_multiplier}")
self.indicators = indicators
self.entry_time = None # 入场时间
self.position_direction = None # 'LONG' or 'SHORT'
self.last_trend_strength = 0.0
self.last_dominant_freq = 0.0 # 主导周期(天)
self.last_volatility = 0.0 # 标准化波动率(0-1)
self.volatility_history = [] # 存储历史波动率
self.log(f"VolatilityAdaptiveSpectralStrategy Initialized (bars/day={bars_per_day}, "
f"window={self.spectral_window} bars, vol_window={self.volatility_window} bars)")
def on_open_bar(self, open_price: float, symbol: str):
"""每根K线开盘时被调用"""
self.symbol = symbol
bar_history = self.get_bar_history()
current_time = self.get_current_time()
if len(bar_history) < self.min_bars_needed or not self.trading:
# 需要足够的数据 (最大窗口 + 缓冲)
min_required = max(self.spectral_window, self.volatility_window) + 10
if len(bar_history) < min_required:
if self.enable_log and len(bar_history) % 50 == 0:
self.log(f"Waiting for {len(bar_history)}/{min_required} bars")
return
position = self.get_current_positions().get(self.symbol, 0)
position_volume = self.get_current_positions().get(self.symbol, 0)
# 提取 OHLC
highs = np.array([b.high for b in bar_history], dtype=float)
lows = np.array([b.low for b in bar_history], dtype=float)
closes = np.array([b.close for b in bar_history], dtype=float)
# 获取必要历史价格 (仅取所需部分)
recent_bars = bar_history[-(max(self.spectral_window, self.volatility_window) + 5):]
closes = np.array([b.close for b in recent_bars], dtype=np.float32)
highs = np.array([b.high for b in recent_bars], dtype=np.float32)
lows = np.array([b.low for b in recent_bars], dtype=np.float32)
# 1. 计算 ATR
atr = talib.ATR(highs, lows, closes, timeperiod=self.atr_period)
# 【核心】计算频域趋势强度 (显式傅里叶)
trend_strength, dominant_freq = self.calculate_trend_strength(closes)
self.last_trend_strength = trend_strength
self.last_dominant_freq = dominant_freq
# 2. 计算基础上下轨
hl2 = (highs + lows) / 2.0
upper_band_basic = hl2 + self.atr_multiplier * atr
lower_band_basic = hl2 - self.atr_multiplier * atr
# 【核心】计算标准化波动率 (0-1范围)
volatility = self.calculate_normalized_volatility(highs, lows, closes)
self.last_volatility = volatility
# 3. 构建 SuperTrend带方向的记忆性逻辑
n = len(closes)
supertrend = np.full(n, np.nan)
direction = np.full(n, 1) # 1 for up, -1 for down
# 检查最大持仓时间 (防止极端事件)
if self.entry_time and (current_time - self.entry_time) >= timedelta(days=self.max_hold_days):
self.log(f"Max hold time reached ({self.max_hold_days} days). Forcing exit.")
self.close_all_positions()
self.entry_time = None
self.position_direction = None
return
# 初始化
supertrend[self.atr_period] = upper_band_basic[self.atr_period]
direction[self.atr_period] = -1 # 初始假设为 downtrend
# 核心逻辑:相变入场/退出
if position_volume == 0:
self.evaluate_entry_signal(open_price, trend_strength, dominant_freq, volatility, recent_bars)
else:
self.manage_open_position(position_volume, trend_strength, volatility)
for i in range(self.atr_period + 1, n):
# 默认继承前值
supertrend[i] = supertrend[i - 1]
direction[i] = direction[i - 1]
def calculate_trend_strength(self, closes: np.array) -> (float, float):
"""
【显式傅里叶】计算低频能量占比 (完全参数化)
"""
if len(closes) < self.spectral_window:
return 0.0, 0.0
# 如果前一根是 downtrend
if direction[i - 1] == -1:
if closes[i] > upper_band_basic[i - 1]:
direction[i] = 1
supertrend[i] = lower_band_basic[i]
else:
supertrend[i] = min(upper_band_basic[i], supertrend[i - 1])
else: # 前一根是 uptrend
if closes[i] < lower_band_basic[i - 1]:
direction[i] = -1
supertrend[i] = upper_band_basic[i]
else:
supertrend[i] = max(lower_band_basic[i], supertrend[i - 1])
# 仅使用窗口内数据
window_data = closes[-self.spectral_window:]
window_mean = np.mean(window_data)
window_std = np.std(window_data)
if window_std < 1e-8:
return 0.0, 0.0
# 获取最新状态
current_direction = direction[-1]
prev_direction = direction[-2] if len(direction) >= 2 else current_direction
normalized = (window_data - window_mean) / window_std
# 4. 确定目标仓位
target_position = 0
if current_direction == 1:
if self.indicators[0].is_condition_met(*self.get_indicator_tuple()):
target_position = self.trade_volume # 做多
elif current_direction == -1:
if self.indicators[1].is_condition_met(*self.get_indicator_tuple()):
target_position = -self.trade_volume # 做空
try:
f, t, Zxx = stft(
normalized,
fs=self.bars_per_day,
nperseg=self.spectral_window,
noverlap=max(0, self.spectral_window // 2),
boundary=None,
padded=False
)
except Exception as e:
self.log(f"STFT calculation error: {str(e)}")
return 0.0, 0.0
# 5. 平仓逻辑:方向翻转即平仓(即使不开反手,也先平)
current_position = position
should_close_long = (current_position > 0) and (current_direction == -1)
should_close_short = (current_position < 0) and (current_direction == 1)
# 过滤无效频率
max_freq = self.bars_per_day / 2
valid_mask = (f >= 0) & (f <= max_freq)
if not np.any(valid_mask):
return 0.0, 0.0
# 6. 执行订单
if should_close_long or should_close_short or (target_position != current_position):
# 先平旧仓
if current_position > 0:
self.close_position("CLOSE_LONG", current_position)
elif current_position < 0:
self.close_position("CLOSE_SHORT", -current_position)
f = f[valid_mask]
Zxx = Zxx[valid_mask, :]
# 再开新仓(如果条件满足)
if target_position > 0:
self.send_market_order("BUY", target_position, "OPEN")
self.log(f"📈 SuperTrend Long | ATR={atr[-1]:.2f}, Dir=+1")
elif target_position < 0:
self.send_market_order("SELL", -target_position, "OPEN")
self.log(f"📉 SuperTrend Short | ATR={atr[-1]:.2f}, Dir=-1")
if Zxx.size == 0 or Zxx.shape[1] == 0:
return 0.0, 0.0
# --- 模板方法 ---
def on_init(self):
super().on_init()
self.cancel_all_pending_orders(self.main_symbol)
# 计算最新时间点的能量
current_energy = np.abs(Zxx[:, -1]) ** 2
# 动态频段定义
low_freq_mask = f < self.low_freq_bound
high_freq_mask = f > self.high_freq_bound
# 能量计算
low_energy = np.sum(current_energy[low_freq_mask]) if np.any(low_freq_mask) else 0.0
high_energy = np.sum(current_energy[high_freq_mask]) if np.any(high_freq_mask) else 0.0
total_energy = low_energy + high_energy + 1e-8
# 趋势强度 = 低频能量占比
trend_strength = low_energy / total_energy
# 计算主导趋势周期 (天)
dominant_freq = 0.0
if np.any(low_freq_mask) and low_energy > 0:
low_energies = current_energy[low_freq_mask]
max_idx = np.argmax(low_energies)
dominant_freq = 1.0 / (f[low_freq_mask][max_idx] + 1e-8)
return float(trend_strength), float(dominant_freq)
def calculate_normalized_volatility(self, highs: np.array, lows: np.array, closes: np.array) -> float:
"""
计算标准化波动率 (0-1范围)
步骤:
1. 计算ATR (真实波幅)
2. 标准化ATR (除以价格)
3. 归一化到0-1范围 (基于历史波动率)
"""
if len(closes) < self.volatility_window + 1:
return 0.5 # 默认中性值
# 1. 计算真实波幅 (TR)
tr1 = highs[-self.volatility_window - 1:] - lows[-self.volatility_window - 1:]
tr2 = np.abs(highs[-self.volatility_window - 1:] - np.roll(closes, 1)[-self.volatility_window - 1:])
tr3 = np.abs(lows[-self.volatility_window - 1:] - np.roll(closes, 1)[-self.volatility_window - 1:])
tr = np.maximum(tr1, np.maximum(tr2, tr3))
# 2. 计算ATR
atr = np.mean(tr[-self.volatility_window:])
# 3. 标准化ATR (除以当前价格)
current_price = closes[-1]
normalized_atr = atr / current_price if current_price > 0 else 0.0
# 4. 归一化到0-1范围 (基于历史波动率)
self.volatility_history.append(normalized_atr)
if len(self.volatility_history) > 1000: # 保留1000个历史值
self.volatility_history.pop(0)
if len(self.volatility_history) < 50: # 需要足够历史数据
return 0.5
# 使用历史50-95百分位进行归一化
low_percentile = np.percentile(self.volatility_history, 50)
high_percentile = np.percentile(self.volatility_history, 95)
if high_percentile - low_percentile < 1e-8:
return 0.5
# 归一化到0-1范围
normalized_vol = (normalized_atr - low_percentile) / (high_percentile - low_percentile + 1e-8)
normalized_vol = max(0.0, min(1.0, normalized_vol)) # 限制在0-1范围内
return normalized_vol
def evaluate_entry_signal(self, open_price: float, trend_strength: float, dominant_freq: float,
volatility: float, recent_bars: List[Bar]):
"""评估波动率条件入场信号"""
# 仅当趋势强度跨越临界点且有明确周期时入场
if trend_strength > self.trend_strength_threshold:
direction = None
trade_type = ""
# 计算价格位置 (短期vs长期均值)
window_closes = np.array([b.close for b in recent_bars[-self.spectral_window:]], dtype=np.float32)
short_avg = np.mean(window_closes[-5:])
long_avg = np.mean(window_closes)
# 添加统计显著性过滤
if abs(short_avg - long_avg) < 0.0005 * long_avg:
return
# 【核心】根据波动率环境决定交易逻辑
if volatility < self.low_vol_threshold:
# 低波动环境: 趋势策略
trade_type = "TREND"
if "BUY" in self.order_direction and short_avg > long_avg:
direction = "BUY"
elif "SELL" in self.order_direction and short_avg < long_avg:
direction = "SELL"
elif volatility > self.high_vol_threshold:
# 高波动环境: 反转策略
trade_type = "REVERSAL"
if "BUY" in self.order_direction and short_avg < long_avg:
direction = "BUY" # 价格低于均值,预期回归
elif "SELL" in self.order_direction and short_avg > long_avg:
direction = "SELL" # 价格高于均值,预期反转
else:
# 中波动环境: 谨慎策略 (需要更强信号)
trade_type = "CAUTIOUS"
if trend_strength > 0.9 and "BUY" in self.order_direction and short_avg > long_avg:
direction = "BUY"
elif trend_strength > 0.9 and "SELL" in self.order_direction and short_avg < long_avg:
direction = "SELL"
if direction:
self.log(
f"Entry: {direction} | Type={trade_type} | Strength={trend_strength:.2f} | "
f"Volatility={volatility:.2f} | ShortAvg={short_avg:.4f} vs LongAvg={long_avg:.4f}"
)
self.send_market_order(direction, self.trade_volume, "OPEN")
self.entry_time = self.get_current_time()
self.position_direction = "LONG" if direction == "BUY" else "SHORT"
def manage_open_position(self, volume: int, trend_strength: float, volatility: float):
"""管理持仓:波动率条件退出"""
# 退出条件1: 趋势强度 < 退出阈值
if trend_strength < self.exit_threshold:
direction = "CLOSE_LONG" if volume > 0 else "CLOSE_SHORT"
self.log(f"Exit (Strength): {direction} | Strength={trend_strength:.2f} < {self.exit_threshold}")
self.close_position(direction, abs(volume))
self.entry_time = None
self.position_direction = None
return
# 退出条件2: 波动率环境突变 (从低波动变为高波动,或反之)
if self.position_direction == "LONG" and volatility > self.high_vol_threshold * 1.2:
# 多头仓位在波动率突增时退出
self.log(
f"Exit (Volatility Spike): CLOSE_LONG | Volatility={volatility:.2f} > {self.high_vol_threshold * 1.2:.2f}")
self.close_position("CLOSE_LONG", abs(volume))
self.entry_time = None
self.position_direction = None
elif self.position_direction == "SHORT" and volatility > self.high_vol_threshold * 1.2:
# 空头仓位在波动率突增时退出
self.log(
f"Exit (Volatility Spike): CLOSE_SHORT | Volatility={volatility:.2f} > {self.high_vol_threshold * 1.2:.2f}")
self.close_position("CLOSE_SHORT", abs(volume))
self.entry_time = None
self.position_direction = None
# --- 辅助函数区 ---
def close_all_positions(self):
"""强制平仓所有头寸"""
positions = self.get_current_positions()
if not positions or self.symbol not in positions or positions[self.symbol] == 0:
return
direction = "CLOSE_LONG" if positions[self.symbol] > 0 else "CLOSE_SHORT"
self.close_position(direction, abs(positions[self.symbol]))
if self.enable_log:
self.log(f"Closed {abs(positions[self.symbol])} contracts")
def close_position(self, direction: str, volume: int):
self.send_market_order(direction, volume, offset="CLOSE")
def send_market_order(self, direction: str, volume: int, offset: str):
order_id = f"{self.symbol}_{direction}_MARKET_{self.order_id_counter}"
order_id = f"{self.symbol}_{direction[-6:]}_{self.order_id_counter}"
self.order_id_counter += 1
order = Order(
id=order_id,
@@ -147,10 +357,21 @@ class SuperTrendStrategy(Strategy):
volume=volume,
price_type="MARKET",
submitted_time=self.get_current_time(),
offset=offset,
offset=offset
)
self.send_order(order)
def on_init(self):
super().on_init()
self.cancel_all_pending_orders(self.main_symbol)
if self.enable_log:
self.log("Strategy initialized. Waiting for volatility-adaptive signals...")
def on_rollover(self, old_symbol: str, new_symbol: str):
super().on_rollover(old_symbol, new_symbol)
self.log("Rollover: SuperTrendStrategy state reset.")
if self.enable_log:
self.log(f"Rollover: {old_symbol} -> {new_symbol}. Resetting state.")
self.entry_time = None
self.position_direction = None
self.last_trend_strength = 0.0
self.volatility_history = [] # 重置波动率历史