Files
NewQuant/main.py

78 lines
2.8 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# main.py
from src.analysis.result_analyzer import ResultAnalyzer
# 导入所有必要的模块
from src.data_manager import DataManager
from src.backtest_engine import BacktestEngine
from src.strategies.simple_limit_buy_strategy import SimpleLimitBuyStrategy
def main():
# --- 配置参数 ---
# 获取当前脚本所在目录,假设数据文件在项目根目录下的 data 文件夹内
data_file_path = '/mnt/d/PyProject/NewQuant/data/data/SHFE_rb2501/SHFE_rb2501_m60_20240901_20241201_min60.csv'
initial_capital = 100000.0
slippage_rate = 0.001 # 假设每笔交易0.1%的滑点
commission_rate = 0.0002 # 假设每笔交易0.02%的佣金
strategy_parameters = {
'symbol': "SHFE_rb2501", # 根据您的数据文件中的品种名称调整
'trade_volume': 1, # 每次交易1手/股
'limit_price_factor': 0.995, # 限价单价格为开盘价的99.5%
'max_position': 10 # 最大持仓10手/股
}
# --- 1. 初始化数据管理器 ---
print("初始化数据管理器...")
data_manager = DataManager(file_path=data_file_path)
# 确保 DataManager 能够重置以进行多次回测
# data_manager.reset() # 首次运行不需要重置
# --- 2. 初始化回测引擎并运行 ---
print("\n初始化回测引擎...")
engine = BacktestEngine(
data_manager=data_manager,
strategy_class=SimpleLimitBuyStrategy,
strategy_params=strategy_parameters,
initial_capital=initial_capital,
slippage_rate=slippage_rate,
commission_rate=commission_rate
)
print("\n开始运行回测...")
engine.run_backtest()
print("\n回测运行完毕。")
# --- 3. 获取回测结果 ---
results = engine.get_backtest_results()
portfolio_snapshots = results["portfolio_snapshots"]
trade_history = results["trade_history"]
initial_capital_result = results["initial_capital"]
# --- 4. 结果分析与可视化 ---
if portfolio_snapshots:
analyzer = ResultAnalyzer(portfolio_snapshots, trade_history, initial_capital_result)
analyzer.generate_report()
analyzer.plot_performance()
else:
print("\n没有生成投资组合快照,无法进行结果分析。")
# --- 4. 结果分析与可视化 (待实现) ---
# if portfolio_snapshots:
# analyzer = ResultAnalyzer(portfolio_snapshots, trade_history, initial_capital_result)
# metrics = analyzer.calculate_all_metrics()
# print("\n--- 绩效指标 ---")
# for key, value in metrics.items():
# print(f" {key}: {value:.4f}")
#
# print("\n--- 绘制绩效图表 ---")
# analyzer.plot_performance()
# else:
# print("\n没有生成投资组合快照无法进行结果分析。")
if __name__ == '__main__':
main()