Files
ProStock/tests/test_financial_price_merge.py
liaozhaorun 3b42093100 feat(data): 财务数据加载与清洗模块
新增 FinancialLoader 类,提供:
- 财务数据加载与清洗(保留合并报表,按 update_flag 去重)
- 支持 as-of join 拼接行情数据(无未来函数)
- 自动识别财务表并配置 asof_backward 拼接模式
2026-03-04 23:35:20 +08:00

245 lines
8.3 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""财务数据与行情数据拼接测试。
测试场景:
1. 普通财务数据:正常公告,之后无修改
2. 隔日修改:公告后几天发布修正版
3. 当日修改:同一天发布多版,取 update_flag=1 的
4. 边界条件:财务数据缺失、行情数据早于最早财务数据
"""
import polars as pl
from datetime import date
from src.data.financial_loader import FinancialLoader
def create_mock_price_data() -> pl.DataFrame:
"""创建模拟行情数据。"""
return pl.DataFrame(
{
"ts_code": ["000001.SZ"] * 10,
"trade_date": [
"20240101",
"20240102",
"20240103",
"20240104",
"20240105",
"20240108",
"20240109",
"20240110",
"20240111",
"20240112",
],
"close": [10.0, 10.2, 10.3, 10.1, 10.5, 10.6, 10.4, 10.7, 10.8, 10.9],
}
)
def create_mock_financial_data() -> pl.DataFrame:
"""创建模拟财务数据(覆盖多种场景)。
注意f_ann_date 必须是 Date 类型(与数据库保持一致)。
"""
return pl.DataFrame(
{
"ts_code": ["000001.SZ", "000001.SZ", "000001.SZ", "000001.SZ"],
# 场景1: 2023Q3 报告,正常公告
# 场景2: 同日多版update_flag 区分)
# 场景3: 隔日修改
"f_ann_date": [
date(2024, 1, 2),
date(2024, 1, 2),
date(2024, 1, 5),
date(2024, 1, 10),
],
"end_date": ["20230930", "20230930", "20230930", "20231231"],
"report_type": [1, 1, 1, 1], # 整数类型(与数据库一致)
"update_flag": [0, 1, 1, 1], # 整数类型(与数据库一致)
"net_profit": [1000000.0, 1100000.0, 1100000.0, 1200000.0],
"revenue": [5000000.0, 5200000.0, 5200000.0, 6000000.0],
}
)
def test_financial_data_cleaning():
"""测试财务数据清洗逻辑。"""
print("=== 测试 1: 财务数据清洗 ===")
df_finance = create_mock_financial_data()
print("原始财务数据:")
print(df_finance)
loader = FinancialLoader()
# 手动执行清洗(模拟 load_financial_data 的逻辑)
# 步骤1: 仅保留合并报表
df = df_finance.filter(pl.col("report_type") == 1)
# 步骤2: 按 update_flag 降序排列后去重
df = df.with_columns(
[pl.col("update_flag").cast(pl.Int32).alias("update_flag_int")]
)
df = df.sort(
["ts_code", "f_ann_date", "update_flag_int"], descending=[False, False, True]
)
df = df.unique(subset=["ts_code", "f_ann_date"], keep="first")
df = df.drop("update_flag_int")
# 步骤3: 排序f_ann_date 已经是 Date 类型)
df = df.sort(["ts_code", "f_ann_date"])
print("\n清洗后的财务数据:")
print(df)
# 验证应该有3条记录第1-2行去重为1条第3行第4行
assert len(df) == 3, f"清洗后应该有3条记录实际有 {len(df)}"
# 验证2024-01-02 的 update_flag 应该是 1
row_jan02 = df.filter(pl.col("f_ann_date") == date(2024, 1, 2))
assert len(row_jan02) == 1, "应该有1条 2024-01-02 的记录"
assert row_jan02["update_flag"][0] == 1, "update_flag 应该为 1"
assert row_jan02["net_profit"][0] == 1100000.0, "net_profit 应该为 1100000"
print("\n[通过] 财务数据清洗测试通过!")
return df
def test_financial_price_merge():
"""测试财务数据拼接逻辑(无未来函数验证)。"""
print("\n=== 测试 2: 财务数据与行情数据拼接 ===")
df_price = create_mock_price_data()
df_finance_raw = create_mock_financial_data()
loader = FinancialLoader()
# 步骤1: 清洗财务数据(手动执行)
# 注意f_ann_date 已经是 Date 类型,不需要转换
df_finance = df_finance_raw.filter(pl.col("report_type") == 1)
df_finance = df_finance.with_columns(
[pl.col("update_flag").cast(pl.Int32).alias("update_flag_int")]
)
df_finance = df_finance.sort(
["ts_code", "f_ann_date", "update_flag_int"], descending=[False, False, True]
)
df_finance = df_finance.unique(subset=["ts_code", "f_ann_date"], keep="first")
df_finance = df_finance.drop("update_flag_int")
df_finance = df_finance.sort(["ts_code", "f_ann_date"])
print("清洗后的财务数据:")
print(df_finance)
# 步骤2: 转换行情数据日期为 Date 类型
df_price = df_price.with_columns(
[pl.col("trade_date").str.strptime(pl.Date, "%Y%m%d").alias("trade_date")]
)
df_price = df_price.sort(["ts_code", "trade_date"])
# 步骤3: 拼接
financial_cols = ["net_profit", "revenue"]
merged = loader.merge_financial_with_price(df_price, df_finance, financial_cols)
# 步骤4: 转回字符串格式
merged = merged.with_columns(
[pl.col("trade_date").dt.strftime("%Y%m%d").alias("trade_date")]
)
print("\n拼接结果:")
print(merged)
# 验证无未来函数:
# 20240101 之前不应有 2023Q3 数据(因为 20240102 才公告)
jan01 = merged.filter(pl.col("trade_date") == "20240101")
assert jan01["net_profit"].is_null().all(), (
"2024-01-01 不应有 2023Q3 数据(尚未公告)"
)
print("[验证 1] 2024-01-01 net_profit 为 null - 正确(公告前无数据)")
# 20240102 及之后应该看到 net_profit=1100000update_flag=1 的版本)
jan02 = merged.filter(pl.col("trade_date") == "20240102")
assert jan02["net_profit"][0] == 1100000.0, "2024-01-02 应使用 update_flag=1 的数据"
print("[验证 2] 2024-01-02 net_profit=1100000 - 正确(使用 update_flag=1")
# 20240104 应延续使用 2023Q3 数据
jan04 = merged.filter(pl.col("trade_date") == "20240104")
assert jan04["net_profit"][0] == 1100000.0, "2024-01-04 应延续使用 2023Q3 数据"
print("[验证 3] 2024-01-04 net_profit=1100000 - 正确(延续使用)")
# 20240110 应切换到 2023Q4 数据(新公告)
jan10 = merged.filter(pl.col("trade_date") == "20240110")
assert jan10["net_profit"][0] == 1200000.0, "2024-01-10 应切换到 2023Q4 数据"
print("[验证 4] 2024-01-10 net_profit=1200000 - 正确(新财报公告)")
# 20240112 应继续延续使用 2023Q4 数据
jan12 = merged.filter(pl.col("trade_date") == "20240112")
assert jan12["net_profit"][0] == 1200000.0, "2024-01-12 应继续使用 2023Q4 数据"
print("[验证 5] 2024-01-12 net_profit=1200000 - 正确(延续使用)")
print("\n[通过] 所有验证通过,无未来函数!")
return merged
def test_empty_financial_data():
"""测试财务数据为空的情况。"""
print("\n=== 测试 3: 空财务数据场景 ===")
df_price = create_mock_price_data()
df_empty = pl.DataFrame()
loader = FinancialLoader()
# 转换行情数据日期为 Date 类型
df_price = df_price.with_columns(
[pl.col("trade_date").str.strptime(pl.Date, "%Y%m%d").alias("trade_date")]
)
df_price = df_price.sort(["ts_code", "trade_date"])
# 拼接空财务数据
merged = loader.merge_financial_with_price(df_price, df_empty, ["net_profit"])
# 转回字符串格式
merged = merged.with_columns(
[pl.col("trade_date").dt.strftime("%Y%m%d").alias("trade_date")]
)
# 验证财务列为空
assert merged["net_profit"].is_null().all(), (
"财务数据为空时net_profit 应全为 null"
)
print("空财务数据拼接结果:")
print(merged)
print("\n[通过] 空财务数据场景测试通过!")
def run_all_tests():
"""运行所有测试。"""
print("开始运行财务数据拼接功能测试...\n")
print("=" * 60)
try:
# 测试 1: 数据清洗
test_financial_data_cleaning()
# 测试 2: 数据拼接
test_financial_price_merge()
# 测试 3: 空数据场景
test_empty_financial_data()
print("\n" + "=" * 60)
print("所有测试通过!")
print("=" * 60)
except AssertionError as e:
print(f"\n[失败] 测试断言失败: {e}")
raise
except Exception as e:
print(f"\n[错误] 测试执行出错: {e}")
raise
if __name__ == "__main__":
run_all_tests()