实现简单单品种回测
This commit is contained in:
4
.gitignore
vendored
Normal file
4
.gitignore
vendored
Normal file
@@ -0,0 +1,4 @@
|
||||
.idea/
|
||||
.vscode/
|
||||
data/data/
|
||||
**/__pycache__/
|
||||
202
data/tqsdk/tq_copy_data.py
Normal file
202
data/tqsdk/tq_copy_data.py
Normal file
@@ -0,0 +1,202 @@
|
||||
import traceback
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from tqsdk import TqApi, TqAuth, TqBacktest, TqSim # 确保导入所有需要的回测/模拟类
|
||||
import os
|
||||
import datetime
|
||||
from datetime import date # 导入 datetime.date
|
||||
|
||||
# --- 配置您的天勤账号信息 ---
|
||||
# 请替换为您的实盘账号或模拟账号信息
|
||||
# 如果您没有天勤账号,可以注册并获取测试Token:https://www.shinnytech.com/tqsdk/doc/quickstart/
|
||||
TQ_USER_NAME = "emanresu" # 例如: "123456"
|
||||
TQ_PASSWORD = "dfgvfgdfgg" # 例如: "your_password"
|
||||
|
||||
BEIJING_TZ = 'Asia/Shanghai'
|
||||
|
||||
|
||||
def collect_and_save_tqsdk_data_stream(
|
||||
symbol: str,
|
||||
freq: str,
|
||||
start_date_str: str,
|
||||
end_date_str: str,
|
||||
mode: str = "backtest", # 默认为回测模式,因为获取历史数据通常用于回测
|
||||
output_dir: str = "../data",
|
||||
tq_user: str = TQ_USER_NAME,
|
||||
tq_pwd: str = TQ_PASSWORD
|
||||
) -> pd.DataFrame or None:
|
||||
"""
|
||||
通过 TqSdk 在指定模式下(回测或模拟)运行,监听并收集指定品种、频率、日期范围的K线数据流,
|
||||
并将其保存到本地CSV文件。此函数会模拟 TqSdk 的时间流运行。
|
||||
|
||||
Args:
|
||||
symbol (str): 交易品种代码,例如 "SHFE.rb2405", "KQ.i9999"。
|
||||
freq (str): 数据频率,例如 "1min", "5min", "day"。注意:tick数据量过大不推荐此方法直接收集。
|
||||
start_date_str (str): 数据流开始日期,格式 'YYYY-MM-DD'。
|
||||
end_date_str (str): 数据流结束日期,格式 'YYYY-MM-DD'。
|
||||
mode (str): 运行模式,可选 "sim" (模拟/实盘) 或 "backtest" (回测)。默认为 "backtest"。
|
||||
output_dir (str): 数据保存的根目录,默认为 "./data"。
|
||||
tq_user (str): 天勤量化账号。
|
||||
tq_pwd (str): 天勤量化密码。
|
||||
|
||||
Returns:
|
||||
pd.DataFrame or None: 收集到的K线数据DataFrame,如果获取失败则返回 None。
|
||||
请注意,对于非常大的数据量,直接返回DataFrame可能消耗大量内存。
|
||||
"""
|
||||
if not tq_user or not tq_pwd:
|
||||
print("错误: 请在代码中配置您的天勤量化账号和密码。")
|
||||
return None
|
||||
|
||||
api = None
|
||||
collected_data = [] # 用于收集每一根完整K线的数据
|
||||
|
||||
try:
|
||||
start_dt_data_obj = datetime.datetime.strptime(start_date_str, '%Y-%m-%d')
|
||||
end_dt_data_obj = datetime.datetime.strptime(end_date_str, '%Y-%m-%d')
|
||||
|
||||
if mode == "backtest":
|
||||
backtest_start_date = start_dt_data_obj.date()
|
||||
backtest_end_date = end_dt_data_obj.date()
|
||||
print(f"初始化天勤回测API,回测日期范围:{backtest_start_date} 至 {backtest_end_date}")
|
||||
api = TqApi(
|
||||
backtest=TqBacktest(start_dt=backtest_start_date, end_dt=backtest_end_date),
|
||||
auth=TqAuth(tq_user, tq_pwd)
|
||||
)
|
||||
elif mode == "sim":
|
||||
print("初始化天勤模拟/实盘API")
|
||||
api = TqApi(account=TqSim(), auth=TqAuth(tq_user, tq_pwd))
|
||||
# 如果您有实盘账户,可以使用:
|
||||
# api = TqApi(account=TqAccount(tq_user, tq_pwd), auth=TqAuth(tq_user, tq_pwd))
|
||||
else:
|
||||
print(f"错误: 不支持的模式 '{mode}'。请使用 'sim' 或 'backtest'。")
|
||||
return None
|
||||
|
||||
# K线数据获取的duration_seconds
|
||||
duration_seconds = 0
|
||||
if "min" in freq:
|
||||
duration_seconds = int(freq.replace("min", "")) * 60
|
||||
elif freq == "day":
|
||||
duration_seconds = 24 * 60 * 60
|
||||
elif freq == "week":
|
||||
duration_seconds = 7 * 24 * 60 * 60
|
||||
elif freq == "month":
|
||||
duration_seconds = 30 * 24 * 60 * 60 # 大约一个月
|
||||
else:
|
||||
print(f"错误: 不支持的数据频率 '{freq}'。目前支持 '1min', '5min', 'day', 'week', 'month'。")
|
||||
print("注意:Tick数据量巨大,不建议用此方法直接收集,因为它会耗尽内存。")
|
||||
return None
|
||||
|
||||
# 获取K线序列,这里获取的是指定频率的K线,天勤会根据模式从历史或实时流中推送
|
||||
klines = api.get_kline_serial(symbol, duration_seconds)
|
||||
|
||||
print(f"开始在 '{mode}' 模式下收集 {symbol} 从 {start_date_str} 到 {end_date_str} 的 {freq} 数据...")
|
||||
|
||||
last_kline_datetime = None # 用于跟踪上一根已完成K线的时间
|
||||
|
||||
while api.wait_update():
|
||||
|
||||
# 检查是否有新的完整K线生成,或者当前K线是最后一次更新 (在回测结束时)
|
||||
# TqSdk会在K线完成时发送最后一次更新,或者在回测结束时强制更新
|
||||
if api.is_changing(klines):
|
||||
# 只有当K线序列发生变化时才处理
|
||||
# 关注最新一根 K 线(即 klines.iloc[-1])
|
||||
current_kline = klines.iloc[-2]
|
||||
|
||||
# TqSdk 会在K线结束后,或者回测结束时,确保K线为最终状态。
|
||||
# 判断当前K线是否已经结束 (is_last=True) 并且与上一次保存的K线不同
|
||||
# 或者,在回测模式下,回测结束时,最后一根K线也会被视为“完成”
|
||||
# 判断条件:K线时间戳不是 None 且 大于上一次记录的 K线时间
|
||||
if not pd.isna(current_kline['datetime']) and (last_kline_datetime is None or (
|
||||
last_kline_datetime is not None and current_kline['datetime'] > last_kline_datetime)):
|
||||
# 将datetime (微秒) 转换为可读格式
|
||||
|
||||
# 检查K线的时间戳是否在我们要获取的日期范围内
|
||||
# 注意:get_kline_serial 会获取指定范围前后的一小段数据,我们需要过滤
|
||||
|
||||
kline_dt = pd.to_datetime(current_kline['datetime'], unit='ns', utc=True)
|
||||
kline_dt = kline_dt.tz_convert(BEIJING_TZ).strftime('%Y-%m-%d %H:%M:%S')
|
||||
kline_data_to_save = {
|
||||
'datetime': kline_dt,
|
||||
'open': current_kline['open'],
|
||||
'high': current_kline['high'],
|
||||
'low': current_kline['low'],
|
||||
'close': current_kline['close'],
|
||||
'volume': current_kline['volume'],
|
||||
'open_oi': current_kline['open_oi'],
|
||||
'close_oi': current_kline['close_oi']
|
||||
}
|
||||
|
||||
collected_data.append(kline_data_to_save)
|
||||
last_kline_datetime = current_kline['datetime']
|
||||
# print(f"收集到 K线: {kline_dt}, close: {current_kline['close']}") # 用于调试
|
||||
|
||||
# 在回测模式下,当回测结束时,api.wait_update() 会抛出异常,此时我们可以退出循环
|
||||
if api.is_changing(api.get_account()) or api.is_changing(api.get_position()):
|
||||
break
|
||||
|
||||
except Exception as e:
|
||||
# TqBacktest 在数据结束时会抛出 "api已关闭" 或类似的异常,这是正常现象。
|
||||
# 我们在这里捕获并判断是否是正常结束。
|
||||
if "api已关闭" in str(e) or "数据已全部输出" in str(e):
|
||||
print("数据流已结束 (TqSdk API 关闭或数据全部输出)。")
|
||||
else:
|
||||
print(f"数据收集过程中发生错误: {e}")
|
||||
traceback.print_exc()
|
||||
# 无论如何,都尝试处理剩余数据并保存
|
||||
finally:
|
||||
if collected_data:
|
||||
df = pd.DataFrame(collected_data).set_index('datetime')
|
||||
df = df.sort_index() # 确保数据按时间排序
|
||||
|
||||
# 构造保存路径
|
||||
freq_folder = freq.replace("min", "m") if "min" in freq else freq
|
||||
if freq == "day": freq_folder = "daily"
|
||||
if freq == "week": freq_folder = "weekly"
|
||||
if freq == "month": freq_folder = "monthly"
|
||||
|
||||
safe_symbol = symbol.replace('.', '_')
|
||||
|
||||
save_folder = os.path.join(output_dir, safe_symbol)
|
||||
os.makedirs(save_folder, exist_ok=True)
|
||||
|
||||
file_name = f"{safe_symbol}_{freq_folder}_{start_date_str.replace('-', '')}_{end_date_str.replace('-', '')}_{freq}.csv"
|
||||
file_path = os.path.join(save_folder, file_name)
|
||||
|
||||
df.to_csv(file_path, index=True)
|
||||
print(f"数据已成功保存到: {file_path}, 共 {len(df)} 条记录。")
|
||||
|
||||
if api:
|
||||
api.close()
|
||||
return df
|
||||
else:
|
||||
print("没有收集到任何数据。")
|
||||
if api:
|
||||
api.close()
|
||||
return None
|
||||
|
||||
# --- 示例用法 ---
|
||||
if __name__ == "__main__":
|
||||
import os
|
||||
|
||||
current_dir = os.getcwd()
|
||||
print("当前工作目录:", current_dir)
|
||||
|
||||
# !!!重要:请先在这里替换成您的天勤账号和密码!!!
|
||||
# 否则程序无法运行。
|
||||
TQ_USER_NAME = "emanresu" # 例如: "123456"
|
||||
TQ_PASSWORD = "dfgvfgdfgg" # 例如: "your_password"
|
||||
|
||||
# 示例1: 在回测模式下获取沪深300指数主连的日线数据 (用于历史回测)
|
||||
# 这种方式适合获取相对较短或中等长度的历史K线数据。
|
||||
df_if_backtest_daily = collect_and_save_tqsdk_data_stream(
|
||||
symbol="SHFE.rb2501",
|
||||
freq="min1",
|
||||
start_date_str="2024-09-01",
|
||||
end_date_str="2024-12-01",
|
||||
mode="backtest", # 指定为回测模式
|
||||
tq_user=TQ_USER_NAME,
|
||||
tq_pwd=TQ_PASSWORD
|
||||
)
|
||||
if df_if_backtest_daily is not None:
|
||||
print(df_if_backtest_daily.tail())
|
||||
658
main.ipynb
Normal file
658
main.ipynb
Normal file
File diff suppressed because one or more lines are too long
77
main.py
Normal file
77
main.py
Normal file
@@ -0,0 +1,77 @@
|
||||
# main.py
|
||||
|
||||
from src.analysis.result_analyzer import ResultAnalyzer
|
||||
# 导入所有必要的模块
|
||||
from src.data_manager import DataManager
|
||||
from src.backtest_engine import BacktestEngine
|
||||
from src.strategies.simple_limit_buy_strategy import SimpleLimitBuyStrategy
|
||||
|
||||
|
||||
def main():
|
||||
# --- 配置参数 ---
|
||||
# 获取当前脚本所在目录,假设数据文件在项目根目录下的 data 文件夹内
|
||||
data_file_path = '/mnt/d/PyProject/NewQuant/data/data/SHFE_rb2501/SHFE_rb2501_m60_20240901_20241201_min60.csv'
|
||||
|
||||
initial_capital = 100000.0
|
||||
slippage_rate = 0.001 # 假设每笔交易0.1%的滑点
|
||||
commission_rate = 0.0002 # 假设每笔交易0.02%的佣金
|
||||
|
||||
strategy_parameters = {
|
||||
'symbol': "SHFE_rb2501", # 根据您的数据文件中的品种名称调整
|
||||
'trade_volume': 1, # 每次交易1手/股
|
||||
'limit_price_factor': 0.995, # 限价单价格为开盘价的99.5%
|
||||
'max_position': 10 # 最大持仓10手/股
|
||||
}
|
||||
|
||||
# --- 1. 初始化数据管理器 ---
|
||||
print("初始化数据管理器...")
|
||||
data_manager = DataManager(file_path=data_file_path)
|
||||
# 确保 DataManager 能够重置以进行多次回测
|
||||
# data_manager.reset() # 首次运行不需要重置
|
||||
|
||||
# --- 2. 初始化回测引擎并运行 ---
|
||||
print("\n初始化回测引擎...")
|
||||
engine = BacktestEngine(
|
||||
data_manager=data_manager,
|
||||
strategy_class=SimpleLimitBuyStrategy,
|
||||
strategy_params=strategy_parameters,
|
||||
initial_capital=initial_capital,
|
||||
slippage_rate=slippage_rate,
|
||||
commission_rate=commission_rate
|
||||
)
|
||||
|
||||
print("\n开始运行回测...")
|
||||
engine.run_backtest()
|
||||
print("\n回测运行完毕。")
|
||||
|
||||
# --- 3. 获取回测结果 ---
|
||||
results = engine.get_backtest_results()
|
||||
portfolio_snapshots = results["portfolio_snapshots"]
|
||||
trade_history = results["trade_history"]
|
||||
initial_capital_result = results["initial_capital"]
|
||||
|
||||
# --- 4. 结果分析与可视化 ---
|
||||
if portfolio_snapshots:
|
||||
analyzer = ResultAnalyzer(portfolio_snapshots, trade_history, initial_capital_result)
|
||||
|
||||
analyzer.generate_report()
|
||||
analyzer.plot_performance()
|
||||
else:
|
||||
print("\n没有生成投资组合快照,无法进行结果分析。")
|
||||
|
||||
# --- 4. 结果分析与可视化 (待实现) ---
|
||||
# if portfolio_snapshots:
|
||||
# analyzer = ResultAnalyzer(portfolio_snapshots, trade_history, initial_capital_result)
|
||||
# metrics = analyzer.calculate_all_metrics()
|
||||
# print("\n--- 绩效指标 ---")
|
||||
# for key, value in metrics.items():
|
||||
# print(f" {key}: {value:.4f}")
|
||||
#
|
||||
# print("\n--- 绘制绩效图表 ---")
|
||||
# analyzer.plot_performance()
|
||||
# else:
|
||||
# print("\n没有生成投资组合快照,无法进行结果分析。")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
0
src/__init__.py
Normal file
0
src/__init__.py
Normal file
0
src/analysis/__init__.py
Normal file
0
src/analysis/__init__.py
Normal file
239
src/analysis/analysis_utils.py
Normal file
239
src/analysis/analysis_utils.py
Normal file
@@ -0,0 +1,239 @@
|
||||
# src/analysis/analysis_utils.py (修改 calculate_metrics 函数)
|
||||
import matplotlib
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
from typing import List, Dict, Any
|
||||
|
||||
from ..core_data import PortfolioSnapshot, Trade, Bar
|
||||
|
||||
|
||||
def calculate_metrics(
|
||||
snapshots: List[PortfolioSnapshot], trades: List[Trade], initial_capital: float
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
纯函数:根据投资组合快照和交易历史计算关键绩效指标。
|
||||
|
||||
Args:
|
||||
snapshots (List[PortfolioSnapshot]): 投资组合快照列表。
|
||||
trades (List[Trade]): 交易历史记录列表。
|
||||
initial_capital (float): 初始资金。
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: 包含各种绩效指标的字典。
|
||||
"""
|
||||
if not snapshots:
|
||||
return {
|
||||
"总收益率": 0.0,
|
||||
"年化收益率": 0.0,
|
||||
"最大回撤": 0.0,
|
||||
"夏普比率": 0.0,
|
||||
"卡玛比率": 0.0,
|
||||
"胜率": 0.0,
|
||||
"盈亏比": 0.0,
|
||||
"总交易次数": len(trades),
|
||||
"盈利交易次数": 0,
|
||||
"亏损交易次数": 0,
|
||||
"平均每次盈利": 0.0,
|
||||
"平均每次亏损": 0.0,
|
||||
"交易成本": 0.0,
|
||||
"总实现盈亏": 0.0,
|
||||
}
|
||||
|
||||
df_values = pd.DataFrame(
|
||||
[{"datetime": s.datetime, "total_value": s.total_value} for s in snapshots]
|
||||
).set_index("datetime")
|
||||
|
||||
df_returns = df_values["total_value"].pct_change().fillna(0)
|
||||
|
||||
final_value = df_values["total_value"].iloc[-1]
|
||||
total_return = (final_value / initial_capital) - 1
|
||||
|
||||
total_days = (df_values.index.max() - df_values.index.min()).days
|
||||
if total_days > 0:
|
||||
annualized_return = (1 + total_return) ** (252 / total_days) - 1
|
||||
else:
|
||||
annualized_return = 0.0
|
||||
|
||||
rolling_max = df_values["total_value"].cummax()
|
||||
daily_drawdown = (rolling_max - df_values["total_value"]) / rolling_max
|
||||
max_drawdown = daily_drawdown.max()
|
||||
|
||||
excess_daily_returns = df_returns
|
||||
daily_volatility = excess_daily_returns.std()
|
||||
|
||||
if daily_volatility > 0:
|
||||
sharpe_ratio = np.sqrt(252) * (excess_daily_returns.mean() / daily_volatility)
|
||||
else:
|
||||
sharpe_ratio = 0.0
|
||||
|
||||
if max_drawdown > 0:
|
||||
calmar_ratio = annualized_return / max_drawdown
|
||||
else:
|
||||
calmar_ratio = float("inf")
|
||||
|
||||
total_commissions = sum(t.commission for t in trades)
|
||||
|
||||
# --- 重新计算交易相关指标 ---
|
||||
realized_pnl_trades = [
|
||||
t.realized_pnl for t in trades if t.is_close_trade
|
||||
] # 只关注平仓交易的盈亏
|
||||
|
||||
winning_pnl = [pnl for pnl in realized_pnl_trades if pnl > 0]
|
||||
losing_pnl = [pnl for pnl in realized_pnl_trades if pnl < 0]
|
||||
|
||||
winning_count = len(winning_pnl)
|
||||
losing_count = len(losing_pnl)
|
||||
total_closed_trades = winning_count + losing_count
|
||||
|
||||
total_profit_per_trade = sum(winning_pnl)
|
||||
total_loss_per_trade = sum(losing_pnl) # sum of negative values
|
||||
|
||||
avg_profit_per_trade = (
|
||||
total_profit_per_trade / winning_count if winning_count > 0 else 0.0
|
||||
)
|
||||
avg_loss_per_trade = (
|
||||
total_loss_per_trade / losing_count if losing_count > 0 else 0.0
|
||||
) # 这是负值
|
||||
|
||||
win_rate = winning_count / total_closed_trades if total_closed_trades > 0 else 0.0
|
||||
|
||||
# 盈亏比 = 平均盈利 / 平均亏损的绝对值
|
||||
profit_loss_ratio = (
|
||||
abs(avg_profit_per_trade / avg_loss_per_trade)
|
||||
if avg_loss_per_trade != 0
|
||||
else float("inf")
|
||||
)
|
||||
|
||||
total_realized_pnl = sum(realized_pnl_trades)
|
||||
|
||||
return {
|
||||
"初始资金": initial_capital,
|
||||
"最终资金": final_value,
|
||||
"总收益率": total_return,
|
||||
"年化收益率": annualized_return,
|
||||
"最大回撤": max_drawdown,
|
||||
"夏普比率": sharpe_ratio,
|
||||
"卡玛比率": calmar_ratio,
|
||||
"总交易次数": len(trades), # 所有的买卖交易
|
||||
"交易成本": total_commissions,
|
||||
"总实现盈亏": total_realized_pnl, # 新增
|
||||
"胜率": win_rate,
|
||||
"盈亏比": profit_loss_ratio,
|
||||
"盈利交易次数": winning_count,
|
||||
"亏损交易次数": losing_count,
|
||||
"平均每次盈利": avg_profit_per_trade,
|
||||
"平均每次亏损": avg_loss_per_trade, # 这个值是负数
|
||||
}
|
||||
|
||||
|
||||
def plot_equity_and_drawdown_chart(snapshots: List[PortfolioSnapshot], initial_capital: float,
|
||||
title: str = "Portfolio Equity and Drawdown Curve") -> None:
|
||||
"""
|
||||
Plots the portfolio equity curve and drawdown. X-axis points are equally spaced.
|
||||
|
||||
Args:
|
||||
snapshots (List[PortfolioSnapshot]): List of portfolio snapshots.
|
||||
initial_capital (float): Initial capital.
|
||||
title (str): Title of the chart.
|
||||
"""
|
||||
if not snapshots:
|
||||
print("No portfolio snapshots available to plot equity and drawdown.")
|
||||
return
|
||||
|
||||
df_equity = pd.DataFrame([
|
||||
{'datetime': s.datetime, 'total_value': s.total_value}
|
||||
for s in snapshots
|
||||
])
|
||||
|
||||
equity_curve = df_equity['total_value'] / initial_capital
|
||||
|
||||
rolling_max = equity_curve.cummax()
|
||||
drawdown = (rolling_max - equity_curve) / rolling_max
|
||||
|
||||
plt.style.use('seaborn-v0_8-darkgrid')
|
||||
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(14, 10), sharex=True, gridspec_kw={'height_ratios': [3, 1]})
|
||||
|
||||
x_axis_indices = np.arange(len(df_equity))
|
||||
|
||||
# Equity Curve Plot
|
||||
ax1.plot(x_axis_indices, equity_curve, label='Equity Curve', color='blue', linewidth=1.5)
|
||||
ax1.set_ylabel('Equity', fontsize=12)
|
||||
ax1.legend(loc='upper left')
|
||||
ax1.grid(True)
|
||||
ax1.set_title(title, fontsize=16)
|
||||
|
||||
# Drawdown Curve Plot
|
||||
ax2.fill_between(x_axis_indices, 0, drawdown, color='red', alpha=0.3)
|
||||
ax2.plot(x_axis_indices, drawdown, color='red', linewidth=1.0, linestyle='--', label='Drawdown')
|
||||
ax2.set_ylabel('Drawdown Rate', fontsize=12)
|
||||
ax2.set_xlabel('Data Point Index (Date Labels Below)', fontsize=12)
|
||||
ax2.set_title('Portfolio Drawdown Curve', fontsize=14)
|
||||
ax2.legend(loc='upper left')
|
||||
ax2.grid(True)
|
||||
ax2.set_ylim(0, max(drawdown.max() * 1.1, 0.05))
|
||||
|
||||
# Set X-axis ticks to show actual dates at intervals
|
||||
num_ticks = 10
|
||||
if len(df_equity) > 0:
|
||||
tick_positions = np.linspace(0, len(df_equity) - 1, num_ticks, dtype=int)
|
||||
tick_labels = [df_equity['datetime'].iloc[i].strftime('%Y-%m-%d %H:%M') for i in tick_positions]
|
||||
ax1.set_xticks(tick_positions)
|
||||
ax1.set_xticklabels(tick_labels, rotation=45, ha='right')
|
||||
|
||||
plt.tight_layout()
|
||||
plt.show()
|
||||
|
||||
|
||||
def plot_close_price_chart(bars: List[Bar], title: str = "Close Price Chart") -> None:
|
||||
"""
|
||||
Plots the underlying asset's close price. X-axis points are equally spaced.
|
||||
|
||||
Args:
|
||||
bars (List[Bar]): List of all processed Bar data.
|
||||
title (str): Title of the chart.
|
||||
"""
|
||||
if not bars:
|
||||
print("No bar data available to plot close price.")
|
||||
return
|
||||
|
||||
df_prices = pd.DataFrame([
|
||||
{'datetime': b.datetime, 'close_price': b.close}
|
||||
for b in bars
|
||||
])
|
||||
|
||||
plt.style.use('seaborn-v0_8-darkgrid')
|
||||
fig, ax = plt.subplots(1, 1, figsize=(14, 7)) # Single subplot
|
||||
|
||||
x_axis_indices = np.arange(len(df_prices))
|
||||
|
||||
ax.plot(x_axis_indices, df_prices['close_price'], label='Close Price', color='orange', linewidth=1.5)
|
||||
ax.set_ylabel('Price', fontsize=12)
|
||||
ax.set_xlabel('Data Point Index (Date Labels Below)', fontsize=12)
|
||||
ax.set_title(title, fontsize=16)
|
||||
ax.legend(loc='upper left')
|
||||
ax.grid(True)
|
||||
|
||||
# Set X-axis ticks to show actual dates at intervals
|
||||
num_ticks = 10
|
||||
if len(df_prices) > 0:
|
||||
tick_positions = np.linspace(0, len(df_prices) - 1, num_ticks, dtype=int)
|
||||
tick_labels = [df_prices['datetime'].iloc[i].strftime('%Y-%m-%d %H:%M') for i in tick_positions]
|
||||
ax.set_xticks(tick_positions)
|
||||
ax.set_xticklabels(tick_labels, rotation=45, ha='right')
|
||||
|
||||
plt.tight_layout()
|
||||
plt.show()
|
||||
|
||||
|
||||
# 辅助函数:计算单笔交易的盈亏
|
||||
def calculate_trade_pnl(
|
||||
trade: Trade, entry_price: float, exit_price: float, direction: str
|
||||
) -> float:
|
||||
if direction == "LONG":
|
||||
pnl = (exit_price - entry_price) * trade.volume
|
||||
elif direction == "SHORT":
|
||||
pnl = (entry_price - exit_price) * trade.volume
|
||||
else:
|
||||
pnl = 0.0
|
||||
return pnl
|
||||
88
src/analysis/result_analyzer.py
Normal file
88
src/analysis/result_analyzer.py
Normal file
@@ -0,0 +1,88 @@
|
||||
# src/analysis/result_analyzer.py
|
||||
|
||||
import pandas as pd
|
||||
from typing import List, Dict, Any, Optional
|
||||
|
||||
# 导入纯函数 (注意相对导入路径的变化)
|
||||
from .analysis_utils import calculate_metrics, plot_equity_and_drawdown_chart, plot_close_price_chart
|
||||
# 导入核心数据类 (注意相对导入路径的变化)
|
||||
from ..core_data import PortfolioSnapshot, Trade, Bar
|
||||
|
||||
|
||||
class ResultAnalyzer:
|
||||
"""
|
||||
结果分析器:负责接收回测数据,并提供分析和可视化方法。
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
portfolio_snapshots: List[PortfolioSnapshot],
|
||||
trade_history: List[Trade],
|
||||
bars: List[Bar],
|
||||
initial_capital: float):
|
||||
"""
|
||||
Args:
|
||||
portfolio_snapshots (List[PortfolioSnapshot]): 回测引擎输出的投资组合快照列表。
|
||||
trade_history (List[Trade]): 回测引擎输出的交易历史记录列表。
|
||||
initial_capital (float): 初始资金。
|
||||
"""
|
||||
self.portfolio_snapshots = portfolio_snapshots
|
||||
self.trade_history = trade_history
|
||||
self.initial_capital = initial_capital
|
||||
self.bars = bars
|
||||
self._metrics_cache: Optional[Dict[str, Any]] = None
|
||||
|
||||
print("\n--- 结果分析器初始化完成 ---")
|
||||
|
||||
def calculate_all_metrics(self) -> Dict[str, Any]:
|
||||
"""
|
||||
计算所有关键绩效指标。
|
||||
如果已计算过则返回缓存结果,否则调用纯函数计算。
|
||||
"""
|
||||
if self._metrics_cache is None:
|
||||
print("正在计算绩效指标...")
|
||||
self._metrics_cache = calculate_metrics(
|
||||
self.portfolio_snapshots,
|
||||
self.trade_history,
|
||||
self.initial_capital
|
||||
)
|
||||
print("绩效指标计算完成。")
|
||||
return self._metrics_cache
|
||||
|
||||
def generate_report(self) -> None:
|
||||
"""
|
||||
生成并打印详细的回测报告。
|
||||
"""
|
||||
metrics = self.calculate_all_metrics()
|
||||
|
||||
print("\n--- 回测绩效报告 ---")
|
||||
print(f"{'初始资金':<15}: {metrics['初始资金']:.2f}")
|
||||
print(f"{'最终资金':<15}: {metrics['最终资金']:.2f}")
|
||||
print(f"{'总收益率':<15}: {metrics['总收益率']:.2%}")
|
||||
print(f"{'年化收益率':<15}: {metrics['年化收益率']:.2%}")
|
||||
print(f"{'最大回撤':<15}: {metrics['最大回撤']:.2%}")
|
||||
print(f"{'夏普比率':<15}: {metrics['夏普比率']:.2f}")
|
||||
print(f"{'卡玛比率':<15}: {metrics['卡玛比率']:.2f}")
|
||||
print(f"{'总交易次数':<15}: {metrics['总交易次数']}")
|
||||
print(f"{'交易成本':<15}: {metrics['交易成本']:.2f}")
|
||||
|
||||
if self.trade_history:
|
||||
print("\n--- 部分交易明细 (最近5笔) ---")
|
||||
for trade in self.trade_history[-5:]:
|
||||
print(
|
||||
f" {trade.fill_time} | {trade.direction:<10} | {trade.symbol} | Vol: {trade.volume} | Price: {trade.price:.2f} | Commission: {trade.commission:.2f}")
|
||||
else:
|
||||
print("\n没有交易记录。")
|
||||
|
||||
def plot_performance(self) -> None:
|
||||
"""
|
||||
绘制投资组合净值和回撤曲线。
|
||||
"""
|
||||
print("正在绘制绩效图表...")
|
||||
# plot_performance_chart(self.portfolio_snapshots, self.initial_capital, self.bars)
|
||||
plot_equity_and_drawdown_chart(self.portfolio_snapshots, self.initial_capital,
|
||||
title="Portfolio Equity and Drawdown Curve")
|
||||
|
||||
# 绘制单独的收盘价曲线
|
||||
plot_close_price_chart(self.bars, title="Underlying Asset Close Price")
|
||||
|
||||
print("图表绘制完成。")
|
||||
82
src/backtest_context.py
Normal file
82
src/backtest_context.py
Normal file
@@ -0,0 +1,82 @@
|
||||
# src/backtest_context.py
|
||||
|
||||
from typing import Dict, List, Optional
|
||||
import pandas as pd
|
||||
|
||||
# 导入核心数据类
|
||||
from .core_data import Bar, Order, Trade
|
||||
|
||||
# 导入 DataManager 的相关类 (虽然这里不是类,但 BacktestEngine 会传入)
|
||||
# 为了避免循环导入,这里假设 BacktestEngine 会传入 DataManager 实例
|
||||
# 或者,如果 DataManager 是一个函数集合,那么 BacktestContext 需要直接访问这些函数
|
||||
# 在这里,我们假设 DataManager 是一个 OOP 类,或者有一个接口让 Context 可以获取历史Bar
|
||||
from .data_manager import DataManager # 假设 DataManager 是一个类
|
||||
|
||||
# 导入 ExecutionSimulator
|
||||
from .execution_simulator import ExecutionSimulator
|
||||
|
||||
class BacktestContext:
|
||||
"""
|
||||
回测上下文:作为策略与回测引擎和模拟器之间的桥梁。
|
||||
策略通过Context获取市场数据和发出交易指令。
|
||||
"""
|
||||
def __init__(self, data_manager: DataManager, simulator: ExecutionSimulator):
|
||||
self._data_manager = data_manager # 引用DataManager实例
|
||||
self._simulator = simulator # 引用ExecutionSimulator实例
|
||||
self._current_bar: Optional[Bar] = None # 当前Bar,由引擎设置
|
||||
|
||||
def set_current_bar(self, bar: Bar):
|
||||
"""由回测引擎调用,更新当前Bar。"""
|
||||
self._current_bar = bar
|
||||
|
||||
def get_current_bar(self) -> Bar:
|
||||
"""获取当前正在处理的Bar对象。"""
|
||||
if self._current_bar is None:
|
||||
raise RuntimeError("当前Bar未设置。请确保在策略on_bar调用前已设置。")
|
||||
return self._current_bar
|
||||
|
||||
def get_history_bars(self, num_bars: int) -> List[Bar]:
|
||||
"""
|
||||
获取当前Bar之前的历史Bar列表。
|
||||
通过DataManager获取,确保不包含未来函数。
|
||||
"""
|
||||
# DataManager需要提供一个方法来获取不包含当前bar的历史数据
|
||||
# 这里假设DataManager已经实现了这个功能,通过内部索引管理
|
||||
# 注意:这里的实现需要和 DataManager 内部逻辑匹配,确保严格的未来函数防范
|
||||
if self._current_bar is None:
|
||||
raise RuntimeError("当前Bar未设置,无法获取历史Bar。")
|
||||
return self._data_manager.get_history_bars(self._current_bar.datetime, num_bars)
|
||||
|
||||
def send_order(self, order: Order) -> Optional[Trade]:
|
||||
"""
|
||||
策略通过此方法发出交易订单。
|
||||
"""
|
||||
if self._current_bar is None:
|
||||
raise RuntimeError("当前Bar未设置,无法发送订单。")
|
||||
# 将订单转发给模拟器执行
|
||||
trade = self._simulator.send_order(order, self._current_bar)
|
||||
# 可以在这里触发策略的on_trade回调(如果策略定义了)
|
||||
return trade
|
||||
|
||||
def get_current_positions(self) -> Dict[str, int]:
|
||||
"""
|
||||
获取当前模拟器的持仓情况。
|
||||
"""
|
||||
return self._simulator.get_current_positions()
|
||||
|
||||
def get_current_cash(self) -> float:
|
||||
"""
|
||||
获取当前模拟器的可用资金。
|
||||
"""
|
||||
return self._simulator.cash
|
||||
|
||||
def get_current_portfolio_value(self, current_bar: Bar) -> float:
|
||||
"""
|
||||
获取当前的投资组合总价值(包括现金和持仓市值)。
|
||||
Args:
|
||||
current_bar (Bar): 当前的Bar数据,用于计算持仓市值。
|
||||
Returns:
|
||||
float: 当前的投资组合总价值。
|
||||
"""
|
||||
# 调用底层模拟器的方法来获取投资组合价值
|
||||
return self._simulator.get_portfolio_value(current_bar)
|
||||
138
src/backtest_engine.py
Normal file
138
src/backtest_engine.py
Normal file
@@ -0,0 +1,138 @@
|
||||
# src/backtest_engine.py
|
||||
|
||||
from typing import Type, Dict, Any, List
|
||||
import pandas as pd
|
||||
|
||||
# 导入所有需要协调的模块
|
||||
from .core_data import Bar, Order, Trade, PortfolioSnapshot
|
||||
from .data_manager import DataManager
|
||||
from .execution_simulator import ExecutionSimulator
|
||||
from .backtest_context import BacktestContext
|
||||
from .strategies.base_strategy import Strategy # 导入策略基类
|
||||
|
||||
|
||||
class BacktestEngine:
|
||||
"""
|
||||
回测引擎:协调数据流、策略执行、订单模拟和结果记录。
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
data_manager: DataManager,
|
||||
strategy_class: Type[Strategy],
|
||||
strategy_params: Dict[str, Any],
|
||||
initial_capital: float = 100000.0,
|
||||
slippage_rate: float = 0.0001,
|
||||
commission_rate: float = 0.0002):
|
||||
"""
|
||||
初始化回测引擎。
|
||||
|
||||
Args:
|
||||
data_manager (DataManager): 已经初始化好的数据管理器实例。
|
||||
strategy_class (Type[Strategy]): 策略类(而不是实例),引擎会负责实例化。
|
||||
strategy_params (Dict[str, Any]): 传递给策略的参数字典。
|
||||
initial_capital (float): 初始交易资金。
|
||||
slippage_rate (float): 交易滑点率。
|
||||
commission_rate (float): 交易佣金率。
|
||||
"""
|
||||
self.data_manager = data_manager
|
||||
self.simulator = ExecutionSimulator(
|
||||
initial_capital=initial_capital,
|
||||
slippage_rate=slippage_rate,
|
||||
commission_rate=commission_rate
|
||||
)
|
||||
self.context = BacktestContext(self.data_manager, self.simulator)
|
||||
|
||||
# 实例化策略
|
||||
self.strategy = strategy_class(self.context, **strategy_params)
|
||||
|
||||
self.portfolio_snapshots: List[PortfolioSnapshot] = [] # 存储每天的投资组合快照
|
||||
self.trade_history: List[Trade] = [] # 存储所有成交记录
|
||||
self.all_bars: List[Bar] = []
|
||||
|
||||
# 历史Bar缓存,用于特征计算
|
||||
self._history_bars: List[Bar] = []
|
||||
self._max_history_bars: int = 200 # 例如,只保留最近200根Bar的历史数据,可根据策略需求调整
|
||||
|
||||
print("\n--- 回测引擎初始化完成 ---")
|
||||
print(f" 策略: {strategy_class.__name__}")
|
||||
print(f" 初始资金: {initial_capital:.2f}")
|
||||
|
||||
def run_backtest(self):
|
||||
"""
|
||||
运行整个回测流程。
|
||||
"""
|
||||
print("\n--- 回测开始 ---")
|
||||
|
||||
# 调用策略的初始化方法
|
||||
self.strategy.on_init()
|
||||
|
||||
# 主回测循环
|
||||
while True:
|
||||
current_bar = self.data_manager.get_next_bar()
|
||||
if current_bar is None:
|
||||
break # 没有更多数据,回测结束
|
||||
|
||||
# 设置当前Bar到Context,供策略访问
|
||||
self.context.set_current_bar(current_bar)
|
||||
|
||||
# 更新历史Bar缓存
|
||||
self._history_bars.append(current_bar)
|
||||
if len(self._history_bars) > self._max_history_bars:
|
||||
self._history_bars.pop(0) # 移除最旧的Bar
|
||||
|
||||
# 1. 计算特征 (使用纯函数)
|
||||
# 注意: extract_bar_features 接收的是完整的历史数据,不包含当前Bar
|
||||
# 但为了简单起见,这里传入的是包含当前bar在内的历史数据,但内部函数应确保不使用“未来”数据
|
||||
# 严格来说,应该传入 self._history_bars[:-1]
|
||||
# features = extract_bar_features(current_bar, self._history_bars[:-1]) # 传入当前Bar之前的所有历史Bar
|
||||
|
||||
# 2. 调用策略的 on_bar 方法
|
||||
self.strategy.on_bar(current_bar)
|
||||
|
||||
# 3. 记录投资组合快照
|
||||
current_portfolio_value = self.simulator.get_portfolio_value(current_bar)
|
||||
current_positions = self.simulator.get_current_positions()
|
||||
|
||||
# 创建 PortfolioSnapshot,记录当前Bar的收盘价
|
||||
price_at_snapshot = {
|
||||
current_bar.symbol if hasattr(current_bar, 'symbol') else "DEFAULT_SYMBOL": current_bar.close}
|
||||
|
||||
snapshot = PortfolioSnapshot(
|
||||
datetime=current_bar.datetime,
|
||||
total_value=current_portfolio_value,
|
||||
cash=self.simulator.cash,
|
||||
positions=current_positions,
|
||||
price_at_snapshot=price_at_snapshot
|
||||
)
|
||||
self.portfolio_snapshots.append(snapshot)
|
||||
self.all_bars.append(current_bar)
|
||||
|
||||
# 记录交易历史(从模拟器获取)
|
||||
# 简化处理:每次获取模拟器中的所有交易历史,并更新引擎的trade_history
|
||||
# 更好的做法是模拟器提供一个方法,返回自上次查询以来的新增交易
|
||||
# 这里为了不重复添加,可以在 trade_log 中只添加当前 Bar 生成的交易
|
||||
|
||||
# 在 on_bar 循环的末尾,获取本Bar周期内新产生的交易
|
||||
# 模拟器在每次send_order成功时会将trade添加到其trade_log
|
||||
# 这里可以做一个增量获取,或者简单地在循环结束后统一获取
|
||||
# 目前我们在执行模拟器中已经将成交记录在了 trade_log 中,所以这里不用重复记录,
|
||||
# 而是等到回测结束后再统一获取。
|
||||
pass # 不在此处记录 self.trade_history
|
||||
|
||||
# 回测结束后,获取所有交易记录
|
||||
self.trade_history = self.simulator.get_trade_history()
|
||||
|
||||
print("--- 回测结束 ---")
|
||||
print(f"总计处理了 {len(self.portfolio_snapshots)} 根K线。")
|
||||
print(f"总计发生了 {len(self.trade_history)} 笔交易。")
|
||||
|
||||
def get_backtest_results(self) -> Dict[str, Any]:
|
||||
"""
|
||||
返回回测结果数据,供结果分析模块使用。
|
||||
"""
|
||||
return {
|
||||
"portfolio_snapshots": self.portfolio_snapshots,
|
||||
"trade_history": self.trade_history,
|
||||
"initial_capital": self.simulator.initial_capital, # 或 self.initial_capital
|
||||
"all_bars": self.all_bars
|
||||
}
|
||||
98
src/core_data.py
Normal file
98
src/core_data.py
Normal file
@@ -0,0 +1,98 @@
|
||||
# src/core_data.py
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
import pandas as pd
|
||||
from typing import Dict, Any, List, Optional
|
||||
import uuid # 用于生成唯一订单ID
|
||||
|
||||
|
||||
@dataclass(frozen=True) # frozen=True 使实例变为不可变
|
||||
class Bar:
|
||||
"""
|
||||
K线数据对象,包含期货或股票的 OHLCV 和持仓量信息。
|
||||
"""
|
||||
datetime: pd.Timestamp
|
||||
open: float
|
||||
high: float
|
||||
low: float
|
||||
close: float
|
||||
volume: int
|
||||
open_oi: int # 开盘持仓量 (Open Interest)
|
||||
close_oi: int # 收盘持仓量
|
||||
symbol: str
|
||||
|
||||
def __post_init__(self):
|
||||
"""
|
||||
数据验证(可选):确保持仓量为非负整数。
|
||||
"""
|
||||
if not isinstance(self.volume, int) or self.volume < 0:
|
||||
raise ValueError(f"Volume must be a non-negative integer, got {self.volume}")
|
||||
if not isinstance(self.open_oi, int) or self.open_oi < 0:
|
||||
raise ValueError(f"Open interest must be a non-negative integer, got {self.open_oi}")
|
||||
if not isinstance(self.close_oi, int) or self.close_oi < 0:
|
||||
raise ValueError(f"Close interest must be a non-negative integer, got {self.close_oi}")
|
||||
|
||||
# 验证价格是否合理
|
||||
if not (self.high >= self.open and self.high >= self.close and \
|
||||
self.low <= self.open and self.low <= self.close and \
|
||||
self.high >= self.low):
|
||||
# 仅作警告,如果数据源可靠,可能不需要严格检查
|
||||
# print(f"Warning: Abnormal OHLC for bar at {self.datetime}")
|
||||
pass
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Order:
|
||||
"""
|
||||
代表一个待执行的交易指令。
|
||||
"""
|
||||
symbol: str # 暂时简化为单品种交易,可扩展
|
||||
direction: str # "BUY", "SELL", "CLOSE_LONG", "CLOSE_SHORT"
|
||||
volume: int # 交易数量
|
||||
id: str = field(default_factory=lambda: str(uuid.uuid4())) # 唯一订单ID
|
||||
price_type: str = "MARKET" # "MARKET", "LIMIT" (简易版默认市价)
|
||||
limit_price: Optional[float] = None # 限价单价格
|
||||
submitted_time: pd.Timestamp = field(default_factory=pd.Timestamp.now) # 订单提交时间
|
||||
|
||||
def __post_init__(self):
|
||||
if self.direction not in ["BUY", "SELL", "CLOSE_LONG", "CLOSE_SHORT", "CANCEL"]:
|
||||
raise ValueError(f"Invalid order direction: {self.direction}")
|
||||
if self.price_type not in ["MARKET", "LIMIT", "CANCEL"]:
|
||||
raise ValueError(f"Invalid price type: {self.price_type}")
|
||||
if not isinstance(self.volume, int) or self.volume < 0:
|
||||
raise ValueError(f"Order volume must be a positive integer, got {self.volume}")
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Trade:
|
||||
"""
|
||||
代表一个已完成的成交记录。
|
||||
"""
|
||||
order_id: str
|
||||
fill_time: pd.Timestamp
|
||||
symbol: str
|
||||
direction: str # "BUY", "SELL" (实际成交方向)
|
||||
volume: int
|
||||
price: float
|
||||
commission: float
|
||||
# 记录成交后的账户状态(可选,用于方便调试)
|
||||
cash_after_trade: float
|
||||
positions_after_trade: Dict[str, int]
|
||||
realized_pnl: float = 0.0 # <--- 新增字段:这笔交易带来的实现盈亏
|
||||
is_open_trade: bool = False # <--- 新增字段:是否是开仓交易(用于跟踪成本)
|
||||
is_close_trade: bool = False # <--- 新增字段:是否是平仓交易 (用于计算盈亏)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class PortfolioSnapshot:
|
||||
"""
|
||||
在特定时间点记录投资组合的快照。
|
||||
"""
|
||||
datetime: pd.Timestamp
|
||||
total_value: float
|
||||
cash: float
|
||||
positions: Dict[str, int] # {symbol: quantity}
|
||||
price_at_snapshot: Dict[str, float] # {symbol: price},用于计算市值
|
||||
|
||||
# 确保 Pandas Timestamp 的默认时区为None,或者在创建时明确指定
|
||||
# pd.set_option('mode.chained_assignment', None) # 避免SettingWithCopyWarning
|
||||
101
src/data_manager.py
Normal file
101
src/data_manager.py
Normal file
@@ -0,0 +1,101 @@
|
||||
# src/data_manager.py (修改并添加 get_history_bars 方法)
|
||||
|
||||
import pandas as pd
|
||||
from typing import Iterator, List, Dict, Any, Optional
|
||||
import os
|
||||
|
||||
# 导入我们刚刚定义的 Bar 类
|
||||
from .core_data import Bar
|
||||
from .data_processing import load_raw_data, df_to_bar_stream
|
||||
|
||||
|
||||
class DataManager: # DataManager 现在是一个类,以便维护内部索引和提供历史数据
|
||||
"""
|
||||
负责从外部数据源加载数据,并将其转换为统一的、可供回测引擎使用的 Bar 对象。
|
||||
并提供获取历史Bar的能力。
|
||||
"""
|
||||
|
||||
def __init__(self, file_path: str, symbol: str, tz='Asia/Shanghai'):
|
||||
self.file_path = file_path
|
||||
self.tz = tz
|
||||
self.raw_df = load_raw_data(self.file_path) # 调用函数式加载数据
|
||||
# self.bars = list(df_to_bar_stream(self.raw_df)) # 一次性转换所有Bar,方便历史数据查找
|
||||
# 优化:使用内部迭代器和缓存,避免一次性生成所有Bar,但可以按需提供历史
|
||||
self._bar_generator = df_to_bar_stream(self.raw_df, symbol)
|
||||
self._bars_cache: List[Bar] = [] # 缓存已经处理过的Bar
|
||||
self._last_bar_index = -1 # 用于跟踪当前的Bar在原始df中的索引
|
||||
self.symbol = symbol
|
||||
|
||||
def get_next_bar(self) -> Optional[Bar]:
|
||||
"""
|
||||
按顺序返回下一根 Bar 对象。
|
||||
"""
|
||||
try:
|
||||
next_bar = next(self._bar_generator)
|
||||
self._bars_cache.append(next_bar)
|
||||
self._last_bar_index += 1
|
||||
return next_bar
|
||||
except StopIteration:
|
||||
return None # 数据已读完
|
||||
|
||||
def get_history_bars(self, current_bar_datetime: pd.Timestamp, num_bars: int) -> List[Bar]:
|
||||
"""
|
||||
获取当前Bar之前的历史Bar列表。
|
||||
严格防止未来函数:只提供时间戳在 current_bar_datetime 之前的 Bar。
|
||||
|
||||
Args:
|
||||
current_bar_datetime (pd.Timestamp): 当前正在处理的Bar的时间戳。
|
||||
用于确保获取的历史数据不包含当前及未来的Bar。
|
||||
num_bars (int): 需要获取的历史Bar的数量。
|
||||
|
||||
Returns:
|
||||
List[Bar]: 历史Bar对象的列表,按时间升序排列。
|
||||
"""
|
||||
# 查找当前 Bar 在缓存中的位置
|
||||
# 这是一个查找操作,效率可能不高,但对于简易回测可以接受。
|
||||
# 更优方案可以是直接传递当前Bar在DataFrame中的索引。
|
||||
|
||||
# 过滤掉日期时间大于等于 current_bar_datetime 的 Bar
|
||||
# 并只取最新的 num_bars 根
|
||||
|
||||
# 确保历史数据是严格在 current_bar_datetime 之前的
|
||||
# 考虑到DataFrame索引已经排序,可以直接通过切片查找
|
||||
# 找到 current_bar_datetime 在 DataFrame 索引中的位置
|
||||
try:
|
||||
# 获取所有在当前时间点之前的历史Bar
|
||||
# 这里我们依赖raw_df的索引,假设它与_bars_cache是同步的
|
||||
# 这种方式更严格地避免未来函数,因为我们直接从原始DataFrame取历史
|
||||
# find_loc = self.raw_df.index.get_loc(current_bar_datetime)
|
||||
# if isinstance(find_loc, slice): # 处理重复时间戳
|
||||
# find_loc = find_loc.start
|
||||
|
||||
# 使用列表缓存来获取历史数据,避免每次都转换DataFrame
|
||||
# 必须确保 _bars_cache 已经包含了 current_bar_datetime 之前的Bar
|
||||
|
||||
# 从缓存中取出所有在 current_bar_datetime 之前的Bar
|
||||
# 找到第一个大于或等于 current_bar_datetime 的 Bar 的索引
|
||||
end_idx = -1
|
||||
for idx, bar in enumerate(self._bars_cache):
|
||||
if bar.datetime >= current_bar_datetime:
|
||||
end_idx = idx
|
||||
break
|
||||
if end_idx == -1: # 如果当前bar是最后一个,则所有缓存都是历史
|
||||
historical_data = self._bars_cache[:]
|
||||
else:
|
||||
historical_data = self._bars_cache[:end_idx]
|
||||
|
||||
# 从筛选出的历史数据中,取出最近的 num_bars 根
|
||||
return historical_data[-num_bars:]
|
||||
except KeyError:
|
||||
# 如果 current_bar_datetime 不在索引中,可能还没到该时间点
|
||||
return self._bars_cache[-num_bars:] if self._bars_cache else []
|
||||
except Exception as e:
|
||||
print(f"获取历史Bar时发生错误: {e}")
|
||||
return []
|
||||
|
||||
def reset(self):
|
||||
"""重置数据管理器,以便重新运行回测。"""
|
||||
self._bar_generator = df_to_bar_stream(self.raw_df, self.symbol)
|
||||
self._bars_cache = []
|
||||
self._last_bar_index = -1
|
||||
print("DataManager 已重置。")
|
||||
89
src/data_processing.py
Normal file
89
src/data_processing.py
Normal file
@@ -0,0 +1,89 @@
|
||||
# src/data_processing.py
|
||||
|
||||
import pandas as pd
|
||||
from typing import Iterator, List, Dict, Any
|
||||
import os
|
||||
|
||||
# 导入我们刚刚定义的 Bar 类
|
||||
from .core_data import Bar
|
||||
|
||||
def load_raw_data(file_path: str) -> pd.DataFrame:
|
||||
"""
|
||||
从 CSV 文件加载原始数据,并进行初步的数据类型处理。
|
||||
假设 datetime 列已经是北京时间,无需额外时区转换或本地化。
|
||||
|
||||
Args:
|
||||
file_path (str): CSV 文件的完整路径。
|
||||
|
||||
Returns:
|
||||
pd.DataFrame: 包含已处理时间戳的原始数据DataFrame。
|
||||
时间戳作为索引,列包括 'open', 'high', 'low', 'close',
|
||||
'volume', 'open_oi', 'close_oi'。
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: 如果文件不存在。
|
||||
KeyError: 如果CSV中缺少必要的列。
|
||||
"""
|
||||
if not os.path.exists(file_path):
|
||||
raise FileNotFoundError(f"数据文件未找到: {file_path}")
|
||||
|
||||
# 定义期望的列名,用于检查和选择
|
||||
expected_cols = ['datetime', 'open', 'high', 'low', 'close', 'volume', 'open_oi', 'close_oi']
|
||||
|
||||
try:
|
||||
# 使用 pandas.read_csv 直接解析 datetime 列
|
||||
# 'datetime' 列的格式是 'YYYY-MM-DD HH:MM:SS',pandas 能够很好地自动识别
|
||||
df = pd.read_csv(
|
||||
file_path,
|
||||
index_col='datetime', # 将 datetime 列设置为索引
|
||||
parse_dates=True # 自动解析索引为 datetime 类型
|
||||
)
|
||||
|
||||
# 检查所有必需的列是否存在
|
||||
missing_cols = [col for col in expected_cols[1:] if col not in df.columns]
|
||||
if missing_cols:
|
||||
raise KeyError(f"CSV文件中缺少以下必需列: {', '.join(missing_cols)}")
|
||||
|
||||
# 确保数据按时间排序 (这是回测的基础)
|
||||
df = df.sort_index()
|
||||
|
||||
print(f"数据加载成功: {file_path}")
|
||||
print(f"数据范围从 {df.index.min()} 到 {df.index.max()}")
|
||||
print(f"总计 {len(df)} 条记录。")
|
||||
|
||||
return df[expected_cols[1:]] # 返回包含核心数据的DataFrame
|
||||
except Exception as e:
|
||||
print(f"加载数据时发生错误: {e}")
|
||||
raise
|
||||
|
||||
def df_to_bar_stream(df: pd.DataFrame, symbol: str) -> Iterator[Bar]:
|
||||
"""
|
||||
将 Pandas DataFrame 转换为 Bar 对象的迭代器(数据流)。
|
||||
这符合函数式编程的理念,按需生成,不一次性加载所有Bar到内存。
|
||||
|
||||
Args:
|
||||
df (pd.DataFrame): 包含 K 线数据的 DataFrame。
|
||||
|
||||
Yields:
|
||||
Bar: 逐个生成的 Bar 对象。
|
||||
"""
|
||||
print("开始将 DataFrame 转换为 Bar 对象流...")
|
||||
for index, row in df.iterrows():
|
||||
try:
|
||||
bar = Bar(
|
||||
symbol=symbol,
|
||||
datetime=index,
|
||||
open=row['open'],
|
||||
high=row['high'],
|
||||
low=row['low'],
|
||||
close=row['close'],
|
||||
volume=int(row['volume']), # 确保成交量为整数
|
||||
open_oi=int(row['open_oi']), # 确保持仓量为整数
|
||||
close_oi=int(row['close_oi'])
|
||||
)
|
||||
yield bar
|
||||
except (ValueError, TypeError) as e:
|
||||
print(f"警告: 无法为 {index} 时间创建Bar对象,跳过。错误: {e}")
|
||||
continue
|
||||
print("Bar 对象流生成完毕。")
|
||||
|
||||
270
src/execution_simulator.py
Normal file
270
src/execution_simulator.py
Normal file
@@ -0,0 +1,270 @@
|
||||
# src/execution_simulator.py (修改部分)
|
||||
|
||||
from typing import Dict, List, Optional
|
||||
import pandas as pd
|
||||
from .core_data import Order, Trade, Bar, PortfolioSnapshot
|
||||
|
||||
|
||||
class ExecutionSimulator:
|
||||
"""
|
||||
模拟交易执行和管理账户资金、持仓。
|
||||
"""
|
||||
|
||||
def __init__(self, initial_capital: float,
|
||||
slippage_rate: float = 0.0001,
|
||||
commission_rate: float = 0.0002,
|
||||
initial_positions: Optional[Dict[str, int]] = None):
|
||||
"""
|
||||
Args:
|
||||
initial_capital (float): 初始资金。
|
||||
slippage_rate (float): 滑点率(相对于成交价格的百分比)。
|
||||
commission_rate (float): 佣金率(相对于成交金额的百分比)。
|
||||
initial_positions (Optional[Dict[str, int]]): 初始持仓,格式为 {symbol: quantity}。
|
||||
"""
|
||||
self.initial_capital = initial_capital
|
||||
self.cash = initial_capital
|
||||
self.positions: Dict[str, int] = initial_positions if initial_positions is not None else {}
|
||||
# 新增:跟踪持仓的平均成本 {symbol: average_cost}
|
||||
self.average_costs: Dict[str, float] = {}
|
||||
# 如果有初始持仓,需要设置初始成本(简化为0,或在外部配置)
|
||||
if initial_positions:
|
||||
for symbol, qty in initial_positions.items():
|
||||
# 初始持仓成本,如果需要精确,应该从外部传入
|
||||
self.average_costs[symbol] = 0.0 # 简化处理,初始持仓成本为0
|
||||
|
||||
self.slippage_rate = slippage_rate
|
||||
self.commission_rate = commission_rate
|
||||
self.trade_log: List[Trade] = [] # 存储所有成交记录
|
||||
self.pending_orders: Dict[str, Order] = {} # {order_id: Order_object}
|
||||
|
||||
print(
|
||||
f"模拟器初始化:初始资金={self.initial_capital:.2f}, 滑点率={self.slippage_rate}, 佣金率={self.commission_rate}")
|
||||
if self.positions:
|
||||
print(f"初始持仓:{self.positions}")
|
||||
|
||||
def _calculate_fill_price(self, order: Order, current_bar: Bar) -> float:
|
||||
"""
|
||||
内部方法:根据订单类型和滑点计算实际成交价格。
|
||||
简化处理:市价单以当前Bar收盘价为基准,考虑滑点。
|
||||
"""
|
||||
base_price = current_bar.close # 简化为收盘价成交
|
||||
|
||||
# 考虑滑点
|
||||
if order.direction in ["BUY", "CLOSE_SHORT"]: # 买入或平空,价格向上偏离
|
||||
fill_price = base_price * (1 + self.slippage_rate)
|
||||
elif order.direction in ["SELL", "CLOSE_LONG"]: # 卖出或平多,价格向下偏离
|
||||
fill_price = base_price * (1 - self.slippage_rate)
|
||||
else: # 默认情况,无滑点
|
||||
fill_price = base_price
|
||||
|
||||
# 如果是限价单且成交价格不满足条件,则可能不成交
|
||||
if order.price_type == "LIMIT" and order.limit_price is not None:
|
||||
# 对于BUY和CLOSE_SHORT,成交价必须 <= 限价
|
||||
if (order.direction == "BUY" or order.direction == "CLOSE_SHORT") and fill_price > order.limit_price:
|
||||
return -1.0 # 未触及限价
|
||||
# 对于SELL和CLOSE_LONG,成交价必须 >= 限价
|
||||
elif (order.direction == "SELL" or order.direction == "CLOSE_LONG") and fill_price < order.limit_price:
|
||||
return -1.0 # 未触及限价
|
||||
|
||||
return fill_price
|
||||
|
||||
def send_order(self, order: Order, current_bar: Bar) -> Optional[Trade]:
|
||||
"""
|
||||
接收策略发出的订单,并模拟执行。
|
||||
如果订单未立即成交,则加入待处理订单列表。
|
||||
特殊处理:如果 order.direction 是 "CANCEL",则调用 cancel_order。
|
||||
|
||||
Args:
|
||||
order (Order): 待执行的订单对象。
|
||||
current_bar (Bar): 当前的Bar数据,用于确定成交价格。
|
||||
|
||||
Returns:
|
||||
Optional[Trade]: 如果订单成功执行则返回 Trade 对象,否则返回 None。
|
||||
"""
|
||||
# --- 处理撤单指令 ---
|
||||
if order.direction == "CANCEL":
|
||||
success = self.cancel_order(order.id)
|
||||
if success:
|
||||
# print(f"[{current_bar.datetime}] 模拟器: 收到并成功处理撤单指令 for Order ID: {order.id}")
|
||||
pass
|
||||
else:
|
||||
# print(f"[{current_bar.datetime}] 模拟器: 收到撤单指令 for Order ID: {order.id}, 但订单已成交或不存在。")
|
||||
pass
|
||||
return None # 撤单操作不返回Trade
|
||||
|
||||
# --- 正常买卖订单处理 ---
|
||||
symbol = order.symbol
|
||||
volume = order.volume
|
||||
|
||||
# 尝试计算成交价格
|
||||
fill_price = self._calculate_fill_price(order, current_bar)
|
||||
|
||||
executed_trade: Optional[Trade] = None
|
||||
realized_pnl = 0.0 # 初始化实现盈亏
|
||||
is_open_trade = False
|
||||
is_close_trade = False
|
||||
|
||||
if fill_price <= 0: # 表示未成交或不满足限价条件
|
||||
if order.price_type == "LIMIT":
|
||||
self.pending_orders[order.id] = order
|
||||
return None # 未成交,返回None
|
||||
|
||||
# --- 以下是订单成功成交的逻辑 ---
|
||||
trade_value = volume * fill_price
|
||||
commission = trade_value * self.commission_rate
|
||||
|
||||
current_position = self.positions.get(symbol, 0)
|
||||
current_average_cost = self.average_costs.get(symbol, 0.0)
|
||||
|
||||
if order.direction == "BUY":
|
||||
# 开多仓或平空仓
|
||||
if current_position >= 0: # 当前持有多仓或无仓位 (开多)
|
||||
is_open_trade = True
|
||||
# 更新平均成本 (加权平均)
|
||||
new_total_cost = (current_average_cost * current_position) + (fill_price * volume)
|
||||
new_total_volume = current_position + volume
|
||||
self.average_costs[symbol] = new_total_cost / new_total_volume if new_total_volume > 0 else 0.0
|
||||
self.positions[symbol] = new_total_volume
|
||||
else: # 当前持有空仓 (平空)
|
||||
is_close_trade = True
|
||||
# 计算平空盈亏
|
||||
# PnL = (开仓成本 - 平仓价格) * 平仓数量 (注意空头方向)
|
||||
# 简化:假设平空时,直接使用当前的平均开仓成本来计算盈亏
|
||||
# 更精确的FIFO/LIFO需更多逻辑
|
||||
pnl_per_share = current_average_cost - fill_price # (买入平空,成本高于平仓价则盈利)
|
||||
realized_pnl = pnl_per_share * volume
|
||||
|
||||
# 更新持仓和成本
|
||||
self.positions[symbol] += volume
|
||||
if self.positions[symbol] == 0:
|
||||
del self.positions[symbol]
|
||||
if symbol in self.average_costs: del self.average_costs[symbol] # 清理成本
|
||||
elif self.positions[symbol] > 0 and current_position < 0: # 部分平空转为多头,需重新设置成本
|
||||
# 这部分逻辑可以更复杂,这里简化处理,如果转为多头,成本重置为0
|
||||
# 实际应该用剩余的空头成本 + 新开多的成本
|
||||
self.average_costs[symbol] = fill_price # 简单地将剩下的多头仓位成本设为当前价格
|
||||
|
||||
# 资金扣除
|
||||
if self.cash >= trade_value + commission:
|
||||
self.cash -= (trade_value + commission)
|
||||
else:
|
||||
# print(f"[{current_bar.datetime}] 资金不足,无法执行买入 {volume} {symbol}")
|
||||
return None
|
||||
|
||||
|
||||
elif order.direction == "SELL":
|
||||
# 开空仓或平多仓
|
||||
if current_position <= 0: # 当前持有空仓或无仓位 (开空)
|
||||
is_open_trade = True
|
||||
# 更新平均成本 (空头成本为负值)
|
||||
new_total_cost = (current_average_cost * current_position) - (fill_price * volume) # 负的持仓乘以负的卖出价
|
||||
new_total_volume = current_position - volume # 空头持仓量更负
|
||||
self.average_costs[symbol] = new_total_cost / new_total_volume if new_total_volume < 0 else 0.0
|
||||
self.positions[symbol] = new_total_volume
|
||||
else: # 当前持有多仓 (平多)
|
||||
is_close_trade = True
|
||||
# 计算平多盈亏
|
||||
# PnL = (平仓价格 - 开仓成本) * 平仓数量
|
||||
pnl_per_share = fill_price - current_average_cost # (卖出平多,平仓价高于成本则盈利)
|
||||
realized_pnl = pnl_per_share * volume
|
||||
|
||||
# 更新持仓和成本
|
||||
self.positions[symbol] -= volume
|
||||
if self.positions[symbol] == 0:
|
||||
del self.positions[symbol]
|
||||
if symbol in self.average_costs: del self.average_costs[symbol] # 清理成本
|
||||
elif self.positions[symbol] < 0 and current_position > 0: # 部分平多转为空头,需重新设置成本
|
||||
self.average_costs[symbol] = fill_price # 简单地将剩下的空头仓位成本设为当前价格
|
||||
|
||||
# 资金扣除 (佣金) 和增加 (卖出收入)
|
||||
if self.cash >= commission:
|
||||
self.cash -= commission
|
||||
self.cash += trade_value
|
||||
else:
|
||||
# print(f"[{current_bar.datetime}] 资金不足,无法执行卖出 {volume} {symbol}")
|
||||
return None
|
||||
|
||||
# 创建 Trade 对象
|
||||
executed_trade = Trade(
|
||||
order_id=order.id, fill_time=current_bar.datetime, symbol=symbol,
|
||||
direction=order.direction, # 记录原始订单方向 (BUY/SELL)
|
||||
volume=volume, price=fill_price, commission=commission,
|
||||
cash_after_trade=self.cash, positions_after_trade=self.positions.copy(),
|
||||
realized_pnl=realized_pnl, # 填充实现盈亏
|
||||
is_open_trade=is_open_trade,
|
||||
is_close_trade=is_close_trade
|
||||
)
|
||||
self.trade_log.append(executed_trade)
|
||||
|
||||
# 如果订单成交,无论它是市价单还是限价单,都从待处理订单中移除
|
||||
if order.id in self.pending_orders:
|
||||
del self.pending_orders[order.id]
|
||||
|
||||
# print(f"[{current_bar.datetime}] 成交: {executed_trade.direction} {executed_trade.volume} {executed_trade.symbol} @ {executed_trade.price:.2f}, 佣金: {executed_trade.commission:.2f}, PnL: {executed_trade.realized_pnl:.2f}")
|
||||
|
||||
return executed_trade
|
||||
|
||||
def cancel_order(self, order_id: str) -> bool:
|
||||
"""
|
||||
尝试取消一个待处理订单。
|
||||
|
||||
Args:
|
||||
order_id (str): 要取消的订单ID。
|
||||
|
||||
Returns:
|
||||
bool: 如果成功取消则返回 True,否则返回 False(例如,订单不存在或已成交)。
|
||||
"""
|
||||
if order_id in self.pending_orders:
|
||||
# print(f"订单 {order_id} 已成功取消。")
|
||||
del self.pending_orders[order_id]
|
||||
return True
|
||||
# print(f"订单 {order_id} 不存在或已成交,无法取消。")
|
||||
return False
|
||||
|
||||
def get_pending_orders(self) -> Dict[str, Order]:
|
||||
"""
|
||||
获取当前所有待处理订单的副本。
|
||||
"""
|
||||
return self.pending_orders.copy()
|
||||
|
||||
def get_portfolio_value(self, current_bar: Bar) -> float:
|
||||
"""
|
||||
计算当前的投资组合总价值(包括现金和持仓市值)。
|
||||
Args:
|
||||
current_bar (Bar): 当前的Bar数据,用于计算持仓市值。
|
||||
Returns:
|
||||
float: 当前的投资组合总价值。
|
||||
"""
|
||||
total_value = self.cash
|
||||
|
||||
# 在单品种场景下,我们假设 self.positions 最多只包含一个品种
|
||||
# 并且这个品种就是 current_bar.symbol 所代表的品种
|
||||
symbol_in_position = list(self.positions.keys())[0] if self.positions else None
|
||||
|
||||
if symbol_in_position and symbol_in_position == current_bar.symbol:
|
||||
quantity = self.positions[symbol_in_position]
|
||||
# 持仓市值 = 数量 * 当前市场价格 (current_bar.close)
|
||||
# 无论多头(quantity > 0)还是空头(quantity < 0),这个计算都是正确的
|
||||
total_value += quantity * current_bar.close
|
||||
|
||||
# 您也可以选择在这里打印调试信息
|
||||
# print(f" DEBUG Portfolio Value Calculation: Cash={self.cash:.2f}, "
|
||||
# f"Position for {symbol_in_position}: {quantity} @ {current_bar.close:.2f}, "
|
||||
# f"Position Value={quantity * current_bar.close:.2f}, Total Value={total_value:.2f}")
|
||||
|
||||
# 如果没有持仓,或者持仓品种与当前Bar品种不符 (理论上单品种不会发生)
|
||||
# 那么 total_value 依然是 self.cash
|
||||
|
||||
return total_value
|
||||
|
||||
def get_current_positions(self) -> Dict[str, int]:
|
||||
"""
|
||||
返回当前持仓字典的副本。
|
||||
"""
|
||||
return self.positions.copy()
|
||||
|
||||
def get_trade_history(self) -> List[Trade]:
|
||||
"""
|
||||
返回所有成交记录的副本。
|
||||
"""
|
||||
return self.trade_log.copy()
|
||||
0
src/strategies/__init__.py
Normal file
0
src/strategies/__init__.py
Normal file
68
src/strategies/base_strategy.py
Normal file
68
src/strategies/base_strategy.py
Normal file
@@ -0,0 +1,68 @@
|
||||
# src/strategies/base_strategy.py
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, Any, Optional
|
||||
|
||||
# 导入核心数据类
|
||||
from ..core_data import Bar, Order, Trade
|
||||
# 导入回测上下文 (注意相对导入路径的变化)
|
||||
from ..backtest_context import BacktestContext
|
||||
|
||||
class Strategy(ABC):
|
||||
"""
|
||||
策略抽象基类。所有具体策略都应继承此类,并实现 on_bar 方法。
|
||||
"""
|
||||
def __init__(self, context: BacktestContext, **parameters: Any):
|
||||
"""
|
||||
初始化策略。
|
||||
|
||||
Args:
|
||||
context (BacktestContext): 回测上下文对象,用于与模拟器和数据管理器交互。
|
||||
**parameters (Any): 策略所需的任何自定义参数。
|
||||
"""
|
||||
self.context = context
|
||||
self.parameters = parameters
|
||||
self.symbol = parameters.get('symbol', "DEFAULT_SYMBOL") # 策略操作的品种
|
||||
self.trade_volume = parameters.get('trade_volume', 100) # 每次下单的数量
|
||||
print(f"策略初始化: {self.__class__.__name__},参数: {self.parameters}")
|
||||
|
||||
@abstractmethod
|
||||
def on_bar(self, bar: Bar):
|
||||
"""
|
||||
每当接收到新的Bar数据时调用。
|
||||
具体策略逻辑在此方法中实现。
|
||||
|
||||
Args:
|
||||
bar (Bar): 当前的Bar数据对象。
|
||||
features (Dict[str, float]): 由数据处理模块计算并传入的特征字典。
|
||||
"""
|
||||
pass
|
||||
|
||||
def on_init(self):
|
||||
"""
|
||||
策略初始化时调用(在回测开始前)。
|
||||
可用于设置初始状态或打印信息。
|
||||
"""
|
||||
print(f"{self.__class__.__name__} 策略初始化回调被调用。")
|
||||
|
||||
def on_trade(self, trade: Trade):
|
||||
"""
|
||||
当模拟器成功执行一笔交易时调用。
|
||||
可用于更新策略内部持仓状态或记录交易。
|
||||
|
||||
Args:
|
||||
trade (Trade): 已完成的交易记录。
|
||||
"""
|
||||
# print(f"策略接收到交易: {trade.direction} {trade.volume} {trade.symbol} @ {trade.price:.2f}")
|
||||
pass # 默认不执行任何操作,具体策略可覆盖
|
||||
|
||||
def on_order_status(self, order: Order, status: str):
|
||||
"""
|
||||
当订单状态更新时调用 (例如,未成交,已提交等)。
|
||||
在简易回测中,可能不会频繁使用。
|
||||
|
||||
Args:
|
||||
order (Order): 相关订单对象。
|
||||
status (str): 订单状态(例如 "FILLED", "PENDING", "CANCELLED")。
|
||||
"""
|
||||
pass # 默认不执行任何操作
|
||||
124
src/strategies/simple_limit_buy_strategy.py
Normal file
124
src/strategies/simple_limit_buy_strategy.py
Normal file
@@ -0,0 +1,124 @@
|
||||
# src/strategies/simple_limit_buy_strategy.py (修改部分)
|
||||
|
||||
from typing import Dict, Any, Optional
|
||||
import pandas as pd
|
||||
|
||||
# 导入 Strategy 抽象基类
|
||||
from .base_strategy import Strategy
|
||||
# 导入核心数据类
|
||||
from ..core_data import Bar, Order, Trade
|
||||
|
||||
|
||||
class SimpleLimitBuyStrategy(Strategy):
|
||||
"""
|
||||
一个简单的限价买入策略:
|
||||
在每根Bar线上,如果当前没有持仓,且没有待处理的买入订单,则尝试下一个限价多单。
|
||||
如果在当前Bar之前有未成交的买入订单,则撤销该订单。
|
||||
确保在任意时间点,最多只有一笔限价买入订单在市场中。
|
||||
"""
|
||||
|
||||
def __init__(self, context: Any, **parameters: Any): # context 类型提示可以为 BacktestContext
|
||||
super().__init__(context, **parameters)
|
||||
self.order_id_counter = 0
|
||||
self.limit_price_factor = parameters.get('limit_price_factor', 0.999) # 限价因子
|
||||
self.max_position = parameters.get('max_position', self.trade_volume * 2) # 最大持仓量
|
||||
|
||||
self._last_order_id: Optional[str] = None # 跟踪上一根Bar发出的订单ID
|
||||
self._current_long_position: int = 0 # 策略内部维护的当前持仓
|
||||
|
||||
def on_init(self):
|
||||
super().on_init()
|
||||
# 确保初始状态正确
|
||||
self._last_order_id = None
|
||||
self._current_long_position = 0 # 或者从模拟器获取初始持仓
|
||||
|
||||
def on_trade(self, trade: Trade):
|
||||
"""
|
||||
当模拟器成功执行一笔交易时调用。
|
||||
更新策略内部持仓状态。
|
||||
"""
|
||||
super().on_trade(trade) # 调用父类方法
|
||||
# 简单起见,这里假设只交易self.symbol
|
||||
if trade.symbol == self.symbol:
|
||||
if trade.direction == "BUY":
|
||||
self._current_long_position += trade.volume
|
||||
elif trade.direction == "SELL": # 可能是平多或开空
|
||||
self._current_long_position -= trade.volume # 卖出量为正值,所以是减
|
||||
|
||||
# 如果成交的是我们之前提交的订单,清空_last_order_id
|
||||
if self._last_order_id == trade.order_id:
|
||||
self._last_order_id = None
|
||||
|
||||
# 打印当前持仓
|
||||
# print(f"[{trade.fill_time}] 策略内部持仓更新: {self.symbol} -> {self._current_long_position}")
|
||||
|
||||
def on_bar(self, bar: Bar):
|
||||
"""
|
||||
每接收到一根Bar时,执行策略逻辑。
|
||||
"""
|
||||
current_portfolio_value = self.context.get_current_portfolio_value(bar)
|
||||
print(f"[{bar.datetime}] Strategy processing Bar. Current close price: {bar.close:.2f}. Current Portfolio Value: {current_portfolio_value:.2f}")
|
||||
|
||||
# 1. 撤销上一根K线未成交的订单
|
||||
if self._last_order_id:
|
||||
# 检查这个订单是否仍然在待处理订单列表中
|
||||
pending_orders = self.context._simulator.get_pending_orders() # 直接访问模拟器,或者通过context提供接口
|
||||
if self._last_order_id in pending_orders:
|
||||
success = self.context.send_order(Order(id=self._last_order_id, symbol=self.symbol,
|
||||
direction="CANCEL", volume=0,
|
||||
price_type="CANCEL")) # 使用一个特殊Order类型表示撤单
|
||||
# 这里发送的“撤单订单”会被simulator的send_order处理,并调用simulator.cancel_order
|
||||
if success: # simulator.send_order返回Trade或None,这里我们用一个特殊处理
|
||||
# Simulator的send_order返回的是Trade,如果实现撤单,最好Simulator的cancel_order返回bool
|
||||
print(f"[{bar.datetime}] 策略: 成功撤销上一根K线未成交订单 {self._last_order_id}")
|
||||
else:
|
||||
print(f"[{bar.datetime}] 策略: 尝试撤销订单 {self._last_order_id} 失败(可能已成交或不存在)")
|
||||
|
||||
# 无论撤销成功与否,既然我们尝试了撤销,就清除记录
|
||||
self._last_order_id = None
|
||||
else:
|
||||
# 订单不在待处理列表中,说明它可能已经成交了 (在on_trade中已处理)
|
||||
# 或者在上一根K线已经被取消/过期
|
||||
self._last_order_id = None # 清理状态
|
||||
# print(f"[{bar.datetime}] 订单 {self._last_order_id} 不在待处理列表,无需撤销。")
|
||||
|
||||
# 2. 判断是否需要下单
|
||||
# 如果当前没有多头持仓,并且没有待处理的买入订单
|
||||
# (注: _last_order_id被清除后,_last_order_id为None表示当前没有待处理的我们发出的买单)
|
||||
current_positions = self.context.get_current_positions()
|
||||
self._current_long_position = current_positions.get(self.symbol, 0) # 从模拟器获取最新持仓
|
||||
|
||||
if self._current_long_position == 0 and self._last_order_id is None: # 确保只有一笔买单
|
||||
# 计算限价价格
|
||||
limit_price = bar.open * self.limit_price_factor
|
||||
trade_volume = self.trade_volume
|
||||
|
||||
# 生成唯一的订单ID
|
||||
order_id = f"{self.symbol}_BUY_{bar.datetime.strftime('%Y%m%d%H%M%S')}_{self.order_id_counter}"
|
||||
self.order_id_counter += 1
|
||||
|
||||
# 创建一个限价多单
|
||||
order = Order(
|
||||
id=order_id,
|
||||
symbol=self.symbol,
|
||||
direction="SELL",
|
||||
volume=trade_volume,
|
||||
price_type="LIMIT",
|
||||
limit_price=limit_price,
|
||||
submitted_time=bar.datetime
|
||||
)
|
||||
|
||||
# 通过上下文发送订单
|
||||
trade = self.context.send_order(order)
|
||||
if trade:
|
||||
print(
|
||||
f"[{bar.datetime}] 策略: 发送并立即成交限价买单 {trade.volume} 股 @ {trade.price:.2f} (订单ID: {order.id})")
|
||||
# 如果立即成交,_last_order_id 仍然保持 None
|
||||
else:
|
||||
# 如果未立即成交,将订单ID记录下来,以便下一根Bar撤销
|
||||
self._last_order_id = order.id
|
||||
print(
|
||||
f"[{bar.datetime}] 策略: 发送限价买单 {trade_volume} 股 @ {limit_price:.2f} (未成交,订单ID: {order.id} 已挂单)")
|
||||
else:
|
||||
# print(f"[{bar.datetime}] 策略: 当前已有持仓或有未成交订单,不重复下单。")
|
||||
pass
|
||||
Reference in New Issue
Block a user