Files
ProStock/src/factors/engine/data_router.py

484 lines
16 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""数据路由器。
按需取数、组装核心宽表。
负责根据数据规格从数据源拉取数据,并组装成统一的宽表格式。
支持内存数据源(用于测试)和真实数据库连接。
支持标准等值匹配和 asof_backward财务数据两种拼接模式。
"""
from typing import Any, Dict, List, Optional, Set, Union
import threading
import polars as pl
from src.factors.engine.data_spec import DataSpec
from src.data.storage import Storage
from src.data.financial_loader import FinancialLoader
class DataRouter:
"""数据路由器 - 按需取数、组装核心宽表。
负责根据数据规格从数据源拉取数据,并组装成统一的宽表格式。
支持内存数据源(用于测试)和真实数据库连接。
Attributes:
data_source: 数据源,可以是内存 DataFrame 字典或数据库连接
is_memory_mode: 是否为内存模式
"""
def __init__(self, data_source: Optional[Dict[str, pl.DataFrame]] = None) -> None:
"""初始化数据路由器。
Args:
data_source: 内存数据源,字典格式 {表名: DataFrame}
为 None 时自动连接 DuckDB 数据库
"""
self.data_source = data_source or {}
self.is_memory_mode = data_source is not None
self._cache: Dict[str, pl.DataFrame] = {}
self._lock = threading.Lock()
# 数据库模式下初始化 Storage 和 FinancialLoader
if not self.is_memory_mode:
self._storage = Storage()
self._financial_loader = FinancialLoader()
else:
self._storage = None
self._financial_loader = None
def fetch_data(
self,
data_specs: List[DataSpec],
start_date: str,
end_date: str,
stock_codes: Optional[List[str]] = None,
) -> pl.DataFrame:
"""根据数据规格获取并组装核心宽表。
Args:
data_specs: 数据规格列表
start_date: 开始日期 (YYYYMMDD)
end_date: 结束日期 (YYYYMMDD)
stock_codes: 股票代码列表None 表示全市场
Returns:
组装好的核心宽表 DataFrame
Raises:
ValueError: 当数据源中缺少必要的表或字段时
"""
if not data_specs:
raise ValueError("数据规格不能为空")
# 收集所有需要的表和字段
required_tables: Dict[str, Set[str]] = {}
for spec in data_specs:
if spec.table not in required_tables:
required_tables[spec.table] = set()
required_tables[spec.table].update(spec.columns)
# 从数据源获取各表数据(使用合并后的 required_tables避免重复加载
table_data = {}
for table_name, columns in required_tables.items():
# 判断是标准表还是财务表
is_financial = any(
s.table == table_name and s.join_type == "asof_backward"
for s in data_specs
)
if is_financial:
# 财务表:找到对应的 spec 获取 join 配置
financial_spec = next(
s
for s in data_specs
if s.table == table_name and s.join_type == "asof_backward"
)
spec = DataSpec(
table=table_name,
columns=list(columns),
join_type="asof_backward",
left_on=financial_spec.left_on,
right_on=financial_spec.right_on,
)
else:
# 标准表
spec = DataSpec(
table=table_name,
columns=list(columns),
join_type="standard",
)
df = self._load_table_from_spec(
spec=spec,
start_date=start_date,
end_date=end_date,
stock_codes=stock_codes,
)
table_data[table_name] = df
# 组装核心宽表(支持多种 join 类型)
core_table = self._assemble_wide_table_with_specs(
table_data, data_specs, start_date, end_date
)
return core_table
def _load_table_from_spec(
self,
spec: DataSpec,
start_date: str,
end_date: str,
stock_codes: Optional[List[str]] = None,
) -> pl.DataFrame:
"""根据数据规格加载单个表的数据。
根据 spec.join_type 选择不同的加载方式:
- standard: 使用原有逻辑,基于 trade_date
- asof_backward: 使用 FinancialLoader基于 f_ann_date扩展回看期
Args:
spec: 数据规格
start_date: 开始日期
end_date: 结束日期
stock_codes: 股票代码过滤
Returns:
过滤后的 DataFrame
"""
cache_key = (
f"{spec.table}_{spec.join_type}_{start_date}_{end_date}_{stock_codes}"
)
with self._lock:
if cache_key in self._cache:
return self._cache[cache_key]
if spec.join_type == "asof_backward":
# 财务数据使用 FinancialLoader
if self._financial_loader is None:
raise RuntimeError("FinancialLoader 未初始化")
# 扩展日期范围回看1年
adjusted_start, _ = self._financial_loader.get_date_range_with_lookback(
start_date, end_date
)
# 处理 stock_codes
ts_code = stock_codes[0] if stock_codes and len(stock_codes) == 1 else None
df = self._financial_loader.load_financial_data(
table_name=spec.table,
columns=spec.columns,
start_date=adjusted_start,
end_date=end_date,
ts_code=ts_code,
)
# 如果 stock_codes 是列表且长度 > 1在内存中过滤
if stock_codes is not None and len(stock_codes) > 1:
df = df.filter(pl.col("ts_code").is_in(stock_codes))
else:
# 标准表使用原有逻辑
df = self._load_table(
table_name=spec.table,
columns=spec.columns,
start_date=start_date,
end_date=end_date,
stock_codes=stock_codes,
)
with self._lock:
self._cache[cache_key] = df
return df
def _load_table(
self,
table_name: str,
columns: List[str],
start_date: str,
end_date: str,
stock_codes: Optional[List[str]] = None,
) -> pl.DataFrame:
"""加载单个表的数据。
Args:
table_name: 表名
columns: 需要的字段
start_date: 开始日期
end_date: 结束日期
stock_codes: 股票代码过滤
Returns:
过滤后的 DataFrame
"""
cache_key = f"{table_name}_{start_date}_{end_date}_{stock_codes}"
with self._lock:
if cache_key in self._cache:
return self._cache[cache_key]
if self.is_memory_mode:
df = self._load_from_memory(
table_name, columns, start_date, end_date, stock_codes
)
else:
df = self._load_from_database(
table_name, columns, start_date, end_date, stock_codes
)
with self._lock:
self._cache[cache_key] = df
return df
def _load_from_memory(
self,
table_name: str,
columns: List[str],
start_date: str,
end_date: str,
stock_codes: Optional[List[str]] = None,
) -> pl.DataFrame:
"""从内存数据源加载数据。"""
if table_name not in self.data_source:
raise ValueError(f"内存数据源中缺少表: {table_name}")
df = self.data_source[table_name]
# 确保必需字段存在
for col in columns:
if col not in df.columns and col not in ["ts_code", "trade_date"]:
raise ValueError(f"{table_name} 缺少字段: {col}")
# 过滤日期和股票
df = df.filter(
(pl.col("trade_date") >= start_date) & (pl.col("trade_date") <= end_date)
)
if stock_codes is not None:
df = df.filter(pl.col("ts_code").is_in(stock_codes))
# 选择需要的列(避免重复)
base_cols = ["ts_code", "trade_date"]
extra_cols = [c for c in columns if c in df.columns and c not in base_cols]
select_cols = base_cols + extra_cols
return df.select(select_cols)
def _load_from_database(
self,
table_name: str,
columns: List[str],
start_date: str,
end_date: str,
stock_codes: Optional[List[str]] = None,
) -> pl.DataFrame:
"""从 DuckDB 数据库加载数据。
利用 Storage.load_polars() 方法,支持 SQL 查询下推。
"""
if self._storage is None:
raise RuntimeError("Storage 未初始化")
# 检查表是否存在
if not self._storage.exists(table_name):
raise ValueError(f"数据库中不存在表: {table_name}")
# 构建查询参数
# Storage.load_polars 目前只支持单个 ts_code需要处理列表情况
if stock_codes is not None and len(stock_codes) == 1:
ts_code_filter = stock_codes[0]
else:
ts_code_filter = None
try:
# 从数据库加载原始数据
df = self._storage.load_polars(
name=table_name,
start_date=start_date,
end_date=end_date,
ts_code=ts_code_filter,
)
except Exception as e:
raise RuntimeError(f"从数据库加载表 {table_name} 失败: {e}")
# 如果 stock_codes 是列表且长度 > 1在内存中过滤
if stock_codes is not None and len(stock_codes) > 1:
df = df.filter(pl.col("ts_code").is_in(stock_codes))
# 检查必需字段
for col in columns:
if col not in df.columns and col not in ["ts_code", "trade_date"]:
raise ValueError(f"{table_name} 缺少字段: {col}")
# 选择需要的列(避免重复)
base_cols = ["ts_code", "trade_date"]
extra_cols = [c for c in columns if c in df.columns and c not in base_cols]
select_cols = base_cols + extra_cols
return df.select(select_cols)
def _assemble_wide_table(
self,
table_data: Dict[str, pl.DataFrame],
required_tables: Dict[str, Set[str]],
) -> pl.DataFrame:
"""组装多表数据为核心宽表。
使用 left join 合并各表数据,以第一个表为基准。
Args:
table_data: 表名到 DataFrame 的映射
required_tables: 表名到字段集合的映射
Returns:
组装后的宽表
"""
if not table_data:
raise ValueError("没有数据可组装")
# 以第一个表为基准
base_table_name = list(table_data.keys())[0]
result = table_data[base_table_name]
# 与其他表 join
for table_name, df in table_data.items():
if table_name == base_table_name:
continue
# 使用 ts_code 和 trade_date 作为 join 键
result = result.join(
df,
on=["ts_code", "trade_date"],
how="left",
)
return result
def _assemble_wide_table_with_specs(
self,
table_data: Dict[str, pl.DataFrame],
data_specs: List[DataSpec],
start_date: str,
end_date: str,
) -> pl.DataFrame:
"""组装多表数据为核心宽表(支持多种 join 类型)。
支持标准等值匹配和 asof_backward 两种模式。
性能优化:
- 在开始时统一将 trade_date 转为 pl.Date
- 所有 asof join 全部在 pl.Date 类型下完成
- 返回前统一转回字符串格式
Args:
table_data: 表名到 DataFrame 的映射
data_specs: 数据规格列表
start_date: 开始日期
end_date: 结束日期
Returns:
组装后的宽表
"""
if not table_data:
raise ValueError("没有数据可组装")
# 从 data_specs 判断每个表的 join 类型
table_join_types = {}
for spec in data_specs:
if spec.table not in table_join_types:
table_join_types[spec.table] = spec.join_type
# 分离标准表和 asof 表(基于 table_data 的表名,避免重复)
standard_tables = [
t
for t in table_data.keys()
if table_join_types.get(t, "standard") == "standard"
]
asof_tables = [
t for t in table_data.keys() if table_join_types.get(t) == "asof_backward"
]
# 先合并所有标准表(使用 trade_date
base_df = None
for table_name in standard_tables:
df = table_data[table_name]
if base_df is None:
base_df = df
else:
# 使用 ts_code 和 trade_date 作为 join 键
# 注:根据动态路由原则,除 ts_code/trade_date 外不应有重复字段
# 如果出现重复,说明 SchemaCache 的字段映射有问题
base_df = base_df.join(
df,
on=["ts_code", "trade_date"],
how="left",
)
if base_df is None:
raise ValueError("至少需要一张标准行情表作为基础")
# 【性能优化】统一转换 trade_date 为 Date 类型(只转换一次)
if asof_tables:
base_df = base_df.with_columns(
[
pl.col("trade_date")
.str.strptime(pl.Date, "%Y%m%d")
.alias("trade_date")
]
)
# 确保已排序join_asof 要求)
base_df = base_df.sort(["ts_code", "trade_date"])
# 逐个合并 asof 表(所有 join 都在 Date 类型下进行)
for table_name in asof_tables:
df_financial = table_data[table_name]
# 提取需要保留的字段(排除 join 键和元数据字段)
# 从 data_specs 中找到对应表的 columns
table_columns = set()
for spec in data_specs:
if spec.table == table_name:
table_columns.update(spec.columns)
financial_cols = [
c
for c in table_columns
if c
not in [
"ts_code",
"f_ann_date",
"report_type",
"update_flag",
"end_date",
]
]
if self._financial_loader is None:
raise RuntimeError("FinancialLoader 未初始化")
base_df = self._financial_loader.merge_financial_with_price(
base_df, df_financial, financial_cols
)
# 【性能优化】所有 asof join 完成后,统一转回字符串格式
if asof_tables:
base_df = base_df.with_columns(
[pl.col("trade_date").dt.strftime("%Y%m%d").alias("trade_date")]
)
return base_df
def clear_cache(self) -> None:
"""清除数据缓存。"""
with self._lock:
self._cache.clear()
# 数据库模式下清理 Storage 连接(可选)
if not self.is_memory_mode and self._storage is not None:
# Storage 使用单例模式,不需要关闭连接
pass