diff --git a/pytest.ini b/pytest.ini new file mode 100644 index 0000000..c7b23ec --- /dev/null +++ b/pytest.ini @@ -0,0 +1,3 @@ +[pytest] +pythonpath = . +testpaths = tests diff --git a/src/data/API_INTERFACE_SPEC.md b/src/data/API_INTERFACE_SPEC.md deleted file mode 100644 index 5280510..0000000 --- a/src/data/API_INTERFACE_SPEC.md +++ /dev/null @@ -1,847 +0,0 @@ -# ProStock 数据接口封装规范 - -## 1. 概述 - -本文档定义了在 `src/data/` 目录下新增 Tushare API 接口封装的标准规范。所有非特殊接口(因子和基础数据)必须遵循此规范,以确保: -- 代码风格统一 -- 自动 sync 支持 -- 增量更新逻辑一致 -- 减少存储写入压力 - -## 2. 接口分类 - -### 2.1 特殊接口(不参与统一 sync) - -以下接口有独立的同步逻辑,不参与本文档定义的自动 sync 机制: - -| 接口类型 | 示例 | 说明 | -|---------|------|------| -| 交易日历 | `trade_cal` | 全局数据,按日期范围获取 | -| 股票基础信息 | `stock_basic` | 一次性全量获取,CSV 存储 | -| 辅助数据 | 行业分类、概念分类 | 低频更新,独立管理 | - -### 2.2 标准接口(必须遵循本规范) - -所有**按股票**或**按日期**获取的因子数据、行情数据、财务数据等,必须遵循本规范。 - -## 3. 文件结构 - -### 3.1 文件命名 - -``` -{data_type}.py -``` - -示例: -- `daily.py` - 日线行情 -- `moneyflow.py` - 资金流向 -- `limit_list.py` - 涨跌停数据 -- `stk_holdernumber.py` - 股东人数 - -### 3.2 文件位置 - -``` -src/data/ -├── __init__.py # 导出公共接口 -├── client.py # TushareClient(已有) -├── config.py # 配置管理(已有) -├── storage.py # 存储管理(已有) -├── rate_limiter.py # 速率限制(已有) -├── trade_cal.py # 交易日历(特殊接口) -├── stock_basic.py # 股票基础(特殊接口) -├── daily.py # 日线行情(参考示例) -└── {new_data_type}.py # 新增接口文件 -``` - -## 4. 接口设计规范 - -### 4.1 数据获取函数 - -#### 4.1.1 按股票获取的接口 - -适用于:日线行情、分钟线、资金流向等 - -```python -def get_{data_type}( - ts_code: str, - start_date: Optional[str] = None, - end_date: Optional[str] = None, - # 其他可选参数... -) -> pd.DataFrame: - """获取 {数据描述}。 - - Args: - ts_code: 股票代码(如 '000001.SZ') - start_date: 开始日期(YYYYMMDD格式) - end_date: 结束日期(YYYYMMDD格式) - # 其他参数说明... - - Returns: - pd.DataFrame 包含以下字段: - - ts_code: 股票代码 - - trade_date: 交易日期 - # 其他字段... - - Example: - >>> data = get_{data_type}('000001.SZ', start_date='20240101', end_date='20240131') - """ - client = TushareClient() - - params = {"ts_code": ts_code} - if start_date: - params["start_date"] = start_date - if end_date: - params["end_date"] = end_date - # 其他参数... - - data = client.query("{api_name}", **params) - return data -``` - -#### 4.1.2 按日期获取的接口 - -适用于:每日涨跌停、每日龙虎榜、每日筹码分布等 - -```python -def get_{data_type}( - trade_date: Optional[str] = None, - start_date: Optional[str] = None, - end_date: Optional[str] = None, - ts_code: Optional[str] = None, - # 其他可选参数... -) -> pd.DataFrame: - """获取 {数据描述}。 - - **优先按日期获取**(推荐): - - 使用 trade_date 获取单日全市场数据 - - 或使用 start_date + end_date 获取区间数据 - - Args: - trade_date: 交易日期(YYYYMMDD格式),获取单日全市场数据 - start_date: 开始日期(YYYYMMDD格式) - end_date: 结束日期(YYYYMMDD格式) - ts_code: 股票代码(可选,用于过滤特定股票) - # 其他参数说明... - - Returns: - pd.DataFrame 包含以下字段: - - ts_code: 股票代码 - - trade_date: 交易日期 - # 其他字段... - - Example: - >>> # 获取单日全市场数据(推荐) - >>> data = get_{data_type}(trade_date='20240115') - >>> # 获取区间数据 - >>> data = get_{data_type}(start_date='20240101', end_date='20240131') - """ - client = TushareClient() - - params = {} - if trade_date: - params["trade_date"] = trade_date - if start_date: - params["start_date"] = start_date - if end_date: - params["end_date"] = end_date - if ts_code: - params["ts_code"] = ts_code - # 其他参数... - - data = client.query("{api_name}", **params) - return data -``` - -### 4.2 关键设计原则 - -#### 4.2.1 优先按日期获取 - -**强烈建议**优先实现按日期获取的接口: - -1. **效率更高**:一次请求获取全市场数据 -2. **API 调用更少**:N 天 = N 次调用,而非 N 天 × M 只股票 -3. **更适合增量更新**:按天检查本地数据,只获取缺失日期 - -#### 4.2.2 日期字段统一 - -- 统一使用 `trade_date` 作为日期字段名 -- 日期格式:`YYYYMMDD` 字符串 -- 如果 API 返回其他字段名(如 `date`、`end_date`),在返回前重命名为 `trade_date` - -#### 4.2.3 股票代码字段 - -- 统一使用 `ts_code` 作为股票代码字段名 -- 格式:`{code}.{exchange}`,如 `000001.SZ`、`600000.SH` - -## 5. Sync 集成规范 - -### 5.1 在 sync.py 中注册新数据类型 - -在 `DataSync` 类中添加新数据类型的同步支持: - -```python -class DataSync: - """Data synchronization manager with full/incremental sync support.""" - - DEFAULT_MAX_WORKERS = 10 - - # 数据类型配置 - DATASET_CONFIG = { - "daily": { - "api_name": "pro_bar", - "fetch_by": "stock", # 按股票获取 - "date_field": "trade_date", - "key_fields": ["ts_code", "trade_date"], - }, - "moneyflow": { - "api_name": "moneyflow", - "fetch_by": "stock", # 按股票获取 - "date_field": "trade_date", - "key_fields": ["ts_code", "trade_date"], - }, - "limit_list": { - "api_name": "limit_list", - "fetch_by": "date", # 按日期获取(优先) - "date_field": "trade_date", - "key_fields": ["ts_code", "trade_date"], - }, - # 新增数据类型... - "{new_data_type}": { - "api_name": "{tushare_api_name}", - "fetch_by": "date", # "date" 或 "stock" - "date_field": "trade_date", - "key_fields": ["ts_code", "trade_date"], # 用于去重的主键 - }, - } -``` - -### 5.2 实现同步方法 - -#### 5.2.1 按日期获取的同步方法(推荐) - -```python -def sync_by_date( - self, - dataset_name: str, - start_date: str, - end_date: str, -) -> pd.DataFrame: - """Sync data by date (fetch all stocks for each date). - - This is the RECOMMENDED approach for date-based data like: - - limit_list (涨跌停) - - top_list (龙虎榜) - - cyq_perf (筹码分布) - - Args: - dataset_name: Name of the dataset in DATASET_CONFIG - start_date: Start date (YYYYMMDD) - end_date: End date (YYYYMMDD) - - Returns: - Combined DataFrame with all data - """ - from src.data.trade_cal import get_trading_days - - config = self.DATASET_CONFIG[dataset_name] - api_name = config["api_name"] - date_field = config["date_field"] - - # Get trading days in the range - trading_days = get_trading_days(start_date, end_date) - if not trading_days: - print(f"[DataSync] No trading days in range {start_date} to {end_date}") - return pd.DataFrame() - - print(f"[DataSync] Fetching {dataset_name} for {len(trading_days)} trading days") - - all_data = [] - error_occurred = False - - for trade_date in tqdm(trading_days, desc=f"Syncing {dataset_name}"): - if not self._stop_flag.is_set(): - break - - try: - data = self.client.query( - api_name, - trade_date=trade_date, - ) - if not data.empty: - all_data.append(data) - except Exception as e: - self._stop_flag.clear() - error_occurred = True - print(f"[ERROR] Failed to fetch {dataset_name} for {trade_date}: {e}") - raise - - if error_occurred or not all_data: - return pd.DataFrame() - - # Combine all data - combined = pd.concat(all_data, ignore_index=True) - - # Ensure date field is consistent - if date_field not in combined.columns and "trade_date" in combined.columns: - combined = combined.rename(columns={"trade_date": date_field}) - - return combined -``` - -#### 5.2.2 按股票获取的同步方法 - -```python -def sync_by_stock( - self, - dataset_name: str, - ts_code: str, - start_date: str, - end_date: str, -) -> pd.DataFrame: - """Sync data by stock (fetch all dates for each stock). - - Use this for stock-based data like: - - daily (日线行情) - - moneyflow (资金流向) - - stk_holdernumber (股东人数) - - Args: - dataset_name: Name of the dataset in DATASET_CONFIG - ts_code: Stock code - start_date: Start date (YYYYMMDD) - end_date: End date (YYYYMMDD) - - Returns: - DataFrame with data for the stock - """ - config = self.DATASET_CONFIG[dataset_name] - api_name = config["api_name"] - - if not self._stop_flag.is_set(): - return pd.DataFrame() - - try: - data = self.client.query( - api_name, - ts_code=ts_code, - start_date=start_date, - end_date=end_date, - ) - return data - except Exception as e: - self._stop_flag.clear() - print(f"[ERROR] Exception syncing {dataset_name} for {ts_code}: {e}") - raise -``` - -### 5.3 增量更新逻辑 - -#### 5.3.1 通用增量更新检查 - -```python -def check_incremental_sync( - self, - dataset_name: str, - force_full: bool = False, -) -> tuple[bool, Optional[str], Optional[str], Optional[str]]: - """Check if incremental sync is needed for a dataset. - - Args: - dataset_name: Name of the dataset - force_full: If True, force full sync - - Returns: - Tuple of (sync_needed, start_date, end_date, local_last_date) - """ - config = self.DATASET_CONFIG[dataset_name] - date_field = config["date_field"] - - # If force_full, always sync from default start - if force_full: - print(f"[DataSync] Force full sync for {dataset_name}") - return (True, DEFAULT_START_DATE, get_today_date(), None) - - # Check local data - local_data = self.storage.load(dataset_name) - if local_data.empty or date_field not in local_data.columns: - print(f"[DataSync] No local {dataset_name} data, full sync needed") - return (True, DEFAULT_START_DATE, get_today_date(), None) - - # Get local last date - local_last_date = str(local_data[date_field].max()) - print(f"[DataSync] Local {dataset_name} last date: {local_last_date}") - - # Get calendar last trading day - today = get_today_date() - _, cal_last = self.get_trade_calendar_bounds(DEFAULT_START_DATE, today) - - if cal_last is None: - print(f"[DataSync] Failed to get trade calendar, proceeding with sync") - return (True, DEFAULT_START_DATE, today, local_last_date) - - print(f"[DataSync] Calendar last trading day: {cal_last}") - - # Compare dates - if int(local_last_date) >= int(cal_last): - print(f"[DataSync] {dataset_name} is up-to-date, skipping sync") - return (False, None, None, None) - - # Need incremental sync - sync_start = get_next_date(local_last_date) - print(f"[DataSync] Incremental sync for {dataset_name} from {sync_start} to {cal_last}") - return (True, sync_start, cal_last, local_last_date) -``` - -#### 5.3.2 完整的同步入口 - -```python -def sync_dataset( - self, - dataset_name: str, - force_full: bool = False, - max_workers: Optional[int] = None, -) -> pd.DataFrame: - """Sync a dataset with automatic incremental update. - - This is the main entry point for syncing any dataset. - - Args: - dataset_name: Name of the dataset in DATASET_CONFIG - force_full: If True, force full reload - max_workers: Number of worker threads (for stock-based sync) - - Returns: - DataFrame with synced data - """ - print("\n" + "=" * 60) - print(f"[DataSync] Starting {dataset_name} sync...") - print("=" * 60) - - # Ensure trade calendar is up-to-date - sync_trade_cal_cache() - - # Check if sync is needed - sync_needed, start_date, end_date, local_last = self.check_incremental_sync( - dataset_name, force_full - ) - - if not sync_needed: - print(f"[DataSync] {dataset_name} is up-to-date, skipping") - return pd.DataFrame() - - config = self.DATASET_CONFIG[dataset_name] - fetch_by = config["fetch_by"] - - # Fetch data based on strategy - if fetch_by == "date": - # Fetch by date (all stocks per day) - data = self.sync_by_date(dataset_name, start_date, end_date) - else: - # Fetch by stock (all dates per stock) - data = self._sync_all_stocks(dataset_name, start_date, end_date, max_workers) - - if data.empty: - print(f"[DataSync] No new data for {dataset_name}") - return pd.DataFrame() - - # Save to storage (single write) - self.storage.save(dataset_name, data, mode="append") - - print(f"[DataSync] Synced {len(data)} rows for {dataset_name}") - return data - -def _sync_all_stocks( - self, - dataset_name: str, - start_date: str, - end_date: str, - max_workers: Optional[int] = None, -) -> pd.DataFrame: - """Sync data for all stocks (stock-based fetch).""" - stock_codes = self.get_all_stock_codes() - if not stock_codes: - return pd.DataFrame() - - print(f"[DataSync] Syncing {dataset_name} for {len(stock_codes)} stocks") - - self._stop_flag.set() - results = [] - - workers = max_workers or self.max_workers - with ThreadPoolExecutor(max_workers=workers) as executor: - future_to_code = { - executor.submit( - self.sync_by_stock, dataset_name, ts_code, start_date, end_date - ): ts_code - for ts_code in stock_codes - } - - with tqdm(total=len(stock_codes), desc=f"Syncing {dataset_name}") as pbar: - for future in as_completed(future_to_code): - try: - data = future.result() - if not data.empty: - results.append(data) - except Exception as e: - executor.shutdown(wait=False, cancel_futures=True) - raise - pbar.update(1) - - if not results: - return pd.DataFrame() - - return pd.concat(results, ignore_index=True) -``` - -## 6. 存储规范 - -### 6.1 Storage 类使用 - -所有数据通过 `Storage` 类进行 HDF5 存储: - -```python -from src.data.storage import Storage - -storage = Storage() - -# 保存数据(自动增量合并) -storage.save("dataset_name", data, mode="append") - -# 加载数据 -all_data = storage.load("dataset_name") -filtered_data = storage.load("dataset_name", start_date="20240101", end_date="20240131") - -# 获取最新日期 -last_date = storage.get_last_date("dataset_name") - -# 检查是否存在 -exists = storage.exists("dataset_name") -``` - -### 6.2 增量写入策略 - -**关键原则**:所有数据在请求完成后**一次性写入**,而非逐条写入: - -```python -# ❌ 错误:逐条写入(性能差) -for date in dates: - data = fetch(date) - storage.save("dataset", data, mode="append") # 多次写入 - -# ✅ 正确:批量写入(性能好) -all_data = [] -for date in dates: - data = fetch(date) - all_data.append(data) -combined = pd.concat(all_data, ignore_index=True) -storage.save("dataset", combined, mode="append") # 单次写入 -``` - -### 6.3 去重策略 - -`Storage.save()` 方法会自动去重,基于配置中的 `key_fields`: - -```python -# storage.py 中的实现 -combined = pd.concat([existing, data], ignore_index=True) -combined = combined.drop_duplicates( - subset=["ts_code", "trade_date"], # 使用 key_fields - keep="last" # 保留最新数据 -) -``` - -## 7. 完整示例:新增涨跌停数据接口 - -### 7.1 创建 limit_list.py - -```python -"""Limit up/down list interface. - -Fetch stocks that hit limit up or limit down for a specific trade date. -This is a date-based interface (recommended approach). -""" -import pandas as pd -from typing import Optional -from src.data.client import TushareClient - - -def get_limit_list( - trade_date: Optional[str] = None, - ts_code: Optional[str] = None, - start_date: Optional[str] = None, - end_date: Optional[str] = None, -) -> pd.DataFrame: - """获取涨跌停数据。 - - **优先按日期获取**(推荐): - - 使用 trade_date 获取单日全市场涨跌停数据 - - 或使用 start_date + end_date 获取区间数据 - - Args: - trade_date: 交易日期(YYYYMMDD格式),获取单日全市场数据 - ts_code: 股票代码(可选,用于过滤) - start_date: 开始日期(YYYYMMDD格式) - end_date: 结束日期(YYYYMMDD格式) - - Returns: - pd.DataFrame 包含以下字段: - - ts_code: 股票代码 - - trade_date: 交易日期 - - name: 股票名称 - - close: 收盘价 - - pct_chg: 涨跌幅 - - amp: 振幅 - - fc_ratio: 封单金额/日成交额 - - fl_ratio: 封单手数/流通股本 - - fd_amount: 封单金额 - - first_time: 首次涨停时间 - - last_time: 最后封板时间 - - open_times: 打开次数 - - strth: 涨停强度 - - limit: 涨停类型(U涨停D跌停) - - Example: - >>> # 获取单日全市场涨跌停数据(推荐) - >>> data = get_limit_list(trade_date='20240115') - >>> # 获取区间数据 - >>> data = get_limit_list(start_date='20240101', end_date='20240131') - """ - client = TushareClient() - - params = {} - if trade_date: - params["trade_date"] = trade_date - if ts_code: - params["ts_code"] = ts_code - if start_date: - params["start_date"] = start_date - if end_date: - params["end_date"] = end_date - - data = client.query("limit_list", **params) - return data -``` - -### 7.2 在 sync.py 中注册 - -```python -class DataSync: - """Data synchronization manager with full/incremental sync support.""" - - DATASET_CONFIG = { - # ... 其他配置 ... - "limit_list": { - "api_name": "limit_list", - "fetch_by": "date", # 按日期获取 - "date_field": "trade_date", - "key_fields": ["ts_code", "trade_date"], - }, - } - - # ... 其他方法 ... - - def sync_limit_list( - self, - force_full: bool = False, - ) -> pd.DataFrame: - """Sync limit list data.""" - return self.sync_dataset("limit_list", force_full) - - -# 便捷函数 -def sync_limit_list(force_full: bool = False) -> pd.DataFrame: - """Sync limit up/down data.""" - sync_manager = DataSync() - return sync_manager.sync_limit_list(force_full) -``` - -### 7.3 更新 __init__.py - -```python -from src.data.limit_list import get_limit_list - -__all__ = [ - # ... 其他导出 ... - "get_limit_list", -] -``` - -## 8. 测试规范 - -### 8.1 测试文件结构 - -``` -tests/ -├── test_sync.py # sync 模块测试 -├── test_daily.py # daily 模块测试 -└── test_{new_module}.py # 新增模块测试 -``` - -### 8.2 测试模板 - -```python -"""Tests for {module_name} module.""" -import pytest -from unittest.mock import patch, MagicMock -import pandas as pd -from src.data.{module_name} import get_{data_type} - - -class Test{DataType}: - """Test cases for {data_type} data fetching.""" - - @patch("src.data.{module_name}.TushareClient") - def test_get_{data_type}_by_date(self, mock_client_class): - """Test fetching data by date.""" - # Setup mock - mock_client = MagicMock() - mock_client_class.return_value = mock_client - mock_client.query.return_value = pd.DataFrame({ - "ts_code": ["000001.SZ"], - "trade_date": ["20240115"], - # ... 其他字段 ... - }) - - # Call function - result = get_{data_type}(trade_date="20240115") - - # Verify - assert not result.empty - mock_client.query.assert_called_once_with( - "{api_name}", - trade_date="20240115", - ) - - @patch("src.data.{module_name}.TushareClient") - def test_get_{data_type}_by_stock(self, mock_client_class): - """Test fetching data by stock code.""" - # Setup mock - mock_client = MagicMock() - mock_client_class.return_value = mock_client - mock_client.query.return_value = pd.DataFrame({ - "ts_code": ["000001.SZ"], - "trade_date": ["20240115"], - # ... 其他字段 ... - }) - - # Call function - result = get_{data_type}( - ts_code="000001.SZ", - start_date="20240101", - end_date="20240131", - ) - - # Verify - assert not result.empty - mock_client.query.assert_called_once() -``` - -## 9. 检查清单 - -在提交新接口前,请确认以下事项: - -### 9.1 文件结构 -- [ ] 文件位于 `src/data/{data_type}.py` -- [ ] 已更新 `src/data/__init__.py` 导出公共接口 -- [ ] 已创建 `tests/test_{data_type}.py` 测试文件 - -### 9.2 接口实现 -- [ ] 数据获取函数使用 `TushareClient` -- [ ] 函数包含完整的 Google 风格文档字符串 -- [ ] 日期参数使用 `YYYYMMDD` 格式 -- [ ] 返回的 DataFrame 包含 `ts_code` 和 `trade_date` 字段 -- [ ] 优先实现按日期获取的接口(如果 API 支持) - -### 9.3 Sync 集成 -- [ ] 已在 `DataSync.DATASET_CONFIG` 中注册 -- [ ] 正确设置 `fetch_by`("date" 或 "stock") -- [ ] 正确设置 `date_field` 和 `key_fields` -- [ ] 已实现对应的 sync 方法或复用通用方法 -- [ ] 增量更新逻辑正确(检查本地最新日期) - -### 9.4 存储优化 -- [ ] 所有数据一次性写入(非逐条) -- [ ] 使用 `storage.save(mode="append")` 进行增量保存 -- [ ] 去重字段配置正确 - -### 9.5 测试 -- [ ] 已编写单元测试 -- [ ] 已 mock TushareClient -- [ ] 测试覆盖正常和异常情况 - -## 10. 常见问题 - -### Q1: API 返回的日期字段名不是 trade_date 怎么办? - -在返回前重命名: - -```python -data = client.query("api_name", **params) -if "end_date" in data.columns: - data = data.rename(columns={"end_date": "trade_date"}) -return data -``` - -### Q2: 如何处理分页(limit/offset)? - -Tushare Pro API 通常不需要手动分页,但如果需要: - -```python -all_data = [] -offset = 0 -limit = 5000 - -while True: - data = client.query( - "api_name", - trade_date=trade_date, - limit=limit, - offset=offset, - ) - if data.empty or len(data) < limit: - all_data.append(data) - break - all_data.append(data) - offset += limit - -return pd.concat(all_data, ignore_index=True) -``` - -### Q3: 如何处理需要额外参数的接口? - -在函数签名中添加参数,并传递给 client.query: - -```python -def get_data( - ts_code: str, - start_date: Optional[str] = None, - end_date: Optional[str] = None, - fields: Optional[list] = None, # 额外参数 -) -> pd.DataFrame: - params = {"ts_code": ts_code} - if start_date: - params["start_date"] = start_date - if end_date: - params["end_date"] = end_date - if fields: - params["fields"] = ",".join(fields) - - return client.query("api_name", **params) -``` - -### Q4: 如何处理没有 trade_date 字段的数据? - -如果数据确实不包含日期字段(如静态数据),可以: -1. 将其归类为"特殊接口",独立管理 -2. 或者添加一个 `sync_date` 字段记录同步时间 - -### Q5: 如何处理按日期获取但 API 不支持的情况? - -如果 API 只支持按股票获取: -1. 在 `DATASET_CONFIG` 中设置 `fetch_by: "stock"` -2. 使用 `_sync_all_stocks` 方法进行同步 -3. 在文档中说明这是按股票获取的接口 - ---- - -**最后更新**: 2026-02-01 diff --git a/src/data/__init__.py b/src/data/__init__.py index 45c5221..a3c6ac6 100644 --- a/src/data/__init__.py +++ b/src/data/__init__.py @@ -6,7 +6,7 @@ Provides simplified interfaces for fetching and storing Tushare data. from src.data.config import Config, get_config from src.data.client import TushareClient from src.data.storage import Storage -from src.data.stock_basic import get_stock_basic, sync_all_stocks +from src.data.api_wrappers import get_stock_basic, sync_all_stocks __all__ = [ "Config", diff --git a/src/data/api_wrappers/API_INTERFACE_SPEC.md b/src/data/api_wrappers/API_INTERFACE_SPEC.md new file mode 100644 index 0000000..95ccf1e --- /dev/null +++ b/src/data/api_wrappers/API_INTERFACE_SPEC.md @@ -0,0 +1,244 @@ +# ProStock 数据接口封装规范 + +## 1. 概述 + +本文档定义了新增 Tushare API 接口封装的标准规范。所有非特殊接口必须遵循此规范,确保: +- 代码风格统一 +- 自动 sync 支持 +- 增量更新逻辑一致 +- 减少存储写入压力 + +## 2. 接口分类 + +### 2.1 特殊接口(不参与统一 sync) + +以下接口有独立的同步逻辑,不参与自动 sync 机制: + +| 接口类型 | 示例 | 说明 | +|---------|------|------| +| 交易日历 | `trade_cal` | 全局数据,按日期范围获取 | +| 股票基础信息 | `stock_basic` | 一次性全量获取,CSV 存储 | +| 辅助数据 | 行业分类、概念分类 | 低频更新,独立管理 | + +### 2.2 标准接口(必须遵循本规范) + +所有按股票或按日期获取的因子数据、行情数据、财务数据等,必须遵循本规范。 + +## 3. 文件结构要求 + +### 3.1 文件命名 + +``` +{data_type}.py +``` + +示例:`daily.py`、`moneyflow.py`、`limit_list.py` + +### 3.2 文件位置 + +所有接口文件必须位于 `src/data/` 目录下。 + +### 3.3 导出要求 + +新接口必须在 `src/data/__init__.py` 中导出: + +```python +from src.data.{module_name} import get_{data_type} + +__all__ = [ + # ... 其他导出 ... + "get_{data_type}", +] +``` + +## 4. 接口设计规范 + +### 4.1 数据获取函数签名要求 + +函数必须返回 `pd.DataFrame`,参数必须包含以下之一: + +#### 4.1.1 按日期获取的接口(优先) + +适用于:涨跌停、龙虎榜、筹码分布等。 + +**函数签名要求**: + +```python +def get_{data_type}( + trade_date: Optional[str] = None, + start_date: Optional[str] = None, + end_date: Optional[str] = None, + ts_code: Optional[str] = None, + # 其他可选参数... +) -> pd.DataFrame: +``` + +**要求**: +- 优先使用 `trade_date` 获取单日全市场数据 +- 支持 `start_date + end_date` 获取区间数据 +- `ts_code` 作为可选过滤参数 + +#### 4.1.2 按股票获取的接口 + +适用于:日线行情、资金流向等。 + +**函数签名要求**: + +```python +def get_{data_type}( + ts_code: str, + start_date: Optional[str] = None, + end_date: Optional[str] = None, + # 其他可选参数... +) -> pd.DataFrame: +``` + +### 4.2 文档字符串要求 + +函数必须包含 Google 风格的完整文档字符串,包含: +- 函数功能描述 +- `Args` 部分:所有参数说明 +- `Returns` 部分:返回的 DataFrame 包含的字段说明 +- `Example` 部分:使用示例 + +### 4.3 日期格式要求 + +- 所有日期参数和返回值使用 `YYYYMMDD` 字符串格式 +- 统一使用 `trade_date` 作为日期字段名 +- 如果 API 返回其他日期字段名(如 `date`、`end_date`),必须在返回前重命名为 `trade_date` + +### 4.4 股票代码要求 + +- 统一使用 `ts_code` 作为股票代码字段名 +- 格式:`{code}.{exchange}`,如 `000001.SZ`、`600000.SH` + +### 4.5 令牌桶限速要求 + +所有 API 调用必须通过 `TushareClient`,自动满足令牌桶限速要求。 + +## 5. Sync 集成规范 + +### 5.1 DATASET_CONFIG 注册要求 + +新接口必须在 `DataSync.DATASET_CONFIG` 中注册,配置项: + +```python +"{new_data_type}": { + "api_name": "{tushare_api_name}", # Tushare API 名称 + "fetch_by": "date", # "date" 或 "stock" + "date_field": "trade_date", + "key_fields": ["ts_code", "trade_date"], # 用于去重的主键 +} +``` + +### 5.2 fetch_by 取值规则 + +- **优先使用 `"date"`**:如果 API 支持按日期获取全市场数据 +- 仅当 API 不支持按日期获取时才使用 `"stock"` + +### 5.3 sync 方法要求 + +必须实现对应的 sync 方法或复用通用方法: + +```python +def sync_{data_type}(self, force_full: bool = False) -> pd.DataFrame: + """Sync {数据描述}。""" + return self.sync_dataset("{data_type}", force_full) +``` + +同时提供便捷函数: + +```python +def sync_{data_type}(force_full: bool = False) -> pd.DataFrame: + """Sync {数据描述}。""" + sync_manager = DataSync() + return sync_manager.sync_{data_type}(force_full) +``` + +### 5.4 增量更新要求 + +- 必须实现增量更新逻辑(自动检查本地最新日期) +- 使用 `force_full` 参数支持强制全量同步 + +## 6. 存储规范 + +### 6.1 存储方式 + +所有数据通过 `Storage` 类进行 HDF5 存储。 + +### 6.2 写入策略 + +**要求**:所有数据在请求完成后**一次性写入**,而非逐条写入。 + +### 6.3 去重要求 + +使用 `key_fields` 配置的字段进行去重,默认使用 `["ts_code", "trade_date"]`。 + +## 7. 测试规范 + +### 7.1 测试文件要求 + +必须创建对应的测试文件:`tests/test_{data_type}.py` + +### 7.2 测试覆盖要求 + +- 测试按日期获取 +- 测试按股票获取(如果支持) +- 必须 mock `TushareClient` +- 测试覆盖正常和异常情况 + +## 8. 新增接口完整流程 + +### 8.1 创建接口文件 + +1. 在 `src/data/` 下创建 `{data_type}.py` +2. 实现数据获取函数,遵循第 4 节规范 + +### 8.2 注册 sync 支持 + +1. 在 `sync.py` 的 `DataSync.DATASET_CONFIG` 中注册 +2. 实现对应的 sync 方法 +3. 提供便捷函数 + +### 8.3 更新导出 + +在 `src/data/__init__.py` 中导出接口函数。 + +### 8.4 创建测试 + +创建 `tests/test_{data_type}.py`,覆盖关键场景。 + +## 9. 检查清单 + +### 9.1 文件结构 +- [ ] 文件位于 `src/data/{data_type}.py` +- [ ] 已更新 `src/data/__init__.py` 导出公共接口 +- [ ] 已创建 `tests/test_{data_type}.py` 测试文件 + +### 9.2 接口实现 +- [ ] 数据获取函数使用 `TushareClient` +- [ ] 函数包含完整的 Google 风格文档字符串 +- [ ] 日期参数使用 `YYYYMMDD` 格式 +- [ ] 返回的 DataFrame 包含 `ts_code` 和 `trade_date` 字段 +- [ ] 优先实现按日期获取的接口(如果 API 支持) + +### 9.3 Sync 集成 +- [ ] 已在 `DataSync.DATASET_CONFIG` 中注册 +- [ ] 正确设置 `fetch_by`("date" 或 "stock") +- [ ] 正确设置 `date_field` 和 `key_fields` +- [ ] 已实现对应的 sync 方法或复用通用方法 +- [ ] 增量更新逻辑正确(检查本地最新日期) + +### 9.4 存储优化 +- [ ] 所有数据一次性写入(非逐条) +- [ ] 使用 `storage.save(mode="append")` 进行增量保存 +- [ ] 去重字段配置正确 + +### 9.5 测试 +- [ ] 已编写单元测试 +- [ ] 已 mock TushareClient +- [ ] 测试覆盖正常和异常情况 + +--- + +**最后更新**: 2026-02-01 diff --git a/src/data/api_wrappers/__init__.py b/src/data/api_wrappers/__init__.py new file mode 100644 index 0000000..cd64427 --- /dev/null +++ b/src/data/api_wrappers/__init__.py @@ -0,0 +1,40 @@ +"""Tushare API wrapper modules. + +This package contains simplified interfaces for fetching data from Tushare API. +All wrapper files follow the naming convention: api_{data_type}.py + +Available APIs: + - api_daily: Daily market data (日线行情) + - api_stock_basic: Stock basic information (股票基本信息) + - api_trade_cal: Trading calendar (交易日历) + +Example: + >>> from src.data.api_wrappers import get_daily, get_stock_basic, get_trade_cal + >>> data = get_daily('000001.SZ', start_date='20240101', end_date='20240131') + >>> stocks = get_stock_basic() + >>> calendar = get_trade_cal('20240101', '20240131') +""" + +from src.data.api_wrappers.api_daily import get_daily +from src.data.api_wrappers.api_stock_basic import get_stock_basic, sync_all_stocks +from src.data.api_wrappers.api_trade_cal import ( + get_trade_cal, + get_trading_days, + get_first_trading_day, + get_last_trading_day, + sync_trade_cal_cache, +) + +__all__ = [ + # Daily market data + "get_daily", + # Stock basic information + "get_stock_basic", + "sync_all_stocks", + # Trade calendar + "get_trade_cal", + "get_trading_days", + "get_first_trading_day", + "get_last_trading_day", + "sync_trade_cal_cache", +] diff --git a/src/data/api.md b/src/data/api_wrappers/api.md similarity index 71% rename from src/data/api.md rename to src/data/api_wrappers/api.md index ba96d04..4c3835b 100644 --- a/src/data/api.md +++ b/src/data/api_wrappers/api.md @@ -179,4 +179,75 @@ df = pro.query('trade_cal', start_date='20180101', end_date='20181231') 17 SSE 20180118 1 18 SSE 20180119 1 19 SSE 20180120 0 -20 SSE 20180121 0 \ No newline at end of file +20 SSE 20180121 0 + + +每日指标 +接口:daily_basic,可以通过数据工具调试和查看数据。 +更新时间:交易日每日15点~17点之间 +描述:获取全部股票每日重要的基本面指标,可用于选股分析、报表展示等。单次请求最大返回6000条数据,可按日线循环提取全部历史。 +积分:至少2000积分才可以调取,5000积分无总量限制,具体请参阅积分获取办法 + +输入参数 + +名称 类型 必选 描述 +ts_code str Y 股票代码(二选一) +trade_date str N 交易日期 (二选一) +start_date str N 开始日期(YYYYMMDD) +end_date str N 结束日期(YYYYMMDD) +注:日期都填YYYYMMDD格式,比如20181010 + +输出参数 + +名称 类型 描述 +ts_code str TS股票代码 +trade_date str 交易日期 +close float 当日收盘价 +turnover_rate float 换手率(%) +turnover_rate_f float 换手率(自由流通股) +volume_ratio float 量比 +pe float 市盈率(总市值/净利润, 亏损的PE为空) +pe_ttm float 市盈率(TTM,亏损的PE为空) +pb float 市净率(总市值/净资产) +ps float 市销率 +ps_ttm float 市销率(TTM) +dv_ratio float 股息率 (%) +dv_ttm float 股息率(TTM)(%) +total_share float 总股本 (万股) +float_share float 流通股本 (万股) +free_share float 自由流通股本 (万) +total_mv float 总市值 (万元) +circ_mv float 流通市值(万元) +接口用法 + + +pro = ts.pro_api() + +df = pro.daily_basic(ts_code='', trade_date='20180726', fields='ts_code,trade_date,turnover_rate,volume_ratio,pe,pb') +或者 + + +df = pro.query('daily_basic', ts_code='', trade_date='20180726',fields='ts_code,trade_date,turnover_rate,volume_ratio,pe,pb') +数据样例 + + ts_code trade_date turnover_rate volume_ratio pe pb +0 600230.SH 20180726 2.4584 0.72 8.6928 3.7203 +1 600237.SH 20180726 1.4737 0.88 166.4001 1.8868 +2 002465.SZ 20180726 0.7489 0.72 71.8943 2.6391 +3 300732.SZ 20180726 6.7083 0.77 21.8101 3.2513 +4 600007.SH 20180726 0.0381 0.61 23.7696 2.3774 +5 300068.SZ 20180726 1.4583 0.52 27.8166 1.7549 +6 300552.SZ 20180726 2.0728 0.95 56.8004 2.9279 +7 601369.SH 20180726 0.2088 0.95 44.1163 1.8001 +8 002518.SZ 20180726 0.5814 0.76 15.1004 2.5626 +9 002913.SZ 20180726 12.1096 1.03 33.1279 2.9217 +10 601818.SH 20180726 0.1893 0.86 6.3064 0.7209 +11 600926.SH 20180726 0.6065 0.46 9.1772 0.9808 +12 002166.SZ 20180726 0.7582 0.82 16.9868 3.3452 +13 600841.SH 20180726 0.3754 1.02 66.2647 2.2302 +14 300634.SZ 20180726 23.1127 1.26 120.3053 14.3168 +15 300126.SZ 20180726 1.2304 1.11 348.4306 1.5171 +16 300718.SZ 20180726 17.6612 0.92 32.0239 3.8661 +17 000708.SZ 20180726 0.5575 0.70 10.3674 1.0276 +18 002626.SZ 20180726 0.6187 0.83 22.7580 4.2446 +19 600816.SH 20180726 0.6745 0.65 11.0778 3.2214 \ No newline at end of file diff --git a/src/data/daily.py b/src/data/api_wrappers/api_daily.py similarity index 99% rename from src/data/daily.py rename to src/data/api_wrappers/api_daily.py index bc7c86e..3be4b3f 100644 --- a/src/data/daily.py +++ b/src/data/api_wrappers/api_daily.py @@ -3,6 +3,7 @@ A single function to fetch A股日线行情 data from Tushare. Supports all output fields including tor (换手率) and vr (量比). """ + import pandas as pd from typing import Optional, List, Literal from src.data.client import TushareClient @@ -33,7 +34,7 @@ def get_daily( Returns: pd.DataFrame with daily market data containing: - - Base fields: ts_code, trade_date, open, high, low, close, pre_close, + - Base fields: ts_code, trade_date, open, high, low, close, pre_close, change, pct_chg, vol, amount - Factor fields (if requested): tor, vr - Adjustment factor (if adjfactor=True): adjfactor diff --git a/src/data/stock_basic.py b/src/data/api_wrappers/api_stock_basic.py similarity index 99% rename from src/data/stock_basic.py rename to src/data/api_wrappers/api_stock_basic.py index b8f36b3..0cdd48b 100644 --- a/src/data/stock_basic.py +++ b/src/data/api_wrappers/api_stock_basic.py @@ -3,6 +3,7 @@ Fetch basic stock information including code, name, listing date, etc. This is a special interface - call once to get all stocks (listed and delisted). """ + import os import pandas as pd from pathlib import Path diff --git a/src/data/trade_cal.py b/src/data/api_wrappers/api_trade_cal.py similarity index 100% rename from src/data/trade_cal.py rename to src/data/api_wrappers/api_trade_cal.py diff --git a/src/data/storage.py b/src/data/storage.py index c9904f6..af15d55 100644 --- a/src/data/storage.py +++ b/src/data/storage.py @@ -1,4 +1,5 @@ """Simplified HDF5 storage for data persistence.""" + import os import pandas as pd from pathlib import Path @@ -47,7 +48,9 @@ class Storage: # Merge with existing data existing = store[name] combined = pd.concat([existing, data], ignore_index=True) - combined = combined.drop_duplicates(subset=["ts_code", "trade_date"], keep="last") + combined = combined.drop_duplicates( + subset=["ts_code", "trade_date"], keep="last" + ) store.put(name, combined, format="table") print(f"[Storage] Saved {len(data)} rows to {file_path}") @@ -57,10 +60,13 @@ class Storage: print(f"[Storage] Error saving {name}: {e}") return {"status": "error", "error": str(e)} - def load(self, name: str, - start_date: Optional[str] = None, - end_date: Optional[str] = None, - ts_code: Optional[str] = None) -> pd.DataFrame: + def load( + self, + name: str, + start_date: Optional[str] = None, + end_date: Optional[str] = None, + ts_code: Optional[str] = None, + ) -> pd.DataFrame: """Load data from HDF5 file. Args: @@ -80,14 +86,25 @@ class Storage: try: with pd.HDFStore(file_path, mode="r") as store: - if name not in store.keys(): + keys = store.keys() + # Handle both '/daily' and 'daily' keys + actual_key = None + if name in keys: + actual_key = name + elif f"/{name}" in keys: + actual_key = f"/{name}" + + if actual_key is None: return pd.DataFrame() - data = store[name] + data = store[actual_key] # Apply filters if start_date and end_date and "trade_date" in data.columns: - data = data[(data["trade_date"] >= start_date) & (data["trade_date"] <= end_date)] + data = data[ + (data["trade_date"] >= start_date) + & (data["trade_date"] <= end_date) + ] if ts_code and "ts_code" in data.columns: data = data[data["ts_code"] == ts_code] diff --git a/src/data/sync.py b/src/data/sync.py index ee0ef3a..e414ded 100644 --- a/src/data/sync.py +++ b/src/data/sync.py @@ -5,11 +5,15 @@ This module provides data fetching functions with intelligent sync logic: - If local file exists: incremental update (fetch from latest date + 1 day) - Multi-threaded concurrent fetching for improved performance - Stop immediately on any exception +- Preview mode: check data volume and samples before actual sync Currently supported data types: - daily: Daily market data (with turnover rate and volume ratio) Usage: + # Preview sync (check data volume and samples without writing) + preview_sync() + # Sync all stocks (full load) sync_all() @@ -18,6 +22,9 @@ Usage: # Force full reload sync_all(force_full=True) + + # Dry run (preview only, no write) + sync_all(dry_run=True) """ import pandas as pd @@ -30,8 +37,8 @@ import sys from src.data.client import TushareClient from src.data.storage import Storage -from src.data.daily import get_daily -from src.data.trade_cal import ( +from src.data.api_wrappers import get_daily +from src.data.api_wrappers import ( get_first_trading_day, get_last_trading_day, sync_trade_cal_cache, @@ -114,7 +121,8 @@ class DataSync: List of stock codes """ # Import sync_all_stocks here to avoid circular imports - from src.data.stock_basic import sync_all_stocks, _get_csv_path + from src.data.api_wrappers import sync_all_stocks + from src.data.api_wrappers.api_stock_basic import _get_csv_path # First, ensure stock_basic.csv is up-to-date with all stocks print("[DataSync] Ensuring stock_basic.csv is up-to-date...") @@ -278,6 +286,184 @@ class DataSync: print(f"[DataSync] Incremental sync needed from {sync_start} to {cal_last}") return (True, sync_start, cal_last, local_last_date) + def preview_sync( + self, + force_full: bool = False, + start_date: Optional[str] = None, + end_date: Optional[str] = None, + sample_size: int = 3, + ) -> dict: + """Preview sync data volume and samples without actually syncing. + + This method provides a preview of what would be synced, including: + - Number of stocks to be synced + - Date range for sync + - Estimated total records + - Sample data from first few stocks + + Args: + force_full: If True, preview full sync from 20180101 + start_date: Manual start date (overrides auto-detection) + end_date: Manual end date (defaults to today) + sample_size: Number of sample stocks to fetch for preview (default: 3) + + Returns: + Dictionary with preview information: + { + 'sync_needed': bool, + 'stock_count': int, + 'start_date': str, + 'end_date': str, + 'estimated_records': int, + 'sample_data': pd.DataFrame, + 'mode': str, # 'full' or 'incremental' + } + """ + print("\n" + "=" * 60) + print("[DataSync] Preview Mode - Analyzing sync requirements...") + print("=" * 60) + + # First, ensure trade calendar cache is up-to-date + print("[DataSync] Syncing trade calendar cache...") + sync_trade_cal_cache() + + # Determine date range + if end_date is None: + end_date = get_today_date() + + # Check if sync is needed + sync_needed, cal_start, cal_end, local_last = self.check_sync_needed(force_full) + + if not sync_needed: + print("\n" + "=" * 60) + print("[DataSync] Preview Result") + print("=" * 60) + print(" Sync Status: NOT NEEDED") + print(" Reason: Local data is up-to-date with trade calendar") + print("=" * 60) + return { + "sync_needed": False, + "stock_count": 0, + "start_date": None, + "end_date": None, + "estimated_records": 0, + "sample_data": pd.DataFrame(), + "mode": "none", + } + + # Use dates from check_sync_needed + if cal_start and cal_end: + sync_start_date = cal_start + end_date = cal_end + else: + sync_start_date = start_date or DEFAULT_START_DATE + if end_date is None: + end_date = get_today_date() + + # Determine sync mode + if force_full: + mode = "full" + print(f"[DataSync] Mode: FULL SYNC from {sync_start_date} to {end_date}") + elif local_last and cal_start and sync_start_date == get_next_date(local_last): + mode = "incremental" + print(f"[DataSync] Mode: INCREMENTAL SYNC (bandwidth optimized)") + print(f"[DataSync] Sync from: {sync_start_date} to {end_date}") + else: + mode = "partial" + print(f"[DataSync] Mode: SYNC from {sync_start_date} to {end_date}") + + # Get all stock codes + stock_codes = self.get_all_stock_codes() + if not stock_codes: + print("[DataSync] No stocks found to sync") + return { + "sync_needed": False, + "stock_count": 0, + "start_date": None, + "end_date": None, + "estimated_records": 0, + "sample_data": pd.DataFrame(), + "mode": "none", + } + + stock_count = len(stock_codes) + print(f"[DataSync] Total stocks to sync: {stock_count}") + + # Fetch sample data from first few stocks + print(f"[DataSync] Fetching sample data from {sample_size} stocks...") + sample_data_list = [] + sample_codes = stock_codes[:sample_size] + + for ts_code in sample_codes: + try: + data = self.client.query( + "pro_bar", + ts_code=ts_code, + start_date=sync_start_date, + end_date=end_date, + factors="tor,vr", + ) + if not data.empty: + sample_data_list.append(data) + print(f" - {ts_code}: {len(data)} records") + except Exception as e: + print(f" - {ts_code}: Error fetching - {e}") + + # Combine sample data + sample_df = ( + pd.concat(sample_data_list, ignore_index=True) + if sample_data_list + else pd.DataFrame() + ) + + # Estimate total records based on sample + if not sample_df.empty: + avg_records_per_stock = len(sample_df) / len(sample_data_list) + estimated_records = int(avg_records_per_stock * stock_count) + else: + estimated_records = 0 + + # Display preview results + print("\n" + "=" * 60) + print("[DataSync] Preview Result") + print("=" * 60) + print(f" Sync Mode: {mode.upper()}") + print(f" Date Range: {sync_start_date} to {end_date}") + print(f" Stocks to Sync: {stock_count}") + print(f" Sample Stocks Checked: {len(sample_data_list)}/{sample_size}") + print(f" Estimated Total Records: ~{estimated_records:,}") + + if not sample_df.empty: + print(f"\n Sample Data Preview (first {len(sample_df)} rows):") + print(" " + "-" * 56) + # Display sample data in a compact format + preview_cols = [ + "ts_code", + "trade_date", + "open", + "high", + "low", + "close", + "vol", + ] + available_cols = [c for c in preview_cols if c in sample_df.columns] + sample_display = sample_df[available_cols].head(10) + for idx, row in sample_display.iterrows(): + print(f" {row.to_dict()}") + print(" " + "-" * 56) + + print("=" * 60) + + return { + "sync_needed": True, + "stock_count": stock_count, + "start_date": sync_start_date, + "end_date": end_date, + "estimated_records": estimated_records, + "sample_data": sample_df, + "mode": mode, + } + def sync_single_stock( self, ts_code: str, @@ -320,6 +506,7 @@ class DataSync: start_date: Optional[str] = None, end_date: Optional[str] = None, max_workers: Optional[int] = None, + dry_run: bool = False, ) -> Dict[str, pd.DataFrame]: """Sync daily data for all stocks in local storage. @@ -337,9 +524,10 @@ class DataSync: start_date: Manual start date (overrides auto-detection) end_date: Manual end date (defaults to today) max_workers: Number of worker threads (default: 10) + dry_run: If True, only preview what would be synced without writing data Returns: - Dict mapping ts_code to DataFrame (empty if sync skipped) + Dict mapping ts_code to DataFrame (empty if sync skipped or dry_run) """ print("\n" + "=" * 60) print("[DataSync] Starting daily data sync...") @@ -378,11 +566,14 @@ class DataSync: # Determine sync mode if force_full: + mode = "full" print(f"[DataSync] Mode: FULL SYNC from {sync_start_date} to {end_date}") elif local_last and cal_start and sync_start_date == get_next_date(local_last): + mode = "incremental" print(f"[DataSync] Mode: INCREMENTAL SYNC (bandwidth optimized)") print(f"[DataSync] Sync from: {sync_start_date} to {end_date}") else: + mode = "partial" print(f"[DataSync] Mode: SYNC from {sync_start_date} to {end_date}") # Get all stock codes @@ -394,6 +585,17 @@ class DataSync: print(f"[DataSync] Total stocks to sync: {len(stock_codes)}") print(f"[DataSync] Using {max_workers or self.max_workers} worker threads") + # Handle dry run mode + if dry_run: + print("\n" + "=" * 60) + print("[DataSync] DRY RUN MODE - No data will be written") + print("=" * 60) + print(f" Would sync {len(stock_codes)} stocks") + print(f" Date range: {sync_start_date} to {end_date}") + print(f" Mode: {mode}") + print("=" * 60) + return {} + # Reset stop flag for new sync self._stop_flag.set() @@ -492,11 +694,62 @@ class DataSync: # Convenience functions +def preview_sync( + force_full: bool = False, + start_date: Optional[str] = None, + end_date: Optional[str] = None, + sample_size: int = 3, + max_workers: Optional[int] = None, +) -> dict: + """Preview sync data volume and samples without actually syncing. + + This is the recommended way to check what would be synced before + running the actual synchronization. + + Args: + force_full: If True, preview full sync from 20180101 + start_date: Manual start date (overrides auto-detection) + end_date: Manual end date (defaults to today) + sample_size: Number of sample stocks to fetch for preview (default: 3) + max_workers: Number of worker threads (not used in preview, for API compatibility) + + Returns: + Dictionary with preview information: + { + 'sync_needed': bool, + 'stock_count': int, + 'start_date': str, + 'end_date': str, + 'estimated_records': int, + 'sample_data': pd.DataFrame, + 'mode': str, # 'full', 'incremental', 'partial', or 'none' + } + + Example: + >>> # Preview what would be synced + >>> preview = preview_sync() + >>> + >>> # Preview full sync + >>> preview = preview_sync(force_full=True) + >>> + >>> # Preview with more samples + >>> preview = preview_sync(sample_size=5) + """ + sync_manager = DataSync(max_workers=max_workers) + return sync_manager.preview_sync( + force_full=force_full, + start_date=start_date, + end_date=end_date, + sample_size=sample_size, + ) + + def sync_all( force_full: bool = False, start_date: Optional[str] = None, end_date: Optional[str] = None, max_workers: Optional[int] = None, + dry_run: bool = False, ) -> Dict[str, pd.DataFrame]: """Sync daily data for all stocks. @@ -507,6 +760,7 @@ def sync_all( start_date: Manual start date (YYYYMMDD) end_date: Manual end date (defaults to today) max_workers: Number of worker threads (default: 10) + dry_run: If True, only preview what would be synced without writing data Returns: Dict mapping ts_code to DataFrame @@ -526,12 +780,16 @@ def sync_all( >>> >>> # Custom thread count >>> result = sync_all(max_workers=20) + >>> + >>> # Dry run (preview only) + >>> result = sync_all(dry_run=True) """ sync_manager = DataSync(max_workers=max_workers) return sync_manager.sync_all( force_full=force_full, start_date=start_date, end_date=end_date, + dry_run=dry_run, ) @@ -540,11 +798,32 @@ if __name__ == "__main__": print("Data Sync Module") print("=" * 60) print("\nUsage:") - print(" from src.data.sync import sync_all") + print(" from src.data.sync import sync_all, preview_sync") + print("") + print(" # Preview before sync (recommended)") + print(" preview = preview_sync()") + print("") + print(" # Dry run (preview only)") + print(" result = sync_all(dry_run=True)") + print("") + print(" # Actual sync") print(" result = sync_all() # Incremental sync") print(" result = sync_all(force_full=True) # Full reload") print("\n" + "=" * 60) - # Run sync - result = sync_all() - print(f"\nSynced {len(result)} stocks") + # Run preview first + print("\n[Main] Running preview first...") + preview = preview_sync() + + if preview["sync_needed"]: + # Ask for confirmation + print("\n" + "=" * 60) + response = input("Proceed with sync? (y/n): ").strip().lower() + if response in ("y", "yes"): + print("\n[Main] Starting actual sync...") + result = sync_all() + print(f"\nSynced {len(result)} stocks") + else: + print("\n[Main] Sync cancelled by user") + else: + print("\n[Main] No sync needed - data is up to date") diff --git a/tests/test_daily.py b/tests/test_daily.py index 9f775fd..648f4ba 100644 --- a/tests/test_daily.py +++ b/tests/test_daily.py @@ -5,29 +5,30 @@ Tests the daily interface implementation against api.md requirements: - tor 换手率 - vr 量比 """ + import pytest import pandas as pd -from src.data.daily import get_daily +from src.data.api_wrappers import get_daily # Expected output fields according to api.md EXPECTED_BASE_FIELDS = [ - 'ts_code', # 股票代码 - 'trade_date', # 交易日期 - 'open', # 开盘价 - 'high', # 最高价 - 'low', # 最低价 - 'close', # 收盘价 - 'pre_close', # 昨收价 - 'change', # 涨跌额 - 'pct_chg', # 涨跌幅 - 'vol', # 成交量 - 'amount', # 成交额 + "ts_code", # 股票代码 + "trade_date", # 交易日期 + "open", # 开盘价 + "high", # 最高价 + "low", # 最低价 + "close", # 收盘价 + "pre_close", # 昨收价 + "change", # 涨跌额 + "pct_chg", # 涨跌幅 + "vol", # 成交量 + "amount", # 成交额 ] EXPECTED_FACTOR_FIELDS = [ - 'turnover_rate', # 换手率 (tor) - 'volume_ratio', # 量比 (vr) + "turnover_rate", # 换手率 (tor) + "volume_ratio", # 量比 (vr) ] @@ -36,19 +37,19 @@ class TestGetDaily: def test_fetch_basic(self): """Test basic daily data fetch with real API.""" - result = get_daily('000001.SZ', start_date='20240101', end_date='20240131') + result = get_daily("000001.SZ", start_date="20240101", end_date="20240131") assert isinstance(result, pd.DataFrame) assert len(result) >= 1 - assert result['ts_code'].iloc[0] == '000001.SZ' + assert result["ts_code"].iloc[0] == "000001.SZ" def test_fetch_with_factors(self): """Test fetch with tor and vr factors.""" result = get_daily( - '000001.SZ', - start_date='20240101', - end_date='20240131', - factors=['tor', 'vr'], + "000001.SZ", + start_date="20240101", + end_date="20240131", + factors=["tor", "vr"], ) assert isinstance(result, pd.DataFrame) @@ -61,25 +62,26 @@ class TestGetDaily: def test_output_fields_completeness(self): """Verify all required output fields are returned.""" - result = get_daily('600000.SH') + result = get_daily("600000.SH") # Verify all base fields are present - assert set(EXPECTED_BASE_FIELDS).issubset(result.columns.tolist()), \ + assert set(EXPECTED_BASE_FIELDS).issubset(result.columns.tolist()), ( f"Missing fields: {set(EXPECTED_BASE_FIELDS) - set(result.columns)}" + ) def test_empty_result(self): """Test handling of empty results.""" # 使用真实 API 测试无效股票代码的空结果 - result = get_daily('INVALID.SZ') + result = get_daily("INVALID.SZ") assert isinstance(result, pd.DataFrame) assert result.empty def test_date_range_query(self): """Test query with date range.""" result = get_daily( - '000001.SZ', - start_date='20240101', - end_date='20240131', + "000001.SZ", + start_date="20240101", + end_date="20240131", ) assert isinstance(result, pd.DataFrame) @@ -87,7 +89,7 @@ class TestGetDaily: def test_with_adj(self): """Test fetch with adjustment type.""" - result = get_daily('000001.SZ', adj='qfq') + result = get_daily("000001.SZ", adj="qfq") assert isinstance(result, pd.DataFrame) @@ -95,11 +97,14 @@ class TestGetDaily: def test_integration(): """Integration test with real Tushare API (requires valid token).""" import os - token = os.environ.get('TUSHARE_TOKEN') + + token = os.environ.get("TUSHARE_TOKEN") if not token: pytest.skip("TUSHARE_TOKEN not configured") - result = get_daily('000001.SZ', start_date='20240101', end_date='20240131', factors=['tor', 'vr']) + result = get_daily( + "000001.SZ", start_date="20240101", end_date="20240131", factors=["tor", "vr"] + ) # Verify structure assert isinstance(result, pd.DataFrame) @@ -112,6 +117,6 @@ def test_integration(): assert field in result.columns, f"Missing factor field: {field}" -if __name__ == '__main__': +if __name__ == "__main__": # 运行 pytest 单元测试(真实API调用) - pytest.main([__file__, '-v']) + pytest.main([__file__, "-v"]) diff --git a/tests/test_daily_storage.py b/tests/test_daily_storage.py index fcd5048..cb848a3 100644 --- a/tests/test_daily_storage.py +++ b/tests/test_daily_storage.py @@ -9,7 +9,7 @@ import pytest import pandas as pd from pathlib import Path from src.data.storage import Storage -from src.data.stock_basic import _get_csv_path +from src.data.api_wrappers.api_stock_basic import _get_csv_path class TestDailyStorageValidation: diff --git a/tests/test_sync.py b/tests/test_sync.py index f336562..ce1ac72 100644 --- a/tests/test_sync.py +++ b/tests/test_sync.py @@ -5,6 +5,7 @@ Tests the sync module's full/incremental sync logic for daily data: - Incremental sync when local data exists (from last_date + 1) - Data integrity validation """ + import pytest import pandas as pd from unittest.mock import Mock, patch, MagicMock @@ -17,6 +18,8 @@ from src.data.sync import ( get_next_date, DEFAULT_START_DATE, ) +from src.data.storage import Storage +from src.data.client import TushareClient class TestDateUtilities: @@ -63,30 +66,32 @@ class TestDataSync: def test_get_all_stock_codes_from_daily(self, mock_storage): """Test getting stock codes from daily data.""" - with patch('src.data.sync.Storage', return_value=mock_storage): + with patch("src.data.sync.Storage", return_value=mock_storage): sync = DataSync() sync.storage = mock_storage - mock_storage.load.return_value = pd.DataFrame({ - 'ts_code': ['000001.SZ', '000001.SZ', '600000.SH'], - }) + mock_storage.load.return_value = pd.DataFrame( + { + "ts_code": ["000001.SZ", "000001.SZ", "600000.SH"], + } + ) codes = sync.get_all_stock_codes() assert len(codes) == 2 - assert '000001.SZ' in codes - assert '600000.SH' in codes + assert "000001.SZ" in codes + assert "600000.SH" in codes def test_get_all_stock_codes_fallback(self, mock_storage): """Test fallback to stock_basic when daily is empty.""" - with patch('src.data.sync.Storage', return_value=mock_storage): + with patch("src.data.sync.Storage", return_value=mock_storage): sync = DataSync() sync.storage = mock_storage # First call (daily) returns empty, second call (stock_basic) returns data mock_storage.load.side_effect = [ pd.DataFrame(), # daily empty - pd.DataFrame({'ts_code': ['000001.SZ', '600000.SH']}), # stock_basic + pd.DataFrame({"ts_code": ["000001.SZ", "600000.SH"]}), # stock_basic ] codes = sync.get_all_stock_codes() @@ -95,21 +100,23 @@ class TestDataSync: def test_get_global_last_date(self, mock_storage): """Test getting global last date.""" - with patch('src.data.sync.Storage', return_value=mock_storage): + with patch("src.data.sync.Storage", return_value=mock_storage): sync = DataSync() sync.storage = mock_storage - mock_storage.load.return_value = pd.DataFrame({ - 'ts_code': ['000001.SZ', '600000.SH'], - 'trade_date': ['20240102', '20240103'], - }) + mock_storage.load.return_value = pd.DataFrame( + { + "ts_code": ["000001.SZ", "600000.SH"], + "trade_date": ["20240102", "20240103"], + } + ) last_date = sync.get_global_last_date() - assert last_date == '20240103' + assert last_date == "20240103" def test_get_global_last_date_empty(self, mock_storage): """Test getting last date from empty storage.""" - with patch('src.data.sync.Storage', return_value=mock_storage): + with patch("src.data.sync.Storage", return_value=mock_storage): sync = DataSync() sync.storage = mock_storage @@ -120,18 +127,23 @@ class TestDataSync: def test_sync_single_stock(self, mock_storage): """Test syncing a single stock.""" - with patch('src.data.sync.Storage', return_value=mock_storage): - with patch('src.data.sync.get_daily', return_value=pd.DataFrame({ - 'ts_code': ['000001.SZ'], - 'trade_date': ['20240102'], - })): + with patch("src.data.sync.Storage", return_value=mock_storage): + with patch( + "src.data.sync.get_daily", + return_value=pd.DataFrame( + { + "ts_code": ["000001.SZ"], + "trade_date": ["20240102"], + } + ), + ): sync = DataSync() sync.storage = mock_storage result = sync.sync_single_stock( - ts_code='000001.SZ', - start_date='20240101', - end_date='20240102', + ts_code="000001.SZ", + start_date="20240101", + end_date="20240102", ) assert isinstance(result, pd.DataFrame) @@ -139,15 +151,15 @@ class TestDataSync: def test_sync_single_stock_empty(self, mock_storage): """Test syncing a stock with no data.""" - with patch('src.data.sync.Storage', return_value=mock_storage): - with patch('src.data.sync.get_daily', return_value=pd.DataFrame()): + with patch("src.data.sync.Storage", return_value=mock_storage): + with patch("src.data.sync.get_daily", return_value=pd.DataFrame()): sync = DataSync() sync.storage = mock_storage result = sync.sync_single_stock( - ts_code='INVALID.SZ', - start_date='20240101', - end_date='20240102', + ts_code="INVALID.SZ", + start_date="20240101", + end_date="20240102", ) assert result.empty @@ -158,40 +170,46 @@ class TestSyncAll: def test_full_sync_mode(self, mock_storage): """Test full sync mode when force_full=True.""" - with patch('src.data.sync.Storage', return_value=mock_storage): - with patch('src.data.sync.get_daily', return_value=pd.DataFrame()): + with patch("src.data.sync.Storage", return_value=mock_storage): + with patch("src.data.sync.get_daily", return_value=pd.DataFrame()): sync = DataSync() sync.storage = mock_storage sync.sync_single_stock = Mock(return_value=pd.DataFrame()) - mock_storage.load.return_value = pd.DataFrame({ - 'ts_code': ['000001.SZ'], - }) + mock_storage.load.return_value = pd.DataFrame( + { + "ts_code": ["000001.SZ"], + } + ) result = sync.sync_all(force_full=True) # Verify sync_single_stock was called with default start date sync.sync_single_stock.assert_called_once() call_args = sync.sync_single_stock.call_args - assert call_args[1]['start_date'] == DEFAULT_START_DATE + assert call_args[1]["start_date"] == DEFAULT_START_DATE def test_incremental_sync_mode(self, mock_storage): """Test incremental sync mode when data exists.""" - with patch('src.data.sync.Storage', return_value=mock_storage): + with patch("src.data.sync.Storage", return_value=mock_storage): sync = DataSync() sync.storage = mock_storage sync.sync_single_stock = Mock(return_value=pd.DataFrame()) # Mock existing data with last date mock_storage.load.side_effect = [ - pd.DataFrame({ - 'ts_code': ['000001.SZ'], - 'trade_date': ['20240102'], - }), # get_all_stock_codes - pd.DataFrame({ - 'ts_code': ['000001.SZ'], - 'trade_date': ['20240102'], - }), # get_global_last_date + pd.DataFrame( + { + "ts_code": ["000001.SZ"], + "trade_date": ["20240102"], + } + ), # get_all_stock_codes + pd.DataFrame( + { + "ts_code": ["000001.SZ"], + "trade_date": ["20240102"], + } + ), # get_global_last_date ] result = sync.sync_all(force_full=False) @@ -199,28 +217,30 @@ class TestSyncAll: # Verify sync_single_stock was called with next date sync.sync_single_stock.assert_called_once() call_args = sync.sync_single_stock.call_args - assert call_args[1]['start_date'] == '20240103' + assert call_args[1]["start_date"] == "20240103" def test_manual_start_date(self, mock_storage): """Test sync with manual start date.""" - with patch('src.data.sync.Storage', return_value=mock_storage): + with patch("src.data.sync.Storage", return_value=mock_storage): sync = DataSync() sync.storage = mock_storage sync.sync_single_stock = Mock(return_value=pd.DataFrame()) - mock_storage.load.return_value = pd.DataFrame({ - 'ts_code': ['000001.SZ'], - }) + mock_storage.load.return_value = pd.DataFrame( + { + "ts_code": ["000001.SZ"], + } + ) - result = sync.sync_all(force_full=False, start_date='20230601') + result = sync.sync_all(force_full=False, start_date="20230601") sync.sync_single_stock.assert_called_once() call_args = sync.sync_single_stock.call_args - assert call_args[1]['start_date'] == '20230601' + assert call_args[1]["start_date"] == "20230601" def test_no_stocks_found(self, mock_storage): """Test sync when no stocks are found.""" - with patch('src.data.sync.Storage', return_value=mock_storage): + with patch("src.data.sync.Storage", return_value=mock_storage): sync = DataSync() sync.storage = mock_storage @@ -236,7 +256,7 @@ class TestSyncAllConvenienceFunction: def test_sync_all_function(self): """Test sync_all convenience function.""" - with patch('src.data.sync.DataSync') as MockSync: + with patch("src.data.sync.DataSync") as MockSync: mock_instance = Mock() mock_instance.sync_all.return_value = {} MockSync.return_value = mock_instance @@ -251,5 +271,5 @@ class TestSyncAllConvenienceFunction: ) -if __name__ == '__main__': - pytest.main([__file__, '-v']) +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/tests/test_sync_real.py b/tests/test_sync_real.py new file mode 100644 index 0000000..eacaf3f --- /dev/null +++ b/tests/test_sync_real.py @@ -0,0 +1,256 @@ +"""Tests for data sync with REAL data (read-only). + +Tests verify: +1. get_global_last_date() correctly reads local data's max date +2. Incremental sync date calculation (local_last_date + 1) +3. Full sync date calculation (20180101) +4. Multi-stock scenario with real data + +⚠️ IMPORTANT: These tests ONLY read data, no write operations. +- NO sync_all() calls (writes daily.h5) +- NO check_sync_needed() calls (writes trade_cal.h5) +""" + +import pytest +import pandas as pd +from pathlib import Path + +from src.data.sync import ( + DataSync, + get_next_date, + DEFAULT_START_DATE, +) +from src.data.storage import Storage + + +class TestDataSyncReadOnly: + """Read-only tests for data sync - verify date calculation logic.""" + + @pytest.fixture + def storage(self): + """Create storage instance.""" + return Storage() + + @pytest.fixture + def data_sync(self): + """Create DataSync instance.""" + return DataSync() + + @pytest.fixture + def daily_exists(self, storage): + """Check if daily.h5 exists.""" + return storage.exists("daily") + + def test_daily_h5_exists(self, storage): + """Verify daily.h5 data file exists before running tests.""" + assert storage.exists("daily"), ( + "daily.h5 not found. Please run full sync first: " + "uv run python -c 'from src.data.sync import sync_all; sync_all(force_full=True)'" + ) + + def test_get_global_last_date(self, data_sync, daily_exists): + """Test get_global_last_date returns correct max date from local data.""" + if not daily_exists: + pytest.skip("daily.h5 not found") + + last_date = data_sync.get_global_last_date() + + # Verify it's a valid date string + assert last_date is not None, "get_global_last_date returned None" + assert isinstance(last_date, str), f"Expected str, got {type(last_date)}" + assert len(last_date) == 8, f"Expected 8-digit date, got {last_date}" + assert last_date.isdigit(), f"Expected numeric date, got {last_date}" + + # Verify by reading storage directly + daily_data = data_sync.storage.load("daily") + expected_max = str(daily_data["trade_date"].max()) + + assert last_date == expected_max, ( + f"get_global_last_date returned {last_date}, " + f"but actual max date is {expected_max}" + ) + + print(f"[TEST] Local data last date: {last_date}") + + def test_incremental_sync_date_calculation(self, data_sync, daily_exists): + """Test incremental sync: start_date = local_last_date + 1. + + This verifies that when local data exists, incremental sync should + fetch data from (local_last_date + 1), not from 20180101. + """ + if not daily_exists: + pytest.skip("daily.h5 not found") + + # Get local last date + local_last_date = data_sync.get_global_last_date() + assert local_last_date is not None, "No local data found" + + # Calculate expected incremental start date + expected_start_date = get_next_date(local_last_date) + + # Verify the calculation is correct + local_last_int = int(local_last_date) + expected_int = local_last_int + 1 + actual_int = int(expected_start_date) + + assert actual_int == expected_int, ( + f"Incremental start date calculation error: " + f"expected {expected_int}, got {actual_int}" + ) + + print( + f"[TEST] Incremental sync: local_last={local_last_date}, " + f"start_date should be {expected_start_date}" + ) + + # Verify this is NOT 20180101 (would be full sync) + assert expected_start_date != DEFAULT_START_DATE, ( + f"Incremental sync should NOT start from {DEFAULT_START_DATE}" + ) + + def test_full_sync_date_calculation(self): + """Test full sync: start_date = 20180101 when force_full=True. + + This verifies that force_full=True always starts from 20180101. + """ + # Full sync should always use DEFAULT_START_DATE + full_sync_start = DEFAULT_START_DATE + + assert full_sync_start == "20180101", ( + f"Full sync should start from 20180101, got {full_sync_start}" + ) + + print(f"[TEST] Full sync start date: {full_sync_start}") + + def test_date_comparison_logic(self, data_sync, daily_exists): + """Test date comparison: incremental vs full sync selection logic. + + Verify that: + - If local_last_date < today: incremental sync needed + - If local_last_date >= today: no sync needed + """ + if not daily_exists: + pytest.skip("daily.h5 not found") + + from datetime import datetime + + local_last_date = data_sync.get_global_last_date() + today = datetime.now().strftime("%Y%m%d") + + local_last_int = int(local_last_date) + today_int = int(today) + + # Log the comparison + print( + f"[TEST] Date comparison: local_last={local_last_date} ({local_last_int}), " + f"today={today} ({today_int})" + ) + + # This test just verifies the comparison logic works + if local_last_int < today_int: + print("[TEST] Local data is older than today - sync needed") + # Incremental sync should fetch from local_last_date + 1 + sync_start = get_next_date(local_last_date) + assert int(sync_start) > local_last_int, ( + "Sync start should be after local last" + ) + else: + print("[TEST] Local data is up-to-date - no sync needed") + + def test_get_all_stock_codes_real_data(self, data_sync, daily_exists): + """Test get_all_stock_codes returns multiple real stock codes.""" + if not daily_exists: + pytest.skip("daily.h5 not found") + + codes = data_sync.get_all_stock_codes() + + # Verify it's a list + assert isinstance(codes, list), f"Expected list, got {type(codes)}" + assert len(codes) > 0, "No stock codes found" + + # Verify multiple stocks + assert len(codes) >= 10, ( + f"Expected at least 10 stocks for multi-stock test, got {len(codes)}" + ) + + # Verify format (should be like 000001.SZ, 600000.SH) + sample_codes = codes[:5] + for code in sample_codes: + assert "." in code, f"Invalid stock code format: {code}" + suffix = code.split(".")[-1] + assert suffix in ["SZ", "SH"], f"Invalid exchange suffix: {suffix}" + + print(f"[TEST] Found {len(codes)} stock codes (sample: {sample_codes})") + + def test_multi_stock_date_range(self, data_sync, daily_exists): + """Test that multiple stocks share the same date range in local data. + + This verifies that local data has consistent date coverage across stocks. + """ + if not daily_exists: + pytest.skip("daily.h5 not found") + + daily_data = data_sync.storage.load("daily") + + # Get date range for each stock + stock_dates = daily_data.groupby("ts_code")["trade_date"].agg(["min", "max"]) + + # Get global min and max + global_min = str(daily_data["trade_date"].min()) + global_max = str(daily_data["trade_date"].max()) + + print(f"[TEST] Global date range: {global_min} to {global_max}") + print(f"[TEST] Total stocks: {len(stock_dates)}") + + # Verify we have data for multiple stocks + assert len(stock_dates) >= 10, ( + f"Expected at least 10 stocks, got {len(stock_dates)}" + ) + + # Verify date range is reasonable (at least 1 year of data) + global_min_int = int(global_min) + global_max_int = int(global_max) + days_span = global_max_int - global_min_int + + assert days_span > 100, ( + f"Date range too small: {days_span} days. " + f"Expected at least 100 days of data." + ) + + print(f"[TEST] Date span: {days_span} days") + + +class TestDateUtilities: + """Test date utility functions.""" + + def test_get_next_date(self): + """Test get_next_date correctly calculates next day.""" + # Test normal cases + assert get_next_date("20240101") == "20240102" + assert get_next_date("20240131") == "20240201" # Month boundary + assert get_next_date("20241231") == "20250101" # Year boundary + + def test_incremental_vs_full_sync_logic(self): + """Test the logic difference between incremental and full sync. + + Incremental: start_date = local_last_date + 1 + Full: start_date = 20180101 + """ + # Scenario 1: Local data exists + local_last_date = "20240115" + incremental_start = get_next_date(local_last_date) + + assert incremental_start == "20240116" + assert incremental_start != DEFAULT_START_DATE + + # Scenario 2: Force full sync + full_sync_start = DEFAULT_START_DATE # "20180101" + + assert full_sync_start == "20180101" + assert incremental_start != full_sync_start + + print("[TEST] Incremental vs Full sync logic verified") + + +if __name__ == "__main__": + pytest.main([__file__, "-v", "-s"])