feat(factors): 添加 SchemaCache 实现数据库表结构自动扫描
This commit is contained in:
@@ -16,6 +16,7 @@ from src.factors.dsl import (
|
|||||||
from src.factors.compiler import DependencyExtractor
|
from src.factors.compiler import DependencyExtractor
|
||||||
from src.factors.translator import PolarsTranslator
|
from src.factors.translator import PolarsTranslator
|
||||||
from src.factors.engine.data_spec import DataSpec, ExecutionPlan
|
from src.factors.engine.data_spec import DataSpec, ExecutionPlan
|
||||||
|
from src.factors.engine.schema_cache import get_schema_cache
|
||||||
|
|
||||||
|
|
||||||
class ExecutionPlanner:
|
class ExecutionPlanner:
|
||||||
@@ -73,9 +74,8 @@ class ExecutionPlanner:
|
|||||||
) -> List[DataSpec]:
|
) -> List[DataSpec]:
|
||||||
"""从依赖推导数据规格。
|
"""从依赖推导数据规格。
|
||||||
|
|
||||||
基础行情字段(open, high, low, close, vol, amount, pre_close, change, pct_chg)
|
使用 SchemaCache 动态扫描数据库表结构,自动匹配字段到对应的表。
|
||||||
默认从 pro_bar 表获取。
|
表结构只扫描一次并缓存在内存中。
|
||||||
每日指标字段(total_mv, circ_mv, pe, pb 等)从 daily_basic 表获取。
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
dependencies: 依赖的字段集合
|
dependencies: 依赖的字段集合
|
||||||
@@ -84,69 +84,16 @@ class ExecutionPlanner:
|
|||||||
Returns:
|
Returns:
|
||||||
数据规格列表
|
数据规格列表
|
||||||
"""
|
"""
|
||||||
# 基础行情字段集合(这些字段从 pro_bar 表获取)
|
# 使用 SchemaCache 自动匹配字段到表
|
||||||
pro_bar_fields = {
|
schema_cache = get_schema_cache()
|
||||||
"open",
|
table_to_fields = schema_cache.match_fields_to_tables(dependencies)
|
||||||
"high",
|
|
||||||
"low",
|
|
||||||
"close",
|
|
||||||
"vol",
|
|
||||||
"amount",
|
|
||||||
"pre_close",
|
|
||||||
"change",
|
|
||||||
"pct_chg",
|
|
||||||
"turnover_rate",
|
|
||||||
"volume_ratio",
|
|
||||||
}
|
|
||||||
|
|
||||||
# 每日指标字段集合(这些字段从 daily_basic 表获取)
|
|
||||||
daily_basic_fields = {
|
|
||||||
"turnover_rate_f",
|
|
||||||
"pe",
|
|
||||||
"pe_ttm",
|
|
||||||
"pb",
|
|
||||||
"ps",
|
|
||||||
"ps_ttm",
|
|
||||||
"dv_ratio",
|
|
||||||
"dv_ttm",
|
|
||||||
"total_share",
|
|
||||||
"float_share",
|
|
||||||
"free_share",
|
|
||||||
"total_mv",
|
|
||||||
"circ_mv",
|
|
||||||
}
|
|
||||||
|
|
||||||
# 将依赖分为不同表的字段
|
|
||||||
pro_bar_deps = dependencies & pro_bar_fields
|
|
||||||
daily_basic_deps = dependencies & daily_basic_fields
|
|
||||||
other_deps = dependencies - pro_bar_fields - daily_basic_fields
|
|
||||||
|
|
||||||
data_specs = []
|
data_specs = []
|
||||||
|
for table_name, columns in table_to_fields.items():
|
||||||
# pro_bar 表的数据规格
|
|
||||||
if pro_bar_deps:
|
|
||||||
data_specs.append(
|
data_specs.append(
|
||||||
DataSpec(
|
DataSpec(
|
||||||
table="pro_bar",
|
table=table_name,
|
||||||
columns=sorted(pro_bar_deps),
|
columns=columns,
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
# daily_basic 表的数据规格
|
|
||||||
if daily_basic_deps:
|
|
||||||
data_specs.append(
|
|
||||||
DataSpec(
|
|
||||||
table="daily_basic",
|
|
||||||
columns=sorted(daily_basic_deps),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
# 其他字段从 daily 表获取
|
|
||||||
if other_deps:
|
|
||||||
data_specs.append(
|
|
||||||
DataSpec(
|
|
||||||
table="daily",
|
|
||||||
columns=sorted(other_deps),
|
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
247
src/factors/engine/schema_cache.py
Normal file
247
src/factors/engine/schema_cache.py
Normal file
@@ -0,0 +1,247 @@
|
|||||||
|
"""表结构缓存管理器。
|
||||||
|
|
||||||
|
提供动态扫描数据库表结构并缓存的功能,避免重复扫描。
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Dict, List, Optional, Set
|
||||||
|
|
||||||
|
from src.data.storage import Storage
|
||||||
|
|
||||||
|
|
||||||
|
class SchemaCache:
|
||||||
|
"""表结构缓存管理器(单例模式)。
|
||||||
|
|
||||||
|
动态扫描数据库中所有表的字段信息,并在内存中缓存。
|
||||||
|
使用 @lru_cache 确保整个进程生命周期中只扫描一次。
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
_instance: 单例实例
|
||||||
|
_field_to_table_map: 字段到表的映射缓存
|
||||||
|
_table_to_fields_map: 表到字段列表的映射缓存
|
||||||
|
"""
|
||||||
|
|
||||||
|
_instance: Optional["SchemaCache"] = None
|
||||||
|
_field_to_table_map: Optional[Dict[str, str]] = None
|
||||||
|
_table_to_fields_map: Optional[Dict[str, List[str]]] = None
|
||||||
|
|
||||||
|
def __new__(cls) -> "SchemaCache":
|
||||||
|
"""确保单例模式。"""
|
||||||
|
if cls._instance is None:
|
||||||
|
cls._instance = super().__new__(cls)
|
||||||
|
return cls._instance
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_instance(cls) -> "SchemaCache":
|
||||||
|
"""获取 SchemaCache 单例实例。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
SchemaCache 实例
|
||||||
|
"""
|
||||||
|
if cls._instance is None:
|
||||||
|
cls._instance = cls()
|
||||||
|
return cls._instance
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def reset_cache(cls) -> None:
|
||||||
|
"""重置缓存(主要用于测试)。"""
|
||||||
|
cls._field_to_table_map = None
|
||||||
|
cls._table_to_fields_map = None
|
||||||
|
|
||||||
|
def _scan_table_schemas(self) -> Dict[str, List[str]]:
|
||||||
|
"""扫描数据库中所有表的字段信息。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
表名到字段列表的映射字典
|
||||||
|
"""
|
||||||
|
storage = Storage()
|
||||||
|
table_fields: Dict[str, List[str]] = {}
|
||||||
|
|
||||||
|
try:
|
||||||
|
conn = storage._connection
|
||||||
|
|
||||||
|
# 检查连接是否可用
|
||||||
|
if conn is None:
|
||||||
|
print("[SchemaCache] 数据库连接不可用")
|
||||||
|
return {}
|
||||||
|
|
||||||
|
# 使用断言帮助类型检查器
|
||||||
|
assert conn is not None
|
||||||
|
|
||||||
|
# 获取所有表名(排除系统表)
|
||||||
|
tables_result = conn.execute("""
|
||||||
|
SELECT table_name
|
||||||
|
FROM information_schema.tables
|
||||||
|
WHERE table_schema = 'main'
|
||||||
|
AND table_type = 'BASE TABLE'
|
||||||
|
""").fetchall()
|
||||||
|
|
||||||
|
tables = [row[0] for row in tables_result]
|
||||||
|
|
||||||
|
# 获取每个表的字段信息
|
||||||
|
for table_name in tables:
|
||||||
|
columns_result = conn.execute(
|
||||||
|
"""
|
||||||
|
SELECT column_name
|
||||||
|
FROM information_schema.columns
|
||||||
|
WHERE table_name = ?
|
||||||
|
ORDER BY ordinal_position
|
||||||
|
""",
|
||||||
|
[table_name],
|
||||||
|
).fetchall()
|
||||||
|
|
||||||
|
columns = [row[0] for row in columns_result]
|
||||||
|
table_fields[table_name] = columns
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"[SchemaCache] 扫描表结构失败: {e}")
|
||||||
|
# 返回空字典,后续可以使用硬编码的默认配置
|
||||||
|
table_fields = {}
|
||||||
|
|
||||||
|
return table_fields
|
||||||
|
|
||||||
|
def _ensure_scanned(self) -> None:
|
||||||
|
"""确保表结构已扫描(只执行一次)。"""
|
||||||
|
if self._table_to_fields_map is None or self._field_to_table_map is None:
|
||||||
|
table_fields = self._scan_table_schemas()
|
||||||
|
|
||||||
|
# 表到字段的映射
|
||||||
|
self._table_to_fields_map = table_fields
|
||||||
|
|
||||||
|
# 字段到表的映射(一个字段可能在多个表中存在)
|
||||||
|
field_to_tables: Dict[str, List[str]] = {}
|
||||||
|
for table, fields in table_fields.items():
|
||||||
|
for field in fields:
|
||||||
|
if field not in field_to_tables:
|
||||||
|
field_to_tables[field] = []
|
||||||
|
field_to_tables[field].append(table)
|
||||||
|
|
||||||
|
# 优先选择最常用的表(pro_bar > daily_basic > daily)
|
||||||
|
priority_order = {"pro_bar": 1, "daily_basic": 2, "daily": 3}
|
||||||
|
|
||||||
|
self._field_to_table_map = {}
|
||||||
|
for field, tables in field_to_tables.items():
|
||||||
|
# 按优先级排序,选择优先级最高的表
|
||||||
|
sorted_tables = sorted(tables, key=lambda t: priority_order.get(t, 999))
|
||||||
|
self._field_to_table_map[field] = sorted_tables[0]
|
||||||
|
|
||||||
|
def get_table_fields(self, table_name: str) -> List[str]:
|
||||||
|
"""获取指定表的字段列表。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
table_name: 表名
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
字段列表,表不存在时返回空列表
|
||||||
|
"""
|
||||||
|
self._ensure_scanned()
|
||||||
|
if self._table_to_fields_map is None:
|
||||||
|
return []
|
||||||
|
return self._table_to_fields_map.get(table_name, [])
|
||||||
|
|
||||||
|
def get_field_table(self, field_name: str) -> Optional[str]:
|
||||||
|
"""获取包含指定字段的表名。
|
||||||
|
|
||||||
|
如果多个表包含该字段,返回优先级最高的表。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
field_name: 字段名
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
表名,字段不存在时返回 None
|
||||||
|
"""
|
||||||
|
self._ensure_scanned()
|
||||||
|
if self._field_to_table_map is None:
|
||||||
|
return None
|
||||||
|
return self._field_to_table_map.get(field_name)
|
||||||
|
|
||||||
|
def get_all_tables(self) -> List[str]:
|
||||||
|
"""获取所有表名列表。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
表名列表
|
||||||
|
"""
|
||||||
|
self._ensure_scanned()
|
||||||
|
if self._table_to_fields_map is None:
|
||||||
|
return []
|
||||||
|
return list(self._table_to_fields_map.keys())
|
||||||
|
|
||||||
|
def field_exists(self, field_name: str) -> bool:
|
||||||
|
"""检查字段是否存在于任何表中。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
field_name: 字段名
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
是否存在
|
||||||
|
"""
|
||||||
|
self._ensure_scanned()
|
||||||
|
if self._field_to_table_map is None:
|
||||||
|
return False
|
||||||
|
return field_name in self._field_to_table_map
|
||||||
|
|
||||||
|
def match_fields_to_tables(self, field_names: Set[str]) -> Dict[str, List[str]]:
|
||||||
|
"""将字段集合按表分组。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
field_names: 字段名集合
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
表名到字段列表的映射
|
||||||
|
"""
|
||||||
|
self._ensure_scanned()
|
||||||
|
|
||||||
|
table_to_fields: Dict[str, List[str]] = {}
|
||||||
|
|
||||||
|
for field in field_names:
|
||||||
|
table = self.get_field_table(field)
|
||||||
|
if table is not None:
|
||||||
|
if table not in table_to_fields:
|
||||||
|
table_to_fields[table] = []
|
||||||
|
table_to_fields[table].append(field)
|
||||||
|
else:
|
||||||
|
# 字段不存在于任何表,归入 "daily" 表(默认表)
|
||||||
|
if "daily" not in table_to_fields:
|
||||||
|
table_to_fields["daily"] = []
|
||||||
|
table_to_fields["daily"].append(field)
|
||||||
|
|
||||||
|
# 对字段列表排序以保持确定性输出
|
||||||
|
for fields in table_to_fields.values():
|
||||||
|
fields.sort()
|
||||||
|
|
||||||
|
return table_to_fields
|
||||||
|
|
||||||
|
|
||||||
|
# 模块级便捷函数
|
||||||
|
|
||||||
|
|
||||||
|
def get_schema_cache() -> SchemaCache:
|
||||||
|
"""获取 SchemaCache 单例实例。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
SchemaCache 实例
|
||||||
|
"""
|
||||||
|
return SchemaCache.get_instance()
|
||||||
|
|
||||||
|
|
||||||
|
def get_field_table(field_name: str) -> Optional[str]:
|
||||||
|
"""获取包含指定字段的表名。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
field_name: 字段名
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
表名,字段不存在时返回 None
|
||||||
|
"""
|
||||||
|
return SchemaCache.get_instance().get_field_table(field_name)
|
||||||
|
|
||||||
|
|
||||||
|
def match_fields_to_tables(field_names: Set[str]) -> Dict[str, List[str]]:
|
||||||
|
"""将字段集合按表分组。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
field_names: 字段名集合
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
表名到字段列表的映射
|
||||||
|
"""
|
||||||
|
return SchemaCache.get_instance().match_fields_to_tables(field_names)
|
||||||
Reference in New Issue
Block a user