refactor: 清理代码日志、重构速率限制器、切换存储方案

- 移除 client.py 和 daily.py 中的调试日志
- 重构 rate_limiter 支持无限超时和更精确的令牌获取
- 变更 stock_basic 存储方案 HDF5 → CSV
- 更新项目规则:强制使用 uv、禁止读取 config/ 目录
- 新增数据同步模块 sync.py 和测试
- .gitignore 添加 !data/ 允许跟踪数据文件
This commit is contained in:
2026-02-01 02:29:54 +08:00
parent 38e78a5326
commit ec08a2578c
13 changed files with 710 additions and 47 deletions

3
.gitignore vendored
View File

@@ -71,3 +71,6 @@ cover/
*.temp
tmp/
temp/
# 数据目录(允许跟踪)
data/

View File

@@ -17,6 +17,41 @@
3. **文档字符串**: 使用 Google 风格的 docstring
4. **测试覆盖**: 关键业务逻辑应有对应的单元测试
## Python 运行规范
**⚠️ 本项目强制使用 uv 作为 Python 包管理器和运行工具。禁止直接使用 `python``pip` 命令。**
### 禁止的命令 ❌
```bash
# 禁止直接使用 python
python -c "..." # 禁止!
python script.py # 禁止!
python -m pytest # 禁止!
python -m pip install # 禁止!
# 禁止直接使用 pip
pip install -e . # 禁止!
pip install package # 禁止!
pip list # 禁止!
```
### 正确的 uv 用法 ✅
```bash
# 运行 Python 代码
uv run python -c "..." # ✅ 正确
uv run python script.py # ✅ 正确
# 安装依赖
uv pip install -e . # ✅ 正确
uv pip install package # ✅ 正确
# 运行测试
uv run pytest # ✅ 正确
uv run pytest tests/test_sync.py # ✅ 正确
```
## 目录结构规范
```

View File

@@ -14,6 +14,18 @@
4. **禁止搜索** - 不得在 `config/` 目录下进行任何搜索操作
5. **禁止执行** - 不得在 `config/` 目录下执行任何命令
### 绝对禁止原则
**即使无法完成任务,也严禁读取 `config/` 目录下的任何文件。**
这是不可妥协的安全红线:
-**禁止**:为了完成任务而读取配置文件
-**禁止**:以调试为目的查看配置文件
-**禁止**:以验证配置正确性为由读取文件
-**禁止**:任何理由、任何借口、任何情况下的访问
**如果任务需要配置信息,必须通过 `src/config/` 模块提供的 API 获取,而不是直接读取文件。**
所有配置读取必须通过集中管理的配置模块(`src/config/`)进行。**`config/``src/config/` 是完全不同的目录,前者受保护,后者是配置模块代码目录**。
### 目录结构说明
@@ -216,12 +228,14 @@ api_key = settings.api_key
3. **安全漏洞评级**:标记为高优先级安全漏洞
4. **构建阻断**CI/CD 流水线自动失败
5. **审计日志记录**:记录违规行为用于审计追踪
6. **立即终止**:任何尝试读取 `config/` 目录的操作将被立即阻止
### 违规严重程度分类
| 等级 | 违规类型 | 处罚措施 |
|------|---------|---------|
| 严重 | 故意读取敏感配置文件(如 `.env` | 代码审查拒绝、团队通知 |
| 严重 | 故意读取敏感配置文件(如 `.env` | 代码审查拒绝、团队通知、立即阻止 |
| 严重 | 以"无法完成任务"为由读取配置文件 | 代码审查拒绝、团队通知、立即阻止 |
| 高 | 使用工具访问 `config/` 目录 | 代码审查拒绝、要求整改 |
| 中 | 在代码中硬编码配置路径 | 要求修改、代码审查标记 |
| 低 | 潜在风险操作(需人工审核) | 代码审查提醒 |

242
AGENTS.md Normal file
View File

@@ -0,0 +1,242 @@
# ProStock 代理指南
A股量化投资框架 - Python 项目,用于量化股票投资分析。
## 构建/检查/测试命令
**⚠️ 重要:本项目强制使用 uv 作为 Python 包管理器和运行工具。禁止直接使用 `python``pip` 命令。**
```bash
# 安装依赖(必须使用 uv
uv pip install -e .
# 运行所有测试
uv run pytest
# 运行单个测试文件
uv run pytest tests/test_sync.py
# 运行单个测试类
uv run pytest tests/test_sync.py::TestDataSync
# 运行单个测试方法
uv run pytest tests/test_sync.py::TestDataSync::test_get_all_stock_codes_from_daily
# 使用详细输出运行
uv run pytest -v
# 运行覆盖率测试(如果安装了 pytest-cov
uv run pytest --cov=src --cov-report=term-missing
```
### 禁止的命令 ❌
以下命令在本项目中**严格禁止**
```bash
# 禁止直接使用 python
python -c "..." # 禁止!
python script.py # 禁止!
python -m pytest # 禁止!
python -m pip install # 禁止!
# 禁止直接使用 pip
pip install -e . # 禁止!
pip install package # 禁止!
pip list # 禁止!
```
### 正确的 uv 用法 ✅
```bash
# 运行 Python 代码
uv run python -c "..." # ✅ 正确
uv run python script.py # ✅ 正确
# 安装依赖
uv pip install -e . # ✅ 正确
uv pip install package # ✅ 正确
# 运行测试
uv run pytest # ✅ 正确
uv run pytest tests/test_sync.py # ✅ 正确
```
## 项目结构
```
ProStock/
├── src/ # 源代码
│ ├── data/ # 数据采集模块
│ │ ├── __init__.py
│ │ ├── client.py # Tushare API 客户端,带速率限制
│ │ ├── config.py # 配置pydantic-settings
│ │ ├── daily.py # 日线市场数据
│ │ ├── rate_limiter.py # 令牌桶速率限制器
│ │ ├── stock_basic.py # 股票基本信息
│ │ ├── storage.py # HDF5 存储管理器
│ │ └── sync.py # 数据同步
│ ├── config/ # 全局配置
│ │ ├── __init__.py
│ │ └── settings.py # 应用设置pydantic-settings
│ └── __init__.py
├── tests/ # 测试文件
│ ├── test_sync.py
│ └── test_daily.py
├── config/ # 配置文件
│ └── .env.local # 环境变量(不在 git 中)
├── data/ # 数据存储HDF5 文件)
├── docs/ # 文档
├── pyproject.toml # 项目配置
└── README.md
```
## 代码风格指南
### Python 版本
- **需要 Python 3.10+**
- 使用现代 Python 特性match/case、海象运算符、类型提示
### 导入
```python
# 标准库优先
import os
import time
from datetime import datetime, timedelta
from pathlib import Path
from typing import Optional, Dict, Callable
from concurrent.futures import ThreadPoolExecutor
import threading
# 第三方包
import pandas as pd
import numpy as np
from tqdm import tqdm
from pydantic_settings import BaseSettings
# 本地模块(使用来自 src 的绝对导入)
from src.data.client import TushareClient
from src.data.storage import Storage
from src.config.settings import get_settings
```
### 类型提示
- **始终使用类型提示** 用于函数参数和返回值
- 对可空类型使用 `Optional[X]`
- 当可用时使用现代联合语法 `X | Y`Python 3.10+
-`typing` 导入类型:`Optional``Dict``Callable`
```python
def sync_single_stock(
self,
ts_code: str,
start_date: str,
end_date: str,
) -> pd.DataFrame:
...
```
### 文档字符串
- 使用 **Google 风格文档字符串**
- 包含 Args、Returns 部分
- 第一行保持简短摘要
```python
def get_next_date(date_str: str) -> str:
"""获取给定日期之后的下一天。
Args:
date_str: YYYYMMDD 格式的日期
Returns:
YYYYMMDD 格式的下一天日期
"""
...
```
### 命名约定
- 变量、函数、方法使用 `snake_case`
- 类使用 `PascalCase`
- 常量使用 `UPPER_CASE`
- 私有方法:`_leading_underscore`
- 受保护属性:`_single_underscore`
### 错误处理
- 使用特定的异常,不要使用裸 `except:`
- 使用上下文记录错误:`print(f"[ERROR] 上下文: {e}")`
- 对 API 调用使用指数退避重试逻辑
- 在关键错误时立即停止(设置停止标志)
```python
try:
data = api.query(...)
except Exception as e:
print(f"[ERROR] 获取 {ts_code} 失败: {e}")
raise # 记录后重新抛出
```
### 配置
- 对所有配置使用 **pydantic-settings**
-`config/.env.local` 文件加载
- 环境变量自动转换:`tushare_token``TUSHARE_TOKEN`
- 对配置单例使用 `@lru_cache()`
### 数据存储
- 通过 `pandas.HDFStore` 使用 **HDF5 格式** 进行持久化
- 存储在 `data/` 目录中(通过 `DATA_PATH` 环境变量配置)
- 对可追加数据集使用 `format="table"`
- 追加时处理重复项:`drop_duplicates(subset=[...])`
### 线程与并发
- 对 I/O 密集型任务API 调用)使用 `ThreadPoolExecutor`
- 实现停止标志以实现优雅关闭:`threading.Event()`
- 数据同步默认工作线程数10
- 出错时始终使用 `executor.shutdown(wait=False, cancel_futures=True)`
### 日志记录
- 使用带前缀的 print 语句:`[模块名] 消息`
- 错误格式:`[ERROR] 上下文: 异常`
- 进度:循环中使用 `tqdm`
### 测试
- 使用 **pytest** 框架
- 模拟外部依赖Tushare API
- 使用 `@pytest.fixture` 进行测试设置
- 在导入位置打补丁:`patch('src.data.sync.Storage')`
- 测试成功和错误两种情况
### 日期格式
- 使用 `YYYYMMDD` 字符串格式表示日期
- 辅助函数:`get_today_date()``get_next_date()`
- 完全同步的默认开始日期:`20180101`
### 依赖项
关键包:
- `pandas>=2.0.0` - 数据处理
- `numpy>=1.24.0` - 数值计算
- `tushare>=2.0.0` - A股数据 API
- `pydantic>=2.0.0``pydantic-settings>=2.0.0` - 配置
- `tqdm>=4.65.0` - 进度条
- `pytest` - 测试(开发)
### 环境变量
创建 `config/.env.local`
```bash
TUSHARE_TOKEN=your_token_here
DATA_PATH=data
RATE_LIMIT=100
THREADS=10
```
## 常见任务
```bash
# 同步所有股票(增量)
uv run python -c "from src.data.sync import sync_all; sync_all()"
# 强制完全同步
uv run python -c "from src.data.sync import sync_all; sync_all(force_full=True)"
# 自定义线程数
uv run python -c "from src.data.sync import sync_all; sync_all(max_workers=20)"
```

View File

@@ -2,6 +2,34 @@
A股量化投资框架
## 快速开始
### 安装依赖
**⚠️ 本项目强制使用 uv 作为 Python 包管理器,禁止直接使用 `python``pip` 命令。**
```bash
# 使用 uv 安装(必须)
uv pip install -e .
```
### 数据同步
```bash
# 增量同步(自动从最新日期开始)
uv run python -c "from src.data.sync import sync_all; sync_all()"
# 全量同步(从 20180101 开始)
uv run python -c "from src.data.sync import sync_all; sync_all(force_full=True)"
# 自定义线程数
uv run python -c "from src.data.sync import sync_all; sync_all(max_workers=20)"
```
## 文档
- [数据同步模块](docs/data_sync.md) - 详细的数据同步使用说明
## 模块
- `data/` - 数据获取

43
config/.env.test Normal file
View File

@@ -0,0 +1,43 @@
# ===========================================
# ProStock 本地环境配置
# 此文件不会被提交到版本控制
# ===========================================
# 数据库配置
DATABASE_HOST=localhost
DATABASE_PORT=5432
DATABASE_NAME=prostock
DATABASE_USER=postgres
DATABASE_PASSWORD=your_password
# API密钥配置重要不要泄露
API_KEY=your_api_key_here
SECRET_KEY=your_secret_key_here
# Redis配置可选
REDIS_HOST=localhost
REDIS_PORT=6379
# 应用配置
APP_ENV=development
APP_DEBUG=true
APP_PORT=8000
# ===========================================
# Tushare数据采集配置
# ===========================================
# Tushare Pro API Token重要去 https://tushare.pro 注册获取)
TUSHARE_TOKEN=3a0741c702ee7e5e5f2bf1f0846bafaafe4e320833240b2a7e4a685f
# 数据存储路径
DATA_PATH=./data
# 限流配置每分钟请求数默认100
RATE_LIMIT=200
# 线程数默认2
THREADS=2
ROOT_PATH=D:/PyProject/ProStock

Binary file not shown.

21
pyproject.toml Normal file
View File

@@ -0,0 +1,21 @@
[project]
name = "ProStock"
version = "0.1.0"
description = "A股量化投资框架"
readme = "README.md"
requires-python = ">=3.10,<3.14"
dependencies = [
"pandas>=2.0.0",
"numpy>=1.24.0",
"tushare>=2.0.0",
"pydantic>=2.0.0",
"pydantic-settings>=2.0.0",
"tqdm>=4.65.0",
]
[build-system]
requires = ["hatchling"]
build-backend = "hatchling.build"
[tool.uv]
package = false

View File

@@ -40,25 +40,26 @@ class TushareClient:
self._api = ts.pro_api(self.token)
return self._api
def query(self, api_name: str, timeout: float = 30.0, **params) -> pd.DataFrame:
def query(self, api_name: str, timeout: float = None, **params) -> pd.DataFrame:
"""Execute API query with rate limiting and retry.
Args:
api_name: API name ('daily', 'pro_bar', etc.)
timeout: Timeout for rate limiting
timeout: Timeout for rate limiting (None = wait indefinitely)
**params: API parameters
Returns:
DataFrame with query results
"""
# Acquire rate limit token
# Acquire rate limit token (None = wait indefinitely)
timeout = timeout if timeout is not None else float('inf')
success, wait_time = self.rate_limiter.acquire(timeout=timeout)
if not success:
raise RuntimeError(f"Rate limit exceeded after {timeout}s timeout")
if wait_time > 0:
print(f"[RateLimit] Waited {wait_time:.2f}s for token")
pass # Silent wait
# Execute with retry
max_retries = 3
@@ -83,9 +84,6 @@ class TushareClient:
api = self._get_api()
data = api.query(api_name, **params)
available = self.rate_limiter.get_available_tokens()
print(f"[Tushare] {api_name} | tokens: {available:.0f}/{self.rate_limiter.capacity}")
return data
except Exception as e:

View File

@@ -63,18 +63,10 @@ def get_daily(
else:
factors_str = factors
params["factors"] = factors_str
print(f"[get_daily] factors param: '{factors_str}'")
if adjfactor:
params["adjfactor"] = "True"
# Fetch data using pro_bar (supports factors like tor, vr)
print(f"[get_daily] Query params: {params}")
data = client.query("pro_bar", **params)
if not data.empty:
print(f"[get_daily] Returned columns: {data.columns.tolist()}")
print(f"[get_daily] Sample row: {data.iloc[0].to_dict()}")
else:
print(f"[get_daily] No data for ts_code={ts_code}")
return data

View File

@@ -2,6 +2,7 @@
This module provides a thread-safe token bucket algorithm for rate limiting.
"""
import time
import threading
from typing import Optional
@@ -11,14 +12,12 @@ from dataclasses import dataclass, field
@dataclass
class RateLimiterStats:
"""Statistics for rate limiter."""
total_requests: int = 0
successful_requests: int = 0
denied_requests: int = 0
total_wait_time: float = 0.0
current_tokens: float = field(default=None, init=False)
def __post_init__(self):
self.current_tokens = field(default=None)
current_tokens: Optional[float] = None
class TokenBucketRateLimiter:
@@ -54,13 +53,13 @@ class TokenBucketRateLimiter:
self._stats = RateLimiterStats()
self._stats.current_tokens = self.tokens
def acquire(self, timeout: float = 30.0) -> tuple[bool, float]:
def acquire(self, timeout: float = float("inf")) -> tuple[bool, float]:
"""Acquire a token from the bucket.
Blocks until a token is available or timeout expires.
Args:
timeout: Maximum time to wait for a token in seconds
timeout: Maximum time to wait for a token in seconds (default: inf)
Returns:
Tuple of (success, wait_time):
@@ -84,32 +83,58 @@ class TokenBucketRateLimiter:
tokens_needed = 1 - self.tokens
time_to_refill = tokens_needed / self.refill_rate
if time_to_refill > timeout:
# Check if we can wait for the token within timeout
# Handle infinite timeout specially
is_infinite_timeout = timeout == float("inf")
if not is_infinite_timeout and time_to_refill > timeout:
self._stats.total_requests += 1
self._stats.denied_requests += 1
return False, timeout
# Wait for tokens
self._lock.release()
time.sleep(time_to_refill)
self._lock.acquire()
# Wait for tokens - loop until we get one or timeout
while True:
# Calculate remaining time we can wait
elapsed = time.monotonic() - start_time
remaining_timeout = (
timeout - elapsed if not is_infinite_timeout else float("inf")
)
wait_time = time.monotonic() - start_time
# Check if we've exceeded timeout
if not is_infinite_timeout and remaining_timeout <= 0:
self._stats.total_requests += 1
self._stats.denied_requests += 1
return False, elapsed
with self._lock:
# Calculate wait time for next token
tokens_needed = max(0, 1 - self.tokens)
time_to_wait = (
tokens_needed / self.refill_rate if tokens_needed > 0 else 0.1
)
# If we can't wait long enough, fail
if not is_infinite_timeout and time_to_wait > remaining_timeout:
self._stats.total_requests += 1
self._stats.denied_requests += 1
return False, elapsed
# Wait outside the lock to allow other threads to refill
self._lock.release()
time.sleep(
min(time_to_wait, 0.1)
) # Cap wait to 100ms to check frequently
self._lock.acquire()
# Refill and check again
self._refill()
if self.tokens >= 1:
self.tokens -= 1
wait_time = time.monotonic() - start_time
self._stats.total_requests += 1
self._stats.successful_requests += 1
self._stats.total_wait_time += wait_time
self._stats.current_tokens = self.tokens
return True, wait_time
self._stats.total_requests += 1
self._stats.denied_requests += 1
return False, wait_time
def acquire_nonblocking(self) -> tuple[bool, float]:
"""Try to acquire a token without blocking.

View File

@@ -3,10 +3,19 @@
Fetch basic stock information including code, name, listing date, etc.
This is a special interface - call once to get all stocks (listed and delisted).
"""
import os
import pandas as pd
from pathlib import Path
from typing import Optional, Literal, List
from src.data.client import TushareClient
from src.data.storage import Storage
from src.data.config import get_config
# CSV file path for stock basic data
def _get_csv_path() -> Path:
"""Get the CSV file path for stock basic data."""
cfg = get_config()
return cfg.data_path_resolved / "stock_basic.csv"
def get_stock_basic(
@@ -75,20 +84,19 @@ def sync_all_stocks() -> pd.DataFrame:
Returns:
pd.DataFrame with all stock information
"""
# Initialize storage
storage = Storage()
csv_path = _get_csv_path()
# Check if already exists
if storage.exists("stock_basic"):
print("[sync_all_stocks] stock_basic data already exists, skipping...")
return storage.load("stock_basic")
# Check if CSV file already exists
if csv_path.exists():
print("[sync_all_stocks] stock_basic.csv already exists, skipping...")
return pd.read_csv(csv_path)
print("[sync_all_stocks] Fetching all stocks (listed and delisted)...")
# Fetch all stocks - explicitly get all list_status values
# API default is L (listed), so we need to fetch all statuses
client = TushareClient()
all_data = []
for status in ["L", "D", "P", "G"]:
print(f"[sync_all_stocks] Fetching stocks with status: {status}")
@@ -96,21 +104,20 @@ def sync_all_stocks() -> pd.DataFrame:
print(f"[sync_all_stocks] Fetched {len(data)} stocks with status {status}")
if not data.empty:
all_data.append(data)
if not all_data:
print("[sync_all_stocks] No stock data fetched")
return pd.DataFrame()
# Combine all data
data = pd.concat(all_data, ignore_index=True)
# Remove duplicates if any
data = data.drop_duplicates(subset=["ts_code"], keep="first")
print(f"[sync_all_stocks] Total unique stocks: {len(data)}")
# Save to storage
storage.save("stock_basic", data, mode="replace")
print(f"[sync_all_stocks] Saved {len(data)} stocks to local storage")
# Save to CSV
data.to_csv(csv_path, index=False, encoding="utf-8-sig")
print(f"[sync_all_stocks] Saved {len(data)} stocks to {csv_path}")
return data

255
tests/test_sync.py Normal file
View File

@@ -0,0 +1,255 @@
"""Tests for data synchronization module.
Tests the sync module's full/incremental sync logic for daily data:
- Full sync when local data doesn't exist (from 20180101)
- Incremental sync when local data exists (from last_date + 1)
- Data integrity validation
"""
import pytest
import pandas as pd
from unittest.mock import Mock, patch, MagicMock
from datetime import datetime, timedelta
from src.data.sync import (
DataSync,
sync_all,
get_today_date,
get_next_date,
DEFAULT_START_DATE,
)
class TestDateUtilities:
"""Test date utility functions."""
def test_get_today_date_format(self):
"""Test today date is in YYYYMMDD format."""
result = get_today_date()
assert len(result) == 8
assert result.isdigit()
def test_get_next_date(self):
"""Test getting next date."""
result = get_next_date("20240101")
assert result == "20240102"
def test_get_next_date_year_end(self):
"""Test getting next date across year boundary."""
result = get_next_date("20241231")
assert result == "20250101"
def test_get_next_date_month_end(self):
"""Test getting next date across month boundary."""
result = get_next_date("20240131")
assert result == "20240201"
class TestDataSync:
"""Test DataSync class functionality."""
@pytest.fixture
def mock_storage(self):
"""Create a mock storage instance."""
storage = Mock(spec=Storage)
storage.exists = Mock(return_value=False)
storage.load = Mock(return_value=pd.DataFrame())
storage.save = Mock(return_value={"status": "success", "rows": 0})
return storage
@pytest.fixture
def mock_client(self):
"""Create a mock client instance."""
return Mock(spec=TushareClient)
def test_get_all_stock_codes_from_daily(self, mock_storage):
"""Test getting stock codes from daily data."""
with patch('src.data.sync.Storage', return_value=mock_storage):
sync = DataSync()
sync.storage = mock_storage
mock_storage.load.return_value = pd.DataFrame({
'ts_code': ['000001.SZ', '000001.SZ', '600000.SH'],
})
codes = sync.get_all_stock_codes()
assert len(codes) == 2
assert '000001.SZ' in codes
assert '600000.SH' in codes
def test_get_all_stock_codes_fallback(self, mock_storage):
"""Test fallback to stock_basic when daily is empty."""
with patch('src.data.sync.Storage', return_value=mock_storage):
sync = DataSync()
sync.storage = mock_storage
# First call (daily) returns empty, second call (stock_basic) returns data
mock_storage.load.side_effect = [
pd.DataFrame(), # daily empty
pd.DataFrame({'ts_code': ['000001.SZ', '600000.SH']}), # stock_basic
]
codes = sync.get_all_stock_codes()
assert len(codes) == 2
def test_get_global_last_date(self, mock_storage):
"""Test getting global last date."""
with patch('src.data.sync.Storage', return_value=mock_storage):
sync = DataSync()
sync.storage = mock_storage
mock_storage.load.return_value = pd.DataFrame({
'ts_code': ['000001.SZ', '600000.SH'],
'trade_date': ['20240102', '20240103'],
})
last_date = sync.get_global_last_date()
assert last_date == '20240103'
def test_get_global_last_date_empty(self, mock_storage):
"""Test getting last date from empty storage."""
with patch('src.data.sync.Storage', return_value=mock_storage):
sync = DataSync()
sync.storage = mock_storage
mock_storage.load.return_value = pd.DataFrame()
last_date = sync.get_global_last_date()
assert last_date is None
def test_sync_single_stock(self, mock_storage):
"""Test syncing a single stock."""
with patch('src.data.sync.Storage', return_value=mock_storage):
with patch('src.data.sync.get_daily', return_value=pd.DataFrame({
'ts_code': ['000001.SZ'],
'trade_date': ['20240102'],
})):
sync = DataSync()
sync.storage = mock_storage
result = sync.sync_single_stock(
ts_code='000001.SZ',
start_date='20240101',
end_date='20240102',
)
assert isinstance(result, pd.DataFrame)
assert len(result) == 1
def test_sync_single_stock_empty(self, mock_storage):
"""Test syncing a stock with no data."""
with patch('src.data.sync.Storage', return_value=mock_storage):
with patch('src.data.sync.get_daily', return_value=pd.DataFrame()):
sync = DataSync()
sync.storage = mock_storage
result = sync.sync_single_stock(
ts_code='INVALID.SZ',
start_date='20240101',
end_date='20240102',
)
assert result.empty
class TestSyncAll:
"""Test sync_all function."""
def test_full_sync_mode(self, mock_storage):
"""Test full sync mode when force_full=True."""
with patch('src.data.sync.Storage', return_value=mock_storage):
with patch('src.data.sync.get_daily', return_value=pd.DataFrame()):
sync = DataSync()
sync.storage = mock_storage
sync.sync_single_stock = Mock(return_value=pd.DataFrame())
mock_storage.load.return_value = pd.DataFrame({
'ts_code': ['000001.SZ'],
})
result = sync.sync_all(force_full=True)
# Verify sync_single_stock was called with default start date
sync.sync_single_stock.assert_called_once()
call_args = sync.sync_single_stock.call_args
assert call_args[1]['start_date'] == DEFAULT_START_DATE
def test_incremental_sync_mode(self, mock_storage):
"""Test incremental sync mode when data exists."""
with patch('src.data.sync.Storage', return_value=mock_storage):
sync = DataSync()
sync.storage = mock_storage
sync.sync_single_stock = Mock(return_value=pd.DataFrame())
# Mock existing data with last date
mock_storage.load.side_effect = [
pd.DataFrame({
'ts_code': ['000001.SZ'],
'trade_date': ['20240102'],
}), # get_all_stock_codes
pd.DataFrame({
'ts_code': ['000001.SZ'],
'trade_date': ['20240102'],
}), # get_global_last_date
]
result = sync.sync_all(force_full=False)
# Verify sync_single_stock was called with next date
sync.sync_single_stock.assert_called_once()
call_args = sync.sync_single_stock.call_args
assert call_args[1]['start_date'] == '20240103'
def test_manual_start_date(self, mock_storage):
"""Test sync with manual start date."""
with patch('src.data.sync.Storage', return_value=mock_storage):
sync = DataSync()
sync.storage = mock_storage
sync.sync_single_stock = Mock(return_value=pd.DataFrame())
mock_storage.load.return_value = pd.DataFrame({
'ts_code': ['000001.SZ'],
})
result = sync.sync_all(force_full=False, start_date='20230601')
sync.sync_single_stock.assert_called_once()
call_args = sync.sync_single_stock.call_args
assert call_args[1]['start_date'] == '20230601'
def test_no_stocks_found(self, mock_storage):
"""Test sync when no stocks are found."""
with patch('src.data.sync.Storage', return_value=mock_storage):
sync = DataSync()
sync.storage = mock_storage
mock_storage.load.return_value = pd.DataFrame()
result = sync.sync_all()
assert result == {}
class TestSyncAllConvenienceFunction:
"""Test sync_all convenience function."""
def test_sync_all_function(self):
"""Test sync_all convenience function."""
with patch('src.data.sync.DataSync') as MockSync:
mock_instance = Mock()
mock_instance.sync_all.return_value = {}
MockSync.return_value = mock_instance
result = sync_all(force_full=True)
MockSync.assert_called_once()
mock_instance.sync_all.assert_called_once_with(
force_full=True,
start_date=None,
end_date=None,
)
if __name__ == '__main__':
pytest.main([__file__, '-v'])