简单波动率策略,实现+网格搜索
This commit is contained in:
@@ -9,7 +9,7 @@ from ..core_data import PortfolioSnapshot, Trade, Bar
|
||||
|
||||
|
||||
def calculate_metrics(
|
||||
snapshots: List[PortfolioSnapshot], trades: List[Trade], initial_capital: float
|
||||
snapshots: List[PortfolioSnapshot], trades: List[Trade], initial_capital: float
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
纯函数:根据投资组合快照和交易历史计算关键绩效指标。
|
||||
@@ -124,11 +124,27 @@ def calculate_metrics(
|
||||
"亏损交易次数": losing_count,
|
||||
"平均每次盈利": avg_profit_per_trade,
|
||||
"平均每次亏损": avg_loss_per_trade, # 这个值是负数
|
||||
"InitialCapital": initial_capital,
|
||||
"FinalCapital": final_value,
|
||||
"TotalReturn": total_return,
|
||||
"AnnualizedReturn": annualized_return,
|
||||
"MaxDrawdown": max_drawdown,
|
||||
"SharpeRatio": sharpe_ratio,
|
||||
"CalmarRatio": calmar_ratio,
|
||||
"TotalTrades": len(trades), # All buy and sell trades
|
||||
"TransactionCosts": total_commissions,
|
||||
"TotalRealizedPNL": total_realized_pnl, # New
|
||||
"WinRate": win_rate,
|
||||
"ProfitLossRatio": profit_loss_ratio,
|
||||
"WinningTradesCount": winning_count,
|
||||
"LosingTradesCount": losing_count,
|
||||
"AvgProfitPerTrade": avg_profit_per_trade,
|
||||
"AvgLossPerTrade": avg_loss_per_trade, # This value is negative
|
||||
}
|
||||
|
||||
|
||||
def plot_equity_and_drawdown_chart(snapshots: List[PortfolioSnapshot], initial_capital: float,
|
||||
title: str = "Portfolio Equity and Drawdown Curve") -> None:
|
||||
title: str = "Portfolio Equity and Drawdown Curve") -> None:
|
||||
"""
|
||||
Plots the portfolio equity curve and drawdown. X-axis points are equally spaced.
|
||||
|
||||
@@ -145,7 +161,7 @@ def plot_equity_and_drawdown_chart(snapshots: List[PortfolioSnapshot], initial_c
|
||||
{'datetime': s.datetime, 'total_value': s.total_value}
|
||||
for s in snapshots
|
||||
])
|
||||
|
||||
|
||||
equity_curve = df_equity['total_value'] / initial_capital
|
||||
|
||||
rolling_max = equity_curve.cummax()
|
||||
@@ -203,7 +219,7 @@ def plot_close_price_chart(bars: List[Bar], title: str = "Close Price Chart") ->
|
||||
])
|
||||
|
||||
plt.style.use('seaborn-v0_8-darkgrid')
|
||||
fig, ax = plt.subplots(1, 1, figsize=(14, 7)) # Single subplot
|
||||
fig, ax = plt.subplots(1, 1, figsize=(14, 7)) # Single subplot
|
||||
|
||||
x_axis_indices = np.arange(len(df_prices))
|
||||
|
||||
@@ -228,7 +244,7 @@ def plot_close_price_chart(bars: List[Bar], title: str = "Close Price Chart") ->
|
||||
|
||||
# 辅助函数:计算单笔交易的盈亏
|
||||
def calculate_trade_pnl(
|
||||
trade: Trade, entry_price: float, exit_price: float, direction: str
|
||||
trade: Trade, entry_price: float, exit_price: float, direction: str
|
||||
) -> float:
|
||||
if direction == "LONG":
|
||||
pnl = (exit_price - entry_price) * trade.volume
|
||||
|
||||
90
src/analysis/grid_search_analyzer.py
Normal file
90
src/analysis/grid_search_analyzer.py
Normal file
@@ -0,0 +1,90 @@
|
||||
# src/grid_search_analyzer.py
|
||||
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
from typing import List, Dict, Any, Tuple
|
||||
|
||||
class GridSearchAnalyzer:
|
||||
"""
|
||||
用于分析和可视化网格搜索结果的类。
|
||||
"""
|
||||
def __init__(self, grid_results: List[Dict[str, Any]], optimization_metric: str):
|
||||
"""
|
||||
初始化 GridSearchAnalyzer。
|
||||
|
||||
Args:
|
||||
grid_results (List[Dict[str, Any]]): 包含每个参数组合及其对应优化指标的列表。
|
||||
例如:[{'param1': v1, 'param2': v2, 'metric': m1}, ...]
|
||||
optimization_metric (str): 用于优化的指标名称(例如 'total_return')。
|
||||
"""
|
||||
if not grid_results:
|
||||
raise ValueError("grid_results 列表不能为空。")
|
||||
if optimization_metric not in grid_results[0]:
|
||||
raise ValueError(f"优化指标 '{optimization_metric}' 不在 grid_results 的字典中。")
|
||||
|
||||
self.grid_results = grid_results
|
||||
self.optimization_metric = optimization_metric
|
||||
self.param_names = [k for k in grid_results[0].keys() if k != optimization_metric]
|
||||
|
||||
if len(self.param_names) != 2:
|
||||
raise ValueError("GridSearchAnalyzer 当前只支持分析两个参数的网格搜索结果。")
|
||||
|
||||
self.param1_name = self.param_names[0]
|
||||
self.param2_name = self.param_names[1]
|
||||
|
||||
def find_best_parameters(self) -> Dict[str, Any]:
|
||||
"""
|
||||
找到网格搜索中表现最佳的参数组合和其指标值。
|
||||
"""
|
||||
if not self.grid_results:
|
||||
return {}
|
||||
best_result = max(self.grid_results, key=lambda x: x[self.optimization_metric])
|
||||
print(f"\n--- 最佳参数组合 ---")
|
||||
print(f" {self.param1_name}: {best_result[self.param1_name]}")
|
||||
print(f" {self.param2_name}: {best_result[self.param2_name]}")
|
||||
print(f" {self.optimization_metric}: {best_result[self.optimization_metric]:.4f}")
|
||||
return best_result
|
||||
|
||||
def plot_heatmap(self, title: str = "heatmap"):
|
||||
"""
|
||||
绘制两个参数的热力图。
|
||||
"""
|
||||
if not self.grid_results:
|
||||
print("没有数据用于绘制热力图。")
|
||||
return
|
||||
|
||||
x_values = sorted(list(set(d[self.param1_name] for d in self.grid_results)))
|
||||
y_values = sorted(list(set(d[self.param2_name] for d in self.grid_results)))
|
||||
|
||||
heatmap_matrix = np.zeros((len(y_values), len(x_values)))
|
||||
|
||||
# 填充网格
|
||||
for item in self.grid_results:
|
||||
x_idx = x_values.index(item[self.param1_name])
|
||||
y_idx = y_values.index(item[self.param2_name])
|
||||
heatmap_matrix[y_idx, x_idx] = item[self.optimization_metric]
|
||||
|
||||
fig, ax = plt.subplots(figsize=(10, 8))
|
||||
im = ax.imshow(heatmap_matrix, cmap='viridis', origin='lower',
|
||||
extent=[x_values[0], x_values[-1], y_values[0], y_values[-1]],
|
||||
aspect='auto')
|
||||
|
||||
ax.set_xticks(x_values)
|
||||
ax.set_yticks(y_values)
|
||||
|
||||
ax.set_xlabel(self.param1_name)
|
||||
ax.set_ylabel(self.param2_name)
|
||||
ax.set_title(f"{title}: {self.optimization_metric} vs. {self.param1_name} & {self.param2_name}")
|
||||
|
||||
cbar = fig.colorbar(im, ax=ax)
|
||||
cbar.set_label(self.optimization_metric)
|
||||
|
||||
# 在每个格子上显示数值
|
||||
for i in range(len(y_values)):
|
||||
for j in range(len(x_values)):
|
||||
text = ax.text(x_values[j], y_values[i], f"{heatmap_matrix[i, j]:.2f}",
|
||||
ha="center", va="center", color="w", fontsize=8)
|
||||
|
||||
plt.tight_layout()
|
||||
plt.show()
|
||||
@@ -1,10 +1,16 @@
|
||||
# src/analysis/result_analyzer.py
|
||||
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
from typing import List, Dict, Any, Optional
|
||||
|
||||
# 导入纯函数 (注意相对导入路径的变化)
|
||||
from .analysis_utils import calculate_metrics, plot_equity_and_drawdown_chart, plot_close_price_chart
|
||||
from .analysis_utils import (
|
||||
calculate_metrics,
|
||||
plot_equity_and_drawdown_chart,
|
||||
plot_close_price_chart,
|
||||
)
|
||||
|
||||
# 导入核心数据类 (注意相对导入路径的变化)
|
||||
from ..core_data import PortfolioSnapshot, Trade, Bar
|
||||
|
||||
@@ -14,11 +20,13 @@ class ResultAnalyzer:
|
||||
结果分析器:负责接收回测数据,并提供分析和可视化方法。
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
portfolio_snapshots: List[PortfolioSnapshot],
|
||||
trade_history: List[Trade],
|
||||
bars: List[Bar],
|
||||
initial_capital: float):
|
||||
def __init__(
|
||||
self,
|
||||
portfolio_snapshots: List[PortfolioSnapshot],
|
||||
trade_history: List[Trade],
|
||||
bars: List[Bar],
|
||||
initial_capital: float,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
portfolio_snapshots (List[PortfolioSnapshot]): 回测引擎输出的投资组合快照列表。
|
||||
@@ -41,9 +49,7 @@ class ResultAnalyzer:
|
||||
if self._metrics_cache is None:
|
||||
print("正在计算绩效指标...")
|
||||
self._metrics_cache = calculate_metrics(
|
||||
self.portfolio_snapshots,
|
||||
self.trade_history,
|
||||
self.initial_capital
|
||||
self.portfolio_snapshots, self.trade_history, self.initial_capital
|
||||
)
|
||||
print("绩效指标计算完成。")
|
||||
return self._metrics_cache
|
||||
@@ -69,7 +75,8 @@ class ResultAnalyzer:
|
||||
print("\n--- 部分交易明细 (最近5笔) ---")
|
||||
for trade in self.trade_history[-5:]:
|
||||
print(
|
||||
f" {trade.fill_time} | {trade.direction:<10} | {trade.symbol} | Vol: {trade.volume} | Price: {trade.price:.2f} | Commission: {trade.commission:.2f}")
|
||||
f" {trade.fill_time} | {trade.direction:<10} | {trade.symbol} | Vol: {trade.volume} | Price: {trade.price:.2f} | Commission: {trade.commission:.2f}"
|
||||
)
|
||||
else:
|
||||
print("\n没有交易记录。")
|
||||
|
||||
@@ -79,10 +86,13 @@ class ResultAnalyzer:
|
||||
"""
|
||||
print("正在绘制绩效图表...")
|
||||
# plot_performance_chart(self.portfolio_snapshots, self.initial_capital, self.bars)
|
||||
plot_equity_and_drawdown_chart(self.portfolio_snapshots, self.initial_capital,
|
||||
title="Portfolio Equity and Drawdown Curve")
|
||||
|
||||
plot_equity_and_drawdown_chart(
|
||||
self.portfolio_snapshots,
|
||||
self.initial_capital,
|
||||
title="Portfolio Equity and Drawdown Curve",
|
||||
)
|
||||
|
||||
# 绘制单独的收盘价曲线
|
||||
plot_close_price_chart(self.bars, title="Underlying Asset Close Price")
|
||||
|
||||
print("图表绘制完成。")
|
||||
|
||||
print("图表绘制完成。")
|
||||
|
||||
Reference in New Issue
Block a user