From 1a6fc2eeba1ab4c19144fe528ab37077bebfdd8b Mon Sep 17 00:00:00 2001 From: liaozhaorun <1300336796@qq.com> Date: Mon, 2 Mar 2026 20:47:01 +0800 Subject: [PATCH] =?UTF-8?q?feat(engine):=20=E5=AE=9E=E7=8E=B0=20DataRouter?= =?UTF-8?q?=20=E6=95=B0=E6=8D=AE=E5=BA=93=E8=BF=9E=E6=8E=A5=E5=8A=9F?= =?UTF-8?q?=E8=83=BD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/factors/engine.py | 131 +++++++++++++++++++++++++++++++++--------- 1 file changed, 104 insertions(+), 27 deletions(-) diff --git a/src/factors/engine.py b/src/factors/engine.py index 9e0ea55..7ea9f11 100644 --- a/src/factors/engine.py +++ b/src/factors/engine.py @@ -24,6 +24,7 @@ from src.factors.dsl import ( ) from src.factors.compiler import DependencyExtractor from src.factors.translator import PolarsTranslator +from src.data.storage import Storage @dataclass @@ -78,13 +79,19 @@ class DataRouter: Args: data_source: 内存数据源,字典格式 {表名: DataFrame} - 为 None 时需要在子类中实现数据库连接 + 为 None 时自动连接 DuckDB 数据库 """ self.data_source = data_source or {} self.is_memory_mode = data_source is not None self._cache: Dict[str, pl.DataFrame] = {} self._lock = threading.Lock() + # 数据库模式下初始化 Storage + if not self.is_memory_mode: + self._storage = Storage() + else: + self._storage = None + def fetch_data( self, data_specs: List[DataSpec], @@ -171,40 +178,105 @@ class DataRouter: return self._cache[cache_key] if self.is_memory_mode: - if table_name not in self.data_source: - raise ValueError(f"内存数据源中缺少表: {table_name}") - - df = self.data_source[table_name] - - # 确保必需字段存在 - for col in columns: - if col not in df.columns and col not in ["ts_code", "trade_date"]: - raise ValueError(f"表 {table_name} 缺少字段: {col}") - - # 过滤日期和股票 - df = df.filter( - (pl.col("trade_date") >= start_date) - & (pl.col("trade_date") <= end_date) + df = self._load_from_memory( + table_name, columns, start_date, end_date, stock_codes ) - - if stock_codes is not None: - df = df.filter(pl.col("ts_code").is_in(stock_codes)) - - # 选择需要的列 - select_cols = ["ts_code", "trade_date"] + [ - c for c in columns if c in df.columns - ] - df = df.select(select_cols) - else: - # TODO: 实现真实数据库连接(DuckDB) - raise NotImplementedError("数据库连接模式尚未实现") + df = self._load_from_database( + table_name, columns, start_date, end_date, stock_codes + ) with self._lock: self._cache[cache_key] = df return df + def _load_from_memory( + self, + table_name: str, + columns: List[str], + start_date: str, + end_date: str, + stock_codes: Optional[List[str]] = None, + ) -> pl.DataFrame: + """从内存数据源加载数据。""" + if table_name not in self.data_source: + raise ValueError(f"内存数据源中缺少表: {table_name}") + + df = self.data_source[table_name] + + # 确保必需字段存在 + for col in columns: + if col not in df.columns and col not in ["ts_code", "trade_date"]: + raise ValueError(f"表 {table_name} 缺少字段: {col}") + + # 过滤日期和股票 + df = df.filter( + (pl.col("trade_date") >= start_date) & (pl.col("trade_date") <= end_date) + ) + + if stock_codes is not None: + df = df.filter(pl.col("ts_code").is_in(stock_codes)) + + # 选择需要的列 + select_cols = ["ts_code", "trade_date"] + [ + c for c in columns if c in df.columns + ] + return df.select(select_cols) + + def _load_from_database( + self, + table_name: str, + columns: List[str], + start_date: str, + end_date: str, + stock_codes: Optional[List[str]] = None, + ) -> pl.DataFrame: + """从 DuckDB 数据库加载数据。 + + 利用 Storage.load_polars() 方法,支持 SQL 查询下推。 + """ + if self._storage is None: + raise RuntimeError("Storage 未初始化") + + # 检查表是否存在 + if not self._storage.exists(table_name): + raise ValueError(f"数据库中不存在表: {table_name}") + + # 构建查询参数 + # Storage.load_polars 目前只支持单个 ts_code,需要处理列表情况 + if stock_codes is not None and len(stock_codes) == 1: + ts_code_filter = stock_codes[0] + else: + ts_code_filter = None + + try: + # 从数据库加载原始数据 + df = self._storage.load_polars( + name=table_name, + start_date=start_date, + end_date=end_date, + ts_code=ts_code_filter, + ) + except Exception as e: + raise RuntimeError(f"从数据库加载表 {table_name} 失败: {e}") + + # 如果 stock_codes 是列表且长度 > 1,在内存中过滤 + if stock_codes is not None and len(stock_codes) > 1: + df = df.filter(pl.col("ts_code").is_in(stock_codes)) + + # 检查必需字段 + for col in columns: + if col not in df.columns and col not in ["ts_code", "trade_date"]: + raise ValueError(f"表 {table_name} 缺少字段: {col}") + + # 选择需要的列 + select_cols = ["ts_code", "trade_date"] + [ + c for c in columns if c in df.columns + ] + + return df.select(select_cols) + def _assemble_wide_table( self, table_data: Dict[str, pl.DataFrame], @@ -275,6 +347,11 @@ class DataRouter: with self._lock: self._cache.clear() + # 数据库模式下清理 Storage 连接(可选) + if not self.is_memory_mode and self._storage is not None: + # Storage 使用单例模式,不需要关闭连接 + pass + class ExecutionPlanner: """执行计划生成器。