Files
ProStock/src/data/catalog.py

665 lines
22 KiB
Python
Raw Normal View History

"""数据目录与动态 SQL 路由模块。
用于动态 SQL 生成和数据拉取解决多表架构下的数据查询痛点
支持 DAILY日频精确对齐 PIT低频财务数据按披露日对齐两种表类型
核心特性
- 自动发现 DuckDB 数据库中的表结构
- 支持通过配置覆盖自动发现的元数据
- 智能识别 PIT 类型表通过 ann_date/f_ann_date 字段
"""
from typing import Dict, List, Set, Optional, Literal
from dataclasses import dataclass, field
from enum import Enum
import polars as pl
import duckdb
from pathlib import Path
class TableFrequency(Enum):
"""表频度类型。"""
DAILY = "daily" # 日频数据,精确对齐
PIT = "pit" # 低频数据,按披露日对齐 (Point-In-Time)
@dataclass
class TableMetadata:
"""表元数据配置。
Attributes:
name: 表名
frequency: 表频度类型DAILY PIT
date_field: 日期字段名DAILY 表为 trade_datePIT 表为 ann_date
code_field: 资产代码字段名通常为 ts_code
fields: 表中所有字段列表
description: 表描述
"""
name: str
frequency: TableFrequency
date_field: str
code_field: str = "ts_code"
fields: List[str] = field(default_factory=list)
description: str = ""
@dataclass
class FieldMapping:
"""字段映射配置。
Attributes:
field_name: 字段名
table_name: 所属表名
description: 字段描述
"""
field_name: str
table_name: str
description: str = ""
class DatabaseCatalog:
"""数据库目录类,管理字段到表的映射关系。
核心职责
1. 自动从 DuckDB 数据库中发现表结构
2. 维护字段到表的映射关系
3. 管理表的元数据频度类型日期字段等
4. 提供字段解析和表路由功能
表类型自动识别规则
- 如果表包含 ann_date f_ann_date 字段识别为 PIT 类型
- 否则如果包含 trade_date 字段识别为 DAILY 类型
Attributes:
tables: 表元数据字典表名 -> TableMetadata
field_mappings: 字段映射字典字段名 -> FieldMapping
db_path: 数据库文件路径
Example:
>>> catalog = DatabaseCatalog("data/prostock.db")
>>> # 自动发现所有表结构
>>> catalog.discover_tables()
>>> table = catalog.get_table_for_field("close")
>>> print(table) # "daily"
"""
# PIT 类型表的标识字段(优先级顺序)
PIT_DATE_FIELDS = ["ann_date", "f_ann_date", "publish_date"]
# DAILY 类型表的标识字段
DAILY_DATE_FIELDS = ["trade_date", "cal_date", "date"]
def __init__(self, db_path: Optional[str] = None):
"""初始化数据库目录。
Args:
db_path: 数据库文件路径如果为 None 则使用默认配置
"""
self.tables: Dict[str, TableMetadata] = {}
self.field_mappings: Dict[str, FieldMapping] = {}
self.db_path = db_path
self._table_frequency_overrides: Dict[str, TableFrequency] = {}
if db_path:
self.discover_tables(db_path)
def set_table_frequency_override(
self, table_name: str, frequency: TableFrequency
) -> None:
"""设置表频度类型覆盖。
用于手动指定表的频度类型覆盖自动识别的结果
Args:
table_name: 表名
frequency: 频度类型DAILY PIT
"""
self._table_frequency_overrides[table_name] = frequency
def discover_tables(self, db_path: str) -> None:
"""自动发现数据库中的所有表结构。
information_schema 中读取表和列信息自动识别
- 表名和字段列表
- 资产代码字段ts_code
- 日期字段根据字段名智能识别表类型
- 表频度类型DAILY PIT
Args:
db_path: DuckDB 数据库文件路径
"""
db_file = db_path.replace("duckdb://", "").lstrip("/")
if not Path(db_file).exists():
print(f"[DatabaseCatalog] 数据库文件不存在: {db_file}")
return
conn = duckdb.connect(db_file, read_only=True)
try:
# 获取所有表
tables_query = """
SELECT table_name
FROM information_schema.tables
WHERE table_schema = 'main'
ORDER BY table_name
"""
tables_result = conn.execute(tables_query).fetchall()
for (table_name,) in tables_result:
# 获取表的列信息
columns_query = """
SELECT column_name, data_type
FROM information_schema.columns
WHERE table_name = ? AND table_schema = 'main'
ORDER BY ordinal_position
"""
columns_result = conn.execute(columns_query, [table_name]).fetchall()
fields = [col[0] for col in columns_result]
# 自动识别表类型和日期字段
frequency, date_field = self._detect_table_type(fields, table_name)
# 检查是否有资产代码字段
code_field = "ts_code" if "ts_code" in fields else None
if code_field and date_field:
# 创建表元数据
metadata = TableMetadata(
name=table_name,
frequency=frequency,
date_field=date_field,
code_field=code_field,
fields=fields,
description=f"自动发现的表: {table_name}",
)
self.register_table(metadata)
print(
f"[DatabaseCatalog] 发现表: {table_name} ({frequency.value}, "
f"日期字段: {date_field})"
)
finally:
conn.close()
def _detect_table_type(
self, fields: List[str], table_name: str
) -> tuple[TableFrequency, Optional[str]]:
"""自动检测表的频度类型和日期字段。
检测规则按优先级
1. 检查是否有手动覆盖配置
2. 检查是否包含 PIT 标识字段ann_date, f_ann_date
3. 检查是否包含 DAILY 标识字段trade_date, cal_date
Args:
fields: 表的字段列表
table_name: 表名
Returns:
(频度类型, 日期字段名)
"""
# 检查手动覆盖配置
if table_name in self._table_frequency_overrides:
frequency = self._table_frequency_overrides[table_name]
if frequency == TableFrequency.PIT:
for field in self.PIT_DATE_FIELDS:
if field in fields:
return frequency, field
else:
for field in self.DAILY_DATE_FIELDS:
if field in fields:
return frequency, field
# 检查 PIT 标识字段
for field in self.PIT_DATE_FIELDS:
if field in fields:
return TableFrequency.PIT, field
# 检查 DAILY 标识字段
for field in self.DAILY_DATE_FIELDS:
if field in fields:
return TableFrequency.DAILY, field
# 默认返回 DAILY但无日期字段
return TableFrequency.DAILY, None
def register_table(self, metadata: TableMetadata) -> None:
"""注册表元数据。
Args:
metadata: 表元数据配置
"""
self.tables[metadata.name] = metadata
# 自动注册字段映射(如果字段已存在,保留第一个表的映射)
for field_name in metadata.fields:
if field_name not in self.field_mappings:
self.field_mappings[field_name] = FieldMapping(
field_name=field_name,
table_name=metadata.name,
description=f"{metadata.description} - {field_name}",
)
def get_table_for_field(self, field: str) -> Optional[str]:
"""获取字段对应的表名。
Args:
field: 字段名
Returns:
表名如果字段不存在则返回 None
"""
mapping = self.field_mappings.get(field)
return mapping.table_name if mapping else None
def get_table_metadata(self, table_name: str) -> Optional[TableMetadata]:
"""获取表的元数据。
Args:
table_name: 表名
Returns:
表元数据如果不存在则返回 None
"""
return self.tables.get(table_name)
def get_table_frequency(self, table_name: str) -> Optional[TableFrequency]:
"""获取表的频度类型。
Args:
table_name: 表名
Returns:
表频度类型DAILY PIT如果不存在则返回 None
"""
metadata = self.tables.get(table_name)
return metadata.frequency if metadata else None
def get_required_tables(self, fields: List[str]) -> Set[str]:
"""获取所需字段涉及的所有表名。
Args:
fields: 字段列表
Returns:
涉及的表名集合
"""
tables = set()
for field in fields:
table = self.get_table_for_field(field)
if table:
tables.add(table)
return tables
def get_fields_for_table(
self, table_name: str, required_fields: List[str]
) -> List[str]:
"""获取指定表需要的字段列表(包含必要的键字段)。
Args:
table_name: 表名
required_fields: 用户请求的所有字段
Returns:
该表需要查询的字段列表包含键字段
"""
metadata = self.tables.get(table_name)
if not metadata:
return []
# 基础键字段
fields = [metadata.code_field, metadata.date_field]
# 添加用户请求的字段(属于该表的)
for field in required_fields:
if self.get_table_for_field(field) == table_name and field not in fields:
fields.append(field)
return fields
def is_pit_table(self, table_name: str) -> bool:
"""判断表是否为 PIT 类型。
Args:
table_name: 表名
Returns:
是否为 PIT 类型表
"""
frequency = self.get_table_frequency(table_name)
return frequency == TableFrequency.PIT
class SQLQueryBuilder:
"""SQL 查询构建器。
根据表类型DAILY/PIT构建优化的 SQL 查询
"""
def __init__(self, catalog: DatabaseCatalog):
"""初始化 SQL 构建器。
Args:
catalog: 数据库目录实例
"""
self.catalog = catalog
def build_query(
self,
table_name: str,
fields: List[str],
start_date: str,
end_date: str,
lookback_days: int = 90,
) -> str:
"""构建优化的 SQL 查询。
对于 PIT 类型表会自动向前回溯 lookback_days
以确保起始日期能匹配到最近的旧数据
Args:
table_name: 表名
fields: 需要查询的字段列表
start_date: 开始日期YYYYMMDD 格式
end_date: 结束日期YYYYMMDD 格式
lookback_days: PIT 表回溯天数默认90天
Returns:
构建好的 SQL 查询语句
"""
metadata = self.catalog.get_table_metadata(table_name)
if not metadata:
raise ValueError(f"未知的表: {table_name}")
# 构建字段列表
fields_str = ", ".join(fields)
# 根据表类型构建 WHERE 条件
if metadata.frequency == TableFrequency.PIT:
# PIT 表:按公告日期查询,需要向前回溯
date_field = metadata.date_field
query_start = self._adjust_start_date(start_date, lookback_days)
query_start_fmt = self._format_date(query_start)
end_date_fmt = self._format_date(end_date)
sql = f"""
SELECT {fields_str}
FROM {table_name}
WHERE {date_field} >= '{query_start_fmt}'
AND {date_field} <= '{end_date_fmt}'
ORDER BY {metadata.code_field}, {date_field}
"""
else:
# DAILY 表:直接按交易日期查询
date_field = metadata.date_field
start_date_fmt = self._format_date(start_date)
end_date_fmt = self._format_date(end_date)
sql = f"""
SELECT {fields_str}
FROM {table_name}
WHERE {date_field} >= '{start_date_fmt}'
AND {date_field} <= '{end_date_fmt}'
ORDER BY {metadata.code_field}, {date_field}
"""
return sql.strip()
def _format_date(self, date_str: str) -> str:
"""将 YYYYMMDD 格式转换为 YYYY-MM-DD 格式。
Args:
date_str: 日期字符串YYYYMMDD 格式
Returns:
格式化后的日期字符串YYYY-MM-DD 格式
"""
return f"{date_str[:4]}-{date_str[4:6]}-{date_str[6:8]}"
def _adjust_start_date(self, start_date: str, days: int) -> str:
"""调整开始日期(向前回溯指定天数)。
Args:
start_date: 开始日期YYYYMMDD 格式
days: 回溯天数
Returns:
调整后的日期YYYYMMDD 格式
"""
from datetime import datetime, timedelta
dt = datetime.strptime(start_date, "%Y%m%d")
adjusted_dt = dt - timedelta(days=days)
return adjusted_dt.strftime("%Y%m%d")
def query_duckdb_to_polars(query: str, db_path: str) -> pl.LazyFrame:
"""执行 DuckDB 查询并返回 Polars LazyFrame。
使用 duckdb.connect().sql(query).pl() 实现高速数据流转
默认使用 read_only=True 模式允许多进程并发读取
Args:
query: SQL 查询语句
db_path: DuckDB 数据库文件路径
Returns:
Polars LazyFrame
"""
conn = duckdb.connect(db_path, read_only=True)
try:
# DuckDB -> Polars 高速转换
df = conn.sql(query).pl()
return df.lazy()
finally:
conn.close()
def build_context_lazyframe(
required_fields: List[str],
start_date: str,
end_date: str,
db_uri: str,
catalog: Optional[DatabaseCatalog] = None,
lookback_days: int = 90,
) -> pl.LazyFrame:
"""构建上下文 LazyFrame根据所需字段动态生成 SQL 并合并数据。
核心逻辑
1. 根据 required_fields 反查涉及的表名
2. 对每个表生成精简的 SQL 查询
3. DuckDB 加载数据到 Polars LazyFrame
4. 合并不同表的数据
- DAILY 表按 ["trade_date", "ts_code"] 进行 left_join
- PIT 表使用 join_asof 按公告日期对齐
5. 最终按 ["ts_code", "trade_date"] 排序
Args:
required_fields: 需要的字段列表
start_date: 开始日期YYYYMMDD 格式
end_date: 结束日期YYYYMMDD 格式
db_uri: 数据库连接 URI "duckdb:///data/prostock.db"
catalog: 数据库目录实例如果为 None 则自动创建并发现表
lookback_days: PIT 表回溯天数默认90天
Returns:
合并后的 LazyFrame包含所有请求的字段
Example:
>>> lf = build_context_lazyframe(
... required_fields=["close", "vol", "basic_eps"],
... start_date="20240101",
... end_date="20240131",
... db_uri="duckdb:///data/prostock.db"
... )
>>> df = lf.collect()
"""
# 解析数据库路径
db_path = db_uri.replace("duckdb://", "").lstrip("/")
# 如果没有提供 catalog自动创建并发现表
if catalog is None:
catalog = DatabaseCatalog(db_path)
# 获取涉及的表
tables = catalog.get_required_tables(required_fields)
if not tables:
# 如果没有涉及的表,返回空 DataFrame
return pl.LazyFrame({"ts_code": [], "trade_date": []})
# 分离 DAILY 表和 PIT 表
daily_tables: List[str] = []
pit_tables: List[str] = []
for table_name in tables:
if catalog.is_pit_table(table_name):
pit_tables.append(table_name)
else:
daily_tables.append(table_name)
# 构建 SQL 查询器
query_builder = SQLQueryBuilder(catalog)
# 加载 DAILY 表数据
daily_lfs: Dict[str, pl.LazyFrame] = {}
for table_name in daily_tables:
fields = catalog.get_fields_for_table(table_name, required_fields)
sql = query_builder.build_query(
table_name=table_name,
fields=fields,
start_date=start_date,
end_date=end_date,
)
print(f"[SQL] {sql[:100]}...")
lf = query_duckdb_to_polars(sql, db_path)
# 统一列名:将表的 date_field 重命名为 trade_date
metadata = catalog.get_table_metadata(table_name)
if metadata and metadata.date_field != "trade_date":
lf = lf.rename({metadata.date_field: "trade_date"})
daily_lfs[table_name] = lf
# 加载 PIT 表数据
pit_lfs: Dict[str, pl.LazyFrame] = {}
for table_name in pit_tables:
fields = catalog.get_fields_for_table(table_name, required_fields)
sql = query_builder.build_query(
table_name=table_name,
fields=fields,
start_date=start_date,
end_date=end_date,
lookback_days=lookback_days,
)
print(f"[SQL] {sql[:100]}...")
lf = query_duckdb_to_polars(sql, db_path)
# PIT 表保持原始公告日期字段(用于 join_asof
pit_lfs[table_name] = lf
# 合并所有 DAILY 表(以第一个 daily 表为基准)
result_lf: Optional[pl.LazyFrame] = None
if daily_lfs:
# 使用第一个 daily 表作为基准
first_table = daily_tables[0]
result_lf = daily_lfs[first_table]
# 合并其他 daily 表
for table_name in daily_tables[1:]:
lf = daily_lfs[table_name]
result_lf = result_lf.join(lf, on=["trade_date", "ts_code"], how="left")
elif pit_lfs:
# 如果没有 daily 表,从 PIT 表创建基准时间轴
# 使用第一个 PIT 表的日期范围
first_pit = pit_tables[0]
pit_metadata = catalog.get_table_metadata(first_pit)
# 从 PIT 表提取所有日期和股票代码组合
result_lf = (
pit_lfs[first_pit]
.select([pl.col(pit_metadata.date_field).alias("trade_date"), "ts_code"])
.unique()
)
# 如果没有结果,返回空 DataFrame
if result_lf is None:
return pl.LazyFrame({"ts_code": [], "trade_date": []})
# 合并 PIT 表(使用 join_asof 按公告日期对齐)
for table_name in pit_tables:
pit_metadata = catalog.get_table_metadata(table_name)
lf = pit_lfs[table_name]
# join_asof: 按 ts_code 分组,将 PIT 数据对齐到交易日
# 策略为 backward使用小于等于当前交易日的最新公告数据
result_lf = result_lf.join_asof(
lf,
left_on="trade_date",
right_on=pit_metadata.date_field,
by="ts_code",
strategy="backward",
)
# 最终排序:按 ["ts_code", "trade_date"] 确保时序计算要求
result_lf = result_lf.sort(["ts_code", "trade_date"])
return result_lf
if __name__ == "__main__":
# 测试代码
print("=" * 60)
print("DatabaseCatalog 自动发现测试")
print("=" * 60)
# 测试自动发现
catalog = DatabaseCatalog("data/prostock.db")
print("\n=== 测试字段到表映射 ===")
print(f"字段 'close' 对应的表: {catalog.get_table_for_field('close')}")
print(f"字段 'vol' 对应的表: {catalog.get_table_for_field('vol')}")
print(f"字段 'pe' 对应的表: {catalog.get_table_for_field('pe')}")
print(f"字段 'basic_eps' 对应的表: {catalog.get_table_for_field('basic_eps')}")
print("\n=== 测试表频度类型 ===")
for table_name in catalog.tables:
freq = catalog.get_table_frequency(table_name)
print(f"'{table_name}' 的频度: {freq.value if freq else 'Unknown'}")
print("\n=== 测试 SQL 构建 ===")
query_builder = SQLQueryBuilder(catalog)
daily_sql = query_builder.build_query(
table_name="daily",
fields=["ts_code", "trade_date", "close", "vol"],
start_date="20240101",
end_date="20240131",
)
print(f"\nDAILY 表 SQL:\n{daily_sql}")
pit_sql = query_builder.build_query(
table_name="financial_income",
fields=["ts_code", "ann_date", "basic_eps", "total_revenue"],
start_date="20240101",
end_date="20240131",
lookback_days=90,
)
print(f"\nPIT 表 SQL:\n{pit_sql}")
print("\n=== 测试多表字段收集 ===")
required_fields = ["close", "vol", "pe", "basic_eps", "total_revenue"]
tables = catalog.get_required_tables(required_fields)
print(f"字段 {required_fields} 涉及的表: {tables}")
for table_name in tables:
fields = catalog.get_fields_for_table(table_name, required_fields)
print(f"'{table_name}' 需要查询的字段: {fields}")
print("\n所有测试通过!")