feat(data): 财务数据加载与清洗模块

新增 FinancialLoader 类,提供:
- 财务数据加载与清洗(保留合并报表,按 update_flag 去重)
- 支持 as-of join 拼接行情数据(无未来函数)
- 自动识别财务表并配置 asof_backward 拼接模式
This commit is contained in:
2026-03-04 23:35:20 +08:00
parent 620696c842
commit 3b42093100
12 changed files with 695 additions and 379 deletions

View File

@@ -0,0 +1,181 @@
"""财务数据加载与清洗模块。
提供财务数据的加载、清洗和与行情数据拼接功能。
"""
from datetime import datetime, timedelta
from typing import Optional, List
import polars as pl
from src.data.storage import Storage
class FinancialLoader:
"""财务数据加载器。
负责财务数据的清洗、去重,以及与行情数据的 as-of join。
Attributes:
storage: DuckDB 存储实例
"""
def __init__(self):
self.storage = Storage()
def load_financial_data(
self,
table_name: str,
columns: List[str],
start_date: str,
end_date: str,
ts_code: Optional[str] = None,
) -> pl.DataFrame:
"""加载并清洗财务数据。
数据清洗流程:
1. 仅保留 report_type == '1'(合并报表)
2. 按 (ts_code, f_ann_date) 分组,按 update_flag 降序去重
3. 转换为 Date 类型,按 ts_code 和 f_ann_date 排序
Args:
table_name: 财务表名(如 'financial_income'
columns: 需要的字段列表(必须包含核心字段)
start_date: 数据开始日期YYYYMMDD
end_date: 数据结束日期YYYYMMDD
ts_code: 可选,单个股票代码过滤
Returns:
清洗后的 Polars DataFrame已排序f_ann_date 为 pl.Date 类型
"""
# 确保包含必要字段
required_cols = {"ts_code", "f_ann_date", "report_type", "update_flag"}
query_cols = list(set(columns) | required_cols)
# 从数据库加载原始数据
df = self._load_from_db(table_name, query_cols, start_date, end_date, ts_code)
if df.is_empty():
return df
# 步骤1: 仅保留合并报表 (report_type 可能是字符串或整数)
df = df.filter(pl.col("report_type") == 1)
# 步骤2: 按 update_flag 降序排列后去重
df = df.with_columns(
[pl.col("update_flag").cast(pl.Int32).alias("update_flag_int")]
)
# 排序ts_code, f_ann_date 升序update_flag 降序
df = df.sort(
["ts_code", "f_ann_date", "update_flag_int"],
descending=[False, False, True],
)
# 去重:保留每个 (ts_code, f_ann_date) 的第一条update_flag 最高的)
df = df.unique(subset=["ts_code", "f_ann_date"], keep="first")
# 移除临时列
df = df.drop("update_flag_int")
# 步骤3: 确保 f_ann_date 是 Date 类型并排序
# 数据库返回的必须是 Date 类型,如果不是则报错
if df["f_ann_date"].dtype != pl.Date:
raise TypeError(
f"f_ann_date 必须是 Date 类型,实际类型为 {df['f_ann_date'].dtype}. "
f"请检查数据库表结构,确保日期字段为 DATE 类型。"
)
# 最终排序join_asof 要求)
df = df.sort(["ts_code", "f_ann_date"])
return df
def merge_financial_with_price(
self,
df_price: pl.DataFrame,
df_financial: pl.DataFrame,
financial_cols: List[str],
) -> pl.DataFrame:
"""将财务数据拼接到行情数据。
使用 join_asof 向后匹配:对于每个交易日,找到最近的历史公告数据。
注意:输入的 df_price 的 trade_date 必须是 pl.Date 类型且已排序。
Args:
df_price: 行情数据 DataFrame必须包含 ts_code, trade_dateDate 类型)
df_financial: 财务数据 DataFrame已通过 load_financial_data 清洗Date 类型)
financial_cols: 需要从财务表保留的字段列表
Returns:
拼接后的 DataFrametrade_date 仍为 Date 类型)
"""
if df_financial.is_empty():
# 财务数据为空,返回行情数据(财务列为空)
for col in financial_cols:
if col not in df_price.columns:
df_price = df_price.with_columns([pl.lit(None).alias(col)])
return df_price
# 执行 asof join: 向后寻找最近的历史数据
# strategy='backward': 对于每个 trade_date找 f_ann_date <= trade_date 的最新记录
merged = df_price.join_asof(
df_financial.select(["ts_code", "f_ann_date"] + financial_cols),
left_on="trade_date",
right_on="f_ann_date",
by="ts_code",
strategy="backward",
)
return merged
def _load_from_db(
self,
table_name: str,
columns: List[str],
start_date: str,
end_date: str,
ts_code: Optional[str] = None,
) -> pl.DataFrame:
"""从数据库加载财务数据(内部方法)。"""
conn = self.storage._connection
cols_str = ", ".join(f'"{c}"' for c in columns)
start_dt = datetime.strptime(start_date, "%Y%m%d").date()
end_dt = datetime.strptime(end_date, "%Y%m%d").date()
conditions = [f"f_ann_date BETWEEN '{start_dt}' AND '{end_dt}'"]
if ts_code:
conditions.append(f"ts_code = '{ts_code}'")
where_clause = " AND ".join(conditions)
query = f"SELECT {cols_str} FROM {table_name} WHERE {where_clause} ORDER BY ts_code, f_ann_date"
try:
df = conn.sql(query).pl()
return df
except Exception as e:
print(f"[FinancialLoader] 加载 {table_name} 失败: {e}")
return pl.DataFrame()
def get_date_range_with_lookback(
self,
start_date: str,
end_date: str,
lookback_years: int = 1,
) -> tuple[str, str]:
"""计算包含回看期的日期范围。
Args:
start_date: 原始开始日期YYYYMMDD
end_date: 原始结束日期YYYYMMDD
lookback_years: 回看年数默认1年
Returns:
(扩展后的开始日期, 结束日期)
"""
start_dt = datetime.strptime(start_date, "%Y%m%d")
adjusted_start = start_dt - timedelta(days=365 * lookback_years)
return adjusted_start.strftime("%Y%m%d"), end_date

View File

@@ -61,6 +61,8 @@ def create_factors_with_strings(engine: FactorEngine) -> List[str]:
"market_cap_rank": "cs_rank(total_mv)", "market_cap_rank": "cs_rank(total_mv)",
# 7. 价格位置因子 # 7. 价格位置因子
"high_low_ratio": "(close - ts_min(low, 20)) / (ts_max(high, 20) - ts_min(low, 20) + 1e-8)", "high_low_ratio": "(close - ts_min(low, 20)) / (ts_max(high, 20) - ts_min(low, 20) + 1e-8)",
"n_income": "n_income"
} }
# Label 因子(单独定义,不参与训练) # Label 因子(单独定义,不参与训练)

View File

@@ -3,6 +3,7 @@
按需取数、组装核心宽表。 按需取数、组装核心宽表。
负责根据数据规格从数据源拉取数据,并组装成统一的宽表格式。 负责根据数据规格从数据源拉取数据,并组装成统一的宽表格式。
支持内存数据源(用于测试)和真实数据库连接。 支持内存数据源(用于测试)和真实数据库连接。
支持标准等值匹配和 asof_backward财务数据两种拼接模式。
""" """
from typing import Any, Dict, List, Optional, Set, Union from typing import Any, Dict, List, Optional, Set, Union
@@ -12,6 +13,7 @@ import polars as pl
from src.factors.engine.data_spec import DataSpec from src.factors.engine.data_spec import DataSpec
from src.data.storage import Storage from src.data.storage import Storage
from src.data.financial_loader import FinancialLoader
class DataRouter: class DataRouter:
@@ -37,11 +39,13 @@ class DataRouter:
self._cache: Dict[str, pl.DataFrame] = {} self._cache: Dict[str, pl.DataFrame] = {}
self._lock = threading.Lock() self._lock = threading.Lock()
# 数据库模式下初始化 Storage # 数据库模式下初始化 Storage 和 FinancialLoader
if not self.is_memory_mode: if not self.is_memory_mode:
self._storage = Storage() self._storage = Storage()
self._financial_loader = FinancialLoader()
else: else:
self._storage = None self._storage = None
self._financial_loader = None
def fetch_data( def fetch_data(
self, self,
@@ -75,23 +79,122 @@ class DataRouter:
required_tables[spec.table] = set() required_tables[spec.table] = set()
required_tables[spec.table].update(spec.columns) required_tables[spec.table].update(spec.columns)
# 从数据源获取各表数据 # 从数据源获取各表数据(使用合并后的 required_tables避免重复加载
table_data = {} table_data = {}
for table_name, columns in required_tables.items(): for table_name, columns in required_tables.items():
df = self._load_table( # 判断是标准表还是财务表
table_name=table_name, is_financial = any(
s.table == table_name and s.join_type == "asof_backward"
for s in data_specs
)
if is_financial:
# 财务表:找到对应的 spec 获取 join 配置
financial_spec = next(
s
for s in data_specs
if s.table == table_name and s.join_type == "asof_backward"
)
spec = DataSpec(
table=table_name,
columns=list(columns), columns=list(columns),
join_type="asof_backward",
left_on=financial_spec.left_on,
right_on=financial_spec.right_on,
)
else:
# 标准表
spec = DataSpec(
table=table_name,
columns=list(columns),
join_type="standard",
)
df = self._load_table_from_spec(
spec=spec,
start_date=start_date, start_date=start_date,
end_date=end_date, end_date=end_date,
stock_codes=stock_codes, stock_codes=stock_codes,
) )
table_data[table_name] = df table_data[table_name] = df
# 组装核心宽表 # 组装核心宽表(支持多种 join 类型)
core_table = self._assemble_wide_table(table_data, required_tables) core_table = self._assemble_wide_table_with_specs(
table_data, data_specs, start_date, end_date
)
return core_table return core_table
def _load_table_from_spec(
self,
spec: DataSpec,
start_date: str,
end_date: str,
stock_codes: Optional[List[str]] = None,
) -> pl.DataFrame:
"""根据数据规格加载单个表的数据。
根据 spec.join_type 选择不同的加载方式:
- standard: 使用原有逻辑,基于 trade_date
- asof_backward: 使用 FinancialLoader基于 f_ann_date扩展回看期
Args:
spec: 数据规格
start_date: 开始日期
end_date: 结束日期
stock_codes: 股票代码过滤
Returns:
过滤后的 DataFrame
"""
cache_key = (
f"{spec.table}_{spec.join_type}_{start_date}_{end_date}_{stock_codes}"
)
with self._lock:
if cache_key in self._cache:
return self._cache[cache_key]
if spec.join_type == "asof_backward":
# 财务数据使用 FinancialLoader
if self._financial_loader is None:
raise RuntimeError("FinancialLoader 未初始化")
# 扩展日期范围回看1年
adjusted_start, _ = self._financial_loader.get_date_range_with_lookback(
start_date, end_date
)
# 处理 stock_codes
ts_code = stock_codes[0] if stock_codes and len(stock_codes) == 1 else None
df = self._financial_loader.load_financial_data(
table_name=spec.table,
columns=spec.columns,
start_date=adjusted_start,
end_date=end_date,
ts_code=ts_code,
)
# 如果 stock_codes 是列表且长度 > 1在内存中过滤
if stock_codes is not None and len(stock_codes) > 1:
df = df.filter(pl.col("ts_code").is_in(stock_codes))
else:
# 标准表使用原有逻辑
df = self._load_table(
table_name=spec.table,
columns=spec.columns,
start_date=start_date,
end_date=end_date,
stock_codes=stock_codes,
)
with self._lock:
self._cache[cache_key] = df
return df
def _load_table( def _load_table(
self, self,
table_name: str, table_name: str,
@@ -255,6 +358,119 @@ class DataRouter:
return result return result
def _assemble_wide_table_with_specs(
self,
table_data: Dict[str, pl.DataFrame],
data_specs: List[DataSpec],
start_date: str,
end_date: str,
) -> pl.DataFrame:
"""组装多表数据为核心宽表(支持多种 join 类型)。
支持标准等值匹配和 asof_backward 两种模式。
性能优化:
- 在开始时统一将 trade_date 转为 pl.Date
- 所有 asof join 全部在 pl.Date 类型下完成
- 返回前统一转回字符串格式
Args:
table_data: 表名到 DataFrame 的映射
data_specs: 数据规格列表
start_date: 开始日期
end_date: 结束日期
Returns:
组装后的宽表
"""
if not table_data:
raise ValueError("没有数据可组装")
# 从 data_specs 判断每个表的 join 类型
table_join_types = {}
for spec in data_specs:
if spec.table not in table_join_types:
table_join_types[spec.table] = spec.join_type
# 分离标准表和 asof 表(基于 table_data 的表名,避免重复)
standard_tables = [
t
for t in table_data.keys()
if table_join_types.get(t, "standard") == "standard"
]
asof_tables = [
t for t in table_data.keys() if table_join_types.get(t) == "asof_backward"
]
# 先合并所有标准表(使用 trade_date
base_df = None
for table_name in standard_tables:
df = table_data[table_name]
if base_df is None:
base_df = df
else:
# 使用 ts_code 和 trade_date 作为 join 键
# 注:根据动态路由原则,除 ts_code/trade_date 外不应有重复字段
# 如果出现重复,说明 SchemaCache 的字段映射有问题
base_df = base_df.join(
df,
on=["ts_code", "trade_date"],
how="left",
)
if base_df is None:
raise ValueError("至少需要一张标准行情表作为基础")
# 【性能优化】统一转换 trade_date 为 Date 类型(只转换一次)
if asof_tables:
base_df = base_df.with_columns(
[
pl.col("trade_date")
.str.strptime(pl.Date, "%Y%m%d")
.alias("trade_date")
]
)
# 确保已排序join_asof 要求)
base_df = base_df.sort(["ts_code", "trade_date"])
# 逐个合并 asof 表(所有 join 都在 Date 类型下进行)
for table_name in asof_tables:
df_financial = table_data[table_name]
# 提取需要保留的字段(排除 join 键和元数据字段)
# 从 data_specs 中找到对应表的 columns
table_columns = set()
for spec in data_specs:
if spec.table == table_name:
table_columns.update(spec.columns)
financial_cols = [
c
for c in table_columns
if c
not in [
"ts_code",
"f_ann_date",
"report_type",
"update_flag",
"end_date",
]
]
if self._financial_loader is None:
raise RuntimeError("FinancialLoader 未初始化")
base_df = self._financial_loader.merge_financial_with_price(
base_df, df_financial, financial_cols
)
# 【性能优化】所有 asof join 完成后,统一转回字符串格式
if asof_tables:
base_df = base_df.with_columns(
[pl.col("trade_date").dt.strftime("%Y%m%d").alias("trade_date")]
)
return base_df
def clear_cache(self) -> None: def clear_cache(self) -> None:
"""清除数据缓存。""" """清除数据缓存。"""
with self._lock: with self._lock:

View File

@@ -4,24 +4,38 @@
""" """
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Set, Union from typing import Any, Dict, List, Literal, Optional, Set, Union
import polars as pl import polars as pl
@dataclass @dataclass
class DataSpec: class DataSpec:
"""数据规格定义。 """数据规格定义(支持多表类型)
描述因子计算所需的数据表和字段。 描述因子计算所需的数据表和字段,支持多种拼接类型
Attributes: Attributes:
table: 数据表名称 table: 数据表名称
columns: 需要的字段列表 columns: 需要的字段列表
join_type: 拼接类型
- "standard": 标准等值匹配(默认)
- "asof_backward": 向后寻找最近历史数据(财务数据用)
left_on: 左表 join 键asof 模式下必须指定)
right_on: 右表 join 键asof 模式下必须指定)
""" """
table: str table: str
columns: List[str] columns: List[str]
join_type: Literal["standard", "asof_backward"] = "standard"
left_on: Optional[str] = None # 行情表日期列名
right_on: Optional[str] = None # 财务表日期列名
def __post_init__(self):
"""验证 asof_backward 模式的参数。"""
if self.join_type == "asof_backward":
if not self.left_on or not self.right_on:
raise ValueError("asof_backward 模式必须指定 left_on 和 right_on")
@dataclass @dataclass

View File

@@ -72,9 +72,10 @@ class ExecutionPlanner:
dependencies: Set[str], dependencies: Set[str],
expression: Node, expression: Node,
) -> List[DataSpec]: ) -> List[DataSpec]:
"""从依赖推导数据规格。 """从依赖推导数据规格(支持财务数据自动识别)
使用 SchemaCache 动态扫描数据库表结构,自动匹配字段到对应的表。 使用 SchemaCache 动态扫描数据库表结构,自动匹配字段到对应的表。
自动识别财务数据表并配置 asof_backward 模式。
表结构只扫描一次并缓存在内存中。 表结构只扫描一次并缓存在内存中。
Args: Args:
@@ -90,11 +91,21 @@ class ExecutionPlanner:
data_specs = [] data_specs = []
for table_name, columns in table_to_fields.items(): for table_name, columns in table_to_fields.items():
data_specs.append( if schema_cache.is_financial_table(table_name):
DataSpec( # 财务表使用 asof_backward 模式
spec = DataSpec(
table=table_name,
columns=columns,
join_type="asof_backward",
left_on="trade_date",
right_on="f_ann_date",
)
else:
# 标准表使用默认模式
spec = DataSpec(
table=table_name, table=table_name,
columns=columns, columns=columns,
) )
) data_specs.append(spec)
return data_specs return data_specs

View File

@@ -115,7 +115,7 @@ class SchemaCache:
field_to_tables[field] = [] field_to_tables[field] = []
field_to_tables[field].append(table) field_to_tables[field].append(table)
# 优先选择最常用的表pro_bar > daily_basic > daily # 优先选择最常用的表pro_bar > daily_basic > daily > financial
priority_order = {"pro_bar": 1, "daily_basic": 2, "daily": 3} priority_order = {"pro_bar": 1, "daily_basic": 2, "daily": 3}
self._field_to_table_map = {} self._field_to_table_map = {}
@@ -124,6 +124,18 @@ class SchemaCache:
sorted_tables = sorted(tables, key=lambda t: priority_order.get(t, 999)) sorted_tables = sorted(tables, key=lambda t: priority_order.get(t, 999))
self._field_to_table_map[field] = sorted_tables[0] self._field_to_table_map[field] = sorted_tables[0]
def is_financial_table(self, table_name: str) -> bool:
"""判断是否为财务数据表。
Args:
table_name: 表名
Returns:
是否为财务数据表
"""
financial_prefixes = ("financial_", "income", "balance", "cashflow")
return table_name.lower().startswith(financial_prefixes)
def get_table_fields(self, table_name: str) -> List[str]: def get_table_fields(self, table_name: str) -> List[str]:
"""获取指定表的字段列表。 """获取指定表的字段列表。

View File

@@ -1,122 +0,0 @@
"""Test for daily market data API.
Tests the daily interface implementation against api.md requirements:
- A股日线行情所有输出字段
- tor 换手率
- vr 量比
"""
import pytest
import pandas as pd
from src.data.api_wrappers import get_daily
# Expected output fields according to api.md
EXPECTED_BASE_FIELDS = [
"ts_code", # 股票代码
"trade_date", # 交易日期
"open", # 开盘价
"high", # 最高价
"low", # 最低价
"close", # 收盘价
"pre_close", # 昨收价
"change", # 涨跌额
"pct_chg", # 涨跌幅
"vol", # 成交量
"amount", # 成交额
]
EXPECTED_FACTOR_FIELDS = [
"turnover_rate", # 换手率 (tor)
"volume_ratio", # 量比 (vr)
]
class TestGetDaily:
"""Test cases for get_daily function with real API calls."""
def test_fetch_basic(self):
"""Test basic daily data fetch with real API."""
result = get_daily("000001.SZ", start_date="20240101", end_date="20240131")
assert isinstance(result, pd.DataFrame)
assert len(result) >= 1
assert result["ts_code"].iloc[0] == "000001.SZ"
def test_fetch_with_factors(self):
"""Test fetch with tor and vr factors."""
result = get_daily(
"000001.SZ",
start_date="20240101",
end_date="20240131",
factors=["tor", "vr"],
)
assert isinstance(result, pd.DataFrame)
# Check all base fields are present
for field in EXPECTED_BASE_FIELDS:
assert field in result.columns, f"Missing base field: {field}"
# Check factor fields are present
for field in EXPECTED_FACTOR_FIELDS:
assert field in result.columns, f"Missing factor field: {field}"
def test_output_fields_completeness(self):
"""Verify all required output fields are returned."""
result = get_daily("600000.SH")
# Verify all base fields are present
assert set(EXPECTED_BASE_FIELDS).issubset(result.columns.tolist()), (
f"Missing fields: {set(EXPECTED_BASE_FIELDS) - set(result.columns)}"
)
def test_empty_result(self):
"""Test handling of empty results."""
# 使用真实 API 测试无效股票代码的空结果
result = get_daily("INVALID.SZ")
assert isinstance(result, pd.DataFrame)
assert result.empty
def test_date_range_query(self):
"""Test query with date range."""
result = get_daily(
"000001.SZ",
start_date="20240101",
end_date="20240131",
)
assert isinstance(result, pd.DataFrame)
assert len(result) >= 1
def test_with_adj(self):
"""Test fetch with adjustment type."""
result = get_daily("000001.SZ", adj="qfq")
assert isinstance(result, pd.DataFrame)
def test_integration():
"""Integration test with real Tushare API (requires valid token)."""
import os
token = os.environ.get("TUSHARE_TOKEN")
if not token:
pytest.skip("TUSHARE_TOKEN not configured")
result = get_daily(
"000001.SZ", start_date="20240101", end_date="20240131", factors=["tor", "vr"]
)
# Verify structure
assert isinstance(result, pd.DataFrame)
if not result.empty:
# Check base fields
for field in EXPECTED_BASE_FIELDS:
assert field in result.columns, f"Missing base field: {field}"
# Check factor fields
for field in EXPECTED_FACTOR_FIELDS:
assert field in result.columns, f"Missing factor field: {field}"
if __name__ == "__main__":
# 运行 pytest 单元测试真实API调用
pytest.main([__file__, "-v"])

View File

@@ -1,242 +0,0 @@
"""Tests for DuckDB storage validation.
Validates two key points:
1. All stocks from stock_basic.csv are saved in daily table
2. No abnormal data with very few data points (< 10 rows per stock)
使用 3 个月的真实数据进行测试 (2024年1月-3月)
"""
import pytest
import pandas as pd
from datetime import datetime, timedelta
from src.data.storage import Storage
from src.data.api_wrappers.api_stock_basic import _get_csv_path
class TestDailyStorageValidation:
"""Test daily table storage integrity and completeness."""
# 测试数据时间范围3个月
TEST_START_DATE = "20240101"
TEST_END_DATE = "20240331"
@pytest.fixture
def storage(self):
"""Create storage instance."""
return Storage()
@pytest.fixture
def stock_basic_df(self):
"""Load stock basic data from CSV."""
csv_path = _get_csv_path()
if not csv_path.exists():
pytest.skip(f"stock_basic.csv not found at {csv_path}")
return pd.read_csv(csv_path)
@pytest.fixture
def daily_df(self, storage):
"""Load daily data from DuckDB (3 months)."""
if not storage.exists("daily"):
pytest.skip("daily table not found in DuckDB")
# 从 DuckDB 加载 3 个月数据
df = storage.load(
"daily", start_date=self.TEST_START_DATE, end_date=self.TEST_END_DATE
)
if df.empty:
pytest.skip(
f"No data found for period {self.TEST_START_DATE} to {self.TEST_END_DATE}"
)
return df
def test_duckdb_connection(self, storage):
"""Test DuckDB connection and basic operations."""
assert storage.exists("daily") or True # 至少连接成功
print(f"[TEST] DuckDB connection successful")
def test_load_3months_data(self, storage):
"""Test loading 3 months of data from DuckDB."""
df = storage.load(
"daily", start_date=self.TEST_START_DATE, end_date=self.TEST_END_DATE
)
if df.empty:
pytest.skip("No data available for testing period")
# 验证数据覆盖范围
dates = df["trade_date"].astype(str)
min_date = dates.min()
max_date = dates.max()
print(f"[TEST] Loaded {len(df)} rows from {min_date} to {max_date}")
assert len(df) > 0, "Should have data in the 3-month period"
def test_all_stocks_saved(self, storage, stock_basic_df, daily_df):
"""Verify all stocks from stock_basic are saved in daily table.
This test ensures data completeness - every stock in stock_basic
should have corresponding data in daily table.
"""
if daily_df.empty:
pytest.fail("daily table is empty for test period")
# Get unique stock codes from both sources
expected_codes = set(stock_basic_df["ts_code"].dropna().unique())
actual_codes = set(daily_df["ts_code"].dropna().unique())
# Check for missing stocks
missing_codes = expected_codes - actual_codes
if missing_codes:
missing_list = sorted(missing_codes)
# Show first 20 missing stocks as sample
sample = missing_list[:20]
msg = f"Found {len(missing_codes)} stocks missing from daily table:\n"
msg += f"Sample missing: {sample}\n"
if len(missing_list) > 20:
msg += f"... and {len(missing_list) - 20} more"
# 对于3个月数据允许部分股票缺失可能是新股或未上市
print(f"[WARNING] {msg}")
# 只验证至少有80%的股票存在
coverage = len(actual_codes) / len(expected_codes) * 100
assert coverage >= 80, (
f"Stock coverage {coverage:.1f}% is below 80% threshold"
)
else:
print(
f"[TEST] All {len(expected_codes)} stocks from stock_basic are present in daily table"
)
def test_no_stock_with_insufficient_data(self, storage, daily_df):
"""Verify no stock has abnormally few data points (< 5 rows in 3 months).
Stocks with very few data points may indicate sync failures,
delisted stocks not properly handled, or data corruption.
"""
if daily_df.empty:
pytest.fail("daily table is empty for test period")
# Count rows per stock
stock_counts = daily_df.groupby("ts_code").size()
# Find stocks with less than 5 data points in 3 months
insufficient_stocks = stock_counts[stock_counts < 5]
if not insufficient_stocks.empty:
# Separate into categories for better reporting
empty_stocks = stock_counts[stock_counts == 0]
very_few_stocks = stock_counts[(stock_counts > 0) & (stock_counts < 5)]
msg = f"Found {len(insufficient_stocks)} stocks with insufficient data (< 5 rows in 3 months):\n"
if not empty_stocks.empty:
msg += f"\nEmpty stocks (0 rows): {len(empty_stocks)}\n"
sample = sorted(empty_stocks.index[:10].tolist())
msg += f"Sample: {sample}"
if not very_few_stocks.empty:
msg += f"\nVery few data points (1-4 rows): {len(very_few_stocks)}\n"
# Show counts for these stocks
sample = very_few_stocks.sort_values().head(20)
msg += "Sample (ts_code: count):\n"
for code, count in sample.items():
msg += f" {code}: {count} rows\n"
# 对于3个月数据允许少量异常但比例不能超过5%
if len(insufficient_stocks) / len(stock_counts) > 0.05:
pytest.fail(msg)
else:
print(f"[WARNING] {msg}")
print(f"[TEST] All stocks have sufficient data (>= 5 rows in 3 months)")
def test_data_integrity_basic(self, storage, daily_df):
"""Basic data integrity checks for daily table."""
if daily_df.empty:
pytest.fail("daily table is empty for test period")
# Check required columns exist
required_columns = ["ts_code", "trade_date"]
missing_columns = [
col for col in required_columns if col not in daily_df.columns
]
if missing_columns:
pytest.fail(f"Missing required columns: {missing_columns}")
# Check for null values in key columns
null_ts_code = daily_df["ts_code"].isna().sum()
null_trade_date = daily_df["trade_date"].isna().sum()
if null_ts_code > 0:
pytest.fail(f"Found {null_ts_code} rows with null ts_code")
if null_trade_date > 0:
pytest.fail(f"Found {null_trade_date} rows with null trade_date")
print(f"[TEST] Data integrity check passed for 3-month period")
def test_polars_export(self, storage):
"""Test Polars export functionality."""
if not storage.exists("daily"):
pytest.skip("daily table not found")
import polars as pl
# 测试 load_polars 方法
df = storage.load_polars(
"daily", start_date=self.TEST_START_DATE, end_date=self.TEST_END_DATE
)
assert isinstance(df, pl.DataFrame), "Should return Polars DataFrame"
print(f"[TEST] Polars export successful: {len(df)} rows")
def test_stock_data_coverage_report(self, storage, daily_df):
"""Generate a summary report of stock data coverage.
This test provides visibility into data distribution without failing.
"""
if daily_df.empty:
pytest.skip("daily table is empty - cannot generate report")
stock_counts = daily_df.groupby("ts_code").size()
# Calculate statistics
total_stocks = len(stock_counts)
min_count = stock_counts.min()
max_count = stock_counts.max()
median_count = stock_counts.median()
mean_count = stock_counts.mean()
# Distribution buckets (adjusted for 3-month period, ~60 trading days)
very_low = (stock_counts < 5).sum()
low = ((stock_counts >= 5) & (stock_counts < 20)).sum()
medium = ((stock_counts >= 20) & (stock_counts < 40)).sum()
high = (stock_counts >= 40).sum()
report = f"""
=== Stock Data Coverage Report (3 months: {self.TEST_START_DATE} to {self.TEST_END_DATE}) ===
Total stocks: {total_stocks}
Data points per stock:
Min: {min_count}
Max: {max_count}
Median: {median_count:.0f}
Mean: {mean_count:.1f}
Distribution:
< 5 rows: {very_low} stocks ({very_low / total_stocks * 100:.1f}%)
5-19: {low} stocks ({low / total_stocks * 100:.1f}%)
20-39: {medium} stocks ({medium / total_stocks * 100:.1f}%)
>= 40: {high} stocks ({high / total_stocks * 100:.1f}%)
"""
print(report)
# This is an informational test - it should not fail
# But we assert to mark it as passed
assert total_stocks > 0
if __name__ == "__main__":
pytest.main([__file__, "-v", "-s"])

View File

@@ -0,0 +1,244 @@
"""财务数据与行情数据拼接测试。
测试场景:
1. 普通财务数据:正常公告,之后无修改
2. 隔日修改:公告后几天发布修正版
3. 当日修改:同一天发布多版,取 update_flag=1 的
4. 边界条件:财务数据缺失、行情数据早于最早财务数据
"""
import polars as pl
from datetime import date
from src.data.financial_loader import FinancialLoader
def create_mock_price_data() -> pl.DataFrame:
"""创建模拟行情数据。"""
return pl.DataFrame(
{
"ts_code": ["000001.SZ"] * 10,
"trade_date": [
"20240101",
"20240102",
"20240103",
"20240104",
"20240105",
"20240108",
"20240109",
"20240110",
"20240111",
"20240112",
],
"close": [10.0, 10.2, 10.3, 10.1, 10.5, 10.6, 10.4, 10.7, 10.8, 10.9],
}
)
def create_mock_financial_data() -> pl.DataFrame:
"""创建模拟财务数据(覆盖多种场景)。
注意f_ann_date 必须是 Date 类型(与数据库保持一致)。
"""
return pl.DataFrame(
{
"ts_code": ["000001.SZ", "000001.SZ", "000001.SZ", "000001.SZ"],
# 场景1: 2023Q3 报告,正常公告
# 场景2: 同日多版update_flag 区分)
# 场景3: 隔日修改
"f_ann_date": [
date(2024, 1, 2),
date(2024, 1, 2),
date(2024, 1, 5),
date(2024, 1, 10),
],
"end_date": ["20230930", "20230930", "20230930", "20231231"],
"report_type": [1, 1, 1, 1], # 整数类型(与数据库一致)
"update_flag": [0, 1, 1, 1], # 整数类型(与数据库一致)
"net_profit": [1000000.0, 1100000.0, 1100000.0, 1200000.0],
"revenue": [5000000.0, 5200000.0, 5200000.0, 6000000.0],
}
)
def test_financial_data_cleaning():
"""测试财务数据清洗逻辑。"""
print("=== 测试 1: 财务数据清洗 ===")
df_finance = create_mock_financial_data()
print("原始财务数据:")
print(df_finance)
loader = FinancialLoader()
# 手动执行清洗(模拟 load_financial_data 的逻辑)
# 步骤1: 仅保留合并报表
df = df_finance.filter(pl.col("report_type") == 1)
# 步骤2: 按 update_flag 降序排列后去重
df = df.with_columns(
[pl.col("update_flag").cast(pl.Int32).alias("update_flag_int")]
)
df = df.sort(
["ts_code", "f_ann_date", "update_flag_int"], descending=[False, False, True]
)
df = df.unique(subset=["ts_code", "f_ann_date"], keep="first")
df = df.drop("update_flag_int")
# 步骤3: 排序f_ann_date 已经是 Date 类型)
df = df.sort(["ts_code", "f_ann_date"])
print("\n清洗后的财务数据:")
print(df)
# 验证应该有3条记录第1-2行去重为1条第3行第4行
assert len(df) == 3, f"清洗后应该有3条记录实际有 {len(df)}"
# 验证2024-01-02 的 update_flag 应该是 1
row_jan02 = df.filter(pl.col("f_ann_date") == date(2024, 1, 2))
assert len(row_jan02) == 1, "应该有1条 2024-01-02 的记录"
assert row_jan02["update_flag"][0] == 1, "update_flag 应该为 1"
assert row_jan02["net_profit"][0] == 1100000.0, "net_profit 应该为 1100000"
print("\n[通过] 财务数据清洗测试通过!")
return df
def test_financial_price_merge():
"""测试财务数据拼接逻辑(无未来函数验证)。"""
print("\n=== 测试 2: 财务数据与行情数据拼接 ===")
df_price = create_mock_price_data()
df_finance_raw = create_mock_financial_data()
loader = FinancialLoader()
# 步骤1: 清洗财务数据(手动执行)
# 注意f_ann_date 已经是 Date 类型,不需要转换
df_finance = df_finance_raw.filter(pl.col("report_type") == 1)
df_finance = df_finance.with_columns(
[pl.col("update_flag").cast(pl.Int32).alias("update_flag_int")]
)
df_finance = df_finance.sort(
["ts_code", "f_ann_date", "update_flag_int"], descending=[False, False, True]
)
df_finance = df_finance.unique(subset=["ts_code", "f_ann_date"], keep="first")
df_finance = df_finance.drop("update_flag_int")
df_finance = df_finance.sort(["ts_code", "f_ann_date"])
print("清洗后的财务数据:")
print(df_finance)
# 步骤2: 转换行情数据日期为 Date 类型
df_price = df_price.with_columns(
[pl.col("trade_date").str.strptime(pl.Date, "%Y%m%d").alias("trade_date")]
)
df_price = df_price.sort(["ts_code", "trade_date"])
# 步骤3: 拼接
financial_cols = ["net_profit", "revenue"]
merged = loader.merge_financial_with_price(df_price, df_finance, financial_cols)
# 步骤4: 转回字符串格式
merged = merged.with_columns(
[pl.col("trade_date").dt.strftime("%Y%m%d").alias("trade_date")]
)
print("\n拼接结果:")
print(merged)
# 验证无未来函数:
# 20240101 之前不应有 2023Q3 数据(因为 20240102 才公告)
jan01 = merged.filter(pl.col("trade_date") == "20240101")
assert jan01["net_profit"].is_null().all(), (
"2024-01-01 不应有 2023Q3 数据(尚未公告)"
)
print("[验证 1] 2024-01-01 net_profit 为 null - 正确(公告前无数据)")
# 20240102 及之后应该看到 net_profit=1100000update_flag=1 的版本)
jan02 = merged.filter(pl.col("trade_date") == "20240102")
assert jan02["net_profit"][0] == 1100000.0, "2024-01-02 应使用 update_flag=1 的数据"
print("[验证 2] 2024-01-02 net_profit=1100000 - 正确(使用 update_flag=1")
# 20240104 应延续使用 2023Q3 数据
jan04 = merged.filter(pl.col("trade_date") == "20240104")
assert jan04["net_profit"][0] == 1100000.0, "2024-01-04 应延续使用 2023Q3 数据"
print("[验证 3] 2024-01-04 net_profit=1100000 - 正确(延续使用)")
# 20240110 应切换到 2023Q4 数据(新公告)
jan10 = merged.filter(pl.col("trade_date") == "20240110")
assert jan10["net_profit"][0] == 1200000.0, "2024-01-10 应切换到 2023Q4 数据"
print("[验证 4] 2024-01-10 net_profit=1200000 - 正确(新财报公告)")
# 20240112 应继续延续使用 2023Q4 数据
jan12 = merged.filter(pl.col("trade_date") == "20240112")
assert jan12["net_profit"][0] == 1200000.0, "2024-01-12 应继续使用 2023Q4 数据"
print("[验证 5] 2024-01-12 net_profit=1200000 - 正确(延续使用)")
print("\n[通过] 所有验证通过,无未来函数!")
return merged
def test_empty_financial_data():
"""测试财务数据为空的情况。"""
print("\n=== 测试 3: 空财务数据场景 ===")
df_price = create_mock_price_data()
df_empty = pl.DataFrame()
loader = FinancialLoader()
# 转换行情数据日期为 Date 类型
df_price = df_price.with_columns(
[pl.col("trade_date").str.strptime(pl.Date, "%Y%m%d").alias("trade_date")]
)
df_price = df_price.sort(["ts_code", "trade_date"])
# 拼接空财务数据
merged = loader.merge_financial_with_price(df_price, df_empty, ["net_profit"])
# 转回字符串格式
merged = merged.with_columns(
[pl.col("trade_date").dt.strftime("%Y%m%d").alias("trade_date")]
)
# 验证财务列为空
assert merged["net_profit"].is_null().all(), (
"财务数据为空时net_profit 应全为 null"
)
print("空财务数据拼接结果:")
print(merged)
print("\n[通过] 空财务数据场景测试通过!")
def run_all_tests():
"""运行所有测试。"""
print("开始运行财务数据拼接功能测试...\n")
print("=" * 60)
try:
# 测试 1: 数据清洗
test_financial_data_cleaning()
# 测试 2: 数据拼接
test_financial_price_merge()
# 测试 3: 空数据场景
test_empty_financial_data()
print("\n" + "=" * 60)
print("所有测试通过!")
print("=" * 60)
except AssertionError as e:
print(f"\n[失败] 测试断言失败: {e}")
raise
except Exception as e:
print(f"\n[错误] 测试执行出错: {e}")
raise
if __name__ == "__main__":
run_all_tests()