feat(engine): 实现 DataRouter 数据库连接功能
This commit is contained in:
@@ -24,6 +24,7 @@ from src.factors.dsl import (
|
||||
)
|
||||
from src.factors.compiler import DependencyExtractor
|
||||
from src.factors.translator import PolarsTranslator
|
||||
from src.data.storage import Storage
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -78,13 +79,19 @@ class DataRouter:
|
||||
|
||||
Args:
|
||||
data_source: 内存数据源,字典格式 {表名: DataFrame}
|
||||
为 None 时需要在子类中实现数据库连接
|
||||
为 None 时自动连接 DuckDB 数据库
|
||||
"""
|
||||
self.data_source = data_source or {}
|
||||
self.is_memory_mode = data_source is not None
|
||||
self._cache: Dict[str, pl.DataFrame] = {}
|
||||
self._lock = threading.Lock()
|
||||
|
||||
# 数据库模式下初始化 Storage
|
||||
if not self.is_memory_mode:
|
||||
self._storage = Storage()
|
||||
else:
|
||||
self._storage = None
|
||||
|
||||
def fetch_data(
|
||||
self,
|
||||
data_specs: List[DataSpec],
|
||||
@@ -171,40 +178,105 @@ class DataRouter:
|
||||
return self._cache[cache_key]
|
||||
|
||||
if self.is_memory_mode:
|
||||
if table_name not in self.data_source:
|
||||
raise ValueError(f"内存数据源中缺少表: {table_name}")
|
||||
|
||||
df = self.data_source[table_name]
|
||||
|
||||
# 确保必需字段存在
|
||||
for col in columns:
|
||||
if col not in df.columns and col not in ["ts_code", "trade_date"]:
|
||||
raise ValueError(f"表 {table_name} 缺少字段: {col}")
|
||||
|
||||
# 过滤日期和股票
|
||||
df = df.filter(
|
||||
(pl.col("trade_date") >= start_date)
|
||||
& (pl.col("trade_date") <= end_date)
|
||||
df = self._load_from_memory(
|
||||
table_name, columns, start_date, end_date, stock_codes
|
||||
)
|
||||
|
||||
if stock_codes is not None:
|
||||
df = df.filter(pl.col("ts_code").is_in(stock_codes))
|
||||
|
||||
# 选择需要的列
|
||||
select_cols = ["ts_code", "trade_date"] + [
|
||||
c for c in columns if c in df.columns
|
||||
]
|
||||
df = df.select(select_cols)
|
||||
|
||||
else:
|
||||
# TODO: 实现真实数据库连接(DuckDB)
|
||||
raise NotImplementedError("数据库连接模式尚未实现")
|
||||
df = self._load_from_database(
|
||||
table_name, columns, start_date, end_date, stock_codes
|
||||
)
|
||||
|
||||
with self._lock:
|
||||
self._cache[cache_key] = df
|
||||
|
||||
return df
|
||||
|
||||
def _load_from_memory(
|
||||
self,
|
||||
table_name: str,
|
||||
columns: List[str],
|
||||
start_date: str,
|
||||
end_date: str,
|
||||
stock_codes: Optional[List[str]] = None,
|
||||
) -> pl.DataFrame:
|
||||
"""从内存数据源加载数据。"""
|
||||
if table_name not in self.data_source:
|
||||
raise ValueError(f"内存数据源中缺少表: {table_name}")
|
||||
|
||||
df = self.data_source[table_name]
|
||||
|
||||
# 确保必需字段存在
|
||||
for col in columns:
|
||||
if col not in df.columns and col not in ["ts_code", "trade_date"]:
|
||||
raise ValueError(f"表 {table_name} 缺少字段: {col}")
|
||||
|
||||
# 过滤日期和股票
|
||||
df = df.filter(
|
||||
(pl.col("trade_date") >= start_date) & (pl.col("trade_date") <= end_date)
|
||||
)
|
||||
|
||||
if stock_codes is not None:
|
||||
df = df.filter(pl.col("ts_code").is_in(stock_codes))
|
||||
|
||||
# 选择需要的列
|
||||
select_cols = ["ts_code", "trade_date"] + [
|
||||
c for c in columns if c in df.columns
|
||||
]
|
||||
return df.select(select_cols)
|
||||
|
||||
def _load_from_database(
|
||||
self,
|
||||
table_name: str,
|
||||
columns: List[str],
|
||||
start_date: str,
|
||||
end_date: str,
|
||||
stock_codes: Optional[List[str]] = None,
|
||||
) -> pl.DataFrame:
|
||||
"""从 DuckDB 数据库加载数据。
|
||||
|
||||
利用 Storage.load_polars() 方法,支持 SQL 查询下推。
|
||||
"""
|
||||
if self._storage is None:
|
||||
raise RuntimeError("Storage 未初始化")
|
||||
|
||||
# 检查表是否存在
|
||||
if not self._storage.exists(table_name):
|
||||
raise ValueError(f"数据库中不存在表: {table_name}")
|
||||
|
||||
# 构建查询参数
|
||||
# Storage.load_polars 目前只支持单个 ts_code,需要处理列表情况
|
||||
if stock_codes is not None and len(stock_codes) == 1:
|
||||
ts_code_filter = stock_codes[0]
|
||||
else:
|
||||
ts_code_filter = None
|
||||
|
||||
try:
|
||||
# 从数据库加载原始数据
|
||||
df = self._storage.load_polars(
|
||||
name=table_name,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
ts_code=ts_code_filter,
|
||||
)
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"从数据库加载表 {table_name} 失败: {e}")
|
||||
|
||||
# 如果 stock_codes 是列表且长度 > 1,在内存中过滤
|
||||
if stock_codes is not None and len(stock_codes) > 1:
|
||||
df = df.filter(pl.col("ts_code").is_in(stock_codes))
|
||||
|
||||
# 检查必需字段
|
||||
for col in columns:
|
||||
if col not in df.columns and col not in ["ts_code", "trade_date"]:
|
||||
raise ValueError(f"表 {table_name} 缺少字段: {col}")
|
||||
|
||||
# 选择需要的列
|
||||
select_cols = ["ts_code", "trade_date"] + [
|
||||
c for c in columns if c in df.columns
|
||||
]
|
||||
|
||||
return df.select(select_cols)
|
||||
|
||||
def _assemble_wide_table(
|
||||
self,
|
||||
table_data: Dict[str, pl.DataFrame],
|
||||
@@ -275,6 +347,11 @@ class DataRouter:
|
||||
with self._lock:
|
||||
self._cache.clear()
|
||||
|
||||
# 数据库模式下清理 Storage 连接(可选)
|
||||
if not self.is_memory_mode and self._storage is not None:
|
||||
# Storage 使用单例模式,不需要关闭连接
|
||||
pass
|
||||
|
||||
|
||||
class ExecutionPlanner:
|
||||
"""执行计划生成器。
|
||||
|
||||
Reference in New Issue
Block a user