简单波动率策略,实现+网格搜索
This commit is contained in:
1012
data/ analysis/Volume.ipynb
Normal file
1012
data/ analysis/Volume.ipynb
Normal file
File diff suppressed because one or more lines are too long
14
data/tqsdk/test.py
Normal file
14
data/tqsdk/test.py
Normal file
@@ -0,0 +1,14 @@
|
|||||||
|
from tqsdk import TqApi, TqAuth
|
||||||
|
|
||||||
|
api = TqApi(auth=TqAuth("emanresu", "dfgvfgdfgg"))
|
||||||
|
|
||||||
|
# au 品种指数合约
|
||||||
|
ls = api.query_quotes(ins_class="INDEX", product_id="au")
|
||||||
|
print(ls)
|
||||||
|
|
||||||
|
# au 品种主连合约
|
||||||
|
ls = api.query_quotes(ins_class="CONT", product_id="au")
|
||||||
|
print(ls)
|
||||||
|
|
||||||
|
# 关闭api,释放相应资源
|
||||||
|
api.close()
|
||||||
@@ -25,7 +25,7 @@ def collect_and_save_tqsdk_data_stream(
|
|||||||
output_dir: str = "../data",
|
output_dir: str = "../data",
|
||||||
tq_user: str = TQ_USER_NAME,
|
tq_user: str = TQ_USER_NAME,
|
||||||
tq_pwd: str = TQ_PASSWORD
|
tq_pwd: str = TQ_PASSWORD
|
||||||
) -> pd.DataFrame or None:
|
) -> pd.DataFrame:
|
||||||
"""
|
"""
|
||||||
通过 TqSdk 在指定模式下(回测或模拟)运行,监听并收集指定品种、频率、日期范围的K线数据流,
|
通过 TqSdk 在指定模式下(回测或模拟)运行,监听并收集指定品种、频率、日期范围的K线数据流,
|
||||||
并将其保存到本地CSV文件。此函数会模拟 TqSdk 的时间流运行。
|
并将其保存到本地CSV文件。此函数会模拟 TqSdk 的时间流运行。
|
||||||
@@ -190,10 +190,10 @@ if __name__ == "__main__":
|
|||||||
# 示例1: 在回测模式下获取沪深300指数主连的日线数据 (用于历史回测)
|
# 示例1: 在回测模式下获取沪深300指数主连的日线数据 (用于历史回测)
|
||||||
# 这种方式适合获取相对较短或中等长度的历史K线数据。
|
# 这种方式适合获取相对较短或中等长度的历史K线数据。
|
||||||
df_if_backtest_daily = collect_and_save_tqsdk_data_stream(
|
df_if_backtest_daily = collect_and_save_tqsdk_data_stream(
|
||||||
symbol="SHFE.rb2410",
|
symbol="KQ.i@SHFE.rb",
|
||||||
freq="min60",
|
freq="day",
|
||||||
start_date_str="2024-05-01",
|
start_date_str="2023-01-01",
|
||||||
end_date_str="2024-09-01",
|
end_date_str="2025-05-01",
|
||||||
mode="backtest", # 指定为回测模式
|
mode="backtest", # 指定为回测模式
|
||||||
tq_user=TQ_USER_NAME,
|
tq_user=TQ_USER_NAME,
|
||||||
tq_pwd=TQ_PASSWORD
|
tq_pwd=TQ_PASSWORD
|
||||||
|
|||||||
12953
grid_search.ipynb
Normal file
12953
grid_search.ipynb
Normal file
File diff suppressed because it is too large
Load Diff
1334
main.ipynb
1334
main.ipynb
File diff suppressed because one or more lines are too long
144
main_multi.ipynb
144
main_multi.ipynb
File diff suppressed because one or more lines are too long
@@ -9,7 +9,7 @@ from ..core_data import PortfolioSnapshot, Trade, Bar
|
|||||||
|
|
||||||
|
|
||||||
def calculate_metrics(
|
def calculate_metrics(
|
||||||
snapshots: List[PortfolioSnapshot], trades: List[Trade], initial_capital: float
|
snapshots: List[PortfolioSnapshot], trades: List[Trade], initial_capital: float
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
纯函数:根据投资组合快照和交易历史计算关键绩效指标。
|
纯函数:根据投资组合快照和交易历史计算关键绩效指标。
|
||||||
@@ -124,11 +124,27 @@ def calculate_metrics(
|
|||||||
"亏损交易次数": losing_count,
|
"亏损交易次数": losing_count,
|
||||||
"平均每次盈利": avg_profit_per_trade,
|
"平均每次盈利": avg_profit_per_trade,
|
||||||
"平均每次亏损": avg_loss_per_trade, # 这个值是负数
|
"平均每次亏损": avg_loss_per_trade, # 这个值是负数
|
||||||
|
"InitialCapital": initial_capital,
|
||||||
|
"FinalCapital": final_value,
|
||||||
|
"TotalReturn": total_return,
|
||||||
|
"AnnualizedReturn": annualized_return,
|
||||||
|
"MaxDrawdown": max_drawdown,
|
||||||
|
"SharpeRatio": sharpe_ratio,
|
||||||
|
"CalmarRatio": calmar_ratio,
|
||||||
|
"TotalTrades": len(trades), # All buy and sell trades
|
||||||
|
"TransactionCosts": total_commissions,
|
||||||
|
"TotalRealizedPNL": total_realized_pnl, # New
|
||||||
|
"WinRate": win_rate,
|
||||||
|
"ProfitLossRatio": profit_loss_ratio,
|
||||||
|
"WinningTradesCount": winning_count,
|
||||||
|
"LosingTradesCount": losing_count,
|
||||||
|
"AvgProfitPerTrade": avg_profit_per_trade,
|
||||||
|
"AvgLossPerTrade": avg_loss_per_trade, # This value is negative
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def plot_equity_and_drawdown_chart(snapshots: List[PortfolioSnapshot], initial_capital: float,
|
def plot_equity_and_drawdown_chart(snapshots: List[PortfolioSnapshot], initial_capital: float,
|
||||||
title: str = "Portfolio Equity and Drawdown Curve") -> None:
|
title: str = "Portfolio Equity and Drawdown Curve") -> None:
|
||||||
"""
|
"""
|
||||||
Plots the portfolio equity curve and drawdown. X-axis points are equally spaced.
|
Plots the portfolio equity curve and drawdown. X-axis points are equally spaced.
|
||||||
|
|
||||||
@@ -145,7 +161,7 @@ def plot_equity_and_drawdown_chart(snapshots: List[PortfolioSnapshot], initial_c
|
|||||||
{'datetime': s.datetime, 'total_value': s.total_value}
|
{'datetime': s.datetime, 'total_value': s.total_value}
|
||||||
for s in snapshots
|
for s in snapshots
|
||||||
])
|
])
|
||||||
|
|
||||||
equity_curve = df_equity['total_value'] / initial_capital
|
equity_curve = df_equity['total_value'] / initial_capital
|
||||||
|
|
||||||
rolling_max = equity_curve.cummax()
|
rolling_max = equity_curve.cummax()
|
||||||
@@ -203,7 +219,7 @@ def plot_close_price_chart(bars: List[Bar], title: str = "Close Price Chart") ->
|
|||||||
])
|
])
|
||||||
|
|
||||||
plt.style.use('seaborn-v0_8-darkgrid')
|
plt.style.use('seaborn-v0_8-darkgrid')
|
||||||
fig, ax = plt.subplots(1, 1, figsize=(14, 7)) # Single subplot
|
fig, ax = plt.subplots(1, 1, figsize=(14, 7)) # Single subplot
|
||||||
|
|
||||||
x_axis_indices = np.arange(len(df_prices))
|
x_axis_indices = np.arange(len(df_prices))
|
||||||
|
|
||||||
@@ -228,7 +244,7 @@ def plot_close_price_chart(bars: List[Bar], title: str = "Close Price Chart") ->
|
|||||||
|
|
||||||
# 辅助函数:计算单笔交易的盈亏
|
# 辅助函数:计算单笔交易的盈亏
|
||||||
def calculate_trade_pnl(
|
def calculate_trade_pnl(
|
||||||
trade: Trade, entry_price: float, exit_price: float, direction: str
|
trade: Trade, entry_price: float, exit_price: float, direction: str
|
||||||
) -> float:
|
) -> float:
|
||||||
if direction == "LONG":
|
if direction == "LONG":
|
||||||
pnl = (exit_price - entry_price) * trade.volume
|
pnl = (exit_price - entry_price) * trade.volume
|
||||||
|
|||||||
90
src/analysis/grid_search_analyzer.py
Normal file
90
src/analysis/grid_search_analyzer.py
Normal file
@@ -0,0 +1,90 @@
|
|||||||
|
# src/grid_search_analyzer.py
|
||||||
|
|
||||||
|
import pandas as pd
|
||||||
|
import numpy as np
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
from typing import List, Dict, Any, Tuple
|
||||||
|
|
||||||
|
class GridSearchAnalyzer:
|
||||||
|
"""
|
||||||
|
用于分析和可视化网格搜索结果的类。
|
||||||
|
"""
|
||||||
|
def __init__(self, grid_results: List[Dict[str, Any]], optimization_metric: str):
|
||||||
|
"""
|
||||||
|
初始化 GridSearchAnalyzer。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
grid_results (List[Dict[str, Any]]): 包含每个参数组合及其对应优化指标的列表。
|
||||||
|
例如:[{'param1': v1, 'param2': v2, 'metric': m1}, ...]
|
||||||
|
optimization_metric (str): 用于优化的指标名称(例如 'total_return')。
|
||||||
|
"""
|
||||||
|
if not grid_results:
|
||||||
|
raise ValueError("grid_results 列表不能为空。")
|
||||||
|
if optimization_metric not in grid_results[0]:
|
||||||
|
raise ValueError(f"优化指标 '{optimization_metric}' 不在 grid_results 的字典中。")
|
||||||
|
|
||||||
|
self.grid_results = grid_results
|
||||||
|
self.optimization_metric = optimization_metric
|
||||||
|
self.param_names = [k for k in grid_results[0].keys() if k != optimization_metric]
|
||||||
|
|
||||||
|
if len(self.param_names) != 2:
|
||||||
|
raise ValueError("GridSearchAnalyzer 当前只支持分析两个参数的网格搜索结果。")
|
||||||
|
|
||||||
|
self.param1_name = self.param_names[0]
|
||||||
|
self.param2_name = self.param_names[1]
|
||||||
|
|
||||||
|
def find_best_parameters(self) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
找到网格搜索中表现最佳的参数组合和其指标值。
|
||||||
|
"""
|
||||||
|
if not self.grid_results:
|
||||||
|
return {}
|
||||||
|
best_result = max(self.grid_results, key=lambda x: x[self.optimization_metric])
|
||||||
|
print(f"\n--- 最佳参数组合 ---")
|
||||||
|
print(f" {self.param1_name}: {best_result[self.param1_name]}")
|
||||||
|
print(f" {self.param2_name}: {best_result[self.param2_name]}")
|
||||||
|
print(f" {self.optimization_metric}: {best_result[self.optimization_metric]:.4f}")
|
||||||
|
return best_result
|
||||||
|
|
||||||
|
def plot_heatmap(self, title: str = "heatmap"):
|
||||||
|
"""
|
||||||
|
绘制两个参数的热力图。
|
||||||
|
"""
|
||||||
|
if not self.grid_results:
|
||||||
|
print("没有数据用于绘制热力图。")
|
||||||
|
return
|
||||||
|
|
||||||
|
x_values = sorted(list(set(d[self.param1_name] for d in self.grid_results)))
|
||||||
|
y_values = sorted(list(set(d[self.param2_name] for d in self.grid_results)))
|
||||||
|
|
||||||
|
heatmap_matrix = np.zeros((len(y_values), len(x_values)))
|
||||||
|
|
||||||
|
# 填充网格
|
||||||
|
for item in self.grid_results:
|
||||||
|
x_idx = x_values.index(item[self.param1_name])
|
||||||
|
y_idx = y_values.index(item[self.param2_name])
|
||||||
|
heatmap_matrix[y_idx, x_idx] = item[self.optimization_metric]
|
||||||
|
|
||||||
|
fig, ax = plt.subplots(figsize=(10, 8))
|
||||||
|
im = ax.imshow(heatmap_matrix, cmap='viridis', origin='lower',
|
||||||
|
extent=[x_values[0], x_values[-1], y_values[0], y_values[-1]],
|
||||||
|
aspect='auto')
|
||||||
|
|
||||||
|
ax.set_xticks(x_values)
|
||||||
|
ax.set_yticks(y_values)
|
||||||
|
|
||||||
|
ax.set_xlabel(self.param1_name)
|
||||||
|
ax.set_ylabel(self.param2_name)
|
||||||
|
ax.set_title(f"{title}: {self.optimization_metric} vs. {self.param1_name} & {self.param2_name}")
|
||||||
|
|
||||||
|
cbar = fig.colorbar(im, ax=ax)
|
||||||
|
cbar.set_label(self.optimization_metric)
|
||||||
|
|
||||||
|
# 在每个格子上显示数值
|
||||||
|
for i in range(len(y_values)):
|
||||||
|
for j in range(len(x_values)):
|
||||||
|
text = ax.text(x_values[j], y_values[i], f"{heatmap_matrix[i, j]:.2f}",
|
||||||
|
ha="center", va="center", color="w", fontsize=8)
|
||||||
|
|
||||||
|
plt.tight_layout()
|
||||||
|
plt.show()
|
||||||
@@ -1,10 +1,16 @@
|
|||||||
# src/analysis/result_analyzer.py
|
# src/analysis/result_analyzer.py
|
||||||
|
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
|
import numpy as np
|
||||||
from typing import List, Dict, Any, Optional
|
from typing import List, Dict, Any, Optional
|
||||||
|
|
||||||
# 导入纯函数 (注意相对导入路径的变化)
|
# 导入纯函数 (注意相对导入路径的变化)
|
||||||
from .analysis_utils import calculate_metrics, plot_equity_and_drawdown_chart, plot_close_price_chart
|
from .analysis_utils import (
|
||||||
|
calculate_metrics,
|
||||||
|
plot_equity_and_drawdown_chart,
|
||||||
|
plot_close_price_chart,
|
||||||
|
)
|
||||||
|
|
||||||
# 导入核心数据类 (注意相对导入路径的变化)
|
# 导入核心数据类 (注意相对导入路径的变化)
|
||||||
from ..core_data import PortfolioSnapshot, Trade, Bar
|
from ..core_data import PortfolioSnapshot, Trade, Bar
|
||||||
|
|
||||||
@@ -14,11 +20,13 @@ class ResultAnalyzer:
|
|||||||
结果分析器:负责接收回测数据,并提供分析和可视化方法。
|
结果分析器:负责接收回测数据,并提供分析和可视化方法。
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(
|
||||||
portfolio_snapshots: List[PortfolioSnapshot],
|
self,
|
||||||
trade_history: List[Trade],
|
portfolio_snapshots: List[PortfolioSnapshot],
|
||||||
bars: List[Bar],
|
trade_history: List[Trade],
|
||||||
initial_capital: float):
|
bars: List[Bar],
|
||||||
|
initial_capital: float,
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
portfolio_snapshots (List[PortfolioSnapshot]): 回测引擎输出的投资组合快照列表。
|
portfolio_snapshots (List[PortfolioSnapshot]): 回测引擎输出的投资组合快照列表。
|
||||||
@@ -41,9 +49,7 @@ class ResultAnalyzer:
|
|||||||
if self._metrics_cache is None:
|
if self._metrics_cache is None:
|
||||||
print("正在计算绩效指标...")
|
print("正在计算绩效指标...")
|
||||||
self._metrics_cache = calculate_metrics(
|
self._metrics_cache = calculate_metrics(
|
||||||
self.portfolio_snapshots,
|
self.portfolio_snapshots, self.trade_history, self.initial_capital
|
||||||
self.trade_history,
|
|
||||||
self.initial_capital
|
|
||||||
)
|
)
|
||||||
print("绩效指标计算完成。")
|
print("绩效指标计算完成。")
|
||||||
return self._metrics_cache
|
return self._metrics_cache
|
||||||
@@ -69,7 +75,8 @@ class ResultAnalyzer:
|
|||||||
print("\n--- 部分交易明细 (最近5笔) ---")
|
print("\n--- 部分交易明细 (最近5笔) ---")
|
||||||
for trade in self.trade_history[-5:]:
|
for trade in self.trade_history[-5:]:
|
||||||
print(
|
print(
|
||||||
f" {trade.fill_time} | {trade.direction:<10} | {trade.symbol} | Vol: {trade.volume} | Price: {trade.price:.2f} | Commission: {trade.commission:.2f}")
|
f" {trade.fill_time} | {trade.direction:<10} | {trade.symbol} | Vol: {trade.volume} | Price: {trade.price:.2f} | Commission: {trade.commission:.2f}"
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
print("\n没有交易记录。")
|
print("\n没有交易记录。")
|
||||||
|
|
||||||
@@ -79,10 +86,13 @@ class ResultAnalyzer:
|
|||||||
"""
|
"""
|
||||||
print("正在绘制绩效图表...")
|
print("正在绘制绩效图表...")
|
||||||
# plot_performance_chart(self.portfolio_snapshots, self.initial_capital, self.bars)
|
# plot_performance_chart(self.portfolio_snapshots, self.initial_capital, self.bars)
|
||||||
plot_equity_and_drawdown_chart(self.portfolio_snapshots, self.initial_capital,
|
plot_equity_and_drawdown_chart(
|
||||||
title="Portfolio Equity and Drawdown Curve")
|
self.portfolio_snapshots,
|
||||||
|
self.initial_capital,
|
||||||
|
title="Portfolio Equity and Drawdown Curve",
|
||||||
|
)
|
||||||
|
|
||||||
# 绘制单独的收盘价曲线
|
# 绘制单独的收盘价曲线
|
||||||
plot_close_price_chart(self.bars, title="Underlying Asset Close Price")
|
plot_close_price_chart(self.bars, title="Underlying Asset Close Price")
|
||||||
|
|
||||||
print("图表绘制完成。")
|
print("图表绘制完成。")
|
||||||
|
|||||||
@@ -58,6 +58,16 @@ class BacktestContext:
|
|||||||
# 可以在这里触发策略的on_trade回调(如果策略定义了)
|
# 可以在这里触发策略的on_trade回调(如果策略定义了)
|
||||||
return trade
|
return trade
|
||||||
|
|
||||||
|
def cancel_order(self, order_id: str) -> bool:
|
||||||
|
"""
|
||||||
|
策略通过此方法发出交易订单。
|
||||||
|
"""
|
||||||
|
if self._current_bar is None:
|
||||||
|
raise RuntimeError("当前Bar未设置,无法发送订单。")
|
||||||
|
|
||||||
|
return self._simulator.cancel_order(order_id)
|
||||||
|
|
||||||
|
|
||||||
def get_current_positions(self) -> Dict[str, int]:
|
def get_current_positions(self) -> Dict[str, int]:
|
||||||
"""
|
"""
|
||||||
获取当前模拟器的持仓情况。
|
获取当前模拟器的持仓情况。
|
||||||
|
|||||||
@@ -36,6 +36,7 @@ class BacktestEngine:
|
|||||||
commission_rate (float): 交易佣金率。
|
commission_rate (float): 交易佣金率。
|
||||||
"""
|
"""
|
||||||
self.data_manager = data_manager
|
self.data_manager = data_manager
|
||||||
|
self.initial_capital = initial_capital
|
||||||
self.simulator = ExecutionSimulator(
|
self.simulator = ExecutionSimulator(
|
||||||
initial_capital=initial_capital,
|
initial_capital=initial_capital,
|
||||||
slippage_rate=slippage_rate,
|
slippage_rate=slippage_rate,
|
||||||
@@ -76,6 +77,7 @@ class BacktestEngine:
|
|||||||
|
|
||||||
# 设置当前Bar到Context,供策略访问
|
# 设置当前Bar到Context,供策略访问
|
||||||
self.context.set_current_bar(current_bar)
|
self.context.set_current_bar(current_bar)
|
||||||
|
self.simulator.update_time(current_time=current_bar.datetime)
|
||||||
|
|
||||||
# 更新历史Bar缓存
|
# 更新历史Bar缓存
|
||||||
self._history_bars.append(current_bar)
|
self._history_bars.append(current_bar)
|
||||||
@@ -129,7 +131,7 @@ class BacktestEngine:
|
|||||||
for symbol_held, quantity in positions_to_close.items():
|
for symbol_held, quantity in positions_to_close.items():
|
||||||
if quantity != 0:
|
if quantity != 0:
|
||||||
print(f"[{last_processed_bar.datetime}] 回测结束平仓: 平仓 {symbol_held} ({quantity} 手) @ {last_processed_bar.close:.2f}。")
|
print(f"[{last_processed_bar.datetime}] 回测结束平仓: 平仓 {symbol_held} ({quantity} 手) @ {last_processed_bar.close:.2f}。")
|
||||||
direction = "SELL" if quantity > 0 else "BUY"
|
direction = "CLOSE_LONG" if quantity > 0 else "CLOSE_SELL"
|
||||||
volume = abs(quantity)
|
volume = abs(quantity)
|
||||||
|
|
||||||
# 使用当前合约的最后一根Bar的价格进行平仓
|
# 使用当前合约的最后一根Bar的价格进行平仓
|
||||||
@@ -148,6 +150,17 @@ class BacktestEngine:
|
|||||||
print(f"总计处理了 {len(self.portfolio_snapshots)} 根K线。")
|
print(f"总计处理了 {len(self.portfolio_snapshots)} 根K线。")
|
||||||
print(f"总计发生了 {len(self.trade_history)} 笔交易。")
|
print(f"总计发生了 {len(self.trade_history)} 笔交易。")
|
||||||
|
|
||||||
|
final_portfolio_value = 0.0
|
||||||
|
if last_processed_bar:
|
||||||
|
final_portfolio_value = self.simulator.get_portfolio_value(last_processed_bar)
|
||||||
|
else: # 如果数据为空,或者回测根本没跑,则净值为初始资金
|
||||||
|
final_portfolio_value = self.initial_capital
|
||||||
|
|
||||||
|
total_return_percentage = ((final_portfolio_value - self.initial_capital) / self.initial_capital) * 100
|
||||||
|
|
||||||
|
print(f"最终总净值: {final_portfolio_value:.2f}")
|
||||||
|
print(f"总收益率: {total_return_percentage:.2f}%")
|
||||||
|
|
||||||
def get_backtest_results(self) -> Dict[str, Any]:
|
def get_backtest_results(self) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
返回回测结果数据,供结果分析模块使用。
|
返回回测结果数据,供结果分析模块使用。
|
||||||
|
|||||||
189
src/backtest_runner.py
Normal file
189
src/backtest_runner.py
Normal file
@@ -0,0 +1,189 @@
|
|||||||
|
# src/backtest_runner.py
|
||||||
|
|
||||||
|
import pandas as pd
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import List, Dict, Any, Optional, Type
|
||||||
|
|
||||||
|
# 导入核心组件
|
||||||
|
from .backtest_engine import BacktestEngine
|
||||||
|
from .data_manager import DataManager
|
||||||
|
from .execution_simulator import (
|
||||||
|
ExecutionSimulator,
|
||||||
|
) # 虽然不直接用,但 BacktestEngine 内部会用到
|
||||||
|
from .strategies.base_strategy import Strategy # 用于类型提示
|
||||||
|
from .core_data import Bar, PortfolioSnapshot, Trade, Order # 导入数据结构
|
||||||
|
|
||||||
|
|
||||||
|
def run_multi_segment_backtest(
|
||||||
|
segment_configs: Dict[str, Dict[str, Any]],
|
||||||
|
strategy_class: Type[Strategy],
|
||||||
|
strategy_params: Dict[str, Any],
|
||||||
|
initial_capital: float,
|
||||||
|
slippage_rate: float,
|
||||||
|
commission_rate: float,
|
||||||
|
total_start_date: datetime,
|
||||||
|
total_end_date: datetime,
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
运行一个多主力合约的分段回测,并返回合并后的回测结果。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
segment_configs (Dict[str, Dict[str, Any]]): 包含所有合约片段配置的字典。
|
||||||
|
strategy_class (Type[Strategy]): 要回测的策略类。
|
||||||
|
strategy_params (Dict[str, Any]): 传递给策略的参数字典。
|
||||||
|
initial_capital (float): 回测的初始资金。
|
||||||
|
slippage_rate (float): 交易滑点率。
|
||||||
|
commission_rate (float): 交易佣金率。
|
||||||
|
total_start_date (datetime): 总回测的起始日期。
|
||||||
|
total_end_date (datetime): 总回测的结束日期。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict[str, Any]: 包含合并后的 portfolio_snapshots, trade_history, all_bars
|
||||||
|
和 initial_capital 的字典。
|
||||||
|
"""
|
||||||
|
# --- 收集所有片段的回测结果 ---
|
||||||
|
all_combined_snapshots: List[PortfolioSnapshot] = []
|
||||||
|
all_combined_trades: List[Trade] = []
|
||||||
|
all_combined_bars: List[Bar] = []
|
||||||
|
|
||||||
|
# 用于净值平滑的基准值,从初始资金开始
|
||||||
|
last_segment_adjusted_total_value = initial_capital
|
||||||
|
current_cash = initial_capital # 用于传递给下一个片段的初始资金
|
||||||
|
current_positions: Dict[str, int] = {} # 用于传递给下一个片段的初始持仓
|
||||||
|
|
||||||
|
print("\n--- 开始分段回测跨越多个合约 ---")
|
||||||
|
|
||||||
|
# 获取所有合约的有序键,确保按定义的顺序回测
|
||||||
|
segment_keys = list(segment_configs.keys())
|
||||||
|
|
||||||
|
for i, contract_symbol in enumerate(segment_keys):
|
||||||
|
config = segment_configs[contract_symbol]
|
||||||
|
segment_file = config["file"]
|
||||||
|
segment_start_date = config["start"]
|
||||||
|
segment_end_date = config["end"]
|
||||||
|
|
||||||
|
current_segment_start = max(segment_start_date, total_start_date)
|
||||||
|
current_segment_end = min(segment_end_date, total_end_date)
|
||||||
|
|
||||||
|
if current_segment_start > current_segment_end:
|
||||||
|
print(f"跳过片段 {contract_symbol}: 日期范围超出总回测周期或无效。")
|
||||||
|
continue
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"\n--- 回测片段: {contract_symbol} 从 {current_segment_start.strftime('%Y-%m-%d %H:%M')} 到 {current_segment_end.strftime('%Y-%m-%d %H:%M')} ---"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 1. 初始化 DataManager (每个片段使用不同的 DataManager 实例)
|
||||||
|
data_manager = DataManager(file_path=segment_file, symbol=contract_symbol)
|
||||||
|
|
||||||
|
if data_manager.raw_df.empty:
|
||||||
|
print(
|
||||||
|
f"警告: 未能加载 {contract_symbol} 的数据文件 {segment_file}。跳过此片段。"
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
strategy_params["symbol"] = contract_symbol
|
||||||
|
# 2. 创建 BacktestEngine 实例 (针对当前合约片段)
|
||||||
|
engine = BacktestEngine(
|
||||||
|
data_manager=data_manager,
|
||||||
|
strategy_class=strategy_class,
|
||||||
|
strategy_params=strategy_params, # 使用传入的策略参数
|
||||||
|
initial_capital=initial_capital,
|
||||||
|
slippage_rate=slippage_rate,
|
||||||
|
commission_rate=commission_rate,
|
||||||
|
current_segment_symbol=contract_symbol,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 在运行当前片段之前,将上一个片段结束时的资金和持仓注入到当前 BacktestEngine 的 simulator 中
|
||||||
|
engine.get_simulator().reset(
|
||||||
|
new_initial_capital=current_cash, new_initial_positions=current_positions
|
||||||
|
)
|
||||||
|
|
||||||
|
# 3. 运行当前片段回测
|
||||||
|
engine.run_backtest()
|
||||||
|
segment_results = engine.get_backtest_results()
|
||||||
|
|
||||||
|
# 4. 获取当前片段结束后的最新资金和持仓状态,供下个片段使用
|
||||||
|
current_cash = engine.get_simulator().cash
|
||||||
|
current_positions = engine.get_simulator().positions.copy()
|
||||||
|
|
||||||
|
# 5. 拼接结果:净值曲线需要特殊处理
|
||||||
|
segment_snapshots = segment_results["portfolio_snapshots"]
|
||||||
|
segment_trades = segment_results["trade_history"]
|
||||||
|
segment_bars = segment_results["all_bars"]
|
||||||
|
|
||||||
|
if segment_snapshots:
|
||||||
|
first_snapshot_value_of_segment = segment_snapshots[0].total_value
|
||||||
|
|
||||||
|
if first_snapshot_value_of_segment == 0:
|
||||||
|
segment_relative_factor = 1.0
|
||||||
|
else:
|
||||||
|
segment_relative_factor = (
|
||||||
|
last_segment_adjusted_total_value / first_snapshot_value_of_segment
|
||||||
|
)
|
||||||
|
|
||||||
|
for snapshot in segment_snapshots:
|
||||||
|
adjusted_snapshot_value = snapshot.total_value * segment_relative_factor
|
||||||
|
adjusted_snapshot = PortfolioSnapshot(
|
||||||
|
datetime=snapshot.datetime,
|
||||||
|
total_value=adjusted_snapshot_value,
|
||||||
|
cash=snapshot.cash,
|
||||||
|
positions=snapshot.positions.copy(),
|
||||||
|
price_at_snapshot=snapshot.price_at_snapshot.copy(),
|
||||||
|
)
|
||||||
|
all_combined_snapshots.append(adjusted_snapshot)
|
||||||
|
|
||||||
|
last_segment_adjusted_total_value = all_combined_snapshots[-1].total_value
|
||||||
|
else:
|
||||||
|
print(
|
||||||
|
f"警告: 未能为 {contract_symbol} 生成任何快照。`last_segment_adjusted_total_value` 保持不变: {last_segment_adjusted_total_value:.2f}。"
|
||||||
|
)
|
||||||
|
|
||||||
|
all_combined_trades.extend(segment_trades)
|
||||||
|
all_combined_bars.extend(segment_bars)
|
||||||
|
|
||||||
|
print("\n--- 总分段回测完成 ---")
|
||||||
|
|
||||||
|
return {
|
||||||
|
"portfolio_snapshots": all_combined_snapshots,
|
||||||
|
"trade_history": all_combined_trades,
|
||||||
|
"all_bars": all_combined_bars,
|
||||||
|
"initial_capital": initial_capital, # 返回初始资金,方便 ResultAnalyzer 使用
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def run_backtest_for_optimization(
|
||||||
|
global_config: Dict[str, Any],
|
||||||
|
strategy_class: Type[Strategy],
|
||||||
|
current_strategy_params: Dict[str, Any],
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
简化后的回测入口,用于优化或参数搜索。
|
||||||
|
它从 global_config 中提取通用回测参数,并结合 current_strategy_params 运行回测。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
global_config (Dict[str, Any]): 包含所有通用全局配置的字典。
|
||||||
|
strategy_class (Type[Strategy]): 要回测的策略类。
|
||||||
|
current_strategy_params (Dict[str, Any]): 当前迭代的策略参数。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict[str, Any]: 包含合并后的 portfolio_snapshots, trade_history, all_bars
|
||||||
|
和 initial_capital 的字典。
|
||||||
|
"""
|
||||||
|
total_start_date = global_config["total_start_date"]
|
||||||
|
total_end_date = global_config["total_end_date"]
|
||||||
|
initial_capital = global_config["initial_capital"]
|
||||||
|
slippage_rate = global_config["slippage_rate"]
|
||||||
|
commission_rate = global_config["commission_rate"]
|
||||||
|
segment_configs = global_config["segment_configs"]
|
||||||
|
|
||||||
|
return run_multi_segment_backtest(
|
||||||
|
segment_configs=segment_configs,
|
||||||
|
strategy_class=strategy_class,
|
||||||
|
strategy_params=current_strategy_params,
|
||||||
|
initial_capital=initial_capital,
|
||||||
|
slippage_rate=slippage_rate,
|
||||||
|
commission_rate=commission_rate,
|
||||||
|
total_start_date=total_start_date,
|
||||||
|
total_end_date=total_end_date,
|
||||||
|
)
|
||||||
36
src/common_utils.py
Normal file
36
src/common_utils.py
Normal file
@@ -0,0 +1,36 @@
|
|||||||
|
# src/common_utils.py
|
||||||
|
|
||||||
|
from typing import List, Union
|
||||||
|
|
||||||
|
def generate_parameter_range(start: Union[int, float], end: Union[int, float], step: Union[int, float]) -> List[Union[int, float]]:
|
||||||
|
"""
|
||||||
|
根据开始、结束和步长生成一个参数值的列表。
|
||||||
|
支持整数和浮点数。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
start (Union[int, float]): 参数范围的起始值。
|
||||||
|
end (Union[int, float]): 参数范围的结束值。
|
||||||
|
step (Union[int, float]): 参数的步长。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[Union[int, float]]: 生成的参数值列表。
|
||||||
|
"""
|
||||||
|
if step == 0:
|
||||||
|
raise ValueError("步长不能为0。")
|
||||||
|
if (end - start) * step < 0: # 步长方向与范围不符
|
||||||
|
raise ValueError("步长方向与起始/结束值不一致,请检查步长正负。")
|
||||||
|
|
||||||
|
param_range = []
|
||||||
|
current_value = start
|
||||||
|
while (step > 0 and current_value <= end) or (step < 0 and current_value >= end):
|
||||||
|
param_range.append(current_value)
|
||||||
|
current_value += step
|
||||||
|
# 针对浮点数精度问题进行小幅调整,避免无限循环或过早停止
|
||||||
|
if isinstance(step, float) or isinstance(start, float) or isinstance(end, float):
|
||||||
|
current_value = round(current_value, 10) # 四舍五入到一定小数位
|
||||||
|
|
||||||
|
return param_range
|
||||||
|
|
||||||
|
# 示例:
|
||||||
|
# print(generate_parameter_range(0.99, 1.01, 0.005)) # [0.99, 0.995, 1.0, 1.005, 1.01]
|
||||||
|
# print(generate_parameter_range(5, 20, 5)) # [5, 10, 15, 20]
|
||||||
@@ -1,5 +1,5 @@
|
|||||||
# src/execution_simulator.py (修改部分)
|
# src/execution_simulator.py (修改部分)
|
||||||
|
from datetime import datetime
|
||||||
from typing import Dict, List, Optional
|
from typing import Dict, List, Optional
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
from .core_data import Order, Trade, Bar, PortfolioSnapshot
|
from .core_data import Order, Trade, Bar, PortfolioSnapshot
|
||||||
@@ -36,35 +36,111 @@ class ExecutionSimulator:
|
|||||||
self.commission_rate = commission_rate
|
self.commission_rate = commission_rate
|
||||||
self.trade_log: List[Trade] = [] # 存储所有成交记录
|
self.trade_log: List[Trade] = [] # 存储所有成交记录
|
||||||
self.pending_orders: Dict[str, Order] = {} # {order_id: Order_object}
|
self.pending_orders: Dict[str, Order] = {} # {order_id: Order_object}
|
||||||
|
self._current_time = None
|
||||||
|
|
||||||
print(
|
print(
|
||||||
f"模拟器初始化:初始资金={self.initial_capital:.2f}, 滑点率={self.slippage_rate}, 佣金率={self.commission_rate}")
|
f"模拟器初始化:初始资金={self.initial_capital:.2f}, 滑点率={self.slippage_rate}, 佣金率={self.commission_rate}")
|
||||||
if self.positions:
|
if self.positions:
|
||||||
print(f"初始持仓:{self.positions}")
|
print(f"初始持仓:{self.positions}")
|
||||||
|
|
||||||
|
def update_time(self, current_time: datetime):
|
||||||
|
"""
|
||||||
|
更新模拟器的当前时间。
|
||||||
|
这个方法由 BacktestEngine 在遍历 K 线时调用。
|
||||||
|
"""
|
||||||
|
self._current_time = current_time
|
||||||
|
|
||||||
|
# --- 新增的公共方法 ---
|
||||||
|
def get_current_time(self) -> datetime:
|
||||||
|
"""
|
||||||
|
获取模拟器的当前时间。
|
||||||
|
"""
|
||||||
|
if self._current_time is None:
|
||||||
|
# 可以在这里抛出错误或者返回一个默认值,取决于你对未初始化时间的处理
|
||||||
|
# 抛出错误可以帮助你发现问题,例如在模拟器时间未设置时就尝试获取
|
||||||
|
# raise RuntimeError("Simulator time has not been set. Ensure update_time is called.")
|
||||||
|
return None
|
||||||
|
return self._current_time
|
||||||
|
|
||||||
def _calculate_fill_price(self, order: Order, current_bar: Bar) -> float:
|
def _calculate_fill_price(self, order: Order, current_bar: Bar) -> float:
|
||||||
"""
|
"""
|
||||||
内部方法:根据订单类型和滑点计算实际成交价格。
|
内部方法:根据订单类型和滑点计算实际成交价格。
|
||||||
简化处理:市价单以当前Bar收盘价为基准,考虑滑点。
|
- 市价单通常以当前K线的开盘价成交(考虑滑点)。
|
||||||
|
- 限价单判断是否触及限价,如果触及,以限价成交(考虑滑点)。
|
||||||
"""
|
"""
|
||||||
base_price = current_bar.close # 简化为收盘价成交
|
fill_price = -1.0 # 默认未成交
|
||||||
|
|
||||||
# 考虑滑点
|
if order.price_type == "MARKET":
|
||||||
if order.direction in ["BUY", "CLOSE_SHORT"]: # 买入或平空,价格向上偏离
|
# 市价单通常以开盘价成交,或者根据你的策略需求选择收盘价
|
||||||
fill_price = base_price * (1 + self.slippage_rate)
|
# 这里我们仍然使用开盘价作为市价单的基准成交价
|
||||||
elif order.direction in ["SELL", "CLOSE_LONG"]: # 卖出或平多,价格向下偏离
|
base_price = current_bar.open
|
||||||
fill_price = base_price * (1 - self.slippage_rate)
|
|
||||||
else: # 默认情况,无滑点
|
|
||||||
fill_price = base_price
|
|
||||||
|
|
||||||
# 如果是限价单且成交价格不满足条件,则可能不成交
|
if order.direction == "BUY" or order.direction == "CLOSE_SHORT": # 买入或平空,价格向上偏离
|
||||||
if order.price_type == "LIMIT" and order.limit_price is not None:
|
fill_price = base_price * (1 + self.slippage_rate)
|
||||||
# 对于BUY和CLOSE_SHORT,成交价必须 <= 限价
|
elif order.direction == "SELL" or order.direction == "CLOSE_LONG": # 卖出或平多,价格向下偏离
|
||||||
if (order.direction == "BUY" or order.direction == "CLOSE_SHORT") and fill_price > order.limit_price:
|
fill_price = base_price * (1 - self.slippage_rate)
|
||||||
return -1.0 # 未触及限价
|
else: # 默认情况,理论上不应该到这里,因为方向应该明确
|
||||||
# 对于SELL和CLOSE_LONG,成交价必须 >= 限价
|
fill_price = base_price # 不考虑滑点
|
||||||
elif (order.direction == "SELL" or order.direction == "CLOSE_LONG") and fill_price < order.limit_price:
|
|
||||||
return -1.0 # 未触及限价
|
# 市价单只要有价格就会成交,无需额外价格区间判断
|
||||||
|
|
||||||
|
elif order.price_type == "LIMIT" and order.limit_price is not None:
|
||||||
|
limit_price = order.limit_price
|
||||||
|
high = current_bar.high
|
||||||
|
low = current_bar.low
|
||||||
|
open_price = current_bar.open # 也可以在限价单成交时用开盘价作为实际成交价,或限价本身
|
||||||
|
|
||||||
|
if order.direction == "BUY" or order.direction == "CLOSE_SHORT": # 限价买入或限价平空
|
||||||
|
# 买入:如果K线最低价 <= 限价,则限价单可能成交
|
||||||
|
if low <= limit_price:
|
||||||
|
# 成交价通常是限价,但考虑滑点后可能略高(对买方不利)
|
||||||
|
# 或者可以以 open/low 之间的一个价格成交,这里简化为限价+滑点
|
||||||
|
fill_price_candidate = limit_price * (1 + self.slippage_rate)
|
||||||
|
|
||||||
|
# 确保成交价不会比当前K线的最低价还低(如果不是在最低点成交)
|
||||||
|
# 也可以简单就返回 limit_price * (1 + self.slippage_rate)
|
||||||
|
|
||||||
|
# 如果开盘价低于或等于限价,那么通常会以开盘价成交(或者略差)
|
||||||
|
# 否则,如果价格是从上方跌落到限价区,那么会在限价附近成交
|
||||||
|
# 这里简化:如果K线触及限价,则以限价成交(考虑滑点)。
|
||||||
|
# 更精细的模拟会考虑价格穿越顺序 (例如,是否开盘就跳过了限价)
|
||||||
|
|
||||||
|
# 如果开盘价已经低于或等于限价,那么就以开盘价成交(考虑滑点)
|
||||||
|
if open_price <= limit_price:
|
||||||
|
fill_price = open_price * (1 + self.slippage_rate)
|
||||||
|
else: # 价格从高处跌落到限价
|
||||||
|
fill_price = limit_price * (1 + self.slippage_rate)
|
||||||
|
|
||||||
|
# 确保成交价不会超过限价 (虽然加滑点可能会略超,但这是交易成本的一部分)
|
||||||
|
# 这个检查是为了避免逻辑错误,理论上加滑点后应该可以接受比限价略高的价格
|
||||||
|
# 如果你严格要求成交价不能高于限价,则需要移除滑点或者更复杂的逻辑
|
||||||
|
# 这里我们接受加滑点后的价格
|
||||||
|
if fill_price > high and fill_price > limit_price: # 如果计算出来的成交价高于K线最高价,则可能不合理
|
||||||
|
fill_price = -1.0 # 理论上不该发生,除非滑点过大
|
||||||
|
else:
|
||||||
|
return -1.0 # 未触及限价
|
||||||
|
|
||||||
|
elif order.direction == "SELL" or order.direction == "CLOSE_LONG": # 限价卖出或限价平多
|
||||||
|
# 卖出:如果K线最高价 >= 限价,则限价单可能成交
|
||||||
|
if high >= limit_price:
|
||||||
|
# 成交价通常是限价,但考虑滑点后可能略低(对卖方不利)
|
||||||
|
fill_price_candidate = limit_price * (1 - self.slippage_rate)
|
||||||
|
|
||||||
|
# 如果开盘价已经高于或等于限价,那么就以开盘价成交(考虑滑点)
|
||||||
|
if open_price >= limit_price:
|
||||||
|
fill_price = open_price * (1 - self.slippage_rate)
|
||||||
|
else: # 价格从低处上涨到限价
|
||||||
|
fill_price = limit_price * (1 - self.slippage_rate)
|
||||||
|
|
||||||
|
# 确保成交价不会低于限价
|
||||||
|
if fill_price < low and fill_price < limit_price: # 如果计算出来的成交价低于K线最低价,则可能不合理
|
||||||
|
fill_price = -1.0 # 理论上不该发生
|
||||||
|
else:
|
||||||
|
return -1.0 # 未触及限价
|
||||||
|
|
||||||
|
# 最后检查成交价是否有效
|
||||||
|
if fill_price <= 0:
|
||||||
|
return -1.0 # 如果计算出来价格无效,返回未成交
|
||||||
|
|
||||||
return fill_price
|
return fill_price
|
||||||
|
|
||||||
@@ -107,6 +183,7 @@ class ExecutionSimulator:
|
|||||||
if fill_price <= 0: # 表示未成交或不满足限价条件
|
if fill_price <= 0: # 表示未成交或不满足限价条件
|
||||||
if order.price_type == "LIMIT":
|
if order.price_type == "LIMIT":
|
||||||
self.pending_orders[order.id] = order
|
self.pending_orders[order.id] = order
|
||||||
|
# print(f'撮合失败,order id:{order.id},fill_price:{fill_price}')
|
||||||
return None # 未成交,返回None
|
return None # 未成交,返回None
|
||||||
|
|
||||||
# --- 以下是订单成功成交的逻辑 ---
|
# --- 以下是订单成功成交的逻辑 ---
|
||||||
@@ -116,78 +193,90 @@ class ExecutionSimulator:
|
|||||||
current_position = self.positions.get(symbol, 0)
|
current_position = self.positions.get(symbol, 0)
|
||||||
current_average_cost = self.average_costs.get(symbol, 0.0)
|
current_average_cost = self.average_costs.get(symbol, 0.0)
|
||||||
|
|
||||||
if order.direction == "BUY":
|
actual_direction = order.direction
|
||||||
|
if order.direction == "CLOSE_SHORT":
|
||||||
|
actual_direction = "BUY"
|
||||||
|
elif order.direction == "CLOSE_LONG":
|
||||||
|
actual_direction = "SELL"
|
||||||
|
|
||||||
|
is_close_order_intent = (order.direction == "CLOSE_LONG" or
|
||||||
|
order.direction == "CLOSE_SHORT")
|
||||||
|
|
||||||
|
if actual_direction == "BUY": # 处理买入 (开多 / 平空)
|
||||||
# 开多仓或平空仓
|
# 开多仓或平空仓
|
||||||
if current_position >= 0: # 当前持有多仓或无仓位 (开多)
|
if current_position >= 0: # 当前持有多仓或无仓位 (开多)
|
||||||
is_open_trade = True
|
is_open_trade = not is_close_order_intent # 如果是平仓意图,则不是开仓交易
|
||||||
# 更新平均成本 (加权平均)
|
# 更新平均成本 (加权平均)
|
||||||
new_total_cost = (current_average_cost * current_position) + (fill_price * volume)
|
new_total_cost = (current_average_cost * current_position) + (fill_price * volume)
|
||||||
new_total_volume = current_position + volume
|
new_total_volume = current_position + volume
|
||||||
self.average_costs[symbol] = new_total_cost / new_total_volume if new_total_volume > 0 else 0.0
|
self.average_costs[symbol] = new_total_cost / new_total_volume if new_total_volume > 0 else 0.0
|
||||||
self.positions[symbol] = new_total_volume
|
self.positions[symbol] = new_total_volume
|
||||||
else: # 当前持有空仓 (平空)
|
else: # 当前持有空仓 (平空)
|
||||||
is_close_trade = True
|
is_close_trade = is_close_order_intent # 这是平仓交易
|
||||||
# 计算平空盈亏
|
# 计算平空盈亏
|
||||||
# PnL = (开仓成本 - 平仓价格) * 平仓数量 (注意空头方向)
|
pnl_per_share = current_average_cost - fill_price
|
||||||
# 简化:假设平空时,直接使用当前的平均开仓成本来计算盈亏
|
|
||||||
# 更精确的FIFO/LIFO需更多逻辑
|
|
||||||
pnl_per_share = current_average_cost - fill_price # (买入平空,成本高于平仓价则盈利)
|
|
||||||
realized_pnl = pnl_per_share * volume
|
realized_pnl = pnl_per_share * volume
|
||||||
|
|
||||||
# 更新持仓和成本
|
# 更新持仓和成本
|
||||||
self.positions[symbol] += volume
|
self.positions[symbol] += volume
|
||||||
if self.positions[symbol] == 0:
|
if self.positions[symbol] == 0: # 如果全部平仓
|
||||||
del self.positions[symbol]
|
del self.positions[symbol]
|
||||||
if symbol in self.average_costs: del self.average_costs[symbol] # 清理成本
|
if symbol in self.average_costs: del self.average_costs[symbol]
|
||||||
elif self.positions[symbol] > 0 and current_position < 0: # 部分平空转为多头,需重新设置成本
|
elif self.positions[symbol] > 0 and current_position < 0: # 部分平空,且空头仓位被买平为多头仓位
|
||||||
# 这部分逻辑可以更复杂,这里简化处理,如果转为多头,成本重置为0
|
# 这是从空头转为多头的复杂情况。需要重新计算平均成本
|
||||||
# 实际应该用剩余的空头成本 + 新开多的成本
|
# 简单处理:将剩余的多头仓位成本设为当前价格
|
||||||
self.average_costs[symbol] = fill_price # 简单地将剩下的多头仓位成本设为当前价格
|
self.average_costs[symbol] = fill_price
|
||||||
|
|
||||||
# 资金扣除
|
|
||||||
if self.cash >= trade_value + commission:
|
if self.cash >= trade_value + commission:
|
||||||
self.cash -= (trade_value + commission)
|
self.cash -= (trade_value + commission)
|
||||||
else:
|
else:
|
||||||
# print(f"[{current_bar.datetime}] 资金不足,无法执行买入 {volume} {symbol}")
|
print(f"[{current_bar.datetime}] 资金不足,无法执行买入 {volume} {symbol}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
elif order.direction == "SELL":
|
elif actual_direction == "SELL": # 处理卖出 (开空 / 平多)
|
||||||
# 开空仓或平多仓
|
# 开空仓或平多仓
|
||||||
if current_position <= 0: # 当前持有空仓或无仓位 (开空)
|
if current_position <= 0: # 当前持有空仓或无仓位 (开空)
|
||||||
is_open_trade = True
|
is_open_trade = not is_close_order_intent # 如果是平仓意图,则不是开仓交易
|
||||||
# 更新平均成本 (空头成本为负值)
|
# 更新平均成本 (空头成本为负值)
|
||||||
new_total_cost = (current_average_cost * current_position) - (fill_price * volume) # 负的持仓乘以负的卖出价
|
# 对于空头,平均成本通常是指你卖出开仓的平均价格
|
||||||
new_total_volume = current_position - volume # 空头持仓量更负
|
# 这里需要根据你的空头成本计算方式来调整
|
||||||
self.average_costs[symbol] = new_total_cost / new_total_volume if new_total_volume < 0 else 0.0
|
# 常见的做法是:总卖出价值 / 总卖出数量
|
||||||
self.positions[symbol] = new_total_volume
|
new_total_value = (current_average_cost * abs(current_position)) + (fill_price * volume)
|
||||||
|
new_total_volume = abs(current_position) + volume
|
||||||
|
self.average_costs[symbol] = new_total_value / new_total_volume if new_total_volume > 0 else 0.0
|
||||||
|
self.positions[symbol] -= volume # 空头数量增加,持仓量变为负更多
|
||||||
|
|
||||||
else: # 当前持有多仓 (平多)
|
else: # 当前持有多仓 (平多)
|
||||||
is_close_trade = True
|
is_close_trade = is_close_order_intent # 这是平仓交易
|
||||||
# 计算平多盈亏
|
# 计算平多盈亏
|
||||||
# PnL = (平仓价格 - 开仓成本) * 平仓数量
|
pnl_per_share = fill_price - current_average_cost
|
||||||
pnl_per_share = fill_price - current_average_cost # (卖出平多,平仓价高于成本则盈利)
|
|
||||||
realized_pnl = pnl_per_share * volume
|
realized_pnl = pnl_per_share * volume
|
||||||
|
|
||||||
# 更新持仓和成本
|
# 更新持仓和成本
|
||||||
self.positions[symbol] -= volume
|
self.positions[symbol] -= volume
|
||||||
if self.positions[symbol] == 0:
|
if self.positions[symbol] == 0: # 如果全部平仓
|
||||||
del self.positions[symbol]
|
del self.positions[symbol]
|
||||||
if symbol in self.average_costs: del self.average_costs[symbol] # 清理成本
|
if symbol in self.average_costs: del self.average_costs[symbol]
|
||||||
elif self.positions[symbol] < 0 and current_position > 0: # 部分平多转为空头,需重新设置成本
|
elif self.positions[symbol] < 0 and current_position > 0: # 部分平多,且多头仓位被卖平为空头仓位
|
||||||
self.average_costs[symbol] = fill_price # 简单地将剩下的空头仓位成本设为当前价格
|
# 从多头转为空头的复杂情况
|
||||||
|
self.average_costs[symbol] = fill_price # 简单将剩余空头仓位成本设为当前价格
|
||||||
|
|
||||||
# 资金扣除 (佣金) 和增加 (卖出收入)
|
|
||||||
if self.cash >= commission:
|
if self.cash >= commission:
|
||||||
self.cash -= commission
|
self.cash -= commission
|
||||||
self.cash += trade_value
|
self.cash += trade_value
|
||||||
else:
|
else:
|
||||||
# print(f"[{current_bar.datetime}] 资金不足,无法执行卖出 {volume} {symbol}")
|
print(f"[{current_bar.datetime}] 资金不足,无法执行卖出 {volume} {symbol}")
|
||||||
return None
|
return None
|
||||||
|
else: # 既不是 BUY 也不是 SELL,且也不是 CANCEL。这可能是未知的 direction
|
||||||
|
print(
|
||||||
|
f"[{current_bar.datetime}] 模拟器: 收到未知订单方向 {order.direction} for Order ID: {order.id}. 订单未处理。")
|
||||||
|
return None
|
||||||
|
|
||||||
# 创建 Trade 对象
|
# 创建 Trade 对象
|
||||||
executed_trade = Trade(
|
executed_trade = Trade(
|
||||||
order_id=order.id, fill_time=current_bar.datetime, symbol=symbol,
|
order_id=order.id, fill_time=current_bar.datetime, symbol=symbol,
|
||||||
direction=order.direction, # 记录原始订单方向 (BUY/SELL)
|
direction=order.direction, # 记录原始订单方向 (BUY/SELL/CLOSE_X)
|
||||||
volume=volume, price=fill_price, commission=commission,
|
volume=volume, price=fill_price, commission=commission,
|
||||||
cash_after_trade=self.cash, positions_after_trade=self.positions.copy(),
|
cash_after_trade=self.cash, positions_after_trade=self.positions.copy(),
|
||||||
realized_pnl=realized_pnl, # 填充实现盈亏
|
realized_pnl=realized_pnl, # 填充实现盈亏
|
||||||
@@ -200,8 +289,6 @@ class ExecutionSimulator:
|
|||||||
if order.id in self.pending_orders:
|
if order.id in self.pending_orders:
|
||||||
del self.pending_orders[order.id]
|
del self.pending_orders[order.id]
|
||||||
|
|
||||||
# print(f"[{current_bar.datetime}] 成交: {executed_trade.direction} {executed_trade.volume} {executed_trade.symbol} @ {executed_trade.price:.2f}, 佣金: {executed_trade.commission:.2f}, PnL: {executed_trade.realized_pnl:.2f}")
|
|
||||||
|
|
||||||
return executed_trade
|
return executed_trade
|
||||||
|
|
||||||
def cancel_order(self, order_id: str) -> bool:
|
def cancel_order(self, order_id: str) -> bool:
|
||||||
@@ -245,7 +332,7 @@ class ExecutionSimulator:
|
|||||||
quantity = self.positions[symbol_in_position]
|
quantity = self.positions[symbol_in_position]
|
||||||
# 持仓市值 = 数量 * 当前市场价格 (current_bar.close)
|
# 持仓市值 = 数量 * 当前市场价格 (current_bar.close)
|
||||||
# 无论多头(quantity > 0)还是空头(quantity < 0),这个计算都是正确的
|
# 无论多头(quantity > 0)还是空头(quantity < 0),这个计算都是正确的
|
||||||
total_value += quantity * current_bar.close
|
total_value += quantity * current_bar.open
|
||||||
|
|
||||||
# 您也可以选择在这里打印调试信息
|
# 您也可以选择在这里打印调试信息
|
||||||
# print(f" DEBUG Portfolio Value Calculation: Cash={self.cash:.2f}, "
|
# print(f" DEBUG Portfolio Value Calculation: Cash={self.cash:.2f}, "
|
||||||
@@ -287,3 +374,14 @@ class ExecutionSimulator:
|
|||||||
"""
|
"""
|
||||||
print("ExecutionSimulator: 清空交易历史。")
|
print("ExecutionSimulator: 清空交易历史。")
|
||||||
self.trade_history = []
|
self.trade_history = []
|
||||||
|
|
||||||
|
def get_average_position_price(self, symbol: str) -> Optional[float]:
|
||||||
|
"""
|
||||||
|
获取指定合约的平均持仓成本。
|
||||||
|
如果无持仓或无该合约记录,返回 None。
|
||||||
|
"""
|
||||||
|
# 返回 average_costs 字典中对应 symbol 的值
|
||||||
|
# 如果没有持仓或者没有记录,返回 None
|
||||||
|
if symbol in self.positions and self.positions[symbol] != 0:
|
||||||
|
return self.average_costs.get(symbol)
|
||||||
|
return None
|
||||||
|
|||||||
0
src/research/__init__.py
Normal file
0
src/research/__init__.py
Normal file
99766
src/research/grid_search.ipynb
Normal file
99766
src/research/grid_search.ipynb
Normal file
File diff suppressed because one or more lines are too long
183
src/strategies/SimpleLimitBuyStrategy.py
Normal file
183
src/strategies/SimpleLimitBuyStrategy.py
Normal file
@@ -0,0 +1,183 @@
|
|||||||
|
# src/strategies/simple_limit_buy_strategy.py
|
||||||
|
|
||||||
|
from .base_strategy import Strategy
|
||||||
|
from ..core_data import Bar, Order
|
||||||
|
from typing import Optional, Dict, Any
|
||||||
|
from collections import deque
|
||||||
|
|
||||||
|
class SimpleLimitBuyStrategy(Strategy):
|
||||||
|
"""
|
||||||
|
一个基于当前K线Open、前1根和前7根K线Range计算优势价格进行限价买入的策略。
|
||||||
|
具备以下特点:
|
||||||
|
- 每根K线开始时取消上一根K线未成交的订单。
|
||||||
|
- 最多只能有一个开仓挂单和一个持仓。
|
||||||
|
- 包含简单的止损和止盈逻辑。
|
||||||
|
"""
|
||||||
|
def __init__(self, simulator: Any, symbol: str, enable_log: bool, trade_volume: int,
|
||||||
|
open_range_factor_1_ago: float,
|
||||||
|
open_range_factor_7_ago: float,
|
||||||
|
max_position: int,
|
||||||
|
stop_loss_points: float = 10, # 新增:止损点数
|
||||||
|
take_profit_points: float = 10): # 新增:止盈点数
|
||||||
|
"""
|
||||||
|
初始化策略。
|
||||||
|
Args:
|
||||||
|
simulator: 模拟器实例。
|
||||||
|
symbol (str): 交易合约代码。
|
||||||
|
trade_volume (int): 单笔交易量。
|
||||||
|
open_range_factor_1_ago (float): 前1根K线Range的权重因子,用于从Open价向下偏移。
|
||||||
|
open_range_factor_7_ago (float): 前7根K线Range的权重因子,用于从Open价向下偏移。
|
||||||
|
max_position (int): 最大持仓量(此处为1,因为只允许一个持仓)。
|
||||||
|
stop_loss_points (float): 止损点数(例如,亏损达到此点数则止损)。
|
||||||
|
take_profit_points (float): 止盈点数(例如,盈利达到此点数则止盈)。
|
||||||
|
"""
|
||||||
|
super().__init__(simulator, symbol, enable_log)
|
||||||
|
self.trade_volume = trade_volume
|
||||||
|
self.open_range_factor_1_ago = open_range_factor_1_ago
|
||||||
|
self.open_range_factor_7_ago = open_range_factor_7_ago
|
||||||
|
self.max_position = max_position # 理论上这里应为1
|
||||||
|
self.stop_loss_points = stop_loss_points
|
||||||
|
self.take_profit_points = take_profit_points
|
||||||
|
|
||||||
|
self.order_id_counter = 0
|
||||||
|
|
||||||
|
self._bar_history: deque[Bar] = deque(maxlen=7)
|
||||||
|
self._last_order_id: Optional[str] = None # 用于跟踪上一根K线发出的订单ID
|
||||||
|
|
||||||
|
self.log(f"策略初始化: symbol={self.symbol}, trade_volume={self.trade_volume}, "
|
||||||
|
f"open_range_factor_1_ago={self.open_range_factor_1_ago}, "
|
||||||
|
f"open_range_factor_7_ago={self.open_range_factor_7_ago}, "
|
||||||
|
f"max_position={self.max_position}, "
|
||||||
|
f"止损点={self.stop_loss_points}, 止盈点={self.take_profit_points}")
|
||||||
|
|
||||||
|
def on_bar(self, bar: Bar, next_bar_open: Optional[float] = None):
|
||||||
|
"""
|
||||||
|
每当新的K线数据到来时调用。
|
||||||
|
Args:
|
||||||
|
bar (Bar): 当前的K线数据对象。
|
||||||
|
next_bar_open (Optional[float]): 下一根K线的开盘价,此处策略未使用。
|
||||||
|
"""
|
||||||
|
current_datetime = bar.datetime # 获取当前K线时间
|
||||||
|
|
||||||
|
# --- 1. 撤销上一根K线未成交的订单 ---
|
||||||
|
# 检查是否记录了上一笔订单ID,并且该订单仍然在待处理列表中
|
||||||
|
if self._last_order_id:
|
||||||
|
pending_orders = self.get_pending_orders()
|
||||||
|
if self._last_order_id in pending_orders:
|
||||||
|
success = self.cancel_order(self._last_order_id) # 直接调用基类的取消方法
|
||||||
|
if success:
|
||||||
|
self.log(f"[{current_datetime}] 策略: 成功撤销上一根K线未成交订单 {self._last_order_id}")
|
||||||
|
else:
|
||||||
|
self.log(f"[{current_datetime}] 策略: 尝试撤销订单 {self._last_order_id} 失败(可能已成交或不存在)")
|
||||||
|
# 无论撤销成功与否,既然我们尝试了撤销,就清除记录
|
||||||
|
self._last_order_id = None
|
||||||
|
# else:
|
||||||
|
# self.log(f"[{current_datetime}] 策略: 无上一根K线未成交订单需要撤销。")
|
||||||
|
|
||||||
|
|
||||||
|
# 2. 更新K线历史
|
||||||
|
self._bar_history.append(bar)
|
||||||
|
trade_volume = self.trade_volume
|
||||||
|
|
||||||
|
# 获取当前持仓和未决订单(在取消之后获取,确保是最新的状态)
|
||||||
|
current_positions = self.get_current_positions()
|
||||||
|
current_pos_volume = current_positions.get(self.symbol, 0)
|
||||||
|
pending_orders_after_cancel = self.get_pending_orders() # 再次获取,此时应已取消旧订单
|
||||||
|
|
||||||
|
# --- 3. 平仓逻辑 (止损/止盈) ---
|
||||||
|
# 只有当有持仓时才考虑平仓
|
||||||
|
if current_pos_volume > 0: # 假设只做多,所以持仓量 > 0
|
||||||
|
|
||||||
|
avg_entry_price = self.get_average_position_price(self.symbol)
|
||||||
|
if avg_entry_price is not None:
|
||||||
|
pnl_per_unit = bar.close - avg_entry_price # 当前浮动盈亏(以收盘价计算)
|
||||||
|
|
||||||
|
# 止盈条件
|
||||||
|
if pnl_per_unit >= self.take_profit_points:
|
||||||
|
self.log(f"[{current_datetime}] 止盈信号 - PnL per unit: {pnl_per_unit:.2f}, 目标: {self.take_profit_points:.2f}")
|
||||||
|
|
||||||
|
order_id = f"{self.symbol}_BUY_{bar.datetime.strftime('%Y%m%d%H%M%S')}_{self.order_id_counter}"
|
||||||
|
self.order_id_counter += 1
|
||||||
|
|
||||||
|
# 创建一个限价多单
|
||||||
|
order = Order(
|
||||||
|
id=order_id,
|
||||||
|
symbol=self.symbol,
|
||||||
|
direction="CLOSE_LONG",
|
||||||
|
volume=trade_volume,
|
||||||
|
price_type="MARKET",
|
||||||
|
# limit_price=limit_price,
|
||||||
|
submitted_time=bar.datetime
|
||||||
|
)
|
||||||
|
trade = self.send_order(order)
|
||||||
|
return # 平仓后本K线不再进行开仓判断
|
||||||
|
|
||||||
|
# 止损条件
|
||||||
|
elif pnl_per_unit <= -self.stop_loss_points:
|
||||||
|
self.log(f"[{current_datetime}] 止损信号 - PnL per unit: {pnl_per_unit:.2f}, 目标: {-self.stop_loss_points:.2f}")
|
||||||
|
# 发送市价卖出订单平仓,确保立即成交
|
||||||
|
order_id = f"{self.symbol}_BUY_{bar.datetime.strftime('%Y%m%d%H%M%S')}_{self.order_id_counter}"
|
||||||
|
self.order_id_counter += 1
|
||||||
|
|
||||||
|
# 创建一个限价多单
|
||||||
|
order = Order(
|
||||||
|
id=order_id,
|
||||||
|
symbol=self.symbol,
|
||||||
|
direction="CLOSE_LONG",
|
||||||
|
volume=trade_volume,
|
||||||
|
price_type="MARKET",
|
||||||
|
# limit_price=limit_price,
|
||||||
|
submitted_time=bar.datetime
|
||||||
|
)
|
||||||
|
trade = self.send_order(order)
|
||||||
|
return # 平仓后本K线不再进行开仓判断
|
||||||
|
|
||||||
|
# --- 4. 开仓逻辑 (只考虑做多 BUY 方向) ---
|
||||||
|
# 只有在没有持仓 (current_pos_volume == 0) 且没有待处理订单 (not pending_orders_after_cancel)
|
||||||
|
# 且K线历史足够长时才考虑开仓
|
||||||
|
if current_pos_volume == 0 and \
|
||||||
|
len(self._bar_history) == self._bar_history.maxlen:
|
||||||
|
|
||||||
|
# 获取前1根K线 (倒数第二根) 和前7根K线 (队列中最老的一根)
|
||||||
|
bar_1_ago = self._bar_history[-2]
|
||||||
|
bar_7_ago = self._bar_history[0]
|
||||||
|
|
||||||
|
# 计算历史 K 线的 Range
|
||||||
|
range_1_ago = bar_1_ago.high - bar_1_ago.low
|
||||||
|
range_7_ago = bar_7_ago.high - bar_7_ago.low
|
||||||
|
|
||||||
|
# 根据策略逻辑计算目标买入价格
|
||||||
|
# 目标买入价 = 当前K线Open - (前1根Range * 因子1 + 前7根Range * 因子2)
|
||||||
|
self.log(bar.open ,range_1_ago * self.open_range_factor_1_ago, range_7_ago * self.open_range_factor_7_ago)
|
||||||
|
target_buy_price = bar.open - (range_1_ago * self.open_range_factor_1_ago + range_7_ago * self.open_range_factor_7_ago)
|
||||||
|
|
||||||
|
# 确保目标买入价格有效,例如不能是负数
|
||||||
|
target_buy_price = max(0.01, target_buy_price)
|
||||||
|
|
||||||
|
self.log(f"[{current_datetime}] 开多仓信号 - 当前Open={bar.open:.2f}, "
|
||||||
|
f"前1Range={range_1_ago:.2f}, 前7Range={range_7_ago:.2f}, "
|
||||||
|
f"计算目标买入价={target_buy_price:.2f}")
|
||||||
|
|
||||||
|
order_id = f"{self.symbol}_BUY_{bar.datetime.strftime('%Y%m%d%H%M%S')}_{self.order_id_counter}"
|
||||||
|
self.order_id_counter += 1
|
||||||
|
|
||||||
|
# 创建一个限价多单
|
||||||
|
order = Order(
|
||||||
|
id=order_id,
|
||||||
|
symbol=self.symbol,
|
||||||
|
direction="BUY",
|
||||||
|
volume=trade_volume,
|
||||||
|
price_type="LIMIT",
|
||||||
|
limit_price=target_buy_price,
|
||||||
|
submitted_time=bar.datetime
|
||||||
|
)
|
||||||
|
new_order = self.send_order(order)
|
||||||
|
# 记录下这个订单的ID,以便在下一根K线开始时进行撤销
|
||||||
|
if new_order:
|
||||||
|
self._last_order_id = new_order.order_id
|
||||||
|
self.log(f"[{current_datetime}] 策略: 发送限价买入订单 {self._last_order_id} @ {target_buy_price:.2f}")
|
||||||
|
else:
|
||||||
|
self.log(f"[{current_datetime}] 策略: 发送订单失败。")
|
||||||
|
|
||||||
|
# else:
|
||||||
|
# self.log(f"[{current_datetime}] 不满足开仓条件:持仓={current_pos_volume}, 待处理订单={len(pending_orders_after_cancel)}, K线历史长度={len(self._bar_history)}")
|
||||||
@@ -1,42 +1,31 @@
|
|||||||
# src/strategies/base_strategy.py
|
# src/strategies/base_strategy.py
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Dict, Any, Optional
|
from datetime import datetime
|
||||||
|
from typing import Dict, Any, Optional, List, TYPE_CHECKING
|
||||||
|
|
||||||
|
# 使用 TYPE_CHECKING 避免循环导入,但保留类型提示
|
||||||
|
from ..backtest_context import BacktestContext # 转发引用 BacktestEngine
|
||||||
|
from ..core_data import Bar, Order, Trade # 导入必要的类型
|
||||||
|
|
||||||
# 导入核心数据类
|
|
||||||
from ..core_data import Bar, Order, Trade
|
|
||||||
# 导入回测上下文 (注意相对导入路径的变化)
|
|
||||||
from ..backtest_context import BacktestContext
|
|
||||||
|
|
||||||
class Strategy(ABC):
|
class Strategy(ABC):
|
||||||
"""
|
"""
|
||||||
策略抽象基类。所有具体策略都应继承此类,并实现 on_bar 方法。
|
所有交易策略的抽象基类。
|
||||||
|
策略通过 context 对象与回测引擎和模拟器进行交互,并提供辅助方法。
|
||||||
"""
|
"""
|
||||||
def __init__(self, context: BacktestContext, **parameters: Any):
|
|
||||||
"""
|
|
||||||
初始化策略。
|
|
||||||
|
|
||||||
|
def __init__(self, context: 'BacktestContext', symbol: str, enable_log: bool = True, **params: Any):
|
||||||
|
"""
|
||||||
Args:
|
Args:
|
||||||
context (BacktestContext): 回测上下文对象,用于与模拟器和数据管理器交互。
|
context (BacktestEngine): 回测引擎实例,作为策略的上下文,提供与模拟器等的交互接口。
|
||||||
**parameters (Any): 策略所需的任何自定义参数。
|
symbol (str): 策略操作的合约Symbol。
|
||||||
|
**params (Any): 其他策略特定参数。
|
||||||
"""
|
"""
|
||||||
self.context = context
|
self.context = context # 存储 context 对象
|
||||||
self.parameters = parameters
|
self.symbol = symbol # 策略操作的合约Symbol
|
||||||
self.symbol = parameters.get('symbol', "DEFAULT_SYMBOL") # 策略操作的品种
|
self.params = params
|
||||||
self.trade_volume = parameters.get('trade_volume', 100) # 每次下单的数量
|
self.enable_log = enable_log
|
||||||
print(f"策略初始化: {self.__class__.__name__},参数: {self.parameters}")
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def on_bar(self, bar: Bar):
|
|
||||||
"""
|
|
||||||
每当接收到新的Bar数据时调用。
|
|
||||||
具体策略逻辑在此方法中实现。
|
|
||||||
|
|
||||||
Args:
|
|
||||||
bar (Bar): 当前的Bar数据对象。
|
|
||||||
features (Dict[str, float]): 由数据处理模块计算并传入的特征字典。
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
def on_init(self):
|
def on_init(self):
|
||||||
"""
|
"""
|
||||||
@@ -45,7 +34,7 @@ class Strategy(ABC):
|
|||||||
"""
|
"""
|
||||||
print(f"{self.__class__.__name__} 策略初始化回调被调用。")
|
print(f"{self.__class__.__name__} 策略初始化回调被调用。")
|
||||||
|
|
||||||
def on_trade(self, trade: Trade):
|
def on_trade(self, trade: 'Trade'):
|
||||||
"""
|
"""
|
||||||
当模拟器成功执行一笔交易时调用。
|
当模拟器成功执行一笔交易时调用。
|
||||||
可用于更新策略内部持仓状态或记录交易。
|
可用于更新策略内部持仓状态或记录交易。
|
||||||
@@ -54,15 +43,94 @@ class Strategy(ABC):
|
|||||||
trade (Trade): 已完成的交易记录。
|
trade (Trade): 已完成的交易记录。
|
||||||
"""
|
"""
|
||||||
# print(f"策略接收到交易: {trade.direction} {trade.volume} {trade.symbol} @ {trade.price:.2f}")
|
# print(f"策略接收到交易: {trade.direction} {trade.volume} {trade.symbol} @ {trade.price:.2f}")
|
||||||
pass # 默认不执行任何操作,具体策略可覆盖
|
pass # 默认不执行任何操作,具体策略可覆盖
|
||||||
|
|
||||||
def on_order_status(self, order: Order, status: str):
|
@abstractmethod
|
||||||
|
def on_bar(self, bar: 'Bar'):
|
||||||
"""
|
"""
|
||||||
当订单状态更新时调用 (例如,未成交,已提交等)。
|
每当新的K线数据到来时调用此方法。
|
||||||
在简易回测中,可能不会频繁使用。
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
order (Order): 相关订单对象。
|
bar (Bar): 当前的K线数据对象。
|
||||||
status (str): 订单状态(例如 "FILLED", "PENDING", "CANCELLED")。
|
next_bar_open (Optional[float]): 下一根K线的开盘价,如果存在的话。
|
||||||
"""
|
"""
|
||||||
pass # 默认不执行任何操作
|
pass
|
||||||
|
|
||||||
|
# --- 新增/修改的辅助方法 ---
|
||||||
|
|
||||||
|
def send_order(self, order: 'Order') -> Optional[Trade]:
|
||||||
|
"""
|
||||||
|
发送订单的辅助方法。
|
||||||
|
会在 BaseStrategy 内部构建 Order 对象,并通过 context 转发给模拟器。
|
||||||
|
"""
|
||||||
|
return self.context.send_order(order)
|
||||||
|
|
||||||
|
def cancel_order(self, order_id: str) -> bool:
|
||||||
|
"""
|
||||||
|
取消指定ID的订单。
|
||||||
|
通过 context 调用模拟器的 cancel_order 方法。
|
||||||
|
"""
|
||||||
|
|
||||||
|
return self.context.cancel_order(order_id)
|
||||||
|
|
||||||
|
def cancel_all_pending_orders(self) -> int:
|
||||||
|
"""
|
||||||
|
取消所有当前策略的未决订单。
|
||||||
|
返回成功取消的订单数量。
|
||||||
|
"""
|
||||||
|
pending_orders = self.get_pending_orders() # 调用 BaseStrategy 自己的 get_pending_orders
|
||||||
|
cancelled_count = 0
|
||||||
|
orders_to_cancel = [order.id for order in pending_orders.values() if order.symbol == self.symbol]
|
||||||
|
for order_id in orders_to_cancel:
|
||||||
|
if self.cancel_order(order_id): # 调用 BaseStrategy 自己的 cancel_order
|
||||||
|
cancelled_count += 1
|
||||||
|
return cancelled_count
|
||||||
|
|
||||||
|
def get_current_positions(self) -> Dict[str, int]:
|
||||||
|
"""
|
||||||
|
获取当前持仓。
|
||||||
|
通过 context 调用模拟器的 get_positions 方法。
|
||||||
|
"""
|
||||||
|
return self.context._simulator.get_current_positions()
|
||||||
|
|
||||||
|
def get_pending_orders(self) -> Dict[str, 'Order']:
|
||||||
|
"""
|
||||||
|
获取当前所有待处理订单的副本。
|
||||||
|
通过 context 调用模拟器的 get_pending_orders 方法。
|
||||||
|
"""
|
||||||
|
return self.context._simulator.get_pending_orders()
|
||||||
|
|
||||||
|
def get_average_position_price(self, symbol: str) -> Optional[float]:
|
||||||
|
"""
|
||||||
|
获取指定合约的平均持仓成本。
|
||||||
|
通过 context 调用模拟器的 get_average_position_price 方法。
|
||||||
|
"""
|
||||||
|
return self.context._simulator.get_average_position_price(symbol)
|
||||||
|
|
||||||
|
# 你可以根据需要在这里添加更多辅助方法,如获取账户净值等
|
||||||
|
def get_account_cash(self) -> float:
|
||||||
|
"""获取当前账户现金余额。"""
|
||||||
|
return self.context._simulator.cash
|
||||||
|
|
||||||
|
|
||||||
|
def log(self, *args: Any, **kwargs: Any):
|
||||||
|
"""
|
||||||
|
统一的日志打印方法。
|
||||||
|
如果 enable_log 为 True,则打印消息到控制台,并包含当前模拟时间。
|
||||||
|
支持传入多个参数,像 print() 函数一样使用。
|
||||||
|
"""
|
||||||
|
if self.enable_log:
|
||||||
|
# 尝试获取当前模拟时间,如果模拟器或时间不可用,则跳过时间前缀
|
||||||
|
try:
|
||||||
|
current_time_str = self.context._simulator.get_current_time().strftime('%Y-%m-%d %H:%M:%S')
|
||||||
|
time_prefix = f"[{current_time_str}] "
|
||||||
|
except AttributeError:
|
||||||
|
# 如果获取不到时间(例如在策略初始化时,模拟器时间还未设置),则不加时间前缀
|
||||||
|
time_prefix = ""
|
||||||
|
|
||||||
|
# 使用 f-string 结合 *args 来构建消息
|
||||||
|
# print() 函数会将 *args 自动用空格分隔,这里我们模仿这个行为
|
||||||
|
message = ' '.join(map(str, args))
|
||||||
|
|
||||||
|
# 你可以将其他 kwargs (如 sep, end, file, flush) 传递给 print,
|
||||||
|
# 但通常日志方法不会频繁使用这些。这里只支持最基础的打印。
|
||||||
|
print(f"{time_prefix}策略 ({self.symbol}): {message}", **kwargs)
|
||||||
@@ -19,9 +19,10 @@ class SimpleLimitBuyStrategy(Strategy):
|
|||||||
|
|
||||||
def __init__(self, context: Any, **parameters: Any): # context 类型提示可以为 BacktestContext
|
def __init__(self, context: Any, **parameters: Any): # context 类型提示可以为 BacktestContext
|
||||||
super().__init__(context, **parameters)
|
super().__init__(context, **parameters)
|
||||||
|
self.trade_volume = 1
|
||||||
self.order_id_counter = 0
|
self.order_id_counter = 0
|
||||||
self.limit_price_factor = parameters.get('limit_price_factor', 0.999) # 限价因子
|
self.limit_price_factor = parameters.get('limit_price_factor', 0.999) # 限价因子
|
||||||
self.max_position = parameters.get('max_position', self.trade_volume * 2) # 最大持仓量
|
self.max_position = parameters.get('max_position', self.trade_volume) # 最大持仓量
|
||||||
|
|
||||||
self._last_order_id: Optional[str] = None # 跟踪上一根Bar发出的订单ID
|
self._last_order_id: Optional[str] = None # 跟踪上一根Bar发出的订单ID
|
||||||
self._current_long_position: int = 0 # 策略内部维护的当前持仓
|
self._current_long_position: int = 0 # 策略内部维护的当前持仓
|
||||||
@@ -57,16 +58,14 @@ class SimpleLimitBuyStrategy(Strategy):
|
|||||||
每接收到一根Bar时,执行策略逻辑。
|
每接收到一根Bar时,执行策略逻辑。
|
||||||
"""
|
"""
|
||||||
current_portfolio_value = self.context.get_current_portfolio_value(bar)
|
current_portfolio_value = self.context.get_current_portfolio_value(bar)
|
||||||
print(f"[{bar.datetime}] Strategy processing Bar. Current close price: {bar.close:.2f}. Current Portfolio Value: {current_portfolio_value:.2f}")
|
# print(f"[{bar.datetime}] Strategy processing Bar. Current close price: {bar.close:.2f}. Current Portfolio Value: {current_portfolio_value:.2f}")
|
||||||
|
|
||||||
# 1. 撤销上一根K线未成交的订单
|
# 1. 撤销上一根K线未成交的订单
|
||||||
if self._last_order_id:
|
if self._last_order_id:
|
||||||
# 检查这个订单是否仍然在待处理订单列表中
|
# 检查这个订单是否仍然在待处理订单列表中
|
||||||
pending_orders = self.context._simulator.get_pending_orders() # 直接访问模拟器,或者通过context提供接口
|
pending_orders = self.get_pending_orders() # 直接访问模拟器,或者通过context提供接口
|
||||||
if self._last_order_id in pending_orders:
|
if self._last_order_id in pending_orders:
|
||||||
success = self.context.send_order(Order(id=self._last_order_id, symbol=self.symbol,
|
success = self.cancel_order(self._last_order_id)
|
||||||
direction="CANCEL", volume=0,
|
|
||||||
price_type="CANCEL")) # 使用一个特殊Order类型表示撤单
|
|
||||||
# 这里发送的“撤单订单”会被simulator的send_order处理,并调用simulator.cancel_order
|
# 这里发送的“撤单订单”会被simulator的send_order处理,并调用simulator.cancel_order
|
||||||
if success: # simulator.send_order返回Trade或None,这里我们用一个特殊处理
|
if success: # simulator.send_order返回Trade或None,这里我们用一个特殊处理
|
||||||
# Simulator的send_order返回的是Trade,如果实现撤单,最好Simulator的cancel_order返回bool
|
# Simulator的send_order返回的是Trade,如果实现撤单,最好Simulator的cancel_order返回bool
|
||||||
@@ -109,10 +108,10 @@ class SimpleLimitBuyStrategy(Strategy):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# 通过上下文发送订单
|
# 通过上下文发送订单
|
||||||
trade = self.context.send_order(order)
|
trade = self.send_order(order)
|
||||||
if trade:
|
if trade:
|
||||||
print(
|
print(
|
||||||
f"[{bar.datetime}] 策略: 发送并立即成交限价买单 {trade.volume} 股 @ {trade.price:.2f} (订单ID: {order.id})")
|
f"[{bar.datetime}] 策略: 发送并立即成交限价买单 {trade.volume} 股 @ {trade.price:.2f}(open:{bar.open}, close:{bar.close}) (订单ID: {order.id})")
|
||||||
# 如果立即成交,_last_order_id 仍然保持 None
|
# 如果立即成交,_last_order_id 仍然保持 None
|
||||||
else:
|
else:
|
||||||
# 如果未立即成交,将订单ID记录下来,以便下一根Bar撤销
|
# 如果未立即成交,将订单ID记录下来,以便下一根Bar撤销
|
||||||
|
|||||||
Reference in New Issue
Block a user