Files
NewQuant/src/analysis/analysis_utils.py

280 lines
9.3 KiB
Python
Raw Normal View History

2025-06-18 10:25:05 +08:00
# src/analysis/analysis_utils.py (修改 calculate_metrics 函数)
import matplotlib
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from typing import List, Dict, Any
from ..core_data import PortfolioSnapshot, Trade, Bar
def calculate_metrics(
2025-06-29 12:03:43 +08:00
snapshots: List[PortfolioSnapshot], trades: List[Trade], initial_capital: float
2025-06-18 10:25:05 +08:00
) -> Dict[str, Any]:
"""
纯函数根据投资组合快照和交易历史计算关键绩效指标
Args:
snapshots (List[PortfolioSnapshot]): 投资组合快照列表
trades (List[Trade]): 交易历史记录列表
initial_capital (float): 初始资金
Returns:
Dict[str, Any]: 包含各种绩效指标的字典
"""
if not snapshots:
return {
"总收益率": 0.0,
"年化收益率": 0.0,
"最大回撤": 0.0,
"夏普比率": 0.0,
"卡玛比率": 0.0,
"胜率": 0.0,
"盈亏比": 0.0,
"总交易次数": len(trades),
"盈利交易次数": 0,
"亏损交易次数": 0,
"平均每次盈利": 0.0,
"平均每次亏损": 0.0,
"交易成本": 0.0,
"总实现盈亏": 0.0,
}
df_values = pd.DataFrame(
[{"datetime": s.datetime, "total_value": s.total_value} for s in snapshots]
).set_index("datetime")
df_returns = df_values["total_value"].pct_change().fillna(0)
final_value = df_values["total_value"].iloc[-1]
total_return = (final_value / initial_capital) - 1
total_days = (df_values.index.max() - df_values.index.min()).days
if total_days > 0:
annualized_return = (1 + total_return) ** (252 / total_days) - 1
else:
annualized_return = 0.0
rolling_max = df_values["total_value"].cummax()
daily_drawdown = (rolling_max - df_values["total_value"]) / rolling_max
max_drawdown = daily_drawdown.max()
excess_daily_returns = df_returns
daily_volatility = excess_daily_returns.std()
if daily_volatility > 0:
sharpe_ratio = np.sqrt(252) * (excess_daily_returns.mean() / daily_volatility)
else:
sharpe_ratio = 0.0
if max_drawdown > 0:
calmar_ratio = annualized_return / max_drawdown
else:
calmar_ratio = float("inf")
total_commissions = sum(t.commission for t in trades)
# --- 重新计算交易相关指标 ---
realized_pnl_trades = [
t.realized_pnl for t in trades if t.is_close_trade
] # 只关注平仓交易的盈亏
winning_pnl = [pnl for pnl in realized_pnl_trades if pnl > 0]
losing_pnl = [pnl for pnl in realized_pnl_trades if pnl < 0]
winning_count = len(winning_pnl)
losing_count = len(losing_pnl)
total_closed_trades = winning_count + losing_count
total_profit_per_trade = sum(winning_pnl)
total_loss_per_trade = sum(losing_pnl) # sum of negative values
avg_profit_per_trade = (
total_profit_per_trade / winning_count if winning_count > 0 else 0.0
)
avg_loss_per_trade = (
total_loss_per_trade / losing_count if losing_count > 0 else 0.0
) # 这是负值
win_rate = winning_count / total_closed_trades if total_closed_trades > 0 else 0.0
# 盈亏比 = 平均盈利 / 平均亏损的绝对值
profit_loss_ratio = (
abs(avg_profit_per_trade / avg_loss_per_trade)
if avg_loss_per_trade != 0
else float("inf")
)
total_realized_pnl = sum(realized_pnl_trades)
return {
"初始资金": initial_capital,
"最终资金": final_value,
"总收益率": total_return,
"年化收益率": annualized_return,
"最大回撤": max_drawdown,
"夏普比率": sharpe_ratio,
"卡玛比率": calmar_ratio,
"总交易次数": len(trades), # 所有的买卖交易
"交易成本": total_commissions,
"总实现盈亏": total_realized_pnl, # 新增
"胜率": win_rate,
"盈亏比": profit_loss_ratio,
"盈利交易次数": winning_count,
"亏损交易次数": losing_count,
"平均每次盈利": avg_profit_per_trade,
"平均每次亏损": avg_loss_per_trade, # 这个值是负数
2025-06-29 12:03:43 +08:00
"initial_capital": initial_capital,
"final_capital": final_value,
"total_return": total_return,
"annualized_return": annualized_return,
"max_drawdown": max_drawdown,
"sharpe_ratio": sharpe_ratio,
"calmar_ratio": calmar_ratio,
"total_trades": len(trades), # All buy and sell trades
"transaction_costs": total_commissions,
"total_realized_pnl": total_realized_pnl, # New
"win_rate": win_rate,
"profit_loss_ratio": profit_loss_ratio,
"winning_trades_count": winning_count,
"losing_trades_count": losing_count,
"avg_profit_per_trade": avg_profit_per_trade,
"avg_loss_per_trade": avg_loss_per_trade, # This value is negative
2025-06-18 10:25:05 +08:00
}
2025-06-29 12:03:43 +08:00
def plot_equity_and_drawdown_chart(
snapshots: List[PortfolioSnapshot],
initial_capital: float,
title: str = "Portfolio Equity and Drawdown Curve",
) -> None:
2025-06-18 10:25:05 +08:00
"""
Plots the portfolio equity curve and drawdown. X-axis points are equally spaced.
Args:
snapshots (List[PortfolioSnapshot]): List of portfolio snapshots.
initial_capital (float): Initial capital.
title (str): Title of the chart.
"""
if not snapshots:
print("No portfolio snapshots available to plot equity and drawdown.")
return
2025-06-29 12:03:43 +08:00
df_equity = pd.DataFrame(
[{"datetime": s.datetime, "total_value": s.total_value} for s in snapshots]
)
2025-06-29 12:03:43 +08:00
equity_curve = df_equity["total_value"] / initial_capital
2025-06-18 10:25:05 +08:00
rolling_max = equity_curve.cummax()
drawdown = (rolling_max - equity_curve) / rolling_max
2025-06-29 12:03:43 +08:00
plt.style.use("seaborn-v0_8-darkgrid")
fig, (ax1, ax2) = plt.subplots(
2, 1, figsize=(14, 10), sharex=True, gridspec_kw={"height_ratios": [3, 1]}
)
2025-06-18 10:25:05 +08:00
x_axis_indices = np.arange(len(df_equity))
# Equity Curve Plot
2025-06-29 12:03:43 +08:00
ax1.plot(
x_axis_indices, equity_curve, label="Equity Curve", color="blue", linewidth=1.5
)
ax1.set_ylabel("Equity", fontsize=12)
ax1.legend(loc="upper left")
2025-06-18 10:25:05 +08:00
ax1.grid(True)
ax1.set_title(title, fontsize=16)
# Drawdown Curve Plot
2025-06-29 12:03:43 +08:00
ax2.fill_between(x_axis_indices, 0, drawdown, color="red", alpha=0.3)
ax2.plot(
x_axis_indices,
drawdown,
color="red",
linewidth=1.0,
linestyle="--",
label="Drawdown",
)
ax2.set_ylabel("Drawdown Rate", fontsize=12)
ax2.set_xlabel("Data Point Index (Date Labels Below)", fontsize=12)
ax2.set_title("Portfolio Drawdown Curve", fontsize=14)
ax2.legend(loc="upper left")
2025-06-18 10:25:05 +08:00
ax2.grid(True)
ax2.set_ylim(0, max(drawdown.max() * 1.1, 0.05))
# Set X-axis ticks to show actual dates at intervals
num_ticks = 10
if len(df_equity) > 0:
tick_positions = np.linspace(0, len(df_equity) - 1, num_ticks, dtype=int)
2025-06-29 12:03:43 +08:00
tick_labels = [
df_equity["datetime"].iloc[i].strftime("%Y-%m-%d %H:%M")
for i in tick_positions
]
2025-06-18 10:25:05 +08:00
ax1.set_xticks(tick_positions)
2025-06-29 12:03:43 +08:00
ax1.set_xticklabels(tick_labels, rotation=45, ha="right")
2025-06-18 10:25:05 +08:00
plt.tight_layout()
plt.show()
def plot_close_price_chart(bars: List[Bar], title: str = "Close Price Chart") -> None:
"""
Plots the underlying asset's close price. X-axis points are equally spaced.
Args:
bars (List[Bar]): List of all processed Bar data.
title (str): Title of the chart.
"""
if not bars:
print("No bar data available to plot close price.")
return
2025-06-29 12:03:43 +08:00
df_prices = pd.DataFrame(
[{"datetime": b.datetime, "close_price": b.close} for b in bars]
)
2025-06-18 10:25:05 +08:00
2025-06-29 12:03:43 +08:00
plt.style.use("seaborn-v0_8-darkgrid")
fig, ax = plt.subplots(1, 1, figsize=(14, 7)) # Single subplot
2025-06-18 10:25:05 +08:00
x_axis_indices = np.arange(len(df_prices))
2025-06-29 12:03:43 +08:00
ax.plot(
x_axis_indices,
df_prices["close_price"],
label="Close Price",
color="orange",
linewidth=1.5,
)
ax.set_ylabel("Price", fontsize=12)
ax.set_xlabel("Data Point Index (Date Labels Below)", fontsize=12)
2025-06-18 10:25:05 +08:00
ax.set_title(title, fontsize=16)
2025-06-29 12:03:43 +08:00
ax.legend(loc="upper left")
2025-06-18 10:25:05 +08:00
ax.grid(True)
# Set X-axis ticks to show actual dates at intervals
num_ticks = 10
if len(df_prices) > 0:
tick_positions = np.linspace(0, len(df_prices) - 1, num_ticks, dtype=int)
2025-06-29 12:03:43 +08:00
tick_labels = [
df_prices["datetime"].iloc[i].strftime("%Y-%m-%d %H:%M")
for i in tick_positions
]
2025-06-18 10:25:05 +08:00
ax.set_xticks(tick_positions)
2025-06-29 12:03:43 +08:00
ax.set_xticklabels(tick_labels, rotation=45, ha="right")
2025-06-18 10:25:05 +08:00
plt.tight_layout()
plt.show()
# 辅助函数:计算单笔交易的盈亏
def calculate_trade_pnl(
2025-06-29 12:03:43 +08:00
trade: Trade, entry_price: float, exit_price: float, direction: str
2025-06-18 10:25:05 +08:00
) -> float:
if direction == "LONG":
pnl = (exit_price - entry_price) * trade.volume
elif direction == "SHORT":
pnl = (entry_price - exit_price) * trade.volume
else:
pnl = 0.0
return pnl