Files
ProStock/src/factors/engine/schema_cache.py

248 lines
7.4 KiB
Python
Raw Normal View History

"""表结构缓存管理器。
提供动态扫描数据库表结构并缓存的功能避免重复扫描
"""
from typing import Dict, List, Optional, Set
from src.data.storage import Storage
class SchemaCache:
"""表结构缓存管理器(单例模式)。
动态扫描数据库中所有表的字段信息并在内存中缓存
使用 @lru_cache 确保整个进程生命周期中只扫描一次
Attributes:
_instance: 单例实例
_field_to_table_map: 字段到表的映射缓存
_table_to_fields_map: 表到字段列表的映射缓存
"""
_instance: Optional["SchemaCache"] = None
_field_to_table_map: Optional[Dict[str, str]] = None
_table_to_fields_map: Optional[Dict[str, List[str]]] = None
def __new__(cls) -> "SchemaCache":
"""确保单例模式。"""
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
@classmethod
def get_instance(cls) -> "SchemaCache":
"""获取 SchemaCache 单例实例。
Returns:
SchemaCache 实例
"""
if cls._instance is None:
cls._instance = cls()
return cls._instance
@classmethod
def reset_cache(cls) -> None:
"""重置缓存(主要用于测试)。"""
cls._field_to_table_map = None
cls._table_to_fields_map = None
def _scan_table_schemas(self) -> Dict[str, List[str]]:
"""扫描数据库中所有表的字段信息。
Returns:
表名到字段列表的映射字典
"""
storage = Storage()
table_fields: Dict[str, List[str]] = {}
try:
conn = storage._connection
# 检查连接是否可用
if conn is None:
print("[SchemaCache] 数据库连接不可用")
return {}
# 使用断言帮助类型检查器
assert conn is not None
# 获取所有表名(排除系统表)
tables_result = conn.execute("""
SELECT table_name
FROM information_schema.tables
WHERE table_schema = 'main'
AND table_type = 'BASE TABLE'
""").fetchall()
tables = [row[0] for row in tables_result]
# 获取每个表的字段信息
for table_name in tables:
columns_result = conn.execute(
"""
SELECT column_name
FROM information_schema.columns
WHERE table_name = ?
ORDER BY ordinal_position
""",
[table_name],
).fetchall()
columns = [row[0] for row in columns_result]
table_fields[table_name] = columns
except Exception as e:
print(f"[SchemaCache] 扫描表结构失败: {e}")
# 返回空字典,后续可以使用硬编码的默认配置
table_fields = {}
return table_fields
def _ensure_scanned(self) -> None:
"""确保表结构已扫描(只执行一次)。"""
if self._table_to_fields_map is None or self._field_to_table_map is None:
table_fields = self._scan_table_schemas()
# 表到字段的映射
self._table_to_fields_map = table_fields
# 字段到表的映射(一个字段可能在多个表中存在)
field_to_tables: Dict[str, List[str]] = {}
for table, fields in table_fields.items():
for field in fields:
if field not in field_to_tables:
field_to_tables[field] = []
field_to_tables[field].append(table)
# 优先选择最常用的表pro_bar > daily_basic > daily
priority_order = {"pro_bar": 1, "daily_basic": 2, "daily": 3}
self._field_to_table_map = {}
for field, tables in field_to_tables.items():
# 按优先级排序,选择优先级最高的表
sorted_tables = sorted(tables, key=lambda t: priority_order.get(t, 999))
self._field_to_table_map[field] = sorted_tables[0]
def get_table_fields(self, table_name: str) -> List[str]:
"""获取指定表的字段列表。
Args:
table_name: 表名
Returns:
字段列表表不存在时返回空列表
"""
self._ensure_scanned()
if self._table_to_fields_map is None:
return []
return self._table_to_fields_map.get(table_name, [])
def get_field_table(self, field_name: str) -> Optional[str]:
"""获取包含指定字段的表名。
如果多个表包含该字段返回优先级最高的表
Args:
field_name: 字段名
Returns:
表名字段不存在时返回 None
"""
self._ensure_scanned()
if self._field_to_table_map is None:
return None
return self._field_to_table_map.get(field_name)
def get_all_tables(self) -> List[str]:
"""获取所有表名列表。
Returns:
表名列表
"""
self._ensure_scanned()
if self._table_to_fields_map is None:
return []
return list(self._table_to_fields_map.keys())
def field_exists(self, field_name: str) -> bool:
"""检查字段是否存在于任何表中。
Args:
field_name: 字段名
Returns:
是否存在
"""
self._ensure_scanned()
if self._field_to_table_map is None:
return False
return field_name in self._field_to_table_map
def match_fields_to_tables(self, field_names: Set[str]) -> Dict[str, List[str]]:
"""将字段集合按表分组。
Args:
field_names: 字段名集合
Returns:
表名到字段列表的映射
"""
self._ensure_scanned()
table_to_fields: Dict[str, List[str]] = {}
for field in field_names:
table = self.get_field_table(field)
if table is not None:
if table not in table_to_fields:
table_to_fields[table] = []
table_to_fields[table].append(field)
else:
# 字段不存在于任何表,归入 "daily" 表(默认表)
if "daily" not in table_to_fields:
table_to_fields["daily"] = []
table_to_fields["daily"].append(field)
# 对字段列表排序以保持确定性输出
for fields in table_to_fields.values():
fields.sort()
return table_to_fields
# 模块级便捷函数
def get_schema_cache() -> SchemaCache:
"""获取 SchemaCache 单例实例。
Returns:
SchemaCache 实例
"""
return SchemaCache.get_instance()
def get_field_table(field_name: str) -> Optional[str]:
"""获取包含指定字段的表名。
Args:
field_name: 字段名
Returns:
表名字段不存在时返回 None
"""
return SchemaCache.get_instance().get_field_table(field_name)
def match_fields_to_tables(field_names: Set[str]) -> Dict[str, List[str]]:
"""将字段集合按表分组。
Args:
field_names: 字段名集合
Returns:
表名到字段列表的映射
"""
return SchemaCache.get_instance().match_fields_to_tables(field_names)