简单波动率策略,实现+网格搜索

This commit is contained in:
2025-06-22 23:03:50 +08:00
parent 355e451aac
commit a81a32ce73
19 changed files with 115435 additions and 748 deletions

1012
data/ analysis/Volume.ipynb Normal file

File diff suppressed because one or more lines are too long

14
data/tqsdk/test.py Normal file
View File

@@ -0,0 +1,14 @@
from tqsdk import TqApi, TqAuth
api = TqApi(auth=TqAuth("emanresu", "dfgvfgdfgg"))
# au 品种指数合约
ls = api.query_quotes(ins_class="INDEX", product_id="au")
print(ls)
# au 品种主连合约
ls = api.query_quotes(ins_class="CONT", product_id="au")
print(ls)
# 关闭api,释放相应资源
api.close()

View File

@@ -25,7 +25,7 @@ def collect_and_save_tqsdk_data_stream(
output_dir: str = "../data",
tq_user: str = TQ_USER_NAME,
tq_pwd: str = TQ_PASSWORD
) -> pd.DataFrame or None:
) -> pd.DataFrame:
"""
通过 TqSdk 在指定模式下回测或模拟运行监听并收集指定品种、频率、日期范围的K线数据流
并将其保存到本地CSV文件。此函数会模拟 TqSdk 的时间流运行。
@@ -190,10 +190,10 @@ if __name__ == "__main__":
# 示例1: 在回测模式下获取沪深300指数主连的日线数据 (用于历史回测)
# 这种方式适合获取相对较短或中等长度的历史K线数据。
df_if_backtest_daily = collect_and_save_tqsdk_data_stream(
symbol="SHFE.rb2410",
freq="min60",
start_date_str="2024-05-01",
end_date_str="2024-09-01",
symbol="KQ.i@SHFE.rb",
freq="day",
start_date_str="2023-01-01",
end_date_str="2025-05-01",
mode="backtest", # 指定为回测模式
tq_user=TQ_USER_NAME,
tq_pwd=TQ_PASSWORD

12953
grid_search.ipynb Normal file

File diff suppressed because it is too large Load Diff

1334
main.ipynb

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@@ -9,7 +9,7 @@ from ..core_data import PortfolioSnapshot, Trade, Bar
def calculate_metrics(
snapshots: List[PortfolioSnapshot], trades: List[Trade], initial_capital: float
snapshots: List[PortfolioSnapshot], trades: List[Trade], initial_capital: float
) -> Dict[str, Any]:
"""
纯函数:根据投资组合快照和交易历史计算关键绩效指标。
@@ -124,11 +124,27 @@ def calculate_metrics(
"亏损交易次数": losing_count,
"平均每次盈利": avg_profit_per_trade,
"平均每次亏损": avg_loss_per_trade, # 这个值是负数
"InitialCapital": initial_capital,
"FinalCapital": final_value,
"TotalReturn": total_return,
"AnnualizedReturn": annualized_return,
"MaxDrawdown": max_drawdown,
"SharpeRatio": sharpe_ratio,
"CalmarRatio": calmar_ratio,
"TotalTrades": len(trades), # All buy and sell trades
"TransactionCosts": total_commissions,
"TotalRealizedPNL": total_realized_pnl, # New
"WinRate": win_rate,
"ProfitLossRatio": profit_loss_ratio,
"WinningTradesCount": winning_count,
"LosingTradesCount": losing_count,
"AvgProfitPerTrade": avg_profit_per_trade,
"AvgLossPerTrade": avg_loss_per_trade, # This value is negative
}
def plot_equity_and_drawdown_chart(snapshots: List[PortfolioSnapshot], initial_capital: float,
title: str = "Portfolio Equity and Drawdown Curve") -> None:
title: str = "Portfolio Equity and Drawdown Curve") -> None:
"""
Plots the portfolio equity curve and drawdown. X-axis points are equally spaced.
@@ -145,7 +161,7 @@ def plot_equity_and_drawdown_chart(snapshots: List[PortfolioSnapshot], initial_c
{'datetime': s.datetime, 'total_value': s.total_value}
for s in snapshots
])
equity_curve = df_equity['total_value'] / initial_capital
rolling_max = equity_curve.cummax()
@@ -203,7 +219,7 @@ def plot_close_price_chart(bars: List[Bar], title: str = "Close Price Chart") ->
])
plt.style.use('seaborn-v0_8-darkgrid')
fig, ax = plt.subplots(1, 1, figsize=(14, 7)) # Single subplot
fig, ax = plt.subplots(1, 1, figsize=(14, 7)) # Single subplot
x_axis_indices = np.arange(len(df_prices))
@@ -228,7 +244,7 @@ def plot_close_price_chart(bars: List[Bar], title: str = "Close Price Chart") ->
# 辅助函数:计算单笔交易的盈亏
def calculate_trade_pnl(
trade: Trade, entry_price: float, exit_price: float, direction: str
trade: Trade, entry_price: float, exit_price: float, direction: str
) -> float:
if direction == "LONG":
pnl = (exit_price - entry_price) * trade.volume

View File

@@ -0,0 +1,90 @@
# src/grid_search_analyzer.py
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from typing import List, Dict, Any, Tuple
class GridSearchAnalyzer:
"""
用于分析和可视化网格搜索结果的类。
"""
def __init__(self, grid_results: List[Dict[str, Any]], optimization_metric: str):
"""
初始化 GridSearchAnalyzer。
Args:
grid_results (List[Dict[str, Any]]): 包含每个参数组合及其对应优化指标的列表。
例如:[{'param1': v1, 'param2': v2, 'metric': m1}, ...]
optimization_metric (str): 用于优化的指标名称(例如 'total_return')。
"""
if not grid_results:
raise ValueError("grid_results 列表不能为空。")
if optimization_metric not in grid_results[0]:
raise ValueError(f"优化指标 '{optimization_metric}' 不在 grid_results 的字典中。")
self.grid_results = grid_results
self.optimization_metric = optimization_metric
self.param_names = [k for k in grid_results[0].keys() if k != optimization_metric]
if len(self.param_names) != 2:
raise ValueError("GridSearchAnalyzer 当前只支持分析两个参数的网格搜索结果。")
self.param1_name = self.param_names[0]
self.param2_name = self.param_names[1]
def find_best_parameters(self) -> Dict[str, Any]:
"""
找到网格搜索中表现最佳的参数组合和其指标值。
"""
if not self.grid_results:
return {}
best_result = max(self.grid_results, key=lambda x: x[self.optimization_metric])
print(f"\n--- 最佳参数组合 ---")
print(f" {self.param1_name}: {best_result[self.param1_name]}")
print(f" {self.param2_name}: {best_result[self.param2_name]}")
print(f" {self.optimization_metric}: {best_result[self.optimization_metric]:.4f}")
return best_result
def plot_heatmap(self, title: str = "heatmap"):
"""
绘制两个参数的热力图。
"""
if not self.grid_results:
print("没有数据用于绘制热力图。")
return
x_values = sorted(list(set(d[self.param1_name] for d in self.grid_results)))
y_values = sorted(list(set(d[self.param2_name] for d in self.grid_results)))
heatmap_matrix = np.zeros((len(y_values), len(x_values)))
# 填充网格
for item in self.grid_results:
x_idx = x_values.index(item[self.param1_name])
y_idx = y_values.index(item[self.param2_name])
heatmap_matrix[y_idx, x_idx] = item[self.optimization_metric]
fig, ax = plt.subplots(figsize=(10, 8))
im = ax.imshow(heatmap_matrix, cmap='viridis', origin='lower',
extent=[x_values[0], x_values[-1], y_values[0], y_values[-1]],
aspect='auto')
ax.set_xticks(x_values)
ax.set_yticks(y_values)
ax.set_xlabel(self.param1_name)
ax.set_ylabel(self.param2_name)
ax.set_title(f"{title}: {self.optimization_metric} vs. {self.param1_name} & {self.param2_name}")
cbar = fig.colorbar(im, ax=ax)
cbar.set_label(self.optimization_metric)
# 在每个格子上显示数值
for i in range(len(y_values)):
for j in range(len(x_values)):
text = ax.text(x_values[j], y_values[i], f"{heatmap_matrix[i, j]:.2f}",
ha="center", va="center", color="w", fontsize=8)
plt.tight_layout()
plt.show()

View File

@@ -1,10 +1,16 @@
# src/analysis/result_analyzer.py
import pandas as pd
import numpy as np
from typing import List, Dict, Any, Optional
# 导入纯函数 (注意相对导入路径的变化)
from .analysis_utils import calculate_metrics, plot_equity_and_drawdown_chart, plot_close_price_chart
from .analysis_utils import (
calculate_metrics,
plot_equity_and_drawdown_chart,
plot_close_price_chart,
)
# 导入核心数据类 (注意相对导入路径的变化)
from ..core_data import PortfolioSnapshot, Trade, Bar
@@ -14,11 +20,13 @@ class ResultAnalyzer:
结果分析器:负责接收回测数据,并提供分析和可视化方法。
"""
def __init__(self,
portfolio_snapshots: List[PortfolioSnapshot],
trade_history: List[Trade],
bars: List[Bar],
initial_capital: float):
def __init__(
self,
portfolio_snapshots: List[PortfolioSnapshot],
trade_history: List[Trade],
bars: List[Bar],
initial_capital: float,
):
"""
Args:
portfolio_snapshots (List[PortfolioSnapshot]): 回测引擎输出的投资组合快照列表。
@@ -41,9 +49,7 @@ class ResultAnalyzer:
if self._metrics_cache is None:
print("正在计算绩效指标...")
self._metrics_cache = calculate_metrics(
self.portfolio_snapshots,
self.trade_history,
self.initial_capital
self.portfolio_snapshots, self.trade_history, self.initial_capital
)
print("绩效指标计算完成。")
return self._metrics_cache
@@ -69,7 +75,8 @@ class ResultAnalyzer:
print("\n--- 部分交易明细 (最近5笔) ---")
for trade in self.trade_history[-5:]:
print(
f" {trade.fill_time} | {trade.direction:<10} | {trade.symbol} | Vol: {trade.volume} | Price: {trade.price:.2f} | Commission: {trade.commission:.2f}")
f" {trade.fill_time} | {trade.direction:<10} | {trade.symbol} | Vol: {trade.volume} | Price: {trade.price:.2f} | Commission: {trade.commission:.2f}"
)
else:
print("\n没有交易记录。")
@@ -79,10 +86,13 @@ class ResultAnalyzer:
"""
print("正在绘制绩效图表...")
# plot_performance_chart(self.portfolio_snapshots, self.initial_capital, self.bars)
plot_equity_and_drawdown_chart(self.portfolio_snapshots, self.initial_capital,
title="Portfolio Equity and Drawdown Curve")
plot_equity_and_drawdown_chart(
self.portfolio_snapshots,
self.initial_capital,
title="Portfolio Equity and Drawdown Curve",
)
# 绘制单独的收盘价曲线
plot_close_price_chart(self.bars, title="Underlying Asset Close Price")
print("图表绘制完成。")
print("图表绘制完成。")

View File

@@ -58,6 +58,16 @@ class BacktestContext:
# 可以在这里触发策略的on_trade回调如果策略定义了
return trade
def cancel_order(self, order_id: str) -> bool:
"""
策略通过此方法发出交易订单。
"""
if self._current_bar is None:
raise RuntimeError("当前Bar未设置无法发送订单。")
return self._simulator.cancel_order(order_id)
def get_current_positions(self) -> Dict[str, int]:
"""
获取当前模拟器的持仓情况。

View File

@@ -36,6 +36,7 @@ class BacktestEngine:
commission_rate (float): 交易佣金率。
"""
self.data_manager = data_manager
self.initial_capital = initial_capital
self.simulator = ExecutionSimulator(
initial_capital=initial_capital,
slippage_rate=slippage_rate,
@@ -76,6 +77,7 @@ class BacktestEngine:
# 设置当前Bar到Context供策略访问
self.context.set_current_bar(current_bar)
self.simulator.update_time(current_time=current_bar.datetime)
# 更新历史Bar缓存
self._history_bars.append(current_bar)
@@ -129,7 +131,7 @@ class BacktestEngine:
for symbol_held, quantity in positions_to_close.items():
if quantity != 0:
print(f"[{last_processed_bar.datetime}] 回测结束平仓: 平仓 {symbol_held} ({quantity} 手) @ {last_processed_bar.close:.2f}")
direction = "SELL" if quantity > 0 else "BUY"
direction = "CLOSE_LONG" if quantity > 0 else "CLOSE_SELL"
volume = abs(quantity)
# 使用当前合约的最后一根Bar的价格进行平仓
@@ -148,6 +150,17 @@ class BacktestEngine:
print(f"总计处理了 {len(self.portfolio_snapshots)} 根K线。")
print(f"总计发生了 {len(self.trade_history)} 笔交易。")
final_portfolio_value = 0.0
if last_processed_bar:
final_portfolio_value = self.simulator.get_portfolio_value(last_processed_bar)
else: # 如果数据为空,或者回测根本没跑,则净值为初始资金
final_portfolio_value = self.initial_capital
total_return_percentage = ((final_portfolio_value - self.initial_capital) / self.initial_capital) * 100
print(f"最终总净值: {final_portfolio_value:.2f}")
print(f"总收益率: {total_return_percentage:.2f}%")
def get_backtest_results(self) -> Dict[str, Any]:
"""
返回回测结果数据,供结果分析模块使用。

189
src/backtest_runner.py Normal file
View File

@@ -0,0 +1,189 @@
# src/backtest_runner.py
import pandas as pd
from datetime import datetime
from typing import List, Dict, Any, Optional, Type
# 导入核心组件
from .backtest_engine import BacktestEngine
from .data_manager import DataManager
from .execution_simulator import (
ExecutionSimulator,
) # 虽然不直接用,但 BacktestEngine 内部会用到
from .strategies.base_strategy import Strategy # 用于类型提示
from .core_data import Bar, PortfolioSnapshot, Trade, Order # 导入数据结构
def run_multi_segment_backtest(
segment_configs: Dict[str, Dict[str, Any]],
strategy_class: Type[Strategy],
strategy_params: Dict[str, Any],
initial_capital: float,
slippage_rate: float,
commission_rate: float,
total_start_date: datetime,
total_end_date: datetime,
) -> Dict[str, Any]:
"""
运行一个多主力合约的分段回测,并返回合并后的回测结果。
Args:
segment_configs (Dict[str, Dict[str, Any]]): 包含所有合约片段配置的字典。
strategy_class (Type[Strategy]): 要回测的策略类。
strategy_params (Dict[str, Any]): 传递给策略的参数字典。
initial_capital (float): 回测的初始资金。
slippage_rate (float): 交易滑点率。
commission_rate (float): 交易佣金率。
total_start_date (datetime): 总回测的起始日期。
total_end_date (datetime): 总回测的结束日期。
Returns:
Dict[str, Any]: 包含合并后的 portfolio_snapshots, trade_history, all_bars
和 initial_capital 的字典。
"""
# --- 收集所有片段的回测结果 ---
all_combined_snapshots: List[PortfolioSnapshot] = []
all_combined_trades: List[Trade] = []
all_combined_bars: List[Bar] = []
# 用于净值平滑的基准值,从初始资金开始
last_segment_adjusted_total_value = initial_capital
current_cash = initial_capital # 用于传递给下一个片段的初始资金
current_positions: Dict[str, int] = {} # 用于传递给下一个片段的初始持仓
print("\n--- 开始分段回测跨越多个合约 ---")
# 获取所有合约的有序键,确保按定义的顺序回测
segment_keys = list(segment_configs.keys())
for i, contract_symbol in enumerate(segment_keys):
config = segment_configs[contract_symbol]
segment_file = config["file"]
segment_start_date = config["start"]
segment_end_date = config["end"]
current_segment_start = max(segment_start_date, total_start_date)
current_segment_end = min(segment_end_date, total_end_date)
if current_segment_start > current_segment_end:
print(f"跳过片段 {contract_symbol}: 日期范围超出总回测周期或无效。")
continue
print(
f"\n--- 回测片段: {contract_symbol}{current_segment_start.strftime('%Y-%m-%d %H:%M')}{current_segment_end.strftime('%Y-%m-%d %H:%M')} ---"
)
# 1. 初始化 DataManager (每个片段使用不同的 DataManager 实例)
data_manager = DataManager(file_path=segment_file, symbol=contract_symbol)
if data_manager.raw_df.empty:
print(
f"警告: 未能加载 {contract_symbol} 的数据文件 {segment_file}。跳过此片段。"
)
continue
strategy_params["symbol"] = contract_symbol
# 2. 创建 BacktestEngine 实例 (针对当前合约片段)
engine = BacktestEngine(
data_manager=data_manager,
strategy_class=strategy_class,
strategy_params=strategy_params, # 使用传入的策略参数
initial_capital=initial_capital,
slippage_rate=slippage_rate,
commission_rate=commission_rate,
current_segment_symbol=contract_symbol,
)
# 在运行当前片段之前,将上一个片段结束时的资金和持仓注入到当前 BacktestEngine 的 simulator 中
engine.get_simulator().reset(
new_initial_capital=current_cash, new_initial_positions=current_positions
)
# 3. 运行当前片段回测
engine.run_backtest()
segment_results = engine.get_backtest_results()
# 4. 获取当前片段结束后的最新资金和持仓状态,供下个片段使用
current_cash = engine.get_simulator().cash
current_positions = engine.get_simulator().positions.copy()
# 5. 拼接结果:净值曲线需要特殊处理
segment_snapshots = segment_results["portfolio_snapshots"]
segment_trades = segment_results["trade_history"]
segment_bars = segment_results["all_bars"]
if segment_snapshots:
first_snapshot_value_of_segment = segment_snapshots[0].total_value
if first_snapshot_value_of_segment == 0:
segment_relative_factor = 1.0
else:
segment_relative_factor = (
last_segment_adjusted_total_value / first_snapshot_value_of_segment
)
for snapshot in segment_snapshots:
adjusted_snapshot_value = snapshot.total_value * segment_relative_factor
adjusted_snapshot = PortfolioSnapshot(
datetime=snapshot.datetime,
total_value=adjusted_snapshot_value,
cash=snapshot.cash,
positions=snapshot.positions.copy(),
price_at_snapshot=snapshot.price_at_snapshot.copy(),
)
all_combined_snapshots.append(adjusted_snapshot)
last_segment_adjusted_total_value = all_combined_snapshots[-1].total_value
else:
print(
f"警告: 未能为 {contract_symbol} 生成任何快照。`last_segment_adjusted_total_value` 保持不变: {last_segment_adjusted_total_value:.2f}"
)
all_combined_trades.extend(segment_trades)
all_combined_bars.extend(segment_bars)
print("\n--- 总分段回测完成 ---")
return {
"portfolio_snapshots": all_combined_snapshots,
"trade_history": all_combined_trades,
"all_bars": all_combined_bars,
"initial_capital": initial_capital, # 返回初始资金,方便 ResultAnalyzer 使用
}
def run_backtest_for_optimization(
global_config: Dict[str, Any],
strategy_class: Type[Strategy],
current_strategy_params: Dict[str, Any],
) -> Dict[str, Any]:
"""
简化后的回测入口,用于优化或参数搜索。
它从 global_config 中提取通用回测参数,并结合 current_strategy_params 运行回测。
Args:
global_config (Dict[str, Any]): 包含所有通用全局配置的字典。
strategy_class (Type[Strategy]): 要回测的策略类。
current_strategy_params (Dict[str, Any]): 当前迭代的策略参数。
Returns:
Dict[str, Any]: 包含合并后的 portfolio_snapshots, trade_history, all_bars
和 initial_capital 的字典。
"""
total_start_date = global_config["total_start_date"]
total_end_date = global_config["total_end_date"]
initial_capital = global_config["initial_capital"]
slippage_rate = global_config["slippage_rate"]
commission_rate = global_config["commission_rate"]
segment_configs = global_config["segment_configs"]
return run_multi_segment_backtest(
segment_configs=segment_configs,
strategy_class=strategy_class,
strategy_params=current_strategy_params,
initial_capital=initial_capital,
slippage_rate=slippage_rate,
commission_rate=commission_rate,
total_start_date=total_start_date,
total_end_date=total_end_date,
)

36
src/common_utils.py Normal file
View File

@@ -0,0 +1,36 @@
# src/common_utils.py
from typing import List, Union
def generate_parameter_range(start: Union[int, float], end: Union[int, float], step: Union[int, float]) -> List[Union[int, float]]:
"""
根据开始、结束和步长生成一个参数值的列表。
支持整数和浮点数。
Args:
start (Union[int, float]): 参数范围的起始值。
end (Union[int, float]): 参数范围的结束值。
step (Union[int, float]): 参数的步长。
Returns:
List[Union[int, float]]: 生成的参数值列表。
"""
if step == 0:
raise ValueError("步长不能为0。")
if (end - start) * step < 0: # 步长方向与范围不符
raise ValueError("步长方向与起始/结束值不一致,请检查步长正负。")
param_range = []
current_value = start
while (step > 0 and current_value <= end) or (step < 0 and current_value >= end):
param_range.append(current_value)
current_value += step
# 针对浮点数精度问题进行小幅调整,避免无限循环或过早停止
if isinstance(step, float) or isinstance(start, float) or isinstance(end, float):
current_value = round(current_value, 10) # 四舍五入到一定小数位
return param_range
# 示例:
# print(generate_parameter_range(0.99, 1.01, 0.005)) # [0.99, 0.995, 1.0, 1.005, 1.01]
# print(generate_parameter_range(5, 20, 5)) # [5, 10, 15, 20]

View File

@@ -1,5 +1,5 @@
# src/execution_simulator.py (修改部分)
from datetime import datetime
from typing import Dict, List, Optional
import pandas as pd
from .core_data import Order, Trade, Bar, PortfolioSnapshot
@@ -36,35 +36,111 @@ class ExecutionSimulator:
self.commission_rate = commission_rate
self.trade_log: List[Trade] = [] # 存储所有成交记录
self.pending_orders: Dict[str, Order] = {} # {order_id: Order_object}
self._current_time = None
print(
f"模拟器初始化:初始资金={self.initial_capital:.2f}, 滑点率={self.slippage_rate}, 佣金率={self.commission_rate}")
if self.positions:
print(f"初始持仓:{self.positions}")
def update_time(self, current_time: datetime):
"""
更新模拟器的当前时间。
这个方法由 BacktestEngine 在遍历 K 线时调用。
"""
self._current_time = current_time
# --- 新增的公共方法 ---
def get_current_time(self) -> datetime:
"""
获取模拟器的当前时间。
"""
if self._current_time is None:
# 可以在这里抛出错误或者返回一个默认值,取决于你对未初始化时间的处理
# 抛出错误可以帮助你发现问题,例如在模拟器时间未设置时就尝试获取
# raise RuntimeError("Simulator time has not been set. Ensure update_time is called.")
return None
return self._current_time
def _calculate_fill_price(self, order: Order, current_bar: Bar) -> float:
"""
内部方法:根据订单类型和滑点计算实际成交价格。
简化处理:市价单以当前Bar收盘价为基准考虑滑点。
- 市价单通常以当前K线的开盘价成交考虑滑点
- 限价单判断是否触及限价,如果触及,以限价成交(考虑滑点)。
"""
base_price = current_bar.close # 简化为收盘价成交
fill_price = -1.0 # 默认未成交
# 考虑滑点
if order.direction in ["BUY", "CLOSE_SHORT"]: # 买入或平空,价格向上偏离
fill_price = base_price * (1 + self.slippage_rate)
elif order.direction in ["SELL", "CLOSE_LONG"]: # 卖出或平多,价格向下偏离
fill_price = base_price * (1 - self.slippage_rate)
else: # 默认情况,无滑点
fill_price = base_price
if order.price_type == "MARKET":
# 市价单通常以开盘价成交,或者根据你的策略需求选择收盘价
# 这里我们仍然使用开盘价作为市价单的基准成交价
base_price = current_bar.open
# 如果是限价单且成交价格不满足条件,则可能不成交
if order.price_type == "LIMIT" and order.limit_price is not None:
# 对于BUY和CLOSE_SHORT成交价必须 <= 限价
if (order.direction == "BUY" or order.direction == "CLOSE_SHORT") and fill_price > order.limit_price:
return -1.0 # 未触及限价
# 对于SELL和CLOSE_LONG成交价必须 >= 限价
elif (order.direction == "SELL" or order.direction == "CLOSE_LONG") and fill_price < order.limit_price:
return -1.0 # 未触及限价
if order.direction == "BUY" or order.direction == "CLOSE_SHORT": # 买入或平空,价格向上偏离
fill_price = base_price * (1 + self.slippage_rate)
elif order.direction == "SELL" or order.direction == "CLOSE_LONG": # 卖出或平多,价格向下偏离
fill_price = base_price * (1 - self.slippage_rate)
else: # 默认情况,理论上不应该到这里,因为方向应该明确
fill_price = base_price # 不考虑滑点
# 市价单只要有价格就会成交,无需额外价格区间判断
elif order.price_type == "LIMIT" and order.limit_price is not None:
limit_price = order.limit_price
high = current_bar.high
low = current_bar.low
open_price = current_bar.open # 也可以在限价单成交时用开盘价作为实际成交价,或限价本身
if order.direction == "BUY" or order.direction == "CLOSE_SHORT": # 限价买入或限价平空
# 买入如果K线最低价 <= 限价,则限价单可能成交
if low <= limit_price:
# 成交价通常是限价,但考虑滑点后可能略高(对买方不利)
# 或者可以以 open/low 之间的一个价格成交,这里简化为限价+滑点
fill_price_candidate = limit_price * (1 + self.slippage_rate)
# 确保成交价不会比当前K线的最低价还低如果不是在最低点成交
# 也可以简单就返回 limit_price * (1 + self.slippage_rate)
# 如果开盘价低于或等于限价,那么通常会以开盘价成交(或者略差)
# 否则,如果价格是从上方跌落到限价区,那么会在限价附近成交
# 这里简化如果K线触及限价则以限价成交考虑滑点
# 更精细的模拟会考虑价格穿越顺序 (例如,是否开盘就跳过了限价)
# 如果开盘价已经低于或等于限价,那么就以开盘价成交(考虑滑点)
if open_price <= limit_price:
fill_price = open_price * (1 + self.slippage_rate)
else: # 价格从高处跌落到限价
fill_price = limit_price * (1 + self.slippage_rate)
# 确保成交价不会超过限价 (虽然加滑点可能会略超,但这是交易成本的一部分)
# 这个检查是为了避免逻辑错误,理论上加滑点后应该可以接受比限价略高的价格
# 如果你严格要求成交价不能高于限价,则需要移除滑点或者更复杂的逻辑
# 这里我们接受加滑点后的价格
if fill_price > high and fill_price > limit_price: # 如果计算出来的成交价高于K线最高价则可能不合理
fill_price = -1.0 # 理论上不该发生,除非滑点过大
else:
return -1.0 # 未触及限价
elif order.direction == "SELL" or order.direction == "CLOSE_LONG": # 限价卖出或限价平多
# 卖出如果K线最高价 >= 限价,则限价单可能成交
if high >= limit_price:
# 成交价通常是限价,但考虑滑点后可能略低(对卖方不利)
fill_price_candidate = limit_price * (1 - self.slippage_rate)
# 如果开盘价已经高于或等于限价,那么就以开盘价成交(考虑滑点)
if open_price >= limit_price:
fill_price = open_price * (1 - self.slippage_rate)
else: # 价格从低处上涨到限价
fill_price = limit_price * (1 - self.slippage_rate)
# 确保成交价不会低于限价
if fill_price < low and fill_price < limit_price: # 如果计算出来的成交价低于K线最低价则可能不合理
fill_price = -1.0 # 理论上不该发生
else:
return -1.0 # 未触及限价
# 最后检查成交价是否有效
if fill_price <= 0:
return -1.0 # 如果计算出来价格无效,返回未成交
return fill_price
@@ -107,6 +183,7 @@ class ExecutionSimulator:
if fill_price <= 0: # 表示未成交或不满足限价条件
if order.price_type == "LIMIT":
self.pending_orders[order.id] = order
# print(f'撮合失败,order id:{order.id},fill_price:{fill_price}')
return None # 未成交返回None
# --- 以下是订单成功成交的逻辑 ---
@@ -116,78 +193,90 @@ class ExecutionSimulator:
current_position = self.positions.get(symbol, 0)
current_average_cost = self.average_costs.get(symbol, 0.0)
if order.direction == "BUY":
actual_direction = order.direction
if order.direction == "CLOSE_SHORT":
actual_direction = "BUY"
elif order.direction == "CLOSE_LONG":
actual_direction = "SELL"
is_close_order_intent = (order.direction == "CLOSE_LONG" or
order.direction == "CLOSE_SHORT")
if actual_direction == "BUY": # 处理买入 (开多 / 平空)
# 开多仓或平空仓
if current_position >= 0: # 当前持有多仓或无仓位 (开多)
is_open_trade = True
is_open_trade = not is_close_order_intent # 如果是平仓意图,则不是开仓交易
# 更新平均成本 (加权平均)
new_total_cost = (current_average_cost * current_position) + (fill_price * volume)
new_total_volume = current_position + volume
self.average_costs[symbol] = new_total_cost / new_total_volume if new_total_volume > 0 else 0.0
self.positions[symbol] = new_total_volume
else: # 当前持有空仓 (平空)
is_close_trade = True
is_close_trade = is_close_order_intent # 这是平仓交易
# 计算平空盈亏
# PnL = (开仓成本 - 平仓价格) * 平仓数量 (注意空头方向)
# 简化:假设平空时,直接使用当前的平均开仓成本来计算盈亏
# 更精确的FIFO/LIFO需更多逻辑
pnl_per_share = current_average_cost - fill_price # (买入平空,成本高于平仓价则盈利)
pnl_per_share = current_average_cost - fill_price
realized_pnl = pnl_per_share * volume
# 更新持仓和成本
self.positions[symbol] += volume
if self.positions[symbol] == 0:
if self.positions[symbol] == 0: # 如果全部平仓
del self.positions[symbol]
if symbol in self.average_costs: del self.average_costs[symbol] # 清理成本
elif self.positions[symbol] > 0 and current_position < 0: # 部分平空转为多头,需重新设置成本
# 这部分逻辑可以更复杂这里简化处理如果转为多头成本重置为0
# 实际应该用剩余的空头成本 + 新开多的成本
self.average_costs[symbol] = fill_price # 简单地将剩下的多头仓位成本设为当前价格
if symbol in self.average_costs: del self.average_costs[symbol]
elif self.positions[symbol] > 0 and current_position < 0: # 部分平空,且空头仓位被买平为多头仓位
# 这是从空头转为多头的复杂情况。需要重新计算平均成本
# 简单处理:将剩余的多头仓位成本设为当前价格
self.average_costs[symbol] = fill_price
# 资金扣除
if self.cash >= trade_value + commission:
self.cash -= (trade_value + commission)
else:
# print(f"[{current_bar.datetime}] 资金不足,无法执行买入 {volume} {symbol}")
print(f"[{current_bar.datetime}] 资金不足,无法执行买入 {volume} {symbol}")
return None
elif order.direction == "SELL":
elif actual_direction == "SELL": # 处理卖出 (开空 / 平多)
# 开空仓或平多仓
if current_position <= 0: # 当前持有空仓或无仓位 (开空)
is_open_trade = True
is_open_trade = not is_close_order_intent # 如果是平仓意图,则不是开仓交易
# 更新平均成本 (空头成本为负值)
new_total_cost = (current_average_cost * current_position) - (fill_price * volume) # 负的持仓乘以负的卖出价
new_total_volume = current_position - volume # 空头持仓量更负
self.average_costs[symbol] = new_total_cost / new_total_volume if new_total_volume < 0 else 0.0
self.positions[symbol] = new_total_volume
# 对于空头,平均成本通常是指你卖出开仓的平均价格
# 这里需要根据你的空头成本计算方式来调整
# 常见的做法是:总卖出价值 / 总卖出数量
new_total_value = (current_average_cost * abs(current_position)) + (fill_price * volume)
new_total_volume = abs(current_position) + volume
self.average_costs[symbol] = new_total_value / new_total_volume if new_total_volume > 0 else 0.0
self.positions[symbol] -= volume # 空头数量增加,持仓量变为负更多
else: # 当前持有多仓 (平多)
is_close_trade = True
is_close_trade = is_close_order_intent # 这是平仓交易
# 计算平多盈亏
# PnL = (平仓价格 - 开仓成本) * 平仓数量
pnl_per_share = fill_price - current_average_cost # (卖出平多,平仓价高于成本则盈利)
pnl_per_share = fill_price - current_average_cost
realized_pnl = pnl_per_share * volume
# 更新持仓和成本
self.positions[symbol] -= volume
if self.positions[symbol] == 0:
if self.positions[symbol] == 0: # 如果全部平仓
del self.positions[symbol]
if symbol in self.average_costs: del self.average_costs[symbol] # 清理成本
elif self.positions[symbol] < 0 and current_position > 0: # 部分平多转为空头,需重新设置成本
self.average_costs[symbol] = fill_price # 简单地将剩下的空头仓位成本设为当前价格
if symbol in self.average_costs: del self.average_costs[symbol]
elif self.positions[symbol] < 0 and current_position > 0: # 部分平多,且多头仓位被卖平为空头仓位
# 从多头转为空头的复杂情况
self.average_costs[symbol] = fill_price # 简单将剩余空头仓位成本设为当前价格
# 资金扣除 (佣金) 和增加 (卖出收入)
if self.cash >= commission:
self.cash -= commission
self.cash += trade_value
else:
# print(f"[{current_bar.datetime}] 资金不足,无法执行卖出 {volume} {symbol}")
print(f"[{current_bar.datetime}] 资金不足,无法执行卖出 {volume} {symbol}")
return None
else: # 既不是 BUY 也不是 SELL且也不是 CANCEL。这可能是未知的 direction
print(
f"[{current_bar.datetime}] 模拟器: 收到未知订单方向 {order.direction} for Order ID: {order.id}. 订单未处理。")
return None
# 创建 Trade 对象
# 创建 Trade 对象
executed_trade = Trade(
order_id=order.id, fill_time=current_bar.datetime, symbol=symbol,
direction=order.direction, # 记录原始订单方向 (BUY/SELL)
direction=order.direction, # 记录原始订单方向 (BUY/SELL/CLOSE_X)
volume=volume, price=fill_price, commission=commission,
cash_after_trade=self.cash, positions_after_trade=self.positions.copy(),
realized_pnl=realized_pnl, # 填充实现盈亏
@@ -200,8 +289,6 @@ class ExecutionSimulator:
if order.id in self.pending_orders:
del self.pending_orders[order.id]
# print(f"[{current_bar.datetime}] 成交: {executed_trade.direction} {executed_trade.volume} {executed_trade.symbol} @ {executed_trade.price:.2f}, 佣金: {executed_trade.commission:.2f}, PnL: {executed_trade.realized_pnl:.2f}")
return executed_trade
def cancel_order(self, order_id: str) -> bool:
@@ -245,7 +332,7 @@ class ExecutionSimulator:
quantity = self.positions[symbol_in_position]
# 持仓市值 = 数量 * 当前市场价格 (current_bar.close)
# 无论多头(quantity > 0)还是空头(quantity < 0),这个计算都是正确的
total_value += quantity * current_bar.close
total_value += quantity * current_bar.open
# 您也可以选择在这里打印调试信息
# print(f" DEBUG Portfolio Value Calculation: Cash={self.cash:.2f}, "
@@ -287,3 +374,14 @@ class ExecutionSimulator:
"""
print("ExecutionSimulator: 清空交易历史。")
self.trade_history = []
def get_average_position_price(self, symbol: str) -> Optional[float]:
"""
获取指定合约的平均持仓成本。
如果无持仓或无该合约记录,返回 None。
"""
# 返回 average_costs 字典中对应 symbol 的值
# 如果没有持仓或者没有记录,返回 None
if symbol in self.positions and self.positions[symbol] != 0:
return self.average_costs.get(symbol)
return None

0
src/research/__init__.py Normal file
View File

99766
src/research/grid_search.ipynb Normal file

File diff suppressed because one or more lines are too long

View File

@@ -0,0 +1,183 @@
# src/strategies/simple_limit_buy_strategy.py
from .base_strategy import Strategy
from ..core_data import Bar, Order
from typing import Optional, Dict, Any
from collections import deque
class SimpleLimitBuyStrategy(Strategy):
"""
一个基于当前K线Open、前1根和前7根K线Range计算优势价格进行限价买入的策略。
具备以下特点:
- 每根K线开始时取消上一根K线未成交的订单。
- 最多只能有一个开仓挂单和一个持仓。
- 包含简单的止损和止盈逻辑。
"""
def __init__(self, simulator: Any, symbol: str, enable_log: bool, trade_volume: int,
open_range_factor_1_ago: float,
open_range_factor_7_ago: float,
max_position: int,
stop_loss_points: float = 10, # 新增:止损点数
take_profit_points: float = 10): # 新增:止盈点数
"""
初始化策略。
Args:
simulator: 模拟器实例。
symbol (str): 交易合约代码。
trade_volume (int): 单笔交易量。
open_range_factor_1_ago (float): 前1根K线Range的权重因子用于从Open价向下偏移。
open_range_factor_7_ago (float): 前7根K线Range的权重因子用于从Open价向下偏移。
max_position (int): 最大持仓量此处为1因为只允许一个持仓
stop_loss_points (float): 止损点数(例如,亏损达到此点数则止损)。
take_profit_points (float): 止盈点数(例如,盈利达到此点数则止盈)。
"""
super().__init__(simulator, symbol, enable_log)
self.trade_volume = trade_volume
self.open_range_factor_1_ago = open_range_factor_1_ago
self.open_range_factor_7_ago = open_range_factor_7_ago
self.max_position = max_position # 理论上这里应为1
self.stop_loss_points = stop_loss_points
self.take_profit_points = take_profit_points
self.order_id_counter = 0
self._bar_history: deque[Bar] = deque(maxlen=7)
self._last_order_id: Optional[str] = None # 用于跟踪上一根K线发出的订单ID
self.log(f"策略初始化: symbol={self.symbol}, trade_volume={self.trade_volume}, "
f"open_range_factor_1_ago={self.open_range_factor_1_ago}, "
f"open_range_factor_7_ago={self.open_range_factor_7_ago}, "
f"max_position={self.max_position}, "
f"止损点={self.stop_loss_points}, 止盈点={self.take_profit_points}")
def on_bar(self, bar: Bar, next_bar_open: Optional[float] = None):
"""
每当新的K线数据到来时调用。
Args:
bar (Bar): 当前的K线数据对象。
next_bar_open (Optional[float]): 下一根K线的开盘价此处策略未使用。
"""
current_datetime = bar.datetime # 获取当前K线时间
# --- 1. 撤销上一根K线未成交的订单 ---
# 检查是否记录了上一笔订单ID并且该订单仍然在待处理列表中
if self._last_order_id:
pending_orders = self.get_pending_orders()
if self._last_order_id in pending_orders:
success = self.cancel_order(self._last_order_id) # 直接调用基类的取消方法
if success:
self.log(f"[{current_datetime}] 策略: 成功撤销上一根K线未成交订单 {self._last_order_id}")
else:
self.log(f"[{current_datetime}] 策略: 尝试撤销订单 {self._last_order_id} 失败(可能已成交或不存在)")
# 无论撤销成功与否,既然我们尝试了撤销,就清除记录
self._last_order_id = None
# else:
# self.log(f"[{current_datetime}] 策略: 无上一根K线未成交订单需要撤销。")
# 2. 更新K线历史
self._bar_history.append(bar)
trade_volume = self.trade_volume
# 获取当前持仓和未决订单(在取消之后获取,确保是最新的状态)
current_positions = self.get_current_positions()
current_pos_volume = current_positions.get(self.symbol, 0)
pending_orders_after_cancel = self.get_pending_orders() # 再次获取,此时应已取消旧订单
# --- 3. 平仓逻辑 (止损/止盈) ---
# 只有当有持仓时才考虑平仓
if current_pos_volume > 0: # 假设只做多,所以持仓量 > 0
avg_entry_price = self.get_average_position_price(self.symbol)
if avg_entry_price is not None:
pnl_per_unit = bar.close - avg_entry_price # 当前浮动盈亏(以收盘价计算)
# 止盈条件
if pnl_per_unit >= self.take_profit_points:
self.log(f"[{current_datetime}] 止盈信号 - PnL per unit: {pnl_per_unit:.2f}, 目标: {self.take_profit_points:.2f}")
order_id = f"{self.symbol}_BUY_{bar.datetime.strftime('%Y%m%d%H%M%S')}_{self.order_id_counter}"
self.order_id_counter += 1
# 创建一个限价多单
order = Order(
id=order_id,
symbol=self.symbol,
direction="CLOSE_LONG",
volume=trade_volume,
price_type="MARKET",
# limit_price=limit_price,
submitted_time=bar.datetime
)
trade = self.send_order(order)
return # 平仓后本K线不再进行开仓判断
# 止损条件
elif pnl_per_unit <= -self.stop_loss_points:
self.log(f"[{current_datetime}] 止损信号 - PnL per unit: {pnl_per_unit:.2f}, 目标: {-self.stop_loss_points:.2f}")
# 发送市价卖出订单平仓,确保立即成交
order_id = f"{self.symbol}_BUY_{bar.datetime.strftime('%Y%m%d%H%M%S')}_{self.order_id_counter}"
self.order_id_counter += 1
# 创建一个限价多单
order = Order(
id=order_id,
symbol=self.symbol,
direction="CLOSE_LONG",
volume=trade_volume,
price_type="MARKET",
# limit_price=limit_price,
submitted_time=bar.datetime
)
trade = self.send_order(order)
return # 平仓后本K线不再进行开仓判断
# --- 4. 开仓逻辑 (只考虑做多 BUY 方向) ---
# 只有在没有持仓 (current_pos_volume == 0) 且没有待处理订单 (not pending_orders_after_cancel)
# 且K线历史足够长时才考虑开仓
if current_pos_volume == 0 and \
len(self._bar_history) == self._bar_history.maxlen:
# 获取前1根K线 (倒数第二根) 和前7根K线 (队列中最老的一根)
bar_1_ago = self._bar_history[-2]
bar_7_ago = self._bar_history[0]
# 计算历史 K 线的 Range
range_1_ago = bar_1_ago.high - bar_1_ago.low
range_7_ago = bar_7_ago.high - bar_7_ago.low
# 根据策略逻辑计算目标买入价格
# 目标买入价 = 当前K线Open - (前1根Range * 因子1 + 前7根Range * 因子2)
self.log(bar.open ,range_1_ago * self.open_range_factor_1_ago, range_7_ago * self.open_range_factor_7_ago)
target_buy_price = bar.open - (range_1_ago * self.open_range_factor_1_ago + range_7_ago * self.open_range_factor_7_ago)
# 确保目标买入价格有效,例如不能是负数
target_buy_price = max(0.01, target_buy_price)
self.log(f"[{current_datetime}] 开多仓信号 - 当前Open={bar.open:.2f}, "
f"前1Range={range_1_ago:.2f}, 前7Range={range_7_ago:.2f}, "
f"计算目标买入价={target_buy_price:.2f}")
order_id = f"{self.symbol}_BUY_{bar.datetime.strftime('%Y%m%d%H%M%S')}_{self.order_id_counter}"
self.order_id_counter += 1
# 创建一个限价多单
order = Order(
id=order_id,
symbol=self.symbol,
direction="BUY",
volume=trade_volume,
price_type="LIMIT",
limit_price=target_buy_price,
submitted_time=bar.datetime
)
new_order = self.send_order(order)
# 记录下这个订单的ID以便在下一根K线开始时进行撤销
if new_order:
self._last_order_id = new_order.order_id
self.log(f"[{current_datetime}] 策略: 发送限价买入订单 {self._last_order_id} @ {target_buy_price:.2f}")
else:
self.log(f"[{current_datetime}] 策略: 发送订单失败。")
# else:
# self.log(f"[{current_datetime}] 不满足开仓条件:持仓={current_pos_volume}, 待处理订单={len(pending_orders_after_cancel)}, K线历史长度={len(self._bar_history)}")

View File

@@ -1,42 +1,31 @@
# src/strategies/base_strategy.py
from abc import ABC, abstractmethod
from typing import Dict, Any, Optional
from datetime import datetime
from typing import Dict, Any, Optional, List, TYPE_CHECKING
# 使用 TYPE_CHECKING 避免循环导入,但保留类型提示
from ..backtest_context import BacktestContext # 转发引用 BacktestEngine
from ..core_data import Bar, Order, Trade # 导入必要的类型
# 导入核心数据类
from ..core_data import Bar, Order, Trade
# 导入回测上下文 (注意相对导入路径的变化)
from ..backtest_context import BacktestContext
class Strategy(ABC):
"""
策略抽象基类。所有具体策略都应继承此类,并实现 on_bar 方法。
所有交易策略抽象基类。
策略通过 context 对象与回测引擎和模拟器进行交互,并提供辅助方法。
"""
def __init__(self, context: BacktestContext, **parameters: Any):
"""
初始化策略。
def __init__(self, context: 'BacktestContext', symbol: str, enable_log: bool = True, **params: Any):
"""
Args:
context (BacktestContext): 回测上下文对象,用于与模拟器和数据管理器交互
**parameters (Any): 策略所需的任何自定义参数
context (BacktestEngine): 回测引擎实例,作为策略的上下文,提供与模拟器等的交互接口
symbol (str): 策略操作的合约Symbol
**params (Any): 其他策略特定参数。
"""
self.context = context
self.parameters = parameters
self.symbol = parameters.get('symbol', "DEFAULT_SYMBOL") # 策略操作的品种
self.trade_volume = parameters.get('trade_volume', 100) # 每次下单的数量
print(f"策略初始化: {self.__class__.__name__},参数: {self.parameters}")
@abstractmethod
def on_bar(self, bar: Bar):
"""
每当接收到新的Bar数据时调用。
具体策略逻辑在此方法中实现。
Args:
bar (Bar): 当前的Bar数据对象。
features (Dict[str, float]): 由数据处理模块计算并传入的特征字典。
"""
pass
self.context = context # 存储 context 对象
self.symbol = symbol # 策略操作的合约Symbol
self.params = params
self.enable_log = enable_log
def on_init(self):
"""
@@ -45,7 +34,7 @@ class Strategy(ABC):
"""
print(f"{self.__class__.__name__} 策略初始化回调被调用。")
def on_trade(self, trade: Trade):
def on_trade(self, trade: 'Trade'):
"""
当模拟器成功执行一笔交易时调用。
可用于更新策略内部持仓状态或记录交易。
@@ -54,15 +43,94 @@ class Strategy(ABC):
trade (Trade): 已完成的交易记录。
"""
# print(f"策略接收到交易: {trade.direction} {trade.volume} {trade.symbol} @ {trade.price:.2f}")
pass # 默认不执行任何操作,具体策略可覆盖
pass # 默认不执行任何操作,具体策略可覆盖
def on_order_status(self, order: Order, status: str):
@abstractmethod
def on_bar(self, bar: 'Bar'):
"""
当订单状态更新时调用 (例如,未成交,已提交等)
在简易回测中,可能不会频繁使用。
每当新的K线数据到来时调用此方法
Args:
order (Order): 相关订单对象。
status (str): 订单状态(例如 "FILLED", "PENDING", "CANCELLED"
bar (Bar): 当前的K线数据对象。
next_bar_open (Optional[float]): 下一根K线的开盘价如果存在的话
"""
pass # 默认不执行任何操作
pass
# --- 新增/修改的辅助方法 ---
def send_order(self, order: 'Order') -> Optional[Trade]:
"""
发送订单的辅助方法。
会在 BaseStrategy 内部构建 Order 对象,并通过 context 转发给模拟器。
"""
return self.context.send_order(order)
def cancel_order(self, order_id: str) -> bool:
"""
取消指定ID的订单。
通过 context 调用模拟器的 cancel_order 方法。
"""
return self.context.cancel_order(order_id)
def cancel_all_pending_orders(self) -> int:
"""
取消所有当前策略的未决订单。
返回成功取消的订单数量。
"""
pending_orders = self.get_pending_orders() # 调用 BaseStrategy 自己的 get_pending_orders
cancelled_count = 0
orders_to_cancel = [order.id for order in pending_orders.values() if order.symbol == self.symbol]
for order_id in orders_to_cancel:
if self.cancel_order(order_id): # 调用 BaseStrategy 自己的 cancel_order
cancelled_count += 1
return cancelled_count
def get_current_positions(self) -> Dict[str, int]:
"""
获取当前持仓。
通过 context 调用模拟器的 get_positions 方法。
"""
return self.context._simulator.get_current_positions()
def get_pending_orders(self) -> Dict[str, 'Order']:
"""
获取当前所有待处理订单的副本。
通过 context 调用模拟器的 get_pending_orders 方法。
"""
return self.context._simulator.get_pending_orders()
def get_average_position_price(self, symbol: str) -> Optional[float]:
"""
获取指定合约的平均持仓成本。
通过 context 调用模拟器的 get_average_position_price 方法。
"""
return self.context._simulator.get_average_position_price(symbol)
# 你可以根据需要在这里添加更多辅助方法,如获取账户净值等
def get_account_cash(self) -> float:
"""获取当前账户现金余额。"""
return self.context._simulator.cash
def log(self, *args: Any, **kwargs: Any):
"""
统一的日志打印方法。
如果 enable_log 为 True则打印消息到控制台并包含当前模拟时间。
支持传入多个参数,像 print() 函数一样使用。
"""
if self.enable_log:
# 尝试获取当前模拟时间,如果模拟器或时间不可用,则跳过时间前缀
try:
current_time_str = self.context._simulator.get_current_time().strftime('%Y-%m-%d %H:%M:%S')
time_prefix = f"[{current_time_str}] "
except AttributeError:
# 如果获取不到时间(例如在策略初始化时,模拟器时间还未设置),则不加时间前缀
time_prefix = ""
# 使用 f-string 结合 *args 来构建消息
# print() 函数会将 *args 自动用空格分隔,这里我们模仿这个行为
message = ' '.join(map(str, args))
# 你可以将其他 kwargs (如 sep, end, file, flush) 传递给 print
# 但通常日志方法不会频繁使用这些。这里只支持最基础的打印。
print(f"{time_prefix}策略 ({self.symbol}): {message}", **kwargs)

View File

@@ -19,9 +19,10 @@ class SimpleLimitBuyStrategy(Strategy):
def __init__(self, context: Any, **parameters: Any): # context 类型提示可以为 BacktestContext
super().__init__(context, **parameters)
self.trade_volume = 1
self.order_id_counter = 0
self.limit_price_factor = parameters.get('limit_price_factor', 0.999) # 限价因子
self.max_position = parameters.get('max_position', self.trade_volume * 2) # 最大持仓量
self.max_position = parameters.get('max_position', self.trade_volume) # 最大持仓量
self._last_order_id: Optional[str] = None # 跟踪上一根Bar发出的订单ID
self._current_long_position: int = 0 # 策略内部维护的当前持仓
@@ -57,16 +58,14 @@ class SimpleLimitBuyStrategy(Strategy):
每接收到一根Bar时执行策略逻辑。
"""
current_portfolio_value = self.context.get_current_portfolio_value(bar)
print(f"[{bar.datetime}] Strategy processing Bar. Current close price: {bar.close:.2f}. Current Portfolio Value: {current_portfolio_value:.2f}")
# print(f"[{bar.datetime}] Strategy processing Bar. Current close price: {bar.close:.2f}. Current Portfolio Value: {current_portfolio_value:.2f}")
# 1. 撤销上一根K线未成交的订单
if self._last_order_id:
# 检查这个订单是否仍然在待处理订单列表中
pending_orders = self.context._simulator.get_pending_orders() # 直接访问模拟器或者通过context提供接口
pending_orders = self.get_pending_orders() # 直接访问模拟器或者通过context提供接口
if self._last_order_id in pending_orders:
success = self.context.send_order(Order(id=self._last_order_id, symbol=self.symbol,
direction="CANCEL", volume=0,
price_type="CANCEL")) # 使用一个特殊Order类型表示撤单
success = self.cancel_order(self._last_order_id)
# 这里发送的“撤单订单”会被simulator的send_order处理并调用simulator.cancel_order
if success: # simulator.send_order返回Trade或None这里我们用一个特殊处理
# Simulator的send_order返回的是Trade如果实现撤单最好Simulator的cancel_order返回bool
@@ -109,10 +108,10 @@ class SimpleLimitBuyStrategy(Strategy):
)
# 通过上下文发送订单
trade = self.context.send_order(order)
trade = self.send_order(order)
if trade:
print(
f"[{bar.datetime}] 策略: 发送并立即成交限价买单 {trade.volume} 股 @ {trade.price:.2f} (订单ID: {order.id})")
f"[{bar.datetime}] 策略: 发送并立即成交限价买单 {trade.volume} 股 @ {trade.price:.2f}(open:{bar.open}, close:{bar.close}) (订单ID: {order.id})")
# 如果立即成交_last_order_id 仍然保持 None
else:
# 如果未立即成交将订单ID记录下来以便下一根Bar撤销