Files
NewQuant/data/ analysis/MA.ipynb

783 lines
102 KiB
Plaintext
Raw Normal View History

2025-09-24 23:14:14 +08:00
{
"cells": [
{
"metadata": {},
"cell_type": "raw",
"source": "# Please replace 'your_futures_data.csv' with the actual path to your CSV file",
"id": "fb1975346060eb6d"
},
{
"cell_type": "code",
"id": "initial_id",
"metadata": {
"collapsed": true,
"ExecuteTime": {
"end_time": "2025-09-24T08:34:05.116565Z",
"start_time": "2025-09-24T08:34:05.113703Z"
}
},
"source": [
"import pandas as pd\n",
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"import seaborn as sns\n",
"import talib as ta # Make sure TA-Lib is installed: pip install TA-Lib\n",
"import statsmodels.api as sm\n",
"\n",
"import warnings\n",
"\n",
"# 忽略所有警告\n",
"warnings.filterwarnings(\"ignore\")\n",
"\n",
"# --- 0. Configure your file path ---\n",
"# Please replace 'your_futures_data.csv' with the actual path to your CSV file\n",
"file_path = '/mnt/d/PyProject/NewQuant/data/data/KQ_m@CZCE_SA/KQ_m@CZCE_SA_min15.csv'\n",
"\n",
"sns.set(style='whitegrid')\n",
"plt.rcParams['font.sans-serif'] = ['SimHei'] # 用来正常显示中文标签\n",
"plt.rcParams['axes.unicode_minus'] = False # 用来正常显示负号\n"
],
"outputs": [],
"execution_count": 27
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2025-09-24T08:34:05.158770Z",
"start_time": "2025-09-24T08:34:05.127669Z"
}
},
"cell_type": "code",
"source": [
"\n",
"# --- 1. Data Loading and Preprocessing ---\n",
"def load_and_preprocess_data(file_path):\n",
" \"\"\"\n",
" Loads historical futures data and performs basic preprocessing.\n",
" Assumes data contains 'datetime', 'open', 'high', 'low', 'close', 'volume' columns.\n",
" \"\"\"\n",
" try:\n",
" df = pd.read_csv(file_path, parse_dates=['datetime'], index_col='datetime')\n",
" # Ensure data is sorted by time\n",
" df = df.sort_index()\n",
" # Check and handle missing values\n",
" initial_rows = len(df)\n",
" df.dropna(inplace=True)\n",
" if len(df) < initial_rows:\n",
" print(f\"Warning: Missing values found in data, deleted {initial_rows - len(df)} rows.\")\n",
"\n",
" # Check if necessary columns exist\n",
" required_columns = ['open', 'high', 'low', 'close', 'volume']\n",
" if not all(col in df.columns for col in required_columns):\n",
" raise ValueError(f\"CSV file is missing required columns. Please ensure it contains: {required_columns}\")\n",
"\n",
" print(f\"Successfully loaded {len(df)} rows of data.\")\n",
" print(\"First 5 rows of data:\")\n",
" print(df.head())\n",
" return df\n",
" except FileNotFoundError:\n",
" print(f\"Error: File '{file_path}' not found. Please check the path.\")\n",
" return None\n",
" except Exception as e:\n",
" print(f\"Error during data loading or preprocessing: {e}\")\n",
" return None\n",
"\n",
"\n",
"df_raw = load_and_preprocess_data(file_path)\n",
"df_raw = df_raw[df_raw.index >= '2024-01-01']"
],
"id": "1638e05ca7ef1ac8",
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Successfully loaded 25662 rows of data.\n",
"First 5 rows of data:\n",
" open high low close volume open_oi \\\n",
"datetime \n",
"2020-12-31 14:45:00 1607.0 1611.0 1603.0 1611.0 19480.0 148833.0 \n",
"2021-01-04 09:00:00 1610.0 1636.0 1601.0 1620.0 55486.0 146448.0 \n",
"2021-01-04 09:15:00 1620.0 1620.0 1601.0 1604.0 30314.0 153373.0 \n",
"2021-01-04 09:30:00 1604.0 1606.0 1590.0 1595.0 30803.0 157091.0 \n",
"2021-01-04 09:45:00 1595.0 1601.0 1594.0 1600.0 10031.0 158730.0 \n",
"\n",
" close_oi underlying_symbol \n",
"datetime \n",
"2020-12-31 14:45:00 146448.0 CZCE.SA105 \n",
"2021-01-04 09:00:00 153373.0 CZCE.SA105 \n",
"2021-01-04 09:15:00 157091.0 CZCE.SA105 \n",
"2021-01-04 09:30:00 158730.0 CZCE.SA105 \n",
"2021-01-04 09:45:00 160031.0 CZCE.SA105 \n"
]
}
],
"execution_count": 28
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2025-09-24T08:34:05.651602Z",
"start_time": "2025-09-24T08:34:05.175334Z"
}
},
"cell_type": "code",
"source": [
"import pandas as pd\n",
"import numpy as np\n",
"import talib\n",
"import matplotlib.pyplot as plt\n",
"import seaborn as sns\n",
"\n",
"\n",
"# --- 1. 数据准备与事件识别函数 ---\n",
"\n",
"def analyze_ma_channel_events(df_raw: pd.DataFrame, ma_periods: list) -> pd.DataFrame:\n",
" \"\"\"\n",
" 分析不同MA周期下价格在动态通道中的 bounce 和 penetration 事件。\n",
"\n",
" Args:\n",
" df_raw (pd.DataFrame): 原始数据,必须包含 'high', 'low', 'close', 'open' 列。\n",
" ma_periods (list): 需要分析的移动平均线周期列表。\n",
"\n",
" Returns:\n",
" pd.DataFrame: 包含每个MA周期事件概率的统计结果。\n",
" \"\"\"\n",
" # 确保有足够的数据\n",
" if df_raw.empty or len(df_raw) < max(ma_periods) + 2:\n",
" raise ValueError(\"数据量不足以进行分析\")\n",
"\n",
" # 预先计算ATR\n",
" atr = talib.ATR(df_raw['high'], df_raw['low'], df_raw['close'], timeperiod=14)\n",
" df = df_raw.copy()\n",
" df['atr'] = atr\n",
" df.dropna(inplace=True)\n",
"\n",
" results = []\n",
"\n",
" for period in ma_periods:\n",
" # 计算MA和动态上下轨\n",
" ma = talib.SMA(df['close'], timeperiod=period)\n",
" upper_band = ma + 0.5 * df['atr']\n",
" lower_band = ma - 0.5 * df['atr']\n",
"\n",
" # 初始化状态变量\n",
" in_band_state = 'outside' # outside, inside_from_top, inside_from_bottom\n",
" bounce_count = 0\n",
" penetration_count = 0\n",
"\n",
" # 遍历数据以识别事件\n",
" for i in range(1, len(df)):\n",
" close = df['close'].iloc[i]\n",
" prev_close = df['close'].iloc[i - 1]\n",
" ub = upper_band.iloc[i]\n",
" lb = lower_band.iloc[i]\n",
"\n",
" # --- 核心事件识别逻辑 ---\n",
" if in_band_state == 'outside':\n",
" # 判断是否从上方穿入\n",
" if prev_close >= ub and close < ub:\n",
" in_band_state = 'inside_from_top'\n",
" # 判断是否从下方穿入\n",
" elif prev_close <= lb and close > lb:\n",
" in_band_state = 'inside_from_bottom'\n",
"\n",
" elif in_band_state == 'inside_from_top':\n",
" # 判断是否从上轨穿出 (bounce)\n",
" if close > ub:\n",
" bounce_count += 1\n",
" in_band_state = 'outside' # 重置状态\n",
" # 判断是否从下轨穿出 (penetration)\n",
" elif close < lb:\n",
" penetration_count += 1\n",
" in_band_state = 'outside' # 重置状态\n",
"\n",
" elif in_band_state == 'inside_from_bottom':\n",
" # 判断是否从下轨穿出 (bounce)\n",
" if close < lb:\n",
" bounce_count += 1\n",
" in_band_state = 'outside' # 重置状态\n",
" # 判断是否从上轨穿出 (penetration)\n",
" elif close > ub:\n",
" penetration_count += 1\n",
" in_band_state = 'outside' # 重置状态\n",
"\n",
" # --- 概率计算 ---\n",
" total_events = bounce_count + penetration_count\n",
" if total_events > 0:\n",
" bounce_prob = bounce_count / total_events\n",
" penetration_prob = penetration_count / total_events\n",
" else:\n",
" bounce_prob = 0\n",
" penetration_prob = 0\n",
"\n",
" results.append({\n",
" 'ma_period': period,\n",
" 'total_events': total_events,\n",
" 'bounce_count': bounce_count,\n",
" 'penetration_count': penetration_count,\n",
" 'bounce_probability': bounce_prob,\n",
" 'penetration_probability': penetration_prob\n",
" })\n",
"\n",
" return pd.DataFrame(results)\n",
"\n",
"\n",
"# --- 2. 可视化函数 ---\n",
"\n",
"def plot_event_probabilities(results_df: pd.DataFrame):\n",
" \"\"\"\n",
" 将事件概率结果进行可视化。\n",
" \"\"\"\n",
" if results_df.empty:\n",
" print(\"没有可供可视化的数据。\")\n",
" return\n",
"\n",
" # 数据重塑以适应seaborn的条形图\n",
" plot_data = results_df.melt(id_vars='ma_period',\n",
" value_vars=['bounce_probability', 'penetration_probability'],\n",
" var_name='event_type',\n",
" value_name='probability')\n",
"\n",
" # 美化标签\n",
" plot_data['event_type'] = plot_data['event_type'].str.replace('_probability', '').str.capitalize()\n",
"\n",
" # 绘图\n",
" fig, ax = plt.subplots(figsize=(14, 8))\n",
"\n",
" sns.barplot(data=plot_data, x='ma_period', y='probability', hue='event_type', ax=ax, palette='viridis')\n",
"\n",
" # 添加标题和标签\n",
" ax.set_title('事件概率 vs. MA周期 (MA ± 0.5 ATR Channel)', fontsize=16, fontweight='bold')\n",
" ax.set_xlabel('移动平均线周期 (MA Period)', fontsize=12)\n",
" ax.set_ylabel('事件发生概率 (Probability)', fontsize=12)\n",
" ax.legend(title='事件类型 (Event Type)', fontsize=10)\n",
" ax.yaxis.set_major_formatter(plt.FuncFormatter('{:.0%}'.format)) # 格式化y轴为百分比\n",
"\n",
" # 在每个条形上显示数值\n",
" for container in ax.containers:\n",
" ax.bar_label(container, fmt='{:.1%}', fontsize=9, padding=3)\n",
"\n",
" plt.tight_layout()\n",
" plt.show()\n",
" # 保存图像\n",
" print(\"图表已保存为 ma_channel_event_probability.png\")\n",
"\n",
"\n",
"# 定义您想分析的MA周期\n",
"ma_periods_to_analyze = [10, 20, 30, 50, 60, 100, 120]\n",
"\n",
"# 执行分析\n",
"event_statistics = analyze_ma_channel_events(df_raw, ma_periods_to_analyze)\n",
"\n",
"# 打印统计结果\n",
"print(\"--- 事件统计结果 ---\")\n",
"print(event_statistics.to_string())\n",
"\n",
"# 可视化结果\n",
"plot_event_probabilities(event_statistics)"
],
"id": "e0a36ae978d73ecc",
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"--- 事件统计结果 ---\n",
" ma_period total_events bounce_count penetration_count bounce_probability penetration_probability\n",
"0 10 1104 557 547 0.504529 0.495471\n",
"1 20 890 496 394 0.557303 0.442697\n",
"2 30 763 432 331 0.566186 0.433814\n",
"3 50 610 364 246 0.596721 0.403279\n",
"4 60 567 319 248 0.562610 0.437390\n",
"5 100 418 238 180 0.569378 0.430622\n",
"6 120 359 219 140 0.610028 0.389972\n"
]
},
{
"data": {
"text/plain": [
"<Figure size 1400x800 with 1 Axes>"
],
"image/png": "iVBORw0KGgoAAAANSUhEUgAABWgAAAMQCAYAAAC60ozSAAAAOnRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjEwLjMsIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcvZiW1igAAAAlwSFlzAAAPYQAAD2EBqD+naQAAzZxJREFUeJzs3Xd8zef///HnyRYRe6tRLUKJETRmBTUqtqKKllIlRqs1an5sWnu1VLW01N4zdpEGofbeM0FCEhKZvz/88v46TUIiiWM87rfb5/bJud7v93Wu6+TNx+d5rvfrMsXGxsYKAAAAAAAAAPDCWVl6AAAAAAAAAADwpiKgBQAAAAAAAAALIaAFAAAAAAAAAAshoAUAAAAAAAAACyGgBQAAAAAAAAALIaAFAAAAAAAAAAshoAUAAAAAAAAACyGgBQAAAAAAAAALIaAFAAAAAAAAAAshoAUA4BXQpUsX9e7dO0V9HDhwQJ9++qkOHDiQ4PHLly/r22+/1dq1a5PUX1hYmFatWqXDhw+naFzJdevWLd2+ffuFvifSxl9//aW6devKx8fHrL1FixZq2rSphUYFAAAAvFg2lh4AAABvulOnTunGjRtycHCQtbV1guecPXtW6dKlk6+vb4LHo6OjFRERofTp06t8+fIJnrNkyRLt37/fCDfDwsKULl0643hYWJjWrFmj/PnzJ2ncDx48UJ8+ffTFF1/I1dVVUVFRCgoKMjvHyspKDx8+1K5du+To6Cgrq//7bjg6Olp2dnZq0KBBkt5PkjZs2KBevXrp448/1vDhw5N8HVLm559/1oQJE577+t69e6tz587x2sPDw3Xx4kXFxsaatQcEBCg8PNx4PW/ePI0ZMybBvtOnT6/9+/c/99gSs2rVKmXPnl2VKlVK9b7TUr9+/bRv3z5t27bN0kMBXqiVK1cqZ86ccnd3t/RQAABINgJaAAAsbNmyZZo3b16Szm3Xrt1Tj5ctW1YLFy6M1x4YGKiNGzeqUqVKqlevniTpp59+0u7du/Xtt9/K3d1dtra2kiQbm6T98yAu3M2QIYMk6eLFi/HC1mzZsumHH37QL7/8IgcHBz169EjXr19Xvnz5ZGdnp9y5c8e75vDhw7K3tzfG86R8+fLJyspKmzZtUvv27WUymcyOx8bGKjIyUjExMSpRokSS5oFns7OzkyR17txZhQoVMjsWEhKiR48eKUuWLGYBvCRduHBBs2fPNq6PExkZqUePHhn32n9/1zY2NrK3t1dERIRiY2Nla2ur6OhotW3bVu+//75x3tChQxUTE5Nq84yzc+dO9e/fX19++eUrF9ACb6o9e/Zoy5Yt+vPPP1W8eHFLDwcAgGQhoAUAwMLat2+vJk2aKF26dAmGktLjYMzJySnRVYwRERGKiIhI9PpZs2YpPDxcPXv2NNp8fX115swZvf3228817v+Ga/b29pKkYcOGqV69evrf//6ngwcPqlKlStq+fbskae/evfr88881YcIEubq6Jtjvxx9//Mz3vn//vj766KNEj2fLlk179uxJ1nxeVVOnTtW0adMkPQ7da9SoYRyrWLGi7t27p7x58ya4orJXr17asGGDxo4dq8aNGyf6HnEBa7Vq1eKt0B4wYICWLl2q/fv3y9nZ2eyYr6+vZs+eHe++PHz4sNq0aWO8/vTTTxN835IlS6p///5Knz69JKl48eKqVauWcXzs2LGKiIhIdNzP4+rVq/rmm2/UvHlz489Lv379tGLFCknSmjVrVKRIEUlSaGio3NzcFBsbqwoVKmj+/Pnx+mvRooWOHDmiP//8U25ubqk6Vkvau3evpk2bplOnTsnR0VHNmjVT9+7dk/wFT0JCQ0NVvnz5BEN3Ozs7HT16NCVDlvTse75o0aLP7OP06dOSpLZt22rfvn1Gu8lkUsaMGeXq6qrPP/88xSs5Q0JCNGXKFG3atElBQUEqWLCg2rZta/Z35LVr11SzZk01adIk0VXmr5OiRYsm+mdt1KhRun//vrp166bVq1cbXx4CAPAqIKAFAMDC8uXLZ/y8adMmPXjwIN45Dx8+VExMjFkYEMdkMqlJkyaJ9u/v76+FCxcqa9asKl26tCTp7t27Onr0qOrXr6+cOXMmem1UVJQePnyo9OnTJ1p+IU7c6kkHBwc5OzvL1tY23jWRkZGS4q+YfNLGjRvl6Oio9OnTm5VFCA4OVqtWrYz//HdVZmxsrCIiInTv3j3jfd40J0+eNALamzdv6t69e4meGxUVZYTYu3btempAG/d7NJlMioqK0o0bN2RnZyc7OzvjWGRkpIKCgowvCrJkyWJc/9+VtUWKFNGSJUu0Y8cOTZ8+XVOmTNE777xjHP/iiy8UFRWl6dOnK3v27PFq1D7pv32n1IABA1SgQAENGTIkweMnT540AtpTp07FK8/wpMDAQCNU3Llz52sT0K5du1bfffedHBwc1KBBA0VEROiXX37R7du3NWrUqOfu9/jx44qJiVGdOnXirdR+2t8ZSZWUe75Lly7GzwcOHNCBAwfk5ub21N9dw4YNlSdPHj18+FCnT5/Wzp07tXPnTo0YMUItWrR4rrEGBwerdevWOnfunCpVqqTChQvrn3/+0aBBg3TlyhV9++23z9Xv68zW1lYTJkxQgwYNNHbsWI0YMcLSQwIAIMkIaAEAeIlMnjxZ58+fT/R4//7947WlS5fuqQHtkCFDFB4eLicnJ6Nt3bp1ioqK0urVq7V69ep4Y5g8ebJZ2/r161W4cGFJj0OLMWPGyMHBQdLjjZ7WrVtnrDa8ffu2zp8/r5CQEJlMJp07d06LFi3Sp59+agSncattExISEqLhw4erb9++ZqvZxo4dqwsXLigkJES2trbav3+/YmJiVLFiRUmPw8PZs2fL29tbv/zyS6L9v85OnjyZ4M8JOXTokIKDgyU9fjQ4Ojo60RD+yVWRt27dUu3ateOd82QpgHbt2mnAgAGJvrezs7NKlSql48ePS5Ly5Mlj3F9x72cymVSqVClJTw9h/1vmIiV27NghPz8/LVu2LNHP4uTJk2rUqJHx89Ps2rXLCHB37tyZ4o3+XgZ37tzR4MGDZWtrqz/++MMoJVKmTBkNGTJE9evXV5UqVZ6r7yNHjkh6HJKmxSPqSbnnv/76a+PnqVOn6sCBA3r//ffVvXv3RPtt3ry58feQJC1fvlz9+/fXDz/8oCZNmjzXquLRo0fr3Llz6tSpkxHGhoeH65NPPtEvv/yiBg0aqFixYsnu93Xn5OSkoUOH6quvvtJnn31m9sUPAAAvMwJaAABeIunSpZOrq6sWL15s1t6oUSM5OzvHe6yzbdu2Onv2bKL9rVy50igvECcmJkZ//vmncubMafZoeWBgoObOnasqVaoYYUNUVJTCw8OVOXNm47xMmTKpXLlyxmZjuXLlUtGiRY3H0H/44Qf98MMPkqT8+fPrxo0bmjdvnj755BOFhoZK0lMfPT158qQOHTqkFi1aaOTIkfL09NSCBQu0YsUK/fDDD/L09FRAQIB69+6tO3fu6LffflPp0qU1atQoLVy4UC4uLgoICFD27NkTfY/X1YkTJxL8OSE7d+6UJLVs2VKLFi3Sv//+q3LlyiV4blwIam1trYwZM6p3796ys7OTra2tvL295ePjo+HDhysmJkYRERF69913kzXuGzduyNHR0XgdFRWVYPAaF/4/67zn9ccff8jDw+OpwVdyP2MbGxs1atRIy5Ytk7+//1NXrCfV5s2bjXA7ofEFBwdr4sSJCR5///33U/To/cqVK/XgwQN99tlnZnWemzdvrp9++knLli177oD26NGjsrOzS/b9k1TJuedTomnTppo6dapu3Lih8+fPJ6lswpPu3LmjlStXKlOmTGbBsIODg1q3bq2BAwdq/fr1BLSJ+OCDD1SsWDEtWLBAgwcPtvRwAABIEgJaAABeIlFRUbp//76WL19u1n7//n09evQoXvvt27cVHR2dYF///vuvhgwZoiJFisjBwUE3btyQ9LiG5qVLl9SvXz99/vnnxvnnz5/X3LlzVa5cOXXu3DnRMb7zzjvq37+/FixYoHXr1umDDz5Qx44dde3aNUn
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"图表已保存为 ma_channel_event_probability.png\n"
]
}
],
"execution_count": 29
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2025-09-24T08:36:29.890785Z",
"start_time": "2025-09-24T08:36:29.734715Z"
}
},
"cell_type": "code",
"source": [
"import pandas as pd\n",
"import numpy as np\n",
"import talib\n",
"\n",
"\n",
"# --- 1. Core Backtesting Engine with ATR Stop-Loss ---\n",
"\n",
"def run_backtest(df_raw: pd.DataFrame, ma_period: int, atr_period: int, atr_multiplier: float, stop_loss_multiplier: float):\n",
" \"\"\"\n",
" Executes a channel breakout backtest with a fixed ATR-based stop-loss.\n",
"\n",
" Args:\n",
" df_raw (pd.DataFrame): The original OHLCV data.\n",
" ma_period (int): The period for the moving average.\n",
" atr_period (int): The period for ATR calculation.\n",
" atr_multiplier (float): The ATR multiplier for the channel.\n",
" stop_loss_multiplier (float): The ATR multiplier for the stop-loss distance.\n",
"\n",
" Returns:\n",
" pd.DataFrame: A DataFrame of all completed trades.\n",
" \"\"\"\n",
" # --- Data Preparation and Indicator Calculation ---\n",
" df = df_raw.copy()\n",
" df.dropna(inplace=True)\n",
"\n",
" df['ma'] = talib.SMA(df['close'], timeperiod=ma_period)\n",
" df['atr'] = talib.ATR(df['high'], df['low'], df['close'], timeperiod=atr_period)\n",
"\n",
" df['upper_band'] = df['ma'] + atr_multiplier * df['atr']\n",
" df['lower_band'] = df['ma'] - atr_multiplier * df['atr']\n",
"\n",
" df.dropna(inplace=True)\n",
"\n",
" # --- Backtesting Loop ---\n",
" position = 0 # 0: flat, 1: long, -1: short\n",
" entry_price = 0.0\n",
" stop_loss_price = 0.0\n",
" trades = []\n",
"\n",
" print(\"Starting backtest...\")\n",
" for i in range(1, len(df)):\n",
" # Get current and previous bar data\n",
" prev_close = df['close'].iloc[i - 1]\n",
" current_close = df['close'].iloc[i]\n",
" current_low = df['low'].iloc[i]\n",
" current_high = df['high'].iloc[i]\n",
"\n",
" prev_upper = df['upper_band'].iloc[i - 1]\n",
" current_upper = df['upper_band'].iloc[i]\n",
"\n",
" prev_lower = df['lower_band'].iloc[i - 1]\n",
" current_lower = df['lower_band'].iloc[i]\n",
"\n",
" current_atr = df['atr'].iloc[i]\n",
"\n",
" # --- Core Trading Logic ---\n",
"\n",
" # Entry Logic (when no position is held)\n",
" if position == 0:\n",
" # Short signal: Crossing down through the upper band from above\n",
" if prev_close > prev_upper and current_close < current_upper:\n",
" position = -1\n",
" entry_price = current_close\n",
" stop_loss_price = entry_price + stop_loss_multiplier * current_atr\n",
" trades.append({'entry_time': df.index[i], 'type': 'short', 'entry_price': entry_price})\n",
" # print(f\"{df.index[i]}: Short entry @ {entry_price:.2f}, SL @ {stop_loss_price:.2f}\")\n",
"\n",
" # Long signal: Crossing up through the lower band from below\n",
" elif prev_close < prev_lower and current_close > current_lower:\n",
" position = 1\n",
" entry_price = current_close\n",
" stop_loss_price = entry_price - stop_loss_multiplier * current_atr\n",
" trades.append({'entry_time': df.index[i], 'type': 'long', 'entry_price': entry_price})\n",
" # print(f\"{df.index[i]}: Long entry @ {entry_price:.2f}, SL @ {stop_loss_price:.2f}\")\n",
"\n",
" # Exit Logic (when a position is held)\n",
" elif position == 1: # Long position\n",
" # Take-profit exit: Crossing up through the upper band from inside\n",
" if prev_close < prev_upper and current_close > current_upper:\n",
" exit_price = current_close\n",
" trade = trades[-1]\n",
" trade['exit_time'] = df.index[i]\n",
" trade['exit_price'] = exit_price\n",
" trade['pnl'] = exit_price - trade['entry_price']\n",
" position = 0\n",
" # print(f\"{df.index[i]}: Long exit (TP) @ {exit_price:.2f}, PnL: {trade['pnl']:.2f}\")\n",
"\n",
" # Stop-loss exit\n",
" elif current_low < stop_loss_price:\n",
" exit_price = stop_loss_price\n",
" trade = trades[-1]\n",
" trade['exit_time'] = df.index[i]\n",
" trade['exit_price'] = exit_price\n",
" trade['pnl'] = exit_price - trade['entry_price']\n",
" position = 0\n",
" # print(f\"{df.index[i]}: Long exit (SL) @ {exit_price:.2f}, PnL: {trade['pnl']:.2f}\")\n",
"\n",
" elif position == -1: # Short position\n",
" # Take-profit exit: Crossing down through the lower band from inside\n",
" if prev_close > prev_lower and current_close < current_lower:\n",
" exit_price = current_close\n",
" trade = trades[-1]\n",
" trade['exit_time'] = df.index[i]\n",
" trade['exit_price'] = exit_price\n",
" trade['pnl'] = trade['entry_price'] - exit_price\n",
" position = 0\n",
" # print(f\"{df.index[i]}: Short exit (TP) @ {exit_price:.2f}, PnL: {trade['pnl']:.2f}\")\n",
"\n",
" # Stop-loss exit\n",
" elif current_high > stop_loss_price:\n",
" exit_price = stop_loss_price\n",
" trade = trades[-1]\n",
" trade['exit_time'] = df.index[i]\n",
" trade['exit_price'] = exit_price\n",
" trade['pnl'] = trade['entry_price'] - exit_price\n",
" position = 0\n",
" # print(f\"{df.index[i]}: Short exit (SL) @ {exit_price:.2f}, PnL: {trade['pnl']:.2f}\")\n",
"\n",
" print(\"Backtest finished.\")\n",
"\n",
" completed_trades = [t for t in trades if 'exit_price' in t]\n",
" return pd.DataFrame(completed_trades)\n",
"\n",
"\n",
"# --- 2. Performance Analysis Function (Unchanged) ---\n",
"\n",
"def analyze_performance(trades_df: pd.DataFrame):\n",
" \"\"\"\n",
" Analyzes trade records and prints a performance summary.\n",
" \"\"\"\n",
" if trades_df.empty:\n",
" print(\"No completed trades to analyze.\")\n",
" return\n",
"\n",
" total_trades = len(trades_df)\n",
" winning_trades = trades_df[trades_df['pnl'] > 0]\n",
" losing_trades = trades_df[trades_df['pnl'] <= 0]\n",
"\n",
" win_rate = len(winning_trades) / total_trades if total_trades > 0 else 0\n",
"\n",
" avg_profit = winning_trades['pnl'].mean() if len(winning_trades) > 0 else 0\n",
" avg_loss = losing_trades['pnl'].mean() if len(losing_trades) > 0 else 0\n",
"\n",
" profit_loss_ratio = abs(avg_profit / avg_loss) if avg_loss != 0 else float('inf')\n",
"\n",
" total_pnl = trades_df['pnl'].sum()\n",
"\n",
" profit_factor = winning_trades['pnl'].sum() / abs(losing_trades['pnl'].sum()) if abs(\n",
" losing_trades['pnl'].sum()) > 0 else float('inf')\n",
"\n",
" print(\"\\n--- Strategy Performance Summary ---\")\n",
" print(f\"Total Trades : {total_trades}\")\n",
" print(f\"Winning Trades : {len(winning_trades)}\")\n",
" print(f\"Losing Trades : {len(losing_trades)}\")\n",
" print(f\"Win Rate : {win_rate:.2%}\")\n",
" print(f\"Avg. Profit : {avg_profit:.2f}\")\n",
" print(f\"Avg. Loss : {avg_loss:.2f}\")\n",
" print(f\"Profit/Loss Ratio : {profit_loss_ratio:.2f}\")\n",
" print(f\"Total PnL : {total_pnl:.2f}\")\n",
" print(f\"Profit Factor : {profit_factor:.2f}\")\n",
" print(\"----------------------------------\\n\")\n",
"\n",
"\n",
"# --- 3. Main Execution Block ---\n",
"\n",
"# --- Strategy Parameters ---\n",
"MA_PERIOD = 10\n",
"ATR_PERIOD = 14\n",
"ATR_MULTIPLIER = 0.5 # Channel width\n",
"STOP_LOSS_MULTIPLIER = 1 # Stop-loss distance\n",
"\n",
"\n",
"# Run the backtest\n",
"trades_log = run_backtest(df_raw,\n",
" ma_period=MA_PERIOD,\n",
" atr_period=ATR_PERIOD,\n",
" atr_multiplier=ATR_MULTIPLIER,\n",
" stop_loss_multiplier=STOP_LOSS_MULTIPLIER)\n",
"\n",
"# Analyze performance\n",
"print(trades_log.head().to_string())\n",
"analyze_performance(trades_log)"
],
"id": "5b6585671d20a1f",
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Starting backtest...\n",
"Backtest finished.\n",
" entry_time type entry_price exit_time exit_price pnl\n",
"0 2024-01-02 21:30:00 long 2020.0 2024-01-02 21:45:00 2033.000000 13.000000\n",
"1 2024-01-03 10:45:00 short 2062.0 2024-01-03 13:45:00 2052.000000 10.000000\n",
"2 2024-01-03 14:00:00 long 2063.0 2024-01-03 14:45:00 2074.000000 11.000000\n",
"3 2024-01-03 21:00:00 short 2064.0 2024-01-03 21:30:00 2080.403936 -16.403936\n",
"4 2024-01-03 22:15:00 short 2079.0 2024-01-04 09:00:00 2062.000000 17.000000\n",
"\n",
"--- Strategy Performance Summary ---\n",
"Total Trades : 1099\n",
"Winning Trades : 581\n",
"Losing Trades : 518\n",
"Win Rate : 52.87%\n",
"Avg. Profit : 10.04\n",
"Avg. Loss : -10.50\n",
"Profit/Loss Ratio : 0.96\n",
"Total PnL : 396.16\n",
"Profit Factor : 1.07\n",
"----------------------------------\n",
"\n"
]
}
],
"execution_count": 33
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2025-09-24T08:36:56.505148Z",
"start_time": "2025-09-24T08:36:56.387310Z"
}
},
"cell_type": "code",
"source": [
"import pandas as pd\n",
"import numpy as np\n",
"import talib\n",
"\n",
"\n",
"# --- 1. 回测引擎核心函数 ---\n",
"\n",
"def run_backtest_mean_reversion(df_raw: pd.DataFrame, ma_period: int, atr_period: int, atr_multiplier: float, stop_loss_multiplier: float):\n",
" \"\"\"\n",
" 执行一个基于ATR通道的均值回归回测。\n",
"\n",
" Args:\n",
" df_raw (pd.DataFrame): 原始OHLCV数据。\n",
" ma_period (int): 移动平均线周期。\n",
" atr_period (int): ATR计算周期。\n",
" atr_multiplier (float): ATR乘数用于构建通道。\n",
" stop_loss_multiplier (float): ATR乘数用于计算止损距离。\n",
"\n",
" Returns:\n",
" pd.DataFrame: 包含所有已完成交易记录的DataFrame。\n",
" \"\"\"\n",
" # --- 数据准备和指标计算 ---\n",
" df = df_raw.copy()\n",
" df.dropna(inplace=True)\n",
"\n",
" df['ma'] = talib.SMA(df['close'], timeperiod=ma_period)\n",
" df['atr'] = talib.ATR(df['high'], df['low'], df['close'], timeperiod=atr_period)\n",
"\n",
" df['upper_band'] = df['ma'] + atr_multiplier * df['atr']\n",
" df['lower_band'] = df['ma'] - atr_multiplier * df['atr']\n",
"\n",
" df.dropna(inplace=True)\n",
"\n",
" # --- 回测循环 ---\n",
" position = 0 # 0: 平仓, 1: 持有多单, -1: 持有空单\n",
" entry_price = 0.0\n",
" stop_loss_price = 0.0\n",
" trades = []\n",
"\n",
" print(\"开始执行均值回归回测...\")\n",
" for i in range(1, len(df)):\n",
" # 获取当前和前一根K线的数据\n",
" prev_close = df['close'].iloc[i - 1]\n",
" current_close = df['close'].iloc[i]\n",
" current_low = df['low'].iloc[i]\n",
" current_high = df['high'].iloc[i]\n",
" current_atr = df['atr'].iloc[i]\n",
"\n",
" prev_upper = df['upper_band'].iloc[i - 1]\n",
" prev_lower = df['lower_band'].iloc[i - 1]\n",
"\n",
" # --- 核心交易逻辑 ---\n",
"\n",
" # 开仓逻辑 (当前无持仓)\n",
" if position == 0:\n",
" # 做空信号:从通道内部向上穿过上轨\n",
" if prev_close < prev_upper and current_close > prev_upper:\n",
" position = -1\n",
" entry_price = current_close\n",
" # 止损价 = 入场价 + 止损距离\n",
" stop_loss_price = entry_price + stop_loss_multiplier * current_atr\n",
" trades.append({'entry_time': df.index[i], 'type': 'short', 'entry_price': entry_price, 'stop_loss': stop_loss_price})\n",
" # print(f\"{df.index[i]}: 开空单 @ {entry_price:.2f}, 止损 @ {stop_loss_price:.2f}\")\n",
"\n",
" # 做多信号:从通道内部向下穿过下轨\n",
" elif prev_close > prev_lower and current_close < prev_lower:\n",
" position = 1\n",
" entry_price = current_close\n",
" # 止损价 = 入场价 - 止损距离\n",
" stop_loss_price = entry_price - stop_loss_multiplier * current_atr\n",
" trades.append({'entry_time': df.index[i], 'type': 'long', 'entry_price': entry_price, 'stop_loss': stop_loss_price})\n",
" # print(f\"{df.index[i]}: 开多单 @ {entry_price:.2f}, 止损 @ {stop_loss_price:.2f}\")\n",
"\n",
" # 平仓逻辑 (当前有持仓)\n",
" elif position == 1: # 持有多单\n",
" # 止损平仓:触及止损价\n",
" if current_low < stop_loss_price:\n",
" exit_price = stop_loss_price\n",
" trade = trades[-1]\n",
" trade['exit_time'] = df.index[i]\n",
" trade['exit_price'] = exit_price\n",
" trade['pnl'] = exit_price - trade['entry_price']\n",
" position = 0\n",
" # print(f\"{df.index[i]}: 多单止损 @ {exit_price:.2f}, PnL: {trade['pnl']:.2f}\")\n",
" # 获利平仓:价格穿回通道内\n",
" elif current_close > df['lower_band'].iloc[i]:\n",
" exit_price = current_close\n",
" trade = trades[-1]\n",
" trade['exit_time'] = df.index[i]\n",
" trade['exit_price'] = exit_price\n",
" trade['pnl'] = exit_price - trade['entry_price']\n",
" position = 0\n",
" # print(f\"{df.index[i]}: 多单获利平仓 @ {exit_price:.2f}, PnL: {trade['pnl']:.2f}\")\n",
"\n",
" elif position == -1: # 持有空单\n",
" # 止损平仓:触及止损价\n",
" if current_high > stop_loss_price:\n",
" exit_price = stop_loss_price\n",
" trade = trades[-1]\n",
" trade['exit_time'] = df.index[i]\n",
" trade['exit_price'] = exit_price\n",
" trade['pnl'] = trade['entry_price'] - exit_price\n",
" position = 0\n",
" # print(f\"{df.index[i]}: 空单止损 @ {exit_price:.2f}, PnL: {trade['pnl']:.2f}\")\n",
" # 获利平仓:价格穿回通道内\n",
" elif current_close < df['upper_band'].iloc[i]:\n",
" exit_price = current_close\n",
" trade = trades[-1]\n",
" trade['exit_time'] = df.index[i]\n",
" trade['exit_price'] = exit_price\n",
" trade['pnl'] = trade['entry_price'] - exit_price\n",
" position = 0\n",
" # print(f\"{df.index[i]}: 空单获利平仓 @ {exit_price:.2f}, PnL: {trade['pnl']:.2f}\")\n",
"\n",
" print(\"回测结束。\")\n",
"\n",
" completed_trades = [t for t in trades if 'exit_price' in t]\n",
" return pd.DataFrame(completed_trades)\n",
"\n",
"\n",
"# --- 2. 性能统计函数 (与之前相同) ---\n",
"\n",
"def analyze_performance(trades_df: pd.DataFrame):\n",
" \"\"\"\n",
" 分析交易记录并打印性能摘要。\n",
" \"\"\"\n",
" if trades_df.empty:\n",
" print(\"没有完成的交易记录,无法进行分析。\")\n",
" return\n",
"\n",
" total_trades = len(trades_df)\n",
" winning_trades = trades_df[trades_df['pnl'] > 0]\n",
" losing_trades = trades_df[trades_df['pnl'] <= 0]\n",
"\n",
" win_rate = len(winning_trades) / total_trades if total_trades > 0 else 0\n",
"\n",
" avg_profit = winning_trades['pnl'].mean() if len(winning_trades) > 0 else 0\n",
" avg_loss = losing_trades['pnl'].mean() if len(losing_trades) > 0 else 0\n",
"\n",
" profit_loss_ratio = abs(avg_profit / avg_loss) if avg_loss != 0 else float('inf')\n",
"\n",
" total_pnl = trades_df['pnl'].sum()\n",
"\n",
" profit_factor = winning_trades['pnl'].sum() / abs(losing_trades['pnl'].sum()) if abs(\n",
" losing_trades['pnl'].sum()) > 0 else float('inf')\n",
"\n",
" print(\"\\n--- 策略性能摘要 ---\")\n",
" print(f\"总交易次数 : {total_trades}\")\n",
" print(f\"盈利交易次数 : {len(winning_trades)}\")\n",
" print(f\"亏损交易次数 : {len(losing_trades)}\")\n",
" print(f\"胜率 : {win_rate:.2%}\")\n",
" print(f\"平均每次盈利 : {avg_profit:.2f}\")\n",
" print(f\"平均每次亏损 : {avg_loss:.2f}\")\n",
" print(f\"盈亏比 : {profit_loss_ratio:.2f}\")\n",
" print(f\"总盈亏 : {total_pnl:.2f}\")\n",
" print(f\"盈利因子 : {profit_factor:.2f}\")\n",
" print(\"----------------------\\n\")\n",
"\n",
"\n",
"# --- 3. 主执行流程 ---\n",
"\n",
"# --- 策略参数 ---\n",
"MA_PERIOD = 120\n",
"ATR_PERIOD = 14\n",
"ATR_MULTIPLIER = 1.0 # 通道宽度乘数\n",
"STOP_LOSS_MULTIPLIER = 1 # 止损距离乘数\n",
"\n",
"# 执行回测\n",
"trades_log = run_backtest_mean_reversion(df_raw,\n",
" ma_period=MA_PERIOD,\n",
" atr_period=ATR_PERIOD,\n",
" atr_multiplier=ATR_MULTIPLIER,\n",
" stop_loss_multiplier=STOP_LOSS_MULTIPLIER)\n",
"\n",
"# 分析性能\n",
"print(trades_log.head().to_string())\n",
"analyze_performance(trades_log)"
],
"id": "5b6b9340f8783522",
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"开始执行均值回归回测...\n",
"回测结束。\n",
" entry_time type entry_price stop_loss exit_time exit_price pnl\n",
"0 2024-01-16 09:00:00 short 1880.0 1891.771265 2024-01-16 09:15:00 1876.000000 4.000000\n",
"1 2024-01-16 09:45:00 short 1882.0 1893.647037 2024-01-16 10:00:00 1893.647037 -11.647037\n",
"2 2024-01-18 09:15:00 short 1882.0 1896.720622 2024-01-18 10:00:00 1875.000000 7.000000\n",
"3 2024-01-18 11:00:00 short 1883.0 1896.979547 2024-01-18 14:15:00 1881.000000 2.000000\n",
"4 2024-01-18 14:30:00 short 1894.0 1907.606606 2024-01-18 14:45:00 1907.606606 -13.606606\n",
"\n",
"--- 策略性能摘要 ---\n",
"总交易次数 : 368\n",
"盈利交易次数 : 207\n",
"亏损交易次数 : 161\n",
"胜率 : 56.25%\n",
"平均每次盈利 : 6.79\n",
"平均每次亏损 : -10.39\n",
"盈亏比 : 0.65\n",
"总盈亏 : -266.73\n",
"盈利因子 : 0.84\n",
"----------------------\n",
"\n"
]
}
],
"execution_count": 36
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 2
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.6"
}
},
"nbformat": 4,
"nbformat_minor": 5
}