"""数据路由器。 按需取数、组装核心宽表。 负责根据数据规格从数据源拉取数据,并组装成统一的宽表格式。 支持内存数据源(用于测试)和真实数据库连接。 支持标准等值匹配和 asof_backward(财务数据)两种拼接模式。 """ from typing import Any, Dict, List, Optional, Set, Union import threading import polars as pl from src.factors.engine.data_spec import DataSpec from src.data.storage import Storage from src.data.financial_loader import FinancialLoader class DataRouter: """数据路由器 - 按需取数、组装核心宽表。 负责根据数据规格从数据源拉取数据,并组装成统一的宽表格式。 支持内存数据源(用于测试)和真实数据库连接。 Attributes: data_source: 数据源,可以是内存 DataFrame 字典或数据库连接 is_memory_mode: 是否为内存模式 """ def __init__(self, data_source: Optional[Dict[str, pl.DataFrame]] = None) -> None: """初始化数据路由器。 Args: data_source: 内存数据源,字典格式 {表名: DataFrame} 为 None 时自动连接 DuckDB 数据库 """ self.data_source = data_source or {} self.is_memory_mode = data_source is not None self._cache: Dict[str, pl.DataFrame] = {} self._lock = threading.Lock() # 数据库模式下初始化 Storage 和 FinancialLoader if not self.is_memory_mode: self._storage = Storage() self._financial_loader = FinancialLoader() else: self._storage = None self._financial_loader = None def fetch_data( self, data_specs: List[DataSpec], start_date: str, end_date: str, stock_codes: Optional[List[str]] = None, ) -> pl.DataFrame: """根据数据规格获取并组装核心宽表。 Args: data_specs: 数据规格列表 start_date: 开始日期 (YYYYMMDD) end_date: 结束日期 (YYYYMMDD) stock_codes: 股票代码列表,None 表示全市场 Returns: 组装好的核心宽表 DataFrame Raises: ValueError: 当数据源中缺少必要的表或字段时 """ if not data_specs: raise ValueError("数据规格不能为空") # 收集所有需要的表和字段 required_tables: Dict[str, Set[str]] = {} for spec in data_specs: if spec.table not in required_tables: required_tables[spec.table] = set() required_tables[spec.table].update(spec.columns) # 从数据源获取各表数据(使用合并后的 required_tables,避免重复加载) table_data = {} for table_name, columns in required_tables.items(): # 判断是标准表还是财务表 is_financial = any( s.table == table_name and s.join_type == "asof_backward" for s in data_specs ) if is_financial: # 财务表:找到对应的 spec 获取 join 配置 financial_spec = next( s for s in data_specs if s.table == table_name and s.join_type == "asof_backward" ) spec = DataSpec( table=table_name, columns=list(columns), join_type="asof_backward", left_on=financial_spec.left_on, right_on=financial_spec.right_on, ) else: # 标准表 spec = DataSpec( table=table_name, columns=list(columns), join_type="standard", ) df = self._load_table_from_spec( spec=spec, start_date=start_date, end_date=end_date, stock_codes=stock_codes, ) table_data[table_name] = df # 组装核心宽表(支持多种 join 类型) core_table = self._assemble_wide_table_with_specs( table_data, data_specs, start_date, end_date ) return core_table def _load_table_from_spec( self, spec: DataSpec, start_date: str, end_date: str, stock_codes: Optional[List[str]] = None, ) -> pl.DataFrame: """根据数据规格加载单个表的数据。 根据 spec.join_type 选择不同的加载方式: - standard: 使用原有逻辑,基于 trade_date - asof_backward: 使用 FinancialLoader,基于 f_ann_date,扩展回看期 Args: spec: 数据规格 start_date: 开始日期 end_date: 结束日期 stock_codes: 股票代码过滤 Returns: 过滤后的 DataFrame """ cache_key = ( f"{spec.table}_{spec.join_type}_{start_date}_{end_date}_{stock_codes}" ) with self._lock: if cache_key in self._cache: return self._cache[cache_key] if spec.join_type == "asof_backward": # 财务数据使用 FinancialLoader if self._financial_loader is None: raise RuntimeError("FinancialLoader 未初始化") # 扩展日期范围(回看1年) adjusted_start, _ = self._financial_loader.get_date_range_with_lookback( start_date, end_date ) # 处理 stock_codes ts_code = stock_codes[0] if stock_codes and len(stock_codes) == 1 else None df = self._financial_loader.load_financial_data( table_name=spec.table, columns=spec.columns, start_date=adjusted_start, end_date=end_date, ts_code=ts_code, ) # 如果 stock_codes 是列表且长度 > 1,在内存中过滤 if stock_codes is not None and len(stock_codes) > 1: df = df.filter(pl.col("ts_code").is_in(stock_codes)) else: # 标准表使用原有逻辑 df = self._load_table( table_name=spec.table, columns=spec.columns, start_date=start_date, end_date=end_date, stock_codes=stock_codes, ) with self._lock: self._cache[cache_key] = df return df def _load_table( self, table_name: str, columns: List[str], start_date: str, end_date: str, stock_codes: Optional[List[str]] = None, ) -> pl.DataFrame: """加载单个表的数据。 Args: table_name: 表名 columns: 需要的字段 start_date: 开始日期 end_date: 结束日期 stock_codes: 股票代码过滤 Returns: 过滤后的 DataFrame """ cache_key = f"{table_name}_{start_date}_{end_date}_{stock_codes}" with self._lock: if cache_key in self._cache: return self._cache[cache_key] if self.is_memory_mode: df = self._load_from_memory( table_name, columns, start_date, end_date, stock_codes ) else: df = self._load_from_database( table_name, columns, start_date, end_date, stock_codes ) with self._lock: self._cache[cache_key] = df return df def _load_from_memory( self, table_name: str, columns: List[str], start_date: str, end_date: str, stock_codes: Optional[List[str]] = None, ) -> pl.DataFrame: """从内存数据源加载数据。""" if table_name not in self.data_source: raise ValueError(f"内存数据源中缺少表: {table_name}") df = self.data_source[table_name] # 确保必需字段存在 for col in columns: if col not in df.columns and col not in ["ts_code", "trade_date"]: raise ValueError(f"表 {table_name} 缺少字段: {col}") # 过滤日期和股票 df = df.filter( (pl.col("trade_date") >= start_date) & (pl.col("trade_date") <= end_date) ) if stock_codes is not None: df = df.filter(pl.col("ts_code").is_in(stock_codes)) # 选择需要的列(避免重复) base_cols = ["ts_code", "trade_date"] extra_cols = [c for c in columns if c in df.columns and c not in base_cols] select_cols = base_cols + extra_cols return df.select(select_cols) def _load_from_database( self, table_name: str, columns: List[str], start_date: str, end_date: str, stock_codes: Optional[List[str]] = None, ) -> pl.DataFrame: """从 DuckDB 数据库加载数据。 利用 Storage.load_polars() 方法,支持 SQL 查询下推。 """ if self._storage is None: raise RuntimeError("Storage 未初始化") # 检查表是否存在 if not self._storage.exists(table_name): raise ValueError(f"数据库中不存在表: {table_name}") # 构建查询参数 # Storage.load_polars 目前只支持单个 ts_code,需要处理列表情况 if stock_codes is not None and len(stock_codes) == 1: ts_code_filter = stock_codes[0] else: ts_code_filter = None try: # 从数据库加载原始数据 df = self._storage.load_polars( name=table_name, start_date=start_date, end_date=end_date, ts_code=ts_code_filter, ) except Exception as e: raise RuntimeError(f"从数据库加载表 {table_name} 失败: {e}") # 如果 stock_codes 是列表且长度 > 1,在内存中过滤 if stock_codes is not None and len(stock_codes) > 1: df = df.filter(pl.col("ts_code").is_in(stock_codes)) # 检查必需字段 for col in columns: if col not in df.columns and col not in ["ts_code", "trade_date"]: raise ValueError(f"表 {table_name} 缺少字段: {col}") # 选择需要的列(避免重复) base_cols = ["ts_code", "trade_date"] extra_cols = [c for c in columns if c in df.columns and c not in base_cols] select_cols = base_cols + extra_cols return df.select(select_cols) def _assemble_wide_table( self, table_data: Dict[str, pl.DataFrame], required_tables: Dict[str, Set[str]], ) -> pl.DataFrame: """组装多表数据为核心宽表。 使用 left join 合并各表数据,以第一个表为基准。 Args: table_data: 表名到 DataFrame 的映射 required_tables: 表名到字段集合的映射 Returns: 组装后的宽表 """ if not table_data: raise ValueError("没有数据可组装") # 以第一个表为基准 base_table_name = list(table_data.keys())[0] result = table_data[base_table_name] # 与其他表 join for table_name, df in table_data.items(): if table_name == base_table_name: continue # 使用 ts_code 和 trade_date 作为 join 键 result = result.join( df, on=["ts_code", "trade_date"], how="left", ) return result def _assemble_wide_table_with_specs( self, table_data: Dict[str, pl.DataFrame], data_specs: List[DataSpec], start_date: str, end_date: str, ) -> pl.DataFrame: """组装多表数据为核心宽表(支持多种 join 类型)。 支持标准等值匹配和 asof_backward 两种模式。 性能优化: - 在开始时统一将 trade_date 转为 pl.Date - 所有 asof join 全部在 pl.Date 类型下完成 - 返回前统一转回字符串格式 Args: table_data: 表名到 DataFrame 的映射 data_specs: 数据规格列表 start_date: 开始日期 end_date: 结束日期 Returns: 组装后的宽表 """ if not table_data: raise ValueError("没有数据可组装") # 从 data_specs 判断每个表的 join 类型 table_join_types = {} for spec in data_specs: if spec.table not in table_join_types: table_join_types[spec.table] = spec.join_type # 分离标准表和 asof 表(基于 table_data 的表名,避免重复) standard_tables = [ t for t in table_data.keys() if table_join_types.get(t, "standard") == "standard" ] asof_tables = [ t for t in table_data.keys() if table_join_types.get(t) == "asof_backward" ] # 先合并所有标准表(使用 trade_date) base_df = None for table_name in standard_tables: df = table_data[table_name] if base_df is None: base_df = df else: # 使用 ts_code 和 trade_date 作为 join 键 # 注:根据动态路由原则,除 ts_code/trade_date 外不应有重复字段 # 如果出现重复,说明 SchemaCache 的字段映射有问题 base_df = base_df.join( df, on=["ts_code", "trade_date"], how="left", ) if base_df is None: raise ValueError("至少需要一张标准行情表作为基础") # 【性能优化】统一转换 trade_date 为 Date 类型(只转换一次) if asof_tables: base_df = base_df.with_columns( [ pl.col("trade_date") .str.strptime(pl.Date, "%Y%m%d") .alias("trade_date") ] ) # 确保已排序(join_asof 要求) base_df = base_df.sort(["ts_code", "trade_date"]) # 逐个合并 asof 表(所有 join 都在 Date 类型下进行) for table_name in asof_tables: df_financial = table_data[table_name] # 提取需要保留的字段(排除 join 键和元数据字段) # 从 data_specs 中找到对应表的 columns table_columns = set() for spec in data_specs: if spec.table == table_name: table_columns.update(spec.columns) financial_cols = [ c for c in table_columns if c not in [ "ts_code", "f_ann_date", "report_type", "update_flag", "end_date", ] ] if self._financial_loader is None: raise RuntimeError("FinancialLoader 未初始化") base_df = self._financial_loader.merge_financial_with_price( base_df, df_financial, financial_cols ) # 【性能优化】所有 asof join 完成后,统一转回字符串格式 if asof_tables: base_df = base_df.with_columns( [pl.col("trade_date").dt.strftime("%Y%m%d").alias("trade_date")] ) return base_df def clear_cache(self) -> None: """清除数据缓存。""" with self._lock: self._cache.clear() # 数据库模式下清理 Storage 连接(可选) if not self.is_memory_mode and self._storage is not None: # Storage 使用单例模式,不需要关闭连接 pass