卡尔曼策略新增md文件
This commit is contained in:
File diff suppressed because one or more lines are too long
@@ -0,0 +1,169 @@
|
||||
import numpy as np
|
||||
import talib
|
||||
from typing import Optional, Dict, Any, List, Tuple, Union
|
||||
|
||||
from src.algo.TrendLine import calculate_latest_trendline_values
|
||||
# 假设这些是你项目中的基础模块
|
||||
from src.core_data import Bar, Order
|
||||
from src.indicators.base_indicators import Indicator
|
||||
from src.indicators.indicators import Empty
|
||||
from src.strategies.base_strategy import Strategy
|
||||
|
||||
|
||||
class TrendlineBreakoutStrategy(Strategy):
|
||||
"""
|
||||
趋势线突破策略 V3 (优化版):
|
||||
1. 策略逻辑与 V2 相同,但趋势线计算被重构为一个独立的、
|
||||
高性能的辅助方法。
|
||||
2. 该方法只计算最新的趋势线值,避免不必要的数组生成。
|
||||
|
||||
开仓信号:
|
||||
- 做多: 上一根收盘价上穿下趋势线
|
||||
- 做空: 上一根收盘价下穿上趋势线
|
||||
|
||||
平仓逻辑:
|
||||
- 采用 ATR 滑动止损 (Trailing Stop)。
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
context: Any,
|
||||
main_symbol: str,
|
||||
trendline_n: int = 50,
|
||||
trade_volume: int = 1,
|
||||
order_direction: Optional[List[str]] = None,
|
||||
atr_period: int = 14,
|
||||
atr_multiplier: float = 1.0,
|
||||
enable_log: bool = True,
|
||||
indicators: Union[Indicator, List[Indicator]] = None,
|
||||
):
|
||||
super().__init__(context, main_symbol, enable_log)
|
||||
self.main_symbol = main_symbol
|
||||
self.trendline_n = trendline_n
|
||||
self.trade_volume = trade_volume
|
||||
self.order_direction = order_direction or ["BUY", "SELL"]
|
||||
self.atr_period = atr_period
|
||||
self.atr_multiplier = atr_multiplier
|
||||
self.pos_meta: Dict[str, Dict[str, Any]] = {}
|
||||
if indicators is None:
|
||||
indicators = [Empty(), Empty()]
|
||||
self.indicators = indicators
|
||||
|
||||
if self.trendline_n < 3:
|
||||
raise ValueError("trendline_n 必须大于或等于 3")
|
||||
|
||||
log_message = (
|
||||
f"TrendlineBreakoutStrategy (V3 Optimized) 初始化:\n"
|
||||
f"交易标的={self.main_symbol}, 交易量={self.trade_volume}\n"
|
||||
f"趋势线周期={self.trendline_n}, ATR周期={self.atr_period}, ATR倍数={self.atr_multiplier}"
|
||||
)
|
||||
self.log(log_message)
|
||||
|
||||
|
||||
def _calculate_atr(self, bar_history: List[Bar]) -> Optional[float]:
|
||||
# (此函数与上一版本完全相同,保持不变)
|
||||
if len(bar_history) < self.atr_period + 1: return None
|
||||
highs = np.array([b.high for b in bar_history])
|
||||
lows = np.array([b.low for b in bar_history])
|
||||
closes = np.array([b.close for b in bar_history])
|
||||
atr = talib.ATR(highs, lows, closes, timeperiod=self.atr_period)
|
||||
return atr[-1] if not np.isnan(atr[-1]) else None
|
||||
|
||||
def on_init(self):
|
||||
super().on_init()
|
||||
self.pos_meta.clear()
|
||||
|
||||
def on_open_bar(self, open_price: float, symbol: str):
|
||||
bar_history = self.get_bar_history()
|
||||
min_bars_required = self.trendline_n + 2
|
||||
if len(bar_history) < min_bars_required:
|
||||
return
|
||||
|
||||
self.cancel_all_pending_orders(symbol)
|
||||
pos = self.get_current_positions().get(symbol, 0)
|
||||
|
||||
# 1. 优先处理平仓逻辑 (逻辑不变)
|
||||
meta = self.pos_meta.get(symbol)
|
||||
if meta and pos != 0:
|
||||
current_atr = self._calculate_atr(bar_history[:-1])
|
||||
if current_atr:
|
||||
trailing_stop = meta['trailing_stop']
|
||||
direction = meta['direction']
|
||||
last_close = bar_history[-1].close
|
||||
if direction == "BUY":
|
||||
new_stop_level = last_close - current_atr * self.atr_multiplier
|
||||
trailing_stop = max(trailing_stop, new_stop_level)
|
||||
else: # SELL
|
||||
new_stop_level = last_close + current_atr * self.atr_multiplier
|
||||
trailing_stop = min(trailing_stop, new_stop_level)
|
||||
self.pos_meta[symbol]['trailing_stop'] = trailing_stop
|
||||
if (direction == "BUY" and open_price <= trailing_stop) or \
|
||||
(direction == "SELL" and open_price >= trailing_stop):
|
||||
self.log(f"ATR滑动止损触发: 价格 {open_price:.2f} 触及止损位 {trailing_stop:.2f}")
|
||||
self.send_market_order("CLOSE_LONG" if direction == "BUY" else "CLOSE_SHORT", abs(pos))
|
||||
del self.pos_meta[symbol]
|
||||
return
|
||||
|
||||
# 2. 开仓逻辑 (调用优化后的方法)
|
||||
if pos == 0:
|
||||
prices_for_trendline = np.array([b.close for b in bar_history[-self.trendline_n - 1:-1]])
|
||||
|
||||
# --- 调用新的独立方法 ---
|
||||
trendline_val_upper, trendline_val_lower = calculate_latest_trendline_values(prices_for_trendline)
|
||||
|
||||
if trendline_val_upper is None or trendline_val_lower is None:
|
||||
return # 无法计算趋势线,跳过
|
||||
|
||||
prev_close = bar_history[-2].close
|
||||
last_close = bar_history[-1].close
|
||||
|
||||
current_atr = self._calculate_atr(bar_history[:-1])
|
||||
if not current_atr:
|
||||
return
|
||||
|
||||
# if "BUY" in self.order_direction and last_close > trendline_val_upper and self.indicators[0].is_condition_met(*self.get_indicator_tuple()):
|
||||
# self.log(f"做多信号: Close({last_close:.2f}) 上穿下趋势线({trendline_val_upper:.2f})")
|
||||
# self.send_open_order("BUY", open_price, self.trade_volume, current_atr)
|
||||
#
|
||||
# elif "SELL" in self.order_direction and last_close < trendline_val_lower and self.indicators[1].is_condition_met(*self.get_indicator_tuple()):
|
||||
# self.log(f"做空信号: Close({last_close:.2f}) 下穿上趋势线({trendline_val_lower:.2f})")
|
||||
# self.send_open_order("SELL", open_price, self.trade_volume, current_atr)
|
||||
|
||||
if "BUY" in self.order_direction and last_close > trendline_val_upper and self.indicators[0].is_condition_met(*self.get_indicator_tuple()):
|
||||
self.log(f"做多信号: Close({last_close:.2f}) 上穿下趋势线({trendline_val_upper:.2f})")
|
||||
self.send_open_order("BUY", open_price, self.trade_volume, current_atr)
|
||||
|
||||
elif "SELL" in self.order_direction and last_close < trendline_val_lower and self.indicators[1].is_condition_met(*self.get_indicator_tuple()):
|
||||
self.log(f"做空信号: Close({last_close:.2f}) 下穿上趋势线({trendline_val_lower:.2f})")
|
||||
self.send_open_order("SELL", open_price, self.trade_volume, current_atr)
|
||||
|
||||
|
||||
# send_open_order, send_market_order, on_rollover 等方法与上一版本完全相同,保持不变
|
||||
def send_open_order(self, direction: str, entry_price: float, volume: int, current_atr: float):
|
||||
if direction == "BUY":
|
||||
initial_stop = entry_price - current_atr * self.atr_multiplier
|
||||
else:
|
||||
initial_stop = entry_price + current_atr * self.atr_multiplier
|
||||
current_time = self.get_current_time()
|
||||
order_id = f"{self.symbol}_{direction}_{current_time.strftime('%Y%m%d%H%M%S')}"
|
||||
order_direction = "BUY" if direction == "BUY" else "SELL"
|
||||
order = Order(id=order_id, symbol=self.symbol, direction=order_direction, volume=volume, price_type="MARKET",
|
||||
submitted_time=current_time, offset="OPEN")
|
||||
self.send_order(order)
|
||||
self.pos_meta[self.symbol] = {"direction": direction, "volume": volume, "entry_price": entry_price,
|
||||
"trailing_stop": initial_stop}
|
||||
self.log(
|
||||
f"发送开仓订单: {direction} {volume}手 @ Market Price (执行价约 {entry_price:.2f}), 初始ATR止损位: {initial_stop:.2f}")
|
||||
|
||||
def send_market_order(self, direction: str, volume: int):
|
||||
current_time = self.get_current_time()
|
||||
order_id = f"{self.symbol}_{direction}_{current_time.strftime('%Y%m%d%H%M%S')}"
|
||||
order = Order(id=order_id, symbol=self.symbol, direction=direction, volume=volume, price_type="MARKET",
|
||||
submitted_time=current_time, offset="CLOSE")
|
||||
self.send_order(order)
|
||||
self.log(f"发送平仓订单: {direction} {volume}手 @ Market Price")
|
||||
|
||||
def on_rollover(self, old_symbol: str, new_symbol: str):
|
||||
super().on_rollover(old_symbol, new_symbol)
|
||||
self.cancel_all_pending_orders(new_symbol)
|
||||
self.pos_meta.clear()
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,226 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"id": "522f09ca7b3fe929",
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2025-10-21T13:27:14.609968Z",
|
||||
"start_time": "2025-10-21T13:27:14.180365Z"
|
||||
}
|
||||
},
|
||||
"source": [
|
||||
"from datetime import datetime\n",
|
||||
"\n",
|
||||
"from src.data_processing import load_raw_data\n",
|
||||
"%load_ext autoreload\n",
|
||||
"%autoreload 2\n",
|
||||
"\n",
|
||||
"import sys\n",
|
||||
"\n",
|
||||
"if '/mnt/d/PyProject/NewQuant/' not in sys.path:\n",
|
||||
" sys.path.append('/mnt/d/PyProject/NewQuant/')"
|
||||
],
|
||||
"outputs": [],
|
||||
"execution_count": 1
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"id": "4f7e4b438cea750e",
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2025-10-21T13:27:14.989537Z",
|
||||
"start_time": "2025-10-21T13:27:14.615874Z"
|
||||
}
|
||||
},
|
||||
"source": [
|
||||
"from turtle import down\n",
|
||||
"from src.analysis.result_analyzer import ResultAnalyzer\n",
|
||||
"# 导入所有必要的模块\n",
|
||||
"from src.data_manager import DataManager\n",
|
||||
"from src.backtest_engine import BacktestEngine\n",
|
||||
"from src.indicators.indicator_list import INDICATOR_LIST\n",
|
||||
"from src.indicators.indicators import *\n",
|
||||
"\n",
|
||||
"# 导入您自己的 SMC 策略\n",
|
||||
"from src.strategies.TrendlineBreakoutStrategy.TrendlineHawkesStrategyFast import TrendlineHawkesStrategy\n",
|
||||
"\n",
|
||||
"# --- 配置参数 ---\n",
|
||||
"# 获取当前脚本所在目录,假设数据文件在项目根目录下的 data 文件夹内\n",
|
||||
"data_file_path = '/mnt/d/PyProject/NewQuant/data/data/KQ_m@CZCE_MA/KQ_m@CZCE_MA_min15.csv'\n"
|
||||
],
|
||||
"outputs": [],
|
||||
"execution_count": 2
|
||||
},
|
||||
{
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2025-10-21T13:27:15.060902Z",
|
||||
"start_time": "2025-10-21T13:27:14.996119Z"
|
||||
}
|
||||
},
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"\n",
|
||||
"initial_capital = 100000.0\n",
|
||||
"slippage_rate = 0.000 # 假设每笔交易0.1%的滑点\n",
|
||||
"commission_rate = 0.0000 # 假设每笔交易0.02%的佣金\n",
|
||||
"\n",
|
||||
"global_config = {\n",
|
||||
" 'symbol': 'KQ_m@CZCE_MA', # 确保与数据文件中的 symbol 匹配\n",
|
||||
"}\n",
|
||||
"\n",
|
||||
"# 回测时间范围\n",
|
||||
"start_time = datetime(2021, 1, 1)\n",
|
||||
"end_time = datetime(2024, 6, 1)\n",
|
||||
"\n",
|
||||
"start_time = datetime(2024, 1, 1)\n",
|
||||
"end_time = datetime(2025, 8, 1)\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"indicators = INDICATOR_LIST\n",
|
||||
"indicators = []\n",
|
||||
"\n",
|
||||
"# 确保 DataManager 能够重置以进行多次回测\n",
|
||||
"# data_manager.reset() # 首次运行不需要重置"
|
||||
],
|
||||
"id": "9ee53c41eaaefabb",
|
||||
"outputs": [],
|
||||
"execution_count": 3
|
||||
},
|
||||
{
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2025-10-21T13:27:21.876001Z",
|
||||
"start_time": "2025-10-21T13:27:15.070471Z"
|
||||
}
|
||||
},
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"from src.indicators.indicators import ROC_MA\n",
|
||||
"\n",
|
||||
"# --- 1. 初始化数据管理器 ---\n",
|
||||
"print(\"初始化数据管理器...\")\n",
|
||||
"data_manager = DataManager(file_path=data_file_path, symbol=global_config['symbol'], start_time=start_time,\n",
|
||||
" end_time=end_time)\n",
|
||||
"\n",
|
||||
"strategy_parameters = {\n",
|
||||
" 'main_symbol': 'MA', # <-- 替换为你的交易品种代码,例如 'GC=F' (黄金期货), 'ZC=F' (玉米期货)\n",
|
||||
" 'trade_volume': 1,\n",
|
||||
" 'trendline_n': 70,\n",
|
||||
" 'hawkes_kappa': 0.1,\n",
|
||||
" 'order_direction': ['SELL', 'BUY'],\n",
|
||||
" 'reverse_logic': True,\n",
|
||||
" # 'indicators': [RateOfChange(10, -2.1, -0.5), ROC_MA(10, 10, -2.7, -0.4)],\n",
|
||||
" 'enable_log': False\n",
|
||||
"}\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# --- 2. 初始化回测引擎并运行 ---\n",
|
||||
"print(\"\\n初始化回测引擎...\")\n",
|
||||
"engine = BacktestEngine(\n",
|
||||
" data_manager=data_manager,\n",
|
||||
" strategy_class=TrendlineHawkesStrategy, # <--- 更改为您的 SMC 策略类\n",
|
||||
" # current_segment_symbol 参数已从 SMCPureH1LongStrategy 中移除,不需要设置\n",
|
||||
" strategy_params=strategy_parameters,\n",
|
||||
" initial_capital=initial_capital,\n",
|
||||
" slippage_rate=slippage_rate,\n",
|
||||
" commission_rate=commission_rate,\n",
|
||||
" roll_over_mode=True,\n",
|
||||
" start_time=start_time,\n",
|
||||
" end_time=end_time,\n",
|
||||
" indicators=indicators # 如果您的 SMC 策略不使用这些指标,也可以考虑移除\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"print(\"\\n开始运行回测...\")\n",
|
||||
"engine.run_backtest()\n",
|
||||
"print(\"\\n回测运行完毕。\")\n",
|
||||
"\n",
|
||||
"# --- 3. 获取回测结果 ---\n",
|
||||
"results = engine.get_backtest_results()\n",
|
||||
"portfolio_snapshots = results[\"portfolio_snapshots\"]\n",
|
||||
"trade_history = results[\"trade_history\"]\n",
|
||||
"initial_capital_result = results[\"initial_capital\"]\n",
|
||||
"bars = results[\"all_bars\"]\n",
|
||||
"\n",
|
||||
"# --- 4. 结果分析与可视化 ---\n",
|
||||
"if portfolio_snapshots:\n",
|
||||
" analyzer = ResultAnalyzer(portfolio_snapshots, trade_history, bars, initial_capital_result, INDICATOR_LIST)\n",
|
||||
"\n",
|
||||
" analyzer.generate_report()\n",
|
||||
" analyzer.plot_performance()\n",
|
||||
" metrics = analyzer.calculate_all_metrics()\n",
|
||||
" print(metrics)\n",
|
||||
"\n",
|
||||
" analyzer.analyze_indicators()\n",
|
||||
"else:\n",
|
||||
" print(\"\\n没有生成投资组合快照,无法进行结果分析。\")"
|
||||
],
|
||||
"id": "f903fd2761d446cd",
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"初始化数据管理器...\n",
|
||||
"数据加载成功: /mnt/d/PyProject/NewQuant/data/data/KQ_m@CZCE_MA/KQ_m@CZCE_MA_min15.csv\n",
|
||||
"数据范围从 2020-12-31 14:45:00 到 2025-08-21 14:30:00\n",
|
||||
"总计 25596 条记录。\n",
|
||||
"\n",
|
||||
"初始化回测引擎...\n",
|
||||
"模拟器初始化:初始资金=100000.00, 滑点率=0.0, 佣金率=0.0\n",
|
||||
"内存仓储已初始化,管理ID: 'src.strategies.TrendlineBreakoutStrategy.TrendlineHawkesStrategyFast.TrendlineHawkesStrategy_13b1be9c188912b2ee8ccd9e5fb0718d'\n",
|
||||
"\n",
|
||||
"--- 回测引擎初始化完成 ---\n",
|
||||
" 策略: TrendlineHawkesStrategy\n",
|
||||
" 初始资金: 100000.00\n",
|
||||
" 换月模式: 启用\n",
|
||||
"\n",
|
||||
"开始运行回测...\n",
|
||||
"\n",
|
||||
"--- 回测开始 ---\n",
|
||||
"TrendlineHawkesStrategy 策略初始化回调被调用。\n",
|
||||
"开始将 DataFrame 转换为 Bar 对象流...\n",
|
||||
"首次运行,正在初始化霍克斯状态和滚动窗口...\n",
|
||||
"状态初始化完成。\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"ename": "KeyboardInterrupt",
|
||||
"evalue": "",
|
||||
"output_type": "error",
|
||||
"traceback": [
|
||||
"\u001B[31m---------------------------------------------------------------------------\u001B[39m",
|
||||
"\u001B[31mKeyboardInterrupt\u001B[39m Traceback (most recent call last)",
|
||||
"\u001B[36mCell\u001B[39m\u001B[36m \u001B[39m\u001B[32mIn[4]\u001B[39m\u001B[32m, line 38\u001B[39m\n\u001B[32m 23\u001B[39m engine = BacktestEngine(\n\u001B[32m 24\u001B[39m data_manager=data_manager,\n\u001B[32m 25\u001B[39m strategy_class=TrendlineHawkesStrategy, \u001B[38;5;66;03m# <--- 更改为您的 SMC 策略类\u001B[39;00m\n\u001B[32m (...)\u001B[39m\u001B[32m 34\u001B[39m indicators=indicators \u001B[38;5;66;03m# 如果您的 SMC 策略不使用这些指标,也可以考虑移除\u001B[39;00m\n\u001B[32m 35\u001B[39m )\n\u001B[32m 37\u001B[39m \u001B[38;5;28mprint\u001B[39m(\u001B[33m\"\u001B[39m\u001B[38;5;130;01m\\n\u001B[39;00m\u001B[33m开始运行回测...\u001B[39m\u001B[33m\"\u001B[39m)\n\u001B[32m---> \u001B[39m\u001B[32m38\u001B[39m \u001B[43mengine\u001B[49m\u001B[43m.\u001B[49m\u001B[43mrun_backtest\u001B[49m\u001B[43m(\u001B[49m\u001B[43m)\u001B[49m\n\u001B[32m 39\u001B[39m \u001B[38;5;28mprint\u001B[39m(\u001B[33m\"\u001B[39m\u001B[38;5;130;01m\\n\u001B[39;00m\u001B[33m回测运行完毕。\u001B[39m\u001B[33m\"\u001B[39m)\n\u001B[32m 41\u001B[39m \u001B[38;5;66;03m# --- 3. 获取回测结果 ---\u001B[39;00m\n",
|
||||
"\u001B[36mFile \u001B[39m\u001B[32m/mnt/d/PyProject/NewQuant/src/backtest_engine.py:166\u001B[39m, in \u001B[36mBacktestEngine.run_backtest\u001B[39m\u001B[34m(self)\u001B[39m\n\u001B[32m 163\u001B[39m \u001B[38;5;28mself\u001B[39m.strategy.on_open_bar(current_bar.open, current_bar.symbol)\n\u001B[32m 165\u001B[39m current_indicator_dict = {}\n\u001B[32m--> \u001B[39m\u001B[32m166\u001B[39m close_array = \u001B[43mnp\u001B[49m\u001B[43m.\u001B[49m\u001B[43marray\u001B[49m\u001B[43m(\u001B[49m\u001B[38;5;28;43mself\u001B[39;49m\u001B[43m.\u001B[49m\u001B[43mclose_list\u001B[49m\u001B[43m)\u001B[49m\n\u001B[32m 167\u001B[39m open_array = np.array(\u001B[38;5;28mself\u001B[39m.open_list)\n\u001B[32m 168\u001B[39m high_array = np.array(\u001B[38;5;28mself\u001B[39m.high_list)\n",
|
||||
"\u001B[31mKeyboardInterrupt\u001B[39m: "
|
||||
]
|
||||
}
|
||||
],
|
||||
"execution_count": 4
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "quant",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.12.11"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
@@ -0,0 +1,178 @@
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from typing import Optional, Dict, Any, List
|
||||
|
||||
# 假设这些是你项目中的模块
|
||||
from src.core_data import Bar, Order
|
||||
from src.strategies.base_strategy import Strategy
|
||||
from src.algo.TrendLine import calculate_latest_trendline_values
|
||||
from src.algo.HawksProcess import calculate_hawkes_bands
|
||||
|
||||
|
||||
class TrendlineHawkesStrategy(Strategy):
|
||||
"""
|
||||
趋势线与霍克斯过程双重确认策略 (V2 - 支持逻辑反转):
|
||||
|
||||
入场信号 (双重确认):
|
||||
1. 趋势线事件: 收盘价突破上轨(标准模式做多)或下轨(标准模式做空)。
|
||||
2. 霍克斯确认: 同时,成交量霍克斯强度必须高于其近期高位分位数。
|
||||
|
||||
出场逻辑 (基于霍克斯过程):
|
||||
- 当成交量霍克斯强度从高位回落至近期低位分位数以下时,平仓。
|
||||
|
||||
逻辑反转 (`reverse_logic=True`):
|
||||
- 趋势线突破上轨时,开【空】仓。
|
||||
- 趋势线突破下轨时,开【多】仓。
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
context: Any,
|
||||
main_symbol: str,
|
||||
trade_volume: int = 1,
|
||||
order_direction: Optional[List[str]] = None,
|
||||
# --- 新增: 逻辑反转开关 ---
|
||||
reverse_logic: bool = False,
|
||||
# --- 趋势线参数 ---
|
||||
trendline_n: int = 50,
|
||||
# --- 霍克斯过程参数 ---
|
||||
hawkes_kappa: float = 0.1,
|
||||
hawkes_lookback: int = 50,
|
||||
hawkes_entry_percent: float = 0.95,
|
||||
hawkes_exit_percent: float = 0.50,
|
||||
enable_log: bool = True,
|
||||
):
|
||||
super().__init__(context, main_symbol, enable_log)
|
||||
self.main_symbol = main_symbol
|
||||
self.trade_volume = trade_volume
|
||||
self.order_direction = order_direction or ["BUY", "SELL"]
|
||||
|
||||
# --- 新增 ---
|
||||
self.reverse_logic = reverse_logic
|
||||
|
||||
self.trendline_n = trendline_n
|
||||
self.hawkes_kappa = hawkes_kappa
|
||||
self.hawkes_lookback = hawkes_lookback
|
||||
self.hawkes_entry_percent = hawkes_entry_percent
|
||||
self.hawkes_exit_percent = hawkes_exit_percent
|
||||
|
||||
self.pos_meta: Dict[str, Dict[str, Any]] = {}
|
||||
|
||||
if self.trendline_n < 3:
|
||||
raise ValueError("trendline_n 必须大于或等于 3")
|
||||
|
||||
log_message = (
|
||||
f"TrendlineHawkesStrategy 初始化:\n"
|
||||
f"【逻辑模式】: {'反转 (Reversal)' if self.reverse_logic else '标准 (Breakout)'}\n"
|
||||
f"趋势线周期={self.trendline_n}\n"
|
||||
f"霍克斯参数: kappa={self.hawkes_kappa}, lookback={self.hawkes_lookback}, "
|
||||
f"entry_pct={self.hawkes_entry_percent}, exit_pct={self.hawkes_exit_percent}"
|
||||
)
|
||||
self.log(log_message)
|
||||
|
||||
def on_init(self):
|
||||
# (此函数保持不变)
|
||||
super().on_init()
|
||||
self.pos_meta.clear()
|
||||
|
||||
def on_open_bar(self, open_price: float, symbol: str):
|
||||
bar_history = self.get_bar_history()
|
||||
min_bars_required = max(self.trendline_n + 2, self.hawkes_lookback + 2)
|
||||
if len(bar_history) < min_bars_required:
|
||||
return
|
||||
|
||||
self.cancel_all_pending_orders(symbol)
|
||||
pos = self.get_current_positions().get(symbol, 0)
|
||||
|
||||
# --- 数据准备 (与之前相同) ---
|
||||
close_prices = np.array([b.close for b in bar_history])
|
||||
volume_series = pd.Series(
|
||||
[b.volume for b in bar_history],
|
||||
index=pd.to_datetime([b.datetime for b in bar_history])
|
||||
)
|
||||
|
||||
vol_hawkes, hawkes_upper_band, hawkes_lower_band = calculate_hawkes_bands(
|
||||
volume_series, self.hawkes_lookback, self.hawkes_kappa,
|
||||
self.hawkes_entry_percent, self.hawkes_exit_percent
|
||||
)
|
||||
|
||||
latest_hawkes_value = vol_hawkes.iloc[-1]
|
||||
latest_hawkes_upper = hawkes_upper_band.iloc[-1]
|
||||
latest_hawkes_lower = hawkes_lower_band.iloc[-1]
|
||||
|
||||
# 1. 优先处理平仓逻辑 (逻辑保持不变)
|
||||
meta = self.pos_meta.get(symbol)
|
||||
if meta and pos != 0:
|
||||
if latest_hawkes_value < latest_hawkes_lower:
|
||||
self.log(f"霍克斯出场信号: 强度({latest_hawkes_value:.2f}) < 阈值({latest_hawkes_lower:.2f})")
|
||||
self.send_market_order("CLOSE_LONG" if meta['direction'] == "BUY" else "CLOSE_SHORT", abs(pos))
|
||||
del self.pos_meta[symbol]
|
||||
return
|
||||
|
||||
# 2. 开仓逻辑 (加入反转判断)
|
||||
if pos == 0:
|
||||
prices_for_trendline = close_prices[-self.trendline_n - 1:-1]
|
||||
trend_upper, trend_lower = calculate_latest_trendline_values(prices_for_trendline)
|
||||
|
||||
if trend_upper is None or trend_lower is None:
|
||||
return
|
||||
|
||||
prev_close = bar_history[-2].close
|
||||
last_close = bar_history[-1].close
|
||||
|
||||
# --- a) 定义基础的突破【事件】 ---
|
||||
upper_break_event = last_close > trend_upper and prev_close < trend_upper
|
||||
lower_break_event = last_close < trend_lower and prev_close > trend_lower
|
||||
|
||||
# --- b) 定义霍克斯【确认】---
|
||||
hawkes_confirmation = latest_hawkes_value > latest_hawkes_upper
|
||||
|
||||
# 只有当基础事件和霍克斯确认都发生时,才考虑开仓
|
||||
if hawkes_confirmation and (upper_break_event or lower_break_event):
|
||||
|
||||
# --- c) 【核心修改】根据 reverse_logic 决定最终交易方向 ---
|
||||
trade_direction = None
|
||||
|
||||
if upper_break_event: # 价格向上突破上轨
|
||||
# 标准模式:做多 (动量)
|
||||
# 反转模式:做空 (力竭反转)
|
||||
trade_direction = "SELL" if self.reverse_logic else "BUY"
|
||||
|
||||
elif lower_break_event: # 价格向下突破下轨
|
||||
# 标准模式:做空 (动量)
|
||||
# 反转模式:做多 (恐慌探底反转)
|
||||
trade_direction = "BUY" if self.reverse_logic else "SELL"
|
||||
|
||||
# d) 执行交易
|
||||
if trade_direction and trade_direction in self.order_direction:
|
||||
event_type = "向上突破" if upper_break_event else "向下突破"
|
||||
logic_type = "反转" if self.reverse_logic else "标准"
|
||||
self.log(
|
||||
f"{logic_type}模式 {trade_direction} 信号: "
|
||||
f"价格{event_type} & 霍克斯强度({latest_hawkes_value:.2f}) > 阈值({latest_hawkes_upper:.2f})"
|
||||
)
|
||||
self.send_open_order(trade_direction, open_price, self.trade_volume)
|
||||
|
||||
# send_open_order, send_market_order, on_rollover 等方法保持不变
|
||||
def send_open_order(self, direction: str, entry_price: float, volume: int):
|
||||
current_time = self.get_current_time()
|
||||
order_id = f"{self.symbol}_{direction}_{current_time.strftime('%Y%m%d%H%M%S')}"
|
||||
order_direction = "BUY" if direction == "BUY" else "SELL"
|
||||
order = Order(id=order_id, symbol=self.symbol, direction=order_direction, volume=volume, price_type="MARKET",
|
||||
submitted_time=current_time, offset="OPEN")
|
||||
self.send_order(order)
|
||||
self.pos_meta[self.symbol] = {"direction": direction, "volume": volume, "entry_price": entry_price}
|
||||
self.log(f"发送开仓订单: {direction} {volume}手 @ Market Price (执行价约 {entry_price:.2f})")
|
||||
|
||||
def send_market_order(self, direction: str, volume: int):
|
||||
current_time = self.get_current_time()
|
||||
order_id = f"{self.symbol}_{direction}_{current_time.strftime('%Y%m%d%H%M%S')}"
|
||||
order = Order(id=order_id, symbol=self.symbol, direction=direction, volume=volume, price_type="MARKET",
|
||||
submitted_time=current_time, offset="CLOSE")
|
||||
self.send_order(order)
|
||||
self.log(f"发送平仓订单: {direction} {volume}手 @ Market Price")
|
||||
|
||||
def on_rollover(self, old_symbol: str, new_symbol: str):
|
||||
super().on_rollover(old_symbol, new_symbol)
|
||||
self.cancel_all_pending_orders(new_symbol)
|
||||
self.pos_meta.clear()
|
||||
@@ -0,0 +1,188 @@
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from typing import Optional, Dict, Any, List
|
||||
|
||||
# 假设这些是你项目中的模块
|
||||
from src.core_data import Bar, Order
|
||||
from src.strategies.base_strategy import Strategy
|
||||
from src.algo.TrendLine import calculate_latest_trendline_values
|
||||
|
||||
|
||||
class TrendlineHawkesStrategy(Strategy):
|
||||
"""
|
||||
趋势线与霍克斯过程双重确认策略 (V4 - 终极性能版):
|
||||
- 霍克斯过程和滚动分位数都实现为高效的有状态增量计算。
|
||||
- 使用固定长度的Numpy数组作为滚动窗口,避免Pandas.rolling的开销和不一致性。
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
context: Any,
|
||||
main_symbol: str,
|
||||
# ... 参数与V3完全相同 ...
|
||||
trade_volume: int = 1,
|
||||
order_direction: Optional[List[str]] = None,
|
||||
reverse_logic: bool = False,
|
||||
trendline_n: int = 50,
|
||||
hawkes_kappa: float = 0.1,
|
||||
hawkes_lookback: int = 50,
|
||||
hawkes_entry_percent: float = 0.95,
|
||||
hawkes_exit_percent: float = 0.50,
|
||||
enable_log: bool = True,
|
||||
):
|
||||
super().__init__(context, main_symbol, enable_log)
|
||||
# ... 参数赋值与V3完全相同 ...
|
||||
self.main_symbol = main_symbol
|
||||
self.trade_volume = trade_volume
|
||||
self.order_direction = order_direction or ["BUY", "SELL"]
|
||||
self.reverse_logic = reverse_logic
|
||||
self.trendline_n = trendline_n
|
||||
self.hawkes_kappa = hawkes_kappa
|
||||
self.hawkes_lookback = hawkes_lookback
|
||||
self.hawkes_entry_percent = hawkes_entry_percent
|
||||
self.hawkes_exit_percent = hawkes_exit_percent
|
||||
self.pos_meta: Dict[str, Dict[str, Any]] = {}
|
||||
|
||||
# --- 【核心修改】状态缓存重构 ---
|
||||
# 只缓存上一个时间点的霍克斯强度值 (未缩放)
|
||||
self._last_hawkes_unscaled: float = 0.0
|
||||
# 只维护一个固定长度的滚动窗口,用于计算分位数
|
||||
self._hawkes_window: np.ndarray = np.array([], dtype=np.float64)
|
||||
# 衰减因子
|
||||
self._hawkes_alpha = np.exp(-self.hawkes_kappa)
|
||||
|
||||
# ... 日志与V3相同 ...
|
||||
|
||||
def _initialize_state(self, initial_volumes: np.ndarray):
|
||||
"""
|
||||
仅在策略开始时调用一次,用于填充初始的滚动窗口。
|
||||
"""
|
||||
print("首次运行,正在初始化霍克斯状态和滚动窗口...")
|
||||
alpha = self._hawkes_alpha
|
||||
kappa = self.hawkes_kappa
|
||||
|
||||
# 完整计算一次历史强度,只为填充窗口
|
||||
temp_hawkes_history = np.zeros_like(initial_volumes, dtype=np.float64)
|
||||
if len(initial_volumes) > 0:
|
||||
temp_hawkes_history[0] = initial_volumes[0] if not np.isnan(initial_volumes[0]) else 0.0
|
||||
for i in range(1, len(initial_volumes)):
|
||||
temp_hawkes_history[i] = temp_hawkes_history[i - 1] * alpha + (
|
||||
initial_volumes[i] if not np.isnan(initial_volumes[i]) else 0.0)
|
||||
|
||||
# 记录最后一个点的强度值,作为下一次增量计算的起点
|
||||
self._last_hawkes_unscaled = temp_hawkes_history[-1] if len(temp_hawkes_history) > 0 else 0.0
|
||||
|
||||
# 用历史强度值的最后 hawkes_lookback 个点来填充滚动窗口
|
||||
self._hawkes_window = (temp_hawkes_history * kappa)[-self.hawkes_lookback:]
|
||||
print("状态初始化完成。")
|
||||
|
||||
def _update_state_incrementally(self, latest_volume: float):
|
||||
"""
|
||||
【增量计算】在每个新的Bar上调用,更新强度值和滚动窗口。
|
||||
"""
|
||||
# 1. 计算最新的霍克斯强度值 (未缩放)
|
||||
new_hawkes_unscaled = self._last_hawkes_unscaled * self._hawkes_alpha + (
|
||||
latest_volume if not np.isnan(latest_volume) else 0.0)
|
||||
|
||||
# 2. 更新上一个点的状态,为下一次计算做准备
|
||||
self._last_hawkes_unscaled = new_hawkes_unscaled
|
||||
|
||||
# 3. 将新的缩放后的强度值推入滚动窗口
|
||||
new_hawkes_scaled = new_hawkes_unscaled * self.hawkes_kappa
|
||||
|
||||
# np.roll 会高效地将数组元素移动,然后我们将新值放在第一个位置
|
||||
# 这比 append + delete 的效率高得多
|
||||
self._hawkes_window = np.roll(self._hawkes_window, -1)
|
||||
self._hawkes_window[-1] = new_hawkes_scaled
|
||||
|
||||
def on_init(self):
|
||||
super().on_init()
|
||||
self.pos_meta.clear()
|
||||
# 重置状态
|
||||
self._last_hawkes_unscaled = 0.0
|
||||
self._hawkes_window = np.array([], dtype=np.float64)
|
||||
|
||||
def on_open_bar(self, open_price: float, symbol: str):
|
||||
bar_history = self.get_bar_history()
|
||||
min_bars_required = max(self.trendline_n + 2, self.hawkes_lookback + 2)
|
||||
if len(bar_history) < min_bars_required:
|
||||
return
|
||||
|
||||
# --- 【核心修改】霍克斯过程的状态更新 ---
|
||||
# 检查是否是第一次运行
|
||||
if self._hawkes_window.size == 0:
|
||||
initial_volumes = np.array([b.volume for b in bar_history], dtype=float)
|
||||
self._initialize_state(initial_volumes[:-1]) # 用到上一根bar为止的数据初始化
|
||||
|
||||
# 增量更新当前bar的状态
|
||||
self._update_state_incrementally(float(bar_history[-1].volume))
|
||||
|
||||
# --- 后续逻辑使用更新后的状态进行计算 ---
|
||||
self.cancel_all_pending_orders(symbol)
|
||||
pos = self.get_current_positions().get(symbol, 0)
|
||||
|
||||
# 【核心修改】直接在固定长度的窗口上计算分位数
|
||||
# 这比pandas.rolling快几个数量级,且结果稳定
|
||||
latest_hawkes_value = self._hawkes_window[-1]
|
||||
latest_hawkes_upper = np.quantile(self._hawkes_window, self.hawkes_entry_percent)
|
||||
latest_hawkes_lower = np.quantile(self._hawkes_window, self.hawkes_exit_percent)
|
||||
|
||||
# 1. 平仓逻辑 (完全不变)
|
||||
meta = self.pos_meta.get(symbol)
|
||||
if meta and pos != 0:
|
||||
if latest_hawkes_value < latest_hawkes_lower:
|
||||
self.log(f"霍克斯出场信号...") # 日志简化
|
||||
self.send_market_order("CLOSE_LONG" if meta['direction'] == "BUY" else "CLOSE_SHORT", abs(pos))
|
||||
del self.pos_meta[symbol]
|
||||
return
|
||||
|
||||
# 2. 开仓逻辑 (完全不变)
|
||||
if pos == 0:
|
||||
close_prices = np.array([b.close for b in bar_history])
|
||||
prices_for_trendline = close_prices[-self.trendline_n - 1:-1]
|
||||
trend_upper, trend_lower = calculate_latest_trendline_values(prices_for_trendline)
|
||||
|
||||
if trend_upper is not None and trend_lower is not None:
|
||||
prev_close = bar_history[-2].close
|
||||
last_close = bar_history[-1].close
|
||||
upper_break_event = last_close > trend_upper and prev_close < trend_upper
|
||||
lower_break_event = last_close < trend_lower and prev_close > trend_lower
|
||||
hawkes_confirmation = latest_hawkes_value > latest_hawkes_upper
|
||||
|
||||
if hawkes_confirmation and (upper_break_event or lower_break_event):
|
||||
trade_direction = None
|
||||
if upper_break_event:
|
||||
trade_direction = "SELL" if self.reverse_logic else "BUY"
|
||||
elif lower_break_event:
|
||||
trade_direction = "BUY" if self.reverse_logic else "SELL"
|
||||
|
||||
if trade_direction and trade_direction in self.order_direction:
|
||||
self.log(f"开仓信号确认...") # 日志简化
|
||||
self.send_open_order(trade_direction, open_price, self.trade_volume)
|
||||
|
||||
# send_open_order, send_market_order, on_rollover 等方法保持不变
|
||||
# ... (代码省略,与之前版本相同) ...
|
||||
|
||||
# send_open_order, send_market_order, on_rollover 等方法保持不变
|
||||
def send_open_order(self, direction: str, entry_price: float, volume: int):
|
||||
current_time = self.get_current_time()
|
||||
order_id = f"{self.symbol}_{direction}_{current_time.strftime('%Y%m%d%H%M%S')}"
|
||||
order_direction = "BUY" if direction == "BUY" else "SELL"
|
||||
order = Order(id=order_id, symbol=self.symbol, direction=order_direction, volume=volume, price_type="MARKET",
|
||||
submitted_time=current_time, offset="OPEN")
|
||||
self.send_order(order)
|
||||
self.pos_meta[self.symbol] = {"direction": direction, "volume": volume, "entry_price": entry_price}
|
||||
self.log(f"发送开仓订单: {direction} {volume}手 @ Market Price (执行价约 {entry_price:.2f})")
|
||||
|
||||
def send_market_order(self, direction: str, volume: int):
|
||||
current_time = self.get_current_time()
|
||||
order_id = f"{self.symbol}_{direction}_{current_time.strftime('%Y%m%d%H%M%S')}"
|
||||
order = Order(id=order_id, symbol=self.symbol, direction=direction, volume=volume, price_type="MARKET",
|
||||
submitted_time=current_time, offset="CLOSE")
|
||||
self.send_order(order)
|
||||
self.log(f"发送平仓订单: {direction} {volume}手 @ Market Price")
|
||||
|
||||
def on_rollover(self, old_symbol: str, new_symbol: str):
|
||||
super().on_rollover(old_symbol, new_symbol)
|
||||
self.cancel_all_pending_orders(new_symbol)
|
||||
self.pos_meta.clear()
|
||||
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
Reference in New Issue
Block a user