feat(factors/engine): 添加性能分析器支持 debug 模式

- 新增 PerformanceProfiler 组件,与 FactorEngine 解耦
- 支持上下文管理器用法,安全计时处理异常
- 为 DataRouter/ComputeEngine/FactorEngine 添加 debug 参数
This commit is contained in:
2026-03-15 19:34:33 +08:00
parent 7bf2699652
commit 81e89f3796
6 changed files with 513 additions and 17 deletions

View File

@@ -4,8 +4,10 @@
使用方法:
uv run python src/experiment/probe_selection/run_probe_selection_all_factors.py
uv run python src/experiment/probe_selection/run_probe_selection_all_factors.py --debug
"""
import argparse
import os
import sys
@@ -180,15 +182,26 @@ def apply_preprocessing_for_probe(
return data
def run_probe_feature_selection_with_all_factors():
"""执行探针法因子筛选(使用 FactorManager 中所有因子)"""
def run_probe_feature_selection_with_all_factors(debug: bool = True):
"""执行探针法因子筛选(使用 FactorManager 中所有因子)
Args:
debug: 是否启用 debug 模式,显示详细的性能统计信息
Returns:
筛选后的特征列表
"""
print("\n" + "=" * 80)
print("增强探针法因子筛选 - 使用 FactorManager 全部因子")
if debug:
print("[DEBUG] 性能监控模式已启用")
print("=" * 80)
# 1. 创建 FactorEngine
print("\n[1] 创建 FactorEngine")
engine = FactorEngine()
if debug:
print("[DEBUG] 启用性能监控模式")
engine = FactorEngine(debug=debug)
# 2. 从 FactorManager 注册所有因子
print("\n[2] 从 FactorManager 注册所有因子")
@@ -203,6 +216,11 @@ def run_probe_feature_selection_with_all_factors():
end_date=VAL_END, # 包含验证集,增加样本量
)
# 3.5. 打印性能报告(如果启用了 debug 模式)
if debug:
print("\n[DEBUG] 生成性能报告")
engine.profiler.print_report()
# 4. 股票池筛选
print("\n[4] 执行股票池筛选")
pool_manager = StockPoolManager(
@@ -309,5 +327,23 @@ def run_probe_feature_selection_with_all_factors():
return selected_features
def main():
"""主入口函数,支持命令行参数"""
parser = argparse.ArgumentParser(
description="探针法因子筛选 - 使用 FactorManager 中所有因子"
)
parser.add_argument(
"--debug",
"-d",
action="store_true",
help="启用 debug 模式,显示详细的性能统计信息",
)
args = parser.parse_args()
selected = run_probe_feature_selection_with_all_factors(debug=args.debug)
return selected
if __name__ == "__main__":
selected = run_probe_feature_selection_with_all_factors()
main()

View File

@@ -6,7 +6,8 @@
避免 Python 层面的多进程/多线程开销。
"""
from typing import Dict, List, Set
import time
from typing import Dict, List, Optional, Set
import polars as pl
@@ -25,9 +26,13 @@ class ComputeEngine:
4. Polars 自动在所有 CPU 核心上并行计算,零拷贝内存
"""
def __init__(self) -> None:
"""初始化计算引擎。"""
pass
def __init__(self, debug: bool = False) -> None:
"""初始化计算引擎。
Args:
debug: 是否启用调试模式
"""
self._debug = debug
def execute(
self,

View File

@@ -6,6 +6,7 @@
支持标准等值匹配和 asof_backward财务数据两种拼接模式。
"""
import time
from typing import Any, Dict, List, Optional, Set, Union
import threading
@@ -25,19 +26,24 @@ class DataRouter:
Attributes:
data_source: 数据源,可以是内存 DataFrame 字典或数据库连接
is_memory_mode: 是否为内存模式
_debug: 是否启用调试模式
"""
def __init__(self, data_source: Optional[Dict[str, pl.DataFrame]] = None) -> None:
def __init__(
self, data_source: Optional[Dict[str, pl.DataFrame]] = None, debug: bool = False
) -> None:
"""初始化数据路由器。
Args:
data_source: 内存数据源,字典格式 {表名: DataFrame}
为 None 时自动连接 DuckDB 数据库
debug: 是否启用调试模式
"""
self.data_source = data_source or {}
self.is_memory_mode = data_source is not None
self._cache: Dict[str, pl.DataFrame] = {}
self._lock = threading.Lock()
self._debug = debug
# 数据库模式下初始化 Storage 和 FinancialLoader
if not self.is_memory_mode:
@@ -71,6 +77,14 @@ class DataRouter:
if not data_specs:
raise ValueError("数据规格不能为空")
if self._debug:
print(f"\n[DataRouter Debug] ========== 开始数据获取 ==========")
print(f"[DataRouter Debug] 日期范围: {start_date} - {end_date}")
print(
f"[DataRouter Debug] 股票代码: {stock_codes if stock_codes else '全市场'}"
)
print(f"[DataRouter Debug] 数据规格数: {len(data_specs)}")
# 收集所有需要的表和字段
required_tables: Dict[str, Set[str]] = {}
@@ -79,8 +93,16 @@ class DataRouter:
required_tables[spec.table] = set()
required_tables[spec.table].update(spec.columns)
if self._debug:
print(f"\n[DataRouter Debug] 需要加载的表:")
for table_name, columns in required_tables.items():
print(f" - {table_name}: {len(columns)}个字段")
# 从数据源获取各表数据(使用合并后的 required_tables避免重复加载
table_data = {}
table_load_times: Dict[str, float] = {}
t_load_start = time.perf_counter()
for table_name, columns in required_tables.items():
# 判断是标准表还是财务表
is_financial = any(
@@ -110,18 +132,40 @@ class DataRouter:
join_type="standard",
)
t_table_start = time.perf_counter()
df = self._load_table_from_spec(
spec=spec,
start_date=start_date,
end_date=end_date,
stock_codes=stock_codes,
)
t_table_end = time.perf_counter()
table_data[table_name] = df
table_load_times[table_name] = t_table_end - t_table_start
t_load_end = time.perf_counter()
if self._debug:
print(f"\n[DataRouter Debug] 表加载详情:")
for table_name, load_time in table_load_times.items():
rows = len(table_data[table_name])
print(f" - {table_name}: {rows:,}行, {load_time * 1000:.2f}ms")
print(f" - 表加载总耗时: {(t_load_end - t_load_start) * 1000:.2f}ms")
# 组装核心宽表(支持多种 join 类型)
t_assemble_start = time.perf_counter()
core_table = self._assemble_wide_table_with_specs(
table_data, data_specs, start_date, end_date
)
t_assemble_end = time.perf_counter()
if self._debug:
print(f"\n[DataRouter Debug] 宽表组装:")
print(f" - 结果行数: {len(core_table):,}")
print(f" - 结果列数: {len(core_table.columns)}")
print(f" - 组装耗时: {(t_assemble_end - t_assemble_start) * 1000:.2f}ms")
print(f"\n[DataRouter Debug] ========== 数据获取完成 ==========")
print(f" - 总耗时: {(t_assemble_end - t_load_start) * 1000:.2f}ms")
return core_table

View File

@@ -10,6 +10,7 @@
5. 返回包含因子结果的数据表
"""
import time
from typing import Any, Dict, List, Optional, Set, Union, TYPE_CHECKING
import polars as pl
@@ -31,6 +32,7 @@ from src.factors.engine.data_router import DataRouter
from src.factors.engine.planner import ExecutionPlanner
from src.factors.engine.compute_engine import ComputeEngine
from src.factors.engine.ast_optimizer import ExpressionFlattener
from src.factors.engine.performance_profiler import PerformanceProfiler
class FactorEngine:
@@ -52,25 +54,28 @@ class FactorEngine:
registered_expressions: 注册的表达式字典
_registry: 函数注册表
_parser: 公式解析器
_debug: 是否启用调试模式
"""
def __init__(
self,
data_source: Optional[Dict[str, pl.DataFrame]] = None,
registry: Optional["FunctionRegistry"] = None,
debug: bool = False,
) -> None:
"""初始化因子引擎。
Args:
data_source: 内存数据源,为 None 时使用数据库连接
registry: 函数注册表None 时创建独立实例
debug: 是否启用调试模式,启用后打印详细的性能信息
"""
from src.factors.registry import FunctionRegistry
from src.factors.parser import FormulaParser
self.router = DataRouter(data_source)
self.router = DataRouter(data_source, debug=debug)
self.planner = ExecutionPlanner()
self.compute_engine = ComputeEngine()
self.compute_engine = ComputeEngine(debug=debug)
self.registered_expressions: Dict[str, Node] = {}
self._plans: Dict[str, ExecutionPlan] = {}
@@ -83,6 +88,21 @@ class FactorEngine:
self._metadata = FactorManager()
# 调试模式配置
self._debug = debug
# 初始化性能分析器
self._profiler = PerformanceProfiler(enabled=debug)
@property
def profiler(self) -> PerformanceProfiler:
"""获取性能分析器实例
Returns:
PerformanceProfiler 实例
"""
return self._profiler
def _register_internal(
self,
name: str,
@@ -312,8 +332,28 @@ class FactorEngine:
if isinstance(factor_names, str):
factor_names = [factor_names]
# 重置性能分析器(如果启用)
if self._profiler.enabled:
self._profiler.reset()
if self._debug:
print(f"\n[FactorEngine Debug] ========== 开始计算 ==========")
print(f"[FactorEngine Debug] 目标因子: {factor_names}")
print(f"[FactorEngine Debug] 日期范围: {start_date} - {end_date}")
print(
f"[FactorEngine Debug] 股票代码: {stock_codes if stock_codes else '全市场'}"
)
# 1. 收集所有需要的因子(包括临时因子依赖)
t_start = time.perf_counter()
all_factor_names = self._collect_all_dependencies(factor_names)
t_deps = time.perf_counter()
if self._debug:
print(f"\n[FactorEngine Debug] 依赖收集:")
print(f" - 原始因子数: {len(factor_names)}")
print(f" - 总因子数(含依赖): {len(all_factor_names)}")
print(f" - 耗时: {(t_deps - t_start) * 1000:.2f}ms")
# 2. 获取执行计划
plans = []
@@ -350,19 +390,42 @@ class FactorEngine:
)
)
if self._debug:
print(f"\n[FactorEngine Debug] 数据规格:")
print(f" - 涉及的表: {list(table_to_columns.keys())}")
for table, cols in table_to_columns.items():
print(f" {table}: {len(cols)}个字段")
# 4. 从路由器获取核心宽表
t_fetch_start = time.perf_counter()
core_data = self.router.fetch_data(
data_specs=unique_specs,
start_date=start_date,
end_date=end_date,
stock_codes=stock_codes,
)
t_fetch_end = time.perf_counter()
if self._debug:
print(f"\n[FactorEngine Debug] 数据获取:")
print(f" - 数据行数: {len(core_data):,}")
print(f" - 数据列数: {len(core_data.columns)}")
print(f" - 耗时: {(t_fetch_end - t_fetch_start) * 1000:.2f}ms")
if len(core_data) == 0:
raise ValueError("未获取到任何数据,请检查日期范围和股票代码")
# 5. 按依赖顺序执行计算(包含临时因子)
t_compute_start = time.perf_counter()
result = self._execute_with_dependencies(all_factor_names, core_data)
t_compute_end = time.perf_counter()
if self._debug:
print(f"\n[FactorEngine Debug] 因子计算:")
print(f" - 计算因子数: {len(all_factor_names)}")
print(f" - 耗时: {(t_compute_end - t_compute_start) * 1000:.2f}ms")
print(f"\n[FactorEngine Debug] ========== 计算完成 ==========")
print(f" - 总耗时: {(t_compute_end - t_start) * 1000:.2f}ms")
# 6. 清理内存宽表过滤掉临时因子列__tmp_X
# 保留所有非临时因子列(包括原始数据列和用户请求的因子列)
@@ -427,14 +490,31 @@ class FactorEngine:
# 2. 按顺序执行
result = core_data
for name in sorted_names:
plan = self._plans[name]
# 创建新的执行计划,引用已计算的依赖列
new_plan = self._create_optimized_plan(plan, result)
if self._profiler.enabled:
# Debug 模式:逐个贪婪计算以获取真实耗时
for name in sorted_names:
plan = self._plans[name]
# 执行计算
result = self.compute_engine.execute(new_plan, result)
# 创建新的执行计划,引用已计算的依赖列
new_plan = self._create_optimized_plan(plan, result)
# 使用性能分析器计时,触发真实计算
with self._profiler.measure(name):
# 逐个执行,触发 Polars 底层计算
result = result.with_columns(
[new_plan.polars_expr.alias(new_plan.output_name)]
)
else:
# 生产模式:批量执行,享受 Polars 并行优化
for name in sorted_names:
plan = self._plans[name]
# 创建新的执行计划,引用已计算的依赖列
new_plan = self._create_optimized_plan(plan, result)
# 执行计算
result = self.compute_engine.execute(new_plan, result)
return result

View File

@@ -0,0 +1,138 @@
"""性能分析器 - 独立组件,与 FactorEngine 解耦
支持上下文管理器用法:
with profiler.measure("factor_name"):
df = df.with_columns(expr) # 触发真实计算
"""
import time
from contextlib import contextmanager
from dataclasses import dataclass
from typing import Any, Dict, Iterator, List, Optional, Tuple
@dataclass
class ProfileRecord:
"""性能记录数据结构"""
time_ms: float
# 预留扩展字段
# memory_mb: float = 0.0
# call_count: int = 1
class PerformanceProfiler:
"""性能分析器 - 独立组件,与 FactorEngine 解耦
支持上下文管理器用法:
with profiler.measure("factor_name"):
df = df.with_columns(expr) # 触发真实计算
"""
def __init__(self, enabled: bool = False):
self.enabled = enabled
self.records: Dict[str, ProfileRecord] = {}
self._context_stack: List[str] = [] # 支持嵌套计时
@contextmanager
def measure(self, name: str) -> Iterator[None]:
"""上下文管理器:安全计时,自动处理异常
Args:
name: 计时任务名称(通常是因子名称)
Yields:
None
"""
if not self.enabled:
yield
return
start_time = time.perf_counter()
self._context_stack.append(name)
try:
yield
finally:
elapsed = time.perf_counter() - start_time
self._context_stack.pop()
# 更加简洁和安全的更新方式
if name not in self.records:
self.records[name] = ProfileRecord(time_ms=0.0)
self.records[name].time_ms += elapsed
def get_report(self) -> Dict[str, Any]:
"""生成性能报告
Returns:
包含性能统计数据的字典:
- enabled: 是否启用了性能监控
- total_time_ms: 总计算时间(毫秒)
- factor_count: 因子数量
- avg_time_ms: 平均耗时(毫秒)
- slowest_factors: 最慢的5个因子列表
- all_records: 所有因子的详细记录
"""
if not self.records:
return {}
times = [r.time_ms for r in self.records.values()]
sorted_records = sorted(
self.records.items(), key=lambda x: x[1].time_ms, reverse=True
)
return {
"enabled": self.enabled,
"total_time_ms": sum(times) * 1000,
"factor_count": len(self.records),
"avg_time_ms": (sum(times) / len(times)) * 1000,
"slowest_factors": [
(name, record.time_ms * 1000) for name, record in sorted_records[:5]
],
"all_records": {
name: {"time_ms": record.time_ms * 1000}
for name, record in self.records.items()
},
}
def print_report(self) -> None:
"""打印格式化的性能报告"""
if not self.enabled:
print("[Performance] 性能监控未启用")
return
report = self.get_report()
if not report:
print("[Performance] 无性能数据")
return
print("\n" + "=" * 60)
print("Factor Engine 性能报告")
print("=" * 60)
print(f"总计算时间: {report['total_time_ms']:.2f}ms")
print(f"因子数量: {report['factor_count']}")
print(f"平均耗时: {report['avg_time_ms']:.2f}ms")
print(f"\n最慢的 5 个因子:")
for i, (name, time_ms) in enumerate(report["slowest_factors"], 1):
print(f" {i}. {name}: {time_ms:.2f}ms")
print("=" * 60)
def reset(self) -> None:
"""重置性能数据"""
self.records.clear()
self._context_stack.clear()
def get_factor_time(self, name: str) -> Optional[float]:
"""获取指定因子的计算时间
Args:
name: 因子名称
Returns:
计算时间(秒),如果未找到则返回 None
"""
if name in self.records:
return self.records[name].time_ms
return None

View File

@@ -0,0 +1,193 @@
"""性能分析器测试"""
import pytest
import time
from src.factors.engine.performance_profiler import PerformanceProfiler, ProfileRecord
class TestProfileRecord:
"""测试 ProfileRecord 数据类"""
def test_profile_record_creation(self):
"""测试创建 ProfileRecord"""
record = ProfileRecord(time_ms=1.5)
assert record.time_ms == 1.5
class TestPerformanceProfiler:
"""测试 PerformanceProfiler 类"""
def test_profiler_disabled_by_default(self):
"""测试默认情况下禁用"""
profiler = PerformanceProfiler()
assert profiler.enabled is False
def test_profiler_enabled(self):
"""测试启用性能分析器"""
profiler = PerformanceProfiler(enabled=True)
assert profiler.enabled is True
def test_measure_context_manager_disabled(self):
"""测试禁用时 measure 上下文管理器"""
profiler = PerformanceProfiler(enabled=False)
with profiler.measure("test_task"):
time.sleep(0.001)
# 禁用时不会记录数据
assert len(profiler.records) == 0
def test_measure_context_manager_enabled(self):
"""测试启用时 measure 上下文管理器"""
profiler = PerformanceProfiler(enabled=True)
with profiler.measure("test_task"):
time.sleep(0.01) # 10ms
assert "test_task" in profiler.records
assert profiler.records["test_task"].time_ms >= 0.01
def test_measure_multiple_calls_accumulate(self):
"""测试多次调用同名任务会累加时间"""
profiler = PerformanceProfiler(enabled=True)
# 第一次调用
with profiler.measure("task"):
time.sleep(0.01)
first_time = profiler.records["task"].time_ms
# 第二次调用
with profiler.measure("task"):
time.sleep(0.01)
# 时间应该累加
assert profiler.records["task"].time_ms > first_time
def test_measure_exception_handling(self):
"""测试异常发生时计时器正常闭合"""
profiler = PerformanceProfiler(enabled=True)
try:
with profiler.measure("failing_task"):
time.sleep(0.01)
raise ValueError("Test error")
except ValueError:
pass
# 异常发生时仍然记录了时间
assert "failing_task" in profiler.records
assert profiler.records["failing_task"].time_ms >= 0.01
def test_get_report_empty(self):
"""测试空记录时的报告"""
profiler = PerformanceProfiler(enabled=True)
report = profiler.get_report()
assert report == {}
def test_get_report_with_data(self):
"""测试有数据时的报告"""
profiler = PerformanceProfiler(enabled=True)
with profiler.measure("task1"):
time.sleep(0.02)
with profiler.measure("task2"):
time.sleep(0.01)
report = profiler.get_report()
assert report["enabled"] is True
assert report["factor_count"] == 2
assert report["total_time_ms"] > 0
assert report["avg_time_ms"] > 0
assert len(report["slowest_factors"]) <= 2
assert "all_records" in report
def test_get_report_sorts_by_time(self):
"""测试报告按时间排序"""
profiler = PerformanceProfiler(enabled=True)
# task1 耗时更长
with profiler.measure("task1"):
time.sleep(0.02)
with profiler.measure("task2"):
time.sleep(0.01)
report = profiler.get_report()
slowest = report["slowest_factors"]
assert slowest[0][0] == "task1" # 最慢的应该是 task1
def test_reset_clears_data(self):
"""测试重置清除数据"""
profiler = PerformanceProfiler(enabled=True)
with profiler.measure("task"):
time.sleep(0.01)
assert len(profiler.records) == 1
profiler.reset()
assert len(profiler.records) == 0
assert len(profiler._context_stack) == 0
def test_print_report_disabled(self):
"""测试禁用时打印报告"""
profiler = PerformanceProfiler(enabled=False)
# 应该正常执行不报错
profiler.print_report()
def test_print_report_empty(self, capsys):
"""测试空数据时打印报告"""
profiler = PerformanceProfiler(enabled=True)
profiler.print_report()
captured = capsys.readouterr()
assert "无性能数据" in captured.out
def test_print_report_with_data(self, capsys):
"""测试有数据时打印报告"""
profiler = PerformanceProfiler(enabled=True)
with profiler.measure("slow_task"):
time.sleep(0.02)
with profiler.measure("fast_task"):
time.sleep(0.01)
profiler.print_report()
captured = capsys.readouterr()
assert "Factor Engine 性能报告" in captured.out
assert "slow_task" in captured.out
assert "fast_task" in captured.out
def test_get_factor_time(self):
"""测试获取单个因子的计算时间"""
profiler = PerformanceProfiler(enabled=True)
with profiler.measure("task"):
time.sleep(0.01)
time_val = profiler.get_factor_time("task")
assert time_val is not None
assert time_val >= 0.01
# 不存在的因子返回 None
assert profiler.get_factor_time("nonexistent") is None
def test_nested_measure(self):
"""测试嵌套计时"""
profiler = PerformanceProfiler(enabled=True)
with profiler.measure("outer"):
time.sleep(0.01)
with profiler.measure("inner"):
time.sleep(0.01)
assert "outer" in profiler.records
assert "inner" in profiler.records
# 嵌套计时应该分别记录
assert profiler.records["outer"].time_ms >= 0.01
assert profiler.records["inner"].time_ms >= 0.01