484 lines
16 KiB
Python
484 lines
16 KiB
Python
"""数据路由器。
|
||
|
||
按需取数、组装核心宽表。
|
||
负责根据数据规格从数据源拉取数据,并组装成统一的宽表格式。
|
||
支持内存数据源(用于测试)和真实数据库连接。
|
||
支持标准等值匹配和 asof_backward(财务数据)两种拼接模式。
|
||
"""
|
||
|
||
from typing import Any, Dict, List, Optional, Set, Union
|
||
import threading
|
||
|
||
import polars as pl
|
||
|
||
from src.factors.engine.data_spec import DataSpec
|
||
from src.data.storage import Storage
|
||
from src.data.financial_loader import FinancialLoader
|
||
|
||
|
||
class DataRouter:
|
||
"""数据路由器 - 按需取数、组装核心宽表。
|
||
|
||
负责根据数据规格从数据源拉取数据,并组装成统一的宽表格式。
|
||
支持内存数据源(用于测试)和真实数据库连接。
|
||
|
||
Attributes:
|
||
data_source: 数据源,可以是内存 DataFrame 字典或数据库连接
|
||
is_memory_mode: 是否为内存模式
|
||
"""
|
||
|
||
def __init__(self, data_source: Optional[Dict[str, pl.DataFrame]] = None) -> None:
|
||
"""初始化数据路由器。
|
||
|
||
Args:
|
||
data_source: 内存数据源,字典格式 {表名: DataFrame}
|
||
为 None 时自动连接 DuckDB 数据库
|
||
"""
|
||
self.data_source = data_source or {}
|
||
self.is_memory_mode = data_source is not None
|
||
self._cache: Dict[str, pl.DataFrame] = {}
|
||
self._lock = threading.Lock()
|
||
|
||
# 数据库模式下初始化 Storage 和 FinancialLoader
|
||
if not self.is_memory_mode:
|
||
self._storage = Storage()
|
||
self._financial_loader = FinancialLoader()
|
||
else:
|
||
self._storage = None
|
||
self._financial_loader = None
|
||
|
||
def fetch_data(
|
||
self,
|
||
data_specs: List[DataSpec],
|
||
start_date: str,
|
||
end_date: str,
|
||
stock_codes: Optional[List[str]] = None,
|
||
) -> pl.DataFrame:
|
||
"""根据数据规格获取并组装核心宽表。
|
||
|
||
Args:
|
||
data_specs: 数据规格列表
|
||
start_date: 开始日期 (YYYYMMDD)
|
||
end_date: 结束日期 (YYYYMMDD)
|
||
stock_codes: 股票代码列表,None 表示全市场
|
||
|
||
Returns:
|
||
组装好的核心宽表 DataFrame
|
||
|
||
Raises:
|
||
ValueError: 当数据源中缺少必要的表或字段时
|
||
"""
|
||
if not data_specs:
|
||
raise ValueError("数据规格不能为空")
|
||
|
||
# 收集所有需要的表和字段
|
||
required_tables: Dict[str, Set[str]] = {}
|
||
|
||
for spec in data_specs:
|
||
if spec.table not in required_tables:
|
||
required_tables[spec.table] = set()
|
||
required_tables[spec.table].update(spec.columns)
|
||
|
||
# 从数据源获取各表数据(使用合并后的 required_tables,避免重复加载)
|
||
table_data = {}
|
||
for table_name, columns in required_tables.items():
|
||
# 判断是标准表还是财务表
|
||
is_financial = any(
|
||
s.table == table_name and s.join_type == "asof_backward"
|
||
for s in data_specs
|
||
)
|
||
|
||
if is_financial:
|
||
# 财务表:找到对应的 spec 获取 join 配置
|
||
financial_spec = next(
|
||
s
|
||
for s in data_specs
|
||
if s.table == table_name and s.join_type == "asof_backward"
|
||
)
|
||
spec = DataSpec(
|
||
table=table_name,
|
||
columns=list(columns),
|
||
join_type="asof_backward",
|
||
left_on=financial_spec.left_on,
|
||
right_on=financial_spec.right_on,
|
||
)
|
||
else:
|
||
# 标准表
|
||
spec = DataSpec(
|
||
table=table_name,
|
||
columns=list(columns),
|
||
join_type="standard",
|
||
)
|
||
|
||
df = self._load_table_from_spec(
|
||
spec=spec,
|
||
start_date=start_date,
|
||
end_date=end_date,
|
||
stock_codes=stock_codes,
|
||
)
|
||
table_data[table_name] = df
|
||
|
||
# 组装核心宽表(支持多种 join 类型)
|
||
core_table = self._assemble_wide_table_with_specs(
|
||
table_data, data_specs, start_date, end_date
|
||
)
|
||
|
||
return core_table
|
||
|
||
def _load_table_from_spec(
|
||
self,
|
||
spec: DataSpec,
|
||
start_date: str,
|
||
end_date: str,
|
||
stock_codes: Optional[List[str]] = None,
|
||
) -> pl.DataFrame:
|
||
"""根据数据规格加载单个表的数据。
|
||
|
||
根据 spec.join_type 选择不同的加载方式:
|
||
- standard: 使用原有逻辑,基于 trade_date
|
||
- asof_backward: 使用 FinancialLoader,基于 f_ann_date,扩展回看期
|
||
|
||
Args:
|
||
spec: 数据规格
|
||
start_date: 开始日期
|
||
end_date: 结束日期
|
||
stock_codes: 股票代码过滤
|
||
|
||
Returns:
|
||
过滤后的 DataFrame
|
||
"""
|
||
cache_key = (
|
||
f"{spec.table}_{spec.join_type}_{start_date}_{end_date}_{stock_codes}"
|
||
)
|
||
|
||
with self._lock:
|
||
if cache_key in self._cache:
|
||
return self._cache[cache_key]
|
||
|
||
if spec.join_type == "asof_backward":
|
||
# 财务数据使用 FinancialLoader
|
||
if self._financial_loader is None:
|
||
raise RuntimeError("FinancialLoader 未初始化")
|
||
|
||
# 扩展日期范围(回看1年)
|
||
adjusted_start, _ = self._financial_loader.get_date_range_with_lookback(
|
||
start_date, end_date
|
||
)
|
||
|
||
# 处理 stock_codes
|
||
ts_code = stock_codes[0] if stock_codes and len(stock_codes) == 1 else None
|
||
|
||
df = self._financial_loader.load_financial_data(
|
||
table_name=spec.table,
|
||
columns=spec.columns,
|
||
start_date=adjusted_start,
|
||
end_date=end_date,
|
||
ts_code=ts_code,
|
||
)
|
||
|
||
# 如果 stock_codes 是列表且长度 > 1,在内存中过滤
|
||
if stock_codes is not None and len(stock_codes) > 1:
|
||
df = df.filter(pl.col("ts_code").is_in(stock_codes))
|
||
|
||
else:
|
||
# 标准表使用原有逻辑
|
||
df = self._load_table(
|
||
table_name=spec.table,
|
||
columns=spec.columns,
|
||
start_date=start_date,
|
||
end_date=end_date,
|
||
stock_codes=stock_codes,
|
||
)
|
||
|
||
with self._lock:
|
||
self._cache[cache_key] = df
|
||
|
||
return df
|
||
|
||
def _load_table(
|
||
self,
|
||
table_name: str,
|
||
columns: List[str],
|
||
start_date: str,
|
||
end_date: str,
|
||
stock_codes: Optional[List[str]] = None,
|
||
) -> pl.DataFrame:
|
||
"""加载单个表的数据。
|
||
|
||
Args:
|
||
table_name: 表名
|
||
columns: 需要的字段
|
||
start_date: 开始日期
|
||
end_date: 结束日期
|
||
stock_codes: 股票代码过滤
|
||
|
||
Returns:
|
||
过滤后的 DataFrame
|
||
"""
|
||
cache_key = f"{table_name}_{start_date}_{end_date}_{stock_codes}"
|
||
|
||
with self._lock:
|
||
if cache_key in self._cache:
|
||
return self._cache[cache_key]
|
||
|
||
if self.is_memory_mode:
|
||
df = self._load_from_memory(
|
||
table_name, columns, start_date, end_date, stock_codes
|
||
)
|
||
else:
|
||
df = self._load_from_database(
|
||
table_name, columns, start_date, end_date, stock_codes
|
||
)
|
||
|
||
with self._lock:
|
||
self._cache[cache_key] = df
|
||
|
||
return df
|
||
|
||
def _load_from_memory(
|
||
self,
|
||
table_name: str,
|
||
columns: List[str],
|
||
start_date: str,
|
||
end_date: str,
|
||
stock_codes: Optional[List[str]] = None,
|
||
) -> pl.DataFrame:
|
||
"""从内存数据源加载数据。"""
|
||
if table_name not in self.data_source:
|
||
raise ValueError(f"内存数据源中缺少表: {table_name}")
|
||
|
||
df = self.data_source[table_name]
|
||
|
||
# 确保必需字段存在
|
||
for col in columns:
|
||
if col not in df.columns and col not in ["ts_code", "trade_date"]:
|
||
raise ValueError(f"表 {table_name} 缺少字段: {col}")
|
||
|
||
# 过滤日期和股票
|
||
df = df.filter(
|
||
(pl.col("trade_date") >= start_date) & (pl.col("trade_date") <= end_date)
|
||
)
|
||
|
||
if stock_codes is not None:
|
||
df = df.filter(pl.col("ts_code").is_in(stock_codes))
|
||
|
||
# 选择需要的列(避免重复)
|
||
base_cols = ["ts_code", "trade_date"]
|
||
extra_cols = [c for c in columns if c in df.columns and c not in base_cols]
|
||
select_cols = base_cols + extra_cols
|
||
|
||
return df.select(select_cols)
|
||
|
||
def _load_from_database(
|
||
self,
|
||
table_name: str,
|
||
columns: List[str],
|
||
start_date: str,
|
||
end_date: str,
|
||
stock_codes: Optional[List[str]] = None,
|
||
) -> pl.DataFrame:
|
||
"""从 DuckDB 数据库加载数据。
|
||
|
||
利用 Storage.load_polars() 方法,支持 SQL 查询下推。
|
||
"""
|
||
if self._storage is None:
|
||
raise RuntimeError("Storage 未初始化")
|
||
|
||
# 检查表是否存在
|
||
if not self._storage.exists(table_name):
|
||
raise ValueError(f"数据库中不存在表: {table_name}")
|
||
|
||
# 构建查询参数
|
||
# Storage.load_polars 目前只支持单个 ts_code,需要处理列表情况
|
||
if stock_codes is not None and len(stock_codes) == 1:
|
||
ts_code_filter = stock_codes[0]
|
||
else:
|
||
ts_code_filter = None
|
||
|
||
try:
|
||
# 从数据库加载原始数据
|
||
df = self._storage.load_polars(
|
||
name=table_name,
|
||
start_date=start_date,
|
||
end_date=end_date,
|
||
ts_code=ts_code_filter,
|
||
)
|
||
except Exception as e:
|
||
raise RuntimeError(f"从数据库加载表 {table_name} 失败: {e}")
|
||
|
||
# 如果 stock_codes 是列表且长度 > 1,在内存中过滤
|
||
if stock_codes is not None and len(stock_codes) > 1:
|
||
df = df.filter(pl.col("ts_code").is_in(stock_codes))
|
||
|
||
# 检查必需字段
|
||
for col in columns:
|
||
if col not in df.columns and col not in ["ts_code", "trade_date"]:
|
||
raise ValueError(f"表 {table_name} 缺少字段: {col}")
|
||
|
||
# 选择需要的列(避免重复)
|
||
base_cols = ["ts_code", "trade_date"]
|
||
extra_cols = [c for c in columns if c in df.columns and c not in base_cols]
|
||
select_cols = base_cols + extra_cols
|
||
|
||
return df.select(select_cols)
|
||
|
||
def _assemble_wide_table(
|
||
self,
|
||
table_data: Dict[str, pl.DataFrame],
|
||
required_tables: Dict[str, Set[str]],
|
||
) -> pl.DataFrame:
|
||
"""组装多表数据为核心宽表。
|
||
|
||
使用 left join 合并各表数据,以第一个表为基准。
|
||
|
||
Args:
|
||
table_data: 表名到 DataFrame 的映射
|
||
required_tables: 表名到字段集合的映射
|
||
|
||
Returns:
|
||
组装后的宽表
|
||
"""
|
||
if not table_data:
|
||
raise ValueError("没有数据可组装")
|
||
|
||
# 以第一个表为基准
|
||
base_table_name = list(table_data.keys())[0]
|
||
result = table_data[base_table_name]
|
||
|
||
# 与其他表 join
|
||
for table_name, df in table_data.items():
|
||
if table_name == base_table_name:
|
||
continue
|
||
|
||
# 使用 ts_code 和 trade_date 作为 join 键
|
||
result = result.join(
|
||
df,
|
||
on=["ts_code", "trade_date"],
|
||
how="left",
|
||
)
|
||
|
||
return result
|
||
|
||
def _assemble_wide_table_with_specs(
|
||
self,
|
||
table_data: Dict[str, pl.DataFrame],
|
||
data_specs: List[DataSpec],
|
||
start_date: str,
|
||
end_date: str,
|
||
) -> pl.DataFrame:
|
||
"""组装多表数据为核心宽表(支持多种 join 类型)。
|
||
|
||
支持标准等值匹配和 asof_backward 两种模式。
|
||
|
||
性能优化:
|
||
- 在开始时统一将 trade_date 转为 pl.Date
|
||
- 所有 asof join 全部在 pl.Date 类型下完成
|
||
- 返回前统一转回字符串格式
|
||
|
||
Args:
|
||
table_data: 表名到 DataFrame 的映射
|
||
data_specs: 数据规格列表
|
||
start_date: 开始日期
|
||
end_date: 结束日期
|
||
|
||
Returns:
|
||
组装后的宽表
|
||
"""
|
||
if not table_data:
|
||
raise ValueError("没有数据可组装")
|
||
|
||
# 从 data_specs 判断每个表的 join 类型
|
||
table_join_types = {}
|
||
for spec in data_specs:
|
||
if spec.table not in table_join_types:
|
||
table_join_types[spec.table] = spec.join_type
|
||
|
||
# 分离标准表和 asof 表(基于 table_data 的表名,避免重复)
|
||
standard_tables = [
|
||
t
|
||
for t in table_data.keys()
|
||
if table_join_types.get(t, "standard") == "standard"
|
||
]
|
||
asof_tables = [
|
||
t for t in table_data.keys() if table_join_types.get(t) == "asof_backward"
|
||
]
|
||
|
||
# 先合并所有标准表(使用 trade_date)
|
||
base_df = None
|
||
for table_name in standard_tables:
|
||
df = table_data[table_name]
|
||
if base_df is None:
|
||
base_df = df
|
||
else:
|
||
# 使用 ts_code 和 trade_date 作为 join 键
|
||
# 注:根据动态路由原则,除 ts_code/trade_date 外不应有重复字段
|
||
# 如果出现重复,说明 SchemaCache 的字段映射有问题
|
||
base_df = base_df.join(
|
||
df,
|
||
on=["ts_code", "trade_date"],
|
||
how="left",
|
||
)
|
||
|
||
if base_df is None:
|
||
raise ValueError("至少需要一张标准行情表作为基础")
|
||
|
||
# 【性能优化】统一转换 trade_date 为 Date 类型(只转换一次)
|
||
if asof_tables:
|
||
base_df = base_df.with_columns(
|
||
[
|
||
pl.col("trade_date")
|
||
.str.strptime(pl.Date, "%Y%m%d")
|
||
.alias("trade_date")
|
||
]
|
||
)
|
||
# 确保已排序(join_asof 要求)
|
||
base_df = base_df.sort(["ts_code", "trade_date"])
|
||
|
||
# 逐个合并 asof 表(所有 join 都在 Date 类型下进行)
|
||
for table_name in asof_tables:
|
||
df_financial = table_data[table_name]
|
||
# 提取需要保留的字段(排除 join 键和元数据字段)
|
||
# 从 data_specs 中找到对应表的 columns
|
||
table_columns = set()
|
||
for spec in data_specs:
|
||
if spec.table == table_name:
|
||
table_columns.update(spec.columns)
|
||
|
||
financial_cols = [
|
||
c
|
||
for c in table_columns
|
||
if c
|
||
not in [
|
||
"ts_code",
|
||
"f_ann_date",
|
||
"report_type",
|
||
"update_flag",
|
||
"end_date",
|
||
]
|
||
]
|
||
|
||
if self._financial_loader is None:
|
||
raise RuntimeError("FinancialLoader 未初始化")
|
||
|
||
base_df = self._financial_loader.merge_financial_with_price(
|
||
base_df, df_financial, financial_cols
|
||
)
|
||
|
||
# 【性能优化】所有 asof join 完成后,统一转回字符串格式
|
||
if asof_tables:
|
||
base_df = base_df.with_columns(
|
||
[pl.col("trade_date").dt.strftime("%Y%m%d").alias("trade_date")]
|
||
)
|
||
|
||
return base_df
|
||
|
||
def clear_cache(self) -> None:
|
||
"""清除数据缓存。"""
|
||
with self._lock:
|
||
self._cache.clear()
|
||
|
||
# 数据库模式下清理 Storage 连接(可选)
|
||
if not self.is_memory_mode and self._storage is not None:
|
||
# Storage 使用单例模式,不需要关闭连接
|
||
pass
|