refactor(factors): 拆分 engine.py 为模块化包

将单文件 engine.py (1064行) 拆分为 engine/ 包:
- 数据规格、路由器、计划器、计算引擎、因子引擎分离
- 保持向后兼容,API 无变化
This commit is contained in:
2026-03-02 22:29:18 +08:00
parent 1c0c4a0de1
commit 77e4e94e05
7 changed files with 1146 additions and 0 deletions

View File

@@ -0,0 +1,304 @@
"""数据路由器。
按需取数、组装核心宽表。
负责根据数据规格从数据源拉取数据,并组装成统一的宽表格式。
支持内存数据源(用于测试)和真实数据库连接。
"""
from typing import Any, Dict, List, Optional, Set, Union
import threading
import polars as pl
from src.factors.engine.data_spec import DataSpec
from src.data.storage import Storage
class DataRouter:
"""数据路由器 - 按需取数、组装核心宽表。
负责根据数据规格从数据源拉取数据,并组装成统一的宽表格式。
支持内存数据源(用于测试)和真实数据库连接。
Attributes:
data_source: 数据源,可以是内存 DataFrame 字典或数据库连接
is_memory_mode: 是否为内存模式
"""
def __init__(self, data_source: Optional[Dict[str, pl.DataFrame]] = None) -> None:
"""初始化数据路由器。
Args:
data_source: 内存数据源,字典格式 {表名: DataFrame}
为 None 时自动连接 DuckDB 数据库
"""
self.data_source = data_source or {}
self.is_memory_mode = data_source is not None
self._cache: Dict[str, pl.DataFrame] = {}
self._lock = threading.Lock()
# 数据库模式下初始化 Storage
if not self.is_memory_mode:
self._storage = Storage()
else:
self._storage = None
def fetch_data(
self,
data_specs: List[DataSpec],
start_date: str,
end_date: str,
stock_codes: Optional[List[str]] = None,
) -> pl.DataFrame:
"""根据数据规格获取并组装核心宽表。
Args:
data_specs: 数据规格列表
start_date: 开始日期 (YYYYMMDD)
end_date: 结束日期 (YYYYMMDD)
stock_codes: 股票代码列表None 表示全市场
Returns:
组装好的核心宽表 DataFrame
Raises:
ValueError: 当数据源中缺少必要的表或字段时
"""
if not data_specs:
raise ValueError("数据规格不能为空")
# 收集所有需要的表和字段
required_tables: Dict[str, Set[str]] = {}
max_lookback = 0
for spec in data_specs:
if spec.table not in required_tables:
required_tables[spec.table] = set()
required_tables[spec.table].update(spec.columns)
max_lookback = max(max_lookback, spec.lookback_days)
# 调整日期范围以包含回看期
adjusted_start = self._adjust_start_date(start_date, max_lookback)
# 从数据源获取各表数据
table_data = {}
for table_name, columns in required_tables.items():
df = self._load_table(
table_name=table_name,
columns=list(columns),
start_date=adjusted_start,
end_date=end_date,
stock_codes=stock_codes,
)
table_data[table_name] = df
# 组装核心宽表
core_table = self._assemble_wide_table(table_data, required_tables)
# 过滤到实际请求日期范围
core_table = core_table.filter(
(pl.col("trade_date") >= start_date) & (pl.col("trade_date") <= end_date)
)
return core_table
def _load_table(
self,
table_name: str,
columns: List[str],
start_date: str,
end_date: str,
stock_codes: Optional[List[str]] = None,
) -> pl.DataFrame:
"""加载单个表的数据。
Args:
table_name: 表名
columns: 需要的字段
start_date: 开始日期
end_date: 结束日期
stock_codes: 股票代码过滤
Returns:
过滤后的 DataFrame
"""
cache_key = f"{table_name}_{start_date}_{end_date}_{stock_codes}"
with self._lock:
if cache_key in self._cache:
return self._cache[cache_key]
if self.is_memory_mode:
df = self._load_from_memory(
table_name, columns, start_date, end_date, stock_codes
)
else:
df = self._load_from_database(
table_name, columns, start_date, end_date, stock_codes
)
with self._lock:
self._cache[cache_key] = df
return df
def _load_from_memory(
self,
table_name: str,
columns: List[str],
start_date: str,
end_date: str,
stock_codes: Optional[List[str]] = None,
) -> pl.DataFrame:
"""从内存数据源加载数据。"""
if table_name not in self.data_source:
raise ValueError(f"内存数据源中缺少表: {table_name}")
df = self.data_source[table_name]
# 确保必需字段存在
for col in columns:
if col not in df.columns and col not in ["ts_code", "trade_date"]:
raise ValueError(f"{table_name} 缺少字段: {col}")
# 过滤日期和股票
df = df.filter(
(pl.col("trade_date") >= start_date) & (pl.col("trade_date") <= end_date)
)
if stock_codes is not None:
df = df.filter(pl.col("ts_code").is_in(stock_codes))
# 选择需要的列
select_cols = ["ts_code", "trade_date"] + [
c for c in columns if c in df.columns
]
return df.select(select_cols)
def _load_from_database(
self,
table_name: str,
columns: List[str],
start_date: str,
end_date: str,
stock_codes: Optional[List[str]] = None,
) -> pl.DataFrame:
"""从 DuckDB 数据库加载数据。
利用 Storage.load_polars() 方法,支持 SQL 查询下推。
"""
if self._storage is None:
raise RuntimeError("Storage 未初始化")
# 检查表是否存在
if not self._storage.exists(table_name):
raise ValueError(f"数据库中不存在表: {table_name}")
# 构建查询参数
# Storage.load_polars 目前只支持单个 ts_code需要处理列表情况
if stock_codes is not None and len(stock_codes) == 1:
ts_code_filter = stock_codes[0]
else:
ts_code_filter = None
try:
# 从数据库加载原始数据
df = self._storage.load_polars(
name=table_name,
start_date=start_date,
end_date=end_date,
ts_code=ts_code_filter,
)
except Exception as e:
raise RuntimeError(f"从数据库加载表 {table_name} 失败: {e}")
# 如果 stock_codes 是列表且长度 > 1在内存中过滤
if stock_codes is not None and len(stock_codes) > 1:
df = df.filter(pl.col("ts_code").is_in(stock_codes))
# 检查必需字段
for col in columns:
if col not in df.columns and col not in ["ts_code", "trade_date"]:
raise ValueError(f"{table_name} 缺少字段: {col}")
# 选择需要的列
select_cols = ["ts_code", "trade_date"] + [
c for c in columns if c in df.columns
]
return df.select(select_cols)
def _assemble_wide_table(
self,
table_data: Dict[str, pl.DataFrame],
required_tables: Dict[str, Set[str]],
) -> pl.DataFrame:
"""组装多表数据为核心宽表。
使用 left join 合并各表数据,以第一个表为基准。
Args:
table_data: 表名到 DataFrame 的映射
required_tables: 表名到字段集合的映射
Returns:
组装后的宽表
"""
if not table_data:
raise ValueError("没有数据可组装")
# 以第一个表为基准
base_table_name = list(table_data.keys())[0]
result = table_data[base_table_name]
# 与其他表 join
for table_name, df in table_data.items():
if table_name == base_table_name:
continue
# 使用 ts_code 和 trade_date 作为 join 键
result = result.join(
df,
on=["ts_code", "trade_date"],
how="left",
)
return result
def _adjust_start_date(self, start_date: str, lookback_days: int) -> str:
"""根据回看天数调整开始日期。
Args:
start_date: 原始开始日期 (YYYYMMDD)
lookback_days: 需要回看的交易日数
Returns:
调整后的开始日期
"""
# 简化的日期调整假设每月30天向前推移
# 实际应用中应该使用交易日历
year = int(start_date[:4])
month = int(start_date[4:6])
day = int(start_date[6:8])
total_days = lookback_days + 30 # 额外缓冲
day -= total_days
while day <= 0:
month -= 1
if month <= 0:
month = 12
year -= 1
day += 30
return f"{year:04d}{month:02d}{day:02d}"
def clear_cache(self) -> None:
"""清除数据缓存。"""
with self._lock:
self._cache.clear()
# 数据库模式下清理 Storage 连接(可选)
if not self.is_memory_mode and self._storage is not None:
# Storage 使用单例模式,不需要关闭连接
pass