"""表结构缓存管理器。 提供动态扫描数据库表结构并缓存的功能,避免重复扫描。 """ from typing import Dict, List, Optional, Set from src.data.storage import Storage class SchemaCache: """表结构缓存管理器(单例模式)。 动态扫描数据库中所有表的字段信息,并在内存中缓存。 使用 @lru_cache 确保整个进程生命周期中只扫描一次。 Attributes: _instance: 单例实例 _field_to_table_map: 字段到表的映射缓存 _table_to_fields_map: 表到字段列表的映射缓存 """ _instance: Optional["SchemaCache"] = None _field_to_table_map: Optional[Dict[str, str]] = None _table_to_fields_map: Optional[Dict[str, List[str]]] = None def __new__(cls) -> "SchemaCache": """确保单例模式。""" if cls._instance is None: cls._instance = super().__new__(cls) return cls._instance @classmethod def get_instance(cls) -> "SchemaCache": """获取 SchemaCache 单例实例。 Returns: SchemaCache 实例 """ if cls._instance is None: cls._instance = cls() return cls._instance @classmethod def reset_cache(cls) -> None: """重置缓存(主要用于测试)。""" cls._field_to_table_map = None cls._table_to_fields_map = None def _scan_table_schemas(self) -> Dict[str, List[str]]: """扫描数据库中所有表的字段信息。 Returns: 表名到字段列表的映射字典 """ storage = Storage() table_fields: Dict[str, List[str]] = {} try: conn = storage._connection # 检查连接是否可用 if conn is None: print("[SchemaCache] 数据库连接不可用") return {} # 使用断言帮助类型检查器 assert conn is not None # 获取所有表名(排除系统表) tables_result = conn.execute(""" SELECT table_name FROM information_schema.tables WHERE table_schema = 'main' AND table_type = 'BASE TABLE' """).fetchall() tables = [row[0] for row in tables_result] # 获取每个表的字段信息 for table_name in tables: columns_result = conn.execute( """ SELECT column_name FROM information_schema.columns WHERE table_name = ? ORDER BY ordinal_position """, [table_name], ).fetchall() columns = [row[0] for row in columns_result] table_fields[table_name] = columns except Exception as e: print(f"[SchemaCache] 扫描表结构失败: {e}") # 返回空字典,后续可以使用硬编码的默认配置 table_fields = {} return table_fields def _ensure_scanned(self) -> None: """确保表结构已扫描(只执行一次)。""" if self._table_to_fields_map is None or self._field_to_table_map is None: table_fields = self._scan_table_schemas() # 表到字段的映射 self._table_to_fields_map = table_fields # 字段到表的映射(一个字段可能在多个表中存在) field_to_tables: Dict[str, List[str]] = {} for table, fields in table_fields.items(): for field in fields: if field not in field_to_tables: field_to_tables[field] = [] field_to_tables[field].append(table) # 优先选择最常用的表(pro_bar > daily_basic > daily) priority_order = {"pro_bar": 1, "daily_basic": 2, "daily": 3} self._field_to_table_map = {} for field, tables in field_to_tables.items(): # 按优先级排序,选择优先级最高的表 sorted_tables = sorted(tables, key=lambda t: priority_order.get(t, 999)) self._field_to_table_map[field] = sorted_tables[0] def get_table_fields(self, table_name: str) -> List[str]: """获取指定表的字段列表。 Args: table_name: 表名 Returns: 字段列表,表不存在时返回空列表 """ self._ensure_scanned() if self._table_to_fields_map is None: return [] return self._table_to_fields_map.get(table_name, []) def get_field_table(self, field_name: str) -> Optional[str]: """获取包含指定字段的表名。 如果多个表包含该字段,返回优先级最高的表。 Args: field_name: 字段名 Returns: 表名,字段不存在时返回 None """ self._ensure_scanned() if self._field_to_table_map is None: return None return self._field_to_table_map.get(field_name) def get_all_tables(self) -> List[str]: """获取所有表名列表。 Returns: 表名列表 """ self._ensure_scanned() if self._table_to_fields_map is None: return [] return list(self._table_to_fields_map.keys()) def field_exists(self, field_name: str) -> bool: """检查字段是否存在于任何表中。 Args: field_name: 字段名 Returns: 是否存在 """ self._ensure_scanned() if self._field_to_table_map is None: return False return field_name in self._field_to_table_map def match_fields_to_tables(self, field_names: Set[str]) -> Dict[str, List[str]]: """将字段集合按表分组。 Args: field_names: 字段名集合 Returns: 表名到字段列表的映射 """ self._ensure_scanned() table_to_fields: Dict[str, List[str]] = {} for field in field_names: table = self.get_field_table(field) if table is not None: if table not in table_to_fields: table_to_fields[table] = [] table_to_fields[table].append(field) else: # 字段不存在于任何表,归入 "daily" 表(默认表) if "daily" not in table_to_fields: table_to_fields["daily"] = [] table_to_fields["daily"].append(field) # 对字段列表排序以保持确定性输出 for fields in table_to_fields.values(): fields.sort() return table_to_fields # 模块级便捷函数 def get_schema_cache() -> SchemaCache: """获取 SchemaCache 单例实例。 Returns: SchemaCache 实例 """ return SchemaCache.get_instance() def get_field_table(field_name: str) -> Optional[str]: """获取包含指定字段的表名。 Args: field_name: 字段名 Returns: 表名,字段不存在时返回 None """ return SchemaCache.get_instance().get_field_table(field_name) def match_fields_to_tables(field_names: Set[str]) -> Dict[str, List[str]]: """将字段集合按表分组。 Args: field_names: 字段名集合 Returns: 表名到字段列表的映射 """ return SchemaCache.get_instance().match_fields_to_tables(field_names)