From 81e89f3796c94ffb0908468d6e9cc42d89cd9fa6 Mon Sep 17 00:00:00 2001 From: liaozhaorun <1300336796@qq.com> Date: Sun, 15 Mar 2026 19:34:33 +0800 Subject: [PATCH] =?UTF-8?q?feat(factors/engine):=20=E6=B7=BB=E5=8A=A0?= =?UTF-8?q?=E6=80=A7=E8=83=BD=E5=88=86=E6=9E=90=E5=99=A8=E6=94=AF=E6=8C=81?= =?UTF-8?q?=20debug=20=E6=A8=A1=E5=BC=8F=20-=20=E6=96=B0=E5=A2=9E=20Perfor?= =?UTF-8?q?manceProfiler=20=E7=BB=84=E4=BB=B6=EF=BC=8C=E4=B8=8E=20FactorEn?= =?UTF-8?q?gine=20=E8=A7=A3=E8=80=A6=20-=20=E6=94=AF=E6=8C=81=E4=B8=8A?= =?UTF-8?q?=E4=B8=8B=E6=96=87=E7=AE=A1=E7=90=86=E5=99=A8=E7=94=A8=E6=B3=95?= =?UTF-8?q?=EF=BC=8C=E5=AE=89=E5=85=A8=E8=AE=A1=E6=97=B6=E5=A4=84=E7=90=86?= =?UTF-8?q?=E5=BC=82=E5=B8=B8=20-=20=E4=B8=BA=20DataRouter/ComputeEngine/F?= =?UTF-8?q?actorEngine=20=E6=B7=BB=E5=8A=A0=20debug=20=E5=8F=82=E6=95=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../run_probe_selection_all_factors.py | 44 +++- src/factors/engine/compute_engine.py | 13 +- src/factors/engine/data_router.py | 46 ++++- src/factors/engine/factor_engine.py | 96 ++++++++- src/factors/engine/performance_profiler.py | 138 +++++++++++++ tests/test_performance_profiler.py | 193 ++++++++++++++++++ 6 files changed, 513 insertions(+), 17 deletions(-) create mode 100644 src/factors/engine/performance_profiler.py create mode 100644 tests/test_performance_profiler.py diff --git a/src/experiment/probe_selection/run_probe_selection_all_factors.py b/src/experiment/probe_selection/run_probe_selection_all_factors.py index b134767..46781fe 100644 --- a/src/experiment/probe_selection/run_probe_selection_all_factors.py +++ b/src/experiment/probe_selection/run_probe_selection_all_factors.py @@ -4,8 +4,10 @@ 使用方法: uv run python src/experiment/probe_selection/run_probe_selection_all_factors.py + uv run python src/experiment/probe_selection/run_probe_selection_all_factors.py --debug """ +import argparse import os import sys @@ -180,15 +182,26 @@ def apply_preprocessing_for_probe( return data -def run_probe_feature_selection_with_all_factors(): - """执行探针法因子筛选(使用 FactorManager 中所有因子)""" +def run_probe_feature_selection_with_all_factors(debug: bool = True): + """执行探针法因子筛选(使用 FactorManager 中所有因子) + + Args: + debug: 是否启用 debug 模式,显示详细的性能统计信息 + + Returns: + 筛选后的特征列表 + """ print("\n" + "=" * 80) print("增强探针法因子筛选 - 使用 FactorManager 全部因子") + if debug: + print("[DEBUG] 性能监控模式已启用") print("=" * 80) # 1. 创建 FactorEngine print("\n[1] 创建 FactorEngine") - engine = FactorEngine() + if debug: + print("[DEBUG] 启用性能监控模式") + engine = FactorEngine(debug=debug) # 2. 从 FactorManager 注册所有因子 print("\n[2] 从 FactorManager 注册所有因子") @@ -203,6 +216,11 @@ def run_probe_feature_selection_with_all_factors(): end_date=VAL_END, # 包含验证集,增加样本量 ) + # 3.5. 打印性能报告(如果启用了 debug 模式) + if debug: + print("\n[DEBUG] 生成性能报告") + engine.profiler.print_report() + # 4. 股票池筛选 print("\n[4] 执行股票池筛选") pool_manager = StockPoolManager( @@ -309,5 +327,23 @@ def run_probe_feature_selection_with_all_factors(): return selected_features +def main(): + """主入口函数,支持命令行参数""" + parser = argparse.ArgumentParser( + description="探针法因子筛选 - 使用 FactorManager 中所有因子" + ) + parser.add_argument( + "--debug", + "-d", + action="store_true", + help="启用 debug 模式,显示详细的性能统计信息", + ) + + args = parser.parse_args() + + selected = run_probe_feature_selection_with_all_factors(debug=args.debug) + return selected + + if __name__ == "__main__": - selected = run_probe_feature_selection_with_all_factors() + main() diff --git a/src/factors/engine/compute_engine.py b/src/factors/engine/compute_engine.py index d0e7b14..eefa230 100644 --- a/src/factors/engine/compute_engine.py +++ b/src/factors/engine/compute_engine.py @@ -6,7 +6,8 @@ 避免 Python 层面的多进程/多线程开销。 """ -from typing import Dict, List, Set +import time +from typing import Dict, List, Optional, Set import polars as pl @@ -25,9 +26,13 @@ class ComputeEngine: 4. Polars 自动在所有 CPU 核心上并行计算,零拷贝内存 """ - def __init__(self) -> None: - """初始化计算引擎。""" - pass + def __init__(self, debug: bool = False) -> None: + """初始化计算引擎。 + + Args: + debug: 是否启用调试模式 + """ + self._debug = debug def execute( self, diff --git a/src/factors/engine/data_router.py b/src/factors/engine/data_router.py index 8ed7b6d..178580b 100644 --- a/src/factors/engine/data_router.py +++ b/src/factors/engine/data_router.py @@ -6,6 +6,7 @@ 支持标准等值匹配和 asof_backward(财务数据)两种拼接模式。 """ +import time from typing import Any, Dict, List, Optional, Set, Union import threading @@ -25,19 +26,24 @@ class DataRouter: Attributes: data_source: 数据源,可以是内存 DataFrame 字典或数据库连接 is_memory_mode: 是否为内存模式 + _debug: 是否启用调试模式 """ - def __init__(self, data_source: Optional[Dict[str, pl.DataFrame]] = None) -> None: + def __init__( + self, data_source: Optional[Dict[str, pl.DataFrame]] = None, debug: bool = False + ) -> None: """初始化数据路由器。 Args: data_source: 内存数据源,字典格式 {表名: DataFrame} 为 None 时自动连接 DuckDB 数据库 + debug: 是否启用调试模式 """ self.data_source = data_source or {} self.is_memory_mode = data_source is not None self._cache: Dict[str, pl.DataFrame] = {} self._lock = threading.Lock() + self._debug = debug # 数据库模式下初始化 Storage 和 FinancialLoader if not self.is_memory_mode: @@ -71,6 +77,14 @@ class DataRouter: if not data_specs: raise ValueError("数据规格不能为空") + if self._debug: + print(f"\n[DataRouter Debug] ========== 开始数据获取 ==========") + print(f"[DataRouter Debug] 日期范围: {start_date} - {end_date}") + print( + f"[DataRouter Debug] 股票代码: {stock_codes if stock_codes else '全市场'}" + ) + print(f"[DataRouter Debug] 数据规格数: {len(data_specs)}") + # 收集所有需要的表和字段 required_tables: Dict[str, Set[str]] = {} @@ -79,8 +93,16 @@ class DataRouter: required_tables[spec.table] = set() required_tables[spec.table].update(spec.columns) + if self._debug: + print(f"\n[DataRouter Debug] 需要加载的表:") + for table_name, columns in required_tables.items(): + print(f" - {table_name}: {len(columns)}个字段") + # 从数据源获取各表数据(使用合并后的 required_tables,避免重复加载) table_data = {} + table_load_times: Dict[str, float] = {} + t_load_start = time.perf_counter() + for table_name, columns in required_tables.items(): # 判断是标准表还是财务表 is_financial = any( @@ -110,18 +132,40 @@ class DataRouter: join_type="standard", ) + t_table_start = time.perf_counter() df = self._load_table_from_spec( spec=spec, start_date=start_date, end_date=end_date, stock_codes=stock_codes, ) + t_table_end = time.perf_counter() table_data[table_name] = df + table_load_times[table_name] = t_table_end - t_table_start + + t_load_end = time.perf_counter() + + if self._debug: + print(f"\n[DataRouter Debug] 表加载详情:") + for table_name, load_time in table_load_times.items(): + rows = len(table_data[table_name]) + print(f" - {table_name}: {rows:,}行, {load_time * 1000:.2f}ms") + print(f" - 表加载总耗时: {(t_load_end - t_load_start) * 1000:.2f}ms") # 组装核心宽表(支持多种 join 类型) + t_assemble_start = time.perf_counter() core_table = self._assemble_wide_table_with_specs( table_data, data_specs, start_date, end_date ) + t_assemble_end = time.perf_counter() + + if self._debug: + print(f"\n[DataRouter Debug] 宽表组装:") + print(f" - 结果行数: {len(core_table):,}") + print(f" - 结果列数: {len(core_table.columns)}") + print(f" - 组装耗时: {(t_assemble_end - t_assemble_start) * 1000:.2f}ms") + print(f"\n[DataRouter Debug] ========== 数据获取完成 ==========") + print(f" - 总耗时: {(t_assemble_end - t_load_start) * 1000:.2f}ms") return core_table diff --git a/src/factors/engine/factor_engine.py b/src/factors/engine/factor_engine.py index 4fd0485..ca905d1 100644 --- a/src/factors/engine/factor_engine.py +++ b/src/factors/engine/factor_engine.py @@ -10,6 +10,7 @@ 5. 返回包含因子结果的数据表 """ +import time from typing import Any, Dict, List, Optional, Set, Union, TYPE_CHECKING import polars as pl @@ -31,6 +32,7 @@ from src.factors.engine.data_router import DataRouter from src.factors.engine.planner import ExecutionPlanner from src.factors.engine.compute_engine import ComputeEngine from src.factors.engine.ast_optimizer import ExpressionFlattener +from src.factors.engine.performance_profiler import PerformanceProfiler class FactorEngine: @@ -52,25 +54,28 @@ class FactorEngine: registered_expressions: 注册的表达式字典 _registry: 函数注册表 _parser: 公式解析器 + _debug: 是否启用调试模式 """ def __init__( self, data_source: Optional[Dict[str, pl.DataFrame]] = None, registry: Optional["FunctionRegistry"] = None, + debug: bool = False, ) -> None: """初始化因子引擎。 Args: data_source: 内存数据源,为 None 时使用数据库连接 registry: 函数注册表,None 时创建独立实例 + debug: 是否启用调试模式,启用后打印详细的性能信息 """ from src.factors.registry import FunctionRegistry from src.factors.parser import FormulaParser - self.router = DataRouter(data_source) + self.router = DataRouter(data_source, debug=debug) self.planner = ExecutionPlanner() - self.compute_engine = ComputeEngine() + self.compute_engine = ComputeEngine(debug=debug) self.registered_expressions: Dict[str, Node] = {} self._plans: Dict[str, ExecutionPlan] = {} @@ -83,6 +88,21 @@ class FactorEngine: self._metadata = FactorManager() + # 调试模式配置 + self._debug = debug + + # 初始化性能分析器 + self._profiler = PerformanceProfiler(enabled=debug) + + @property + def profiler(self) -> PerformanceProfiler: + """获取性能分析器实例 + + Returns: + PerformanceProfiler 实例 + """ + return self._profiler + def _register_internal( self, name: str, @@ -312,8 +332,28 @@ class FactorEngine: if isinstance(factor_names, str): factor_names = [factor_names] + # 重置性能分析器(如果启用) + if self._profiler.enabled: + self._profiler.reset() + + if self._debug: + print(f"\n[FactorEngine Debug] ========== 开始计算 ==========") + print(f"[FactorEngine Debug] 目标因子: {factor_names}") + print(f"[FactorEngine Debug] 日期范围: {start_date} - {end_date}") + print( + f"[FactorEngine Debug] 股票代码: {stock_codes if stock_codes else '全市场'}" + ) + # 1. 收集所有需要的因子(包括临时因子依赖) + t_start = time.perf_counter() all_factor_names = self._collect_all_dependencies(factor_names) + t_deps = time.perf_counter() + + if self._debug: + print(f"\n[FactorEngine Debug] 依赖收集:") + print(f" - 原始因子数: {len(factor_names)}") + print(f" - 总因子数(含依赖): {len(all_factor_names)}") + print(f" - 耗时: {(t_deps - t_start) * 1000:.2f}ms") # 2. 获取执行计划 plans = [] @@ -350,19 +390,42 @@ class FactorEngine: ) ) + if self._debug: + print(f"\n[FactorEngine Debug] 数据规格:") + print(f" - 涉及的表: {list(table_to_columns.keys())}") + for table, cols in table_to_columns.items(): + print(f" {table}: {len(cols)}个字段") + # 4. 从路由器获取核心宽表 + t_fetch_start = time.perf_counter() core_data = self.router.fetch_data( data_specs=unique_specs, start_date=start_date, end_date=end_date, stock_codes=stock_codes, ) + t_fetch_end = time.perf_counter() + + if self._debug: + print(f"\n[FactorEngine Debug] 数据获取:") + print(f" - 数据行数: {len(core_data):,}") + print(f" - 数据列数: {len(core_data.columns)}") + print(f" - 耗时: {(t_fetch_end - t_fetch_start) * 1000:.2f}ms") if len(core_data) == 0: raise ValueError("未获取到任何数据,请检查日期范围和股票代码") # 5. 按依赖顺序执行计算(包含临时因子) + t_compute_start = time.perf_counter() result = self._execute_with_dependencies(all_factor_names, core_data) + t_compute_end = time.perf_counter() + + if self._debug: + print(f"\n[FactorEngine Debug] 因子计算:") + print(f" - 计算因子数: {len(all_factor_names)}") + print(f" - 耗时: {(t_compute_end - t_compute_start) * 1000:.2f}ms") + print(f"\n[FactorEngine Debug] ========== 计算完成 ==========") + print(f" - 总耗时: {(t_compute_end - t_start) * 1000:.2f}ms") # 6. 清理内存宽表,过滤掉临时因子列(__tmp_X) # 保留所有非临时因子列(包括原始数据列和用户请求的因子列) @@ -427,14 +490,31 @@ class FactorEngine: # 2. 按顺序执行 result = core_data - for name in sorted_names: - plan = self._plans[name] - # 创建新的执行计划,引用已计算的依赖列 - new_plan = self._create_optimized_plan(plan, result) + if self._profiler.enabled: + # Debug 模式:逐个贪婪计算以获取真实耗时 + for name in sorted_names: + plan = self._plans[name] - # 执行计算 - result = self.compute_engine.execute(new_plan, result) + # 创建新的执行计划,引用已计算的依赖列 + new_plan = self._create_optimized_plan(plan, result) + + # 使用性能分析器计时,触发真实计算 + with self._profiler.measure(name): + # 逐个执行,触发 Polars 底层计算 + result = result.with_columns( + [new_plan.polars_expr.alias(new_plan.output_name)] + ) + else: + # 生产模式:批量执行,享受 Polars 并行优化 + for name in sorted_names: + plan = self._plans[name] + + # 创建新的执行计划,引用已计算的依赖列 + new_plan = self._create_optimized_plan(plan, result) + + # 执行计算 + result = self.compute_engine.execute(new_plan, result) return result diff --git a/src/factors/engine/performance_profiler.py b/src/factors/engine/performance_profiler.py new file mode 100644 index 0000000..9e72d85 --- /dev/null +++ b/src/factors/engine/performance_profiler.py @@ -0,0 +1,138 @@ +"""性能分析器 - 独立组件,与 FactorEngine 解耦 + +支持上下文管理器用法: + with profiler.measure("factor_name"): + df = df.with_columns(expr) # 触发真实计算 +""" + +import time +from contextlib import contextmanager +from dataclasses import dataclass +from typing import Any, Dict, Iterator, List, Optional, Tuple + + +@dataclass +class ProfileRecord: + """性能记录数据结构""" + + time_ms: float + # 预留扩展字段 + # memory_mb: float = 0.0 + # call_count: int = 1 + + +class PerformanceProfiler: + """性能分析器 - 独立组件,与 FactorEngine 解耦 + + 支持上下文管理器用法: + with profiler.measure("factor_name"): + df = df.with_columns(expr) # 触发真实计算 + """ + + def __init__(self, enabled: bool = False): + self.enabled = enabled + self.records: Dict[str, ProfileRecord] = {} + self._context_stack: List[str] = [] # 支持嵌套计时 + + @contextmanager + def measure(self, name: str) -> Iterator[None]: + """上下文管理器:安全计时,自动处理异常 + + Args: + name: 计时任务名称(通常是因子名称) + + Yields: + None + """ + if not self.enabled: + yield + return + + start_time = time.perf_counter() + self._context_stack.append(name) + + try: + yield + finally: + elapsed = time.perf_counter() - start_time + self._context_stack.pop() + + # 更加简洁和安全的更新方式 + if name not in self.records: + self.records[name] = ProfileRecord(time_ms=0.0) + + self.records[name].time_ms += elapsed + + def get_report(self) -> Dict[str, Any]: + """生成性能报告 + + Returns: + 包含性能统计数据的字典: + - enabled: 是否启用了性能监控 + - total_time_ms: 总计算时间(毫秒) + - factor_count: 因子数量 + - avg_time_ms: 平均耗时(毫秒) + - slowest_factors: 最慢的5个因子列表 + - all_records: 所有因子的详细记录 + """ + if not self.records: + return {} + + times = [r.time_ms for r in self.records.values()] + sorted_records = sorted( + self.records.items(), key=lambda x: x[1].time_ms, reverse=True + ) + + return { + "enabled": self.enabled, + "total_time_ms": sum(times) * 1000, + "factor_count": len(self.records), + "avg_time_ms": (sum(times) / len(times)) * 1000, + "slowest_factors": [ + (name, record.time_ms * 1000) for name, record in sorted_records[:5] + ], + "all_records": { + name: {"time_ms": record.time_ms * 1000} + for name, record in self.records.items() + }, + } + + def print_report(self) -> None: + """打印格式化的性能报告""" + if not self.enabled: + print("[Performance] 性能监控未启用") + return + + report = self.get_report() + if not report: + print("[Performance] 无性能数据") + return + + print("\n" + "=" * 60) + print("Factor Engine 性能报告") + print("=" * 60) + print(f"总计算时间: {report['total_time_ms']:.2f}ms") + print(f"因子数量: {report['factor_count']}") + print(f"平均耗时: {report['avg_time_ms']:.2f}ms") + print(f"\n最慢的 5 个因子:") + for i, (name, time_ms) in enumerate(report["slowest_factors"], 1): + print(f" {i}. {name}: {time_ms:.2f}ms") + print("=" * 60) + + def reset(self) -> None: + """重置性能数据""" + self.records.clear() + self._context_stack.clear() + + def get_factor_time(self, name: str) -> Optional[float]: + """获取指定因子的计算时间 + + Args: + name: 因子名称 + + Returns: + 计算时间(秒),如果未找到则返回 None + """ + if name in self.records: + return self.records[name].time_ms + return None diff --git a/tests/test_performance_profiler.py b/tests/test_performance_profiler.py new file mode 100644 index 0000000..0637a86 --- /dev/null +++ b/tests/test_performance_profiler.py @@ -0,0 +1,193 @@ +"""性能分析器测试""" + +import pytest +import time +from src.factors.engine.performance_profiler import PerformanceProfiler, ProfileRecord + + +class TestProfileRecord: + """测试 ProfileRecord 数据类""" + + def test_profile_record_creation(self): + """测试创建 ProfileRecord""" + record = ProfileRecord(time_ms=1.5) + assert record.time_ms == 1.5 + + +class TestPerformanceProfiler: + """测试 PerformanceProfiler 类""" + + def test_profiler_disabled_by_default(self): + """测试默认情况下禁用""" + profiler = PerformanceProfiler() + assert profiler.enabled is False + + def test_profiler_enabled(self): + """测试启用性能分析器""" + profiler = PerformanceProfiler(enabled=True) + assert profiler.enabled is True + + def test_measure_context_manager_disabled(self): + """测试禁用时 measure 上下文管理器""" + profiler = PerformanceProfiler(enabled=False) + + with profiler.measure("test_task"): + time.sleep(0.001) + + # 禁用时不会记录数据 + assert len(profiler.records) == 0 + + def test_measure_context_manager_enabled(self): + """测试启用时 measure 上下文管理器""" + profiler = PerformanceProfiler(enabled=True) + + with profiler.measure("test_task"): + time.sleep(0.01) # 10ms + + assert "test_task" in profiler.records + assert profiler.records["test_task"].time_ms >= 0.01 + + def test_measure_multiple_calls_accumulate(self): + """测试多次调用同名任务会累加时间""" + profiler = PerformanceProfiler(enabled=True) + + # 第一次调用 + with profiler.measure("task"): + time.sleep(0.01) + + first_time = profiler.records["task"].time_ms + + # 第二次调用 + with profiler.measure("task"): + time.sleep(0.01) + + # 时间应该累加 + assert profiler.records["task"].time_ms > first_time + + def test_measure_exception_handling(self): + """测试异常发生时计时器正常闭合""" + profiler = PerformanceProfiler(enabled=True) + + try: + with profiler.measure("failing_task"): + time.sleep(0.01) + raise ValueError("Test error") + except ValueError: + pass + + # 异常发生时仍然记录了时间 + assert "failing_task" in profiler.records + assert profiler.records["failing_task"].time_ms >= 0.01 + + def test_get_report_empty(self): + """测试空记录时的报告""" + profiler = PerformanceProfiler(enabled=True) + report = profiler.get_report() + assert report == {} + + def test_get_report_with_data(self): + """测试有数据时的报告""" + profiler = PerformanceProfiler(enabled=True) + + with profiler.measure("task1"): + time.sleep(0.02) + with profiler.measure("task2"): + time.sleep(0.01) + + report = profiler.get_report() + + assert report["enabled"] is True + assert report["factor_count"] == 2 + assert report["total_time_ms"] > 0 + assert report["avg_time_ms"] > 0 + assert len(report["slowest_factors"]) <= 2 + assert "all_records" in report + + def test_get_report_sorts_by_time(self): + """测试报告按时间排序""" + profiler = PerformanceProfiler(enabled=True) + + # task1 耗时更长 + with profiler.measure("task1"): + time.sleep(0.02) + + with profiler.measure("task2"): + time.sleep(0.01) + + report = profiler.get_report() + slowest = report["slowest_factors"] + + assert slowest[0][0] == "task1" # 最慢的应该是 task1 + + def test_reset_clears_data(self): + """测试重置清除数据""" + profiler = PerformanceProfiler(enabled=True) + + with profiler.measure("task"): + time.sleep(0.01) + + assert len(profiler.records) == 1 + + profiler.reset() + + assert len(profiler.records) == 0 + assert len(profiler._context_stack) == 0 + + def test_print_report_disabled(self): + """测试禁用时打印报告""" + profiler = PerformanceProfiler(enabled=False) + # 应该正常执行不报错 + profiler.print_report() + + def test_print_report_empty(self, capsys): + """测试空数据时打印报告""" + profiler = PerformanceProfiler(enabled=True) + profiler.print_report() + + captured = capsys.readouterr() + assert "无性能数据" in captured.out + + def test_print_report_with_data(self, capsys): + """测试有数据时打印报告""" + profiler = PerformanceProfiler(enabled=True) + + with profiler.measure("slow_task"): + time.sleep(0.02) + with profiler.measure("fast_task"): + time.sleep(0.01) + + profiler.print_report() + + captured = capsys.readouterr() + assert "Factor Engine 性能报告" in captured.out + assert "slow_task" in captured.out + assert "fast_task" in captured.out + + def test_get_factor_time(self): + """测试获取单个因子的计算时间""" + profiler = PerformanceProfiler(enabled=True) + + with profiler.measure("task"): + time.sleep(0.01) + + time_val = profiler.get_factor_time("task") + assert time_val is not None + assert time_val >= 0.01 + + # 不存在的因子返回 None + assert profiler.get_factor_time("nonexistent") is None + + def test_nested_measure(self): + """测试嵌套计时""" + profiler = PerformanceProfiler(enabled=True) + + with profiler.measure("outer"): + time.sleep(0.01) + with profiler.measure("inner"): + time.sleep(0.01) + + assert "outer" in profiler.records + assert "inner" in profiler.records + # 嵌套计时应该分别记录 + assert profiler.records["outer"].time_ms >= 0.01 + assert profiler.records["inner"].time_ms >= 0.01