简单波动率策略,实现+网格搜索

This commit is contained in:
2025-06-22 23:03:50 +08:00
parent 355e451aac
commit a81a32ce73
19 changed files with 115435 additions and 748 deletions

View File

@@ -9,7 +9,7 @@ from ..core_data import PortfolioSnapshot, Trade, Bar
def calculate_metrics(
snapshots: List[PortfolioSnapshot], trades: List[Trade], initial_capital: float
snapshots: List[PortfolioSnapshot], trades: List[Trade], initial_capital: float
) -> Dict[str, Any]:
"""
纯函数:根据投资组合快照和交易历史计算关键绩效指标。
@@ -124,11 +124,27 @@ def calculate_metrics(
"亏损交易次数": losing_count,
"平均每次盈利": avg_profit_per_trade,
"平均每次亏损": avg_loss_per_trade, # 这个值是负数
"InitialCapital": initial_capital,
"FinalCapital": final_value,
"TotalReturn": total_return,
"AnnualizedReturn": annualized_return,
"MaxDrawdown": max_drawdown,
"SharpeRatio": sharpe_ratio,
"CalmarRatio": calmar_ratio,
"TotalTrades": len(trades), # All buy and sell trades
"TransactionCosts": total_commissions,
"TotalRealizedPNL": total_realized_pnl, # New
"WinRate": win_rate,
"ProfitLossRatio": profit_loss_ratio,
"WinningTradesCount": winning_count,
"LosingTradesCount": losing_count,
"AvgProfitPerTrade": avg_profit_per_trade,
"AvgLossPerTrade": avg_loss_per_trade, # This value is negative
}
def plot_equity_and_drawdown_chart(snapshots: List[PortfolioSnapshot], initial_capital: float,
title: str = "Portfolio Equity and Drawdown Curve") -> None:
title: str = "Portfolio Equity and Drawdown Curve") -> None:
"""
Plots the portfolio equity curve and drawdown. X-axis points are equally spaced.
@@ -145,7 +161,7 @@ def plot_equity_and_drawdown_chart(snapshots: List[PortfolioSnapshot], initial_c
{'datetime': s.datetime, 'total_value': s.total_value}
for s in snapshots
])
equity_curve = df_equity['total_value'] / initial_capital
rolling_max = equity_curve.cummax()
@@ -203,7 +219,7 @@ def plot_close_price_chart(bars: List[Bar], title: str = "Close Price Chart") ->
])
plt.style.use('seaborn-v0_8-darkgrid')
fig, ax = plt.subplots(1, 1, figsize=(14, 7)) # Single subplot
fig, ax = plt.subplots(1, 1, figsize=(14, 7)) # Single subplot
x_axis_indices = np.arange(len(df_prices))
@@ -228,7 +244,7 @@ def plot_close_price_chart(bars: List[Bar], title: str = "Close Price Chart") ->
# 辅助函数:计算单笔交易的盈亏
def calculate_trade_pnl(
trade: Trade, entry_price: float, exit_price: float, direction: str
trade: Trade, entry_price: float, exit_price: float, direction: str
) -> float:
if direction == "LONG":
pnl = (exit_price - entry_price) * trade.volume

View File

@@ -0,0 +1,90 @@
# src/grid_search_analyzer.py
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from typing import List, Dict, Any, Tuple
class GridSearchAnalyzer:
"""
用于分析和可视化网格搜索结果的类。
"""
def __init__(self, grid_results: List[Dict[str, Any]], optimization_metric: str):
"""
初始化 GridSearchAnalyzer。
Args:
grid_results (List[Dict[str, Any]]): 包含每个参数组合及其对应优化指标的列表。
例如:[{'param1': v1, 'param2': v2, 'metric': m1}, ...]
optimization_metric (str): 用于优化的指标名称(例如 'total_return')。
"""
if not grid_results:
raise ValueError("grid_results 列表不能为空。")
if optimization_metric not in grid_results[0]:
raise ValueError(f"优化指标 '{optimization_metric}' 不在 grid_results 的字典中。")
self.grid_results = grid_results
self.optimization_metric = optimization_metric
self.param_names = [k for k in grid_results[0].keys() if k != optimization_metric]
if len(self.param_names) != 2:
raise ValueError("GridSearchAnalyzer 当前只支持分析两个参数的网格搜索结果。")
self.param1_name = self.param_names[0]
self.param2_name = self.param_names[1]
def find_best_parameters(self) -> Dict[str, Any]:
"""
找到网格搜索中表现最佳的参数组合和其指标值。
"""
if not self.grid_results:
return {}
best_result = max(self.grid_results, key=lambda x: x[self.optimization_metric])
print(f"\n--- 最佳参数组合 ---")
print(f" {self.param1_name}: {best_result[self.param1_name]}")
print(f" {self.param2_name}: {best_result[self.param2_name]}")
print(f" {self.optimization_metric}: {best_result[self.optimization_metric]:.4f}")
return best_result
def plot_heatmap(self, title: str = "heatmap"):
"""
绘制两个参数的热力图。
"""
if not self.grid_results:
print("没有数据用于绘制热力图。")
return
x_values = sorted(list(set(d[self.param1_name] for d in self.grid_results)))
y_values = sorted(list(set(d[self.param2_name] for d in self.grid_results)))
heatmap_matrix = np.zeros((len(y_values), len(x_values)))
# 填充网格
for item in self.grid_results:
x_idx = x_values.index(item[self.param1_name])
y_idx = y_values.index(item[self.param2_name])
heatmap_matrix[y_idx, x_idx] = item[self.optimization_metric]
fig, ax = plt.subplots(figsize=(10, 8))
im = ax.imshow(heatmap_matrix, cmap='viridis', origin='lower',
extent=[x_values[0], x_values[-1], y_values[0], y_values[-1]],
aspect='auto')
ax.set_xticks(x_values)
ax.set_yticks(y_values)
ax.set_xlabel(self.param1_name)
ax.set_ylabel(self.param2_name)
ax.set_title(f"{title}: {self.optimization_metric} vs. {self.param1_name} & {self.param2_name}")
cbar = fig.colorbar(im, ax=ax)
cbar.set_label(self.optimization_metric)
# 在每个格子上显示数值
for i in range(len(y_values)):
for j in range(len(x_values)):
text = ax.text(x_values[j], y_values[i], f"{heatmap_matrix[i, j]:.2f}",
ha="center", va="center", color="w", fontsize=8)
plt.tight_layout()
plt.show()

View File

@@ -1,10 +1,16 @@
# src/analysis/result_analyzer.py
import pandas as pd
import numpy as np
from typing import List, Dict, Any, Optional
# 导入纯函数 (注意相对导入路径的变化)
from .analysis_utils import calculate_metrics, plot_equity_and_drawdown_chart, plot_close_price_chart
from .analysis_utils import (
calculate_metrics,
plot_equity_and_drawdown_chart,
plot_close_price_chart,
)
# 导入核心数据类 (注意相对导入路径的变化)
from ..core_data import PortfolioSnapshot, Trade, Bar
@@ -14,11 +20,13 @@ class ResultAnalyzer:
结果分析器:负责接收回测数据,并提供分析和可视化方法。
"""
def __init__(self,
portfolio_snapshots: List[PortfolioSnapshot],
trade_history: List[Trade],
bars: List[Bar],
initial_capital: float):
def __init__(
self,
portfolio_snapshots: List[PortfolioSnapshot],
trade_history: List[Trade],
bars: List[Bar],
initial_capital: float,
):
"""
Args:
portfolio_snapshots (List[PortfolioSnapshot]): 回测引擎输出的投资组合快照列表。
@@ -41,9 +49,7 @@ class ResultAnalyzer:
if self._metrics_cache is None:
print("正在计算绩效指标...")
self._metrics_cache = calculate_metrics(
self.portfolio_snapshots,
self.trade_history,
self.initial_capital
self.portfolio_snapshots, self.trade_history, self.initial_capital
)
print("绩效指标计算完成。")
return self._metrics_cache
@@ -69,7 +75,8 @@ class ResultAnalyzer:
print("\n--- 部分交易明细 (最近5笔) ---")
for trade in self.trade_history[-5:]:
print(
f" {trade.fill_time} | {trade.direction:<10} | {trade.symbol} | Vol: {trade.volume} | Price: {trade.price:.2f} | Commission: {trade.commission:.2f}")
f" {trade.fill_time} | {trade.direction:<10} | {trade.symbol} | Vol: {trade.volume} | Price: {trade.price:.2f} | Commission: {trade.commission:.2f}"
)
else:
print("\n没有交易记录。")
@@ -79,10 +86,13 @@ class ResultAnalyzer:
"""
print("正在绘制绩效图表...")
# plot_performance_chart(self.portfolio_snapshots, self.initial_capital, self.bars)
plot_equity_and_drawdown_chart(self.portfolio_snapshots, self.initial_capital,
title="Portfolio Equity and Drawdown Curve")
plot_equity_and_drawdown_chart(
self.portfolio_snapshots,
self.initial_capital,
title="Portfolio Equity and Drawdown Curve",
)
# 绘制单独的收盘价曲线
plot_close_price_chart(self.bars, title="Underlying Asset Close Price")
print("图表绘制完成。")
print("图表绘制完成。")