简单波动率策略,实现+网格搜索
This commit is contained in:
1012
data/ analysis/Volume.ipynb
Normal file
1012
data/ analysis/Volume.ipynb
Normal file
File diff suppressed because one or more lines are too long
14
data/tqsdk/test.py
Normal file
14
data/tqsdk/test.py
Normal file
@@ -0,0 +1,14 @@
|
||||
from tqsdk import TqApi, TqAuth
|
||||
|
||||
api = TqApi(auth=TqAuth("emanresu", "dfgvfgdfgg"))
|
||||
|
||||
# au 品种指数合约
|
||||
ls = api.query_quotes(ins_class="INDEX", product_id="au")
|
||||
print(ls)
|
||||
|
||||
# au 品种主连合约
|
||||
ls = api.query_quotes(ins_class="CONT", product_id="au")
|
||||
print(ls)
|
||||
|
||||
# 关闭api,释放相应资源
|
||||
api.close()
|
||||
@@ -25,7 +25,7 @@ def collect_and_save_tqsdk_data_stream(
|
||||
output_dir: str = "../data",
|
||||
tq_user: str = TQ_USER_NAME,
|
||||
tq_pwd: str = TQ_PASSWORD
|
||||
) -> pd.DataFrame or None:
|
||||
) -> pd.DataFrame:
|
||||
"""
|
||||
通过 TqSdk 在指定模式下(回测或模拟)运行,监听并收集指定品种、频率、日期范围的K线数据流,
|
||||
并将其保存到本地CSV文件。此函数会模拟 TqSdk 的时间流运行。
|
||||
@@ -190,10 +190,10 @@ if __name__ == "__main__":
|
||||
# 示例1: 在回测模式下获取沪深300指数主连的日线数据 (用于历史回测)
|
||||
# 这种方式适合获取相对较短或中等长度的历史K线数据。
|
||||
df_if_backtest_daily = collect_and_save_tqsdk_data_stream(
|
||||
symbol="SHFE.rb2410",
|
||||
freq="min60",
|
||||
start_date_str="2024-05-01",
|
||||
end_date_str="2024-09-01",
|
||||
symbol="KQ.i@SHFE.rb",
|
||||
freq="day",
|
||||
start_date_str="2023-01-01",
|
||||
end_date_str="2025-05-01",
|
||||
mode="backtest", # 指定为回测模式
|
||||
tq_user=TQ_USER_NAME,
|
||||
tq_pwd=TQ_PASSWORD
|
||||
|
||||
12953
grid_search.ipynb
Normal file
12953
grid_search.ipynb
Normal file
File diff suppressed because it is too large
Load Diff
1332
main.ipynb
1332
main.ipynb
File diff suppressed because one or more lines are too long
140
main_multi.ipynb
140
main_multi.ipynb
File diff suppressed because one or more lines are too long
@@ -124,6 +124,22 @@ def calculate_metrics(
|
||||
"亏损交易次数": losing_count,
|
||||
"平均每次盈利": avg_profit_per_trade,
|
||||
"平均每次亏损": avg_loss_per_trade, # 这个值是负数
|
||||
"InitialCapital": initial_capital,
|
||||
"FinalCapital": final_value,
|
||||
"TotalReturn": total_return,
|
||||
"AnnualizedReturn": annualized_return,
|
||||
"MaxDrawdown": max_drawdown,
|
||||
"SharpeRatio": sharpe_ratio,
|
||||
"CalmarRatio": calmar_ratio,
|
||||
"TotalTrades": len(trades), # All buy and sell trades
|
||||
"TransactionCosts": total_commissions,
|
||||
"TotalRealizedPNL": total_realized_pnl, # New
|
||||
"WinRate": win_rate,
|
||||
"ProfitLossRatio": profit_loss_ratio,
|
||||
"WinningTradesCount": winning_count,
|
||||
"LosingTradesCount": losing_count,
|
||||
"AvgProfitPerTrade": avg_profit_per_trade,
|
||||
"AvgLossPerTrade": avg_loss_per_trade, # This value is negative
|
||||
}
|
||||
|
||||
|
||||
|
||||
90
src/analysis/grid_search_analyzer.py
Normal file
90
src/analysis/grid_search_analyzer.py
Normal file
@@ -0,0 +1,90 @@
|
||||
# src/grid_search_analyzer.py
|
||||
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
from typing import List, Dict, Any, Tuple
|
||||
|
||||
class GridSearchAnalyzer:
|
||||
"""
|
||||
用于分析和可视化网格搜索结果的类。
|
||||
"""
|
||||
def __init__(self, grid_results: List[Dict[str, Any]], optimization_metric: str):
|
||||
"""
|
||||
初始化 GridSearchAnalyzer。
|
||||
|
||||
Args:
|
||||
grid_results (List[Dict[str, Any]]): 包含每个参数组合及其对应优化指标的列表。
|
||||
例如:[{'param1': v1, 'param2': v2, 'metric': m1}, ...]
|
||||
optimization_metric (str): 用于优化的指标名称(例如 'total_return')。
|
||||
"""
|
||||
if not grid_results:
|
||||
raise ValueError("grid_results 列表不能为空。")
|
||||
if optimization_metric not in grid_results[0]:
|
||||
raise ValueError(f"优化指标 '{optimization_metric}' 不在 grid_results 的字典中。")
|
||||
|
||||
self.grid_results = grid_results
|
||||
self.optimization_metric = optimization_metric
|
||||
self.param_names = [k for k in grid_results[0].keys() if k != optimization_metric]
|
||||
|
||||
if len(self.param_names) != 2:
|
||||
raise ValueError("GridSearchAnalyzer 当前只支持分析两个参数的网格搜索结果。")
|
||||
|
||||
self.param1_name = self.param_names[0]
|
||||
self.param2_name = self.param_names[1]
|
||||
|
||||
def find_best_parameters(self) -> Dict[str, Any]:
|
||||
"""
|
||||
找到网格搜索中表现最佳的参数组合和其指标值。
|
||||
"""
|
||||
if not self.grid_results:
|
||||
return {}
|
||||
best_result = max(self.grid_results, key=lambda x: x[self.optimization_metric])
|
||||
print(f"\n--- 最佳参数组合 ---")
|
||||
print(f" {self.param1_name}: {best_result[self.param1_name]}")
|
||||
print(f" {self.param2_name}: {best_result[self.param2_name]}")
|
||||
print(f" {self.optimization_metric}: {best_result[self.optimization_metric]:.4f}")
|
||||
return best_result
|
||||
|
||||
def plot_heatmap(self, title: str = "heatmap"):
|
||||
"""
|
||||
绘制两个参数的热力图。
|
||||
"""
|
||||
if not self.grid_results:
|
||||
print("没有数据用于绘制热力图。")
|
||||
return
|
||||
|
||||
x_values = sorted(list(set(d[self.param1_name] for d in self.grid_results)))
|
||||
y_values = sorted(list(set(d[self.param2_name] for d in self.grid_results)))
|
||||
|
||||
heatmap_matrix = np.zeros((len(y_values), len(x_values)))
|
||||
|
||||
# 填充网格
|
||||
for item in self.grid_results:
|
||||
x_idx = x_values.index(item[self.param1_name])
|
||||
y_idx = y_values.index(item[self.param2_name])
|
||||
heatmap_matrix[y_idx, x_idx] = item[self.optimization_metric]
|
||||
|
||||
fig, ax = plt.subplots(figsize=(10, 8))
|
||||
im = ax.imshow(heatmap_matrix, cmap='viridis', origin='lower',
|
||||
extent=[x_values[0], x_values[-1], y_values[0], y_values[-1]],
|
||||
aspect='auto')
|
||||
|
||||
ax.set_xticks(x_values)
|
||||
ax.set_yticks(y_values)
|
||||
|
||||
ax.set_xlabel(self.param1_name)
|
||||
ax.set_ylabel(self.param2_name)
|
||||
ax.set_title(f"{title}: {self.optimization_metric} vs. {self.param1_name} & {self.param2_name}")
|
||||
|
||||
cbar = fig.colorbar(im, ax=ax)
|
||||
cbar.set_label(self.optimization_metric)
|
||||
|
||||
# 在每个格子上显示数值
|
||||
for i in range(len(y_values)):
|
||||
for j in range(len(x_values)):
|
||||
text = ax.text(x_values[j], y_values[i], f"{heatmap_matrix[i, j]:.2f}",
|
||||
ha="center", va="center", color="w", fontsize=8)
|
||||
|
||||
plt.tight_layout()
|
||||
plt.show()
|
||||
@@ -1,10 +1,16 @@
|
||||
# src/analysis/result_analyzer.py
|
||||
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
from typing import List, Dict, Any, Optional
|
||||
|
||||
# 导入纯函数 (注意相对导入路径的变化)
|
||||
from .analysis_utils import calculate_metrics, plot_equity_and_drawdown_chart, plot_close_price_chart
|
||||
from .analysis_utils import (
|
||||
calculate_metrics,
|
||||
plot_equity_and_drawdown_chart,
|
||||
plot_close_price_chart,
|
||||
)
|
||||
|
||||
# 导入核心数据类 (注意相对导入路径的变化)
|
||||
from ..core_data import PortfolioSnapshot, Trade, Bar
|
||||
|
||||
@@ -14,11 +20,13 @@ class ResultAnalyzer:
|
||||
结果分析器:负责接收回测数据,并提供分析和可视化方法。
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
def __init__(
|
||||
self,
|
||||
portfolio_snapshots: List[PortfolioSnapshot],
|
||||
trade_history: List[Trade],
|
||||
bars: List[Bar],
|
||||
initial_capital: float):
|
||||
initial_capital: float,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
portfolio_snapshots (List[PortfolioSnapshot]): 回测引擎输出的投资组合快照列表。
|
||||
@@ -41,9 +49,7 @@ class ResultAnalyzer:
|
||||
if self._metrics_cache is None:
|
||||
print("正在计算绩效指标...")
|
||||
self._metrics_cache = calculate_metrics(
|
||||
self.portfolio_snapshots,
|
||||
self.trade_history,
|
||||
self.initial_capital
|
||||
self.portfolio_snapshots, self.trade_history, self.initial_capital
|
||||
)
|
||||
print("绩效指标计算完成。")
|
||||
return self._metrics_cache
|
||||
@@ -69,7 +75,8 @@ class ResultAnalyzer:
|
||||
print("\n--- 部分交易明细 (最近5笔) ---")
|
||||
for trade in self.trade_history[-5:]:
|
||||
print(
|
||||
f" {trade.fill_time} | {trade.direction:<10} | {trade.symbol} | Vol: {trade.volume} | Price: {trade.price:.2f} | Commission: {trade.commission:.2f}")
|
||||
f" {trade.fill_time} | {trade.direction:<10} | {trade.symbol} | Vol: {trade.volume} | Price: {trade.price:.2f} | Commission: {trade.commission:.2f}"
|
||||
)
|
||||
else:
|
||||
print("\n没有交易记录。")
|
||||
|
||||
@@ -79,8 +86,11 @@ class ResultAnalyzer:
|
||||
"""
|
||||
print("正在绘制绩效图表...")
|
||||
# plot_performance_chart(self.portfolio_snapshots, self.initial_capital, self.bars)
|
||||
plot_equity_and_drawdown_chart(self.portfolio_snapshots, self.initial_capital,
|
||||
title="Portfolio Equity and Drawdown Curve")
|
||||
plot_equity_and_drawdown_chart(
|
||||
self.portfolio_snapshots,
|
||||
self.initial_capital,
|
||||
title="Portfolio Equity and Drawdown Curve",
|
||||
)
|
||||
|
||||
# 绘制单独的收盘价曲线
|
||||
plot_close_price_chart(self.bars, title="Underlying Asset Close Price")
|
||||
|
||||
@@ -58,6 +58,16 @@ class BacktestContext:
|
||||
# 可以在这里触发策略的on_trade回调(如果策略定义了)
|
||||
return trade
|
||||
|
||||
def cancel_order(self, order_id: str) -> bool:
|
||||
"""
|
||||
策略通过此方法发出交易订单。
|
||||
"""
|
||||
if self._current_bar is None:
|
||||
raise RuntimeError("当前Bar未设置,无法发送订单。")
|
||||
|
||||
return self._simulator.cancel_order(order_id)
|
||||
|
||||
|
||||
def get_current_positions(self) -> Dict[str, int]:
|
||||
"""
|
||||
获取当前模拟器的持仓情况。
|
||||
|
||||
@@ -36,6 +36,7 @@ class BacktestEngine:
|
||||
commission_rate (float): 交易佣金率。
|
||||
"""
|
||||
self.data_manager = data_manager
|
||||
self.initial_capital = initial_capital
|
||||
self.simulator = ExecutionSimulator(
|
||||
initial_capital=initial_capital,
|
||||
slippage_rate=slippage_rate,
|
||||
@@ -76,6 +77,7 @@ class BacktestEngine:
|
||||
|
||||
# 设置当前Bar到Context,供策略访问
|
||||
self.context.set_current_bar(current_bar)
|
||||
self.simulator.update_time(current_time=current_bar.datetime)
|
||||
|
||||
# 更新历史Bar缓存
|
||||
self._history_bars.append(current_bar)
|
||||
@@ -129,7 +131,7 @@ class BacktestEngine:
|
||||
for symbol_held, quantity in positions_to_close.items():
|
||||
if quantity != 0:
|
||||
print(f"[{last_processed_bar.datetime}] 回测结束平仓: 平仓 {symbol_held} ({quantity} 手) @ {last_processed_bar.close:.2f}。")
|
||||
direction = "SELL" if quantity > 0 else "BUY"
|
||||
direction = "CLOSE_LONG" if quantity > 0 else "CLOSE_SELL"
|
||||
volume = abs(quantity)
|
||||
|
||||
# 使用当前合约的最后一根Bar的价格进行平仓
|
||||
@@ -148,6 +150,17 @@ class BacktestEngine:
|
||||
print(f"总计处理了 {len(self.portfolio_snapshots)} 根K线。")
|
||||
print(f"总计发生了 {len(self.trade_history)} 笔交易。")
|
||||
|
||||
final_portfolio_value = 0.0
|
||||
if last_processed_bar:
|
||||
final_portfolio_value = self.simulator.get_portfolio_value(last_processed_bar)
|
||||
else: # 如果数据为空,或者回测根本没跑,则净值为初始资金
|
||||
final_portfolio_value = self.initial_capital
|
||||
|
||||
total_return_percentage = ((final_portfolio_value - self.initial_capital) / self.initial_capital) * 100
|
||||
|
||||
print(f"最终总净值: {final_portfolio_value:.2f}")
|
||||
print(f"总收益率: {total_return_percentage:.2f}%")
|
||||
|
||||
def get_backtest_results(self) -> Dict[str, Any]:
|
||||
"""
|
||||
返回回测结果数据,供结果分析模块使用。
|
||||
|
||||
189
src/backtest_runner.py
Normal file
189
src/backtest_runner.py
Normal file
@@ -0,0 +1,189 @@
|
||||
# src/backtest_runner.py
|
||||
|
||||
import pandas as pd
|
||||
from datetime import datetime
|
||||
from typing import List, Dict, Any, Optional, Type
|
||||
|
||||
# 导入核心组件
|
||||
from .backtest_engine import BacktestEngine
|
||||
from .data_manager import DataManager
|
||||
from .execution_simulator import (
|
||||
ExecutionSimulator,
|
||||
) # 虽然不直接用,但 BacktestEngine 内部会用到
|
||||
from .strategies.base_strategy import Strategy # 用于类型提示
|
||||
from .core_data import Bar, PortfolioSnapshot, Trade, Order # 导入数据结构
|
||||
|
||||
|
||||
def run_multi_segment_backtest(
|
||||
segment_configs: Dict[str, Dict[str, Any]],
|
||||
strategy_class: Type[Strategy],
|
||||
strategy_params: Dict[str, Any],
|
||||
initial_capital: float,
|
||||
slippage_rate: float,
|
||||
commission_rate: float,
|
||||
total_start_date: datetime,
|
||||
total_end_date: datetime,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
运行一个多主力合约的分段回测,并返回合并后的回测结果。
|
||||
|
||||
Args:
|
||||
segment_configs (Dict[str, Dict[str, Any]]): 包含所有合约片段配置的字典。
|
||||
strategy_class (Type[Strategy]): 要回测的策略类。
|
||||
strategy_params (Dict[str, Any]): 传递给策略的参数字典。
|
||||
initial_capital (float): 回测的初始资金。
|
||||
slippage_rate (float): 交易滑点率。
|
||||
commission_rate (float): 交易佣金率。
|
||||
total_start_date (datetime): 总回测的起始日期。
|
||||
total_end_date (datetime): 总回测的结束日期。
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: 包含合并后的 portfolio_snapshots, trade_history, all_bars
|
||||
和 initial_capital 的字典。
|
||||
"""
|
||||
# --- 收集所有片段的回测结果 ---
|
||||
all_combined_snapshots: List[PortfolioSnapshot] = []
|
||||
all_combined_trades: List[Trade] = []
|
||||
all_combined_bars: List[Bar] = []
|
||||
|
||||
# 用于净值平滑的基准值,从初始资金开始
|
||||
last_segment_adjusted_total_value = initial_capital
|
||||
current_cash = initial_capital # 用于传递给下一个片段的初始资金
|
||||
current_positions: Dict[str, int] = {} # 用于传递给下一个片段的初始持仓
|
||||
|
||||
print("\n--- 开始分段回测跨越多个合约 ---")
|
||||
|
||||
# 获取所有合约的有序键,确保按定义的顺序回测
|
||||
segment_keys = list(segment_configs.keys())
|
||||
|
||||
for i, contract_symbol in enumerate(segment_keys):
|
||||
config = segment_configs[contract_symbol]
|
||||
segment_file = config["file"]
|
||||
segment_start_date = config["start"]
|
||||
segment_end_date = config["end"]
|
||||
|
||||
current_segment_start = max(segment_start_date, total_start_date)
|
||||
current_segment_end = min(segment_end_date, total_end_date)
|
||||
|
||||
if current_segment_start > current_segment_end:
|
||||
print(f"跳过片段 {contract_symbol}: 日期范围超出总回测周期或无效。")
|
||||
continue
|
||||
|
||||
print(
|
||||
f"\n--- 回测片段: {contract_symbol} 从 {current_segment_start.strftime('%Y-%m-%d %H:%M')} 到 {current_segment_end.strftime('%Y-%m-%d %H:%M')} ---"
|
||||
)
|
||||
|
||||
# 1. 初始化 DataManager (每个片段使用不同的 DataManager 实例)
|
||||
data_manager = DataManager(file_path=segment_file, symbol=contract_symbol)
|
||||
|
||||
if data_manager.raw_df.empty:
|
||||
print(
|
||||
f"警告: 未能加载 {contract_symbol} 的数据文件 {segment_file}。跳过此片段。"
|
||||
)
|
||||
continue
|
||||
|
||||
strategy_params["symbol"] = contract_symbol
|
||||
# 2. 创建 BacktestEngine 实例 (针对当前合约片段)
|
||||
engine = BacktestEngine(
|
||||
data_manager=data_manager,
|
||||
strategy_class=strategy_class,
|
||||
strategy_params=strategy_params, # 使用传入的策略参数
|
||||
initial_capital=initial_capital,
|
||||
slippage_rate=slippage_rate,
|
||||
commission_rate=commission_rate,
|
||||
current_segment_symbol=contract_symbol,
|
||||
)
|
||||
|
||||
# 在运行当前片段之前,将上一个片段结束时的资金和持仓注入到当前 BacktestEngine 的 simulator 中
|
||||
engine.get_simulator().reset(
|
||||
new_initial_capital=current_cash, new_initial_positions=current_positions
|
||||
)
|
||||
|
||||
# 3. 运行当前片段回测
|
||||
engine.run_backtest()
|
||||
segment_results = engine.get_backtest_results()
|
||||
|
||||
# 4. 获取当前片段结束后的最新资金和持仓状态,供下个片段使用
|
||||
current_cash = engine.get_simulator().cash
|
||||
current_positions = engine.get_simulator().positions.copy()
|
||||
|
||||
# 5. 拼接结果:净值曲线需要特殊处理
|
||||
segment_snapshots = segment_results["portfolio_snapshots"]
|
||||
segment_trades = segment_results["trade_history"]
|
||||
segment_bars = segment_results["all_bars"]
|
||||
|
||||
if segment_snapshots:
|
||||
first_snapshot_value_of_segment = segment_snapshots[0].total_value
|
||||
|
||||
if first_snapshot_value_of_segment == 0:
|
||||
segment_relative_factor = 1.0
|
||||
else:
|
||||
segment_relative_factor = (
|
||||
last_segment_adjusted_total_value / first_snapshot_value_of_segment
|
||||
)
|
||||
|
||||
for snapshot in segment_snapshots:
|
||||
adjusted_snapshot_value = snapshot.total_value * segment_relative_factor
|
||||
adjusted_snapshot = PortfolioSnapshot(
|
||||
datetime=snapshot.datetime,
|
||||
total_value=adjusted_snapshot_value,
|
||||
cash=snapshot.cash,
|
||||
positions=snapshot.positions.copy(),
|
||||
price_at_snapshot=snapshot.price_at_snapshot.copy(),
|
||||
)
|
||||
all_combined_snapshots.append(adjusted_snapshot)
|
||||
|
||||
last_segment_adjusted_total_value = all_combined_snapshots[-1].total_value
|
||||
else:
|
||||
print(
|
||||
f"警告: 未能为 {contract_symbol} 生成任何快照。`last_segment_adjusted_total_value` 保持不变: {last_segment_adjusted_total_value:.2f}。"
|
||||
)
|
||||
|
||||
all_combined_trades.extend(segment_trades)
|
||||
all_combined_bars.extend(segment_bars)
|
||||
|
||||
print("\n--- 总分段回测完成 ---")
|
||||
|
||||
return {
|
||||
"portfolio_snapshots": all_combined_snapshots,
|
||||
"trade_history": all_combined_trades,
|
||||
"all_bars": all_combined_bars,
|
||||
"initial_capital": initial_capital, # 返回初始资金,方便 ResultAnalyzer 使用
|
||||
}
|
||||
|
||||
|
||||
def run_backtest_for_optimization(
|
||||
global_config: Dict[str, Any],
|
||||
strategy_class: Type[Strategy],
|
||||
current_strategy_params: Dict[str, Any],
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
简化后的回测入口,用于优化或参数搜索。
|
||||
它从 global_config 中提取通用回测参数,并结合 current_strategy_params 运行回测。
|
||||
|
||||
Args:
|
||||
global_config (Dict[str, Any]): 包含所有通用全局配置的字典。
|
||||
strategy_class (Type[Strategy]): 要回测的策略类。
|
||||
current_strategy_params (Dict[str, Any]): 当前迭代的策略参数。
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: 包含合并后的 portfolio_snapshots, trade_history, all_bars
|
||||
和 initial_capital 的字典。
|
||||
"""
|
||||
total_start_date = global_config["total_start_date"]
|
||||
total_end_date = global_config["total_end_date"]
|
||||
initial_capital = global_config["initial_capital"]
|
||||
slippage_rate = global_config["slippage_rate"]
|
||||
commission_rate = global_config["commission_rate"]
|
||||
segment_configs = global_config["segment_configs"]
|
||||
|
||||
return run_multi_segment_backtest(
|
||||
segment_configs=segment_configs,
|
||||
strategy_class=strategy_class,
|
||||
strategy_params=current_strategy_params,
|
||||
initial_capital=initial_capital,
|
||||
slippage_rate=slippage_rate,
|
||||
commission_rate=commission_rate,
|
||||
total_start_date=total_start_date,
|
||||
total_end_date=total_end_date,
|
||||
)
|
||||
36
src/common_utils.py
Normal file
36
src/common_utils.py
Normal file
@@ -0,0 +1,36 @@
|
||||
# src/common_utils.py
|
||||
|
||||
from typing import List, Union
|
||||
|
||||
def generate_parameter_range(start: Union[int, float], end: Union[int, float], step: Union[int, float]) -> List[Union[int, float]]:
|
||||
"""
|
||||
根据开始、结束和步长生成一个参数值的列表。
|
||||
支持整数和浮点数。
|
||||
|
||||
Args:
|
||||
start (Union[int, float]): 参数范围的起始值。
|
||||
end (Union[int, float]): 参数范围的结束值。
|
||||
step (Union[int, float]): 参数的步长。
|
||||
|
||||
Returns:
|
||||
List[Union[int, float]]: 生成的参数值列表。
|
||||
"""
|
||||
if step == 0:
|
||||
raise ValueError("步长不能为0。")
|
||||
if (end - start) * step < 0: # 步长方向与范围不符
|
||||
raise ValueError("步长方向与起始/结束值不一致,请检查步长正负。")
|
||||
|
||||
param_range = []
|
||||
current_value = start
|
||||
while (step > 0 and current_value <= end) or (step < 0 and current_value >= end):
|
||||
param_range.append(current_value)
|
||||
current_value += step
|
||||
# 针对浮点数精度问题进行小幅调整,避免无限循环或过早停止
|
||||
if isinstance(step, float) or isinstance(start, float) or isinstance(end, float):
|
||||
current_value = round(current_value, 10) # 四舍五入到一定小数位
|
||||
|
||||
return param_range
|
||||
|
||||
# 示例:
|
||||
# print(generate_parameter_range(0.99, 1.01, 0.005)) # [0.99, 0.995, 1.0, 1.005, 1.01]
|
||||
# print(generate_parameter_range(5, 20, 5)) # [5, 10, 15, 20]
|
||||
@@ -1,5 +1,5 @@
|
||||
# src/execution_simulator.py (修改部分)
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Dict, List, Optional
|
||||
import pandas as pd
|
||||
from .core_data import Order, Trade, Bar, PortfolioSnapshot
|
||||
@@ -36,36 +36,112 @@ class ExecutionSimulator:
|
||||
self.commission_rate = commission_rate
|
||||
self.trade_log: List[Trade] = [] # 存储所有成交记录
|
||||
self.pending_orders: Dict[str, Order] = {} # {order_id: Order_object}
|
||||
self._current_time = None
|
||||
|
||||
print(
|
||||
f"模拟器初始化:初始资金={self.initial_capital:.2f}, 滑点率={self.slippage_rate}, 佣金率={self.commission_rate}")
|
||||
if self.positions:
|
||||
print(f"初始持仓:{self.positions}")
|
||||
|
||||
def update_time(self, current_time: datetime):
|
||||
"""
|
||||
更新模拟器的当前时间。
|
||||
这个方法由 BacktestEngine 在遍历 K 线时调用。
|
||||
"""
|
||||
self._current_time = current_time
|
||||
|
||||
# --- 新增的公共方法 ---
|
||||
def get_current_time(self) -> datetime:
|
||||
"""
|
||||
获取模拟器的当前时间。
|
||||
"""
|
||||
if self._current_time is None:
|
||||
# 可以在这里抛出错误或者返回一个默认值,取决于你对未初始化时间的处理
|
||||
# 抛出错误可以帮助你发现问题,例如在模拟器时间未设置时就尝试获取
|
||||
# raise RuntimeError("Simulator time has not been set. Ensure update_time is called.")
|
||||
return None
|
||||
return self._current_time
|
||||
|
||||
def _calculate_fill_price(self, order: Order, current_bar: Bar) -> float:
|
||||
"""
|
||||
内部方法:根据订单类型和滑点计算实际成交价格。
|
||||
简化处理:市价单以当前Bar收盘价为基准,考虑滑点。
|
||||
- 市价单通常以当前K线的开盘价成交(考虑滑点)。
|
||||
- 限价单判断是否触及限价,如果触及,以限价成交(考虑滑点)。
|
||||
"""
|
||||
base_price = current_bar.close # 简化为收盘价成交
|
||||
fill_price = -1.0 # 默认未成交
|
||||
|
||||
# 考虑滑点
|
||||
if order.direction in ["BUY", "CLOSE_SHORT"]: # 买入或平空,价格向上偏离
|
||||
if order.price_type == "MARKET":
|
||||
# 市价单通常以开盘价成交,或者根据你的策略需求选择收盘价
|
||||
# 这里我们仍然使用开盘价作为市价单的基准成交价
|
||||
base_price = current_bar.open
|
||||
|
||||
if order.direction == "BUY" or order.direction == "CLOSE_SHORT": # 买入或平空,价格向上偏离
|
||||
fill_price = base_price * (1 + self.slippage_rate)
|
||||
elif order.direction in ["SELL", "CLOSE_LONG"]: # 卖出或平多,价格向下偏离
|
||||
elif order.direction == "SELL" or order.direction == "CLOSE_LONG": # 卖出或平多,价格向下偏离
|
||||
fill_price = base_price * (1 - self.slippage_rate)
|
||||
else: # 默认情况,无滑点
|
||||
fill_price = base_price
|
||||
else: # 默认情况,理论上不应该到这里,因为方向应该明确
|
||||
fill_price = base_price # 不考虑滑点
|
||||
|
||||
# 如果是限价单且成交价格不满足条件,则可能不成交
|
||||
if order.price_type == "LIMIT" and order.limit_price is not None:
|
||||
# 对于BUY和CLOSE_SHORT,成交价必须 <= 限价
|
||||
if (order.direction == "BUY" or order.direction == "CLOSE_SHORT") and fill_price > order.limit_price:
|
||||
# 市价单只要有价格就会成交,无需额外价格区间判断
|
||||
|
||||
elif order.price_type == "LIMIT" and order.limit_price is not None:
|
||||
limit_price = order.limit_price
|
||||
high = current_bar.high
|
||||
low = current_bar.low
|
||||
open_price = current_bar.open # 也可以在限价单成交时用开盘价作为实际成交价,或限价本身
|
||||
|
||||
if order.direction == "BUY" or order.direction == "CLOSE_SHORT": # 限价买入或限价平空
|
||||
# 买入:如果K线最低价 <= 限价,则限价单可能成交
|
||||
if low <= limit_price:
|
||||
# 成交价通常是限价,但考虑滑点后可能略高(对买方不利)
|
||||
# 或者可以以 open/low 之间的一个价格成交,这里简化为限价+滑点
|
||||
fill_price_candidate = limit_price * (1 + self.slippage_rate)
|
||||
|
||||
# 确保成交价不会比当前K线的最低价还低(如果不是在最低点成交)
|
||||
# 也可以简单就返回 limit_price * (1 + self.slippage_rate)
|
||||
|
||||
# 如果开盘价低于或等于限价,那么通常会以开盘价成交(或者略差)
|
||||
# 否则,如果价格是从上方跌落到限价区,那么会在限价附近成交
|
||||
# 这里简化:如果K线触及限价,则以限价成交(考虑滑点)。
|
||||
# 更精细的模拟会考虑价格穿越顺序 (例如,是否开盘就跳过了限价)
|
||||
|
||||
# 如果开盘价已经低于或等于限价,那么就以开盘价成交(考虑滑点)
|
||||
if open_price <= limit_price:
|
||||
fill_price = open_price * (1 + self.slippage_rate)
|
||||
else: # 价格从高处跌落到限价
|
||||
fill_price = limit_price * (1 + self.slippage_rate)
|
||||
|
||||
# 确保成交价不会超过限价 (虽然加滑点可能会略超,但这是交易成本的一部分)
|
||||
# 这个检查是为了避免逻辑错误,理论上加滑点后应该可以接受比限价略高的价格
|
||||
# 如果你严格要求成交价不能高于限价,则需要移除滑点或者更复杂的逻辑
|
||||
# 这里我们接受加滑点后的价格
|
||||
if fill_price > high and fill_price > limit_price: # 如果计算出来的成交价高于K线最高价,则可能不合理
|
||||
fill_price = -1.0 # 理论上不该发生,除非滑点过大
|
||||
else:
|
||||
return -1.0 # 未触及限价
|
||||
# 对于SELL和CLOSE_LONG,成交价必须 >= 限价
|
||||
elif (order.direction == "SELL" or order.direction == "CLOSE_LONG") and fill_price < order.limit_price:
|
||||
|
||||
elif order.direction == "SELL" or order.direction == "CLOSE_LONG": # 限价卖出或限价平多
|
||||
# 卖出:如果K线最高价 >= 限价,则限价单可能成交
|
||||
if high >= limit_price:
|
||||
# 成交价通常是限价,但考虑滑点后可能略低(对卖方不利)
|
||||
fill_price_candidate = limit_price * (1 - self.slippage_rate)
|
||||
|
||||
# 如果开盘价已经高于或等于限价,那么就以开盘价成交(考虑滑点)
|
||||
if open_price >= limit_price:
|
||||
fill_price = open_price * (1 - self.slippage_rate)
|
||||
else: # 价格从低处上涨到限价
|
||||
fill_price = limit_price * (1 - self.slippage_rate)
|
||||
|
||||
# 确保成交价不会低于限价
|
||||
if fill_price < low and fill_price < limit_price: # 如果计算出来的成交价低于K线最低价,则可能不合理
|
||||
fill_price = -1.0 # 理论上不该发生
|
||||
else:
|
||||
return -1.0 # 未触及限价
|
||||
|
||||
# 最后检查成交价是否有效
|
||||
if fill_price <= 0:
|
||||
return -1.0 # 如果计算出来价格无效,返回未成交
|
||||
|
||||
return fill_price
|
||||
|
||||
def send_order(self, order: Order, current_bar: Bar) -> Optional[Trade]:
|
||||
@@ -107,6 +183,7 @@ class ExecutionSimulator:
|
||||
if fill_price <= 0: # 表示未成交或不满足限价条件
|
||||
if order.price_type == "LIMIT":
|
||||
self.pending_orders[order.id] = order
|
||||
# print(f'撮合失败,order id:{order.id},fill_price:{fill_price}')
|
||||
return None # 未成交,返回None
|
||||
|
||||
# --- 以下是订单成功成交的逻辑 ---
|
||||
@@ -116,78 +193,90 @@ class ExecutionSimulator:
|
||||
current_position = self.positions.get(symbol, 0)
|
||||
current_average_cost = self.average_costs.get(symbol, 0.0)
|
||||
|
||||
if order.direction == "BUY":
|
||||
actual_direction = order.direction
|
||||
if order.direction == "CLOSE_SHORT":
|
||||
actual_direction = "BUY"
|
||||
elif order.direction == "CLOSE_LONG":
|
||||
actual_direction = "SELL"
|
||||
|
||||
is_close_order_intent = (order.direction == "CLOSE_LONG" or
|
||||
order.direction == "CLOSE_SHORT")
|
||||
|
||||
if actual_direction == "BUY": # 处理买入 (开多 / 平空)
|
||||
# 开多仓或平空仓
|
||||
if current_position >= 0: # 当前持有多仓或无仓位 (开多)
|
||||
is_open_trade = True
|
||||
is_open_trade = not is_close_order_intent # 如果是平仓意图,则不是开仓交易
|
||||
# 更新平均成本 (加权平均)
|
||||
new_total_cost = (current_average_cost * current_position) + (fill_price * volume)
|
||||
new_total_volume = current_position + volume
|
||||
self.average_costs[symbol] = new_total_cost / new_total_volume if new_total_volume > 0 else 0.0
|
||||
self.positions[symbol] = new_total_volume
|
||||
else: # 当前持有空仓 (平空)
|
||||
is_close_trade = True
|
||||
is_close_trade = is_close_order_intent # 这是平仓交易
|
||||
# 计算平空盈亏
|
||||
# PnL = (开仓成本 - 平仓价格) * 平仓数量 (注意空头方向)
|
||||
# 简化:假设平空时,直接使用当前的平均开仓成本来计算盈亏
|
||||
# 更精确的FIFO/LIFO需更多逻辑
|
||||
pnl_per_share = current_average_cost - fill_price # (买入平空,成本高于平仓价则盈利)
|
||||
pnl_per_share = current_average_cost - fill_price
|
||||
realized_pnl = pnl_per_share * volume
|
||||
|
||||
# 更新持仓和成本
|
||||
self.positions[symbol] += volume
|
||||
if self.positions[symbol] == 0:
|
||||
if self.positions[symbol] == 0: # 如果全部平仓
|
||||
del self.positions[symbol]
|
||||
if symbol in self.average_costs: del self.average_costs[symbol] # 清理成本
|
||||
elif self.positions[symbol] > 0 and current_position < 0: # 部分平空转为多头,需重新设置成本
|
||||
# 这部分逻辑可以更复杂,这里简化处理,如果转为多头,成本重置为0
|
||||
# 实际应该用剩余的空头成本 + 新开多的成本
|
||||
self.average_costs[symbol] = fill_price # 简单地将剩下的多头仓位成本设为当前价格
|
||||
if symbol in self.average_costs: del self.average_costs[symbol]
|
||||
elif self.positions[symbol] > 0 and current_position < 0: # 部分平空,且空头仓位被买平为多头仓位
|
||||
# 这是从空头转为多头的复杂情况。需要重新计算平均成本
|
||||
# 简单处理:将剩余的多头仓位成本设为当前价格
|
||||
self.average_costs[symbol] = fill_price
|
||||
|
||||
# 资金扣除
|
||||
if self.cash >= trade_value + commission:
|
||||
self.cash -= (trade_value + commission)
|
||||
else:
|
||||
# print(f"[{current_bar.datetime}] 资金不足,无法执行买入 {volume} {symbol}")
|
||||
print(f"[{current_bar.datetime}] 资金不足,无法执行买入 {volume} {symbol}")
|
||||
return None
|
||||
|
||||
|
||||
elif order.direction == "SELL":
|
||||
elif actual_direction == "SELL": # 处理卖出 (开空 / 平多)
|
||||
# 开空仓或平多仓
|
||||
if current_position <= 0: # 当前持有空仓或无仓位 (开空)
|
||||
is_open_trade = True
|
||||
is_open_trade = not is_close_order_intent # 如果是平仓意图,则不是开仓交易
|
||||
# 更新平均成本 (空头成本为负值)
|
||||
new_total_cost = (current_average_cost * current_position) - (fill_price * volume) # 负的持仓乘以负的卖出价
|
||||
new_total_volume = current_position - volume # 空头持仓量更负
|
||||
self.average_costs[symbol] = new_total_cost / new_total_volume if new_total_volume < 0 else 0.0
|
||||
self.positions[symbol] = new_total_volume
|
||||
# 对于空头,平均成本通常是指你卖出开仓的平均价格
|
||||
# 这里需要根据你的空头成本计算方式来调整
|
||||
# 常见的做法是:总卖出价值 / 总卖出数量
|
||||
new_total_value = (current_average_cost * abs(current_position)) + (fill_price * volume)
|
||||
new_total_volume = abs(current_position) + volume
|
||||
self.average_costs[symbol] = new_total_value / new_total_volume if new_total_volume > 0 else 0.0
|
||||
self.positions[symbol] -= volume # 空头数量增加,持仓量变为负更多
|
||||
|
||||
else: # 当前持有多仓 (平多)
|
||||
is_close_trade = True
|
||||
is_close_trade = is_close_order_intent # 这是平仓交易
|
||||
# 计算平多盈亏
|
||||
# PnL = (平仓价格 - 开仓成本) * 平仓数量
|
||||
pnl_per_share = fill_price - current_average_cost # (卖出平多,平仓价高于成本则盈利)
|
||||
pnl_per_share = fill_price - current_average_cost
|
||||
realized_pnl = pnl_per_share * volume
|
||||
|
||||
# 更新持仓和成本
|
||||
self.positions[symbol] -= volume
|
||||
if self.positions[symbol] == 0:
|
||||
if self.positions[symbol] == 0: # 如果全部平仓
|
||||
del self.positions[symbol]
|
||||
if symbol in self.average_costs: del self.average_costs[symbol] # 清理成本
|
||||
elif self.positions[symbol] < 0 and current_position > 0: # 部分平多转为空头,需重新设置成本
|
||||
self.average_costs[symbol] = fill_price # 简单地将剩下的空头仓位成本设为当前价格
|
||||
if symbol in self.average_costs: del self.average_costs[symbol]
|
||||
elif self.positions[symbol] < 0 and current_position > 0: # 部分平多,且多头仓位被卖平为空头仓位
|
||||
# 从多头转为空头的复杂情况
|
||||
self.average_costs[symbol] = fill_price # 简单将剩余空头仓位成本设为当前价格
|
||||
|
||||
# 资金扣除 (佣金) 和增加 (卖出收入)
|
||||
if self.cash >= commission:
|
||||
self.cash -= commission
|
||||
self.cash += trade_value
|
||||
else:
|
||||
# print(f"[{current_bar.datetime}] 资金不足,无法执行卖出 {volume} {symbol}")
|
||||
print(f"[{current_bar.datetime}] 资金不足,无法执行卖出 {volume} {symbol}")
|
||||
return None
|
||||
else: # 既不是 BUY 也不是 SELL,且也不是 CANCEL。这可能是未知的 direction
|
||||
print(
|
||||
f"[{current_bar.datetime}] 模拟器: 收到未知订单方向 {order.direction} for Order ID: {order.id}. 订单未处理。")
|
||||
return None
|
||||
|
||||
# 创建 Trade 对象
|
||||
executed_trade = Trade(
|
||||
order_id=order.id, fill_time=current_bar.datetime, symbol=symbol,
|
||||
direction=order.direction, # 记录原始订单方向 (BUY/SELL)
|
||||
direction=order.direction, # 记录原始订单方向 (BUY/SELL/CLOSE_X)
|
||||
volume=volume, price=fill_price, commission=commission,
|
||||
cash_after_trade=self.cash, positions_after_trade=self.positions.copy(),
|
||||
realized_pnl=realized_pnl, # 填充实现盈亏
|
||||
@@ -200,8 +289,6 @@ class ExecutionSimulator:
|
||||
if order.id in self.pending_orders:
|
||||
del self.pending_orders[order.id]
|
||||
|
||||
# print(f"[{current_bar.datetime}] 成交: {executed_trade.direction} {executed_trade.volume} {executed_trade.symbol} @ {executed_trade.price:.2f}, 佣金: {executed_trade.commission:.2f}, PnL: {executed_trade.realized_pnl:.2f}")
|
||||
|
||||
return executed_trade
|
||||
|
||||
def cancel_order(self, order_id: str) -> bool:
|
||||
@@ -245,7 +332,7 @@ class ExecutionSimulator:
|
||||
quantity = self.positions[symbol_in_position]
|
||||
# 持仓市值 = 数量 * 当前市场价格 (current_bar.close)
|
||||
# 无论多头(quantity > 0)还是空头(quantity < 0),这个计算都是正确的
|
||||
total_value += quantity * current_bar.close
|
||||
total_value += quantity * current_bar.open
|
||||
|
||||
# 您也可以选择在这里打印调试信息
|
||||
# print(f" DEBUG Portfolio Value Calculation: Cash={self.cash:.2f}, "
|
||||
@@ -287,3 +374,14 @@ class ExecutionSimulator:
|
||||
"""
|
||||
print("ExecutionSimulator: 清空交易历史。")
|
||||
self.trade_history = []
|
||||
|
||||
def get_average_position_price(self, symbol: str) -> Optional[float]:
|
||||
"""
|
||||
获取指定合约的平均持仓成本。
|
||||
如果无持仓或无该合约记录,返回 None。
|
||||
"""
|
||||
# 返回 average_costs 字典中对应 symbol 的值
|
||||
# 如果没有持仓或者没有记录,返回 None
|
||||
if symbol in self.positions and self.positions[symbol] != 0:
|
||||
return self.average_costs.get(symbol)
|
||||
return None
|
||||
|
||||
0
src/research/__init__.py
Normal file
0
src/research/__init__.py
Normal file
99766
src/research/grid_search.ipynb
Normal file
99766
src/research/grid_search.ipynb
Normal file
File diff suppressed because one or more lines are too long
183
src/strategies/SimpleLimitBuyStrategy.py
Normal file
183
src/strategies/SimpleLimitBuyStrategy.py
Normal file
@@ -0,0 +1,183 @@
|
||||
# src/strategies/simple_limit_buy_strategy.py
|
||||
|
||||
from .base_strategy import Strategy
|
||||
from ..core_data import Bar, Order
|
||||
from typing import Optional, Dict, Any
|
||||
from collections import deque
|
||||
|
||||
class SimpleLimitBuyStrategy(Strategy):
|
||||
"""
|
||||
一个基于当前K线Open、前1根和前7根K线Range计算优势价格进行限价买入的策略。
|
||||
具备以下特点:
|
||||
- 每根K线开始时取消上一根K线未成交的订单。
|
||||
- 最多只能有一个开仓挂单和一个持仓。
|
||||
- 包含简单的止损和止盈逻辑。
|
||||
"""
|
||||
def __init__(self, simulator: Any, symbol: str, enable_log: bool, trade_volume: int,
|
||||
open_range_factor_1_ago: float,
|
||||
open_range_factor_7_ago: float,
|
||||
max_position: int,
|
||||
stop_loss_points: float = 10, # 新增:止损点数
|
||||
take_profit_points: float = 10): # 新增:止盈点数
|
||||
"""
|
||||
初始化策略。
|
||||
Args:
|
||||
simulator: 模拟器实例。
|
||||
symbol (str): 交易合约代码。
|
||||
trade_volume (int): 单笔交易量。
|
||||
open_range_factor_1_ago (float): 前1根K线Range的权重因子,用于从Open价向下偏移。
|
||||
open_range_factor_7_ago (float): 前7根K线Range的权重因子,用于从Open价向下偏移。
|
||||
max_position (int): 最大持仓量(此处为1,因为只允许一个持仓)。
|
||||
stop_loss_points (float): 止损点数(例如,亏损达到此点数则止损)。
|
||||
take_profit_points (float): 止盈点数(例如,盈利达到此点数则止盈)。
|
||||
"""
|
||||
super().__init__(simulator, symbol, enable_log)
|
||||
self.trade_volume = trade_volume
|
||||
self.open_range_factor_1_ago = open_range_factor_1_ago
|
||||
self.open_range_factor_7_ago = open_range_factor_7_ago
|
||||
self.max_position = max_position # 理论上这里应为1
|
||||
self.stop_loss_points = stop_loss_points
|
||||
self.take_profit_points = take_profit_points
|
||||
|
||||
self.order_id_counter = 0
|
||||
|
||||
self._bar_history: deque[Bar] = deque(maxlen=7)
|
||||
self._last_order_id: Optional[str] = None # 用于跟踪上一根K线发出的订单ID
|
||||
|
||||
self.log(f"策略初始化: symbol={self.symbol}, trade_volume={self.trade_volume}, "
|
||||
f"open_range_factor_1_ago={self.open_range_factor_1_ago}, "
|
||||
f"open_range_factor_7_ago={self.open_range_factor_7_ago}, "
|
||||
f"max_position={self.max_position}, "
|
||||
f"止损点={self.stop_loss_points}, 止盈点={self.take_profit_points}")
|
||||
|
||||
def on_bar(self, bar: Bar, next_bar_open: Optional[float] = None):
|
||||
"""
|
||||
每当新的K线数据到来时调用。
|
||||
Args:
|
||||
bar (Bar): 当前的K线数据对象。
|
||||
next_bar_open (Optional[float]): 下一根K线的开盘价,此处策略未使用。
|
||||
"""
|
||||
current_datetime = bar.datetime # 获取当前K线时间
|
||||
|
||||
# --- 1. 撤销上一根K线未成交的订单 ---
|
||||
# 检查是否记录了上一笔订单ID,并且该订单仍然在待处理列表中
|
||||
if self._last_order_id:
|
||||
pending_orders = self.get_pending_orders()
|
||||
if self._last_order_id in pending_orders:
|
||||
success = self.cancel_order(self._last_order_id) # 直接调用基类的取消方法
|
||||
if success:
|
||||
self.log(f"[{current_datetime}] 策略: 成功撤销上一根K线未成交订单 {self._last_order_id}")
|
||||
else:
|
||||
self.log(f"[{current_datetime}] 策略: 尝试撤销订单 {self._last_order_id} 失败(可能已成交或不存在)")
|
||||
# 无论撤销成功与否,既然我们尝试了撤销,就清除记录
|
||||
self._last_order_id = None
|
||||
# else:
|
||||
# self.log(f"[{current_datetime}] 策略: 无上一根K线未成交订单需要撤销。")
|
||||
|
||||
|
||||
# 2. 更新K线历史
|
||||
self._bar_history.append(bar)
|
||||
trade_volume = self.trade_volume
|
||||
|
||||
# 获取当前持仓和未决订单(在取消之后获取,确保是最新的状态)
|
||||
current_positions = self.get_current_positions()
|
||||
current_pos_volume = current_positions.get(self.symbol, 0)
|
||||
pending_orders_after_cancel = self.get_pending_orders() # 再次获取,此时应已取消旧订单
|
||||
|
||||
# --- 3. 平仓逻辑 (止损/止盈) ---
|
||||
# 只有当有持仓时才考虑平仓
|
||||
if current_pos_volume > 0: # 假设只做多,所以持仓量 > 0
|
||||
|
||||
avg_entry_price = self.get_average_position_price(self.symbol)
|
||||
if avg_entry_price is not None:
|
||||
pnl_per_unit = bar.close - avg_entry_price # 当前浮动盈亏(以收盘价计算)
|
||||
|
||||
# 止盈条件
|
||||
if pnl_per_unit >= self.take_profit_points:
|
||||
self.log(f"[{current_datetime}] 止盈信号 - PnL per unit: {pnl_per_unit:.2f}, 目标: {self.take_profit_points:.2f}")
|
||||
|
||||
order_id = f"{self.symbol}_BUY_{bar.datetime.strftime('%Y%m%d%H%M%S')}_{self.order_id_counter}"
|
||||
self.order_id_counter += 1
|
||||
|
||||
# 创建一个限价多单
|
||||
order = Order(
|
||||
id=order_id,
|
||||
symbol=self.symbol,
|
||||
direction="CLOSE_LONG",
|
||||
volume=trade_volume,
|
||||
price_type="MARKET",
|
||||
# limit_price=limit_price,
|
||||
submitted_time=bar.datetime
|
||||
)
|
||||
trade = self.send_order(order)
|
||||
return # 平仓后本K线不再进行开仓判断
|
||||
|
||||
# 止损条件
|
||||
elif pnl_per_unit <= -self.stop_loss_points:
|
||||
self.log(f"[{current_datetime}] 止损信号 - PnL per unit: {pnl_per_unit:.2f}, 目标: {-self.stop_loss_points:.2f}")
|
||||
# 发送市价卖出订单平仓,确保立即成交
|
||||
order_id = f"{self.symbol}_BUY_{bar.datetime.strftime('%Y%m%d%H%M%S')}_{self.order_id_counter}"
|
||||
self.order_id_counter += 1
|
||||
|
||||
# 创建一个限价多单
|
||||
order = Order(
|
||||
id=order_id,
|
||||
symbol=self.symbol,
|
||||
direction="CLOSE_LONG",
|
||||
volume=trade_volume,
|
||||
price_type="MARKET",
|
||||
# limit_price=limit_price,
|
||||
submitted_time=bar.datetime
|
||||
)
|
||||
trade = self.send_order(order)
|
||||
return # 平仓后本K线不再进行开仓判断
|
||||
|
||||
# --- 4. 开仓逻辑 (只考虑做多 BUY 方向) ---
|
||||
# 只有在没有持仓 (current_pos_volume == 0) 且没有待处理订单 (not pending_orders_after_cancel)
|
||||
# 且K线历史足够长时才考虑开仓
|
||||
if current_pos_volume == 0 and \
|
||||
len(self._bar_history) == self._bar_history.maxlen:
|
||||
|
||||
# 获取前1根K线 (倒数第二根) 和前7根K线 (队列中最老的一根)
|
||||
bar_1_ago = self._bar_history[-2]
|
||||
bar_7_ago = self._bar_history[0]
|
||||
|
||||
# 计算历史 K 线的 Range
|
||||
range_1_ago = bar_1_ago.high - bar_1_ago.low
|
||||
range_7_ago = bar_7_ago.high - bar_7_ago.low
|
||||
|
||||
# 根据策略逻辑计算目标买入价格
|
||||
# 目标买入价 = 当前K线Open - (前1根Range * 因子1 + 前7根Range * 因子2)
|
||||
self.log(bar.open ,range_1_ago * self.open_range_factor_1_ago, range_7_ago * self.open_range_factor_7_ago)
|
||||
target_buy_price = bar.open - (range_1_ago * self.open_range_factor_1_ago + range_7_ago * self.open_range_factor_7_ago)
|
||||
|
||||
# 确保目标买入价格有效,例如不能是负数
|
||||
target_buy_price = max(0.01, target_buy_price)
|
||||
|
||||
self.log(f"[{current_datetime}] 开多仓信号 - 当前Open={bar.open:.2f}, "
|
||||
f"前1Range={range_1_ago:.2f}, 前7Range={range_7_ago:.2f}, "
|
||||
f"计算目标买入价={target_buy_price:.2f}")
|
||||
|
||||
order_id = f"{self.symbol}_BUY_{bar.datetime.strftime('%Y%m%d%H%M%S')}_{self.order_id_counter}"
|
||||
self.order_id_counter += 1
|
||||
|
||||
# 创建一个限价多单
|
||||
order = Order(
|
||||
id=order_id,
|
||||
symbol=self.symbol,
|
||||
direction="BUY",
|
||||
volume=trade_volume,
|
||||
price_type="LIMIT",
|
||||
limit_price=target_buy_price,
|
||||
submitted_time=bar.datetime
|
||||
)
|
||||
new_order = self.send_order(order)
|
||||
# 记录下这个订单的ID,以便在下一根K线开始时进行撤销
|
||||
if new_order:
|
||||
self._last_order_id = new_order.order_id
|
||||
self.log(f"[{current_datetime}] 策略: 发送限价买入订单 {self._last_order_id} @ {target_buy_price:.2f}")
|
||||
else:
|
||||
self.log(f"[{current_datetime}] 策略: 发送订单失败。")
|
||||
|
||||
# else:
|
||||
# self.log(f"[{current_datetime}] 不满足开仓条件:持仓={current_pos_volume}, 待处理订单={len(pending_orders_after_cancel)}, K线历史长度={len(self._bar_history)}")
|
||||
@@ -1,42 +1,31 @@
|
||||
# src/strategies/base_strategy.py
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, Any, Optional
|
||||
from datetime import datetime
|
||||
from typing import Dict, Any, Optional, List, TYPE_CHECKING
|
||||
|
||||
# 使用 TYPE_CHECKING 避免循环导入,但保留类型提示
|
||||
from ..backtest_context import BacktestContext # 转发引用 BacktestEngine
|
||||
from ..core_data import Bar, Order, Trade # 导入必要的类型
|
||||
|
||||
# 导入核心数据类
|
||||
from ..core_data import Bar, Order, Trade
|
||||
# 导入回测上下文 (注意相对导入路径的变化)
|
||||
from ..backtest_context import BacktestContext
|
||||
|
||||
class Strategy(ABC):
|
||||
"""
|
||||
策略抽象基类。所有具体策略都应继承此类,并实现 on_bar 方法。
|
||||
所有交易策略的抽象基类。
|
||||
策略通过 context 对象与回测引擎和模拟器进行交互,并提供辅助方法。
|
||||
"""
|
||||
def __init__(self, context: BacktestContext, **parameters: Any):
|
||||
"""
|
||||
初始化策略。
|
||||
|
||||
def __init__(self, context: 'BacktestContext', symbol: str, enable_log: bool = True, **params: Any):
|
||||
"""
|
||||
Args:
|
||||
context (BacktestContext): 回测上下文对象,用于与模拟器和数据管理器交互。
|
||||
**parameters (Any): 策略所需的任何自定义参数。
|
||||
context (BacktestEngine): 回测引擎实例,作为策略的上下文,提供与模拟器等的交互接口。
|
||||
symbol (str): 策略操作的合约Symbol。
|
||||
**params (Any): 其他策略特定参数。
|
||||
"""
|
||||
self.context = context
|
||||
self.parameters = parameters
|
||||
self.symbol = parameters.get('symbol', "DEFAULT_SYMBOL") # 策略操作的品种
|
||||
self.trade_volume = parameters.get('trade_volume', 100) # 每次下单的数量
|
||||
print(f"策略初始化: {self.__class__.__name__},参数: {self.parameters}")
|
||||
|
||||
@abstractmethod
|
||||
def on_bar(self, bar: Bar):
|
||||
"""
|
||||
每当接收到新的Bar数据时调用。
|
||||
具体策略逻辑在此方法中实现。
|
||||
|
||||
Args:
|
||||
bar (Bar): 当前的Bar数据对象。
|
||||
features (Dict[str, float]): 由数据处理模块计算并传入的特征字典。
|
||||
"""
|
||||
pass
|
||||
self.context = context # 存储 context 对象
|
||||
self.symbol = symbol # 策略操作的合约Symbol
|
||||
self.params = params
|
||||
self.enable_log = enable_log
|
||||
|
||||
def on_init(self):
|
||||
"""
|
||||
@@ -45,7 +34,7 @@ class Strategy(ABC):
|
||||
"""
|
||||
print(f"{self.__class__.__name__} 策略初始化回调被调用。")
|
||||
|
||||
def on_trade(self, trade: Trade):
|
||||
def on_trade(self, trade: 'Trade'):
|
||||
"""
|
||||
当模拟器成功执行一笔交易时调用。
|
||||
可用于更新策略内部持仓状态或记录交易。
|
||||
@@ -56,13 +45,92 @@ class Strategy(ABC):
|
||||
# print(f"策略接收到交易: {trade.direction} {trade.volume} {trade.symbol} @ {trade.price:.2f}")
|
||||
pass # 默认不执行任何操作,具体策略可覆盖
|
||||
|
||||
def on_order_status(self, order: Order, status: str):
|
||||
@abstractmethod
|
||||
def on_bar(self, bar: 'Bar'):
|
||||
"""
|
||||
当订单状态更新时调用 (例如,未成交,已提交等)。
|
||||
在简易回测中,可能不会频繁使用。
|
||||
|
||||
每当新的K线数据到来时调用此方法。
|
||||
Args:
|
||||
order (Order): 相关订单对象。
|
||||
status (str): 订单状态(例如 "FILLED", "PENDING", "CANCELLED")。
|
||||
bar (Bar): 当前的K线数据对象。
|
||||
next_bar_open (Optional[float]): 下一根K线的开盘价,如果存在的话。
|
||||
"""
|
||||
pass # 默认不执行任何操作
|
||||
pass
|
||||
|
||||
# --- 新增/修改的辅助方法 ---
|
||||
|
||||
def send_order(self, order: 'Order') -> Optional[Trade]:
|
||||
"""
|
||||
发送订单的辅助方法。
|
||||
会在 BaseStrategy 内部构建 Order 对象,并通过 context 转发给模拟器。
|
||||
"""
|
||||
return self.context.send_order(order)
|
||||
|
||||
def cancel_order(self, order_id: str) -> bool:
|
||||
"""
|
||||
取消指定ID的订单。
|
||||
通过 context 调用模拟器的 cancel_order 方法。
|
||||
"""
|
||||
|
||||
return self.context.cancel_order(order_id)
|
||||
|
||||
def cancel_all_pending_orders(self) -> int:
|
||||
"""
|
||||
取消所有当前策略的未决订单。
|
||||
返回成功取消的订单数量。
|
||||
"""
|
||||
pending_orders = self.get_pending_orders() # 调用 BaseStrategy 自己的 get_pending_orders
|
||||
cancelled_count = 0
|
||||
orders_to_cancel = [order.id for order in pending_orders.values() if order.symbol == self.symbol]
|
||||
for order_id in orders_to_cancel:
|
||||
if self.cancel_order(order_id): # 调用 BaseStrategy 自己的 cancel_order
|
||||
cancelled_count += 1
|
||||
return cancelled_count
|
||||
|
||||
def get_current_positions(self) -> Dict[str, int]:
|
||||
"""
|
||||
获取当前持仓。
|
||||
通过 context 调用模拟器的 get_positions 方法。
|
||||
"""
|
||||
return self.context._simulator.get_current_positions()
|
||||
|
||||
def get_pending_orders(self) -> Dict[str, 'Order']:
|
||||
"""
|
||||
获取当前所有待处理订单的副本。
|
||||
通过 context 调用模拟器的 get_pending_orders 方法。
|
||||
"""
|
||||
return self.context._simulator.get_pending_orders()
|
||||
|
||||
def get_average_position_price(self, symbol: str) -> Optional[float]:
|
||||
"""
|
||||
获取指定合约的平均持仓成本。
|
||||
通过 context 调用模拟器的 get_average_position_price 方法。
|
||||
"""
|
||||
return self.context._simulator.get_average_position_price(symbol)
|
||||
|
||||
# 你可以根据需要在这里添加更多辅助方法,如获取账户净值等
|
||||
def get_account_cash(self) -> float:
|
||||
"""获取当前账户现金余额。"""
|
||||
return self.context._simulator.cash
|
||||
|
||||
|
||||
def log(self, *args: Any, **kwargs: Any):
|
||||
"""
|
||||
统一的日志打印方法。
|
||||
如果 enable_log 为 True,则打印消息到控制台,并包含当前模拟时间。
|
||||
支持传入多个参数,像 print() 函数一样使用。
|
||||
"""
|
||||
if self.enable_log:
|
||||
# 尝试获取当前模拟时间,如果模拟器或时间不可用,则跳过时间前缀
|
||||
try:
|
||||
current_time_str = self.context._simulator.get_current_time().strftime('%Y-%m-%d %H:%M:%S')
|
||||
time_prefix = f"[{current_time_str}] "
|
||||
except AttributeError:
|
||||
# 如果获取不到时间(例如在策略初始化时,模拟器时间还未设置),则不加时间前缀
|
||||
time_prefix = ""
|
||||
|
||||
# 使用 f-string 结合 *args 来构建消息
|
||||
# print() 函数会将 *args 自动用空格分隔,这里我们模仿这个行为
|
||||
message = ' '.join(map(str, args))
|
||||
|
||||
# 你可以将其他 kwargs (如 sep, end, file, flush) 传递给 print,
|
||||
# 但通常日志方法不会频繁使用这些。这里只支持最基础的打印。
|
||||
print(f"{time_prefix}策略 ({self.symbol}): {message}", **kwargs)
|
||||
@@ -19,9 +19,10 @@ class SimpleLimitBuyStrategy(Strategy):
|
||||
|
||||
def __init__(self, context: Any, **parameters: Any): # context 类型提示可以为 BacktestContext
|
||||
super().__init__(context, **parameters)
|
||||
self.trade_volume = 1
|
||||
self.order_id_counter = 0
|
||||
self.limit_price_factor = parameters.get('limit_price_factor', 0.999) # 限价因子
|
||||
self.max_position = parameters.get('max_position', self.trade_volume * 2) # 最大持仓量
|
||||
self.max_position = parameters.get('max_position', self.trade_volume) # 最大持仓量
|
||||
|
||||
self._last_order_id: Optional[str] = None # 跟踪上一根Bar发出的订单ID
|
||||
self._current_long_position: int = 0 # 策略内部维护的当前持仓
|
||||
@@ -57,16 +58,14 @@ class SimpleLimitBuyStrategy(Strategy):
|
||||
每接收到一根Bar时,执行策略逻辑。
|
||||
"""
|
||||
current_portfolio_value = self.context.get_current_portfolio_value(bar)
|
||||
print(f"[{bar.datetime}] Strategy processing Bar. Current close price: {bar.close:.2f}. Current Portfolio Value: {current_portfolio_value:.2f}")
|
||||
# print(f"[{bar.datetime}] Strategy processing Bar. Current close price: {bar.close:.2f}. Current Portfolio Value: {current_portfolio_value:.2f}")
|
||||
|
||||
# 1. 撤销上一根K线未成交的订单
|
||||
if self._last_order_id:
|
||||
# 检查这个订单是否仍然在待处理订单列表中
|
||||
pending_orders = self.context._simulator.get_pending_orders() # 直接访问模拟器,或者通过context提供接口
|
||||
pending_orders = self.get_pending_orders() # 直接访问模拟器,或者通过context提供接口
|
||||
if self._last_order_id in pending_orders:
|
||||
success = self.context.send_order(Order(id=self._last_order_id, symbol=self.symbol,
|
||||
direction="CANCEL", volume=0,
|
||||
price_type="CANCEL")) # 使用一个特殊Order类型表示撤单
|
||||
success = self.cancel_order(self._last_order_id)
|
||||
# 这里发送的“撤单订单”会被simulator的send_order处理,并调用simulator.cancel_order
|
||||
if success: # simulator.send_order返回Trade或None,这里我们用一个特殊处理
|
||||
# Simulator的send_order返回的是Trade,如果实现撤单,最好Simulator的cancel_order返回bool
|
||||
@@ -109,10 +108,10 @@ class SimpleLimitBuyStrategy(Strategy):
|
||||
)
|
||||
|
||||
# 通过上下文发送订单
|
||||
trade = self.context.send_order(order)
|
||||
trade = self.send_order(order)
|
||||
if trade:
|
||||
print(
|
||||
f"[{bar.datetime}] 策略: 发送并立即成交限价买单 {trade.volume} 股 @ {trade.price:.2f} (订单ID: {order.id})")
|
||||
f"[{bar.datetime}] 策略: 发送并立即成交限价买单 {trade.volume} 股 @ {trade.price:.2f}(open:{bar.open}, close:{bar.close}) (订单ID: {order.id})")
|
||||
# 如果立即成交,_last_order_id 仍然保持 None
|
||||
else:
|
||||
# 如果未立即成交,将订单ID记录下来,以便下一根Bar撤销
|
||||
|
||||
Reference in New Issue
Block a user