1、vp策略-v2
This commit is contained in:
@@ -236,12 +236,12 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
# 这种方式适合获取相对较短或中等长度的历史K线数据。
|
# 这种方式适合获取相对较短或中等长度的历史K线数据。
|
||||||
df_if_backtest_daily = collect_and_save_tqsdk_data_stream(
|
df_if_backtest_daily = collect_and_save_tqsdk_data_stream(
|
||||||
symbol="KQ.m@DCE.v",
|
symbol="KQ.m@CZCE.CF",
|
||||||
# symbol='SHFE.rb2510',
|
# symbol='SHFE.rb2510',
|
||||||
# symbol='KQ.i@SHFE.bu',
|
# symbol='KQ.i@SHFE.bu',
|
||||||
freq="min15",
|
freq="min15",
|
||||||
start_date_str="2021-01-01",
|
start_date_str="2021-01-01",
|
||||||
end_date_str="2025-09-20",
|
end_date_str="2025-10-20",
|
||||||
mode="backtest", # 指定为回测模式
|
mode="backtest", # 指定为回测模式
|
||||||
tq_user=TQ_USER_NAME,
|
tq_user=TQ_USER_NAME,
|
||||||
tq_pwd=TQ_PASSWORD,
|
tq_pwd=TQ_PASSWORD,
|
||||||
|
|||||||
@@ -120,7 +120,7 @@ class ResultAnalyzer:
|
|||||||
) # 明确标题
|
) # 明确标题
|
||||||
print("图表绘制完成。")
|
print("图表绘制完成。")
|
||||||
|
|
||||||
def analyze_indicators(self):
|
def analyze_indicators(self, profit_offset: float = 0.0) -> None:
|
||||||
"""
|
"""
|
||||||
分析所有平仓交易的指标值与实现盈亏的关系,并绘制累积盈亏曲线图。
|
分析所有平仓交易的指标值与实现盈亏的关系,并绘制累积盈亏曲线图。
|
||||||
图表将展示指标值区间与对应累积盈亏的关系,帮助找出具有概率优势的指标区间。
|
图表将展示指标值区间与对应累积盈亏的关系,帮助找出具有概率优势的指标区间。
|
||||||
@@ -156,7 +156,7 @@ class ResultAnalyzer:
|
|||||||
and np.isnan(trade.indicator_dict[indicator_name])
|
and np.isnan(trade.indicator_dict[indicator_name])
|
||||||
):
|
):
|
||||||
indi_values.append(trade.indicator_dict[indicator_name])
|
indi_values.append(trade.indicator_dict[indicator_name])
|
||||||
pnls.append(trade.realized_pnl)
|
pnls.append(trade.realized_pnl - profit_offset)
|
||||||
|
|
||||||
if not indi_values:
|
if not indi_values:
|
||||||
print(f"指标 '{indicator_name}' 没有对应的有效平仓交易数据。跳过绘图。")
|
print(f"指标 '{indicator_name}' 没有对应的有效平仓交易数据。跳过绘图。")
|
||||||
|
|||||||
@@ -3,92 +3,93 @@
|
|||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Optional, Any, Dict, TYPE_CHECKING
|
from typing import Optional, Any, Dict, TYPE_CHECKING
|
||||||
|
|
||||||
# 使用 TYPE_CHECKING 避免循环导入,只在类型检查时导入 BacktestEngine
|
from .state_repo import StateRepository
|
||||||
if TYPE_CHECKING:
|
|
||||||
from .backtest_engine import BacktestEngine
|
|
||||||
from .execution_simulator import ExecutionSimulator
|
from .execution_simulator import ExecutionSimulator
|
||||||
from .data_manager import DataManager
|
from .data_manager import DataManager
|
||||||
from .core_data import Bar, Order # 确保导入 Order
|
from .core_data import Bar, Order
|
||||||
|
|
||||||
|
|
||||||
class BacktestContext:
|
class BacktestContext:
|
||||||
"""
|
"""
|
||||||
回测上下文,用于连接策略与数据管理器、模拟器。
|
回测上下文,用于连接策略与数据管理器、模拟器和状态持久化。
|
||||||
策略通过此上下文与回测引擎进行交互。
|
策略通过此上下文与回测引擎进行交互。
|
||||||
"""
|
"""
|
||||||
def __init__(self, data_manager: 'DataManager', simulator: 'ExecutionSimulator'):
|
|
||||||
|
def __init__(self,
|
||||||
|
data_manager: 'DataManager',
|
||||||
|
simulator: 'ExecutionSimulator',
|
||||||
|
state_repository: 'StateRepository'): # MODIFIED: 新增参数
|
||||||
"""
|
"""
|
||||||
初始化回测上下文。
|
初始化回测上下文。
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
data_manager (DataManager): 数据管理器实例。
|
data_manager (DataManager): 数据管理器实例。
|
||||||
simulator (ExecutionSimulator): 交易模拟器实例。
|
simulator (ExecutionSimulator): 交易模拟器实例。
|
||||||
|
state_repository (StateRepository): 状态管理仓储实例,用于持久化策略状态。
|
||||||
"""
|
"""
|
||||||
self._data_manager = data_manager
|
self._data_manager = data_manager
|
||||||
self._simulator = simulator
|
self._simulator = simulator
|
||||||
|
self._state_repository = state_repository # NEW: 存储状态仓储实例
|
||||||
self._current_bar: Optional['Bar'] = None
|
self._current_bar: Optional['Bar'] = None
|
||||||
self._engine: Optional['BacktestEngine'] = None # 添加对引擎的引用
|
self._engine = None
|
||||||
|
|
||||||
|
# --- 新增:状态管理功能 ---
|
||||||
|
|
||||||
|
def save_state(self, state: Dict[str, Any]) -> None:
|
||||||
|
"""
|
||||||
|
保存策略的当前状态。
|
||||||
|
|
||||||
|
策略应在适当的时机(例如,每日结束、策略关闭时)调用此方法
|
||||||
|
来持久化其内部变量。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
state (Dict[str, Any]): 包含策略状态的字典。
|
||||||
|
"""
|
||||||
|
self._state_repository.save(state)
|
||||||
|
|
||||||
|
def load_state(self) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
加载策略的历史状态。
|
||||||
|
|
||||||
|
策略应在初始化时调用此方法来恢复之前的运行状态。
|
||||||
|
如果不存在历史状态,将返回一个空字典。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict[str, Any]: 包含策略历史状态的字典。
|
||||||
|
"""
|
||||||
|
return self._state_repository.load()
|
||||||
|
|
||||||
|
# --- 现有功能保持不变 ---
|
||||||
|
|
||||||
def set_current_bar(self, bar: 'Bar'):
|
def set_current_bar(self, bar: 'Bar'):
|
||||||
"""
|
|
||||||
设置当前正在处理的 K 线数据。
|
|
||||||
由 BacktestEngine 调用。
|
|
||||||
"""
|
|
||||||
self._current_bar = bar
|
self._current_bar = bar
|
||||||
|
|
||||||
def get_current_bar(self) -> Optional['Bar']:
|
def get_current_bar(self) -> Optional['Bar']:
|
||||||
"""
|
|
||||||
获取当前正在处理的 K 线数据。
|
|
||||||
策略可以通过此方法获取最新 K 线。
|
|
||||||
"""
|
|
||||||
return self._current_bar
|
return self._current_bar
|
||||||
|
|
||||||
def get_current_time(self) -> datetime:
|
def get_current_time(self) -> datetime:
|
||||||
"""
|
|
||||||
获取当前模拟时间。
|
|
||||||
"""
|
|
||||||
return self._simulator.get_current_time()
|
return self._simulator.get_current_time()
|
||||||
|
|
||||||
def get_current_positions(self) -> Dict[str, int]:
|
def get_current_positions(self) -> Dict[str, int]:
|
||||||
"""
|
|
||||||
获取当前所有持仓。
|
|
||||||
"""
|
|
||||||
return self._simulator.get_current_positions()
|
return self._simulator.get_current_positions()
|
||||||
|
|
||||||
def get_pending_orders(self) -> Dict[str, 'Order']:
|
def get_pending_orders(self) -> Dict[str, 'Order']:
|
||||||
"""
|
|
||||||
获取当前所有待处理(未成交)订单。
|
|
||||||
"""
|
|
||||||
return self._simulator.get_pending_orders()
|
return self._simulator.get_pending_orders()
|
||||||
|
|
||||||
def get_account_cash(self) -> float:
|
def get_account_cash(self) -> float:
|
||||||
"""
|
|
||||||
获取当前可用现金。
|
|
||||||
"""
|
|
||||||
return self._simulator.cash
|
return self._simulator.cash
|
||||||
|
|
||||||
def get_average_position_price(self, symbol: str) -> Optional[float]:
|
def get_average_position_price(self, symbol: str) -> Optional[float]:
|
||||||
"""
|
|
||||||
获取指定合约的平均持仓成本。
|
|
||||||
"""
|
|
||||||
return self._simulator.get_average_position_price(symbol)
|
return self._simulator.get_average_position_price(symbol)
|
||||||
|
|
||||||
def send_order(self, order: 'Order') -> Optional['Order']:
|
def send_order(self, order: 'Order') -> Optional['Order']:
|
||||||
"""
|
|
||||||
策略通过此方法发送订单到模拟器。
|
|
||||||
"""
|
|
||||||
return self._simulator.send_order_to_pending(order)
|
return self._simulator.send_order_to_pending(order)
|
||||||
|
|
||||||
def cancel_order(self, order_id: str) -> bool:
|
def cancel_order(self, order_id: str) -> bool:
|
||||||
"""
|
|
||||||
策略通过此方法取消指定ID的订单。
|
|
||||||
"""
|
|
||||||
return self._simulator.cancel_order(order_id)
|
return self._simulator.cancel_order(order_id)
|
||||||
|
|
||||||
def set_engine(self, engine: 'BacktestEngine'):
|
def set_engine(self, engine: 'BacktestEngine'):
|
||||||
"""
|
|
||||||
设置对 BacktestEngine 实例的引用。
|
|
||||||
由 BacktestEngine 在初始化时调用,用于允许 Context 访问 Engine 的状态。
|
|
||||||
"""
|
|
||||||
self._engine = engine
|
self._engine = engine
|
||||||
|
|
||||||
def get_bar_history(self):
|
def get_bar_history(self):
|
||||||
@@ -99,12 +100,6 @@ class BacktestContext:
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def is_rollover_bar(self) -> bool:
|
def is_rollover_bar(self) -> bool:
|
||||||
"""
|
|
||||||
属性:判断当前 K 线是否为换月 K 线(即新合约的第一根 K 线)。
|
|
||||||
用于在换月时禁止策略开仓。
|
|
||||||
"""
|
|
||||||
if self._engine:
|
if self._engine:
|
||||||
return self._engine.is_rollover_bar
|
return self._engine.is_rollover_bar
|
||||||
# 如果没有设置引擎引用,默认不认为是换月 K 线
|
|
||||||
# 这通常发生在测试 Context 本身时,或 Engine 初始化不完整的情况。
|
|
||||||
return False
|
return False
|
||||||
@@ -5,12 +5,14 @@ import numpy as np
|
|||||||
import pandas as pd
|
import pandas as pd
|
||||||
|
|
||||||
from src.indicators.base_indicators import Indicator
|
from src.indicators.base_indicators import Indicator
|
||||||
|
from .common_utils import generate_strategy_identifier
|
||||||
|
|
||||||
# 导入所有需要协调的模块
|
# 导入所有需要协调的模块
|
||||||
from .core_data import Bar, Order, Trade, PortfolioSnapshot
|
from .core_data import Bar, Order, Trade, PortfolioSnapshot
|
||||||
from .data_manager import DataManager
|
from .data_manager import DataManager
|
||||||
from .execution_simulator import ExecutionSimulator
|
from .execution_simulator import ExecutionSimulator
|
||||||
from .backtest_context import BacktestContext
|
from .backtest_context import BacktestContext
|
||||||
|
from .state_repo import MemoryStateRepository
|
||||||
from .strategies.base_strategy import Strategy
|
from .strategies.base_strategy import Strategy
|
||||||
|
|
||||||
class BacktestEngine:
|
class BacktestEngine:
|
||||||
@@ -50,7 +52,8 @@ class BacktestEngine:
|
|||||||
commission_rate=commission_rate
|
commission_rate=commission_rate
|
||||||
)
|
)
|
||||||
# 传入引擎自身给 context,以便 context 可以获取引擎的状态(如 is_rollover_bar)
|
# 传入引擎自身给 context,以便 context 可以获取引擎的状态(如 is_rollover_bar)
|
||||||
self.context = BacktestContext(self.data_manager, self.simulator)
|
identifier = generate_strategy_identifier(strategy_class, strategy_params)
|
||||||
|
self.context = BacktestContext(self.data_manager, self.simulator, MemoryStateRepository(identifier))
|
||||||
self.context.set_engine(self) # 建立 Context 到 Engine 的引用
|
self.context.set_engine(self) # 建立 Context 到 Engine 的引用
|
||||||
|
|
||||||
# self.current_segment_symbol = current_segment_symbol # 此行移除或作为内部变量动态管理
|
# self.current_segment_symbol = current_segment_symbol # 此行移除或作为内部变量动态管理
|
||||||
|
|||||||
@@ -212,3 +212,93 @@ def is_bar_pre_close_period(
|
|||||||
# 6. 判断当前系统时间是否在这个窗口内
|
# 6. 判断当前系统时间是否在这个窗口内
|
||||||
# 窗口定义为 [pre_close_window_start_time, final_bar_end_time),即包含开始时间,不包含结束时间
|
# 窗口定义为 [pre_close_window_start_time, final_bar_end_time),即包含开始时间,不包含结束时间
|
||||||
return pre_close_window_start_time <= current_system_time < final_bar_end_time
|
return pre_close_window_start_time <= current_system_time < final_bar_end_time
|
||||||
|
|
||||||
|
import hashlib
|
||||||
|
import json
|
||||||
|
from typing import Type, Any, Dict, Union
|
||||||
|
|
||||||
|
# --- 辅助编码器,用于处理参数中的 “类” 对象 (保持不变) ---
|
||||||
|
class StrategyParamEncoder(json.JSONEncoder):
|
||||||
|
def default(self, o: Any) -> Any:
|
||||||
|
if isinstance(o, type):
|
||||||
|
return f"{o.__module__}.{o.__name__}"
|
||||||
|
return super().default(o)
|
||||||
|
|
||||||
|
# --- 新增:递归净化函数,用于移除实例 ---
|
||||||
|
def _clean_params_for_hashing(data: Any) -> Any:
|
||||||
|
"""
|
||||||
|
递归地“净化”参数数据,将所有非基本类型、非类的实例对象替换为 None。
|
||||||
|
这确保了只有可序列化的配置值会影响最终的哈希结果。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data: 输入的数据,可以是字典、列表、或任何值。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
一个净化后的版本,其中所有实例对象都被替换为 None。
|
||||||
|
"""
|
||||||
|
# 1. 基本情况:如果数据是基本类型或一个“类”,直接返回
|
||||||
|
if isinstance(data, (str, int, float, bool, type(None), type)):
|
||||||
|
return data
|
||||||
|
|
||||||
|
# 2. 递归情况:处理字典
|
||||||
|
if isinstance(data, dict):
|
||||||
|
# 遍历字典,对每个值进行递归净化
|
||||||
|
return {key: _clean_params_for_hashing(value) for key, value in data.items()}
|
||||||
|
|
||||||
|
# 3. 递归情况:处理列表和元组
|
||||||
|
if isinstance(data, (list, tuple)):
|
||||||
|
# 遍历序列,对每个元素进行递归净化
|
||||||
|
return type(data)([_clean_params_for_hashing(item) for item in data])
|
||||||
|
|
||||||
|
# 4. 最终情况:如果代码运行到这里,说明 `data` 是一个我们不希望
|
||||||
|
# 纳入哈希计算的实例对象。我们用 None 替换它。
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def generate_strategy_identifier(
|
||||||
|
strategy_class: Type,
|
||||||
|
parameters: Dict[str, Any]
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
为策略实例生成一个唯一的、确定性的标识符 (忽略实例版)。
|
||||||
|
|
||||||
|
此版本会主动忽略参数中的实例对象,在计算哈希前将它们替换为 None。
|
||||||
|
这允许策略接收已实例化的组件,同时确保标识符的唯一性仅由
|
||||||
|
可配置的、可序列化的参数决定。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
strategy_class (Type): 策略的类本身。
|
||||||
|
parameters (Dict[str, Any]): 用于初始化该策略的参数字典。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: 一个唯一的、适合用作文件名或Redis键的标识符。
|
||||||
|
"""
|
||||||
|
# 1. 获取模块路径和类名
|
||||||
|
module_path = strategy_class.__module__
|
||||||
|
class_name = strategy_class.__name__
|
||||||
|
|
||||||
|
# 2. [核心修改] 首先,净化参数字典,移除实例对象
|
||||||
|
cleaned_parameters = _clean_params_for_hashing(parameters)
|
||||||
|
|
||||||
|
# 3. 对净化后的参数进行稳定序列化
|
||||||
|
# 这里仍然需要自定义的Encoder来处理参数中合法的“类”对象。
|
||||||
|
try:
|
||||||
|
param_string = json.dumps(
|
||||||
|
cleaned_parameters,
|
||||||
|
sort_keys=True,
|
||||||
|
separators=(',', ':'),
|
||||||
|
cls=StrategyParamEncoder
|
||||||
|
)
|
||||||
|
except TypeError as e:
|
||||||
|
# 理论上,经过净化后,这里不应该再出现TypeError,但作为保护性代码保留
|
||||||
|
raise TypeError(
|
||||||
|
f"净化后的策略 '{class_name}' 参数依然无法序列化。请检查参数结构。原始错误: {e}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 4. 计算参数字符串的哈希值
|
||||||
|
param_hash = hashlib.md5(param_string.encode('utf-8')).hexdigest()
|
||||||
|
|
||||||
|
# 5. 组合成最终的标识符
|
||||||
|
identifier = f"{module_path}.{class_name}_{param_hash}"
|
||||||
|
|
||||||
|
return identifier
|
||||||
134
src/state_repo.py
Normal file
134
src/state_repo.py
Normal file
@@ -0,0 +1,134 @@
|
|||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import Any, Dict
|
||||||
|
|
||||||
|
class StateRepository(ABC):
|
||||||
|
"""
|
||||||
|
状态仓储的抽象基类 (优化版)。
|
||||||
|
每个仓储实例都与一个唯一标识符绑定,专门负责该标识符对应状态的持久化。
|
||||||
|
这种设计简化了API,并使其职责更加单一。
|
||||||
|
"""
|
||||||
|
def __init__(self, identifier: str):
|
||||||
|
"""
|
||||||
|
:param identifier: 此仓储实例绑定的唯一状态标识符。
|
||||||
|
"""
|
||||||
|
self.identifier = identifier
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def save(self, state: Dict[str, Any]) -> None:
|
||||||
|
"""
|
||||||
|
持久化当前的状态。
|
||||||
|
|
||||||
|
:param state: 需要保存的状态数据。
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def load(self) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
加载历史状态。
|
||||||
|
|
||||||
|
:return: 返回保存的状态字典。若无历史状态,则返回空字典。
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class MemoryStateRepository(StateRepository):
|
||||||
|
"""
|
||||||
|
一个完全基于内存的状态仓储。
|
||||||
|
状态仅在程序的生命周期内存在。适用于回测和单元测试。
|
||||||
|
"""
|
||||||
|
# 使用一个类级别的字典来模拟全局存储,以便在同一进程中多次创建同名实例时能找回状态
|
||||||
|
_global_storage: Dict[str, Dict[str, Any]] = {}
|
||||||
|
|
||||||
|
def __init__(self, identifier: str):
|
||||||
|
super().__init__(identifier)
|
||||||
|
# 确保该标识符在全局存储中有个位置
|
||||||
|
if self.identifier not in MemoryStateRepository._global_storage:
|
||||||
|
MemoryStateRepository._global_storage[self.identifier] = {}
|
||||||
|
print(f"内存仓储已初始化,管理ID: '{self.identifier}'")
|
||||||
|
|
||||||
|
def save(self, state: Dict[str, Any]) -> None:
|
||||||
|
# 使用 .copy() 存储副本,防止外部修改影响内部状态
|
||||||
|
MemoryStateRepository._global_storage[self.identifier] = state.copy()
|
||||||
|
|
||||||
|
def load(self) -> Dict[str, Any]:
|
||||||
|
# 返回副本,防止调用方意外修改内部状态
|
||||||
|
state = MemoryStateRepository._global_storage.get(self.identifier, {}).copy()
|
||||||
|
return state
|
||||||
|
|
||||||
|
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
|
||||||
|
|
||||||
|
class JsonFileStateRepository(StateRepository):
|
||||||
|
"""
|
||||||
|
使用JSON文件作为后端的状态仓储 (简化版)。
|
||||||
|
|
||||||
|
本实现假设在任何时间点,最多只有一个进程会写入与本实例关联的
|
||||||
|
特定文件。因此,它采用直接写入文件的方式,没有包含处理并发
|
||||||
|
写入的原子操作逻辑。
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, identifier: str, storage_path: str = './states'):
|
||||||
|
super().__init__(identifier)
|
||||||
|
self.storage_path = storage_path
|
||||||
|
self._file_path = os.path.join(self.storage_path, f"{self.identifier}.json")
|
||||||
|
|
||||||
|
if not os.path.exists(self.storage_path):
|
||||||
|
os.makedirs(self.storage_path)
|
||||||
|
print(f"JSON文件仓储(简化版)已初始化,将管理文件: '{self._file_path}'")
|
||||||
|
|
||||||
|
def save(self, state: Dict[str, Any]) -> None:
|
||||||
|
"""
|
||||||
|
将状态直接写入到本实例绑定的文件中。
|
||||||
|
注意:此操作不是原子的。
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
with open(self._file_path, 'w', encoding='utf-8') as f:
|
||||||
|
json.dump(state, f, indent=4, ensure_ascii=False)
|
||||||
|
print(f"(JSON) 状态已直接保存至 '{self._file_path}'")
|
||||||
|
except (IOError, TypeError) as e:
|
||||||
|
print(f"错误:无法将状态写入文件 '{self._file_path}'。错误: {e}")
|
||||||
|
|
||||||
|
def load(self) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
从本实例绑定的文件中加载状态。
|
||||||
|
"""
|
||||||
|
if not os.path.exists(self._file_path):
|
||||||
|
return {}
|
||||||
|
try:
|
||||||
|
with open(self._file_path, 'r', encoding='utf-8') as f:
|
||||||
|
state = json.load(f)
|
||||||
|
print(f"(JSON) 从 '{self._file_path}' 加载了状态。")
|
||||||
|
return state
|
||||||
|
except (IOError, json.JSONDecodeError) as e:
|
||||||
|
print(f"错误:无法从 '{self._file_path}' 加载状态。文件可能已损坏。错误: {e}")
|
||||||
|
return {}
|
||||||
|
|
||||||
|
|
||||||
|
import redis
|
||||||
|
class RedisStateRepository(StateRepository):
|
||||||
|
DEFAULT_EXPIRATION_SECONDS = 7 * 24 * 60 * 60
|
||||||
|
|
||||||
|
def __init__(self, identifier: str, host='localhost', port=6379, db=0, expiration_sec: int = DEFAULT_EXPIRATION_SECONDS):
|
||||||
|
super().__init__(identifier)
|
||||||
|
try:
|
||||||
|
self.redis_client = redis.Redis(host=host, port=port, db=db, decode_responses=True)
|
||||||
|
self.redis_client.ping()
|
||||||
|
self.expiration_seconds = expiration_sec
|
||||||
|
print(f"Redis仓储已连接,将管理Key: '{self.identifier}'")
|
||||||
|
except redis.exceptions.ConnectionError as e:
|
||||||
|
raise ConnectionError(f"无法连接到Redis服务器 at {host}:{port}。") from e
|
||||||
|
|
||||||
|
def save(self, state: Dict[str, Any]) -> None:
|
||||||
|
serialized_state = json.dumps(state)
|
||||||
|
self.redis_client.set(self.identifier, serialized_state, ex=self.expiration_seconds)
|
||||||
|
print(f"(Redis) Key '{self.identifier}' 的状态已保存。")
|
||||||
|
|
||||||
|
def load(self) -> Dict[str, Any]:
|
||||||
|
serialized_state = self.redis_client.get(self.identifier)
|
||||||
|
if serialized_state is None:
|
||||||
|
return {}
|
||||||
|
return json.loads(serialized_state)
|
||||||
File diff suppressed because one or more lines are too long
@@ -276,13 +276,3 @@ class ValueMigrationStrategy(Strategy):
|
|||||||
)
|
)
|
||||||
self.send_order(order)
|
self.send_order(order)
|
||||||
|
|
||||||
|
|
||||||
def send_market_order(self, direction: str, volume: int, offset: str = "CLOSE"):
|
|
||||||
# ... (与之前版本相同) ...
|
|
||||||
order_id = f"{self.symbol}_{direction}_{offset}_{self.get_current_time().strftime('%Y%m%d%H%M%S')}_{self.order_id_counter}"
|
|
||||||
self.order_id_counter += 1
|
|
||||||
order = Order(
|
|
||||||
id=order_id, symbol=self.symbol, direction=direction, volume=volume,
|
|
||||||
price_type="MARKET", submitted_time=self.get_current_time(), offset=offset
|
|
||||||
)
|
|
||||||
self.send_order(order)
|
|
||||||
|
|||||||
315
src/strategies/ValueMigrationStrategy/ValueMigrationStrategy2.py
Normal file
315
src/strategies/ValueMigrationStrategy/ValueMigrationStrategy2.py
Normal file
@@ -0,0 +1,315 @@
|
|||||||
|
# =====================================================================================
|
||||||
|
# 以下是新增的 ValueMigrationStrategy 策略代码
|
||||||
|
# =====================================================================================
|
||||||
|
from collections import deque
|
||||||
|
from datetime import timedelta, time
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pandas as pd
|
||||||
|
from typing import List, Any, Optional, Dict
|
||||||
|
|
||||||
|
import talib
|
||||||
|
|
||||||
|
from src.core_data import Bar, Order
|
||||||
|
from src.strategies.ValueMigrationStrategy.data_class import ProfileStats, calculate_profile_from_bars
|
||||||
|
from src.strategies.base_strategy import Strategy
|
||||||
|
|
||||||
|
|
||||||
|
# = ===================================================================
|
||||||
|
# 全局辅助函数 (Global Helper Functions)
|
||||||
|
# 将这些函数放在文件顶部,以便所有策略类都能调用
|
||||||
|
# =====================================================================
|
||||||
|
|
||||||
|
def compute_price_volume_distribution(bars: List[Bar], tick_size: float) -> Optional[pd.Series]:
|
||||||
|
"""
|
||||||
|
[全局函数] 从K线数据中计算出原始的价格-成交量分布。
|
||||||
|
"""
|
||||||
|
if not bars:
|
||||||
|
return None
|
||||||
|
|
||||||
|
data = []
|
||||||
|
# 为了性能,我们只处理有限数量的bars,防止内存问题
|
||||||
|
# 在实际应用中,更高效的实现是必要的
|
||||||
|
for bar in bars[-500:]: # 添加一个安全限制
|
||||||
|
price_range = np.arange(bar.low, bar.high + tick_size, tick_size)
|
||||||
|
if len(price_range) == 0 or bar.volume == 0: continue
|
||||||
|
|
||||||
|
# 将成交量近似分布到K线覆盖的每个tick上
|
||||||
|
volume_per_tick = bar.volume / len(price_range)
|
||||||
|
for price in price_range:
|
||||||
|
data.append({'price': price, 'volume': volume_per_tick})
|
||||||
|
|
||||||
|
if not data:
|
||||||
|
return None
|
||||||
|
|
||||||
|
df = pd.DataFrame(data)
|
||||||
|
if df.empty:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return df.groupby('price')['volume'].sum().sort_index()
|
||||||
|
|
||||||
|
|
||||||
|
# 确保在文件顶部导入
|
||||||
|
from scipy.signal import find_peaks
|
||||||
|
|
||||||
|
|
||||||
|
def find_hvns_with_distance(price_volume_dist: pd.Series, distance_in_ticks: int) -> List[float]:
|
||||||
|
"""
|
||||||
|
[全局函数] 使用峰值查找算法,根据峰值间的最小距离来识别HVNs。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
price_volume_dist: 价格-成交量分布序列。
|
||||||
|
distance_in_ticks: 两个HVN之间必须间隔的最小tick数量。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
一个包含所有被识别出的HVN价格的列表。
|
||||||
|
"""
|
||||||
|
if price_volume_dist.empty or len(price_volume_dist) < 3:
|
||||||
|
return []
|
||||||
|
|
||||||
|
# distance参数确保找到的峰值之间至少相隔N个点
|
||||||
|
peaks_indices, _ = find_peaks(price_volume_dist.values, distance=distance_in_ticks)
|
||||||
|
|
||||||
|
if len(peaks_indices) == 0:
|
||||||
|
return [price_volume_dist.idxmax()] # 默认返回POC
|
||||||
|
|
||||||
|
hvn_prices = price_volume_dist.index[peaks_indices].tolist()
|
||||||
|
return hvn_prices
|
||||||
|
|
||||||
|
|
||||||
|
def find_hvns_strict(price_volume_dist: pd.Series, window_radius: int) -> List[float]:
|
||||||
|
"""
|
||||||
|
[全局函数] 使用严格的“滚动窗口最大值”定义来识别HVNs。
|
||||||
|
一个点是HVN,当且仅当它的成交量大于其左右各 `window_radius` 个点的成交量。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
price_volume_dist: 价格-成交量分布序列。
|
||||||
|
window_radius: 定义了检查窗口的半径 (即您所说的 N)。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
一个包含所有被识别出的HVN价格的列表。
|
||||||
|
"""
|
||||||
|
if price_volume_dist.empty or window_radius == 0:
|
||||||
|
return [price_volume_dist.idxmax()] if not price_volume_dist.empty else []
|
||||||
|
|
||||||
|
# 1. 确保价格序列是连续的,用0填充缺失的ticks
|
||||||
|
full_price_range = np.arange(price_volume_dist.index.min(),
|
||||||
|
price_volume_dist.index.max() + price_volume_dist.index.to_series().diff().min(),
|
||||||
|
price_volume_dist.index.to_series().diff().min())
|
||||||
|
continuous_dist = price_volume_dist.reindex(full_price_range, fill_value=0)
|
||||||
|
|
||||||
|
# 2. 计算滚动窗口最大值
|
||||||
|
window_size = 2 * window_radius + 1
|
||||||
|
rolling_max = continuous_dist.rolling(window=window_size, center=True).max()
|
||||||
|
|
||||||
|
# 3. 找到那些自身成交量就等于其窗口最大值的点
|
||||||
|
is_hvn = (continuous_dist == rolling_max) & (continuous_dist > 0)
|
||||||
|
hvn_prices = continuous_dist[is_hvn].index.tolist()
|
||||||
|
|
||||||
|
# 4. 处理平顶山:如果连续多个点都是HVN,只保留中间那个
|
||||||
|
if not hvn_prices:
|
||||||
|
return [price_volume_dist.idxmax()] # 如果找不到,返回POC
|
||||||
|
|
||||||
|
final_hvns = []
|
||||||
|
i = 0
|
||||||
|
while i < len(hvn_prices):
|
||||||
|
# 找到一个连续HVN块
|
||||||
|
j = i
|
||||||
|
while j + 1 < len(hvn_prices) and (hvn_prices[j + 1] - hvn_prices[j]) < (
|
||||||
|
2 * price_volume_dist.index.to_series().diff().min()):
|
||||||
|
j += 1
|
||||||
|
|
||||||
|
# 取这个连续块的中间点
|
||||||
|
middle_index = i + (j - i) // 2
|
||||||
|
final_hvns.append(hvn_prices[middle_index])
|
||||||
|
|
||||||
|
i = j + 1
|
||||||
|
|
||||||
|
return final_hvns
|
||||||
|
|
||||||
|
|
||||||
|
# 确保在文件顶部导入
|
||||||
|
from scipy.signal import find_peaks
|
||||||
|
|
||||||
|
|
||||||
|
# =====================================================================================
|
||||||
|
# 以下是V2版本的、简化了状态管理的 HVNPullbackStrategy 代码
|
||||||
|
# =====================================================================================
|
||||||
|
|
||||||
|
class ValueMigrationStrategy(Strategy):
|
||||||
|
"""
|
||||||
|
一个基于动态HVN突破后回测的量化交易策略。(V2: 简化状态管理)
|
||||||
|
|
||||||
|
V2版本简化了内部状态管理,移除了基于order_id的复杂元数据传递,
|
||||||
|
使用更直接、更健壮的单一状态变量来处理挂单的止盈止损参数,
|
||||||
|
完美适配“单次单持仓”的策略逻辑。
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
context: Any,
|
||||||
|
main_symbol: str,
|
||||||
|
enable_log: bool,
|
||||||
|
trade_volume: int,
|
||||||
|
tick_size: float = 1,
|
||||||
|
profile_period: int = 100,
|
||||||
|
recalc_interval: int = 4,
|
||||||
|
hvn_distance_ticks: int = 20,
|
||||||
|
entry_offset_atr: float = 0.0,
|
||||||
|
stop_loss_atr: float = 1.0,
|
||||||
|
take_profit_atr: float = 2.0,
|
||||||
|
atr_period: int = 14,
|
||||||
|
order_direction=None,
|
||||||
|
indicators=[None, None],
|
||||||
|
):
|
||||||
|
super().__init__(context, main_symbol, enable_log)
|
||||||
|
if order_direction is None:
|
||||||
|
order_direction = ['BUY', 'SELL']
|
||||||
|
|
||||||
|
self.trade_volume = trade_volume
|
||||||
|
self.tick_size = tick_size
|
||||||
|
self.profile_period = profile_period
|
||||||
|
self.recalc_interval = recalc_interval
|
||||||
|
self.hvn_distance_ticks = hvn_distance_ticks
|
||||||
|
self.entry_offset_atr = entry_offset_atr
|
||||||
|
self.stop_loss_atr = stop_loss_atr
|
||||||
|
self.take_profit_atr = take_profit_atr
|
||||||
|
self.atr_period = atr_period
|
||||||
|
|
||||||
|
self.order_direction = order_direction
|
||||||
|
self.indicator_long = indicators[0]
|
||||||
|
self.indicator_short = indicators[1]
|
||||||
|
|
||||||
|
self.main_symbol = main_symbol
|
||||||
|
self.order_id_counter = 0
|
||||||
|
|
||||||
|
self._bar_counter = 0
|
||||||
|
self._cached_hvns: List[float] = []
|
||||||
|
|
||||||
|
# --- V2: 简化的状态管理 ---
|
||||||
|
self._pending_sl_price: Optional[float] = None
|
||||||
|
self._pending_tp_price: Optional[float] = None
|
||||||
|
|
||||||
|
def on_open_bar(self, open_price: float, symbol: str):
|
||||||
|
self.symbol = symbol
|
||||||
|
self._bar_counter += 1
|
||||||
|
bar_history = self.get_bar_history()
|
||||||
|
|
||||||
|
required_len = max(self.profile_period, self.atr_period) + 1
|
||||||
|
if len(bar_history) < required_len:
|
||||||
|
return
|
||||||
|
|
||||||
|
# --- 1. 取消所有挂单并重置挂单状态 ---
|
||||||
|
self.cancel_all_pending_orders(self.symbol)
|
||||||
|
# self._pending_sl_price = None
|
||||||
|
# self._pending_tp_price = None
|
||||||
|
|
||||||
|
# --- 2. 管理现有持仓 ---
|
||||||
|
position_volume = self.get_current_positions().get(self.symbol, 0)
|
||||||
|
if position_volume != 0:
|
||||||
|
self.manage_open_position(position_volume, open_price)
|
||||||
|
return
|
||||||
|
|
||||||
|
# --- 3. 周期性地计算HVNs ---
|
||||||
|
if self._bar_counter % self.recalc_interval == 1:
|
||||||
|
profile_bars = bar_history[-self.profile_period:]
|
||||||
|
dist = compute_price_volume_distribution(profile_bars, self.tick_size)
|
||||||
|
if dist is not None and not dist.empty:
|
||||||
|
# self._cached_hvns = find_hvns_with_distance(dist, self.hvn_distance_ticks)
|
||||||
|
self._cached_hvns = find_hvns_strict(dist, self.hvn_distance_ticks)
|
||||||
|
self.log(f"New HVNs identified at: {[f'{p:.2f}' for p in self._cached_hvns]}")
|
||||||
|
|
||||||
|
if not self._cached_hvns: return
|
||||||
|
|
||||||
|
# --- 4. 评估新机会 (挂单逻辑) ---
|
||||||
|
self.evaluate_entry_signal(bar_history)
|
||||||
|
|
||||||
|
def manage_open_position(self, volume: int, current_price: float):
|
||||||
|
"""主动管理已开仓位的止盈止损。"""
|
||||||
|
|
||||||
|
# # [V2 关键逻辑]: 检测是否为新持仓
|
||||||
|
# # 如果这是一个新持仓,并且我们有预设的止盈止损,就将其存入
|
||||||
|
# if self._pending_sl_price is not None and self._pending_tp_price is not None:
|
||||||
|
# meta = {'sl_price': self._pending_sl_price, 'tp_price': self._pending_tp_price}
|
||||||
|
# self.position_meta = meta
|
||||||
|
# self.log(f"新持仓确认。已设置TP/SL: {meta}")
|
||||||
|
# else:
|
||||||
|
# # 这种情况理论上不应发生,但作为保护
|
||||||
|
# self.log("Error: New position detected but no pending TP/SL values found.")
|
||||||
|
# self.close_position("CLOSE_LONG" if volume > 0 else "CLOSE_SHORT", abs(volume))
|
||||||
|
# return
|
||||||
|
|
||||||
|
# [常规逻辑]: 检查止盈止损
|
||||||
|
sl_price = self._pending_sl_price
|
||||||
|
tp_price = self._pending_tp_price
|
||||||
|
|
||||||
|
if volume > 0: # 多头
|
||||||
|
if current_price <= sl_price or current_price >= tp_price:
|
||||||
|
action = "止损" if current_price <= sl_price else "止盈"
|
||||||
|
self.log(f"多头{action}触发 at {current_price:.2f}")
|
||||||
|
self.close_position("CLOSE_LONG", abs(volume))
|
||||||
|
elif volume < 0: # 空头
|
||||||
|
if current_price >= sl_price or current_price <= tp_price:
|
||||||
|
action = "止损" if current_price >= sl_price else "止盈"
|
||||||
|
self.log(f"空头{action}触发 at {current_price:.2f}")
|
||||||
|
self.close_position("CLOSE_SHORT", abs(volume))
|
||||||
|
|
||||||
|
def evaluate_entry_signal(self, bar_history: List[Bar]):
|
||||||
|
prev_close = bar_history[-2].close
|
||||||
|
current_close = bar_history[-1].close
|
||||||
|
|
||||||
|
highs = np.array([b.high for b in bar_history], dtype=float)
|
||||||
|
lows = np.array([b.low for b in bar_history], dtype=float)
|
||||||
|
closes = np.array([b.close for b in bar_history], dtype=float)
|
||||||
|
current_atr = talib.ATR(highs, lows, closes, self.atr_period)[-1]
|
||||||
|
if current_atr < self.tick_size: return
|
||||||
|
|
||||||
|
for hvn in sorted(self._cached_hvns):
|
||||||
|
# (为了简洁,买卖逻辑合并)
|
||||||
|
direction = None
|
||||||
|
if "BUY" in self.order_direction and (prev_close < hvn < current_close):
|
||||||
|
direction = "BUY"
|
||||||
|
pass_filter = self.indicator_long is None or self.indicator_long.is_condition_met(
|
||||||
|
*self.get_indicator_tuple())
|
||||||
|
elif "SELL" in self.order_direction and (prev_close > hvn > current_close):
|
||||||
|
direction = "SELL"
|
||||||
|
pass_filter = self.indicator_short is None or self.indicator_short.is_condition_met(
|
||||||
|
*self.get_indicator_tuple())
|
||||||
|
else:
|
||||||
|
continue # 没有触发穿越
|
||||||
|
|
||||||
|
if direction and pass_filter:
|
||||||
|
offset = self.entry_offset_atr * current_atr
|
||||||
|
limit_price = hvn + offset if direction == "BUY" else hvn - offset
|
||||||
|
|
||||||
|
self.log(f"价格穿越HVN({hvn:.2f}). 在 {limit_price:.2f} 挂限价{direction}单。")
|
||||||
|
self.send_hvn_limit_order(direction, limit_price, current_atr)
|
||||||
|
return # 每次只挂一个单
|
||||||
|
|
||||||
|
def send_hvn_limit_order(self, direction: str, limit_price: float, entry_atr: float):
|
||||||
|
# [V2 关键逻辑]: 直接更新实例变量
|
||||||
|
self._pending_sl_price = limit_price - self.stop_loss_atr * entry_atr if direction == "BUY" else limit_price + self.stop_loss_atr * entry_atr
|
||||||
|
self._pending_tp_price = limit_price + self.take_profit_atr * entry_atr if direction == "BUY" else limit_price - self.take_profit_atr * entry_atr
|
||||||
|
|
||||||
|
order_id = f"{self.symbol}_{direction}_LIMIT_{self.order_id_counter}"
|
||||||
|
self.order_id_counter += 1
|
||||||
|
|
||||||
|
order = Order(
|
||||||
|
id=order_id, symbol=self.symbol, direction=direction, volume=self.trade_volume,
|
||||||
|
price_type="LIMIT", limit_price=limit_price, submitted_time=self.get_current_time(),
|
||||||
|
offset="OPEN"
|
||||||
|
)
|
||||||
|
self.send_order(order)
|
||||||
|
|
||||||
|
def close_position(self, direction: str, volume: int):
|
||||||
|
self.send_market_order(direction, volume)
|
||||||
|
|
||||||
|
|
||||||
|
def send_market_order(self, direction: str, volume: int, offset: str = "CLOSE"):
|
||||||
|
order_id = f"{self.symbol}_{direction}_{offset}_{self.get_current_time().strftime('%Y%m%d%H%M%S')}_{self.order_id_counter}"
|
||||||
|
self.order_id_counter += 1
|
||||||
|
order = Order(
|
||||||
|
id=order_id, symbol=self.symbol, direction=direction, volume=volume,
|
||||||
|
price_type="MARKET", submitted_time=self.get_current_time(), offset=offset
|
||||||
|
)
|
||||||
|
self.send_order(order)
|
||||||
317
src/strategies/ValueMigrationStrategy/ValueMigrationStrategy3.py
Normal file
317
src/strategies/ValueMigrationStrategy/ValueMigrationStrategy3.py
Normal file
@@ -0,0 +1,317 @@
|
|||||||
|
# =====================================================================================
|
||||||
|
# 以下是新增的 ValueMigrationStrategy 策略代码
|
||||||
|
# =====================================================================================
|
||||||
|
from collections import deque
|
||||||
|
from datetime import timedelta, time
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pandas as pd
|
||||||
|
from typing import List, Any, Optional, Dict
|
||||||
|
|
||||||
|
import talib
|
||||||
|
|
||||||
|
from src.core_data import Bar, Order
|
||||||
|
from src.strategies.ValueMigrationStrategy.data_class import ProfileStats, calculate_profile_from_bars
|
||||||
|
from src.strategies.base_strategy import Strategy
|
||||||
|
|
||||||
|
|
||||||
|
# = ===================================================================
|
||||||
|
# 全局辅助函数 (Global Helper Functions)
|
||||||
|
# 将这些函数放在文件顶部,以便所有策略类都能调用
|
||||||
|
# =====================================================================
|
||||||
|
|
||||||
|
def compute_price_volume_distribution(bars: List[Bar], tick_size: float) -> Optional[pd.Series]:
|
||||||
|
"""
|
||||||
|
[全局函数] 从K线数据中计算出原始的价格-成交量分布。
|
||||||
|
"""
|
||||||
|
if not bars:
|
||||||
|
return None
|
||||||
|
|
||||||
|
data = []
|
||||||
|
# 为了性能,我们只处理有限数量的bars,防止内存问题
|
||||||
|
# 在实际应用中,更高效的实现是必要的
|
||||||
|
for bar in bars[-500:]: # 添加一个安全限制
|
||||||
|
price_range = np.arange(bar.low, bar.high + tick_size, tick_size)
|
||||||
|
if len(price_range) == 0 or bar.volume == 0: continue
|
||||||
|
|
||||||
|
# 将成交量近似分布到K线覆盖的每个tick上
|
||||||
|
volume_per_tick = bar.volume / len(price_range)
|
||||||
|
for price in price_range:
|
||||||
|
data.append({'price': price, 'volume': volume_per_tick})
|
||||||
|
|
||||||
|
if not data:
|
||||||
|
return None
|
||||||
|
|
||||||
|
df = pd.DataFrame(data)
|
||||||
|
if df.empty:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return df.groupby('price')['volume'].sum().sort_index()
|
||||||
|
|
||||||
|
|
||||||
|
# 确保在文件顶部导入
|
||||||
|
from scipy.signal import find_peaks
|
||||||
|
|
||||||
|
|
||||||
|
def find_hvns_with_distance(price_volume_dist: pd.Series, distance_in_ticks: int) -> List[float]:
|
||||||
|
"""
|
||||||
|
[全局函数] 使用峰值查找算法,根据峰值间的最小距离来识别HVNs。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
price_volume_dist: 价格-成交量分布序列。
|
||||||
|
distance_in_ticks: 两个HVN之间必须间隔的最小tick数量。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
一个包含所有被识别出的HVN价格的列表。
|
||||||
|
"""
|
||||||
|
if price_volume_dist.empty or len(price_volume_dist) < 3:
|
||||||
|
return []
|
||||||
|
|
||||||
|
# distance参数确保找到的峰值之间至少相隔N个点
|
||||||
|
peaks_indices, _ = find_peaks(price_volume_dist.values, distance=distance_in_ticks)
|
||||||
|
|
||||||
|
if len(peaks_indices) == 0:
|
||||||
|
return [price_volume_dist.idxmax()] # 默认返回POC
|
||||||
|
|
||||||
|
hvn_prices = price_volume_dist.index[peaks_indices].tolist()
|
||||||
|
return hvn_prices
|
||||||
|
|
||||||
|
|
||||||
|
# 确保在文件顶部导入
|
||||||
|
from scipy.signal import find_peaks
|
||||||
|
|
||||||
|
|
||||||
|
# =====================================================================================
|
||||||
|
# 以下是V2版本的、简化了状态管理的 HVNPullbackStrategy 代码
|
||||||
|
# =====================================================================================
|
||||||
|
|
||||||
|
# 引入必要的类型,确保代码清晰
|
||||||
|
from typing import Any, Dict, Optional, List
|
||||||
|
import numpy as np
|
||||||
|
import talib
|
||||||
|
|
||||||
|
|
||||||
|
class ValueMigrationStrategy(Strategy):
|
||||||
|
"""
|
||||||
|
一个基于动态HVN突破后回测的量化交易策略。(V3: 集成上下文状态管理)
|
||||||
|
|
||||||
|
V3版本完全集成BacktestContext的状态管理功能,实现了策略重启后的状态恢复。
|
||||||
|
- 状态被简化为两个核心变量:_pending_sl_price 和 _pending_tp_price。
|
||||||
|
- 在策略初始化时安全地加载状态,并兼容空状态或旧版状态。
|
||||||
|
- 在下单或平仓时立即持久化状态,确保数据一致性。
|
||||||
|
- 增加了逻辑检查,处理重启后可能出现的状态与实际持仓不一致的问题。
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
context: Any, # 通常会是 BacktestContext
|
||||||
|
main_symbol: str,
|
||||||
|
enable_log: bool,
|
||||||
|
trade_volume: int,
|
||||||
|
tick_size: float = 1,
|
||||||
|
profile_period: int = 100,
|
||||||
|
recalc_interval: int = 4,
|
||||||
|
hvn_distance_ticks: int = 1,
|
||||||
|
entry_offset_atr: float = 0.0,
|
||||||
|
stop_loss_atr: float = 1.0,
|
||||||
|
take_profit_atr: float = 1.0,
|
||||||
|
atr_period: int = 14,
|
||||||
|
order_direction=None,
|
||||||
|
indicators=[None, None],
|
||||||
|
):
|
||||||
|
super().__init__(context, main_symbol, enable_log)
|
||||||
|
# --- 参数初始化 (保持不变) ---
|
||||||
|
if order_direction is None:
|
||||||
|
order_direction = ['BUY', 'SELL']
|
||||||
|
self.trade_volume = trade_volume
|
||||||
|
self.tick_size = tick_size
|
||||||
|
self.profile_period = profile_period
|
||||||
|
self.recalc_interval = recalc_interval
|
||||||
|
self.hvn_distance_ticks = hvn_distance_ticks
|
||||||
|
self.entry_offset_atr = entry_offset_atr
|
||||||
|
self.stop_loss_atr = stop_loss_atr
|
||||||
|
self.take_profit_atr = take_profit_atr
|
||||||
|
self.atr_period = atr_period
|
||||||
|
self.order_direction = order_direction
|
||||||
|
self.indicator_long = indicators[0]
|
||||||
|
self.indicator_short = indicators[1]
|
||||||
|
self.main_symbol = main_symbol
|
||||||
|
self.order_id_counter = 0
|
||||||
|
self._bar_counter = 0
|
||||||
|
self._cached_hvns: List[float] = []
|
||||||
|
|
||||||
|
# --- 新增: 初始化时加载状态 ---
|
||||||
|
self._pending_sl_price: Optional[float] = None
|
||||||
|
self._pending_tp_price: Optional[float] = None
|
||||||
|
self._load_state_from_context()
|
||||||
|
|
||||||
|
def _get_state_dict(self) -> Dict[str, Any]:
|
||||||
|
"""一个辅助函数,用于生成当前需要保存的状态字典。"""
|
||||||
|
return {
|
||||||
|
"_pending_sl_price": self._pending_sl_price,
|
||||||
|
"_pending_tp_price": self._pending_tp_price,
|
||||||
|
}
|
||||||
|
|
||||||
|
def _load_state_from_context(self):
|
||||||
|
"""
|
||||||
|
[新增] 从上下文中加载状态,并进行健壮性处理。
|
||||||
|
"""
|
||||||
|
loaded_state = self.context.load_state()
|
||||||
|
if not loaded_state:
|
||||||
|
self.log("未找到历史状态,进行全新初始化。")
|
||||||
|
return
|
||||||
|
|
||||||
|
# 使用 .get() 方法安全地读取,即使key不存在或state为空也不会报错。
|
||||||
|
# 这完美解决了“读取的state的key不一样”的问题。
|
||||||
|
self._pending_sl_price = loaded_state.get("_pending_sl_price")
|
||||||
|
self._pending_tp_price = loaded_state.get("_pending_tp_price")
|
||||||
|
|
||||||
|
if self._pending_sl_price is not None:
|
||||||
|
self.log(f"成功从上下文加载状态: SL={self._pending_sl_price}, TP={self._pending_tp_price}")
|
||||||
|
else:
|
||||||
|
self.log("加载的状态为空或格式不兼容,视为全新初始化。")
|
||||||
|
|
||||||
|
def on_open_bar(self, open_price: float, symbol: str):
|
||||||
|
self.symbol = symbol
|
||||||
|
self._bar_counter += 1
|
||||||
|
bar_history = self.get_bar_history()
|
||||||
|
|
||||||
|
required_len = max(self.profile_period, self.atr_period) + 1
|
||||||
|
if len(bar_history) < required_len:
|
||||||
|
return
|
||||||
|
|
||||||
|
# 取消所有挂单,这符合原逻辑,确保每根bar都是新的开始
|
||||||
|
self.cancel_all_pending_orders(self.symbol)
|
||||||
|
|
||||||
|
position_volume = self.get_current_positions().get(self.symbol, 0)
|
||||||
|
|
||||||
|
# --- 新增: 状态一致性检查 ---
|
||||||
|
# 场景:策略重启后,加载了之前的止盈止损状态,但发现实际上并没有持仓
|
||||||
|
# (可能因为上次平仓后、清空状态前程序就关闭了)。
|
||||||
|
# 这种情况下,状态是无效的“幽灵状态”,必须清除。
|
||||||
|
if position_volume == 0 and self._pending_sl_price is not None:
|
||||||
|
self.log("检测到状态与实际持仓不符 (有状态但无持仓),重置本地状态。")
|
||||||
|
self._pending_sl_price = None
|
||||||
|
self._pending_tp_price = None
|
||||||
|
self.context.save_state(self._get_state_dict()) # 立即同步清除后的状态
|
||||||
|
|
||||||
|
# --- 1. 管理现有持仓 (如果存在) ---
|
||||||
|
if position_volume != 0:
|
||||||
|
self.manage_open_position(position_volume, open_price)
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
# 周期性地计算HVNs
|
||||||
|
if self._bar_counter % self.recalc_interval == 1:
|
||||||
|
profile_bars = bar_history[-self.profile_period:]
|
||||||
|
dist = compute_price_volume_distribution(profile_bars, self.tick_size)
|
||||||
|
if dist is not None and not dist.empty:
|
||||||
|
self._cached_hvns = find_hvns_with_distance(dist, self.hvn_distance_ticks)
|
||||||
|
self.log(f"识别到新的高价值节点: {[f'{p:.2f}' for p in self._cached_hvns]}")
|
||||||
|
|
||||||
|
if not self._cached_hvns: return
|
||||||
|
|
||||||
|
# 评估新机会 (挂单逻辑)
|
||||||
|
self.evaluate_entry_signal(bar_history)
|
||||||
|
|
||||||
|
def manage_open_position(self, volume: int, current_price: float):
|
||||||
|
"""
|
||||||
|
[修改] 主动管理已开仓位的止盈止损。
|
||||||
|
不再使用 position_meta,直接依赖实例变量。
|
||||||
|
"""
|
||||||
|
# [关键安全检查]: 如果有持仓,但却没有止盈止损状态,这是一个危险的信号。
|
||||||
|
# 可能是状态文件损坏或逻辑错误。为控制风险,应立即平仓。
|
||||||
|
if self._pending_sl_price is None or self._pending_tp_price is None:
|
||||||
|
self.log("风险警告:存在持仓但无有效的止盈止损价格,立即市价平仓!")
|
||||||
|
self.close_position("CLOSE_LONG" if volume > 0 else "CLOSE_SHORT", abs(volume))
|
||||||
|
return
|
||||||
|
|
||||||
|
sl_price = self._pending_sl_price
|
||||||
|
tp_price = self._pending_tp_price
|
||||||
|
|
||||||
|
# 止盈止损逻辑 (保持不变)
|
||||||
|
if volume > 0: # 多头
|
||||||
|
if current_price <= sl_price or current_price >= tp_price:
|
||||||
|
action = "止损" if current_price <= sl_price else "止盈"
|
||||||
|
self.log(f"多头{action}触发于 {current_price:.2f} (SL: {sl_price}, TP: {tp_price})")
|
||||||
|
self.close_position("CLOSE_LONG", abs(volume))
|
||||||
|
elif volume < 0: # 空头
|
||||||
|
if current_price >= sl_price or current_price <= tp_price:
|
||||||
|
action = "止损" if current_price >= sl_price else "止盈"
|
||||||
|
self.log(f"空头{action}触发于 {current_price:.2f} (SL: {sl_price}, TP: {tp_price})")
|
||||||
|
self.close_position("CLOSE_SHORT", abs(volume))
|
||||||
|
|
||||||
|
def evaluate_entry_signal(self, bar_history: List[Bar]):
|
||||||
|
# [修改] 在挂单前,先重置旧的挂单状态,虽然on_open_bar开头也做了,但这里更保险
|
||||||
|
self._pending_sl_price = None
|
||||||
|
self._pending_tp_price = None
|
||||||
|
|
||||||
|
# ... 原有挂单信号计算逻辑保持不变 ...
|
||||||
|
prev_close = bar_history[-2].close
|
||||||
|
current_close = bar_history[-1].close
|
||||||
|
highs = np.array([b.high for b in bar_history], dtype=float)
|
||||||
|
lows = np.array([b.low for b in bar_history], dtype=float)
|
||||||
|
closes = np.array([b.close for b in bar_history], dtype=float)
|
||||||
|
current_atr = talib.ATR(highs, lows, closes, self.atr_period)[-1]
|
||||||
|
if current_atr < self.tick_size: return
|
||||||
|
|
||||||
|
for hvn in sorted(self._cached_hvns):
|
||||||
|
if "BUY" in self.order_direction and (prev_close < hvn < current_close):
|
||||||
|
direction = "BUY"
|
||||||
|
pass_filter = self.indicator_long is None or self.indicator_long.is_condition_met(
|
||||||
|
*self.get_indicator_tuple())
|
||||||
|
elif "SELL" in self.order_direction and (prev_close > hvn > current_close):
|
||||||
|
direction = "SELL"
|
||||||
|
pass_filter = self.indicator_short is None or self.indicator_short.is_condition_met(
|
||||||
|
*self.get_indicator_tuple())
|
||||||
|
else:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if direction and pass_filter:
|
||||||
|
offset = self.entry_offset_atr * current_atr
|
||||||
|
limit_price = hvn + offset if direction == "BUY" else hvn - offset
|
||||||
|
self.log(f"价格穿越HVN({hvn:.2f})。在 {limit_price:.2f} 挂限价{direction}单。")
|
||||||
|
self.send_hvn_limit_order(direction, limit_price, current_atr)
|
||||||
|
return
|
||||||
|
|
||||||
|
def send_hvn_limit_order(self, direction: str, limit_price: float, entry_atr: float):
|
||||||
|
# 1. 设置实例的止盈止损状态
|
||||||
|
self._pending_sl_price = limit_price - self.stop_loss_atr * entry_atr if direction == "BUY" else limit_price + self.stop_loss_atr * entry_atr
|
||||||
|
self._pending_tp_price = limit_price + self.take_profit_atr * entry_atr if direction == "BUY" else limit_price - self.take_profit_atr * entry_atr
|
||||||
|
|
||||||
|
# 2. [新增] 状态已更新,立即通过上下文持久化
|
||||||
|
self.context.save_state(self._get_state_dict())
|
||||||
|
self.log(f"状态已更新并保存: SL={self._pending_sl_price}, TP={self._pending_tp_price}")
|
||||||
|
|
||||||
|
# 3. 发送订单
|
||||||
|
order_id = f"{self.symbol}_{direction}_LIMIT_{self.order_id_counter}"
|
||||||
|
self.order_id_counter += 1
|
||||||
|
order = Order(
|
||||||
|
id=order_id, symbol=self.symbol, direction=direction, volume=self.trade_volume,
|
||||||
|
price_type="LIMIT", limit_price=limit_price, submitted_time=self.get_current_time(),
|
||||||
|
offset="OPEN"
|
||||||
|
)
|
||||||
|
self.send_order(order)
|
||||||
|
|
||||||
|
def close_position(self, direction: str, volume: int):
|
||||||
|
"""[修改] 平仓时,必须清空状态并立即保存。"""
|
||||||
|
# 1. 发送平仓市价单
|
||||||
|
self.send_market_order(direction, volume)
|
||||||
|
|
||||||
|
# 2. 清空本地的止盈止损状态
|
||||||
|
self._pending_sl_price = None
|
||||||
|
self._pending_tp_price = None
|
||||||
|
|
||||||
|
# 3. [新增] 状态已清空,立即通过上下文持久化这个“空状态”
|
||||||
|
self.context.save_state(self._get_state_dict())
|
||||||
|
self.log("持仓已平,相关的止盈止损状态已清空并保存。")
|
||||||
|
|
||||||
|
def send_market_order(self, direction: str, volume: int, offset: str = "CLOSE"):
|
||||||
|
# ... 此辅助函数保持不变 ...
|
||||||
|
order_id = f"{self.symbol}_{direction}_{offset}_{self.get_current_time().strftime('%Y%m%d%H%M%S')}_{self.order_id_counter}"
|
||||||
|
self.order_id_counter += 1
|
||||||
|
order = Order(
|
||||||
|
id=order_id, symbol=self.symbol, direction=direction, volume=volume,
|
||||||
|
price_type="MARKET", submitted_time=self.get_current_time(), offset=offset
|
||||||
|
)
|
||||||
|
self.send_order(order)
|
||||||
Reference in New Issue
Block a user