78 lines
2.8 KiB
Python
78 lines
2.8 KiB
Python
# main.py
|
||
|
||
from src.analysis.result_analyzer import ResultAnalyzer
|
||
# 导入所有必要的模块
|
||
from src.data_manager import DataManager
|
||
from src.backtest_engine import BacktestEngine
|
||
from src.strategies.simple_limit_buy_strategy import SimpleLimitBuyStrategy
|
||
|
||
|
||
def main():
|
||
# --- 配置参数 ---
|
||
# 获取当前脚本所在目录,假设数据文件在项目根目录下的 data 文件夹内
|
||
data_file_path = '/mnt/d/PyProject/NewQuant/data/data/SHFE_rb2501/SHFE_rb2501_m60_20240901_20241201_min60.csv'
|
||
|
||
initial_capital = 100000.0
|
||
slippage_rate = 0.001 # 假设每笔交易0.1%的滑点
|
||
commission_rate = 0.0002 # 假设每笔交易0.02%的佣金
|
||
|
||
strategy_parameters = {
|
||
'symbol': "SHFE_rb2501", # 根据您的数据文件中的品种名称调整
|
||
'trade_volume': 1, # 每次交易1手/股
|
||
'limit_price_factor': 0.995, # 限价单价格为开盘价的99.5%
|
||
'max_position': 10 # 最大持仓10手/股
|
||
}
|
||
|
||
# --- 1. 初始化数据管理器 ---
|
||
print("初始化数据管理器...")
|
||
data_manager = DataManager(file_path=data_file_path)
|
||
# 确保 DataManager 能够重置以进行多次回测
|
||
# data_manager.reset() # 首次运行不需要重置
|
||
|
||
# --- 2. 初始化回测引擎并运行 ---
|
||
print("\n初始化回测引擎...")
|
||
engine = BacktestEngine(
|
||
data_manager=data_manager,
|
||
strategy_class=SimpleLimitBuyStrategy,
|
||
strategy_params=strategy_parameters,
|
||
initial_capital=initial_capital,
|
||
slippage_rate=slippage_rate,
|
||
commission_rate=commission_rate
|
||
)
|
||
|
||
print("\n开始运行回测...")
|
||
engine.run_backtest()
|
||
print("\n回测运行完毕。")
|
||
|
||
# --- 3. 获取回测结果 ---
|
||
results = engine.get_backtest_results()
|
||
portfolio_snapshots = results["portfolio_snapshots"]
|
||
trade_history = results["trade_history"]
|
||
initial_capital_result = results["initial_capital"]
|
||
|
||
# --- 4. 结果分析与可视化 ---
|
||
if portfolio_snapshots:
|
||
analyzer = ResultAnalyzer(portfolio_snapshots, trade_history, initial_capital_result)
|
||
|
||
analyzer.generate_report()
|
||
analyzer.plot_performance()
|
||
else:
|
||
print("\n没有生成投资组合快照,无法进行结果分析。")
|
||
|
||
# --- 4. 结果分析与可视化 (待实现) ---
|
||
# if portfolio_snapshots:
|
||
# analyzer = ResultAnalyzer(portfolio_snapshots, trade_history, initial_capital_result)
|
||
# metrics = analyzer.calculate_all_metrics()
|
||
# print("\n--- 绩效指标 ---")
|
||
# for key, value in metrics.items():
|
||
# print(f" {key}: {value:.4f}")
|
||
#
|
||
# print("\n--- 绘制绩效图表 ---")
|
||
# analyzer.plot_performance()
|
||
# else:
|
||
# print("\n没有生成投资组合快照,无法进行结果分析。")
|
||
|
||
|
||
if __name__ == '__main__':
|
||
main()
|