diff --git a/.gitignore b/.gitignore index 77773ca..e610afc 100644 --- a/.gitignore +++ b/.gitignore @@ -71,3 +71,6 @@ cover/ *.temp tmp/ temp/ + +# 数据目录(允许跟踪) +data/ diff --git a/.kilocode/rules/project_rules.md b/.kilocode/rules/project_rules.md index f10b51c..4887b93 100644 --- a/.kilocode/rules/project_rules.md +++ b/.kilocode/rules/project_rules.md @@ -17,6 +17,41 @@ 3. **文档字符串**: 使用 Google 风格的 docstring 4. **测试覆盖**: 关键业务逻辑应有对应的单元测试 +## Python 运行规范 + +**⚠️ 本项目强制使用 uv 作为 Python 包管理器和运行工具。禁止直接使用 `python` 或 `pip` 命令。** + +### 禁止的命令 ❌ + +```bash +# 禁止直接使用 python +python -c "..." # 禁止! +python script.py # 禁止! +python -m pytest # 禁止! +python -m pip install # 禁止! + +# 禁止直接使用 pip +pip install -e . # 禁止! +pip install package # 禁止! +pip list # 禁止! +``` + +### 正确的 uv 用法 ✅ + +```bash +# 运行 Python 代码 +uv run python -c "..." # ✅ 正确 +uv run python script.py # ✅ 正确 + +# 安装依赖 +uv pip install -e . # ✅ 正确 +uv pip install package # ✅ 正确 + +# 运行测试 +uv run pytest # ✅ 正确 +uv run pytest tests/test_sync.py # ✅ 正确 +``` + ## 目录结构规范 ``` diff --git a/.kilocode/rules/security_rules.md b/.kilocode/rules/security_rules.md index 14dc61f..0a03c20 100644 --- a/.kilocode/rules/security_rules.md +++ b/.kilocode/rules/security_rules.md @@ -14,6 +14,18 @@ 4. **禁止搜索** - 不得在 `config/` 目录下进行任何搜索操作 5. **禁止执行** - 不得在 `config/` 目录下执行任何命令 +### 绝对禁止原则 + +**即使无法完成任务,也严禁读取 `config/` 目录下的任何文件。** + +这是不可妥协的安全红线: +- ❌ **禁止**:为了完成任务而读取配置文件 +- ❌ **禁止**:以调试为目的查看配置文件 +- ❌ **禁止**:以验证配置正确性为由读取文件 +- ❌ **禁止**:任何理由、任何借口、任何情况下的访问 + +**如果任务需要配置信息,必须通过 `src/config/` 模块提供的 API 获取,而不是直接读取文件。** + 所有配置读取必须通过集中管理的配置模块(`src/config/`)进行。**`config/` 与 `src/config/` 是完全不同的目录,前者受保护,后者是配置模块代码目录**。 ### 目录结构说明 @@ -216,12 +228,14 @@ api_key = settings.api_key 3. **安全漏洞评级**:标记为高优先级安全漏洞 4. **构建阻断**:CI/CD 流水线自动失败 5. **审计日志记录**:记录违规行为用于审计追踪 +6. **立即终止**:任何尝试读取 `config/` 目录的操作将被立即阻止 ### 违规严重程度分类 | 等级 | 违规类型 | 处罚措施 | |------|---------|---------| -| 严重 | 故意读取敏感配置文件(如 `.env`) | 代码审查拒绝、团队通知 | +| 严重 | 故意读取敏感配置文件(如 `.env`) | 代码审查拒绝、团队通知、立即阻止 | +| 严重 | 以"无法完成任务"为由读取配置文件 | 代码审查拒绝、团队通知、立即阻止 | | 高 | 使用工具访问 `config/` 目录 | 代码审查拒绝、要求整改 | | 中 | 在代码中硬编码配置路径 | 要求修改、代码审查标记 | | 低 | 潜在风险操作(需人工审核) | 代码审查提醒 | diff --git a/AGENTS.md b/AGENTS.md new file mode 100644 index 0000000..bca8d6e --- /dev/null +++ b/AGENTS.md @@ -0,0 +1,242 @@ +# ProStock 代理指南 + +A股量化投资框架 - Python 项目,用于量化股票投资分析。 + +## 构建/检查/测试命令 + +**⚠️ 重要:本项目强制使用 uv 作为 Python 包管理器和运行工具。禁止直接使用 `python` 或 `pip` 命令。** + +```bash +# 安装依赖(必须使用 uv) +uv pip install -e . + +# 运行所有测试 +uv run pytest + +# 运行单个测试文件 +uv run pytest tests/test_sync.py + +# 运行单个测试类 +uv run pytest tests/test_sync.py::TestDataSync + +# 运行单个测试方法 +uv run pytest tests/test_sync.py::TestDataSync::test_get_all_stock_codes_from_daily + +# 使用详细输出运行 +uv run pytest -v + +# 运行覆盖率测试(如果安装了 pytest-cov) +uv run pytest --cov=src --cov-report=term-missing +``` + +### 禁止的命令 ❌ + +以下命令在本项目中**严格禁止**: + +```bash +# 禁止直接使用 python +python -c "..." # 禁止! +python script.py # 禁止! +python -m pytest # 禁止! +python -m pip install # 禁止! + +# 禁止直接使用 pip +pip install -e . # 禁止! +pip install package # 禁止! +pip list # 禁止! +``` + +### 正确的 uv 用法 ✅ + +```bash +# 运行 Python 代码 +uv run python -c "..." # ✅ 正确 +uv run python script.py # ✅ 正确 + +# 安装依赖 +uv pip install -e . # ✅ 正确 +uv pip install package # ✅ 正确 + +# 运行测试 +uv run pytest # ✅ 正确 +uv run pytest tests/test_sync.py # ✅ 正确 +``` + +## 项目结构 + +``` +ProStock/ +├── src/ # 源代码 +│ ├── data/ # 数据采集模块 +│ │ ├── __init__.py +│ │ ├── client.py # Tushare API 客户端,带速率限制 +│ │ ├── config.py # 配置(pydantic-settings) +│ │ ├── daily.py # 日线市场数据 +│ │ ├── rate_limiter.py # 令牌桶速率限制器 +│ │ ├── stock_basic.py # 股票基本信息 +│ │ ├── storage.py # HDF5 存储管理器 +│ │ └── sync.py # 数据同步 +│ ├── config/ # 全局配置 +│ │ ├── __init__.py +│ │ └── settings.py # 应用设置(pydantic-settings) +│ └── __init__.py +├── tests/ # 测试文件 +│ ├── test_sync.py +│ └── test_daily.py +├── config/ # 配置文件 +│ └── .env.local # 环境变量(不在 git 中) +├── data/ # 数据存储(HDF5 文件) +├── docs/ # 文档 +├── pyproject.toml # 项目配置 +└── README.md +``` + +## 代码风格指南 + +### Python 版本 +- **需要 Python 3.10+** +- 使用现代 Python 特性(match/case、海象运算符、类型提示) + +### 导入 +```python +# 标准库优先 +import os +import time +from datetime import datetime, timedelta +from pathlib import Path +from typing import Optional, Dict, Callable +from concurrent.futures import ThreadPoolExecutor +import threading + +# 第三方包 +import pandas as pd +import numpy as np +from tqdm import tqdm +from pydantic_settings import BaseSettings + +# 本地模块(使用来自 src 的绝对导入) +from src.data.client import TushareClient +from src.data.storage import Storage +from src.config.settings import get_settings +``` + +### 类型提示 +- **始终使用类型提示** 用于函数参数和返回值 +- 对可空类型使用 `Optional[X]` +- 当可用时使用现代联合语法 `X | Y`(Python 3.10+) +- 从 `typing` 导入类型:`Optional`、`Dict`、`Callable` 等 + +```python +def sync_single_stock( + self, + ts_code: str, + start_date: str, + end_date: str, +) -> pd.DataFrame: + ... +``` + +### 文档字符串 +- 使用 **Google 风格文档字符串** +- 包含 Args、Returns 部分 +- 第一行保持简短摘要 + +```python +def get_next_date(date_str: str) -> str: + """获取给定日期之后的下一天。 + + Args: + date_str: YYYYMMDD 格式的日期 + + Returns: + YYYYMMDD 格式的下一天日期 + """ + ... +``` + +### 命名约定 +- 变量、函数、方法使用 `snake_case` +- 类使用 `PascalCase` +- 常量使用 `UPPER_CASE` +- 私有方法:`_leading_underscore` +- 受保护属性:`_single_underscore` + +### 错误处理 +- 使用特定的异常,不要使用裸 `except:` +- 使用上下文记录错误:`print(f"[ERROR] 上下文: {e}")` +- 对 API 调用使用指数退避重试逻辑 +- 在关键错误时立即停止(设置停止标志) + +```python +try: + data = api.query(...) +except Exception as e: + print(f"[ERROR] 获取 {ts_code} 失败: {e}") + raise # 记录后重新抛出 +``` + +### 配置 +- 对所有配置使用 **pydantic-settings** +- 从 `config/.env.local` 文件加载 +- 环境变量自动转换:`tushare_token` → `TUSHARE_TOKEN` +- 对配置单例使用 `@lru_cache()` + +### 数据存储 +- 通过 `pandas.HDFStore` 使用 **HDF5 格式** 进行持久化 +- 存储在 `data/` 目录中(通过 `DATA_PATH` 环境变量配置) +- 对可追加数据集使用 `format="table"` +- 追加时处理重复项:`drop_duplicates(subset=[...])` + +### 线程与并发 +- 对 I/O 密集型任务(API 调用)使用 `ThreadPoolExecutor` +- 实现停止标志以实现优雅关闭:`threading.Event()` +- 数据同步默认工作线程数:10 +- 出错时始终使用 `executor.shutdown(wait=False, cancel_futures=True)` + +### 日志记录 +- 使用带前缀的 print 语句:`[模块名] 消息` +- 错误格式:`[ERROR] 上下文: 异常` +- 进度:循环中使用 `tqdm` + +### 测试 +- 使用 **pytest** 框架 +- 模拟外部依赖(Tushare API) +- 使用 `@pytest.fixture` 进行测试设置 +- 在导入位置打补丁:`patch('src.data.sync.Storage')` +- 测试成功和错误两种情况 + +### 日期格式 +- 使用 `YYYYMMDD` 字符串格式表示日期 +- 辅助函数:`get_today_date()`、`get_next_date()` +- 完全同步的默认开始日期:`20180101` + +### 依赖项 +关键包: +- `pandas>=2.0.0` - 数据处理 +- `numpy>=1.24.0` - 数值计算 +- `tushare>=2.0.0` - A股数据 API +- `pydantic>=2.0.0`、`pydantic-settings>=2.0.0` - 配置 +- `tqdm>=4.65.0` - 进度条 +- `pytest` - 测试(开发) + +### 环境变量 +创建 `config/.env.local`: +```bash +TUSHARE_TOKEN=your_token_here +DATA_PATH=data +RATE_LIMIT=100 +THREADS=10 +``` + +## 常见任务 + +```bash +# 同步所有股票(增量) +uv run python -c "from src.data.sync import sync_all; sync_all()" + +# 强制完全同步 +uv run python -c "from src.data.sync import sync_all; sync_all(force_full=True)" + +# 自定义线程数 +uv run python -c "from src.data.sync import sync_all; sync_all(max_workers=20)" +``` diff --git a/README.md b/README.md index c18e334..3954294 100644 --- a/README.md +++ b/README.md @@ -2,6 +2,34 @@ A股量化投资框架 +## 快速开始 + +### 安装依赖 + +**⚠️ 本项目强制使用 uv 作为 Python 包管理器,禁止直接使用 `python` 或 `pip` 命令。** + +```bash +# 使用 uv 安装(必须) +uv pip install -e . +``` + +### 数据同步 + +```bash +# 增量同步(自动从最新日期开始) +uv run python -c "from src.data.sync import sync_all; sync_all()" + +# 全量同步(从 20180101 开始) +uv run python -c "from src.data.sync import sync_all; sync_all(force_full=True)" + +# 自定义线程数 +uv run python -c "from src.data.sync import sync_all; sync_all(max_workers=20)" +``` + +## 文档 + +- [数据同步模块](docs/data_sync.md) - 详细的数据同步使用说明 + ## 模块 - `data/` - 数据获取 diff --git a/config/.env.test b/config/.env.test new file mode 100644 index 0000000..915e716 --- /dev/null +++ b/config/.env.test @@ -0,0 +1,43 @@ +# =========================================== +# ProStock 本地环境配置 +# 此文件不会被提交到版本控制 +# =========================================== + +# 数据库配置 +DATABASE_HOST=localhost +DATABASE_PORT=5432 +DATABASE_NAME=prostock +DATABASE_USER=postgres +DATABASE_PASSWORD=your_password + +# API密钥配置(重要:不要泄露) +API_KEY=your_api_key_here +SECRET_KEY=your_secret_key_here + +# Redis配置(可选) +REDIS_HOST=localhost +REDIS_PORT=6379 + +# 应用配置 +APP_ENV=development +APP_DEBUG=true +APP_PORT=8000 + +# =========================================== +# Tushare数据采集配置 +# =========================================== + +# Tushare Pro API Token(重要:去 https://tushare.pro 注册获取) +TUSHARE_TOKEN=3a0741c702ee7e5e5f2bf1f0846bafaafe4e320833240b2a7e4a685f + +# 数据存储路径 +DATA_PATH=./data + +# 限流配置:每分钟请求数(默认100) +RATE_LIMIT=200 + +# 线程数(默认2) +THREADS=2 + +ROOT_PATH=D:/PyProject/ProStock + diff --git a/data/stock_basic.h5 b/data/stock_basic.h5 deleted file mode 100644 index 5e4b36d..0000000 Binary files a/data/stock_basic.h5 and /dev/null differ diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..7d51d5e --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,21 @@ +[project] +name = "ProStock" +version = "0.1.0" +description = "A股量化投资框架" +readme = "README.md" +requires-python = ">=3.10,<3.14" +dependencies = [ + "pandas>=2.0.0", + "numpy>=1.24.0", + "tushare>=2.0.0", + "pydantic>=2.0.0", + "pydantic-settings>=2.0.0", + "tqdm>=4.65.0", +] + +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[tool.uv] +package = false diff --git a/src/data/client.py b/src/data/client.py index 04ec97e..c66edf3 100644 --- a/src/data/client.py +++ b/src/data/client.py @@ -40,25 +40,26 @@ class TushareClient: self._api = ts.pro_api(self.token) return self._api - def query(self, api_name: str, timeout: float = 30.0, **params) -> pd.DataFrame: + def query(self, api_name: str, timeout: float = None, **params) -> pd.DataFrame: """Execute API query with rate limiting and retry. Args: api_name: API name ('daily', 'pro_bar', etc.) - timeout: Timeout for rate limiting + timeout: Timeout for rate limiting (None = wait indefinitely) **params: API parameters Returns: DataFrame with query results """ - # Acquire rate limit token + # Acquire rate limit token (None = wait indefinitely) + timeout = timeout if timeout is not None else float('inf') success, wait_time = self.rate_limiter.acquire(timeout=timeout) if not success: raise RuntimeError(f"Rate limit exceeded after {timeout}s timeout") if wait_time > 0: - print(f"[RateLimit] Waited {wait_time:.2f}s for token") + pass # Silent wait # Execute with retry max_retries = 3 @@ -83,9 +84,6 @@ class TushareClient: api = self._get_api() data = api.query(api_name, **params) - available = self.rate_limiter.get_available_tokens() - print(f"[Tushare] {api_name} | tokens: {available:.0f}/{self.rate_limiter.capacity}") - return data except Exception as e: diff --git a/src/data/daily.py b/src/data/daily.py index cd1b048..bc7c86e 100644 --- a/src/data/daily.py +++ b/src/data/daily.py @@ -63,18 +63,10 @@ def get_daily( else: factors_str = factors params["factors"] = factors_str - print(f"[get_daily] factors param: '{factors_str}'") if adjfactor: params["adjfactor"] = "True" # Fetch data using pro_bar (supports factors like tor, vr) - print(f"[get_daily] Query params: {params}") data = client.query("pro_bar", **params) - if not data.empty: - print(f"[get_daily] Returned columns: {data.columns.tolist()}") - print(f"[get_daily] Sample row: {data.iloc[0].to_dict()}") - else: - print(f"[get_daily] No data for ts_code={ts_code}") - return data diff --git a/src/data/rate_limiter.py b/src/data/rate_limiter.py index e9e8051..1ade893 100644 --- a/src/data/rate_limiter.py +++ b/src/data/rate_limiter.py @@ -2,6 +2,7 @@ This module provides a thread-safe token bucket algorithm for rate limiting. """ + import time import threading from typing import Optional @@ -11,14 +12,12 @@ from dataclasses import dataclass, field @dataclass class RateLimiterStats: """Statistics for rate limiter.""" + total_requests: int = 0 successful_requests: int = 0 denied_requests: int = 0 total_wait_time: float = 0.0 - current_tokens: float = field(default=None, init=False) - - def __post_init__(self): - self.current_tokens = field(default=None) + current_tokens: Optional[float] = None class TokenBucketRateLimiter: @@ -54,13 +53,13 @@ class TokenBucketRateLimiter: self._stats = RateLimiterStats() self._stats.current_tokens = self.tokens - def acquire(self, timeout: float = 30.0) -> tuple[bool, float]: + def acquire(self, timeout: float = float("inf")) -> tuple[bool, float]: """Acquire a token from the bucket. Blocks until a token is available or timeout expires. Args: - timeout: Maximum time to wait for a token in seconds + timeout: Maximum time to wait for a token in seconds (default: inf) Returns: Tuple of (success, wait_time): @@ -84,32 +83,58 @@ class TokenBucketRateLimiter: tokens_needed = 1 - self.tokens time_to_refill = tokens_needed / self.refill_rate - if time_to_refill > timeout: + # Check if we can wait for the token within timeout + # Handle infinite timeout specially + is_infinite_timeout = timeout == float("inf") + if not is_infinite_timeout and time_to_refill > timeout: self._stats.total_requests += 1 self._stats.denied_requests += 1 return False, timeout - # Wait for tokens - self._lock.release() - time.sleep(time_to_refill) - self._lock.acquire() + # Wait for tokens - loop until we get one or timeout + while True: + # Calculate remaining time we can wait + elapsed = time.monotonic() - start_time + remaining_timeout = ( + timeout - elapsed if not is_infinite_timeout else float("inf") + ) - wait_time = time.monotonic() - start_time + # Check if we've exceeded timeout + if not is_infinite_timeout and remaining_timeout <= 0: + self._stats.total_requests += 1 + self._stats.denied_requests += 1 + return False, elapsed - with self._lock: + # Calculate wait time for next token + tokens_needed = max(0, 1 - self.tokens) + time_to_wait = ( + tokens_needed / self.refill_rate if tokens_needed > 0 else 0.1 + ) + + # If we can't wait long enough, fail + if not is_infinite_timeout and time_to_wait > remaining_timeout: + self._stats.total_requests += 1 + self._stats.denied_requests += 1 + return False, elapsed + + # Wait outside the lock to allow other threads to refill + self._lock.release() + time.sleep( + min(time_to_wait, 0.1) + ) # Cap wait to 100ms to check frequently + self._lock.acquire() + + # Refill and check again self._refill() if self.tokens >= 1: self.tokens -= 1 + wait_time = time.monotonic() - start_time self._stats.total_requests += 1 self._stats.successful_requests += 1 self._stats.total_wait_time += wait_time self._stats.current_tokens = self.tokens return True, wait_time - self._stats.total_requests += 1 - self._stats.denied_requests += 1 - return False, wait_time - def acquire_nonblocking(self) -> tuple[bool, float]: """Try to acquire a token without blocking. diff --git a/src/data/stock_basic.py b/src/data/stock_basic.py index ee28852..b8f36b3 100644 --- a/src/data/stock_basic.py +++ b/src/data/stock_basic.py @@ -3,10 +3,19 @@ Fetch basic stock information including code, name, listing date, etc. This is a special interface - call once to get all stocks (listed and delisted). """ +import os import pandas as pd +from pathlib import Path from typing import Optional, Literal, List from src.data.client import TushareClient -from src.data.storage import Storage +from src.data.config import get_config + + +# CSV file path for stock basic data +def _get_csv_path() -> Path: + """Get the CSV file path for stock basic data.""" + cfg = get_config() + return cfg.data_path_resolved / "stock_basic.csv" def get_stock_basic( @@ -75,20 +84,19 @@ def sync_all_stocks() -> pd.DataFrame: Returns: pd.DataFrame with all stock information """ - # Initialize storage - storage = Storage() + csv_path = _get_csv_path() - # Check if already exists - if storage.exists("stock_basic"): - print("[sync_all_stocks] stock_basic data already exists, skipping...") - return storage.load("stock_basic") + # Check if CSV file already exists + if csv_path.exists(): + print("[sync_all_stocks] stock_basic.csv already exists, skipping...") + return pd.read_csv(csv_path) print("[sync_all_stocks] Fetching all stocks (listed and delisted)...") # Fetch all stocks - explicitly get all list_status values # API default is L (listed), so we need to fetch all statuses client = TushareClient() - + all_data = [] for status in ["L", "D", "P", "G"]: print(f"[sync_all_stocks] Fetching stocks with status: {status}") @@ -96,21 +104,20 @@ def sync_all_stocks() -> pd.DataFrame: print(f"[sync_all_stocks] Fetched {len(data)} stocks with status {status}") if not data.empty: all_data.append(data) - + if not all_data: print("[sync_all_stocks] No stock data fetched") return pd.DataFrame() - + # Combine all data data = pd.concat(all_data, ignore_index=True) # Remove duplicates if any data = data.drop_duplicates(subset=["ts_code"], keep="first") print(f"[sync_all_stocks] Total unique stocks: {len(data)}") - # Save to storage - storage.save("stock_basic", data, mode="replace") - - print(f"[sync_all_stocks] Saved {len(data)} stocks to local storage") + # Save to CSV + data.to_csv(csv_path, index=False, encoding="utf-8-sig") + print(f"[sync_all_stocks] Saved {len(data)} stocks to {csv_path}") return data diff --git a/tests/test_sync.py b/tests/test_sync.py new file mode 100644 index 0000000..f336562 --- /dev/null +++ b/tests/test_sync.py @@ -0,0 +1,255 @@ +"""Tests for data synchronization module. + +Tests the sync module's full/incremental sync logic for daily data: +- Full sync when local data doesn't exist (from 20180101) +- Incremental sync when local data exists (from last_date + 1) +- Data integrity validation +""" +import pytest +import pandas as pd +from unittest.mock import Mock, patch, MagicMock +from datetime import datetime, timedelta + +from src.data.sync import ( + DataSync, + sync_all, + get_today_date, + get_next_date, + DEFAULT_START_DATE, +) + + +class TestDateUtilities: + """Test date utility functions.""" + + def test_get_today_date_format(self): + """Test today date is in YYYYMMDD format.""" + result = get_today_date() + assert len(result) == 8 + assert result.isdigit() + + def test_get_next_date(self): + """Test getting next date.""" + result = get_next_date("20240101") + assert result == "20240102" + + def test_get_next_date_year_end(self): + """Test getting next date across year boundary.""" + result = get_next_date("20241231") + assert result == "20250101" + + def test_get_next_date_month_end(self): + """Test getting next date across month boundary.""" + result = get_next_date("20240131") + assert result == "20240201" + + +class TestDataSync: + """Test DataSync class functionality.""" + + @pytest.fixture + def mock_storage(self): + """Create a mock storage instance.""" + storage = Mock(spec=Storage) + storage.exists = Mock(return_value=False) + storage.load = Mock(return_value=pd.DataFrame()) + storage.save = Mock(return_value={"status": "success", "rows": 0}) + return storage + + @pytest.fixture + def mock_client(self): + """Create a mock client instance.""" + return Mock(spec=TushareClient) + + def test_get_all_stock_codes_from_daily(self, mock_storage): + """Test getting stock codes from daily data.""" + with patch('src.data.sync.Storage', return_value=mock_storage): + sync = DataSync() + sync.storage = mock_storage + + mock_storage.load.return_value = pd.DataFrame({ + 'ts_code': ['000001.SZ', '000001.SZ', '600000.SH'], + }) + + codes = sync.get_all_stock_codes() + + assert len(codes) == 2 + assert '000001.SZ' in codes + assert '600000.SH' in codes + + def test_get_all_stock_codes_fallback(self, mock_storage): + """Test fallback to stock_basic when daily is empty.""" + with patch('src.data.sync.Storage', return_value=mock_storage): + sync = DataSync() + sync.storage = mock_storage + + # First call (daily) returns empty, second call (stock_basic) returns data + mock_storage.load.side_effect = [ + pd.DataFrame(), # daily empty + pd.DataFrame({'ts_code': ['000001.SZ', '600000.SH']}), # stock_basic + ] + + codes = sync.get_all_stock_codes() + + assert len(codes) == 2 + + def test_get_global_last_date(self, mock_storage): + """Test getting global last date.""" + with patch('src.data.sync.Storage', return_value=mock_storage): + sync = DataSync() + sync.storage = mock_storage + + mock_storage.load.return_value = pd.DataFrame({ + 'ts_code': ['000001.SZ', '600000.SH'], + 'trade_date': ['20240102', '20240103'], + }) + + last_date = sync.get_global_last_date() + assert last_date == '20240103' + + def test_get_global_last_date_empty(self, mock_storage): + """Test getting last date from empty storage.""" + with patch('src.data.sync.Storage', return_value=mock_storage): + sync = DataSync() + sync.storage = mock_storage + + mock_storage.load.return_value = pd.DataFrame() + + last_date = sync.get_global_last_date() + assert last_date is None + + def test_sync_single_stock(self, mock_storage): + """Test syncing a single stock.""" + with patch('src.data.sync.Storage', return_value=mock_storage): + with patch('src.data.sync.get_daily', return_value=pd.DataFrame({ + 'ts_code': ['000001.SZ'], + 'trade_date': ['20240102'], + })): + sync = DataSync() + sync.storage = mock_storage + + result = sync.sync_single_stock( + ts_code='000001.SZ', + start_date='20240101', + end_date='20240102', + ) + + assert isinstance(result, pd.DataFrame) + assert len(result) == 1 + + def test_sync_single_stock_empty(self, mock_storage): + """Test syncing a stock with no data.""" + with patch('src.data.sync.Storage', return_value=mock_storage): + with patch('src.data.sync.get_daily', return_value=pd.DataFrame()): + sync = DataSync() + sync.storage = mock_storage + + result = sync.sync_single_stock( + ts_code='INVALID.SZ', + start_date='20240101', + end_date='20240102', + ) + + assert result.empty + + +class TestSyncAll: + """Test sync_all function.""" + + def test_full_sync_mode(self, mock_storage): + """Test full sync mode when force_full=True.""" + with patch('src.data.sync.Storage', return_value=mock_storage): + with patch('src.data.sync.get_daily', return_value=pd.DataFrame()): + sync = DataSync() + sync.storage = mock_storage + sync.sync_single_stock = Mock(return_value=pd.DataFrame()) + + mock_storage.load.return_value = pd.DataFrame({ + 'ts_code': ['000001.SZ'], + }) + + result = sync.sync_all(force_full=True) + + # Verify sync_single_stock was called with default start date + sync.sync_single_stock.assert_called_once() + call_args = sync.sync_single_stock.call_args + assert call_args[1]['start_date'] == DEFAULT_START_DATE + + def test_incremental_sync_mode(self, mock_storage): + """Test incremental sync mode when data exists.""" + with patch('src.data.sync.Storage', return_value=mock_storage): + sync = DataSync() + sync.storage = mock_storage + sync.sync_single_stock = Mock(return_value=pd.DataFrame()) + + # Mock existing data with last date + mock_storage.load.side_effect = [ + pd.DataFrame({ + 'ts_code': ['000001.SZ'], + 'trade_date': ['20240102'], + }), # get_all_stock_codes + pd.DataFrame({ + 'ts_code': ['000001.SZ'], + 'trade_date': ['20240102'], + }), # get_global_last_date + ] + + result = sync.sync_all(force_full=False) + + # Verify sync_single_stock was called with next date + sync.sync_single_stock.assert_called_once() + call_args = sync.sync_single_stock.call_args + assert call_args[1]['start_date'] == '20240103' + + def test_manual_start_date(self, mock_storage): + """Test sync with manual start date.""" + with patch('src.data.sync.Storage', return_value=mock_storage): + sync = DataSync() + sync.storage = mock_storage + sync.sync_single_stock = Mock(return_value=pd.DataFrame()) + + mock_storage.load.return_value = pd.DataFrame({ + 'ts_code': ['000001.SZ'], + }) + + result = sync.sync_all(force_full=False, start_date='20230601') + + sync.sync_single_stock.assert_called_once() + call_args = sync.sync_single_stock.call_args + assert call_args[1]['start_date'] == '20230601' + + def test_no_stocks_found(self, mock_storage): + """Test sync when no stocks are found.""" + with patch('src.data.sync.Storage', return_value=mock_storage): + sync = DataSync() + sync.storage = mock_storage + + mock_storage.load.return_value = pd.DataFrame() + + result = sync.sync_all() + + assert result == {} + + +class TestSyncAllConvenienceFunction: + """Test sync_all convenience function.""" + + def test_sync_all_function(self): + """Test sync_all convenience function.""" + with patch('src.data.sync.DataSync') as MockSync: + mock_instance = Mock() + mock_instance.sync_all.return_value = {} + MockSync.return_value = mock_instance + + result = sync_all(force_full=True) + + MockSync.assert_called_once() + mock_instance.sync_all.assert_called_once_with( + force_full=True, + start_date=None, + end_date=None, + ) + + +if __name__ == '__main__': + pytest.main([__file__, '-v'])