Files
ProStock/src/data/API_INTERFACE_SPEC.md

24 KiB
Raw Blame History

ProStock 数据接口封装规范

1. 概述

本文档定义了在 src/data/ 目录下新增 Tushare API 接口封装的标准规范。所有非特殊接口(因子和基础数据)必须遵循此规范,以确保:

  • 代码风格统一
  • 自动 sync 支持
  • 增量更新逻辑一致
  • 减少存储写入压力

2. 接口分类

2.1 特殊接口(不参与统一 sync

以下接口有独立的同步逻辑,不参与本文档定义的自动 sync 机制:

接口类型 示例 说明
交易日历 trade_cal 全局数据,按日期范围获取
股票基础信息 stock_basic 一次性全量获取CSV 存储
辅助数据 行业分类、概念分类 低频更新,独立管理

2.2 标准接口(必须遵循本规范)

所有按股票按日期获取的因子数据、行情数据、财务数据等,必须遵循本规范。

3. 文件结构

3.1 文件命名

{data_type}.py

示例:

  • daily.py - 日线行情
  • moneyflow.py - 资金流向
  • limit_list.py - 涨跌停数据
  • stk_holdernumber.py - 股东人数

3.2 文件位置

src/data/
├── __init__.py          # 导出公共接口
├── client.py            # TushareClient已有
├── config.py            # 配置管理(已有)
├── storage.py           # 存储管理(已有)
├── rate_limiter.py      # 速率限制(已有)
├── trade_cal.py         # 交易日历(特殊接口)
├── stock_basic.py       # 股票基础(特殊接口)
├── daily.py             # 日线行情(参考示例)
└── {new_data_type}.py   # 新增接口文件

4. 接口设计规范

4.1 数据获取函数

4.1.1 按股票获取的接口

适用于:日线行情、分钟线、资金流向等

def get_{data_type}(
    ts_code: str,
    start_date: Optional[str] = None,
    end_date: Optional[str] = None,
    # 其他可选参数...
) -> pd.DataFrame:
    """获取 {数据描述}
    Args:
        ts_code: 股票代码(如 '000001.SZ'
        start_date: 开始日期YYYYMMDD格式
        end_date: 结束日期YYYYMMDD格式
        # 其他参数说明...

    Returns:
        pd.DataFrame 包含以下字段:
        - ts_code: 股票代码
        - trade_date: 交易日期
        # 其他字段...

    Example:
        >>> data = get_{data_type}('000001.SZ', start_date='20240101', end_date='20240131')
    """
    client = TushareClient()
    
    params = {"ts_code": ts_code}
    if start_date:
        params["start_date"] = start_date
    if end_date:
        params["end_date"] = end_date
    # 其他参数...
    
    data = client.query("{api_name}", **params)
    return data

4.1.2 按日期获取的接口

适用于:每日涨跌停、每日龙虎榜、每日筹码分布等

def get_{data_type}(
    trade_date: Optional[str] = None,
    start_date: Optional[str] = None,
    end_date: Optional[str] = None,
    ts_code: Optional[str] = None,
    # 其他可选参数...
) -> pd.DataFrame:
    """获取 {数据描述}
    **优先按日期获取**(推荐):
    - 使用 trade_date 获取单日全市场数据
    - 或使用 start_date + end_date 获取区间数据

    Args:
        trade_date: 交易日期YYYYMMDD格式获取单日全市场数据
        start_date: 开始日期YYYYMMDD格式
        end_date: 结束日期YYYYMMDD格式
        ts_code: 股票代码(可选,用于过滤特定股票)
        # 其他参数说明...

    Returns:
        pd.DataFrame 包含以下字段:
        - ts_code: 股票代码
        - trade_date: 交易日期
        # 其他字段...

    Example:
        >>> # 获取单日全市场数据(推荐)
        >>> data = get_{data_type}(trade_date='20240115')
        >>> # 获取区间数据
        >>> data = get_{data_type}(start_date='20240101', end_date='20240131')
    """
    client = TushareClient()
    
    params = {}
    if trade_date:
        params["trade_date"] = trade_date
    if start_date:
        params["start_date"] = start_date
    if end_date:
        params["end_date"] = end_date
    if ts_code:
        params["ts_code"] = ts_code
    # 其他参数...
    
    data = client.query("{api_name}", **params)
    return data

4.2 关键设计原则

4.2.1 优先按日期获取

强烈建议优先实现按日期获取的接口:

  1. 效率更高:一次请求获取全市场数据
  2. API 调用更少N 天 = N 次调用,而非 N 天 × M 只股票
  3. 更适合增量更新:按天检查本地数据,只获取缺失日期

4.2.2 日期字段统一

  • 统一使用 trade_date 作为日期字段名
  • 日期格式:YYYYMMDD 字符串
  • 如果 API 返回其他字段名(如 dateend_date),在返回前重命名为 trade_date

4.2.3 股票代码字段

  • 统一使用 ts_code 作为股票代码字段名
  • 格式:{code}.{exchange},如 000001.SZ600000.SH

5. Sync 集成规范

5.1 在 sync.py 中注册新数据类型

DataSync 类中添加新数据类型的同步支持:

class DataSync:
    """Data synchronization manager with full/incremental sync support."""

    DEFAULT_MAX_WORKERS = 10
    
    # 数据类型配置
    DATASET_CONFIG = {
        "daily": {
            "api_name": "pro_bar",
            "fetch_by": "stock",  # 按股票获取
            "date_field": "trade_date",
            "key_fields": ["ts_code", "trade_date"],
        },
        "moneyflow": {
            "api_name": "moneyflow",
            "fetch_by": "stock",  # 按股票获取
            "date_field": "trade_date",
            "key_fields": ["ts_code", "trade_date"],
        },
        "limit_list": {
            "api_name": "limit_list",
            "fetch_by": "date",  # 按日期获取(优先)
            "date_field": "trade_date",
            "key_fields": ["ts_code", "trade_date"],
        },
        # 新增数据类型...
        "{new_data_type}": {
            "api_name": "{tushare_api_name}",
            "fetch_by": "date",  # "date" 或 "stock"
            "date_field": "trade_date",
            "key_fields": ["ts_code", "trade_date"],  # 用于去重的主键
        },
    }

5.2 实现同步方法

5.2.1 按日期获取的同步方法(推荐)

def sync_by_date(
    self,
    dataset_name: str,
    start_date: str,
    end_date: str,
) -> pd.DataFrame:
    """Sync data by date (fetch all stocks for each date).

    This is the RECOMMENDED approach for date-based data like:
    - limit_list (涨跌停)
    - top_list (龙虎榜)
    - cyq_perf (筹码分布)

    Args:
        dataset_name: Name of the dataset in DATASET_CONFIG
        start_date: Start date (YYYYMMDD)
        end_date: End date (YYYYMMDD)

    Returns:
        Combined DataFrame with all data
    """
    from src.data.trade_cal import get_trading_days
    
    config = self.DATASET_CONFIG[dataset_name]
    api_name = config["api_name"]
    date_field = config["date_field"]
    
    # Get trading days in the range
    trading_days = get_trading_days(start_date, end_date)
    if not trading_days:
        print(f"[DataSync] No trading days in range {start_date} to {end_date}")
        return pd.DataFrame()
    
    print(f"[DataSync] Fetching {dataset_name} for {len(trading_days)} trading days")
    
    all_data = []
    error_occurred = False
    
    for trade_date in tqdm(trading_days, desc=f"Syncing {dataset_name}"):
        if not self._stop_flag.is_set():
            break
        
        try:
            data = self.client.query(
                api_name,
                trade_date=trade_date,
            )
            if not data.empty:
                all_data.append(data)
        except Exception as e:
            self._stop_flag.clear()
            error_occurred = True
            print(f"[ERROR] Failed to fetch {dataset_name} for {trade_date}: {e}")
            raise
    
    if error_occurred or not all_data:
        return pd.DataFrame()
    
    # Combine all data
    combined = pd.concat(all_data, ignore_index=True)
    
    # Ensure date field is consistent
    if date_field not in combined.columns and "trade_date" in combined.columns:
        combined = combined.rename(columns={"trade_date": date_field})
    
    return combined

5.2.2 按股票获取的同步方法

def sync_by_stock(
    self,
    dataset_name: str,
    ts_code: str,
    start_date: str,
    end_date: str,
) -> pd.DataFrame:
    """Sync data by stock (fetch all dates for each stock).

    Use this for stock-based data like:
    - daily (日线行情)
    - moneyflow (资金流向)
    - stk_holdernumber (股东人数)

    Args:
        dataset_name: Name of the dataset in DATASET_CONFIG
        ts_code: Stock code
        start_date: Start date (YYYYMMDD)
        end_date: End date (YYYYMMDD)

    Returns:
        DataFrame with data for the stock
    """
    config = self.DATASET_CONFIG[dataset_name]
    api_name = config["api_name"]
    
    if not self._stop_flag.is_set():
        return pd.DataFrame()
    
    try:
        data = self.client.query(
            api_name,
            ts_code=ts_code,
            start_date=start_date,
            end_date=end_date,
        )
        return data
    except Exception as e:
        self._stop_flag.clear()
        print(f"[ERROR] Exception syncing {dataset_name} for {ts_code}: {e}")
        raise

5.3 增量更新逻辑

5.3.1 通用增量更新检查

def check_incremental_sync(
    self,
    dataset_name: str,
    force_full: bool = False,
) -> tuple[bool, Optional[str], Optional[str], Optional[str]]:
    """Check if incremental sync is needed for a dataset.

    Args:
        dataset_name: Name of the dataset
        force_full: If True, force full sync

    Returns:
        Tuple of (sync_needed, start_date, end_date, local_last_date)
    """
    config = self.DATASET_CONFIG[dataset_name]
    date_field = config["date_field"]
    
    # If force_full, always sync from default start
    if force_full:
        print(f"[DataSync] Force full sync for {dataset_name}")
        return (True, DEFAULT_START_DATE, get_today_date(), None)
    
    # Check local data
    local_data = self.storage.load(dataset_name)
    if local_data.empty or date_field not in local_data.columns:
        print(f"[DataSync] No local {dataset_name} data, full sync needed")
        return (True, DEFAULT_START_DATE, get_today_date(), None)
    
    # Get local last date
    local_last_date = str(local_data[date_field].max())
    print(f"[DataSync] Local {dataset_name} last date: {local_last_date}")
    
    # Get calendar last trading day
    today = get_today_date()
    _, cal_last = self.get_trade_calendar_bounds(DEFAULT_START_DATE, today)
    
    if cal_last is None:
        print(f"[DataSync] Failed to get trade calendar, proceeding with sync")
        return (True, DEFAULT_START_DATE, today, local_last_date)
    
    print(f"[DataSync] Calendar last trading day: {cal_last}")
    
    # Compare dates
    if int(local_last_date) >= int(cal_last):
        print(f"[DataSync] {dataset_name} is up-to-date, skipping sync")
        return (False, None, None, None)
    
    # Need incremental sync
    sync_start = get_next_date(local_last_date)
    print(f"[DataSync] Incremental sync for {dataset_name} from {sync_start} to {cal_last}")
    return (True, sync_start, cal_last, local_last_date)

5.3.2 完整的同步入口

def sync_dataset(
    self,
    dataset_name: str,
    force_full: bool = False,
    max_workers: Optional[int] = None,
) -> pd.DataFrame:
    """Sync a dataset with automatic incremental update.

    This is the main entry point for syncing any dataset.

    Args:
        dataset_name: Name of the dataset in DATASET_CONFIG
        force_full: If True, force full reload
        max_workers: Number of worker threads (for stock-based sync)

    Returns:
        DataFrame with synced data
    """
    print("\n" + "=" * 60)
    print(f"[DataSync] Starting {dataset_name} sync...")
    print("=" * 60)
    
    # Ensure trade calendar is up-to-date
    sync_trade_cal_cache()
    
    # Check if sync is needed
    sync_needed, start_date, end_date, local_last = self.check_incremental_sync(
        dataset_name, force_full
    )
    
    if not sync_needed:
        print(f"[DataSync] {dataset_name} is up-to-date, skipping")
        return pd.DataFrame()
    
    config = self.DATASET_CONFIG[dataset_name]
    fetch_by = config["fetch_by"]
    
    # Fetch data based on strategy
    if fetch_by == "date":
        # Fetch by date (all stocks per day)
        data = self.sync_by_date(dataset_name, start_date, end_date)
    else:
        # Fetch by stock (all dates per stock)
        data = self._sync_all_stocks(dataset_name, start_date, end_date, max_workers)
    
    if data.empty:
        print(f"[DataSync] No new data for {dataset_name}")
        return pd.DataFrame()
    
    # Save to storage (single write)
    self.storage.save(dataset_name, data, mode="append")
    
    print(f"[DataSync] Synced {len(data)} rows for {dataset_name}")
    return data

def _sync_all_stocks(
    self,
    dataset_name: str,
    start_date: str,
    end_date: str,
    max_workers: Optional[int] = None,
) -> pd.DataFrame:
    """Sync data for all stocks (stock-based fetch)."""
    stock_codes = self.get_all_stock_codes()
    if not stock_codes:
        return pd.DataFrame()
    
    print(f"[DataSync] Syncing {dataset_name} for {len(stock_codes)} stocks")
    
    self._stop_flag.set()
    results = []
    
    workers = max_workers or self.max_workers
    with ThreadPoolExecutor(max_workers=workers) as executor:
        future_to_code = {
            executor.submit(
                self.sync_by_stock, dataset_name, ts_code, start_date, end_date
            ): ts_code
            for ts_code in stock_codes
        }
        
        with tqdm(total=len(stock_codes), desc=f"Syncing {dataset_name}") as pbar:
            for future in as_completed(future_to_code):
                try:
                    data = future.result()
                    if not data.empty:
                        results.append(data)
                except Exception as e:
                    executor.shutdown(wait=False, cancel_futures=True)
                    raise
                pbar.update(1)
    
    if not results:
        return pd.DataFrame()
    
    return pd.concat(results, ignore_index=True)

6. 存储规范

6.1 Storage 类使用

所有数据通过 Storage 类进行 HDF5 存储:

from src.data.storage import Storage

storage = Storage()

# 保存数据(自动增量合并)
storage.save("dataset_name", data, mode="append")

# 加载数据
all_data = storage.load("dataset_name")
filtered_data = storage.load("dataset_name", start_date="20240101", end_date="20240131")

# 获取最新日期
last_date = storage.get_last_date("dataset_name")

# 检查是否存在
exists = storage.exists("dataset_name")

6.2 增量写入策略

关键原则:所有数据在请求完成后一次性写入,而非逐条写入:

# ❌ 错误:逐条写入(性能差)
for date in dates:
    data = fetch(date)
    storage.save("dataset", data, mode="append")  # 多次写入

# ✅ 正确:批量写入(性能好)
all_data = []
for date in dates:
    data = fetch(date)
    all_data.append(data)
combined = pd.concat(all_data, ignore_index=True)
storage.save("dataset", combined, mode="append")  # 单次写入

6.3 去重策略

Storage.save() 方法会自动去重,基于配置中的 key_fields

# storage.py 中的实现
combined = pd.concat([existing, data], ignore_index=True)
combined = combined.drop_duplicates(
    subset=["ts_code", "trade_date"],  # 使用 key_fields
    keep="last"  # 保留最新数据
)

7. 完整示例:新增涨跌停数据接口

7.1 创建 limit_list.py

"""Limit up/down list interface.

Fetch stocks that hit limit up or limit down for a specific trade date.
This is a date-based interface (recommended approach).
"""
import pandas as pd
from typing import Optional
from src.data.client import TushareClient


def get_limit_list(
    trade_date: Optional[str] = None,
    ts_code: Optional[str] = None,
    start_date: Optional[str] = None,
    end_date: Optional[str] = None,
) -> pd.DataFrame:
    """获取涨跌停数据。

    **优先按日期获取**(推荐):
    - 使用 trade_date 获取单日全市场涨跌停数据
    - 或使用 start_date + end_date 获取区间数据

    Args:
        trade_date: 交易日期YYYYMMDD格式获取单日全市场数据
        ts_code: 股票代码(可选,用于过滤)
        start_date: 开始日期YYYYMMDD格式
        end_date: 结束日期YYYYMMDD格式

    Returns:
        pd.DataFrame 包含以下字段:
        - ts_code: 股票代码
        - trade_date: 交易日期
        - name: 股票名称
        - close: 收盘价
        - pct_chg: 涨跌幅
        - amp: 振幅
        - fc_ratio: 封单金额/日成交额
        - fl_ratio: 封单手数/流通股本
        - fd_amount: 封单金额
        - first_time: 首次涨停时间
        - last_time: 最后封板时间
        - open_times: 打开次数
        - strth: 涨停强度
        - limit: 涨停类型U涨停D跌停

    Example:
        >>> # 获取单日全市场涨跌停数据(推荐)
        >>> data = get_limit_list(trade_date='20240115')
        >>> # 获取区间数据
        >>> data = get_limit_list(start_date='20240101', end_date='20240131')
    """
    client = TushareClient()
    
    params = {}
    if trade_date:
        params["trade_date"] = trade_date
    if ts_code:
        params["ts_code"] = ts_code
    if start_date:
        params["start_date"] = start_date
    if end_date:
        params["end_date"] = end_date
    
    data = client.query("limit_list", **params)
    return data

7.2 在 sync.py 中注册

class DataSync:
    """Data synchronization manager with full/incremental sync support."""

    DATASET_CONFIG = {
        # ... 其他配置 ...
        "limit_list": {
            "api_name": "limit_list",
            "fetch_by": "date",  # 按日期获取
            "date_field": "trade_date",
            "key_fields": ["ts_code", "trade_date"],
        },
    }
    
    # ... 其他方法 ...

    def sync_limit_list(
        self,
        force_full: bool = False,
    ) -> pd.DataFrame:
        """Sync limit list data."""
        return self.sync_dataset("limit_list", force_full)


# 便捷函数
def sync_limit_list(force_full: bool = False) -> pd.DataFrame:
    """Sync limit up/down data."""
    sync_manager = DataSync()
    return sync_manager.sync_limit_list(force_full)

7.3 更新 init.py

from src.data.limit_list import get_limit_list

__all__ = [
    # ... 其他导出 ...
    "get_limit_list",
]

8. 测试规范

8.1 测试文件结构

tests/
├── test_sync.py          # sync 模块测试
├── test_daily.py         # daily 模块测试
└── test_{new_module}.py  # 新增模块测试

8.2 测试模板

"""Tests for {module_name} module."""
import pytest
from unittest.mock import patch, MagicMock
import pandas as pd
from src.data.{module_name} import get_{data_type}


class Test{DataType}:
    """Test cases for {data_type} data fetching."""

    @patch("src.data.{module_name}.TushareClient")
    def test_get_{data_type}_by_date(self, mock_client_class):
        """Test fetching data by date."""
        # Setup mock
        mock_client = MagicMock()
        mock_client_class.return_value = mock_client
        mock_client.query.return_value = pd.DataFrame({
            "ts_code": ["000001.SZ"],
            "trade_date": ["20240115"],
            # ... 其他字段 ...
        })
        
        # Call function
        result = get_{data_type}(trade_date="20240115")
        
        # Verify
        assert not result.empty
        mock_client.query.assert_called_once_with(
            "{api_name}",
            trade_date="20240115",
        )

    @patch("src.data.{module_name}.TushareClient")
    def test_get_{data_type}_by_stock(self, mock_client_class):
        """Test fetching data by stock code."""
        # Setup mock
        mock_client = MagicMock()
        mock_client_class.return_value = mock_client
        mock_client.query.return_value = pd.DataFrame({
            "ts_code": ["000001.SZ"],
            "trade_date": ["20240115"],
            # ... 其他字段 ...
        })
        
        # Call function
        result = get_{data_type}(
            ts_code="000001.SZ",
            start_date="20240101",
            end_date="20240131",
        )
        
        # Verify
        assert not result.empty
        mock_client.query.assert_called_once()

9. 检查清单

在提交新接口前,请确认以下事项:

9.1 文件结构

  • 文件位于 src/data/{data_type}.py
  • 已更新 src/data/__init__.py 导出公共接口
  • 已创建 tests/test_{data_type}.py 测试文件

9.2 接口实现

  • 数据获取函数使用 TushareClient
  • 函数包含完整的 Google 风格文档字符串
  • 日期参数使用 YYYYMMDD 格式
  • 返回的 DataFrame 包含 ts_codetrade_date 字段
  • 优先实现按日期获取的接口(如果 API 支持)

9.3 Sync 集成

  • 已在 DataSync.DATASET_CONFIG 中注册
  • 正确设置 fetch_by"date" 或 "stock"
  • 正确设置 date_fieldkey_fields
  • 已实现对应的 sync 方法或复用通用方法
  • 增量更新逻辑正确(检查本地最新日期)

9.4 存储优化

  • 所有数据一次性写入(非逐条)
  • 使用 storage.save(mode="append") 进行增量保存
  • 去重字段配置正确

9.5 测试

  • 已编写单元测试
  • 已 mock TushareClient
  • 测试覆盖正常和异常情况

10. 常见问题

Q1: API 返回的日期字段名不是 trade_date 怎么办?

在返回前重命名:

data = client.query("api_name", **params)
if "end_date" in data.columns:
    data = data.rename(columns={"end_date": "trade_date"})
return data

Q2: 如何处理分页limit/offset

Tushare Pro API 通常不需要手动分页,但如果需要:

all_data = []
offset = 0
limit = 5000

while True:
    data = client.query(
        "api_name",
        trade_date=trade_date,
        limit=limit,
        offset=offset,
    )
    if data.empty or len(data) < limit:
        all_data.append(data)
        break
    all_data.append(data)
    offset += limit

return pd.concat(all_data, ignore_index=True)

Q3: 如何处理需要额外参数的接口?

在函数签名中添加参数,并传递给 client.query

def get_data(
    ts_code: str,
    start_date: Optional[str] = None,
    end_date: Optional[str] = None,
    fields: Optional[list] = None,  # 额外参数
) -> pd.DataFrame:
    params = {"ts_code": ts_code}
    if start_date:
        params["start_date"] = start_date
    if end_date:
        params["end_date"] = end_date
    if fields:
        params["fields"] = ",".join(fields)
    
    return client.query("api_name", **params)

Q4: 如何处理没有 trade_date 字段的数据?

如果数据确实不包含日期字段(如静态数据),可以:

  1. 将其归类为"特殊接口",独立管理
  2. 或者添加一个 sync_date 字段记录同步时间

Q5: 如何处理按日期获取但 API 不支持的情况?

如果 API 只支持按股票获取:

  1. DATASET_CONFIG 中设置 fetch_by: "stock"
  2. 使用 _sync_all_stocks 方法进行同步
  3. 在文档中说明这是按股票获取的接口

最后更新: 2026-02-01