Files
ProStock/tests/test_financial_price_merge.py
liaozhaorun 505279c08b fix(data): 修复财务因子计算非确定性问题
重构 financial_loader 的去重逻辑,确保截面排名计算的股票集合一致:
- 引入"高水位线"算法剔除陈旧历史财报(解决2026年发布2021年财报的问题)
- 改变去重策略:按报告期(end_date)而非更新标识(update_flag)保留最新数据
- 扩展回看期从1年到2年,防止ST/停牌公司财报缺失
- 确保相同交易日在不同查询范围下返回一致的财务数据
2026-03-08 20:58:35 +08:00

352 lines
11 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""财务数据与行情数据拼接测试。
测试场景:
1. 普通财务数据:正常公告,之后无修改
2. 隔日修改:公告后几天发布修正版
3. 当日修改:同一天发布多版,取 update_flag=1 的
4. 边界条件:财务数据缺失、行情数据早于最早财务数据
"""
import polars as pl
from datetime import date
from src.data.financial_loader import FinancialLoader
def create_mock_price_data() -> pl.DataFrame:
"""创建模拟行情数据。"""
return pl.DataFrame(
{
"ts_code": ["000001.SZ"] * 12,
"trade_date": [
"20240101",
"20240102",
"20240103",
"20240104",
"20240105",
"20240108",
"20240109",
"20240110",
"20240111",
"20240112",
# 添加2024-04-30之后的日期用于测试同日不同报告期场景
"20240501",
"20240502",
],
"close": [
10.0,
10.2,
10.3,
10.1,
10.5,
10.6,
10.4,
10.7,
10.8,
10.9,
11.0,
11.1,
],
}
)
def create_mock_financial_data() -> pl.DataFrame:
"""创建模拟财务数据(覆盖多种场景)。
场景说明:
1. 2024-01-02 发布 2023Q3 报告end_date=20230930
2. 2024-01-02 发布 2023Q3 更正版update_flag=1
3. 2024-04-30 同时发布 2023年报end_date=20231231和 2024Q1季报end_date=20240331
4. 2024-04-30 发布 2023年报更正版
预期结果:
- 2024-01-02 保留 2023Q3 更正版
- 2024-04-30 保留 2024Q1 季报end_date 最新)
注意f_ann_date 必须是 Date 类型(与数据库保持一致)。
"""
return pl.DataFrame(
{
"ts_code": [
"000001.SZ",
"000001.SZ",
"000001.SZ",
"000001.SZ",
"000001.SZ",
],
"f_ann_date": [
date(2024, 1, 2),
date(2024, 1, 2), # 同日多版
date(2024, 4, 30),
date(2024, 4, 30),
date(2024, 4, 30), # 同日不同报告期
],
"end_date": [
"20230930",
"20230930", # 2023Q3
"20231231",
"20240331",
"20231231", # 年报和季报同一天发布
],
"report_type": [1, 1, 1, 1, 1], # 整数类型(与数据库一致)
"update_flag": [0, 1, 0, 0, 1], # 年报也有更正版
"net_profit": [
1000000.0,
1100000.0, # 2023Q3
5000000.0,
1500000.0,
5500000.0, # 年报更正后550万季报150万
],
"revenue": [
5000000.0,
5200000.0, # 2023Q3
20000000.0,
8000000.0,
22000000.0,
],
}
)
def test_financial_data_cleaning():
"""测试财务数据清洗逻辑 - 确保同日多报告期时选 end_date 最新的。"""
print("=== 测试 1: 财务数据清洗 ===")
df_finance = create_mock_financial_data()
print("原始财务数据:")
print(df_finance)
loader = FinancialLoader()
# 手动执行新的清洗逻辑
df = df_finance.filter(pl.col("report_type") == 1)
# 添加辅助列
df = df.with_columns(
[
pl.col("end_date").cast(pl.Int32).alias("end_date_int"),
pl.col("update_flag")
.fill_null("0")
.cast(pl.Int32, strict=False)
.fill_null(0)
.alias("update_flag_int"),
]
)
# 确定性排序
df = df.sort(["ts_code", "f_ann_date", "end_date_int", "update_flag_int"])
# 累积最大报告期
df = df.with_columns(
pl.col("end_date_int").cum_max().over("ts_code").alias("max_end_date_seen")
)
# 过滤历史包袱
df = df.filter(pl.col("end_date_int") == pl.col("max_end_date_seen"))
# 去重保留最后一条end_date 最大的)
df = df.unique(subset=["ts_code", "f_ann_date"], keep="last")
# 清理辅助列
df = df.drop(["end_date_int", "update_flag_int", "max_end_date_seen"])
df = df.sort(["ts_code", "f_ann_date"])
print("\n清洗后的财务数据:")
print(df)
# 验证应该有2条记录2024-01-02 和 2024-04-30
assert len(df) == 2, f"清洗后应该有2条记录实际有 {len(df)}"
# 验证2024-01-02 的 end_date 应该是 20230930
row_jan02 = df.filter(pl.col("f_ann_date") == date(2024, 1, 2))
assert len(row_jan02) == 1
assert row_jan02["end_date"][0] == "20230930"
assert row_jan02["update_flag"][0] == 1
print("[验证 1] 2024-01-02 正确保留了 2023Q3 更正版")
# 验证2024-04-30 应该保留 2024Q1end_date=20240331而不是年报
row_apr30 = df.filter(pl.col("f_ann_date") == date(2024, 4, 30))
assert len(row_apr30) == 1
assert row_apr30["end_date"][0] == "20240331", (
f"2024-04-30 应该保留 end_date 最新的 20240331"
f"实际为 {row_apr30['end_date'][0]}"
)
assert row_apr30["net_profit"][0] == 1500000.0
print("[验证 2] 2024-04-30 正确保留了 2024Q1 季报end_date 最新)")
print("\n[通过] 财务数据清洗测试通过!")
return df
def test_financial_price_merge():
"""测试财务数据拼接逻辑(无未来函数验证)。"""
print("\n=== 测试 2: 财务数据与行情数据拼接 ===")
df_price = create_mock_price_data()
df_finance_raw = create_mock_financial_data()
loader = FinancialLoader()
# 步骤1: 清洗财务数据(手动执行新的清洗逻辑)
# 注意f_ann_date 已经是 Date 类型,不需要转换
df_finance = df_finance_raw.filter(pl.col("report_type") == 1)
# 添加辅助列
df_finance = df_finance.with_columns(
[
pl.col("end_date").cast(pl.Int32).alias("end_date_int"),
pl.col("update_flag")
.fill_null("0")
.cast(pl.Int32, strict=False)
.fill_null(0)
.alias("update_flag_int"),
]
)
# 确定性排序
df_finance = df_finance.sort(
["ts_code", "f_ann_date", "end_date_int", "update_flag_int"]
)
# 累积最大报告期
df_finance = df_finance.with_columns(
pl.col("end_date_int").cum_max().over("ts_code").alias("max_end_date_seen")
)
# 过滤历史包袱
df_finance = df_finance.filter(
pl.col("end_date_int") == pl.col("max_end_date_seen")
)
# 去重保留最后一条end_date 最大的)
df_finance = df_finance.unique(subset=["ts_code", "f_ann_date"], keep="last")
# 清理辅助列
df_finance = df_finance.drop(
["end_date_int", "update_flag_int", "max_end_date_seen"]
)
df_finance = df_finance.sort(["ts_code", "f_ann_date"])
print("清洗后的财务数据:")
print(df_finance)
# 步骤2: 转换行情数据日期为 Date 类型
df_price = df_price.with_columns(
[pl.col("trade_date").str.strptime(pl.Date, "%Y%m%d").alias("trade_date")]
)
df_price = df_price.sort(["ts_code", "trade_date"])
# 步骤3: 拼接
financial_cols = ["net_profit", "revenue"]
merged = loader.merge_financial_with_price(df_price, df_finance, financial_cols)
# 步骤4: 转回字符串格式
merged = merged.with_columns(
[pl.col("trade_date").dt.strftime("%Y%m%d").alias("trade_date")]
)
print("\n拼接结果:")
print(merged)
# 验证无未来函数:
# 20240101 之前不应有 2023Q3 数据(因为 20240102 才公告)
jan01 = merged.filter(pl.col("trade_date") == "20240101")
assert jan01["net_profit"].is_null().all(), (
"2024-01-01 不应有 2023Q3 数据(尚未公告)"
)
print("[验证 1] 2024-01-01 net_profit 为 null - 正确(公告前无数据)")
# 20240102 及之后应该看到 net_profit=1100000update_flag=1 的版本)
jan02 = merged.filter(pl.col("trade_date") == "20240102")
assert jan02["net_profit"][0] == 1100000.0, "2024-01-02 应使用 update_flag=1 的数据"
print("[验证 2] 2024-01-02 net_profit=1100000 - 正确(使用 update_flag=1")
# 20240104 应延续使用 2023Q3 数据
jan04 = merged.filter(pl.col("trade_date") == "20240104")
assert jan04["net_profit"][0] == 1100000.0, "2024-01-04 应延续使用 2023Q3 数据"
print("[验证 3] 2024-01-04 net_profit=1100000 - 正确(延续使用)")
# 20240110 应延续使用 2023Q3 数据2024-04-30 还未公告)
jan10 = merged.filter(pl.col("trade_date") == "20240110")
assert jan10["net_profit"][0] == 1100000.0, "2024-01-10 应延续使用 2023Q3 数据"
print("[验证 4] 2024-01-10 net_profit=1100000 - 正确(延续使用 2023Q3")
# 20240112 应继续延续使用 2023Q3 数据
jan12 = merged.filter(pl.col("trade_date") == "20240112")
assert jan12["net_profit"][0] == 1100000.0, "2024-01-12 应继续使用 2023Q3 数据"
print("[验证 5] 2024-01-12 net_profit=1100000 - 正确(延续使用 2023Q3")
# 20240501 应切换到 2024Q1 数据2024-04-30 已公告,且选择 end_date 最新的)
may01 = merged.filter(pl.col("trade_date") == "20240501")
assert may01["net_profit"][0] == 1500000.0, "2024-05-01 应切换到 2024Q1 数据"
print(
"[验证 6] 2024-05-01 net_profit=1500000 - 正确(切换到 2024Q1end_date 最新)"
)
print("\n[通过] 所有验证通过,无未来函数!")
return merged
def test_empty_financial_data():
"""测试财务数据为空的情况。"""
print("\n=== 测试 3: 空财务数据场景 ===")
df_price = create_mock_price_data()
df_empty = pl.DataFrame()
loader = FinancialLoader()
# 转换行情数据日期为 Date 类型
df_price = df_price.with_columns(
[pl.col("trade_date").str.strptime(pl.Date, "%Y%m%d").alias("trade_date")]
)
df_price = df_price.sort(["ts_code", "trade_date"])
# 拼接空财务数据
merged = loader.merge_financial_with_price(df_price, df_empty, ["net_profit"])
# 转回字符串格式
merged = merged.with_columns(
[pl.col("trade_date").dt.strftime("%Y%m%d").alias("trade_date")]
)
# 验证财务列为空
assert merged["net_profit"].is_null().all(), (
"财务数据为空时net_profit 应全为 null"
)
print("空财务数据拼接结果:")
print(merged)
print("\n[通过] 空财务数据场景测试通过!")
def run_all_tests():
"""运行所有测试。"""
print("开始运行财务数据拼接功能测试...\n")
print("=" * 60)
try:
# 测试 1: 数据清洗
test_financial_data_cleaning()
# 测试 2: 数据拼接
test_financial_price_merge()
# 测试 3: 空数据场景
test_empty_financial_data()
print("\n" + "=" * 60)
print("所有测试通过!")
print("=" * 60)
except AssertionError as e:
print(f"\n[失败] 测试断言失败: {e}")
raise
except Exception as e:
print(f"\n[错误] 测试执行出错: {e}")
raise
if __name__ == "__main__":
run_all_tests()