From 12ddb19b2e0440e87c761828e0a712e7afac9fb3 Mon Sep 17 00:00:00 2001 From: liaozhaorun <1300336796@qq.com> Date: Tue, 3 Mar 2026 17:32:58 +0800 Subject: [PATCH] =?UTF-8?q?feat(factors):=20=E6=B7=BB=E5=8A=A0=20SchemaCac?= =?UTF-8?q?he=20=E5=AE=9E=E7=8E=B0=E6=95=B0=E6=8D=AE=E5=BA=93=E8=A1=A8?= =?UTF-8?q?=E7=BB=93=E6=9E=84=E8=87=AA=E5=8A=A8=E6=89=AB=E6=8F=8F?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/factors/engine/planner.py | 71 ++------- src/factors/engine/schema_cache.py | 247 +++++++++++++++++++++++++++++ 2 files changed, 256 insertions(+), 62 deletions(-) create mode 100644 src/factors/engine/schema_cache.py diff --git a/src/factors/engine/planner.py b/src/factors/engine/planner.py index 75a8bb7..05c756e 100644 --- a/src/factors/engine/planner.py +++ b/src/factors/engine/planner.py @@ -16,6 +16,7 @@ from src.factors.dsl import ( from src.factors.compiler import DependencyExtractor from src.factors.translator import PolarsTranslator from src.factors.engine.data_spec import DataSpec, ExecutionPlan +from src.factors.engine.schema_cache import get_schema_cache class ExecutionPlanner: @@ -73,9 +74,8 @@ class ExecutionPlanner: ) -> List[DataSpec]: """从依赖推导数据规格。 - 基础行情字段(open, high, low, close, vol, amount, pre_close, change, pct_chg) - 默认从 pro_bar 表获取。 - 每日指标字段(total_mv, circ_mv, pe, pb 等)从 daily_basic 表获取。 + 使用 SchemaCache 动态扫描数据库表结构,自动匹配字段到对应的表。 + 表结构只扫描一次并缓存在内存中。 Args: dependencies: 依赖的字段集合 @@ -84,69 +84,16 @@ class ExecutionPlanner: Returns: 数据规格列表 """ - # 基础行情字段集合(这些字段从 pro_bar 表获取) - pro_bar_fields = { - "open", - "high", - "low", - "close", - "vol", - "amount", - "pre_close", - "change", - "pct_chg", - "turnover_rate", - "volume_ratio", - } - - # 每日指标字段集合(这些字段从 daily_basic 表获取) - daily_basic_fields = { - "turnover_rate_f", - "pe", - "pe_ttm", - "pb", - "ps", - "ps_ttm", - "dv_ratio", - "dv_ttm", - "total_share", - "float_share", - "free_share", - "total_mv", - "circ_mv", - } - - # 将依赖分为不同表的字段 - pro_bar_deps = dependencies & pro_bar_fields - daily_basic_deps = dependencies & daily_basic_fields - other_deps = dependencies - pro_bar_fields - daily_basic_fields + # 使用 SchemaCache 自动匹配字段到表 + schema_cache = get_schema_cache() + table_to_fields = schema_cache.match_fields_to_tables(dependencies) data_specs = [] - - # pro_bar 表的数据规格 - if pro_bar_deps: + for table_name, columns in table_to_fields.items(): data_specs.append( DataSpec( - table="pro_bar", - columns=sorted(pro_bar_deps), - ) - ) - - # daily_basic 表的数据规格 - if daily_basic_deps: - data_specs.append( - DataSpec( - table="daily_basic", - columns=sorted(daily_basic_deps), - ) - ) - - # 其他字段从 daily 表获取 - if other_deps: - data_specs.append( - DataSpec( - table="daily", - columns=sorted(other_deps), + table=table_name, + columns=columns, ) ) diff --git a/src/factors/engine/schema_cache.py b/src/factors/engine/schema_cache.py new file mode 100644 index 0000000..51b07e8 --- /dev/null +++ b/src/factors/engine/schema_cache.py @@ -0,0 +1,247 @@ +"""表结构缓存管理器。 + +提供动态扫描数据库表结构并缓存的功能,避免重复扫描。 +""" + +from typing import Dict, List, Optional, Set + +from src.data.storage import Storage + + +class SchemaCache: + """表结构缓存管理器(单例模式)。 + + 动态扫描数据库中所有表的字段信息,并在内存中缓存。 + 使用 @lru_cache 确保整个进程生命周期中只扫描一次。 + + Attributes: + _instance: 单例实例 + _field_to_table_map: 字段到表的映射缓存 + _table_to_fields_map: 表到字段列表的映射缓存 + """ + + _instance: Optional["SchemaCache"] = None + _field_to_table_map: Optional[Dict[str, str]] = None + _table_to_fields_map: Optional[Dict[str, List[str]]] = None + + def __new__(cls) -> "SchemaCache": + """确保单例模式。""" + if cls._instance is None: + cls._instance = super().__new__(cls) + return cls._instance + + @classmethod + def get_instance(cls) -> "SchemaCache": + """获取 SchemaCache 单例实例。 + + Returns: + SchemaCache 实例 + """ + if cls._instance is None: + cls._instance = cls() + return cls._instance + + @classmethod + def reset_cache(cls) -> None: + """重置缓存(主要用于测试)。""" + cls._field_to_table_map = None + cls._table_to_fields_map = None + + def _scan_table_schemas(self) -> Dict[str, List[str]]: + """扫描数据库中所有表的字段信息。 + + Returns: + 表名到字段列表的映射字典 + """ + storage = Storage() + table_fields: Dict[str, List[str]] = {} + + try: + conn = storage._connection + + # 检查连接是否可用 + if conn is None: + print("[SchemaCache] 数据库连接不可用") + return {} + + # 使用断言帮助类型检查器 + assert conn is not None + + # 获取所有表名(排除系统表) + tables_result = conn.execute(""" + SELECT table_name + FROM information_schema.tables + WHERE table_schema = 'main' + AND table_type = 'BASE TABLE' + """).fetchall() + + tables = [row[0] for row in tables_result] + + # 获取每个表的字段信息 + for table_name in tables: + columns_result = conn.execute( + """ + SELECT column_name + FROM information_schema.columns + WHERE table_name = ? + ORDER BY ordinal_position + """, + [table_name], + ).fetchall() + + columns = [row[0] for row in columns_result] + table_fields[table_name] = columns + + except Exception as e: + print(f"[SchemaCache] 扫描表结构失败: {e}") + # 返回空字典,后续可以使用硬编码的默认配置 + table_fields = {} + + return table_fields + + def _ensure_scanned(self) -> None: + """确保表结构已扫描(只执行一次)。""" + if self._table_to_fields_map is None or self._field_to_table_map is None: + table_fields = self._scan_table_schemas() + + # 表到字段的映射 + self._table_to_fields_map = table_fields + + # 字段到表的映射(一个字段可能在多个表中存在) + field_to_tables: Dict[str, List[str]] = {} + for table, fields in table_fields.items(): + for field in fields: + if field not in field_to_tables: + field_to_tables[field] = [] + field_to_tables[field].append(table) + + # 优先选择最常用的表(pro_bar > daily_basic > daily) + priority_order = {"pro_bar": 1, "daily_basic": 2, "daily": 3} + + self._field_to_table_map = {} + for field, tables in field_to_tables.items(): + # 按优先级排序,选择优先级最高的表 + sorted_tables = sorted(tables, key=lambda t: priority_order.get(t, 999)) + self._field_to_table_map[field] = sorted_tables[0] + + def get_table_fields(self, table_name: str) -> List[str]: + """获取指定表的字段列表。 + + Args: + table_name: 表名 + + Returns: + 字段列表,表不存在时返回空列表 + """ + self._ensure_scanned() + if self._table_to_fields_map is None: + return [] + return self._table_to_fields_map.get(table_name, []) + + def get_field_table(self, field_name: str) -> Optional[str]: + """获取包含指定字段的表名。 + + 如果多个表包含该字段,返回优先级最高的表。 + + Args: + field_name: 字段名 + + Returns: + 表名,字段不存在时返回 None + """ + self._ensure_scanned() + if self._field_to_table_map is None: + return None + return self._field_to_table_map.get(field_name) + + def get_all_tables(self) -> List[str]: + """获取所有表名列表。 + + Returns: + 表名列表 + """ + self._ensure_scanned() + if self._table_to_fields_map is None: + return [] + return list(self._table_to_fields_map.keys()) + + def field_exists(self, field_name: str) -> bool: + """检查字段是否存在于任何表中。 + + Args: + field_name: 字段名 + + Returns: + 是否存在 + """ + self._ensure_scanned() + if self._field_to_table_map is None: + return False + return field_name in self._field_to_table_map + + def match_fields_to_tables(self, field_names: Set[str]) -> Dict[str, List[str]]: + """将字段集合按表分组。 + + Args: + field_names: 字段名集合 + + Returns: + 表名到字段列表的映射 + """ + self._ensure_scanned() + + table_to_fields: Dict[str, List[str]] = {} + + for field in field_names: + table = self.get_field_table(field) + if table is not None: + if table not in table_to_fields: + table_to_fields[table] = [] + table_to_fields[table].append(field) + else: + # 字段不存在于任何表,归入 "daily" 表(默认表) + if "daily" not in table_to_fields: + table_to_fields["daily"] = [] + table_to_fields["daily"].append(field) + + # 对字段列表排序以保持确定性输出 + for fields in table_to_fields.values(): + fields.sort() + + return table_to_fields + + +# 模块级便捷函数 + + +def get_schema_cache() -> SchemaCache: + """获取 SchemaCache 单例实例。 + + Returns: + SchemaCache 实例 + """ + return SchemaCache.get_instance() + + +def get_field_table(field_name: str) -> Optional[str]: + """获取包含指定字段的表名。 + + Args: + field_name: 字段名 + + Returns: + 表名,字段不存在时返回 None + """ + return SchemaCache.get_instance().get_field_table(field_name) + + +def match_fields_to_tables(field_names: Set[str]) -> Dict[str, List[str]]: + """将字段集合按表分组。 + + Args: + field_names: 字段名集合 + + Returns: + 表名到字段列表的映射 + """ + return SchemaCache.get_instance().match_fields_to_tables(field_names)