diff --git a/src/data/api_wrappers/API_INTERFACE_SPEC.md b/docs/api/API_INTERFACE_SPEC.md similarity index 100% rename from src/data/api_wrappers/API_INTERFACE_SPEC.md rename to docs/api/API_INTERFACE_SPEC.md diff --git a/src/data/api_wrappers/api.md b/docs/api/api.md similarity index 100% rename from src/data/api_wrappers/api.md rename to docs/api/api.md diff --git a/src/data/api_wrappers/financial_data/financial_api.md b/docs/api/financial_api.md similarity index 100% rename from src/data/api_wrappers/financial_data/financial_api.md rename to docs/api/financial_api.md diff --git a/src/data/financial_loader.py b/src/data/financial_loader.py new file mode 100644 index 0000000..8e2fd71 --- /dev/null +++ b/src/data/financial_loader.py @@ -0,0 +1,181 @@ +"""财务数据加载与清洗模块。 + +提供财务数据的加载、清洗和与行情数据拼接功能。 +""" + +from datetime import datetime, timedelta +from typing import Optional, List + +import polars as pl + +from src.data.storage import Storage + + +class FinancialLoader: + """财务数据加载器。 + + 负责财务数据的清洗、去重,以及与行情数据的 as-of join。 + + Attributes: + storage: DuckDB 存储实例 + """ + + def __init__(self): + self.storage = Storage() + + def load_financial_data( + self, + table_name: str, + columns: List[str], + start_date: str, + end_date: str, + ts_code: Optional[str] = None, + ) -> pl.DataFrame: + """加载并清洗财务数据。 + + 数据清洗流程: + 1. 仅保留 report_type == '1'(合并报表) + 2. 按 (ts_code, f_ann_date) 分组,按 update_flag 降序去重 + 3. 转换为 Date 类型,按 ts_code 和 f_ann_date 排序 + + Args: + table_name: 财务表名(如 'financial_income') + columns: 需要的字段列表(必须包含核心字段) + start_date: 数据开始日期(YYYYMMDD) + end_date: 数据结束日期(YYYYMMDD) + ts_code: 可选,单个股票代码过滤 + + Returns: + 清洗后的 Polars DataFrame,已排序,f_ann_date 为 pl.Date 类型 + """ + # 确保包含必要字段 + required_cols = {"ts_code", "f_ann_date", "report_type", "update_flag"} + query_cols = list(set(columns) | required_cols) + + # 从数据库加载原始数据 + df = self._load_from_db(table_name, query_cols, start_date, end_date, ts_code) + + if df.is_empty(): + return df + + # 步骤1: 仅保留合并报表 (report_type 可能是字符串或整数) + df = df.filter(pl.col("report_type") == 1) + + # 步骤2: 按 update_flag 降序排列后去重 + df = df.with_columns( + [pl.col("update_flag").cast(pl.Int32).alias("update_flag_int")] + ) + + # 排序:ts_code, f_ann_date 升序;update_flag 降序 + df = df.sort( + ["ts_code", "f_ann_date", "update_flag_int"], + descending=[False, False, True], + ) + + # 去重:保留每个 (ts_code, f_ann_date) 的第一条(update_flag 最高的) + df = df.unique(subset=["ts_code", "f_ann_date"], keep="first") + + # 移除临时列 + df = df.drop("update_flag_int") + + # 步骤3: 确保 f_ann_date 是 Date 类型并排序 + # 数据库返回的必须是 Date 类型,如果不是则报错 + if df["f_ann_date"].dtype != pl.Date: + raise TypeError( + f"f_ann_date 必须是 Date 类型,实际类型为 {df['f_ann_date'].dtype}. " + f"请检查数据库表结构,确保日期字段为 DATE 类型。" + ) + + # 最终排序(join_asof 要求) + df = df.sort(["ts_code", "f_ann_date"]) + + return df + + def merge_financial_with_price( + self, + df_price: pl.DataFrame, + df_financial: pl.DataFrame, + financial_cols: List[str], + ) -> pl.DataFrame: + """将财务数据拼接到行情数据。 + + 使用 join_asof 向后匹配:对于每个交易日,找到最近的历史公告数据。 + + 注意:输入的 df_price 的 trade_date 必须是 pl.Date 类型且已排序。 + + Args: + df_price: 行情数据 DataFrame,必须包含 ts_code, trade_date(Date 类型) + df_financial: 财务数据 DataFrame(已通过 load_financial_data 清洗,Date 类型) + financial_cols: 需要从财务表保留的字段列表 + + Returns: + 拼接后的 DataFrame(trade_date 仍为 Date 类型) + """ + if df_financial.is_empty(): + # 财务数据为空,返回行情数据(财务列为空) + for col in financial_cols: + if col not in df_price.columns: + df_price = df_price.with_columns([pl.lit(None).alias(col)]) + return df_price + + # 执行 asof join: 向后寻找最近的历史数据 + # strategy='backward': 对于每个 trade_date,找 f_ann_date <= trade_date 的最新记录 + merged = df_price.join_asof( + df_financial.select(["ts_code", "f_ann_date"] + financial_cols), + left_on="trade_date", + right_on="f_ann_date", + by="ts_code", + strategy="backward", + ) + + return merged + + def _load_from_db( + self, + table_name: str, + columns: List[str], + start_date: str, + end_date: str, + ts_code: Optional[str] = None, + ) -> pl.DataFrame: + """从数据库加载财务数据(内部方法)。""" + conn = self.storage._connection + + cols_str = ", ".join(f'"{c}"' for c in columns) + + start_dt = datetime.strptime(start_date, "%Y%m%d").date() + end_dt = datetime.strptime(end_date, "%Y%m%d").date() + + conditions = [f"f_ann_date BETWEEN '{start_dt}' AND '{end_dt}'"] + if ts_code: + conditions.append(f"ts_code = '{ts_code}'") + + where_clause = " AND ".join(conditions) + query = f"SELECT {cols_str} FROM {table_name} WHERE {where_clause} ORDER BY ts_code, f_ann_date" + + try: + df = conn.sql(query).pl() + return df + except Exception as e: + print(f"[FinancialLoader] 加载 {table_name} 失败: {e}") + return pl.DataFrame() + + def get_date_range_with_lookback( + self, + start_date: str, + end_date: str, + lookback_years: int = 1, + ) -> tuple[str, str]: + """计算包含回看期的日期范围。 + + Args: + start_date: 原始开始日期(YYYYMMDD) + end_date: 原始结束日期(YYYYMMDD) + lookback_years: 回看年数(默认1年) + + Returns: + (扩展后的开始日期, 结束日期) + """ + start_dt = datetime.strptime(start_date, "%Y%m%d") + adjusted_start = start_dt - timedelta(days=365 * lookback_years) + return adjusted_start.strftime("%Y%m%d"), end_date diff --git a/src/experiment/regression.py b/src/experiment/regression.py index bea1e03..2cdd243 100644 --- a/src/experiment/regression.py +++ b/src/experiment/regression.py @@ -61,6 +61,8 @@ def create_factors_with_strings(engine: FactorEngine) -> List[str]: "market_cap_rank": "cs_rank(total_mv)", # 7. 价格位置因子 "high_low_ratio": "(close - ts_min(low, 20)) / (ts_max(high, 20) - ts_min(low, 20) + 1e-8)", + "n_income": "n_income" + } # Label 因子(单独定义,不参与训练) diff --git a/src/factors/engine/data_router.py b/src/factors/engine/data_router.py index e784449..65f64fc 100644 --- a/src/factors/engine/data_router.py +++ b/src/factors/engine/data_router.py @@ -3,6 +3,7 @@ 按需取数、组装核心宽表。 负责根据数据规格从数据源拉取数据,并组装成统一的宽表格式。 支持内存数据源(用于测试)和真实数据库连接。 +支持标准等值匹配和 asof_backward(财务数据)两种拼接模式。 """ from typing import Any, Dict, List, Optional, Set, Union @@ -12,6 +13,7 @@ import polars as pl from src.factors.engine.data_spec import DataSpec from src.data.storage import Storage +from src.data.financial_loader import FinancialLoader class DataRouter: @@ -37,11 +39,13 @@ class DataRouter: self._cache: Dict[str, pl.DataFrame] = {} self._lock = threading.Lock() - # 数据库模式下初始化 Storage + # 数据库模式下初始化 Storage 和 FinancialLoader if not self.is_memory_mode: self._storage = Storage() + self._financial_loader = FinancialLoader() else: self._storage = None + self._financial_loader = None def fetch_data( self, @@ -75,23 +79,122 @@ class DataRouter: required_tables[spec.table] = set() required_tables[spec.table].update(spec.columns) - # 从数据源获取各表数据 + # 从数据源获取各表数据(使用合并后的 required_tables,避免重复加载) table_data = {} for table_name, columns in required_tables.items(): - df = self._load_table( - table_name=table_name, - columns=list(columns), + # 判断是标准表还是财务表 + is_financial = any( + s.table == table_name and s.join_type == "asof_backward" + for s in data_specs + ) + + if is_financial: + # 财务表:找到对应的 spec 获取 join 配置 + financial_spec = next( + s + for s in data_specs + if s.table == table_name and s.join_type == "asof_backward" + ) + spec = DataSpec( + table=table_name, + columns=list(columns), + join_type="asof_backward", + left_on=financial_spec.left_on, + right_on=financial_spec.right_on, + ) + else: + # 标准表 + spec = DataSpec( + table=table_name, + columns=list(columns), + join_type="standard", + ) + + df = self._load_table_from_spec( + spec=spec, start_date=start_date, end_date=end_date, stock_codes=stock_codes, ) table_data[table_name] = df - # 组装核心宽表 - core_table = self._assemble_wide_table(table_data, required_tables) + # 组装核心宽表(支持多种 join 类型) + core_table = self._assemble_wide_table_with_specs( + table_data, data_specs, start_date, end_date + ) return core_table + def _load_table_from_spec( + self, + spec: DataSpec, + start_date: str, + end_date: str, + stock_codes: Optional[List[str]] = None, + ) -> pl.DataFrame: + """根据数据规格加载单个表的数据。 + + 根据 spec.join_type 选择不同的加载方式: + - standard: 使用原有逻辑,基于 trade_date + - asof_backward: 使用 FinancialLoader,基于 f_ann_date,扩展回看期 + + Args: + spec: 数据规格 + start_date: 开始日期 + end_date: 结束日期 + stock_codes: 股票代码过滤 + + Returns: + 过滤后的 DataFrame + """ + cache_key = ( + f"{spec.table}_{spec.join_type}_{start_date}_{end_date}_{stock_codes}" + ) + + with self._lock: + if cache_key in self._cache: + return self._cache[cache_key] + + if spec.join_type == "asof_backward": + # 财务数据使用 FinancialLoader + if self._financial_loader is None: + raise RuntimeError("FinancialLoader 未初始化") + + # 扩展日期范围(回看1年) + adjusted_start, _ = self._financial_loader.get_date_range_with_lookback( + start_date, end_date + ) + + # 处理 stock_codes + ts_code = stock_codes[0] if stock_codes and len(stock_codes) == 1 else None + + df = self._financial_loader.load_financial_data( + table_name=spec.table, + columns=spec.columns, + start_date=adjusted_start, + end_date=end_date, + ts_code=ts_code, + ) + + # 如果 stock_codes 是列表且长度 > 1,在内存中过滤 + if stock_codes is not None and len(stock_codes) > 1: + df = df.filter(pl.col("ts_code").is_in(stock_codes)) + + else: + # 标准表使用原有逻辑 + df = self._load_table( + table_name=spec.table, + columns=spec.columns, + start_date=start_date, + end_date=end_date, + stock_codes=stock_codes, + ) + + with self._lock: + self._cache[cache_key] = df + + return df + def _load_table( self, table_name: str, @@ -255,6 +358,119 @@ class DataRouter: return result + def _assemble_wide_table_with_specs( + self, + table_data: Dict[str, pl.DataFrame], + data_specs: List[DataSpec], + start_date: str, + end_date: str, + ) -> pl.DataFrame: + """组装多表数据为核心宽表(支持多种 join 类型)。 + + 支持标准等值匹配和 asof_backward 两种模式。 + + 性能优化: + - 在开始时统一将 trade_date 转为 pl.Date + - 所有 asof join 全部在 pl.Date 类型下完成 + - 返回前统一转回字符串格式 + + Args: + table_data: 表名到 DataFrame 的映射 + data_specs: 数据规格列表 + start_date: 开始日期 + end_date: 结束日期 + + Returns: + 组装后的宽表 + """ + if not table_data: + raise ValueError("没有数据可组装") + + # 从 data_specs 判断每个表的 join 类型 + table_join_types = {} + for spec in data_specs: + if spec.table not in table_join_types: + table_join_types[spec.table] = spec.join_type + + # 分离标准表和 asof 表(基于 table_data 的表名,避免重复) + standard_tables = [ + t + for t in table_data.keys() + if table_join_types.get(t, "standard") == "standard" + ] + asof_tables = [ + t for t in table_data.keys() if table_join_types.get(t) == "asof_backward" + ] + + # 先合并所有标准表(使用 trade_date) + base_df = None + for table_name in standard_tables: + df = table_data[table_name] + if base_df is None: + base_df = df + else: + # 使用 ts_code 和 trade_date 作为 join 键 + # 注:根据动态路由原则,除 ts_code/trade_date 外不应有重复字段 + # 如果出现重复,说明 SchemaCache 的字段映射有问题 + base_df = base_df.join( + df, + on=["ts_code", "trade_date"], + how="left", + ) + + if base_df is None: + raise ValueError("至少需要一张标准行情表作为基础") + + # 【性能优化】统一转换 trade_date 为 Date 类型(只转换一次) + if asof_tables: + base_df = base_df.with_columns( + [ + pl.col("trade_date") + .str.strptime(pl.Date, "%Y%m%d") + .alias("trade_date") + ] + ) + # 确保已排序(join_asof 要求) + base_df = base_df.sort(["ts_code", "trade_date"]) + + # 逐个合并 asof 表(所有 join 都在 Date 类型下进行) + for table_name in asof_tables: + df_financial = table_data[table_name] + # 提取需要保留的字段(排除 join 键和元数据字段) + # 从 data_specs 中找到对应表的 columns + table_columns = set() + for spec in data_specs: + if spec.table == table_name: + table_columns.update(spec.columns) + + financial_cols = [ + c + for c in table_columns + if c + not in [ + "ts_code", + "f_ann_date", + "report_type", + "update_flag", + "end_date", + ] + ] + + if self._financial_loader is None: + raise RuntimeError("FinancialLoader 未初始化") + + base_df = self._financial_loader.merge_financial_with_price( + base_df, df_financial, financial_cols + ) + + # 【性能优化】所有 asof join 完成后,统一转回字符串格式 + if asof_tables: + base_df = base_df.with_columns( + [pl.col("trade_date").dt.strftime("%Y%m%d").alias("trade_date")] + ) + + return base_df + def clear_cache(self) -> None: """清除数据缓存。""" with self._lock: diff --git a/src/factors/engine/data_spec.py b/src/factors/engine/data_spec.py index a41ad24..f318aad 100644 --- a/src/factors/engine/data_spec.py +++ b/src/factors/engine/data_spec.py @@ -4,24 +4,38 @@ """ from dataclasses import dataclass, field -from typing import Any, Dict, List, Optional, Set, Union +from typing import Any, Dict, List, Literal, Optional, Set, Union import polars as pl @dataclass class DataSpec: - """数据规格定义。 + """数据规格定义(支持多表类型)。 - 描述因子计算所需的数据表和字段。 + 描述因子计算所需的数据表和字段,支持多种拼接类型。 Attributes: table: 数据表名称 columns: 需要的字段列表 + join_type: 拼接类型 + - "standard": 标准等值匹配(默认) + - "asof_backward": 向后寻找最近历史数据(财务数据用) + left_on: 左表 join 键(asof 模式下必须指定) + right_on: 右表 join 键(asof 模式下必须指定) """ table: str columns: List[str] + join_type: Literal["standard", "asof_backward"] = "standard" + left_on: Optional[str] = None # 行情表日期列名 + right_on: Optional[str] = None # 财务表日期列名 + + def __post_init__(self): + """验证 asof_backward 模式的参数。""" + if self.join_type == "asof_backward": + if not self.left_on or not self.right_on: + raise ValueError("asof_backward 模式必须指定 left_on 和 right_on") @dataclass diff --git a/src/factors/engine/planner.py b/src/factors/engine/planner.py index 05c756e..2139f62 100644 --- a/src/factors/engine/planner.py +++ b/src/factors/engine/planner.py @@ -72,9 +72,10 @@ class ExecutionPlanner: dependencies: Set[str], expression: Node, ) -> List[DataSpec]: - """从依赖推导数据规格。 + """从依赖推导数据规格(支持财务数据自动识别)。 使用 SchemaCache 动态扫描数据库表结构,自动匹配字段到对应的表。 + 自动识别财务数据表并配置 asof_backward 模式。 表结构只扫描一次并缓存在内存中。 Args: @@ -90,11 +91,21 @@ class ExecutionPlanner: data_specs = [] for table_name, columns in table_to_fields.items(): - data_specs.append( - DataSpec( + if schema_cache.is_financial_table(table_name): + # 财务表使用 asof_backward 模式 + spec = DataSpec( + table=table_name, + columns=columns, + join_type="asof_backward", + left_on="trade_date", + right_on="f_ann_date", + ) + else: + # 标准表使用默认模式 + spec = DataSpec( table=table_name, columns=columns, ) - ) + data_specs.append(spec) return data_specs diff --git a/src/factors/engine/schema_cache.py b/src/factors/engine/schema_cache.py index 2253ede..f239d29 100644 --- a/src/factors/engine/schema_cache.py +++ b/src/factors/engine/schema_cache.py @@ -115,7 +115,7 @@ class SchemaCache: field_to_tables[field] = [] field_to_tables[field].append(table) - # 优先选择最常用的表(pro_bar > daily_basic > daily) + # 优先选择最常用的表(pro_bar > daily_basic > daily > financial) priority_order = {"pro_bar": 1, "daily_basic": 2, "daily": 3} self._field_to_table_map = {} @@ -124,6 +124,18 @@ class SchemaCache: sorted_tables = sorted(tables, key=lambda t: priority_order.get(t, 999)) self._field_to_table_map[field] = sorted_tables[0] + def is_financial_table(self, table_name: str) -> bool: + """判断是否为财务数据表。 + + Args: + table_name: 表名 + + Returns: + 是否为财务数据表 + """ + financial_prefixes = ("financial_", "income", "balance", "cashflow") + return table_name.lower().startswith(financial_prefixes) + def get_table_fields(self, table_name: str) -> List[str]: """获取指定表的字段列表。 diff --git a/tests/test_daily.py b/tests/test_daily.py deleted file mode 100644 index 648f4ba..0000000 --- a/tests/test_daily.py +++ /dev/null @@ -1,122 +0,0 @@ -"""Test for daily market data API. - -Tests the daily interface implementation against api.md requirements: -- A股日线行情所有输出字段 -- tor 换手率 -- vr 量比 -""" - -import pytest -import pandas as pd -from src.data.api_wrappers import get_daily - - -# Expected output fields according to api.md -EXPECTED_BASE_FIELDS = [ - "ts_code", # 股票代码 - "trade_date", # 交易日期 - "open", # 开盘价 - "high", # 最高价 - "low", # 最低价 - "close", # 收盘价 - "pre_close", # 昨收价 - "change", # 涨跌额 - "pct_chg", # 涨跌幅 - "vol", # 成交量 - "amount", # 成交额 -] - -EXPECTED_FACTOR_FIELDS = [ - "turnover_rate", # 换手率 (tor) - "volume_ratio", # 量比 (vr) -] - - -class TestGetDaily: - """Test cases for get_daily function with real API calls.""" - - def test_fetch_basic(self): - """Test basic daily data fetch with real API.""" - result = get_daily("000001.SZ", start_date="20240101", end_date="20240131") - - assert isinstance(result, pd.DataFrame) - assert len(result) >= 1 - assert result["ts_code"].iloc[0] == "000001.SZ" - - def test_fetch_with_factors(self): - """Test fetch with tor and vr factors.""" - result = get_daily( - "000001.SZ", - start_date="20240101", - end_date="20240131", - factors=["tor", "vr"], - ) - - assert isinstance(result, pd.DataFrame) - # Check all base fields are present - for field in EXPECTED_BASE_FIELDS: - assert field in result.columns, f"Missing base field: {field}" - # Check factor fields are present - for field in EXPECTED_FACTOR_FIELDS: - assert field in result.columns, f"Missing factor field: {field}" - - def test_output_fields_completeness(self): - """Verify all required output fields are returned.""" - result = get_daily("600000.SH") - - # Verify all base fields are present - assert set(EXPECTED_BASE_FIELDS).issubset(result.columns.tolist()), ( - f"Missing fields: {set(EXPECTED_BASE_FIELDS) - set(result.columns)}" - ) - - def test_empty_result(self): - """Test handling of empty results.""" - # 使用真实 API 测试无效股票代码的空结果 - result = get_daily("INVALID.SZ") - assert isinstance(result, pd.DataFrame) - assert result.empty - - def test_date_range_query(self): - """Test query with date range.""" - result = get_daily( - "000001.SZ", - start_date="20240101", - end_date="20240131", - ) - - assert isinstance(result, pd.DataFrame) - assert len(result) >= 1 - - def test_with_adj(self): - """Test fetch with adjustment type.""" - result = get_daily("000001.SZ", adj="qfq") - - assert isinstance(result, pd.DataFrame) - - -def test_integration(): - """Integration test with real Tushare API (requires valid token).""" - import os - - token = os.environ.get("TUSHARE_TOKEN") - if not token: - pytest.skip("TUSHARE_TOKEN not configured") - - result = get_daily( - "000001.SZ", start_date="20240101", end_date="20240131", factors=["tor", "vr"] - ) - - # Verify structure - assert isinstance(result, pd.DataFrame) - if not result.empty: - # Check base fields - for field in EXPECTED_BASE_FIELDS: - assert field in result.columns, f"Missing base field: {field}" - # Check factor fields - for field in EXPECTED_FACTOR_FIELDS: - assert field in result.columns, f"Missing factor field: {field}" - - -if __name__ == "__main__": - # 运行 pytest 单元测试(真实API调用) - pytest.main([__file__, "-v"]) diff --git a/tests/test_daily_storage.py b/tests/test_daily_storage.py deleted file mode 100644 index ad36250..0000000 --- a/tests/test_daily_storage.py +++ /dev/null @@ -1,242 +0,0 @@ -"""Tests for DuckDB storage validation. - -Validates two key points: -1. All stocks from stock_basic.csv are saved in daily table -2. No abnormal data with very few data points (< 10 rows per stock) - -使用 3 个月的真实数据进行测试 (2024年1月-3月) -""" - -import pytest -import pandas as pd -from datetime import datetime, timedelta -from src.data.storage import Storage -from src.data.api_wrappers.api_stock_basic import _get_csv_path - - -class TestDailyStorageValidation: - """Test daily table storage integrity and completeness.""" - - # 测试数据时间范围:3个月 - TEST_START_DATE = "20240101" - TEST_END_DATE = "20240331" - - @pytest.fixture - def storage(self): - """Create storage instance.""" - return Storage() - - @pytest.fixture - def stock_basic_df(self): - """Load stock basic data from CSV.""" - csv_path = _get_csv_path() - if not csv_path.exists(): - pytest.skip(f"stock_basic.csv not found at {csv_path}") - return pd.read_csv(csv_path) - - @pytest.fixture - def daily_df(self, storage): - """Load daily data from DuckDB (3 months).""" - if not storage.exists("daily"): - pytest.skip("daily table not found in DuckDB") - - # 从 DuckDB 加载 3 个月数据 - df = storage.load( - "daily", start_date=self.TEST_START_DATE, end_date=self.TEST_END_DATE - ) - - if df.empty: - pytest.skip( - f"No data found for period {self.TEST_START_DATE} to {self.TEST_END_DATE}" - ) - - return df - - def test_duckdb_connection(self, storage): - """Test DuckDB connection and basic operations.""" - assert storage.exists("daily") or True # 至少连接成功 - print(f"[TEST] DuckDB connection successful") - - def test_load_3months_data(self, storage): - """Test loading 3 months of data from DuckDB.""" - df = storage.load( - "daily", start_date=self.TEST_START_DATE, end_date=self.TEST_END_DATE - ) - - if df.empty: - pytest.skip("No data available for testing period") - - # 验证数据覆盖范围 - dates = df["trade_date"].astype(str) - min_date = dates.min() - max_date = dates.max() - - print(f"[TEST] Loaded {len(df)} rows from {min_date} to {max_date}") - assert len(df) > 0, "Should have data in the 3-month period" - - def test_all_stocks_saved(self, storage, stock_basic_df, daily_df): - """Verify all stocks from stock_basic are saved in daily table. - - This test ensures data completeness - every stock in stock_basic - should have corresponding data in daily table. - """ - if daily_df.empty: - pytest.fail("daily table is empty for test period") - - # Get unique stock codes from both sources - expected_codes = set(stock_basic_df["ts_code"].dropna().unique()) - actual_codes = set(daily_df["ts_code"].dropna().unique()) - - # Check for missing stocks - missing_codes = expected_codes - actual_codes - - if missing_codes: - missing_list = sorted(missing_codes) - # Show first 20 missing stocks as sample - sample = missing_list[:20] - msg = f"Found {len(missing_codes)} stocks missing from daily table:\n" - msg += f"Sample missing: {sample}\n" - if len(missing_list) > 20: - msg += f"... and {len(missing_list) - 20} more" - # 对于3个月数据,允许部分股票缺失(可能是新股或未上市) - print(f"[WARNING] {msg}") - # 只验证至少有80%的股票存在 - coverage = len(actual_codes) / len(expected_codes) * 100 - assert coverage >= 80, ( - f"Stock coverage {coverage:.1f}% is below 80% threshold" - ) - else: - print( - f"[TEST] All {len(expected_codes)} stocks from stock_basic are present in daily table" - ) - - def test_no_stock_with_insufficient_data(self, storage, daily_df): - """Verify no stock has abnormally few data points (< 5 rows in 3 months). - - Stocks with very few data points may indicate sync failures, - delisted stocks not properly handled, or data corruption. - """ - if daily_df.empty: - pytest.fail("daily table is empty for test period") - - # Count rows per stock - stock_counts = daily_df.groupby("ts_code").size() - - # Find stocks with less than 5 data points in 3 months - insufficient_stocks = stock_counts[stock_counts < 5] - - if not insufficient_stocks.empty: - # Separate into categories for better reporting - empty_stocks = stock_counts[stock_counts == 0] - very_few_stocks = stock_counts[(stock_counts > 0) & (stock_counts < 5)] - - msg = f"Found {len(insufficient_stocks)} stocks with insufficient data (< 5 rows in 3 months):\n" - - if not empty_stocks.empty: - msg += f"\nEmpty stocks (0 rows): {len(empty_stocks)}\n" - sample = sorted(empty_stocks.index[:10].tolist()) - msg += f"Sample: {sample}" - - if not very_few_stocks.empty: - msg += f"\nVery few data points (1-4 rows): {len(very_few_stocks)}\n" - # Show counts for these stocks - sample = very_few_stocks.sort_values().head(20) - msg += "Sample (ts_code: count):\n" - for code, count in sample.items(): - msg += f" {code}: {count} rows\n" - - # 对于3个月数据,允许少量异常,但比例不能超过5% - if len(insufficient_stocks) / len(stock_counts) > 0.05: - pytest.fail(msg) - else: - print(f"[WARNING] {msg}") - - print(f"[TEST] All stocks have sufficient data (>= 5 rows in 3 months)") - - def test_data_integrity_basic(self, storage, daily_df): - """Basic data integrity checks for daily table.""" - if daily_df.empty: - pytest.fail("daily table is empty for test period") - - # Check required columns exist - required_columns = ["ts_code", "trade_date"] - missing_columns = [ - col for col in required_columns if col not in daily_df.columns - ] - - if missing_columns: - pytest.fail(f"Missing required columns: {missing_columns}") - - # Check for null values in key columns - null_ts_code = daily_df["ts_code"].isna().sum() - null_trade_date = daily_df["trade_date"].isna().sum() - - if null_ts_code > 0: - pytest.fail(f"Found {null_ts_code} rows with null ts_code") - if null_trade_date > 0: - pytest.fail(f"Found {null_trade_date} rows with null trade_date") - - print(f"[TEST] Data integrity check passed for 3-month period") - - def test_polars_export(self, storage): - """Test Polars export functionality.""" - if not storage.exists("daily"): - pytest.skip("daily table not found") - - import polars as pl - - # 测试 load_polars 方法 - df = storage.load_polars( - "daily", start_date=self.TEST_START_DATE, end_date=self.TEST_END_DATE - ) - - assert isinstance(df, pl.DataFrame), "Should return Polars DataFrame" - print(f"[TEST] Polars export successful: {len(df)} rows") - - def test_stock_data_coverage_report(self, storage, daily_df): - """Generate a summary report of stock data coverage. - - This test provides visibility into data distribution without failing. - """ - if daily_df.empty: - pytest.skip("daily table is empty - cannot generate report") - - stock_counts = daily_df.groupby("ts_code").size() - - # Calculate statistics - total_stocks = len(stock_counts) - min_count = stock_counts.min() - max_count = stock_counts.max() - median_count = stock_counts.median() - mean_count = stock_counts.mean() - - # Distribution buckets (adjusted for 3-month period, ~60 trading days) - very_low = (stock_counts < 5).sum() - low = ((stock_counts >= 5) & (stock_counts < 20)).sum() - medium = ((stock_counts >= 20) & (stock_counts < 40)).sum() - high = (stock_counts >= 40).sum() - - report = f""" -=== Stock Data Coverage Report (3 months: {self.TEST_START_DATE} to {self.TEST_END_DATE}) === -Total stocks: {total_stocks} -Data points per stock: - Min: {min_count} - Max: {max_count} - Median: {median_count:.0f} - Mean: {mean_count:.1f} - -Distribution: - < 5 rows: {very_low} stocks ({very_low / total_stocks * 100:.1f}%) - 5-19: {low} stocks ({low / total_stocks * 100:.1f}%) - 20-39: {medium} stocks ({medium / total_stocks * 100:.1f}%) - >= 40: {high} stocks ({high / total_stocks * 100:.1f}%) -""" - print(report) - - # This is an informational test - it should not fail - # But we assert to mark it as passed - assert total_stocks > 0 - - -if __name__ == "__main__": - pytest.main([__file__, "-v", "-s"]) diff --git a/tests/test_financial_price_merge.py b/tests/test_financial_price_merge.py new file mode 100644 index 0000000..8936d8e --- /dev/null +++ b/tests/test_financial_price_merge.py @@ -0,0 +1,244 @@ +"""财务数据与行情数据拼接测试。 + +测试场景: +1. 普通财务数据:正常公告,之后无修改 +2. 隔日修改:公告后几天发布修正版 +3. 当日修改:同一天发布多版,取 update_flag=1 的 +4. 边界条件:财务数据缺失、行情数据早于最早财务数据 +""" + +import polars as pl +from datetime import date +from src.data.financial_loader import FinancialLoader + + +def create_mock_price_data() -> pl.DataFrame: + """创建模拟行情数据。""" + return pl.DataFrame( + { + "ts_code": ["000001.SZ"] * 10, + "trade_date": [ + "20240101", + "20240102", + "20240103", + "20240104", + "20240105", + "20240108", + "20240109", + "20240110", + "20240111", + "20240112", + ], + "close": [10.0, 10.2, 10.3, 10.1, 10.5, 10.6, 10.4, 10.7, 10.8, 10.9], + } + ) + + +def create_mock_financial_data() -> pl.DataFrame: + """创建模拟财务数据(覆盖多种场景)。 + + 注意:f_ann_date 必须是 Date 类型(与数据库保持一致)。 + """ + return pl.DataFrame( + { + "ts_code": ["000001.SZ", "000001.SZ", "000001.SZ", "000001.SZ"], + # 场景1: 2023Q3 报告,正常公告 + # 场景2: 同日多版(update_flag 区分) + # 场景3: 隔日修改 + "f_ann_date": [ + date(2024, 1, 2), + date(2024, 1, 2), + date(2024, 1, 5), + date(2024, 1, 10), + ], + "end_date": ["20230930", "20230930", "20230930", "20231231"], + "report_type": [1, 1, 1, 1], # 整数类型(与数据库一致) + "update_flag": [0, 1, 1, 1], # 整数类型(与数据库一致) + "net_profit": [1000000.0, 1100000.0, 1100000.0, 1200000.0], + "revenue": [5000000.0, 5200000.0, 5200000.0, 6000000.0], + } + ) + + +def test_financial_data_cleaning(): + """测试财务数据清洗逻辑。""" + print("=== 测试 1: 财务数据清洗 ===") + + df_finance = create_mock_financial_data() + print("原始财务数据:") + print(df_finance) + + loader = FinancialLoader() + + # 手动执行清洗(模拟 load_financial_data 的逻辑) + # 步骤1: 仅保留合并报表 + df = df_finance.filter(pl.col("report_type") == 1) + + # 步骤2: 按 update_flag 降序排列后去重 + df = df.with_columns( + [pl.col("update_flag").cast(pl.Int32).alias("update_flag_int")] + ) + + df = df.sort( + ["ts_code", "f_ann_date", "update_flag_int"], descending=[False, False, True] + ) + + df = df.unique(subset=["ts_code", "f_ann_date"], keep="first") + df = df.drop("update_flag_int") + + # 步骤3: 排序(f_ann_date 已经是 Date 类型) + df = df.sort(["ts_code", "f_ann_date"]) + + print("\n清洗后的财务数据:") + print(df) + + # 验证:应该有3条记录(第1-2行去重为1条,第3行,第4行) + assert len(df) == 3, f"清洗后应该有3条记录,实际有 {len(df)} 条" + + # 验证:2024-01-02 的 update_flag 应该是 1 + row_jan02 = df.filter(pl.col("f_ann_date") == date(2024, 1, 2)) + assert len(row_jan02) == 1, "应该有1条 2024-01-02 的记录" + assert row_jan02["update_flag"][0] == 1, "update_flag 应该为 1" + assert row_jan02["net_profit"][0] == 1100000.0, "net_profit 应该为 1100000" + + print("\n[通过] 财务数据清洗测试通过!") + return df + + +def test_financial_price_merge(): + """测试财务数据拼接逻辑(无未来函数验证)。""" + print("\n=== 测试 2: 财务数据与行情数据拼接 ===") + + df_price = create_mock_price_data() + df_finance_raw = create_mock_financial_data() + + loader = FinancialLoader() + + # 步骤1: 清洗财务数据(手动执行) + # 注意:f_ann_date 已经是 Date 类型,不需要转换 + df_finance = df_finance_raw.filter(pl.col("report_type") == 1) + df_finance = df_finance.with_columns( + [pl.col("update_flag").cast(pl.Int32).alias("update_flag_int")] + ) + df_finance = df_finance.sort( + ["ts_code", "f_ann_date", "update_flag_int"], descending=[False, False, True] + ) + df_finance = df_finance.unique(subset=["ts_code", "f_ann_date"], keep="first") + df_finance = df_finance.drop("update_flag_int") + df_finance = df_finance.sort(["ts_code", "f_ann_date"]) + + print("清洗后的财务数据:") + print(df_finance) + + # 步骤2: 转换行情数据日期为 Date 类型 + df_price = df_price.with_columns( + [pl.col("trade_date").str.strptime(pl.Date, "%Y%m%d").alias("trade_date")] + ) + df_price = df_price.sort(["ts_code", "trade_date"]) + + # 步骤3: 拼接 + financial_cols = ["net_profit", "revenue"] + merged = loader.merge_financial_with_price(df_price, df_finance, financial_cols) + + # 步骤4: 转回字符串格式 + merged = merged.with_columns( + [pl.col("trade_date").dt.strftime("%Y%m%d").alias("trade_date")] + ) + + print("\n拼接结果:") + print(merged) + + # 验证无未来函数: + # 20240101 之前不应有 2023Q3 数据(因为 20240102 才公告) + jan01 = merged.filter(pl.col("trade_date") == "20240101") + assert jan01["net_profit"].is_null().all(), ( + "2024-01-01 不应有 2023Q3 数据(尚未公告)" + ) + print("[验证 1] 2024-01-01 net_profit 为 null - 正确(公告前无数据)") + + # 20240102 及之后应该看到 net_profit=1100000(update_flag=1 的版本) + jan02 = merged.filter(pl.col("trade_date") == "20240102") + assert jan02["net_profit"][0] == 1100000.0, "2024-01-02 应使用 update_flag=1 的数据" + print("[验证 2] 2024-01-02 net_profit=1100000 - 正确(使用 update_flag=1)") + + # 20240104 应延续使用 2023Q3 数据 + jan04 = merged.filter(pl.col("trade_date") == "20240104") + assert jan04["net_profit"][0] == 1100000.0, "2024-01-04 应延续使用 2023Q3 数据" + print("[验证 3] 2024-01-04 net_profit=1100000 - 正确(延续使用)") + + # 20240110 应切换到 2023Q4 数据(新公告) + jan10 = merged.filter(pl.col("trade_date") == "20240110") + assert jan10["net_profit"][0] == 1200000.0, "2024-01-10 应切换到 2023Q4 数据" + print("[验证 4] 2024-01-10 net_profit=1200000 - 正确(新财报公告)") + + # 20240112 应继续延续使用 2023Q4 数据 + jan12 = merged.filter(pl.col("trade_date") == "20240112") + assert jan12["net_profit"][0] == 1200000.0, "2024-01-12 应继续使用 2023Q4 数据" + print("[验证 5] 2024-01-12 net_profit=1200000 - 正确(延续使用)") + + print("\n[通过] 所有验证通过,无未来函数!") + return merged + + +def test_empty_financial_data(): + """测试财务数据为空的情况。""" + print("\n=== 测试 3: 空财务数据场景 ===") + + df_price = create_mock_price_data() + df_empty = pl.DataFrame() + + loader = FinancialLoader() + + # 转换行情数据日期为 Date 类型 + df_price = df_price.with_columns( + [pl.col("trade_date").str.strptime(pl.Date, "%Y%m%d").alias("trade_date")] + ) + df_price = df_price.sort(["ts_code", "trade_date"]) + + # 拼接空财务数据 + merged = loader.merge_financial_with_price(df_price, df_empty, ["net_profit"]) + + # 转回字符串格式 + merged = merged.with_columns( + [pl.col("trade_date").dt.strftime("%Y%m%d").alias("trade_date")] + ) + + # 验证财务列为空 + assert merged["net_profit"].is_null().all(), ( + "财务数据为空时,net_profit 应全为 null" + ) + + print("空财务数据拼接结果:") + print(merged) + print("\n[通过] 空财务数据场景测试通过!") + + +def run_all_tests(): + """运行所有测试。""" + print("开始运行财务数据拼接功能测试...\n") + print("=" * 60) + + try: + # 测试 1: 数据清洗 + test_financial_data_cleaning() + + # 测试 2: 数据拼接 + test_financial_price_merge() + + # 测试 3: 空数据场景 + test_empty_financial_data() + + print("\n" + "=" * 60) + print("所有测试通过!") + print("=" * 60) + + except AssertionError as e: + print(f"\n[失败] 测试断言失败: {e}") + raise + except Exception as e: + print(f"\n[错误] 测试执行出错: {e}") + raise + + +if __name__ == "__main__": + run_all_tests()