Files
ProStock/src/factors/engine/schema_cache.py

248 lines
7.4 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""表结构缓存管理器。
提供动态扫描数据库表结构并缓存的功能,避免重复扫描。
"""
from typing import Dict, List, Optional, Set
from src.data.storage import Storage
class SchemaCache:
"""表结构缓存管理器(单例模式)。
动态扫描数据库中所有表的字段信息,并在内存中缓存。
使用 @lru_cache 确保整个进程生命周期中只扫描一次。
Attributes:
_instance: 单例实例
_field_to_table_map: 字段到表的映射缓存
_table_to_fields_map: 表到字段列表的映射缓存
"""
_instance: Optional["SchemaCache"] = None
_field_to_table_map: Optional[Dict[str, str]] = None
_table_to_fields_map: Optional[Dict[str, List[str]]] = None
def __new__(cls) -> "SchemaCache":
"""确保单例模式。"""
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
@classmethod
def get_instance(cls) -> "SchemaCache":
"""获取 SchemaCache 单例实例。
Returns:
SchemaCache 实例
"""
if cls._instance is None:
cls._instance = cls()
return cls._instance
@classmethod
def reset_cache(cls) -> None:
"""重置缓存(主要用于测试)。"""
cls._field_to_table_map = None
cls._table_to_fields_map = None
def _scan_table_schemas(self) -> Dict[str, List[str]]:
"""扫描数据库中所有表的字段信息。
Returns:
表名到字段列表的映射字典
"""
storage = Storage()
table_fields: Dict[str, List[str]] = {}
try:
conn = storage._connection
# 检查连接是否可用
if conn is None:
print("[SchemaCache] 数据库连接不可用")
return {}
# 使用断言帮助类型检查器
assert conn is not None
# 获取所有表名(排除系统表)
tables_result = conn.execute("""
SELECT table_name
FROM information_schema.tables
WHERE table_schema = 'main'
AND table_type = 'BASE TABLE'
""").fetchall()
tables = [row[0] for row in tables_result]
# 获取每个表的字段信息
for table_name in tables:
columns_result = conn.execute(
"""
SELECT column_name
FROM information_schema.columns
WHERE table_name = ?
ORDER BY ordinal_position
""",
[table_name],
).fetchall()
columns = [row[0] for row in columns_result]
table_fields[table_name] = columns
except Exception as e:
print(f"[SchemaCache] 扫描表结构失败: {e}")
# 返回空字典,后续可以使用硬编码的默认配置
table_fields = {}
return table_fields
def _ensure_scanned(self) -> None:
"""确保表结构已扫描(只执行一次)。"""
if self._table_to_fields_map is None or self._field_to_table_map is None:
table_fields = self._scan_table_schemas()
# 表到字段的映射
self._table_to_fields_map = table_fields
# 字段到表的映射(一个字段可能在多个表中存在)
field_to_tables: Dict[str, List[str]] = {}
for table, fields in table_fields.items():
for field in fields:
if field not in field_to_tables:
field_to_tables[field] = []
field_to_tables[field].append(table)
# 优先选择最常用的表pro_bar > daily_basic > daily
priority_order = {"pro_bar": 1, "daily_basic": 2, "daily": 3}
self._field_to_table_map = {}
for field, tables in field_to_tables.items():
# 按优先级排序,选择优先级最高的表
sorted_tables = sorted(tables, key=lambda t: priority_order.get(t, 999))
self._field_to_table_map[field] = sorted_tables[0]
def get_table_fields(self, table_name: str) -> List[str]:
"""获取指定表的字段列表。
Args:
table_name: 表名
Returns:
字段列表,表不存在时返回空列表
"""
self._ensure_scanned()
if self._table_to_fields_map is None:
return []
return self._table_to_fields_map.get(table_name, [])
def get_field_table(self, field_name: str) -> Optional[str]:
"""获取包含指定字段的表名。
如果多个表包含该字段,返回优先级最高的表。
Args:
field_name: 字段名
Returns:
表名,字段不存在时返回 None
"""
self._ensure_scanned()
if self._field_to_table_map is None:
return None
return self._field_to_table_map.get(field_name)
def get_all_tables(self) -> List[str]:
"""获取所有表名列表。
Returns:
表名列表
"""
self._ensure_scanned()
if self._table_to_fields_map is None:
return []
return list(self._table_to_fields_map.keys())
def field_exists(self, field_name: str) -> bool:
"""检查字段是否存在于任何表中。
Args:
field_name: 字段名
Returns:
是否存在
"""
self._ensure_scanned()
if self._field_to_table_map is None:
return False
return field_name in self._field_to_table_map
def match_fields_to_tables(self, field_names: Set[str]) -> Dict[str, List[str]]:
"""将字段集合按表分组。
Args:
field_names: 字段名集合
Returns:
表名到字段列表的映射
"""
self._ensure_scanned()
table_to_fields: Dict[str, List[str]] = {}
for field in field_names:
table = self.get_field_table(field)
if table is not None:
if table not in table_to_fields:
table_to_fields[table] = []
table_to_fields[table].append(field)
else:
# 字段不存在于任何表,归入 "daily" 表(默认表)
if "daily" not in table_to_fields:
table_to_fields["daily"] = []
table_to_fields["daily"].append(field)
# 对字段列表排序以保持确定性输出
for fields in table_to_fields.values():
fields.sort()
return table_to_fields
# 模块级便捷函数
def get_schema_cache() -> SchemaCache:
"""获取 SchemaCache 单例实例。
Returns:
SchemaCache 实例
"""
return SchemaCache.get_instance()
def get_field_table(field_name: str) -> Optional[str]:
"""获取包含指定字段的表名。
Args:
field_name: 字段名
Returns:
表名,字段不存在时返回 None
"""
return SchemaCache.get_instance().get_field_table(field_name)
def match_fields_to_tables(field_names: Set[str]) -> Dict[str, List[str]]:
"""将字段集合按表分组。
Args:
field_names: 字段名集合
Returns:
表名到字段列表的映射
"""
return SchemaCache.get_instance().match_fields_to_tables(field_names)