78 lines
2.8 KiB
Python
78 lines
2.8 KiB
Python
|
|
# main.py
|
|||
|
|
|
|||
|
|
from src.analysis.result_analyzer import ResultAnalyzer
|
|||
|
|
# 导入所有必要的模块
|
|||
|
|
from src.data_manager import DataManager
|
|||
|
|
from src.backtest_engine import BacktestEngine
|
|||
|
|
from src.strategies.simple_limit_buy_strategy import SimpleLimitBuyStrategy
|
|||
|
|
|
|||
|
|
|
|||
|
|
def main():
|
|||
|
|
# --- 配置参数 ---
|
|||
|
|
# 获取当前脚本所在目录,假设数据文件在项目根目录下的 data 文件夹内
|
|||
|
|
data_file_path = '/mnt/d/PyProject/NewQuant/data/data/SHFE_rb2501/SHFE_rb2501_m60_20240901_20241201_min60.csv'
|
|||
|
|
|
|||
|
|
initial_capital = 100000.0
|
|||
|
|
slippage_rate = 0.001 # 假设每笔交易0.1%的滑点
|
|||
|
|
commission_rate = 0.0002 # 假设每笔交易0.02%的佣金
|
|||
|
|
|
|||
|
|
strategy_parameters = {
|
|||
|
|
'symbol': "SHFE_rb2501", # 根据您的数据文件中的品种名称调整
|
|||
|
|
'trade_volume': 1, # 每次交易1手/股
|
|||
|
|
'limit_price_factor': 0.995, # 限价单价格为开盘价的99.5%
|
|||
|
|
'max_position': 10 # 最大持仓10手/股
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
# --- 1. 初始化数据管理器 ---
|
|||
|
|
print("初始化数据管理器...")
|
|||
|
|
data_manager = DataManager(file_path=data_file_path)
|
|||
|
|
# 确保 DataManager 能够重置以进行多次回测
|
|||
|
|
# data_manager.reset() # 首次运行不需要重置
|
|||
|
|
|
|||
|
|
# --- 2. 初始化回测引擎并运行 ---
|
|||
|
|
print("\n初始化回测引擎...")
|
|||
|
|
engine = BacktestEngine(
|
|||
|
|
data_manager=data_manager,
|
|||
|
|
strategy_class=SimpleLimitBuyStrategy,
|
|||
|
|
strategy_params=strategy_parameters,
|
|||
|
|
initial_capital=initial_capital,
|
|||
|
|
slippage_rate=slippage_rate,
|
|||
|
|
commission_rate=commission_rate
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
print("\n开始运行回测...")
|
|||
|
|
engine.run_backtest()
|
|||
|
|
print("\n回测运行完毕。")
|
|||
|
|
|
|||
|
|
# --- 3. 获取回测结果 ---
|
|||
|
|
results = engine.get_backtest_results()
|
|||
|
|
portfolio_snapshots = results["portfolio_snapshots"]
|
|||
|
|
trade_history = results["trade_history"]
|
|||
|
|
initial_capital_result = results["initial_capital"]
|
|||
|
|
|
|||
|
|
# --- 4. 结果分析与可视化 ---
|
|||
|
|
if portfolio_snapshots:
|
|||
|
|
analyzer = ResultAnalyzer(portfolio_snapshots, trade_history, initial_capital_result)
|
|||
|
|
|
|||
|
|
analyzer.generate_report()
|
|||
|
|
analyzer.plot_performance()
|
|||
|
|
else:
|
|||
|
|
print("\n没有生成投资组合快照,无法进行结果分析。")
|
|||
|
|
|
|||
|
|
# --- 4. 结果分析与可视化 (待实现) ---
|
|||
|
|
# if portfolio_snapshots:
|
|||
|
|
# analyzer = ResultAnalyzer(portfolio_snapshots, trade_history, initial_capital_result)
|
|||
|
|
# metrics = analyzer.calculate_all_metrics()
|
|||
|
|
# print("\n--- 绩效指标 ---")
|
|||
|
|
# for key, value in metrics.items():
|
|||
|
|
# print(f" {key}: {value:.4f}")
|
|||
|
|
#
|
|||
|
|
# print("\n--- 绘制绩效图表 ---")
|
|||
|
|
# analyzer.plot_performance()
|
|||
|
|
# else:
|
|||
|
|
# print("\n没有生成投资组合快照,无法进行结果分析。")
|
|||
|
|
|
|||
|
|
|
|||
|
|
if __name__ == '__main__':
|
|||
|
|
main()
|