feat(data): 财务数据加载与清洗模块
新增 FinancialLoader 类,提供: - 财务数据加载与清洗(保留合并报表,按 update_flag 去重) - 支持 as-of join 拼接行情数据(无未来函数) - 自动识别财务表并配置 asof_backward 拼接模式
This commit is contained in:
@@ -1,122 +0,0 @@
|
||||
"""Test for daily market data API.
|
||||
|
||||
Tests the daily interface implementation against api.md requirements:
|
||||
- A股日线行情所有输出字段
|
||||
- tor 换手率
|
||||
- vr 量比
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import pandas as pd
|
||||
from src.data.api_wrappers import get_daily
|
||||
|
||||
|
||||
# Expected output fields according to api.md
|
||||
EXPECTED_BASE_FIELDS = [
|
||||
"ts_code", # 股票代码
|
||||
"trade_date", # 交易日期
|
||||
"open", # 开盘价
|
||||
"high", # 最高价
|
||||
"low", # 最低价
|
||||
"close", # 收盘价
|
||||
"pre_close", # 昨收价
|
||||
"change", # 涨跌额
|
||||
"pct_chg", # 涨跌幅
|
||||
"vol", # 成交量
|
||||
"amount", # 成交额
|
||||
]
|
||||
|
||||
EXPECTED_FACTOR_FIELDS = [
|
||||
"turnover_rate", # 换手率 (tor)
|
||||
"volume_ratio", # 量比 (vr)
|
||||
]
|
||||
|
||||
|
||||
class TestGetDaily:
|
||||
"""Test cases for get_daily function with real API calls."""
|
||||
|
||||
def test_fetch_basic(self):
|
||||
"""Test basic daily data fetch with real API."""
|
||||
result = get_daily("000001.SZ", start_date="20240101", end_date="20240131")
|
||||
|
||||
assert isinstance(result, pd.DataFrame)
|
||||
assert len(result) >= 1
|
||||
assert result["ts_code"].iloc[0] == "000001.SZ"
|
||||
|
||||
def test_fetch_with_factors(self):
|
||||
"""Test fetch with tor and vr factors."""
|
||||
result = get_daily(
|
||||
"000001.SZ",
|
||||
start_date="20240101",
|
||||
end_date="20240131",
|
||||
factors=["tor", "vr"],
|
||||
)
|
||||
|
||||
assert isinstance(result, pd.DataFrame)
|
||||
# Check all base fields are present
|
||||
for field in EXPECTED_BASE_FIELDS:
|
||||
assert field in result.columns, f"Missing base field: {field}"
|
||||
# Check factor fields are present
|
||||
for field in EXPECTED_FACTOR_FIELDS:
|
||||
assert field in result.columns, f"Missing factor field: {field}"
|
||||
|
||||
def test_output_fields_completeness(self):
|
||||
"""Verify all required output fields are returned."""
|
||||
result = get_daily("600000.SH")
|
||||
|
||||
# Verify all base fields are present
|
||||
assert set(EXPECTED_BASE_FIELDS).issubset(result.columns.tolist()), (
|
||||
f"Missing fields: {set(EXPECTED_BASE_FIELDS) - set(result.columns)}"
|
||||
)
|
||||
|
||||
def test_empty_result(self):
|
||||
"""Test handling of empty results."""
|
||||
# 使用真实 API 测试无效股票代码的空结果
|
||||
result = get_daily("INVALID.SZ")
|
||||
assert isinstance(result, pd.DataFrame)
|
||||
assert result.empty
|
||||
|
||||
def test_date_range_query(self):
|
||||
"""Test query with date range."""
|
||||
result = get_daily(
|
||||
"000001.SZ",
|
||||
start_date="20240101",
|
||||
end_date="20240131",
|
||||
)
|
||||
|
||||
assert isinstance(result, pd.DataFrame)
|
||||
assert len(result) >= 1
|
||||
|
||||
def test_with_adj(self):
|
||||
"""Test fetch with adjustment type."""
|
||||
result = get_daily("000001.SZ", adj="qfq")
|
||||
|
||||
assert isinstance(result, pd.DataFrame)
|
||||
|
||||
|
||||
def test_integration():
|
||||
"""Integration test with real Tushare API (requires valid token)."""
|
||||
import os
|
||||
|
||||
token = os.environ.get("TUSHARE_TOKEN")
|
||||
if not token:
|
||||
pytest.skip("TUSHARE_TOKEN not configured")
|
||||
|
||||
result = get_daily(
|
||||
"000001.SZ", start_date="20240101", end_date="20240131", factors=["tor", "vr"]
|
||||
)
|
||||
|
||||
# Verify structure
|
||||
assert isinstance(result, pd.DataFrame)
|
||||
if not result.empty:
|
||||
# Check base fields
|
||||
for field in EXPECTED_BASE_FIELDS:
|
||||
assert field in result.columns, f"Missing base field: {field}"
|
||||
# Check factor fields
|
||||
for field in EXPECTED_FACTOR_FIELDS:
|
||||
assert field in result.columns, f"Missing factor field: {field}"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 运行 pytest 单元测试(真实API调用)
|
||||
pytest.main([__file__, "-v"])
|
||||
@@ -1,242 +0,0 @@
|
||||
"""Tests for DuckDB storage validation.
|
||||
|
||||
Validates two key points:
|
||||
1. All stocks from stock_basic.csv are saved in daily table
|
||||
2. No abnormal data with very few data points (< 10 rows per stock)
|
||||
|
||||
使用 3 个月的真实数据进行测试 (2024年1月-3月)
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import pandas as pd
|
||||
from datetime import datetime, timedelta
|
||||
from src.data.storage import Storage
|
||||
from src.data.api_wrappers.api_stock_basic import _get_csv_path
|
||||
|
||||
|
||||
class TestDailyStorageValidation:
|
||||
"""Test daily table storage integrity and completeness."""
|
||||
|
||||
# 测试数据时间范围:3个月
|
||||
TEST_START_DATE = "20240101"
|
||||
TEST_END_DATE = "20240331"
|
||||
|
||||
@pytest.fixture
|
||||
def storage(self):
|
||||
"""Create storage instance."""
|
||||
return Storage()
|
||||
|
||||
@pytest.fixture
|
||||
def stock_basic_df(self):
|
||||
"""Load stock basic data from CSV."""
|
||||
csv_path = _get_csv_path()
|
||||
if not csv_path.exists():
|
||||
pytest.skip(f"stock_basic.csv not found at {csv_path}")
|
||||
return pd.read_csv(csv_path)
|
||||
|
||||
@pytest.fixture
|
||||
def daily_df(self, storage):
|
||||
"""Load daily data from DuckDB (3 months)."""
|
||||
if not storage.exists("daily"):
|
||||
pytest.skip("daily table not found in DuckDB")
|
||||
|
||||
# 从 DuckDB 加载 3 个月数据
|
||||
df = storage.load(
|
||||
"daily", start_date=self.TEST_START_DATE, end_date=self.TEST_END_DATE
|
||||
)
|
||||
|
||||
if df.empty:
|
||||
pytest.skip(
|
||||
f"No data found for period {self.TEST_START_DATE} to {self.TEST_END_DATE}"
|
||||
)
|
||||
|
||||
return df
|
||||
|
||||
def test_duckdb_connection(self, storage):
|
||||
"""Test DuckDB connection and basic operations."""
|
||||
assert storage.exists("daily") or True # 至少连接成功
|
||||
print(f"[TEST] DuckDB connection successful")
|
||||
|
||||
def test_load_3months_data(self, storage):
|
||||
"""Test loading 3 months of data from DuckDB."""
|
||||
df = storage.load(
|
||||
"daily", start_date=self.TEST_START_DATE, end_date=self.TEST_END_DATE
|
||||
)
|
||||
|
||||
if df.empty:
|
||||
pytest.skip("No data available for testing period")
|
||||
|
||||
# 验证数据覆盖范围
|
||||
dates = df["trade_date"].astype(str)
|
||||
min_date = dates.min()
|
||||
max_date = dates.max()
|
||||
|
||||
print(f"[TEST] Loaded {len(df)} rows from {min_date} to {max_date}")
|
||||
assert len(df) > 0, "Should have data in the 3-month period"
|
||||
|
||||
def test_all_stocks_saved(self, storage, stock_basic_df, daily_df):
|
||||
"""Verify all stocks from stock_basic are saved in daily table.
|
||||
|
||||
This test ensures data completeness - every stock in stock_basic
|
||||
should have corresponding data in daily table.
|
||||
"""
|
||||
if daily_df.empty:
|
||||
pytest.fail("daily table is empty for test period")
|
||||
|
||||
# Get unique stock codes from both sources
|
||||
expected_codes = set(stock_basic_df["ts_code"].dropna().unique())
|
||||
actual_codes = set(daily_df["ts_code"].dropna().unique())
|
||||
|
||||
# Check for missing stocks
|
||||
missing_codes = expected_codes - actual_codes
|
||||
|
||||
if missing_codes:
|
||||
missing_list = sorted(missing_codes)
|
||||
# Show first 20 missing stocks as sample
|
||||
sample = missing_list[:20]
|
||||
msg = f"Found {len(missing_codes)} stocks missing from daily table:\n"
|
||||
msg += f"Sample missing: {sample}\n"
|
||||
if len(missing_list) > 20:
|
||||
msg += f"... and {len(missing_list) - 20} more"
|
||||
# 对于3个月数据,允许部分股票缺失(可能是新股或未上市)
|
||||
print(f"[WARNING] {msg}")
|
||||
# 只验证至少有80%的股票存在
|
||||
coverage = len(actual_codes) / len(expected_codes) * 100
|
||||
assert coverage >= 80, (
|
||||
f"Stock coverage {coverage:.1f}% is below 80% threshold"
|
||||
)
|
||||
else:
|
||||
print(
|
||||
f"[TEST] All {len(expected_codes)} stocks from stock_basic are present in daily table"
|
||||
)
|
||||
|
||||
def test_no_stock_with_insufficient_data(self, storage, daily_df):
|
||||
"""Verify no stock has abnormally few data points (< 5 rows in 3 months).
|
||||
|
||||
Stocks with very few data points may indicate sync failures,
|
||||
delisted stocks not properly handled, or data corruption.
|
||||
"""
|
||||
if daily_df.empty:
|
||||
pytest.fail("daily table is empty for test period")
|
||||
|
||||
# Count rows per stock
|
||||
stock_counts = daily_df.groupby("ts_code").size()
|
||||
|
||||
# Find stocks with less than 5 data points in 3 months
|
||||
insufficient_stocks = stock_counts[stock_counts < 5]
|
||||
|
||||
if not insufficient_stocks.empty:
|
||||
# Separate into categories for better reporting
|
||||
empty_stocks = stock_counts[stock_counts == 0]
|
||||
very_few_stocks = stock_counts[(stock_counts > 0) & (stock_counts < 5)]
|
||||
|
||||
msg = f"Found {len(insufficient_stocks)} stocks with insufficient data (< 5 rows in 3 months):\n"
|
||||
|
||||
if not empty_stocks.empty:
|
||||
msg += f"\nEmpty stocks (0 rows): {len(empty_stocks)}\n"
|
||||
sample = sorted(empty_stocks.index[:10].tolist())
|
||||
msg += f"Sample: {sample}"
|
||||
|
||||
if not very_few_stocks.empty:
|
||||
msg += f"\nVery few data points (1-4 rows): {len(very_few_stocks)}\n"
|
||||
# Show counts for these stocks
|
||||
sample = very_few_stocks.sort_values().head(20)
|
||||
msg += "Sample (ts_code: count):\n"
|
||||
for code, count in sample.items():
|
||||
msg += f" {code}: {count} rows\n"
|
||||
|
||||
# 对于3个月数据,允许少量异常,但比例不能超过5%
|
||||
if len(insufficient_stocks) / len(stock_counts) > 0.05:
|
||||
pytest.fail(msg)
|
||||
else:
|
||||
print(f"[WARNING] {msg}")
|
||||
|
||||
print(f"[TEST] All stocks have sufficient data (>= 5 rows in 3 months)")
|
||||
|
||||
def test_data_integrity_basic(self, storage, daily_df):
|
||||
"""Basic data integrity checks for daily table."""
|
||||
if daily_df.empty:
|
||||
pytest.fail("daily table is empty for test period")
|
||||
|
||||
# Check required columns exist
|
||||
required_columns = ["ts_code", "trade_date"]
|
||||
missing_columns = [
|
||||
col for col in required_columns if col not in daily_df.columns
|
||||
]
|
||||
|
||||
if missing_columns:
|
||||
pytest.fail(f"Missing required columns: {missing_columns}")
|
||||
|
||||
# Check for null values in key columns
|
||||
null_ts_code = daily_df["ts_code"].isna().sum()
|
||||
null_trade_date = daily_df["trade_date"].isna().sum()
|
||||
|
||||
if null_ts_code > 0:
|
||||
pytest.fail(f"Found {null_ts_code} rows with null ts_code")
|
||||
if null_trade_date > 0:
|
||||
pytest.fail(f"Found {null_trade_date} rows with null trade_date")
|
||||
|
||||
print(f"[TEST] Data integrity check passed for 3-month period")
|
||||
|
||||
def test_polars_export(self, storage):
|
||||
"""Test Polars export functionality."""
|
||||
if not storage.exists("daily"):
|
||||
pytest.skip("daily table not found")
|
||||
|
||||
import polars as pl
|
||||
|
||||
# 测试 load_polars 方法
|
||||
df = storage.load_polars(
|
||||
"daily", start_date=self.TEST_START_DATE, end_date=self.TEST_END_DATE
|
||||
)
|
||||
|
||||
assert isinstance(df, pl.DataFrame), "Should return Polars DataFrame"
|
||||
print(f"[TEST] Polars export successful: {len(df)} rows")
|
||||
|
||||
def test_stock_data_coverage_report(self, storage, daily_df):
|
||||
"""Generate a summary report of stock data coverage.
|
||||
|
||||
This test provides visibility into data distribution without failing.
|
||||
"""
|
||||
if daily_df.empty:
|
||||
pytest.skip("daily table is empty - cannot generate report")
|
||||
|
||||
stock_counts = daily_df.groupby("ts_code").size()
|
||||
|
||||
# Calculate statistics
|
||||
total_stocks = len(stock_counts)
|
||||
min_count = stock_counts.min()
|
||||
max_count = stock_counts.max()
|
||||
median_count = stock_counts.median()
|
||||
mean_count = stock_counts.mean()
|
||||
|
||||
# Distribution buckets (adjusted for 3-month period, ~60 trading days)
|
||||
very_low = (stock_counts < 5).sum()
|
||||
low = ((stock_counts >= 5) & (stock_counts < 20)).sum()
|
||||
medium = ((stock_counts >= 20) & (stock_counts < 40)).sum()
|
||||
high = (stock_counts >= 40).sum()
|
||||
|
||||
report = f"""
|
||||
=== Stock Data Coverage Report (3 months: {self.TEST_START_DATE} to {self.TEST_END_DATE}) ===
|
||||
Total stocks: {total_stocks}
|
||||
Data points per stock:
|
||||
Min: {min_count}
|
||||
Max: {max_count}
|
||||
Median: {median_count:.0f}
|
||||
Mean: {mean_count:.1f}
|
||||
|
||||
Distribution:
|
||||
< 5 rows: {very_low} stocks ({very_low / total_stocks * 100:.1f}%)
|
||||
5-19: {low} stocks ({low / total_stocks * 100:.1f}%)
|
||||
20-39: {medium} stocks ({medium / total_stocks * 100:.1f}%)
|
||||
>= 40: {high} stocks ({high / total_stocks * 100:.1f}%)
|
||||
"""
|
||||
print(report)
|
||||
|
||||
# This is an informational test - it should not fail
|
||||
# But we assert to mark it as passed
|
||||
assert total_stocks > 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v", "-s"])
|
||||
244
tests/test_financial_price_merge.py
Normal file
244
tests/test_financial_price_merge.py
Normal file
@@ -0,0 +1,244 @@
|
||||
"""财务数据与行情数据拼接测试。
|
||||
|
||||
测试场景:
|
||||
1. 普通财务数据:正常公告,之后无修改
|
||||
2. 隔日修改:公告后几天发布修正版
|
||||
3. 当日修改:同一天发布多版,取 update_flag=1 的
|
||||
4. 边界条件:财务数据缺失、行情数据早于最早财务数据
|
||||
"""
|
||||
|
||||
import polars as pl
|
||||
from datetime import date
|
||||
from src.data.financial_loader import FinancialLoader
|
||||
|
||||
|
||||
def create_mock_price_data() -> pl.DataFrame:
|
||||
"""创建模拟行情数据。"""
|
||||
return pl.DataFrame(
|
||||
{
|
||||
"ts_code": ["000001.SZ"] * 10,
|
||||
"trade_date": [
|
||||
"20240101",
|
||||
"20240102",
|
||||
"20240103",
|
||||
"20240104",
|
||||
"20240105",
|
||||
"20240108",
|
||||
"20240109",
|
||||
"20240110",
|
||||
"20240111",
|
||||
"20240112",
|
||||
],
|
||||
"close": [10.0, 10.2, 10.3, 10.1, 10.5, 10.6, 10.4, 10.7, 10.8, 10.9],
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def create_mock_financial_data() -> pl.DataFrame:
|
||||
"""创建模拟财务数据(覆盖多种场景)。
|
||||
|
||||
注意:f_ann_date 必须是 Date 类型(与数据库保持一致)。
|
||||
"""
|
||||
return pl.DataFrame(
|
||||
{
|
||||
"ts_code": ["000001.SZ", "000001.SZ", "000001.SZ", "000001.SZ"],
|
||||
# 场景1: 2023Q3 报告,正常公告
|
||||
# 场景2: 同日多版(update_flag 区分)
|
||||
# 场景3: 隔日修改
|
||||
"f_ann_date": [
|
||||
date(2024, 1, 2),
|
||||
date(2024, 1, 2),
|
||||
date(2024, 1, 5),
|
||||
date(2024, 1, 10),
|
||||
],
|
||||
"end_date": ["20230930", "20230930", "20230930", "20231231"],
|
||||
"report_type": [1, 1, 1, 1], # 整数类型(与数据库一致)
|
||||
"update_flag": [0, 1, 1, 1], # 整数类型(与数据库一致)
|
||||
"net_profit": [1000000.0, 1100000.0, 1100000.0, 1200000.0],
|
||||
"revenue": [5000000.0, 5200000.0, 5200000.0, 6000000.0],
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def test_financial_data_cleaning():
|
||||
"""测试财务数据清洗逻辑。"""
|
||||
print("=== 测试 1: 财务数据清洗 ===")
|
||||
|
||||
df_finance = create_mock_financial_data()
|
||||
print("原始财务数据:")
|
||||
print(df_finance)
|
||||
|
||||
loader = FinancialLoader()
|
||||
|
||||
# 手动执行清洗(模拟 load_financial_data 的逻辑)
|
||||
# 步骤1: 仅保留合并报表
|
||||
df = df_finance.filter(pl.col("report_type") == 1)
|
||||
|
||||
# 步骤2: 按 update_flag 降序排列后去重
|
||||
df = df.with_columns(
|
||||
[pl.col("update_flag").cast(pl.Int32).alias("update_flag_int")]
|
||||
)
|
||||
|
||||
df = df.sort(
|
||||
["ts_code", "f_ann_date", "update_flag_int"], descending=[False, False, True]
|
||||
)
|
||||
|
||||
df = df.unique(subset=["ts_code", "f_ann_date"], keep="first")
|
||||
df = df.drop("update_flag_int")
|
||||
|
||||
# 步骤3: 排序(f_ann_date 已经是 Date 类型)
|
||||
df = df.sort(["ts_code", "f_ann_date"])
|
||||
|
||||
print("\n清洗后的财务数据:")
|
||||
print(df)
|
||||
|
||||
# 验证:应该有3条记录(第1-2行去重为1条,第3行,第4行)
|
||||
assert len(df) == 3, f"清洗后应该有3条记录,实际有 {len(df)} 条"
|
||||
|
||||
# 验证:2024-01-02 的 update_flag 应该是 1
|
||||
row_jan02 = df.filter(pl.col("f_ann_date") == date(2024, 1, 2))
|
||||
assert len(row_jan02) == 1, "应该有1条 2024-01-02 的记录"
|
||||
assert row_jan02["update_flag"][0] == 1, "update_flag 应该为 1"
|
||||
assert row_jan02["net_profit"][0] == 1100000.0, "net_profit 应该为 1100000"
|
||||
|
||||
print("\n[通过] 财务数据清洗测试通过!")
|
||||
return df
|
||||
|
||||
|
||||
def test_financial_price_merge():
|
||||
"""测试财务数据拼接逻辑(无未来函数验证)。"""
|
||||
print("\n=== 测试 2: 财务数据与行情数据拼接 ===")
|
||||
|
||||
df_price = create_mock_price_data()
|
||||
df_finance_raw = create_mock_financial_data()
|
||||
|
||||
loader = FinancialLoader()
|
||||
|
||||
# 步骤1: 清洗财务数据(手动执行)
|
||||
# 注意:f_ann_date 已经是 Date 类型,不需要转换
|
||||
df_finance = df_finance_raw.filter(pl.col("report_type") == 1)
|
||||
df_finance = df_finance.with_columns(
|
||||
[pl.col("update_flag").cast(pl.Int32).alias("update_flag_int")]
|
||||
)
|
||||
df_finance = df_finance.sort(
|
||||
["ts_code", "f_ann_date", "update_flag_int"], descending=[False, False, True]
|
||||
)
|
||||
df_finance = df_finance.unique(subset=["ts_code", "f_ann_date"], keep="first")
|
||||
df_finance = df_finance.drop("update_flag_int")
|
||||
df_finance = df_finance.sort(["ts_code", "f_ann_date"])
|
||||
|
||||
print("清洗后的财务数据:")
|
||||
print(df_finance)
|
||||
|
||||
# 步骤2: 转换行情数据日期为 Date 类型
|
||||
df_price = df_price.with_columns(
|
||||
[pl.col("trade_date").str.strptime(pl.Date, "%Y%m%d").alias("trade_date")]
|
||||
)
|
||||
df_price = df_price.sort(["ts_code", "trade_date"])
|
||||
|
||||
# 步骤3: 拼接
|
||||
financial_cols = ["net_profit", "revenue"]
|
||||
merged = loader.merge_financial_with_price(df_price, df_finance, financial_cols)
|
||||
|
||||
# 步骤4: 转回字符串格式
|
||||
merged = merged.with_columns(
|
||||
[pl.col("trade_date").dt.strftime("%Y%m%d").alias("trade_date")]
|
||||
)
|
||||
|
||||
print("\n拼接结果:")
|
||||
print(merged)
|
||||
|
||||
# 验证无未来函数:
|
||||
# 20240101 之前不应有 2023Q3 数据(因为 20240102 才公告)
|
||||
jan01 = merged.filter(pl.col("trade_date") == "20240101")
|
||||
assert jan01["net_profit"].is_null().all(), (
|
||||
"2024-01-01 不应有 2023Q3 数据(尚未公告)"
|
||||
)
|
||||
print("[验证 1] 2024-01-01 net_profit 为 null - 正确(公告前无数据)")
|
||||
|
||||
# 20240102 及之后应该看到 net_profit=1100000(update_flag=1 的版本)
|
||||
jan02 = merged.filter(pl.col("trade_date") == "20240102")
|
||||
assert jan02["net_profit"][0] == 1100000.0, "2024-01-02 应使用 update_flag=1 的数据"
|
||||
print("[验证 2] 2024-01-02 net_profit=1100000 - 正确(使用 update_flag=1)")
|
||||
|
||||
# 20240104 应延续使用 2023Q3 数据
|
||||
jan04 = merged.filter(pl.col("trade_date") == "20240104")
|
||||
assert jan04["net_profit"][0] == 1100000.0, "2024-01-04 应延续使用 2023Q3 数据"
|
||||
print("[验证 3] 2024-01-04 net_profit=1100000 - 正确(延续使用)")
|
||||
|
||||
# 20240110 应切换到 2023Q4 数据(新公告)
|
||||
jan10 = merged.filter(pl.col("trade_date") == "20240110")
|
||||
assert jan10["net_profit"][0] == 1200000.0, "2024-01-10 应切换到 2023Q4 数据"
|
||||
print("[验证 4] 2024-01-10 net_profit=1200000 - 正确(新财报公告)")
|
||||
|
||||
# 20240112 应继续延续使用 2023Q4 数据
|
||||
jan12 = merged.filter(pl.col("trade_date") == "20240112")
|
||||
assert jan12["net_profit"][0] == 1200000.0, "2024-01-12 应继续使用 2023Q4 数据"
|
||||
print("[验证 5] 2024-01-12 net_profit=1200000 - 正确(延续使用)")
|
||||
|
||||
print("\n[通过] 所有验证通过,无未来函数!")
|
||||
return merged
|
||||
|
||||
|
||||
def test_empty_financial_data():
|
||||
"""测试财务数据为空的情况。"""
|
||||
print("\n=== 测试 3: 空财务数据场景 ===")
|
||||
|
||||
df_price = create_mock_price_data()
|
||||
df_empty = pl.DataFrame()
|
||||
|
||||
loader = FinancialLoader()
|
||||
|
||||
# 转换行情数据日期为 Date 类型
|
||||
df_price = df_price.with_columns(
|
||||
[pl.col("trade_date").str.strptime(pl.Date, "%Y%m%d").alias("trade_date")]
|
||||
)
|
||||
df_price = df_price.sort(["ts_code", "trade_date"])
|
||||
|
||||
# 拼接空财务数据
|
||||
merged = loader.merge_financial_with_price(df_price, df_empty, ["net_profit"])
|
||||
|
||||
# 转回字符串格式
|
||||
merged = merged.with_columns(
|
||||
[pl.col("trade_date").dt.strftime("%Y%m%d").alias("trade_date")]
|
||||
)
|
||||
|
||||
# 验证财务列为空
|
||||
assert merged["net_profit"].is_null().all(), (
|
||||
"财务数据为空时,net_profit 应全为 null"
|
||||
)
|
||||
|
||||
print("空财务数据拼接结果:")
|
||||
print(merged)
|
||||
print("\n[通过] 空财务数据场景测试通过!")
|
||||
|
||||
|
||||
def run_all_tests():
|
||||
"""运行所有测试。"""
|
||||
print("开始运行财务数据拼接功能测试...\n")
|
||||
print("=" * 60)
|
||||
|
||||
try:
|
||||
# 测试 1: 数据清洗
|
||||
test_financial_data_cleaning()
|
||||
|
||||
# 测试 2: 数据拼接
|
||||
test_financial_price_merge()
|
||||
|
||||
# 测试 3: 空数据场景
|
||||
test_empty_financial_data()
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("所有测试通过!")
|
||||
print("=" * 60)
|
||||
|
||||
except AssertionError as e:
|
||||
print(f"\n[失败] 测试断言失败: {e}")
|
||||
raise
|
||||
except Exception as e:
|
||||
print(f"\n[错误] 测试执行出错: {e}")
|
||||
raise
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_all_tests()
|
||||
Reference in New Issue
Block a user