Files
ProStock/src/factors/engine/data_router.py

267 lines
8.2 KiB
Python
Raw Normal View History

"""数据路由器。
按需取数组装核心宽表
负责根据数据规格从数据源拉取数据并组装成统一的宽表格式
支持内存数据源用于测试和真实数据库连接
"""
from typing import Any, Dict, List, Optional, Set, Union
import threading
import polars as pl
from src.factors.engine.data_spec import DataSpec
from src.data.storage import Storage
class DataRouter:
"""数据路由器 - 按需取数、组装核心宽表。
负责根据数据规格从数据源拉取数据并组装成统一的宽表格式
支持内存数据源用于测试和真实数据库连接
Attributes:
data_source: 数据源可以是内存 DataFrame 字典或数据库连接
is_memory_mode: 是否为内存模式
"""
def __init__(self, data_source: Optional[Dict[str, pl.DataFrame]] = None) -> None:
"""初始化数据路由器。
Args:
data_source: 内存数据源字典格式 {表名: DataFrame}
None 时自动连接 DuckDB 数据库
"""
self.data_source = data_source or {}
self.is_memory_mode = data_source is not None
self._cache: Dict[str, pl.DataFrame] = {}
self._lock = threading.Lock()
# 数据库模式下初始化 Storage
if not self.is_memory_mode:
self._storage = Storage()
else:
self._storage = None
def fetch_data(
self,
data_specs: List[DataSpec],
start_date: str,
end_date: str,
stock_codes: Optional[List[str]] = None,
) -> pl.DataFrame:
"""根据数据规格获取并组装核心宽表。
Args:
data_specs: 数据规格列表
start_date: 开始日期 (YYYYMMDD)
end_date: 结束日期 (YYYYMMDD)
stock_codes: 股票代码列表None 表示全市场
Returns:
组装好的核心宽表 DataFrame
Raises:
ValueError: 当数据源中缺少必要的表或字段时
"""
if not data_specs:
raise ValueError("数据规格不能为空")
# 收集所有需要的表和字段
required_tables: Dict[str, Set[str]] = {}
for spec in data_specs:
if spec.table not in required_tables:
required_tables[spec.table] = set()
required_tables[spec.table].update(spec.columns)
# 从数据源获取各表数据
table_data = {}
for table_name, columns in required_tables.items():
df = self._load_table(
table_name=table_name,
columns=list(columns),
start_date=start_date,
end_date=end_date,
stock_codes=stock_codes,
)
table_data[table_name] = df
# 组装核心宽表
core_table = self._assemble_wide_table(table_data, required_tables)
return core_table
def _load_table(
self,
table_name: str,
columns: List[str],
start_date: str,
end_date: str,
stock_codes: Optional[List[str]] = None,
) -> pl.DataFrame:
"""加载单个表的数据。
Args:
table_name: 表名
columns: 需要的字段
start_date: 开始日期
end_date: 结束日期
stock_codes: 股票代码过滤
Returns:
过滤后的 DataFrame
"""
cache_key = f"{table_name}_{start_date}_{end_date}_{stock_codes}"
with self._lock:
if cache_key in self._cache:
return self._cache[cache_key]
if self.is_memory_mode:
df = self._load_from_memory(
table_name, columns, start_date, end_date, stock_codes
)
else:
df = self._load_from_database(
table_name, columns, start_date, end_date, stock_codes
)
with self._lock:
self._cache[cache_key] = df
return df
def _load_from_memory(
self,
table_name: str,
columns: List[str],
start_date: str,
end_date: str,
stock_codes: Optional[List[str]] = None,
) -> pl.DataFrame:
"""从内存数据源加载数据。"""
if table_name not in self.data_source:
raise ValueError(f"内存数据源中缺少表: {table_name}")
df = self.data_source[table_name]
# 确保必需字段存在
for col in columns:
if col not in df.columns and col not in ["ts_code", "trade_date"]:
raise ValueError(f"{table_name} 缺少字段: {col}")
# 过滤日期和股票
df = df.filter(
(pl.col("trade_date") >= start_date) & (pl.col("trade_date") <= end_date)
)
if stock_codes is not None:
df = df.filter(pl.col("ts_code").is_in(stock_codes))
# 选择需要的列
select_cols = ["ts_code", "trade_date"] + [
c for c in columns if c in df.columns
]
return df.select(select_cols)
def _load_from_database(
self,
table_name: str,
columns: List[str],
start_date: str,
end_date: str,
stock_codes: Optional[List[str]] = None,
) -> pl.DataFrame:
"""从 DuckDB 数据库加载数据。
利用 Storage.load_polars() 方法支持 SQL 查询下推
"""
if self._storage is None:
raise RuntimeError("Storage 未初始化")
# 检查表是否存在
if not self._storage.exists(table_name):
raise ValueError(f"数据库中不存在表: {table_name}")
# 构建查询参数
# Storage.load_polars 目前只支持单个 ts_code需要处理列表情况
if stock_codes is not None and len(stock_codes) == 1:
ts_code_filter = stock_codes[0]
else:
ts_code_filter = None
try:
# 从数据库加载原始数据
df = self._storage.load_polars(
name=table_name,
start_date=start_date,
end_date=end_date,
ts_code=ts_code_filter,
)
except Exception as e:
raise RuntimeError(f"从数据库加载表 {table_name} 失败: {e}")
# 如果 stock_codes 是列表且长度 > 1在内存中过滤
if stock_codes is not None and len(stock_codes) > 1:
df = df.filter(pl.col("ts_code").is_in(stock_codes))
# 检查必需字段
for col in columns:
if col not in df.columns and col not in ["ts_code", "trade_date"]:
raise ValueError(f"{table_name} 缺少字段: {col}")
# 选择需要的列
select_cols = ["ts_code", "trade_date"] + [
c for c in columns if c in df.columns
]
return df.select(select_cols)
def _assemble_wide_table(
self,
table_data: Dict[str, pl.DataFrame],
required_tables: Dict[str, Set[str]],
) -> pl.DataFrame:
"""组装多表数据为核心宽表。
使用 left join 合并各表数据以第一个表为基准
Args:
table_data: 表名到 DataFrame 的映射
required_tables: 表名到字段集合的映射
Returns:
组装后的宽表
"""
if not table_data:
raise ValueError("没有数据可组装")
# 以第一个表为基准
base_table_name = list(table_data.keys())[0]
result = table_data[base_table_name]
# 与其他表 join
for table_name, df in table_data.items():
if table_name == base_table_name:
continue
# 使用 ts_code 和 trade_date 作为 join 键
result = result.join(
df,
on=["ts_code", "trade_date"],
how="left",
)
return result
def clear_cache(self) -> None:
"""清除数据缓存。"""
with self._lock:
self._cache.clear()
# 数据库模式下清理 Storage 连接(可选)
if not self.is_memory_mode and self._storage is not None:
# Storage 使用单例模式,不需要关闭连接
pass