feat(data): 财务数据加载与清洗模块
新增 FinancialLoader 类,提供: - 财务数据加载与清洗(保留合并报表,按 update_flag 去重) - 支持 as-of join 拼接行情数据(无未来函数) - 自动识别财务表并配置 asof_backward 拼接模式
This commit is contained in:
@@ -3,6 +3,7 @@
|
||||
按需取数、组装核心宽表。
|
||||
负责根据数据规格从数据源拉取数据,并组装成统一的宽表格式。
|
||||
支持内存数据源(用于测试)和真实数据库连接。
|
||||
支持标准等值匹配和 asof_backward(财务数据)两种拼接模式。
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, List, Optional, Set, Union
|
||||
@@ -12,6 +13,7 @@ import polars as pl
|
||||
|
||||
from src.factors.engine.data_spec import DataSpec
|
||||
from src.data.storage import Storage
|
||||
from src.data.financial_loader import FinancialLoader
|
||||
|
||||
|
||||
class DataRouter:
|
||||
@@ -37,11 +39,13 @@ class DataRouter:
|
||||
self._cache: Dict[str, pl.DataFrame] = {}
|
||||
self._lock = threading.Lock()
|
||||
|
||||
# 数据库模式下初始化 Storage
|
||||
# 数据库模式下初始化 Storage 和 FinancialLoader
|
||||
if not self.is_memory_mode:
|
||||
self._storage = Storage()
|
||||
self._financial_loader = FinancialLoader()
|
||||
else:
|
||||
self._storage = None
|
||||
self._financial_loader = None
|
||||
|
||||
def fetch_data(
|
||||
self,
|
||||
@@ -75,23 +79,122 @@ class DataRouter:
|
||||
required_tables[spec.table] = set()
|
||||
required_tables[spec.table].update(spec.columns)
|
||||
|
||||
# 从数据源获取各表数据
|
||||
# 从数据源获取各表数据(使用合并后的 required_tables,避免重复加载)
|
||||
table_data = {}
|
||||
for table_name, columns in required_tables.items():
|
||||
df = self._load_table(
|
||||
table_name=table_name,
|
||||
columns=list(columns),
|
||||
# 判断是标准表还是财务表
|
||||
is_financial = any(
|
||||
s.table == table_name and s.join_type == "asof_backward"
|
||||
for s in data_specs
|
||||
)
|
||||
|
||||
if is_financial:
|
||||
# 财务表:找到对应的 spec 获取 join 配置
|
||||
financial_spec = next(
|
||||
s
|
||||
for s in data_specs
|
||||
if s.table == table_name and s.join_type == "asof_backward"
|
||||
)
|
||||
spec = DataSpec(
|
||||
table=table_name,
|
||||
columns=list(columns),
|
||||
join_type="asof_backward",
|
||||
left_on=financial_spec.left_on,
|
||||
right_on=financial_spec.right_on,
|
||||
)
|
||||
else:
|
||||
# 标准表
|
||||
spec = DataSpec(
|
||||
table=table_name,
|
||||
columns=list(columns),
|
||||
join_type="standard",
|
||||
)
|
||||
|
||||
df = self._load_table_from_spec(
|
||||
spec=spec,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
stock_codes=stock_codes,
|
||||
)
|
||||
table_data[table_name] = df
|
||||
|
||||
# 组装核心宽表
|
||||
core_table = self._assemble_wide_table(table_data, required_tables)
|
||||
# 组装核心宽表(支持多种 join 类型)
|
||||
core_table = self._assemble_wide_table_with_specs(
|
||||
table_data, data_specs, start_date, end_date
|
||||
)
|
||||
|
||||
return core_table
|
||||
|
||||
def _load_table_from_spec(
|
||||
self,
|
||||
spec: DataSpec,
|
||||
start_date: str,
|
||||
end_date: str,
|
||||
stock_codes: Optional[List[str]] = None,
|
||||
) -> pl.DataFrame:
|
||||
"""根据数据规格加载单个表的数据。
|
||||
|
||||
根据 spec.join_type 选择不同的加载方式:
|
||||
- standard: 使用原有逻辑,基于 trade_date
|
||||
- asof_backward: 使用 FinancialLoader,基于 f_ann_date,扩展回看期
|
||||
|
||||
Args:
|
||||
spec: 数据规格
|
||||
start_date: 开始日期
|
||||
end_date: 结束日期
|
||||
stock_codes: 股票代码过滤
|
||||
|
||||
Returns:
|
||||
过滤后的 DataFrame
|
||||
"""
|
||||
cache_key = (
|
||||
f"{spec.table}_{spec.join_type}_{start_date}_{end_date}_{stock_codes}"
|
||||
)
|
||||
|
||||
with self._lock:
|
||||
if cache_key in self._cache:
|
||||
return self._cache[cache_key]
|
||||
|
||||
if spec.join_type == "asof_backward":
|
||||
# 财务数据使用 FinancialLoader
|
||||
if self._financial_loader is None:
|
||||
raise RuntimeError("FinancialLoader 未初始化")
|
||||
|
||||
# 扩展日期范围(回看1年)
|
||||
adjusted_start, _ = self._financial_loader.get_date_range_with_lookback(
|
||||
start_date, end_date
|
||||
)
|
||||
|
||||
# 处理 stock_codes
|
||||
ts_code = stock_codes[0] if stock_codes and len(stock_codes) == 1 else None
|
||||
|
||||
df = self._financial_loader.load_financial_data(
|
||||
table_name=spec.table,
|
||||
columns=spec.columns,
|
||||
start_date=adjusted_start,
|
||||
end_date=end_date,
|
||||
ts_code=ts_code,
|
||||
)
|
||||
|
||||
# 如果 stock_codes 是列表且长度 > 1,在内存中过滤
|
||||
if stock_codes is not None and len(stock_codes) > 1:
|
||||
df = df.filter(pl.col("ts_code").is_in(stock_codes))
|
||||
|
||||
else:
|
||||
# 标准表使用原有逻辑
|
||||
df = self._load_table(
|
||||
table_name=spec.table,
|
||||
columns=spec.columns,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
stock_codes=stock_codes,
|
||||
)
|
||||
|
||||
with self._lock:
|
||||
self._cache[cache_key] = df
|
||||
|
||||
return df
|
||||
|
||||
def _load_table(
|
||||
self,
|
||||
table_name: str,
|
||||
@@ -255,6 +358,119 @@ class DataRouter:
|
||||
|
||||
return result
|
||||
|
||||
def _assemble_wide_table_with_specs(
|
||||
self,
|
||||
table_data: Dict[str, pl.DataFrame],
|
||||
data_specs: List[DataSpec],
|
||||
start_date: str,
|
||||
end_date: str,
|
||||
) -> pl.DataFrame:
|
||||
"""组装多表数据为核心宽表(支持多种 join 类型)。
|
||||
|
||||
支持标准等值匹配和 asof_backward 两种模式。
|
||||
|
||||
性能优化:
|
||||
- 在开始时统一将 trade_date 转为 pl.Date
|
||||
- 所有 asof join 全部在 pl.Date 类型下完成
|
||||
- 返回前统一转回字符串格式
|
||||
|
||||
Args:
|
||||
table_data: 表名到 DataFrame 的映射
|
||||
data_specs: 数据规格列表
|
||||
start_date: 开始日期
|
||||
end_date: 结束日期
|
||||
|
||||
Returns:
|
||||
组装后的宽表
|
||||
"""
|
||||
if not table_data:
|
||||
raise ValueError("没有数据可组装")
|
||||
|
||||
# 从 data_specs 判断每个表的 join 类型
|
||||
table_join_types = {}
|
||||
for spec in data_specs:
|
||||
if spec.table not in table_join_types:
|
||||
table_join_types[spec.table] = spec.join_type
|
||||
|
||||
# 分离标准表和 asof 表(基于 table_data 的表名,避免重复)
|
||||
standard_tables = [
|
||||
t
|
||||
for t in table_data.keys()
|
||||
if table_join_types.get(t, "standard") == "standard"
|
||||
]
|
||||
asof_tables = [
|
||||
t for t in table_data.keys() if table_join_types.get(t) == "asof_backward"
|
||||
]
|
||||
|
||||
# 先合并所有标准表(使用 trade_date)
|
||||
base_df = None
|
||||
for table_name in standard_tables:
|
||||
df = table_data[table_name]
|
||||
if base_df is None:
|
||||
base_df = df
|
||||
else:
|
||||
# 使用 ts_code 和 trade_date 作为 join 键
|
||||
# 注:根据动态路由原则,除 ts_code/trade_date 外不应有重复字段
|
||||
# 如果出现重复,说明 SchemaCache 的字段映射有问题
|
||||
base_df = base_df.join(
|
||||
df,
|
||||
on=["ts_code", "trade_date"],
|
||||
how="left",
|
||||
)
|
||||
|
||||
if base_df is None:
|
||||
raise ValueError("至少需要一张标准行情表作为基础")
|
||||
|
||||
# 【性能优化】统一转换 trade_date 为 Date 类型(只转换一次)
|
||||
if asof_tables:
|
||||
base_df = base_df.with_columns(
|
||||
[
|
||||
pl.col("trade_date")
|
||||
.str.strptime(pl.Date, "%Y%m%d")
|
||||
.alias("trade_date")
|
||||
]
|
||||
)
|
||||
# 确保已排序(join_asof 要求)
|
||||
base_df = base_df.sort(["ts_code", "trade_date"])
|
||||
|
||||
# 逐个合并 asof 表(所有 join 都在 Date 类型下进行)
|
||||
for table_name in asof_tables:
|
||||
df_financial = table_data[table_name]
|
||||
# 提取需要保留的字段(排除 join 键和元数据字段)
|
||||
# 从 data_specs 中找到对应表的 columns
|
||||
table_columns = set()
|
||||
for spec in data_specs:
|
||||
if spec.table == table_name:
|
||||
table_columns.update(spec.columns)
|
||||
|
||||
financial_cols = [
|
||||
c
|
||||
for c in table_columns
|
||||
if c
|
||||
not in [
|
||||
"ts_code",
|
||||
"f_ann_date",
|
||||
"report_type",
|
||||
"update_flag",
|
||||
"end_date",
|
||||
]
|
||||
]
|
||||
|
||||
if self._financial_loader is None:
|
||||
raise RuntimeError("FinancialLoader 未初始化")
|
||||
|
||||
base_df = self._financial_loader.merge_financial_with_price(
|
||||
base_df, df_financial, financial_cols
|
||||
)
|
||||
|
||||
# 【性能优化】所有 asof join 完成后,统一转回字符串格式
|
||||
if asof_tables:
|
||||
base_df = base_df.with_columns(
|
||||
[pl.col("trade_date").dt.strftime("%Y%m%d").alias("trade_date")]
|
||||
)
|
||||
|
||||
return base_df
|
||||
|
||||
def clear_cache(self) -> None:
|
||||
"""清除数据缓存。"""
|
||||
with self._lock:
|
||||
|
||||
@@ -4,24 +4,38 @@
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, List, Optional, Set, Union
|
||||
from typing import Any, Dict, List, Literal, Optional, Set, Union
|
||||
|
||||
import polars as pl
|
||||
|
||||
|
||||
@dataclass
|
||||
class DataSpec:
|
||||
"""数据规格定义。
|
||||
"""数据规格定义(支持多表类型)。
|
||||
|
||||
描述因子计算所需的数据表和字段。
|
||||
描述因子计算所需的数据表和字段,支持多种拼接类型。
|
||||
|
||||
Attributes:
|
||||
table: 数据表名称
|
||||
columns: 需要的字段列表
|
||||
join_type: 拼接类型
|
||||
- "standard": 标准等值匹配(默认)
|
||||
- "asof_backward": 向后寻找最近历史数据(财务数据用)
|
||||
left_on: 左表 join 键(asof 模式下必须指定)
|
||||
right_on: 右表 join 键(asof 模式下必须指定)
|
||||
"""
|
||||
|
||||
table: str
|
||||
columns: List[str]
|
||||
join_type: Literal["standard", "asof_backward"] = "standard"
|
||||
left_on: Optional[str] = None # 行情表日期列名
|
||||
right_on: Optional[str] = None # 财务表日期列名
|
||||
|
||||
def __post_init__(self):
|
||||
"""验证 asof_backward 模式的参数。"""
|
||||
if self.join_type == "asof_backward":
|
||||
if not self.left_on or not self.right_on:
|
||||
raise ValueError("asof_backward 模式必须指定 left_on 和 right_on")
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@@ -72,9 +72,10 @@ class ExecutionPlanner:
|
||||
dependencies: Set[str],
|
||||
expression: Node,
|
||||
) -> List[DataSpec]:
|
||||
"""从依赖推导数据规格。
|
||||
"""从依赖推导数据规格(支持财务数据自动识别)。
|
||||
|
||||
使用 SchemaCache 动态扫描数据库表结构,自动匹配字段到对应的表。
|
||||
自动识别财务数据表并配置 asof_backward 模式。
|
||||
表结构只扫描一次并缓存在内存中。
|
||||
|
||||
Args:
|
||||
@@ -90,11 +91,21 @@ class ExecutionPlanner:
|
||||
|
||||
data_specs = []
|
||||
for table_name, columns in table_to_fields.items():
|
||||
data_specs.append(
|
||||
DataSpec(
|
||||
if schema_cache.is_financial_table(table_name):
|
||||
# 财务表使用 asof_backward 模式
|
||||
spec = DataSpec(
|
||||
table=table_name,
|
||||
columns=columns,
|
||||
join_type="asof_backward",
|
||||
left_on="trade_date",
|
||||
right_on="f_ann_date",
|
||||
)
|
||||
else:
|
||||
# 标准表使用默认模式
|
||||
spec = DataSpec(
|
||||
table=table_name,
|
||||
columns=columns,
|
||||
)
|
||||
)
|
||||
data_specs.append(spec)
|
||||
|
||||
return data_specs
|
||||
|
||||
@@ -115,7 +115,7 @@ class SchemaCache:
|
||||
field_to_tables[field] = []
|
||||
field_to_tables[field].append(table)
|
||||
|
||||
# 优先选择最常用的表(pro_bar > daily_basic > daily)
|
||||
# 优先选择最常用的表(pro_bar > daily_basic > daily > financial)
|
||||
priority_order = {"pro_bar": 1, "daily_basic": 2, "daily": 3}
|
||||
|
||||
self._field_to_table_map = {}
|
||||
@@ -124,6 +124,18 @@ class SchemaCache:
|
||||
sorted_tables = sorted(tables, key=lambda t: priority_order.get(t, 999))
|
||||
self._field_to_table_map[field] = sorted_tables[0]
|
||||
|
||||
def is_financial_table(self, table_name: str) -> bool:
|
||||
"""判断是否为财务数据表。
|
||||
|
||||
Args:
|
||||
table_name: 表名
|
||||
|
||||
Returns:
|
||||
是否为财务数据表
|
||||
"""
|
||||
financial_prefixes = ("financial_", "income", "balance", "cashflow")
|
||||
return table_name.lower().startswith(financial_prefixes)
|
||||
|
||||
def get_table_fields(self, table_name: str) -> List[str]:
|
||||
"""获取指定表的字段列表。
|
||||
|
||||
|
||||
Reference in New Issue
Block a user