Files
NewQuant/grid_search_multi_process.ipynb
2025-07-10 15:07:31 +08:00

368 lines
22 KiB
Plaintext
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "782ec73f",
"metadata": {},
"outputs": [],
"source": [
"import pandas as pd\n",
"from datetime import datetime\n",
"import itertools\n",
"from typing import Dict, Any, List, Tuple, Optional\n",
"import multiprocessing # 导入 multiprocessing 模块\n",
"import math # 保留 math 导入,因为您的策略内部可能需要用到数学函数\n",
"\n",
"# 导入所有必要的模块\n",
"# 请确保这些导入路径与您的项目结构相符\n",
"from src.analysis.grid_search_analyzer import GridSearchAnalyzer\n",
"from src.analysis.result_analyzer import ResultAnalyzer\n",
"from src.common_utils import generate_parameter_range\n",
"from src.data_manager import DataManager\n",
"from src.backtest_engine import BacktestEngine\n",
"# 导入策略类\n",
"from src.strategies.SimpleLimitBuyStrategy import SimpleLimitBuyStrategyShort, SimpleLimitBuyStrategyLong, SimpleLimitBuyStrategy\n",
"\n",
"\n",
"import builtins\n",
"\n",
"%load_ext autoreload\n",
"%autoreload 2\n",
"\n",
"origin_print = print\n"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "76f9a2e9",
"metadata": {},
"outputs": [],
"source": [
"\n",
"# --- 单个回测任务函数 ---\n",
"# 这个函数将在每个独立的进程中运行,因此它必须是自包含的\n",
"def run_single_backtest(\n",
" combination: Tuple[float, float], # 传入当前参数组合\n",
" common_config: Dict[str, Any] # 传入公共配置 (如数据路径, 初始资金等)\n",
") -> Optional[Dict[str, Any]]:\n",
" \"\"\"\n",
" 运行单个参数组合的回测任务。\n",
" 此函数将在一个独立的进程中执行。\n",
" \"\"\"\n",
" p1_value, p2_value = combination\n",
"\n",
" # 从 common_config 中获取必要的配置\n",
" symbol = common_config['symbol']\n",
" data_path = common_config['data_path']\n",
" initial_capital = common_config['initial_capital']\n",
" slippage_rate = common_config['slippage_rate']\n",
" commission_rate = common_config['commission_rate']\n",
" start_time = common_config['start_time']\n",
" end_time = common_config['end_time']\n",
" roll_over_mode = common_config['roll_over_mode']\n",
" # bar_duration_seconds = common_config['bar_duration_seconds'] # 如果DataManager需要可以再传\n",
" param1_name = common_config['param1_name']\n",
" param2_name = common_config['param2_name']\n",
" optimization_metric = common_config['optimization_metric']\n",
"\n",
" # 每个进程内部独立初始化 DataManager 和 BacktestEngine\n",
" # 确保每个进程有自己的数据副本和模拟状态\n",
" data_manager = DataManager(\n",
" file_path=data_path,\n",
" symbol=symbol,\n",
" # bar_duration_seconds=bar_duration_seconds, # 如果DataManager需要根据数据文件路径推断或者额外参数传入\n",
" # start_date=start_time.date(), # DataManager 现在通过 file_path 和 symbol 处理数据\n",
" # end_date=end_time.date(),\n",
" )\n",
" # data_manager.load_data() # DataManager 内部加载数据\n",
"\n",
" # 策略参数\n",
" strategy_parameters = {\n",
" 'trade_volume': 1,\n",
" param1_name: p1_value,\n",
" param2_name: p2_value,\n",
" 'max_position': 10,\n",
" 'enable_log': False, # 在网格搜索时通常关闭策略内部的详细日志\n",
" }\n",
"\n",
" # 打印当前进程正在处理的组合信息\n",
" # 注意:多进程打印会交错显示\n",
" # print(f\"--- 正在运行组合: {strategy_parameters} (PID: {multiprocessing.current_process().pid}) ---\")\n",
"\n",
" try:\n",
" # 初始化回测引擎\n",
" engine = BacktestEngine(\n",
" data_manager=data_manager,\n",
" strategy_class=common_config['strategy'],\n",
" strategy_params=strategy_parameters,\n",
" initial_capital=initial_capital,\n",
" slippage_rate=slippage_rate,\n",
" commission_rate=commission_rate,\n",
" roll_over_mode=True, # 保持换月模式\n",
" start_time=common_config['start_time'],\n",
" end_time=common_config['end_time']\n",
" )\n",
" # 运行回测,传入时间范围\n",
" engine.run_backtest()\n",
"\n",
" # 获取回测结果并分析\n",
" results = engine.get_backtest_results()\n",
" portfolio_snapshots = results[\"portfolio_snapshots\"]\n",
" trade_history = results[\"trade_history\"]\n",
" bars = results[\"all_bars\"]\n",
" initial_capital_result = results[\"initial_capital\"]\n",
"\n",
" if portfolio_snapshots:\n",
" analyzer = ResultAnalyzer(portfolio_snapshots, trade_history, bars, initial_capital_result)\n",
"\n",
" # analyzer.generate_report()\n",
" # analyzer.plot_performance()\n",
" metrics = analyzer.calculate_all_metrics()\n",
"\n",
" # 将当前组合的参数和性能指标存储起来\n",
" result_entry = {**strategy_parameters, **metrics}\n",
" # print(f\" 组合 {combination} 完成。{optimization_metric}: {metrics.get(optimization_metric, 0.0):.4f} (PID: {multiprocessing.current_process().pid})\")\n",
" return result_entry\n",
" else:\n",
" print(f\" 组合 {strategy_parameters} 没有生成投资组合快照,无法进行结果分析。(PID: {multiprocessing.current_process().pid})\")\n",
" # 返回一个包含参数和默认0值的结果以便追踪失败组合\n",
" return {**strategy_parameters, \"total_return\": 0.0, \"annualized_return\": 0.0, \"sharpe_ratio\": 0.0, \"max_drawdown\": 0.0, \"error\": \"No portfolio snapshots\"}\n",
" except Exception as e:\n",
" import traceback\n",
" error_trace = traceback.format_exc()\n",
" print(f\" 组合 {strategy_parameters} 运行失败: {e}\\n{error_trace} (PID: {multiprocessing.current_process().pid})\")\n",
" # 返回错误信息,以便后续处理\n",
" return {**strategy_parameters, \"error\": str(e), \"traceback\": error_trace}\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "c0984689",
"metadata": {},
"outputs": [],
"source": [
"\n",
"def slient_print(*args):\n",
" pass\n"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "239e9ca0",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"总计 1681 种参数组合需要回测。\n",
"--- 启动多进程网格搜索,使用 10 个进程 ---\n"
]
},
{
"ename": "KeyboardInterrupt",
"evalue": "",
"output_type": "error",
"traceback": [
"\u001b[31m---------------------------------------------------------------------------\u001b[39m",
"\u001b[31mKeyboardInterrupt\u001b[39m Traceback (most recent call last)",
"\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[4]\u001b[39m\u001b[32m, line 69\u001b[39m\n\u001b[32m 63\u001b[39m args_for_starmap = [\n\u001b[32m 64\u001b[39m (combo, common_config_for_processes) \u001b[38;5;28;01mfor\u001b[39;00m combo \u001b[38;5;129;01min\u001b[39;00m param_combinations\n\u001b[32m 65\u001b[39m ]\n\u001b[32m 67\u001b[39m \u001b[38;5;66;03m# 使用 starmap() 来并行执行 run_single_backtest 函数\u001b[39;00m\n\u001b[32m 68\u001b[39m \u001b[38;5;66;03m# starmap 是阻塞的,会等待所有任务完成并返回结果列表\u001b[39;00m\n\u001b[32m---> \u001b[39m\u001b[32m69\u001b[39m \u001b[38;5;28;01mfor\u001b[39;00m i, result_entry \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28menumerate\u001b[39m(\u001b[43mpool\u001b[49m\u001b[43m.\u001b[49m\u001b[43mstarmap\u001b[49m\u001b[43m(\u001b[49m\u001b[43mrun_single_backtest\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43margs_for_starmap\u001b[49m\u001b[43m)\u001b[49m):\n\u001b[32m 70\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m result_entry: \u001b[38;5;66;03m# 确保结果不为空\u001b[39;00m\n\u001b[32m 71\u001b[39m all_results.append(result_entry)\n",
"\u001b[36mFile \u001b[39m\u001b[32m~/miniconda3/envs/quant/lib/python3.12/multiprocessing/pool.py:375\u001b[39m, in \u001b[36mPool.starmap\u001b[39m\u001b[34m(self, func, iterable, chunksize)\u001b[39m\n\u001b[32m 369\u001b[39m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34mstarmap\u001b[39m(\u001b[38;5;28mself\u001b[39m, func, iterable, chunksize=\u001b[38;5;28;01mNone\u001b[39;00m):\n\u001b[32m 370\u001b[39m \u001b[38;5;250m \u001b[39m\u001b[33;03m'''\u001b[39;00m\n\u001b[32m 371\u001b[39m \u001b[33;03m Like `map()` method but the elements of the `iterable` are expected to\u001b[39;00m\n\u001b[32m 372\u001b[39m \u001b[33;03m be iterables as well and will be unpacked as arguments. Hence\u001b[39;00m\n\u001b[32m 373\u001b[39m \u001b[33;03m `func` and (a, b) becomes func(a, b).\u001b[39;00m\n\u001b[32m 374\u001b[39m \u001b[33;03m '''\u001b[39;00m\n\u001b[32m--> \u001b[39m\u001b[32m375\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43m_map_async\u001b[49m\u001b[43m(\u001b[49m\u001b[43mfunc\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43miterable\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mstarmapstar\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mchunksize\u001b[49m\u001b[43m)\u001b[49m\u001b[43m.\u001b[49m\u001b[43mget\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n",
"\u001b[36mFile \u001b[39m\u001b[32m~/miniconda3/envs/quant/lib/python3.12/multiprocessing/pool.py:768\u001b[39m, in \u001b[36mApplyResult.get\u001b[39m\u001b[34m(self, timeout)\u001b[39m\n\u001b[32m 767\u001b[39m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34mget\u001b[39m(\u001b[38;5;28mself\u001b[39m, timeout=\u001b[38;5;28;01mNone\u001b[39;00m):\n\u001b[32m--> \u001b[39m\u001b[32m768\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43mwait\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtimeout\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 769\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28mself\u001b[39m.ready():\n\u001b[32m 770\u001b[39m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mTimeoutError\u001b[39;00m\n",
"\u001b[36mFile \u001b[39m\u001b[32m~/miniconda3/envs/quant/lib/python3.12/multiprocessing/pool.py:765\u001b[39m, in \u001b[36mApplyResult.wait\u001b[39m\u001b[34m(self, timeout)\u001b[39m\n\u001b[32m 764\u001b[39m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34mwait\u001b[39m(\u001b[38;5;28mself\u001b[39m, timeout=\u001b[38;5;28;01mNone\u001b[39;00m):\n\u001b[32m--> \u001b[39m\u001b[32m765\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43m_event\u001b[49m\u001b[43m.\u001b[49m\u001b[43mwait\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtimeout\u001b[49m\u001b[43m)\u001b[49m\n",
"\u001b[36mFile \u001b[39m\u001b[32m~/miniconda3/envs/quant/lib/python3.12/threading.py:655\u001b[39m, in \u001b[36mEvent.wait\u001b[39m\u001b[34m(self, timeout)\u001b[39m\n\u001b[32m 653\u001b[39m signaled = \u001b[38;5;28mself\u001b[39m._flag\n\u001b[32m 654\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m signaled:\n\u001b[32m--> \u001b[39m\u001b[32m655\u001b[39m signaled = \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43m_cond\u001b[49m\u001b[43m.\u001b[49m\u001b[43mwait\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtimeout\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 656\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m signaled\n",
"\u001b[36mFile \u001b[39m\u001b[32m~/miniconda3/envs/quant/lib/python3.12/threading.py:355\u001b[39m, in \u001b[36mCondition.wait\u001b[39m\u001b[34m(self, timeout)\u001b[39m\n\u001b[32m 353\u001b[39m \u001b[38;5;28;01mtry\u001b[39;00m: \u001b[38;5;66;03m# restore state no matter what (e.g., KeyboardInterrupt)\u001b[39;00m\n\u001b[32m 354\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m timeout \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[32m--> \u001b[39m\u001b[32m355\u001b[39m \u001b[43mwaiter\u001b[49m\u001b[43m.\u001b[49m\u001b[43macquire\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 356\u001b[39m gotit = \u001b[38;5;28;01mTrue\u001b[39;00m\n\u001b[32m 357\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m:\n",
"\u001b[31mKeyboardInterrupt\u001b[39m: "
]
}
],
"source": [
"\n",
"# --- 主执行块 ---\n",
"# 这是多进程代码的入口点,必须在 'if __name__ == \"__main__\":' 保护块中\n",
"if __name__ == \"__main__\":\n",
" # 确保 autoreload 启用 (在Jupyter Notebook中使用纯Python脚本运行时可移除)\n",
" # %load_ext autoreload\n",
" # %autoreload 2\n",
"\n",
" # --- 全局配置 ---\n",
" data_file_path = \"/mnt/d/PyProject/NewQuant/data/data/KQ_m@CZCE_MA/KQ_m@CZCE_MA_min60.csv\"\n",
" initial_capital = 100000.0\n",
" slippage_rate = 0.0000\n",
" commission_rate = 0.0002\n",
" global_config = {\n",
" 'symbol': 'KQ_m@CZCE_MA',\n",
" }\n",
" # 确保每个合约的tick_size在这里定义或获取\n",
" RB_TICK_SIZE = 1.0 # 螺纹钢的最小变动单位\n",
"\n",
" # --- 定义参数网格 ---\n",
" param1_name = \"open_range_factor_1_ago\"\n",
" param1_values = generate_parameter_range(start=-2, end=2, step=0.1)\n",
" param2_name = \"open_range_factor_7_ago\"\n",
" param2_values = generate_parameter_range(start=-2, end=2, step=0.1)\n",
" optimization_metric = 'sharpe_ratio'\n",
" \n",
" # 生成所有参数组合\n",
" param_combinations = list(itertools.product(param1_values, param2_values))\n",
" total_combinations = len(param_combinations)\n",
" print(f\"总计 {total_combinations} 种参数组合需要回测。\")\n",
"\n",
" all_results: List[Dict[str, Any]] = []\n",
" grid_results: List[Dict[str, Any]] = []\n",
"\n",
" # 准备传递给每个子进程的公共配置字典\n",
" common_config_for_processes = {\n",
" 'symbol': global_config['symbol'],\n",
" 'data_path': data_file_path,\n",
" 'initial_capital': initial_capital,\n",
" 'slippage_rate': slippage_rate,\n",
" 'commission_rate': commission_rate,\n",
" 'start_time': datetime(2022, 1, 1), # 回测起始时间\n",
" 'end_time': datetime(2025, 1, 1), # 回测结束时间\n",
" 'roll_over_mode': True, # 保持换月模式\n",
" 'param1_name': param1_name,\n",
" 'param2_name': param2_name,\n",
" 'optimization_metric': optimization_metric,\n",
" 'strategy': SimpleLimitBuyStrategyLong\n",
" }\n",
"\n",
" # 确定要使用的进程数量 (通常是CPU核心数)\n",
" num_processes = int(multiprocessing.cpu_count() / 2)\n",
" if num_processes < 1:\n",
" num_processes = 1\n",
" \n",
" print(f\"--- 启动多进程网格搜索,使用 {num_processes} 个进程 ---\")\n",
"\n",
" builtins.print = slient_print\n",
"\n",
" # 创建一个进程池\n",
" with multiprocessing.Pool(processes=num_processes) as pool:\n",
" # 准备 run_single_backtest 函数的参数列表\n",
" # starmap 需要一个可迭代对象,其中每个元素是传递给目标函数的参数元组\n",
" args_for_starmap = [\n",
" (combo, common_config_for_processes) for combo in param_combinations\n",
" ]\n",
" \n",
" # 使用 starmap() 来并行执行 run_single_backtest 函数\n",
" # starmap 是阻塞的,会等待所有任务完成并返回结果列表\n",
" for i, result_entry in enumerate(pool.starmap(run_single_backtest, args_for_starmap)):\n",
" if result_entry: # 确保结果不为空\n",
" all_results.append(result_entry)\n",
" # 仅将成功的(无错误的)结果添加到用于网格分析的列表中\n",
" if 'error' not in result_entry:\n",
" grid_results.append(\n",
" {\n",
" param1_name: result_entry.get(param1_name),\n",
" param2_name: result_entry.get(param2_name),\n",
" optimization_metric: result_entry.get(optimization_metric, 0.0),\n",
" }\n",
" )\n",
" else:\n",
" # 对于失败的组合,将其优化指标设置为一个特殊值,便于识别\n",
" grid_results.append(\n",
" {\n",
" param1_name: result_entry.get(param1_name),\n",
" param2_name: result_entry.get(param2_name),\n",
" optimization_metric: float('-inf'), # 用负无穷表示失败\n",
" 'error_message': result_entry['error']\n",
" }\n",
" )\n",
"\n",
" builtins.print = origin_print\n",
" print(\"\\n--- 网格搜索回测完毕 ---\")\n",
"\n",
" # --- 5. 后处理和最佳结果选择 ---\n",
" if all_results:\n",
" results_df = pd.DataFrame(all_results)\n",
" # print(\"\\n--- 所有回测结果汇总 ---\")\n",
" # # 确保打印时浮点数格式化\n",
" # pd.set_option('display.float_format', lambda x: '%.4f' % x)\n",
" # print(results_df.to_string())\n",
"\n",
" # 找到最佳组合 (排除有错误的)\n",
" # 过滤掉包含 'error' 键的行,或者 'error' 键的值不为空的行\n",
" # 同时确保优化指标是数值,并且不为无穷大\n",
" print(results_df.info())\n",
" successful_results_df = results_df[(pd.to_numeric(results_df[optimization_metric], errors='coerce').notna()) &\n",
" (pd.to_numeric(results_df[optimization_metric], errors='coerce') != float('-inf'))\n",
" ].copy() # 使用 .copy() 避免 SettingWithCopyWarning\n",
" \n",
" if not successful_results_df.empty and optimization_metric in successful_results_df.columns:\n",
" # 确保优化指标列是数值类型\n",
" successful_results_df[optimization_metric] = pd.to_numeric(successful_results_df[optimization_metric], errors='coerce')\n",
"\n",
" if not successful_results_df.empty and optimization_metric in successful_results_df.columns:\n",
" # 过滤掉NaN值如果所有夏普比率都是NaN则可能没有有效结果\n",
" normal_results = successful_results_df[(results_df['total_trades'] > 200) & (results_df['total_return'] > 0)]\n",
" if len(normal_results) > 0:\n",
" best_result = normal_results.loc[(normal_results[optimization_metric].idxmax())]\n",
" print(\"\\n--- 最优参数组合 (按夏普比率) ---\")\n",
" print(best_result)\n",
" else:\n",
" print('ERROR!!!!!!!!!!!!!!!!!!!!')\n",
" \n",
" # 找到最大值的索引\n",
" # best_result = successful_results_df.loc[successful_results_df[optimization_metric].idxmax()]\n",
" # print(f\"\\n--- 最优参数组合 (按 {optimization_metric}) ---\")\n",
" # print(best_result)\n",
"\n",
" # 导出到CSV\n",
" output_filename = f\"grid_search_results_{datetime.now().strftime('%Y%m%d_%H%M%S')}.csv\"\n",
" # results_df.to_csv(output_filename, index=False, encoding='utf-8')\n",
" # print(f\"\\n所有结果已导出到: {output_filename}\")\n",
"\n",
" # 打印枢轴表\n",
" grid_df = pd.DataFrame(grid_results)\n",
" # 确保优化指标列是数值类型,非数值的(如 -inf在pandas中可能被正确处理\n",
" grid_df[optimization_metric] = pd.to_numeric(grid_df[optimization_metric], errors='coerce')\n",
"\n",
" pivot_table = grid_df.pivot_table(\n",
" index=param1_name, columns=param2_name, values=optimization_metric\n",
" )\n",
" print(f\"\\n{optimization_metric} 网格结果 (Pivoted):\")\n",
" print(pivot_table.to_string())\n",
" else:\n",
" print(f\"\\n没有成功的组合结果可供分析或优化指标 '{optimization_metric}' 不在结果中,或所有组合均失败。\")\n",
" else:\n",
" print(\"没有可用的回测结果。\")\n",
" print(\"\\n--- 动态网格搜索完成 ---\")\n",
"\n",
" # --- 6. 可视化 (依赖 GridSearchAnalyzer) ---\n",
" if grid_results:\n",
" grid_analyzer = GridSearchAnalyzer(grid_results, optimization_metric)\n",
" grid_analyzer.find_best_parameters() # 这会找到并打印最佳参数\n",
" grid_analyzer.plot_heatmap() # 这会绘制热力图\n",
" else:\n",
" print(\"\\n没有生成任何网格搜索结果无法进行分析。\")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "quant",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.11"
}
},
"nbformat": 4,
"nbformat_minor": 5
}