"""数据目录与动态 SQL 路由模块。 用于动态 SQL 生成和数据拉取,解决多表架构下的数据查询痛点。 支持 DAILY(日频精确对齐)和 PIT(低频财务数据,按披露日对齐)两种表类型。 核心特性: - 自动发现 DuckDB 数据库中的表结构 - 支持通过配置覆盖自动发现的元数据 - 智能识别 PIT 类型表(通过 ann_date/f_ann_date 字段) """ from typing import Dict, List, Set, Optional, Literal from dataclasses import dataclass, field from enum import Enum import polars as pl import duckdb from pathlib import Path class TableFrequency(Enum): """表频度类型。""" DAILY = "daily" # 日频数据,精确对齐 PIT = "pit" # 低频数据,按披露日对齐 (Point-In-Time) @dataclass class TableMetadata: """表元数据配置。 Attributes: name: 表名 frequency: 表频度类型(DAILY 或 PIT) date_field: 日期字段名(DAILY 表为 trade_date,PIT 表为 ann_date) code_field: 资产代码字段名(通常为 ts_code) fields: 表中所有字段列表 description: 表描述 """ name: str frequency: TableFrequency date_field: str code_field: str = "ts_code" fields: List[str] = field(default_factory=list) description: str = "" @dataclass class FieldMapping: """字段映射配置。 Attributes: field_name: 字段名 table_name: 所属表名 description: 字段描述 """ field_name: str table_name: str description: str = "" class DatabaseCatalog: """数据库目录类,管理字段到表的映射关系。 核心职责: 1. 自动从 DuckDB 数据库中发现表结构 2. 维护字段到表的映射关系 3. 管理表的元数据(频度类型、日期字段等) 4. 提供字段解析和表路由功能 表类型自动识别规则: - 如果表包含 ann_date 或 f_ann_date 字段,识别为 PIT 类型 - 否则,如果包含 trade_date 字段,识别为 DAILY 类型 Attributes: tables: 表元数据字典,表名 -> TableMetadata field_mappings: 字段映射字典,字段名 -> FieldMapping db_path: 数据库文件路径 Example: >>> catalog = DatabaseCatalog("data/prostock.db") >>> # 自动发现所有表结构 >>> catalog.discover_tables() >>> table = catalog.get_table_for_field("close") >>> print(table) # "daily" """ # PIT 类型表的标识字段(优先级顺序) PIT_DATE_FIELDS = ["ann_date", "f_ann_date", "publish_date"] # DAILY 类型表的标识字段 DAILY_DATE_FIELDS = ["trade_date", "cal_date", "date"] def __init__(self, db_path: Optional[str] = None): """初始化数据库目录。 Args: db_path: 数据库文件路径,如果为 None 则使用默认配置 """ self.tables: Dict[str, TableMetadata] = {} self.field_mappings: Dict[str, FieldMapping] = {} self.db_path = db_path self._table_frequency_overrides: Dict[str, TableFrequency] = {} if db_path: self.discover_tables(db_path) def set_table_frequency_override( self, table_name: str, frequency: TableFrequency ) -> None: """设置表频度类型覆盖。 用于手动指定表的频度类型,覆盖自动识别的结果。 Args: table_name: 表名 frequency: 频度类型(DAILY 或 PIT) """ self._table_frequency_overrides[table_name] = frequency def discover_tables(self, db_path: str) -> None: """自动发现数据库中的所有表结构。 从 information_schema 中读取表和列信息,自动识别: - 表名和字段列表 - 资产代码字段(ts_code) - 日期字段(根据字段名智能识别表类型) - 表频度类型(DAILY 或 PIT) Args: db_path: DuckDB 数据库文件路径 """ db_file = db_path.replace("duckdb://", "").lstrip("/") if not Path(db_file).exists(): print(f"[DatabaseCatalog] 数据库文件不存在: {db_file}") return conn = duckdb.connect(db_file, read_only=True) try: # 获取所有表 tables_query = """ SELECT table_name FROM information_schema.tables WHERE table_schema = 'main' ORDER BY table_name """ tables_result = conn.execute(tables_query).fetchall() for (table_name,) in tables_result: # 获取表的列信息 columns_query = """ SELECT column_name, data_type FROM information_schema.columns WHERE table_name = ? AND table_schema = 'main' ORDER BY ordinal_position """ columns_result = conn.execute(columns_query, [table_name]).fetchall() fields = [col[0] for col in columns_result] # 自动识别表类型和日期字段 frequency, date_field = self._detect_table_type(fields, table_name) # 检查是否有资产代码字段 code_field = "ts_code" if "ts_code" in fields else None if code_field and date_field: # 创建表元数据 metadata = TableMetadata( name=table_name, frequency=frequency, date_field=date_field, code_field=code_field, fields=fields, description=f"自动发现的表: {table_name}", ) self.register_table(metadata) print( f"[DatabaseCatalog] 发现表: {table_name} ({frequency.value}, " f"日期字段: {date_field})" ) finally: conn.close() def _detect_table_type( self, fields: List[str], table_name: str ) -> tuple[TableFrequency, Optional[str]]: """自动检测表的频度类型和日期字段。 检测规则(按优先级): 1. 检查是否有手动覆盖配置 2. 检查是否包含 PIT 标识字段(ann_date, f_ann_date 等) 3. 检查是否包含 DAILY 标识字段(trade_date, cal_date 等) Args: fields: 表的字段列表 table_name: 表名 Returns: (频度类型, 日期字段名) """ # 检查手动覆盖配置 if table_name in self._table_frequency_overrides: frequency = self._table_frequency_overrides[table_name] if frequency == TableFrequency.PIT: for field in self.PIT_DATE_FIELDS: if field in fields: return frequency, field else: for field in self.DAILY_DATE_FIELDS: if field in fields: return frequency, field # 检查 PIT 标识字段 for field in self.PIT_DATE_FIELDS: if field in fields: return TableFrequency.PIT, field # 检查 DAILY 标识字段 for field in self.DAILY_DATE_FIELDS: if field in fields: return TableFrequency.DAILY, field # 默认返回 DAILY,但无日期字段 return TableFrequency.DAILY, None def register_table(self, metadata: TableMetadata) -> None: """注册表元数据。 Args: metadata: 表元数据配置 """ self.tables[metadata.name] = metadata # 自动注册字段映射(如果字段已存在,保留第一个表的映射) for field_name in metadata.fields: if field_name not in self.field_mappings: self.field_mappings[field_name] = FieldMapping( field_name=field_name, table_name=metadata.name, description=f"{metadata.description} - {field_name}", ) def get_table_for_field(self, field: str) -> Optional[str]: """获取字段对应的表名。 Args: field: 字段名 Returns: 表名,如果字段不存在则返回 None """ mapping = self.field_mappings.get(field) return mapping.table_name if mapping else None def get_table_metadata(self, table_name: str) -> Optional[TableMetadata]: """获取表的元数据。 Args: table_name: 表名 Returns: 表元数据,如果不存在则返回 None """ return self.tables.get(table_name) def get_table_frequency(self, table_name: str) -> Optional[TableFrequency]: """获取表的频度类型。 Args: table_name: 表名 Returns: 表频度类型(DAILY 或 PIT),如果不存在则返回 None """ metadata = self.tables.get(table_name) return metadata.frequency if metadata else None def get_required_tables(self, fields: List[str]) -> Set[str]: """获取所需字段涉及的所有表名。 Args: fields: 字段列表 Returns: 涉及的表名集合 """ tables = set() for field in fields: table = self.get_table_for_field(field) if table: tables.add(table) return tables def get_fields_for_table( self, table_name: str, required_fields: List[str] ) -> List[str]: """获取指定表需要的字段列表(包含必要的键字段)。 Args: table_name: 表名 required_fields: 用户请求的所有字段 Returns: 该表需要查询的字段列表(包含键字段) """ metadata = self.tables.get(table_name) if not metadata: return [] # 基础键字段 fields = [metadata.code_field, metadata.date_field] # 添加用户请求的字段(属于该表的) for field in required_fields: if self.get_table_for_field(field) == table_name and field not in fields: fields.append(field) return fields def is_pit_table(self, table_name: str) -> bool: """判断表是否为 PIT 类型。 Args: table_name: 表名 Returns: 是否为 PIT 类型表 """ frequency = self.get_table_frequency(table_name) return frequency == TableFrequency.PIT class SQLQueryBuilder: """SQL 查询构建器。 根据表类型(DAILY/PIT)构建优化的 SQL 查询。 """ def __init__(self, catalog: DatabaseCatalog): """初始化 SQL 构建器。 Args: catalog: 数据库目录实例 """ self.catalog = catalog def build_query( self, table_name: str, fields: List[str], start_date: str, end_date: str, lookback_days: int = 90, ) -> str: """构建优化的 SQL 查询。 对于 PIT 类型表,会自动向前回溯 lookback_days 天, 以确保起始日期能匹配到最近的旧数据。 Args: table_name: 表名 fields: 需要查询的字段列表 start_date: 开始日期(YYYYMMDD 格式) end_date: 结束日期(YYYYMMDD 格式) lookback_days: PIT 表回溯天数(默认90天) Returns: 构建好的 SQL 查询语句 """ metadata = self.catalog.get_table_metadata(table_name) if not metadata: raise ValueError(f"未知的表: {table_name}") # 构建字段列表 fields_str = ", ".join(fields) # 根据表类型构建 WHERE 条件 if metadata.frequency == TableFrequency.PIT: # PIT 表:按公告日期查询,需要向前回溯 date_field = metadata.date_field query_start = self._adjust_start_date(start_date, lookback_days) query_start_fmt = self._format_date(query_start) end_date_fmt = self._format_date(end_date) sql = f""" SELECT {fields_str} FROM {table_name} WHERE {date_field} >= '{query_start_fmt}' AND {date_field} <= '{end_date_fmt}' ORDER BY {metadata.code_field}, {date_field} """ else: # DAILY 表:直接按交易日期查询 date_field = metadata.date_field start_date_fmt = self._format_date(start_date) end_date_fmt = self._format_date(end_date) sql = f""" SELECT {fields_str} FROM {table_name} WHERE {date_field} >= '{start_date_fmt}' AND {date_field} <= '{end_date_fmt}' ORDER BY {metadata.code_field}, {date_field} """ return sql.strip() def _format_date(self, date_str: str) -> str: """将 YYYYMMDD 格式转换为 YYYY-MM-DD 格式。 Args: date_str: 日期字符串(YYYYMMDD 格式) Returns: 格式化后的日期字符串(YYYY-MM-DD 格式) """ return f"{date_str[:4]}-{date_str[4:6]}-{date_str[6:8]}" def _adjust_start_date(self, start_date: str, days: int) -> str: """调整开始日期(向前回溯指定天数)。 Args: start_date: 开始日期(YYYYMMDD 格式) days: 回溯天数 Returns: 调整后的日期(YYYYMMDD 格式) """ from datetime import datetime, timedelta dt = datetime.strptime(start_date, "%Y%m%d") adjusted_dt = dt - timedelta(days=days) return adjusted_dt.strftime("%Y%m%d") def query_duckdb_to_polars(query: str, db_path: str) -> pl.LazyFrame: """执行 DuckDB 查询并返回 Polars LazyFrame。 使用 duckdb.connect().sql(query).pl() 实现高速数据流转。 默认使用 read_only=True 模式,允许多进程并发读取。 Args: query: SQL 查询语句 db_path: DuckDB 数据库文件路径 Returns: Polars LazyFrame """ conn = duckdb.connect(db_path, read_only=True) try: # DuckDB -> Polars 高速转换 df = conn.sql(query).pl() return df.lazy() finally: conn.close() def build_context_lazyframe( required_fields: List[str], start_date: str, end_date: str, db_uri: str, catalog: Optional[DatabaseCatalog] = None, lookback_days: int = 90, ) -> pl.LazyFrame: """构建上下文 LazyFrame,根据所需字段动态生成 SQL 并合并数据。 核心逻辑: 1. 根据 required_fields 反查涉及的表名 2. 对每个表生成精简的 SQL 查询 3. 从 DuckDB 加载数据到 Polars LazyFrame 4. 合并不同表的数据: - DAILY 表按 ["trade_date", "ts_code"] 进行 left_join - PIT 表使用 join_asof 按公告日期对齐 5. 最终按 ["ts_code", "trade_date"] 排序 Args: required_fields: 需要的字段列表 start_date: 开始日期(YYYYMMDD 格式) end_date: 结束日期(YYYYMMDD 格式) db_uri: 数据库连接 URI(如 "duckdb:///data/prostock.db") catalog: 数据库目录实例,如果为 None 则自动创建并发现表 lookback_days: PIT 表回溯天数(默认90天) Returns: 合并后的 LazyFrame,包含所有请求的字段 Example: >>> lf = build_context_lazyframe( ... required_fields=["close", "vol", "basic_eps"], ... start_date="20240101", ... end_date="20240131", ... db_uri="duckdb:///data/prostock.db" ... ) >>> df = lf.collect() """ # 解析数据库路径 db_path = db_uri.replace("duckdb://", "").lstrip("/") # 如果没有提供 catalog,自动创建并发现表 if catalog is None: catalog = DatabaseCatalog(db_path) # 获取涉及的表 tables = catalog.get_required_tables(required_fields) if not tables: # 如果没有涉及的表,返回空 DataFrame return pl.LazyFrame({"ts_code": [], "trade_date": []}) # 分离 DAILY 表和 PIT 表 daily_tables: List[str] = [] pit_tables: List[str] = [] for table_name in tables: if catalog.is_pit_table(table_name): pit_tables.append(table_name) else: daily_tables.append(table_name) # 构建 SQL 查询器 query_builder = SQLQueryBuilder(catalog) # 加载 DAILY 表数据 daily_lfs: Dict[str, pl.LazyFrame] = {} for table_name in daily_tables: fields = catalog.get_fields_for_table(table_name, required_fields) sql = query_builder.build_query( table_name=table_name, fields=fields, start_date=start_date, end_date=end_date, ) print(f"[SQL] {sql[:100]}...") lf = query_duckdb_to_polars(sql, db_path) # 统一列名:将表的 date_field 重命名为 trade_date metadata = catalog.get_table_metadata(table_name) if metadata and metadata.date_field != "trade_date": lf = lf.rename({metadata.date_field: "trade_date"}) daily_lfs[table_name] = lf # 加载 PIT 表数据 pit_lfs: Dict[str, pl.LazyFrame] = {} for table_name in pit_tables: fields = catalog.get_fields_for_table(table_name, required_fields) sql = query_builder.build_query( table_name=table_name, fields=fields, start_date=start_date, end_date=end_date, lookback_days=lookback_days, ) print(f"[SQL] {sql[:100]}...") lf = query_duckdb_to_polars(sql, db_path) # PIT 表保持原始公告日期字段(用于 join_asof) pit_lfs[table_name] = lf # 合并所有 DAILY 表(以第一个 daily 表为基准) result_lf: Optional[pl.LazyFrame] = None if daily_lfs: # 使用第一个 daily 表作为基准 first_table = daily_tables[0] result_lf = daily_lfs[first_table] # 合并其他 daily 表 for table_name in daily_tables[1:]: lf = daily_lfs[table_name] result_lf = result_lf.join(lf, on=["trade_date", "ts_code"], how="left") elif pit_lfs: # 如果没有 daily 表,从 PIT 表创建基准时间轴 # 使用第一个 PIT 表的日期范围 first_pit = pit_tables[0] pit_metadata = catalog.get_table_metadata(first_pit) # 从 PIT 表提取所有日期和股票代码组合 result_lf = ( pit_lfs[first_pit] .select([pl.col(pit_metadata.date_field).alias("trade_date"), "ts_code"]) .unique() ) # 如果没有结果,返回空 DataFrame if result_lf is None: return pl.LazyFrame({"ts_code": [], "trade_date": []}) # 合并 PIT 表(使用 join_asof 按公告日期对齐) for table_name in pit_tables: pit_metadata = catalog.get_table_metadata(table_name) lf = pit_lfs[table_name] # join_asof: 按 ts_code 分组,将 PIT 数据对齐到交易日 # 策略为 backward:使用小于等于当前交易日的最新公告数据 result_lf = result_lf.join_asof( lf, left_on="trade_date", right_on=pit_metadata.date_field, by="ts_code", strategy="backward", ) # 最终排序:按 ["ts_code", "trade_date"] 确保时序计算要求 result_lf = result_lf.sort(["ts_code", "trade_date"]) return result_lf if __name__ == "__main__": # 测试代码 print("=" * 60) print("DatabaseCatalog 自动发现测试") print("=" * 60) # 测试自动发现 catalog = DatabaseCatalog("data/prostock.db") print("\n=== 测试字段到表映射 ===") print(f"字段 'close' 对应的表: {catalog.get_table_for_field('close')}") print(f"字段 'vol' 对应的表: {catalog.get_table_for_field('vol')}") print(f"字段 'pe' 对应的表: {catalog.get_table_for_field('pe')}") print(f"字段 'basic_eps' 对应的表: {catalog.get_table_for_field('basic_eps')}") print("\n=== 测试表频度类型 ===") for table_name in catalog.tables: freq = catalog.get_table_frequency(table_name) print(f"表 '{table_name}' 的频度: {freq.value if freq else 'Unknown'}") print("\n=== 测试 SQL 构建 ===") query_builder = SQLQueryBuilder(catalog) daily_sql = query_builder.build_query( table_name="daily", fields=["ts_code", "trade_date", "close", "vol"], start_date="20240101", end_date="20240131", ) print(f"\nDAILY 表 SQL:\n{daily_sql}") pit_sql = query_builder.build_query( table_name="financial_income", fields=["ts_code", "ann_date", "basic_eps", "total_revenue"], start_date="20240101", end_date="20240131", lookback_days=90, ) print(f"\nPIT 表 SQL:\n{pit_sql}") print("\n=== 测试多表字段收集 ===") required_fields = ["close", "vol", "pe", "basic_eps", "total_revenue"] tables = catalog.get_required_tables(required_fields) print(f"字段 {required_fields} 涉及的表: {tables}") for table_name in tables: fields = catalog.get_fields_for_table(table_name, required_fields) print(f" 表 '{table_name}' 需要查询的字段: {fields}") print("\n所有测试通过!")