feat(factors/engine): 添加性能分析器支持 debug 模式
- 新增 PerformanceProfiler 组件,与 FactorEngine 解耦 - 支持上下文管理器用法,安全计时处理异常 - 为 DataRouter/ComputeEngine/FactorEngine 添加 debug 参数
This commit is contained in:
@@ -4,8 +4,10 @@
|
||||
|
||||
使用方法:
|
||||
uv run python src/experiment/probe_selection/run_probe_selection_all_factors.py
|
||||
uv run python src/experiment/probe_selection/run_probe_selection_all_factors.py --debug
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import sys
|
||||
|
||||
@@ -180,15 +182,26 @@ def apply_preprocessing_for_probe(
|
||||
return data
|
||||
|
||||
|
||||
def run_probe_feature_selection_with_all_factors():
|
||||
"""执行探针法因子筛选(使用 FactorManager 中所有因子)"""
|
||||
def run_probe_feature_selection_with_all_factors(debug: bool = True):
|
||||
"""执行探针法因子筛选(使用 FactorManager 中所有因子)
|
||||
|
||||
Args:
|
||||
debug: 是否启用 debug 模式,显示详细的性能统计信息
|
||||
|
||||
Returns:
|
||||
筛选后的特征列表
|
||||
"""
|
||||
print("\n" + "=" * 80)
|
||||
print("增强探针法因子筛选 - 使用 FactorManager 全部因子")
|
||||
if debug:
|
||||
print("[DEBUG] 性能监控模式已启用")
|
||||
print("=" * 80)
|
||||
|
||||
# 1. 创建 FactorEngine
|
||||
print("\n[1] 创建 FactorEngine")
|
||||
engine = FactorEngine()
|
||||
if debug:
|
||||
print("[DEBUG] 启用性能监控模式")
|
||||
engine = FactorEngine(debug=debug)
|
||||
|
||||
# 2. 从 FactorManager 注册所有因子
|
||||
print("\n[2] 从 FactorManager 注册所有因子")
|
||||
@@ -203,6 +216,11 @@ def run_probe_feature_selection_with_all_factors():
|
||||
end_date=VAL_END, # 包含验证集,增加样本量
|
||||
)
|
||||
|
||||
# 3.5. 打印性能报告(如果启用了 debug 模式)
|
||||
if debug:
|
||||
print("\n[DEBUG] 生成性能报告")
|
||||
engine.profiler.print_report()
|
||||
|
||||
# 4. 股票池筛选
|
||||
print("\n[4] 执行股票池筛选")
|
||||
pool_manager = StockPoolManager(
|
||||
@@ -309,5 +327,23 @@ def run_probe_feature_selection_with_all_factors():
|
||||
return selected_features
|
||||
|
||||
|
||||
def main():
|
||||
"""主入口函数,支持命令行参数"""
|
||||
parser = argparse.ArgumentParser(
|
||||
description="探针法因子筛选 - 使用 FactorManager 中所有因子"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--debug",
|
||||
"-d",
|
||||
action="store_true",
|
||||
help="启用 debug 模式,显示详细的性能统计信息",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
selected = run_probe_feature_selection_with_all_factors(debug=args.debug)
|
||||
return selected
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
selected = run_probe_feature_selection_with_all_factors()
|
||||
main()
|
||||
|
||||
@@ -6,7 +6,8 @@
|
||||
避免 Python 层面的多进程/多线程开销。
|
||||
"""
|
||||
|
||||
from typing import Dict, List, Set
|
||||
import time
|
||||
from typing import Dict, List, Optional, Set
|
||||
|
||||
import polars as pl
|
||||
|
||||
@@ -25,9 +26,13 @@ class ComputeEngine:
|
||||
4. Polars 自动在所有 CPU 核心上并行计算,零拷贝内存
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""初始化计算引擎。"""
|
||||
pass
|
||||
def __init__(self, debug: bool = False) -> None:
|
||||
"""初始化计算引擎。
|
||||
|
||||
Args:
|
||||
debug: 是否启用调试模式
|
||||
"""
|
||||
self._debug = debug
|
||||
|
||||
def execute(
|
||||
self,
|
||||
|
||||
@@ -6,6 +6,7 @@
|
||||
支持标准等值匹配和 asof_backward(财务数据)两种拼接模式。
|
||||
"""
|
||||
|
||||
import time
|
||||
from typing import Any, Dict, List, Optional, Set, Union
|
||||
import threading
|
||||
|
||||
@@ -25,19 +26,24 @@ class DataRouter:
|
||||
Attributes:
|
||||
data_source: 数据源,可以是内存 DataFrame 字典或数据库连接
|
||||
is_memory_mode: 是否为内存模式
|
||||
_debug: 是否启用调试模式
|
||||
"""
|
||||
|
||||
def __init__(self, data_source: Optional[Dict[str, pl.DataFrame]] = None) -> None:
|
||||
def __init__(
|
||||
self, data_source: Optional[Dict[str, pl.DataFrame]] = None, debug: bool = False
|
||||
) -> None:
|
||||
"""初始化数据路由器。
|
||||
|
||||
Args:
|
||||
data_source: 内存数据源,字典格式 {表名: DataFrame}
|
||||
为 None 时自动连接 DuckDB 数据库
|
||||
debug: 是否启用调试模式
|
||||
"""
|
||||
self.data_source = data_source or {}
|
||||
self.is_memory_mode = data_source is not None
|
||||
self._cache: Dict[str, pl.DataFrame] = {}
|
||||
self._lock = threading.Lock()
|
||||
self._debug = debug
|
||||
|
||||
# 数据库模式下初始化 Storage 和 FinancialLoader
|
||||
if not self.is_memory_mode:
|
||||
@@ -71,6 +77,14 @@ class DataRouter:
|
||||
if not data_specs:
|
||||
raise ValueError("数据规格不能为空")
|
||||
|
||||
if self._debug:
|
||||
print(f"\n[DataRouter Debug] ========== 开始数据获取 ==========")
|
||||
print(f"[DataRouter Debug] 日期范围: {start_date} - {end_date}")
|
||||
print(
|
||||
f"[DataRouter Debug] 股票代码: {stock_codes if stock_codes else '全市场'}"
|
||||
)
|
||||
print(f"[DataRouter Debug] 数据规格数: {len(data_specs)}")
|
||||
|
||||
# 收集所有需要的表和字段
|
||||
required_tables: Dict[str, Set[str]] = {}
|
||||
|
||||
@@ -79,8 +93,16 @@ class DataRouter:
|
||||
required_tables[spec.table] = set()
|
||||
required_tables[spec.table].update(spec.columns)
|
||||
|
||||
if self._debug:
|
||||
print(f"\n[DataRouter Debug] 需要加载的表:")
|
||||
for table_name, columns in required_tables.items():
|
||||
print(f" - {table_name}: {len(columns)}个字段")
|
||||
|
||||
# 从数据源获取各表数据(使用合并后的 required_tables,避免重复加载)
|
||||
table_data = {}
|
||||
table_load_times: Dict[str, float] = {}
|
||||
t_load_start = time.perf_counter()
|
||||
|
||||
for table_name, columns in required_tables.items():
|
||||
# 判断是标准表还是财务表
|
||||
is_financial = any(
|
||||
@@ -110,18 +132,40 @@ class DataRouter:
|
||||
join_type="standard",
|
||||
)
|
||||
|
||||
t_table_start = time.perf_counter()
|
||||
df = self._load_table_from_spec(
|
||||
spec=spec,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
stock_codes=stock_codes,
|
||||
)
|
||||
t_table_end = time.perf_counter()
|
||||
table_data[table_name] = df
|
||||
table_load_times[table_name] = t_table_end - t_table_start
|
||||
|
||||
t_load_end = time.perf_counter()
|
||||
|
||||
if self._debug:
|
||||
print(f"\n[DataRouter Debug] 表加载详情:")
|
||||
for table_name, load_time in table_load_times.items():
|
||||
rows = len(table_data[table_name])
|
||||
print(f" - {table_name}: {rows:,}行, {load_time * 1000:.2f}ms")
|
||||
print(f" - 表加载总耗时: {(t_load_end - t_load_start) * 1000:.2f}ms")
|
||||
|
||||
# 组装核心宽表(支持多种 join 类型)
|
||||
t_assemble_start = time.perf_counter()
|
||||
core_table = self._assemble_wide_table_with_specs(
|
||||
table_data, data_specs, start_date, end_date
|
||||
)
|
||||
t_assemble_end = time.perf_counter()
|
||||
|
||||
if self._debug:
|
||||
print(f"\n[DataRouter Debug] 宽表组装:")
|
||||
print(f" - 结果行数: {len(core_table):,}")
|
||||
print(f" - 结果列数: {len(core_table.columns)}")
|
||||
print(f" - 组装耗时: {(t_assemble_end - t_assemble_start) * 1000:.2f}ms")
|
||||
print(f"\n[DataRouter Debug] ========== 数据获取完成 ==========")
|
||||
print(f" - 总耗时: {(t_assemble_end - t_load_start) * 1000:.2f}ms")
|
||||
|
||||
return core_table
|
||||
|
||||
|
||||
@@ -10,6 +10,7 @@
|
||||
5. 返回包含因子结果的数据表
|
||||
"""
|
||||
|
||||
import time
|
||||
from typing import Any, Dict, List, Optional, Set, Union, TYPE_CHECKING
|
||||
|
||||
import polars as pl
|
||||
@@ -31,6 +32,7 @@ from src.factors.engine.data_router import DataRouter
|
||||
from src.factors.engine.planner import ExecutionPlanner
|
||||
from src.factors.engine.compute_engine import ComputeEngine
|
||||
from src.factors.engine.ast_optimizer import ExpressionFlattener
|
||||
from src.factors.engine.performance_profiler import PerformanceProfiler
|
||||
|
||||
|
||||
class FactorEngine:
|
||||
@@ -52,25 +54,28 @@ class FactorEngine:
|
||||
registered_expressions: 注册的表达式字典
|
||||
_registry: 函数注册表
|
||||
_parser: 公式解析器
|
||||
_debug: 是否启用调试模式
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
data_source: Optional[Dict[str, pl.DataFrame]] = None,
|
||||
registry: Optional["FunctionRegistry"] = None,
|
||||
debug: bool = False,
|
||||
) -> None:
|
||||
"""初始化因子引擎。
|
||||
|
||||
Args:
|
||||
data_source: 内存数据源,为 None 时使用数据库连接
|
||||
registry: 函数注册表,None 时创建独立实例
|
||||
debug: 是否启用调试模式,启用后打印详细的性能信息
|
||||
"""
|
||||
from src.factors.registry import FunctionRegistry
|
||||
from src.factors.parser import FormulaParser
|
||||
|
||||
self.router = DataRouter(data_source)
|
||||
self.router = DataRouter(data_source, debug=debug)
|
||||
self.planner = ExecutionPlanner()
|
||||
self.compute_engine = ComputeEngine()
|
||||
self.compute_engine = ComputeEngine(debug=debug)
|
||||
self.registered_expressions: Dict[str, Node] = {}
|
||||
self._plans: Dict[str, ExecutionPlan] = {}
|
||||
|
||||
@@ -83,6 +88,21 @@ class FactorEngine:
|
||||
|
||||
self._metadata = FactorManager()
|
||||
|
||||
# 调试模式配置
|
||||
self._debug = debug
|
||||
|
||||
# 初始化性能分析器
|
||||
self._profiler = PerformanceProfiler(enabled=debug)
|
||||
|
||||
@property
|
||||
def profiler(self) -> PerformanceProfiler:
|
||||
"""获取性能分析器实例
|
||||
|
||||
Returns:
|
||||
PerformanceProfiler 实例
|
||||
"""
|
||||
return self._profiler
|
||||
|
||||
def _register_internal(
|
||||
self,
|
||||
name: str,
|
||||
@@ -312,8 +332,28 @@ class FactorEngine:
|
||||
if isinstance(factor_names, str):
|
||||
factor_names = [factor_names]
|
||||
|
||||
# 重置性能分析器(如果启用)
|
||||
if self._profiler.enabled:
|
||||
self._profiler.reset()
|
||||
|
||||
if self._debug:
|
||||
print(f"\n[FactorEngine Debug] ========== 开始计算 ==========")
|
||||
print(f"[FactorEngine Debug] 目标因子: {factor_names}")
|
||||
print(f"[FactorEngine Debug] 日期范围: {start_date} - {end_date}")
|
||||
print(
|
||||
f"[FactorEngine Debug] 股票代码: {stock_codes if stock_codes else '全市场'}"
|
||||
)
|
||||
|
||||
# 1. 收集所有需要的因子(包括临时因子依赖)
|
||||
t_start = time.perf_counter()
|
||||
all_factor_names = self._collect_all_dependencies(factor_names)
|
||||
t_deps = time.perf_counter()
|
||||
|
||||
if self._debug:
|
||||
print(f"\n[FactorEngine Debug] 依赖收集:")
|
||||
print(f" - 原始因子数: {len(factor_names)}")
|
||||
print(f" - 总因子数(含依赖): {len(all_factor_names)}")
|
||||
print(f" - 耗时: {(t_deps - t_start) * 1000:.2f}ms")
|
||||
|
||||
# 2. 获取执行计划
|
||||
plans = []
|
||||
@@ -350,19 +390,42 @@ class FactorEngine:
|
||||
)
|
||||
)
|
||||
|
||||
if self._debug:
|
||||
print(f"\n[FactorEngine Debug] 数据规格:")
|
||||
print(f" - 涉及的表: {list(table_to_columns.keys())}")
|
||||
for table, cols in table_to_columns.items():
|
||||
print(f" {table}: {len(cols)}个字段")
|
||||
|
||||
# 4. 从路由器获取核心宽表
|
||||
t_fetch_start = time.perf_counter()
|
||||
core_data = self.router.fetch_data(
|
||||
data_specs=unique_specs,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
stock_codes=stock_codes,
|
||||
)
|
||||
t_fetch_end = time.perf_counter()
|
||||
|
||||
if self._debug:
|
||||
print(f"\n[FactorEngine Debug] 数据获取:")
|
||||
print(f" - 数据行数: {len(core_data):,}")
|
||||
print(f" - 数据列数: {len(core_data.columns)}")
|
||||
print(f" - 耗时: {(t_fetch_end - t_fetch_start) * 1000:.2f}ms")
|
||||
|
||||
if len(core_data) == 0:
|
||||
raise ValueError("未获取到任何数据,请检查日期范围和股票代码")
|
||||
|
||||
# 5. 按依赖顺序执行计算(包含临时因子)
|
||||
t_compute_start = time.perf_counter()
|
||||
result = self._execute_with_dependencies(all_factor_names, core_data)
|
||||
t_compute_end = time.perf_counter()
|
||||
|
||||
if self._debug:
|
||||
print(f"\n[FactorEngine Debug] 因子计算:")
|
||||
print(f" - 计算因子数: {len(all_factor_names)}")
|
||||
print(f" - 耗时: {(t_compute_end - t_compute_start) * 1000:.2f}ms")
|
||||
print(f"\n[FactorEngine Debug] ========== 计算完成 ==========")
|
||||
print(f" - 总耗时: {(t_compute_end - t_start) * 1000:.2f}ms")
|
||||
|
||||
# 6. 清理内存宽表,过滤掉临时因子列(__tmp_X)
|
||||
# 保留所有非临时因子列(包括原始数据列和用户请求的因子列)
|
||||
@@ -427,14 +490,31 @@ class FactorEngine:
|
||||
|
||||
# 2. 按顺序执行
|
||||
result = core_data
|
||||
for name in sorted_names:
|
||||
plan = self._plans[name]
|
||||
|
||||
# 创建新的执行计划,引用已计算的依赖列
|
||||
new_plan = self._create_optimized_plan(plan, result)
|
||||
if self._profiler.enabled:
|
||||
# Debug 模式:逐个贪婪计算以获取真实耗时
|
||||
for name in sorted_names:
|
||||
plan = self._plans[name]
|
||||
|
||||
# 执行计算
|
||||
result = self.compute_engine.execute(new_plan, result)
|
||||
# 创建新的执行计划,引用已计算的依赖列
|
||||
new_plan = self._create_optimized_plan(plan, result)
|
||||
|
||||
# 使用性能分析器计时,触发真实计算
|
||||
with self._profiler.measure(name):
|
||||
# 逐个执行,触发 Polars 底层计算
|
||||
result = result.with_columns(
|
||||
[new_plan.polars_expr.alias(new_plan.output_name)]
|
||||
)
|
||||
else:
|
||||
# 生产模式:批量执行,享受 Polars 并行优化
|
||||
for name in sorted_names:
|
||||
plan = self._plans[name]
|
||||
|
||||
# 创建新的执行计划,引用已计算的依赖列
|
||||
new_plan = self._create_optimized_plan(plan, result)
|
||||
|
||||
# 执行计算
|
||||
result = self.compute_engine.execute(new_plan, result)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
138
src/factors/engine/performance_profiler.py
Normal file
138
src/factors/engine/performance_profiler.py
Normal file
@@ -0,0 +1,138 @@
|
||||
"""性能分析器 - 独立组件,与 FactorEngine 解耦
|
||||
|
||||
支持上下文管理器用法:
|
||||
with profiler.measure("factor_name"):
|
||||
df = df.with_columns(expr) # 触发真实计算
|
||||
"""
|
||||
|
||||
import time
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, Iterator, List, Optional, Tuple
|
||||
|
||||
|
||||
@dataclass
|
||||
class ProfileRecord:
|
||||
"""性能记录数据结构"""
|
||||
|
||||
time_ms: float
|
||||
# 预留扩展字段
|
||||
# memory_mb: float = 0.0
|
||||
# call_count: int = 1
|
||||
|
||||
|
||||
class PerformanceProfiler:
|
||||
"""性能分析器 - 独立组件,与 FactorEngine 解耦
|
||||
|
||||
支持上下文管理器用法:
|
||||
with profiler.measure("factor_name"):
|
||||
df = df.with_columns(expr) # 触发真实计算
|
||||
"""
|
||||
|
||||
def __init__(self, enabled: bool = False):
|
||||
self.enabled = enabled
|
||||
self.records: Dict[str, ProfileRecord] = {}
|
||||
self._context_stack: List[str] = [] # 支持嵌套计时
|
||||
|
||||
@contextmanager
|
||||
def measure(self, name: str) -> Iterator[None]:
|
||||
"""上下文管理器:安全计时,自动处理异常
|
||||
|
||||
Args:
|
||||
name: 计时任务名称(通常是因子名称)
|
||||
|
||||
Yields:
|
||||
None
|
||||
"""
|
||||
if not self.enabled:
|
||||
yield
|
||||
return
|
||||
|
||||
start_time = time.perf_counter()
|
||||
self._context_stack.append(name)
|
||||
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
elapsed = time.perf_counter() - start_time
|
||||
self._context_stack.pop()
|
||||
|
||||
# 更加简洁和安全的更新方式
|
||||
if name not in self.records:
|
||||
self.records[name] = ProfileRecord(time_ms=0.0)
|
||||
|
||||
self.records[name].time_ms += elapsed
|
||||
|
||||
def get_report(self) -> Dict[str, Any]:
|
||||
"""生成性能报告
|
||||
|
||||
Returns:
|
||||
包含性能统计数据的字典:
|
||||
- enabled: 是否启用了性能监控
|
||||
- total_time_ms: 总计算时间(毫秒)
|
||||
- factor_count: 因子数量
|
||||
- avg_time_ms: 平均耗时(毫秒)
|
||||
- slowest_factors: 最慢的5个因子列表
|
||||
- all_records: 所有因子的详细记录
|
||||
"""
|
||||
if not self.records:
|
||||
return {}
|
||||
|
||||
times = [r.time_ms for r in self.records.values()]
|
||||
sorted_records = sorted(
|
||||
self.records.items(), key=lambda x: x[1].time_ms, reverse=True
|
||||
)
|
||||
|
||||
return {
|
||||
"enabled": self.enabled,
|
||||
"total_time_ms": sum(times) * 1000,
|
||||
"factor_count": len(self.records),
|
||||
"avg_time_ms": (sum(times) / len(times)) * 1000,
|
||||
"slowest_factors": [
|
||||
(name, record.time_ms * 1000) for name, record in sorted_records[:5]
|
||||
],
|
||||
"all_records": {
|
||||
name: {"time_ms": record.time_ms * 1000}
|
||||
for name, record in self.records.items()
|
||||
},
|
||||
}
|
||||
|
||||
def print_report(self) -> None:
|
||||
"""打印格式化的性能报告"""
|
||||
if not self.enabled:
|
||||
print("[Performance] 性能监控未启用")
|
||||
return
|
||||
|
||||
report = self.get_report()
|
||||
if not report:
|
||||
print("[Performance] 无性能数据")
|
||||
return
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("Factor Engine 性能报告")
|
||||
print("=" * 60)
|
||||
print(f"总计算时间: {report['total_time_ms']:.2f}ms")
|
||||
print(f"因子数量: {report['factor_count']}")
|
||||
print(f"平均耗时: {report['avg_time_ms']:.2f}ms")
|
||||
print(f"\n最慢的 5 个因子:")
|
||||
for i, (name, time_ms) in enumerate(report["slowest_factors"], 1):
|
||||
print(f" {i}. {name}: {time_ms:.2f}ms")
|
||||
print("=" * 60)
|
||||
|
||||
def reset(self) -> None:
|
||||
"""重置性能数据"""
|
||||
self.records.clear()
|
||||
self._context_stack.clear()
|
||||
|
||||
def get_factor_time(self, name: str) -> Optional[float]:
|
||||
"""获取指定因子的计算时间
|
||||
|
||||
Args:
|
||||
name: 因子名称
|
||||
|
||||
Returns:
|
||||
计算时间(秒),如果未找到则返回 None
|
||||
"""
|
||||
if name in self.records:
|
||||
return self.records[name].time_ms
|
||||
return None
|
||||
193
tests/test_performance_profiler.py
Normal file
193
tests/test_performance_profiler.py
Normal file
@@ -0,0 +1,193 @@
|
||||
"""性能分析器测试"""
|
||||
|
||||
import pytest
|
||||
import time
|
||||
from src.factors.engine.performance_profiler import PerformanceProfiler, ProfileRecord
|
||||
|
||||
|
||||
class TestProfileRecord:
|
||||
"""测试 ProfileRecord 数据类"""
|
||||
|
||||
def test_profile_record_creation(self):
|
||||
"""测试创建 ProfileRecord"""
|
||||
record = ProfileRecord(time_ms=1.5)
|
||||
assert record.time_ms == 1.5
|
||||
|
||||
|
||||
class TestPerformanceProfiler:
|
||||
"""测试 PerformanceProfiler 类"""
|
||||
|
||||
def test_profiler_disabled_by_default(self):
|
||||
"""测试默认情况下禁用"""
|
||||
profiler = PerformanceProfiler()
|
||||
assert profiler.enabled is False
|
||||
|
||||
def test_profiler_enabled(self):
|
||||
"""测试启用性能分析器"""
|
||||
profiler = PerformanceProfiler(enabled=True)
|
||||
assert profiler.enabled is True
|
||||
|
||||
def test_measure_context_manager_disabled(self):
|
||||
"""测试禁用时 measure 上下文管理器"""
|
||||
profiler = PerformanceProfiler(enabled=False)
|
||||
|
||||
with profiler.measure("test_task"):
|
||||
time.sleep(0.001)
|
||||
|
||||
# 禁用时不会记录数据
|
||||
assert len(profiler.records) == 0
|
||||
|
||||
def test_measure_context_manager_enabled(self):
|
||||
"""测试启用时 measure 上下文管理器"""
|
||||
profiler = PerformanceProfiler(enabled=True)
|
||||
|
||||
with profiler.measure("test_task"):
|
||||
time.sleep(0.01) # 10ms
|
||||
|
||||
assert "test_task" in profiler.records
|
||||
assert profiler.records["test_task"].time_ms >= 0.01
|
||||
|
||||
def test_measure_multiple_calls_accumulate(self):
|
||||
"""测试多次调用同名任务会累加时间"""
|
||||
profiler = PerformanceProfiler(enabled=True)
|
||||
|
||||
# 第一次调用
|
||||
with profiler.measure("task"):
|
||||
time.sleep(0.01)
|
||||
|
||||
first_time = profiler.records["task"].time_ms
|
||||
|
||||
# 第二次调用
|
||||
with profiler.measure("task"):
|
||||
time.sleep(0.01)
|
||||
|
||||
# 时间应该累加
|
||||
assert profiler.records["task"].time_ms > first_time
|
||||
|
||||
def test_measure_exception_handling(self):
|
||||
"""测试异常发生时计时器正常闭合"""
|
||||
profiler = PerformanceProfiler(enabled=True)
|
||||
|
||||
try:
|
||||
with profiler.measure("failing_task"):
|
||||
time.sleep(0.01)
|
||||
raise ValueError("Test error")
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
# 异常发生时仍然记录了时间
|
||||
assert "failing_task" in profiler.records
|
||||
assert profiler.records["failing_task"].time_ms >= 0.01
|
||||
|
||||
def test_get_report_empty(self):
|
||||
"""测试空记录时的报告"""
|
||||
profiler = PerformanceProfiler(enabled=True)
|
||||
report = profiler.get_report()
|
||||
assert report == {}
|
||||
|
||||
def test_get_report_with_data(self):
|
||||
"""测试有数据时的报告"""
|
||||
profiler = PerformanceProfiler(enabled=True)
|
||||
|
||||
with profiler.measure("task1"):
|
||||
time.sleep(0.02)
|
||||
with profiler.measure("task2"):
|
||||
time.sleep(0.01)
|
||||
|
||||
report = profiler.get_report()
|
||||
|
||||
assert report["enabled"] is True
|
||||
assert report["factor_count"] == 2
|
||||
assert report["total_time_ms"] > 0
|
||||
assert report["avg_time_ms"] > 0
|
||||
assert len(report["slowest_factors"]) <= 2
|
||||
assert "all_records" in report
|
||||
|
||||
def test_get_report_sorts_by_time(self):
|
||||
"""测试报告按时间排序"""
|
||||
profiler = PerformanceProfiler(enabled=True)
|
||||
|
||||
# task1 耗时更长
|
||||
with profiler.measure("task1"):
|
||||
time.sleep(0.02)
|
||||
|
||||
with profiler.measure("task2"):
|
||||
time.sleep(0.01)
|
||||
|
||||
report = profiler.get_report()
|
||||
slowest = report["slowest_factors"]
|
||||
|
||||
assert slowest[0][0] == "task1" # 最慢的应该是 task1
|
||||
|
||||
def test_reset_clears_data(self):
|
||||
"""测试重置清除数据"""
|
||||
profiler = PerformanceProfiler(enabled=True)
|
||||
|
||||
with profiler.measure("task"):
|
||||
time.sleep(0.01)
|
||||
|
||||
assert len(profiler.records) == 1
|
||||
|
||||
profiler.reset()
|
||||
|
||||
assert len(profiler.records) == 0
|
||||
assert len(profiler._context_stack) == 0
|
||||
|
||||
def test_print_report_disabled(self):
|
||||
"""测试禁用时打印报告"""
|
||||
profiler = PerformanceProfiler(enabled=False)
|
||||
# 应该正常执行不报错
|
||||
profiler.print_report()
|
||||
|
||||
def test_print_report_empty(self, capsys):
|
||||
"""测试空数据时打印报告"""
|
||||
profiler = PerformanceProfiler(enabled=True)
|
||||
profiler.print_report()
|
||||
|
||||
captured = capsys.readouterr()
|
||||
assert "无性能数据" in captured.out
|
||||
|
||||
def test_print_report_with_data(self, capsys):
|
||||
"""测试有数据时打印报告"""
|
||||
profiler = PerformanceProfiler(enabled=True)
|
||||
|
||||
with profiler.measure("slow_task"):
|
||||
time.sleep(0.02)
|
||||
with profiler.measure("fast_task"):
|
||||
time.sleep(0.01)
|
||||
|
||||
profiler.print_report()
|
||||
|
||||
captured = capsys.readouterr()
|
||||
assert "Factor Engine 性能报告" in captured.out
|
||||
assert "slow_task" in captured.out
|
||||
assert "fast_task" in captured.out
|
||||
|
||||
def test_get_factor_time(self):
|
||||
"""测试获取单个因子的计算时间"""
|
||||
profiler = PerformanceProfiler(enabled=True)
|
||||
|
||||
with profiler.measure("task"):
|
||||
time.sleep(0.01)
|
||||
|
||||
time_val = profiler.get_factor_time("task")
|
||||
assert time_val is not None
|
||||
assert time_val >= 0.01
|
||||
|
||||
# 不存在的因子返回 None
|
||||
assert profiler.get_factor_time("nonexistent") is None
|
||||
|
||||
def test_nested_measure(self):
|
||||
"""测试嵌套计时"""
|
||||
profiler = PerformanceProfiler(enabled=True)
|
||||
|
||||
with profiler.measure("outer"):
|
||||
time.sleep(0.01)
|
||||
with profiler.measure("inner"):
|
||||
time.sleep(0.01)
|
||||
|
||||
assert "outer" in profiler.records
|
||||
assert "inner" in profiler.records
|
||||
# 嵌套计时应该分别记录
|
||||
assert profiler.records["outer"].time_ms >= 0.01
|
||||
assert profiler.records["inner"].time_ms >= 0.01
|
||||
Reference in New Issue
Block a user