refactor: 清理代码日志、重构速率限制器、切换存储方案
- 移除 client.py 和 daily.py 中的调试日志 - 重构 rate_limiter 支持无限超时和更精确的令牌获取 - 变更 stock_basic 存储方案 HDF5 → CSV - 更新项目规则:强制使用 uv、禁止读取 config/ 目录 - 新增数据同步模块 sync.py 和测试 - .gitignore 添加 !data/ 允许跟踪数据文件
This commit is contained in:
3
.gitignore
vendored
3
.gitignore
vendored
@@ -71,3 +71,6 @@ cover/
|
||||
*.temp
|
||||
tmp/
|
||||
temp/
|
||||
|
||||
# 数据目录(允许跟踪)
|
||||
data/
|
||||
|
||||
@@ -17,6 +17,41 @@
|
||||
3. **文档字符串**: 使用 Google 风格的 docstring
|
||||
4. **测试覆盖**: 关键业务逻辑应有对应的单元测试
|
||||
|
||||
## Python 运行规范
|
||||
|
||||
**⚠️ 本项目强制使用 uv 作为 Python 包管理器和运行工具。禁止直接使用 `python` 或 `pip` 命令。**
|
||||
|
||||
### 禁止的命令 ❌
|
||||
|
||||
```bash
|
||||
# 禁止直接使用 python
|
||||
python -c "..." # 禁止!
|
||||
python script.py # 禁止!
|
||||
python -m pytest # 禁止!
|
||||
python -m pip install # 禁止!
|
||||
|
||||
# 禁止直接使用 pip
|
||||
pip install -e . # 禁止!
|
||||
pip install package # 禁止!
|
||||
pip list # 禁止!
|
||||
```
|
||||
|
||||
### 正确的 uv 用法 ✅
|
||||
|
||||
```bash
|
||||
# 运行 Python 代码
|
||||
uv run python -c "..." # ✅ 正确
|
||||
uv run python script.py # ✅ 正确
|
||||
|
||||
# 安装依赖
|
||||
uv pip install -e . # ✅ 正确
|
||||
uv pip install package # ✅ 正确
|
||||
|
||||
# 运行测试
|
||||
uv run pytest # ✅ 正确
|
||||
uv run pytest tests/test_sync.py # ✅ 正确
|
||||
```
|
||||
|
||||
## 目录结构规范
|
||||
|
||||
```
|
||||
|
||||
@@ -14,6 +14,18 @@
|
||||
4. **禁止搜索** - 不得在 `config/` 目录下进行任何搜索操作
|
||||
5. **禁止执行** - 不得在 `config/` 目录下执行任何命令
|
||||
|
||||
### 绝对禁止原则
|
||||
|
||||
**即使无法完成任务,也严禁读取 `config/` 目录下的任何文件。**
|
||||
|
||||
这是不可妥协的安全红线:
|
||||
- ❌ **禁止**:为了完成任务而读取配置文件
|
||||
- ❌ **禁止**:以调试为目的查看配置文件
|
||||
- ❌ **禁止**:以验证配置正确性为由读取文件
|
||||
- ❌ **禁止**:任何理由、任何借口、任何情况下的访问
|
||||
|
||||
**如果任务需要配置信息,必须通过 `src/config/` 模块提供的 API 获取,而不是直接读取文件。**
|
||||
|
||||
所有配置读取必须通过集中管理的配置模块(`src/config/`)进行。**`config/` 与 `src/config/` 是完全不同的目录,前者受保护,后者是配置模块代码目录**。
|
||||
|
||||
### 目录结构说明
|
||||
@@ -216,12 +228,14 @@ api_key = settings.api_key
|
||||
3. **安全漏洞评级**:标记为高优先级安全漏洞
|
||||
4. **构建阻断**:CI/CD 流水线自动失败
|
||||
5. **审计日志记录**:记录违规行为用于审计追踪
|
||||
6. **立即终止**:任何尝试读取 `config/` 目录的操作将被立即阻止
|
||||
|
||||
### 违规严重程度分类
|
||||
|
||||
| 等级 | 违规类型 | 处罚措施 |
|
||||
|------|---------|---------|
|
||||
| 严重 | 故意读取敏感配置文件(如 `.env`) | 代码审查拒绝、团队通知 |
|
||||
| 严重 | 故意读取敏感配置文件(如 `.env`) | 代码审查拒绝、团队通知、立即阻止 |
|
||||
| 严重 | 以"无法完成任务"为由读取配置文件 | 代码审查拒绝、团队通知、立即阻止 |
|
||||
| 高 | 使用工具访问 `config/` 目录 | 代码审查拒绝、要求整改 |
|
||||
| 中 | 在代码中硬编码配置路径 | 要求修改、代码审查标记 |
|
||||
| 低 | 潜在风险操作(需人工审核) | 代码审查提醒 |
|
||||
|
||||
242
AGENTS.md
Normal file
242
AGENTS.md
Normal file
@@ -0,0 +1,242 @@
|
||||
# ProStock 代理指南
|
||||
|
||||
A股量化投资框架 - Python 项目,用于量化股票投资分析。
|
||||
|
||||
## 构建/检查/测试命令
|
||||
|
||||
**⚠️ 重要:本项目强制使用 uv 作为 Python 包管理器和运行工具。禁止直接使用 `python` 或 `pip` 命令。**
|
||||
|
||||
```bash
|
||||
# 安装依赖(必须使用 uv)
|
||||
uv pip install -e .
|
||||
|
||||
# 运行所有测试
|
||||
uv run pytest
|
||||
|
||||
# 运行单个测试文件
|
||||
uv run pytest tests/test_sync.py
|
||||
|
||||
# 运行单个测试类
|
||||
uv run pytest tests/test_sync.py::TestDataSync
|
||||
|
||||
# 运行单个测试方法
|
||||
uv run pytest tests/test_sync.py::TestDataSync::test_get_all_stock_codes_from_daily
|
||||
|
||||
# 使用详细输出运行
|
||||
uv run pytest -v
|
||||
|
||||
# 运行覆盖率测试(如果安装了 pytest-cov)
|
||||
uv run pytest --cov=src --cov-report=term-missing
|
||||
```
|
||||
|
||||
### 禁止的命令 ❌
|
||||
|
||||
以下命令在本项目中**严格禁止**:
|
||||
|
||||
```bash
|
||||
# 禁止直接使用 python
|
||||
python -c "..." # 禁止!
|
||||
python script.py # 禁止!
|
||||
python -m pytest # 禁止!
|
||||
python -m pip install # 禁止!
|
||||
|
||||
# 禁止直接使用 pip
|
||||
pip install -e . # 禁止!
|
||||
pip install package # 禁止!
|
||||
pip list # 禁止!
|
||||
```
|
||||
|
||||
### 正确的 uv 用法 ✅
|
||||
|
||||
```bash
|
||||
# 运行 Python 代码
|
||||
uv run python -c "..." # ✅ 正确
|
||||
uv run python script.py # ✅ 正确
|
||||
|
||||
# 安装依赖
|
||||
uv pip install -e . # ✅ 正确
|
||||
uv pip install package # ✅ 正确
|
||||
|
||||
# 运行测试
|
||||
uv run pytest # ✅ 正确
|
||||
uv run pytest tests/test_sync.py # ✅ 正确
|
||||
```
|
||||
|
||||
## 项目结构
|
||||
|
||||
```
|
||||
ProStock/
|
||||
├── src/ # 源代码
|
||||
│ ├── data/ # 数据采集模块
|
||||
│ │ ├── __init__.py
|
||||
│ │ ├── client.py # Tushare API 客户端,带速率限制
|
||||
│ │ ├── config.py # 配置(pydantic-settings)
|
||||
│ │ ├── daily.py # 日线市场数据
|
||||
│ │ ├── rate_limiter.py # 令牌桶速率限制器
|
||||
│ │ ├── stock_basic.py # 股票基本信息
|
||||
│ │ ├── storage.py # HDF5 存储管理器
|
||||
│ │ └── sync.py # 数据同步
|
||||
│ ├── config/ # 全局配置
|
||||
│ │ ├── __init__.py
|
||||
│ │ └── settings.py # 应用设置(pydantic-settings)
|
||||
│ └── __init__.py
|
||||
├── tests/ # 测试文件
|
||||
│ ├── test_sync.py
|
||||
│ └── test_daily.py
|
||||
├── config/ # 配置文件
|
||||
│ └── .env.local # 环境变量(不在 git 中)
|
||||
├── data/ # 数据存储(HDF5 文件)
|
||||
├── docs/ # 文档
|
||||
├── pyproject.toml # 项目配置
|
||||
└── README.md
|
||||
```
|
||||
|
||||
## 代码风格指南
|
||||
|
||||
### Python 版本
|
||||
- **需要 Python 3.10+**
|
||||
- 使用现代 Python 特性(match/case、海象运算符、类型提示)
|
||||
|
||||
### 导入
|
||||
```python
|
||||
# 标准库优先
|
||||
import os
|
||||
import time
|
||||
from datetime import datetime, timedelta
|
||||
from pathlib import Path
|
||||
from typing import Optional, Dict, Callable
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
import threading
|
||||
|
||||
# 第三方包
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
# 本地模块(使用来自 src 的绝对导入)
|
||||
from src.data.client import TushareClient
|
||||
from src.data.storage import Storage
|
||||
from src.config.settings import get_settings
|
||||
```
|
||||
|
||||
### 类型提示
|
||||
- **始终使用类型提示** 用于函数参数和返回值
|
||||
- 对可空类型使用 `Optional[X]`
|
||||
- 当可用时使用现代联合语法 `X | Y`(Python 3.10+)
|
||||
- 从 `typing` 导入类型:`Optional`、`Dict`、`Callable` 等
|
||||
|
||||
```python
|
||||
def sync_single_stock(
|
||||
self,
|
||||
ts_code: str,
|
||||
start_date: str,
|
||||
end_date: str,
|
||||
) -> pd.DataFrame:
|
||||
...
|
||||
```
|
||||
|
||||
### 文档字符串
|
||||
- 使用 **Google 风格文档字符串**
|
||||
- 包含 Args、Returns 部分
|
||||
- 第一行保持简短摘要
|
||||
|
||||
```python
|
||||
def get_next_date(date_str: str) -> str:
|
||||
"""获取给定日期之后的下一天。
|
||||
|
||||
Args:
|
||||
date_str: YYYYMMDD 格式的日期
|
||||
|
||||
Returns:
|
||||
YYYYMMDD 格式的下一天日期
|
||||
"""
|
||||
...
|
||||
```
|
||||
|
||||
### 命名约定
|
||||
- 变量、函数、方法使用 `snake_case`
|
||||
- 类使用 `PascalCase`
|
||||
- 常量使用 `UPPER_CASE`
|
||||
- 私有方法:`_leading_underscore`
|
||||
- 受保护属性:`_single_underscore`
|
||||
|
||||
### 错误处理
|
||||
- 使用特定的异常,不要使用裸 `except:`
|
||||
- 使用上下文记录错误:`print(f"[ERROR] 上下文: {e}")`
|
||||
- 对 API 调用使用指数退避重试逻辑
|
||||
- 在关键错误时立即停止(设置停止标志)
|
||||
|
||||
```python
|
||||
try:
|
||||
data = api.query(...)
|
||||
except Exception as e:
|
||||
print(f"[ERROR] 获取 {ts_code} 失败: {e}")
|
||||
raise # 记录后重新抛出
|
||||
```
|
||||
|
||||
### 配置
|
||||
- 对所有配置使用 **pydantic-settings**
|
||||
- 从 `config/.env.local` 文件加载
|
||||
- 环境变量自动转换:`tushare_token` → `TUSHARE_TOKEN`
|
||||
- 对配置单例使用 `@lru_cache()`
|
||||
|
||||
### 数据存储
|
||||
- 通过 `pandas.HDFStore` 使用 **HDF5 格式** 进行持久化
|
||||
- 存储在 `data/` 目录中(通过 `DATA_PATH` 环境变量配置)
|
||||
- 对可追加数据集使用 `format="table"`
|
||||
- 追加时处理重复项:`drop_duplicates(subset=[...])`
|
||||
|
||||
### 线程与并发
|
||||
- 对 I/O 密集型任务(API 调用)使用 `ThreadPoolExecutor`
|
||||
- 实现停止标志以实现优雅关闭:`threading.Event()`
|
||||
- 数据同步默认工作线程数:10
|
||||
- 出错时始终使用 `executor.shutdown(wait=False, cancel_futures=True)`
|
||||
|
||||
### 日志记录
|
||||
- 使用带前缀的 print 语句:`[模块名] 消息`
|
||||
- 错误格式:`[ERROR] 上下文: 异常`
|
||||
- 进度:循环中使用 `tqdm`
|
||||
|
||||
### 测试
|
||||
- 使用 **pytest** 框架
|
||||
- 模拟外部依赖(Tushare API)
|
||||
- 使用 `@pytest.fixture` 进行测试设置
|
||||
- 在导入位置打补丁:`patch('src.data.sync.Storage')`
|
||||
- 测试成功和错误两种情况
|
||||
|
||||
### 日期格式
|
||||
- 使用 `YYYYMMDD` 字符串格式表示日期
|
||||
- 辅助函数:`get_today_date()`、`get_next_date()`
|
||||
- 完全同步的默认开始日期:`20180101`
|
||||
|
||||
### 依赖项
|
||||
关键包:
|
||||
- `pandas>=2.0.0` - 数据处理
|
||||
- `numpy>=1.24.0` - 数值计算
|
||||
- `tushare>=2.0.0` - A股数据 API
|
||||
- `pydantic>=2.0.0`、`pydantic-settings>=2.0.0` - 配置
|
||||
- `tqdm>=4.65.0` - 进度条
|
||||
- `pytest` - 测试(开发)
|
||||
|
||||
### 环境变量
|
||||
创建 `config/.env.local`:
|
||||
```bash
|
||||
TUSHARE_TOKEN=your_token_here
|
||||
DATA_PATH=data
|
||||
RATE_LIMIT=100
|
||||
THREADS=10
|
||||
```
|
||||
|
||||
## 常见任务
|
||||
|
||||
```bash
|
||||
# 同步所有股票(增量)
|
||||
uv run python -c "from src.data.sync import sync_all; sync_all()"
|
||||
|
||||
# 强制完全同步
|
||||
uv run python -c "from src.data.sync import sync_all; sync_all(force_full=True)"
|
||||
|
||||
# 自定义线程数
|
||||
uv run python -c "from src.data.sync import sync_all; sync_all(max_workers=20)"
|
||||
```
|
||||
28
README.md
28
README.md
@@ -2,6 +2,34 @@
|
||||
|
||||
A股量化投资框架
|
||||
|
||||
## 快速开始
|
||||
|
||||
### 安装依赖
|
||||
|
||||
**⚠️ 本项目强制使用 uv 作为 Python 包管理器,禁止直接使用 `python` 或 `pip` 命令。**
|
||||
|
||||
```bash
|
||||
# 使用 uv 安装(必须)
|
||||
uv pip install -e .
|
||||
```
|
||||
|
||||
### 数据同步
|
||||
|
||||
```bash
|
||||
# 增量同步(自动从最新日期开始)
|
||||
uv run python -c "from src.data.sync import sync_all; sync_all()"
|
||||
|
||||
# 全量同步(从 20180101 开始)
|
||||
uv run python -c "from src.data.sync import sync_all; sync_all(force_full=True)"
|
||||
|
||||
# 自定义线程数
|
||||
uv run python -c "from src.data.sync import sync_all; sync_all(max_workers=20)"
|
||||
```
|
||||
|
||||
## 文档
|
||||
|
||||
- [数据同步模块](docs/data_sync.md) - 详细的数据同步使用说明
|
||||
|
||||
## 模块
|
||||
|
||||
- `data/` - 数据获取
|
||||
|
||||
43
config/.env.test
Normal file
43
config/.env.test
Normal file
@@ -0,0 +1,43 @@
|
||||
# ===========================================
|
||||
# ProStock 本地环境配置
|
||||
# 此文件不会被提交到版本控制
|
||||
# ===========================================
|
||||
|
||||
# 数据库配置
|
||||
DATABASE_HOST=localhost
|
||||
DATABASE_PORT=5432
|
||||
DATABASE_NAME=prostock
|
||||
DATABASE_USER=postgres
|
||||
DATABASE_PASSWORD=your_password
|
||||
|
||||
# API密钥配置(重要:不要泄露)
|
||||
API_KEY=your_api_key_here
|
||||
SECRET_KEY=your_secret_key_here
|
||||
|
||||
# Redis配置(可选)
|
||||
REDIS_HOST=localhost
|
||||
REDIS_PORT=6379
|
||||
|
||||
# 应用配置
|
||||
APP_ENV=development
|
||||
APP_DEBUG=true
|
||||
APP_PORT=8000
|
||||
|
||||
# ===========================================
|
||||
# Tushare数据采集配置
|
||||
# ===========================================
|
||||
|
||||
# Tushare Pro API Token(重要:去 https://tushare.pro 注册获取)
|
||||
TUSHARE_TOKEN=3a0741c702ee7e5e5f2bf1f0846bafaafe4e320833240b2a7e4a685f
|
||||
|
||||
# 数据存储路径
|
||||
DATA_PATH=./data
|
||||
|
||||
# 限流配置:每分钟请求数(默认100)
|
||||
RATE_LIMIT=200
|
||||
|
||||
# 线程数(默认2)
|
||||
THREADS=2
|
||||
|
||||
ROOT_PATH=D:/PyProject/ProStock
|
||||
|
||||
Binary file not shown.
21
pyproject.toml
Normal file
21
pyproject.toml
Normal file
@@ -0,0 +1,21 @@
|
||||
[project]
|
||||
name = "ProStock"
|
||||
version = "0.1.0"
|
||||
description = "A股量化投资框架"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.10,<3.14"
|
||||
dependencies = [
|
||||
"pandas>=2.0.0",
|
||||
"numpy>=1.24.0",
|
||||
"tushare>=2.0.0",
|
||||
"pydantic>=2.0.0",
|
||||
"pydantic-settings>=2.0.0",
|
||||
"tqdm>=4.65.0",
|
||||
]
|
||||
|
||||
[build-system]
|
||||
requires = ["hatchling"]
|
||||
build-backend = "hatchling.build"
|
||||
|
||||
[tool.uv]
|
||||
package = false
|
||||
@@ -40,25 +40,26 @@ class TushareClient:
|
||||
self._api = ts.pro_api(self.token)
|
||||
return self._api
|
||||
|
||||
def query(self, api_name: str, timeout: float = 30.0, **params) -> pd.DataFrame:
|
||||
def query(self, api_name: str, timeout: float = None, **params) -> pd.DataFrame:
|
||||
"""Execute API query with rate limiting and retry.
|
||||
|
||||
Args:
|
||||
api_name: API name ('daily', 'pro_bar', etc.)
|
||||
timeout: Timeout for rate limiting
|
||||
timeout: Timeout for rate limiting (None = wait indefinitely)
|
||||
**params: API parameters
|
||||
|
||||
Returns:
|
||||
DataFrame with query results
|
||||
"""
|
||||
# Acquire rate limit token
|
||||
# Acquire rate limit token (None = wait indefinitely)
|
||||
timeout = timeout if timeout is not None else float('inf')
|
||||
success, wait_time = self.rate_limiter.acquire(timeout=timeout)
|
||||
|
||||
if not success:
|
||||
raise RuntimeError(f"Rate limit exceeded after {timeout}s timeout")
|
||||
|
||||
if wait_time > 0:
|
||||
print(f"[RateLimit] Waited {wait_time:.2f}s for token")
|
||||
pass # Silent wait
|
||||
|
||||
# Execute with retry
|
||||
max_retries = 3
|
||||
@@ -83,9 +84,6 @@ class TushareClient:
|
||||
api = self._get_api()
|
||||
data = api.query(api_name, **params)
|
||||
|
||||
available = self.rate_limiter.get_available_tokens()
|
||||
print(f"[Tushare] {api_name} | tokens: {available:.0f}/{self.rate_limiter.capacity}")
|
||||
|
||||
return data
|
||||
|
||||
except Exception as e:
|
||||
|
||||
@@ -63,18 +63,10 @@ def get_daily(
|
||||
else:
|
||||
factors_str = factors
|
||||
params["factors"] = factors_str
|
||||
print(f"[get_daily] factors param: '{factors_str}'")
|
||||
if adjfactor:
|
||||
params["adjfactor"] = "True"
|
||||
|
||||
# Fetch data using pro_bar (supports factors like tor, vr)
|
||||
print(f"[get_daily] Query params: {params}")
|
||||
data = client.query("pro_bar", **params)
|
||||
|
||||
if not data.empty:
|
||||
print(f"[get_daily] Returned columns: {data.columns.tolist()}")
|
||||
print(f"[get_daily] Sample row: {data.iloc[0].to_dict()}")
|
||||
else:
|
||||
print(f"[get_daily] No data for ts_code={ts_code}")
|
||||
|
||||
return data
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
|
||||
This module provides a thread-safe token bucket algorithm for rate limiting.
|
||||
"""
|
||||
|
||||
import time
|
||||
import threading
|
||||
from typing import Optional
|
||||
@@ -11,14 +12,12 @@ from dataclasses import dataclass, field
|
||||
@dataclass
|
||||
class RateLimiterStats:
|
||||
"""Statistics for rate limiter."""
|
||||
|
||||
total_requests: int = 0
|
||||
successful_requests: int = 0
|
||||
denied_requests: int = 0
|
||||
total_wait_time: float = 0.0
|
||||
current_tokens: float = field(default=None, init=False)
|
||||
|
||||
def __post_init__(self):
|
||||
self.current_tokens = field(default=None)
|
||||
current_tokens: Optional[float] = None
|
||||
|
||||
|
||||
class TokenBucketRateLimiter:
|
||||
@@ -54,13 +53,13 @@ class TokenBucketRateLimiter:
|
||||
self._stats = RateLimiterStats()
|
||||
self._stats.current_tokens = self.tokens
|
||||
|
||||
def acquire(self, timeout: float = 30.0) -> tuple[bool, float]:
|
||||
def acquire(self, timeout: float = float("inf")) -> tuple[bool, float]:
|
||||
"""Acquire a token from the bucket.
|
||||
|
||||
Blocks until a token is available or timeout expires.
|
||||
|
||||
Args:
|
||||
timeout: Maximum time to wait for a token in seconds
|
||||
timeout: Maximum time to wait for a token in seconds (default: inf)
|
||||
|
||||
Returns:
|
||||
Tuple of (success, wait_time):
|
||||
@@ -84,32 +83,58 @@ class TokenBucketRateLimiter:
|
||||
tokens_needed = 1 - self.tokens
|
||||
time_to_refill = tokens_needed / self.refill_rate
|
||||
|
||||
if time_to_refill > timeout:
|
||||
# Check if we can wait for the token within timeout
|
||||
# Handle infinite timeout specially
|
||||
is_infinite_timeout = timeout == float("inf")
|
||||
if not is_infinite_timeout and time_to_refill > timeout:
|
||||
self._stats.total_requests += 1
|
||||
self._stats.denied_requests += 1
|
||||
return False, timeout
|
||||
|
||||
# Wait for tokens
|
||||
self._lock.release()
|
||||
time.sleep(time_to_refill)
|
||||
self._lock.acquire()
|
||||
# Wait for tokens - loop until we get one or timeout
|
||||
while True:
|
||||
# Calculate remaining time we can wait
|
||||
elapsed = time.monotonic() - start_time
|
||||
remaining_timeout = (
|
||||
timeout - elapsed if not is_infinite_timeout else float("inf")
|
||||
)
|
||||
|
||||
wait_time = time.monotonic() - start_time
|
||||
# Check if we've exceeded timeout
|
||||
if not is_infinite_timeout and remaining_timeout <= 0:
|
||||
self._stats.total_requests += 1
|
||||
self._stats.denied_requests += 1
|
||||
return False, elapsed
|
||||
|
||||
with self._lock:
|
||||
# Calculate wait time for next token
|
||||
tokens_needed = max(0, 1 - self.tokens)
|
||||
time_to_wait = (
|
||||
tokens_needed / self.refill_rate if tokens_needed > 0 else 0.1
|
||||
)
|
||||
|
||||
# If we can't wait long enough, fail
|
||||
if not is_infinite_timeout and time_to_wait > remaining_timeout:
|
||||
self._stats.total_requests += 1
|
||||
self._stats.denied_requests += 1
|
||||
return False, elapsed
|
||||
|
||||
# Wait outside the lock to allow other threads to refill
|
||||
self._lock.release()
|
||||
time.sleep(
|
||||
min(time_to_wait, 0.1)
|
||||
) # Cap wait to 100ms to check frequently
|
||||
self._lock.acquire()
|
||||
|
||||
# Refill and check again
|
||||
self._refill()
|
||||
if self.tokens >= 1:
|
||||
self.tokens -= 1
|
||||
wait_time = time.monotonic() - start_time
|
||||
self._stats.total_requests += 1
|
||||
self._stats.successful_requests += 1
|
||||
self._stats.total_wait_time += wait_time
|
||||
self._stats.current_tokens = self.tokens
|
||||
return True, wait_time
|
||||
|
||||
self._stats.total_requests += 1
|
||||
self._stats.denied_requests += 1
|
||||
return False, wait_time
|
||||
|
||||
def acquire_nonblocking(self) -> tuple[bool, float]:
|
||||
"""Try to acquire a token without blocking.
|
||||
|
||||
|
||||
@@ -3,10 +3,19 @@
|
||||
Fetch basic stock information including code, name, listing date, etc.
|
||||
This is a special interface - call once to get all stocks (listed and delisted).
|
||||
"""
|
||||
import os
|
||||
import pandas as pd
|
||||
from pathlib import Path
|
||||
from typing import Optional, Literal, List
|
||||
from src.data.client import TushareClient
|
||||
from src.data.storage import Storage
|
||||
from src.data.config import get_config
|
||||
|
||||
|
||||
# CSV file path for stock basic data
|
||||
def _get_csv_path() -> Path:
|
||||
"""Get the CSV file path for stock basic data."""
|
||||
cfg = get_config()
|
||||
return cfg.data_path_resolved / "stock_basic.csv"
|
||||
|
||||
|
||||
def get_stock_basic(
|
||||
@@ -75,20 +84,19 @@ def sync_all_stocks() -> pd.DataFrame:
|
||||
Returns:
|
||||
pd.DataFrame with all stock information
|
||||
"""
|
||||
# Initialize storage
|
||||
storage = Storage()
|
||||
csv_path = _get_csv_path()
|
||||
|
||||
# Check if already exists
|
||||
if storage.exists("stock_basic"):
|
||||
print("[sync_all_stocks] stock_basic data already exists, skipping...")
|
||||
return storage.load("stock_basic")
|
||||
# Check if CSV file already exists
|
||||
if csv_path.exists():
|
||||
print("[sync_all_stocks] stock_basic.csv already exists, skipping...")
|
||||
return pd.read_csv(csv_path)
|
||||
|
||||
print("[sync_all_stocks] Fetching all stocks (listed and delisted)...")
|
||||
|
||||
# Fetch all stocks - explicitly get all list_status values
|
||||
# API default is L (listed), so we need to fetch all statuses
|
||||
client = TushareClient()
|
||||
|
||||
|
||||
all_data = []
|
||||
for status in ["L", "D", "P", "G"]:
|
||||
print(f"[sync_all_stocks] Fetching stocks with status: {status}")
|
||||
@@ -96,21 +104,20 @@ def sync_all_stocks() -> pd.DataFrame:
|
||||
print(f"[sync_all_stocks] Fetched {len(data)} stocks with status {status}")
|
||||
if not data.empty:
|
||||
all_data.append(data)
|
||||
|
||||
|
||||
if not all_data:
|
||||
print("[sync_all_stocks] No stock data fetched")
|
||||
return pd.DataFrame()
|
||||
|
||||
|
||||
# Combine all data
|
||||
data = pd.concat(all_data, ignore_index=True)
|
||||
# Remove duplicates if any
|
||||
data = data.drop_duplicates(subset=["ts_code"], keep="first")
|
||||
print(f"[sync_all_stocks] Total unique stocks: {len(data)}")
|
||||
|
||||
# Save to storage
|
||||
storage.save("stock_basic", data, mode="replace")
|
||||
|
||||
print(f"[sync_all_stocks] Saved {len(data)} stocks to local storage")
|
||||
# Save to CSV
|
||||
data.to_csv(csv_path, index=False, encoding="utf-8-sig")
|
||||
print(f"[sync_all_stocks] Saved {len(data)} stocks to {csv_path}")
|
||||
return data
|
||||
|
||||
|
||||
|
||||
255
tests/test_sync.py
Normal file
255
tests/test_sync.py
Normal file
@@ -0,0 +1,255 @@
|
||||
"""Tests for data synchronization module.
|
||||
|
||||
Tests the sync module's full/incremental sync logic for daily data:
|
||||
- Full sync when local data doesn't exist (from 20180101)
|
||||
- Incremental sync when local data exists (from last_date + 1)
|
||||
- Data integrity validation
|
||||
"""
|
||||
import pytest
|
||||
import pandas as pd
|
||||
from unittest.mock import Mock, patch, MagicMock
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
from src.data.sync import (
|
||||
DataSync,
|
||||
sync_all,
|
||||
get_today_date,
|
||||
get_next_date,
|
||||
DEFAULT_START_DATE,
|
||||
)
|
||||
|
||||
|
||||
class TestDateUtilities:
|
||||
"""Test date utility functions."""
|
||||
|
||||
def test_get_today_date_format(self):
|
||||
"""Test today date is in YYYYMMDD format."""
|
||||
result = get_today_date()
|
||||
assert len(result) == 8
|
||||
assert result.isdigit()
|
||||
|
||||
def test_get_next_date(self):
|
||||
"""Test getting next date."""
|
||||
result = get_next_date("20240101")
|
||||
assert result == "20240102"
|
||||
|
||||
def test_get_next_date_year_end(self):
|
||||
"""Test getting next date across year boundary."""
|
||||
result = get_next_date("20241231")
|
||||
assert result == "20250101"
|
||||
|
||||
def test_get_next_date_month_end(self):
|
||||
"""Test getting next date across month boundary."""
|
||||
result = get_next_date("20240131")
|
||||
assert result == "20240201"
|
||||
|
||||
|
||||
class TestDataSync:
|
||||
"""Test DataSync class functionality."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_storage(self):
|
||||
"""Create a mock storage instance."""
|
||||
storage = Mock(spec=Storage)
|
||||
storage.exists = Mock(return_value=False)
|
||||
storage.load = Mock(return_value=pd.DataFrame())
|
||||
storage.save = Mock(return_value={"status": "success", "rows": 0})
|
||||
return storage
|
||||
|
||||
@pytest.fixture
|
||||
def mock_client(self):
|
||||
"""Create a mock client instance."""
|
||||
return Mock(spec=TushareClient)
|
||||
|
||||
def test_get_all_stock_codes_from_daily(self, mock_storage):
|
||||
"""Test getting stock codes from daily data."""
|
||||
with patch('src.data.sync.Storage', return_value=mock_storage):
|
||||
sync = DataSync()
|
||||
sync.storage = mock_storage
|
||||
|
||||
mock_storage.load.return_value = pd.DataFrame({
|
||||
'ts_code': ['000001.SZ', '000001.SZ', '600000.SH'],
|
||||
})
|
||||
|
||||
codes = sync.get_all_stock_codes()
|
||||
|
||||
assert len(codes) == 2
|
||||
assert '000001.SZ' in codes
|
||||
assert '600000.SH' in codes
|
||||
|
||||
def test_get_all_stock_codes_fallback(self, mock_storage):
|
||||
"""Test fallback to stock_basic when daily is empty."""
|
||||
with patch('src.data.sync.Storage', return_value=mock_storage):
|
||||
sync = DataSync()
|
||||
sync.storage = mock_storage
|
||||
|
||||
# First call (daily) returns empty, second call (stock_basic) returns data
|
||||
mock_storage.load.side_effect = [
|
||||
pd.DataFrame(), # daily empty
|
||||
pd.DataFrame({'ts_code': ['000001.SZ', '600000.SH']}), # stock_basic
|
||||
]
|
||||
|
||||
codes = sync.get_all_stock_codes()
|
||||
|
||||
assert len(codes) == 2
|
||||
|
||||
def test_get_global_last_date(self, mock_storage):
|
||||
"""Test getting global last date."""
|
||||
with patch('src.data.sync.Storage', return_value=mock_storage):
|
||||
sync = DataSync()
|
||||
sync.storage = mock_storage
|
||||
|
||||
mock_storage.load.return_value = pd.DataFrame({
|
||||
'ts_code': ['000001.SZ', '600000.SH'],
|
||||
'trade_date': ['20240102', '20240103'],
|
||||
})
|
||||
|
||||
last_date = sync.get_global_last_date()
|
||||
assert last_date == '20240103'
|
||||
|
||||
def test_get_global_last_date_empty(self, mock_storage):
|
||||
"""Test getting last date from empty storage."""
|
||||
with patch('src.data.sync.Storage', return_value=mock_storage):
|
||||
sync = DataSync()
|
||||
sync.storage = mock_storage
|
||||
|
||||
mock_storage.load.return_value = pd.DataFrame()
|
||||
|
||||
last_date = sync.get_global_last_date()
|
||||
assert last_date is None
|
||||
|
||||
def test_sync_single_stock(self, mock_storage):
|
||||
"""Test syncing a single stock."""
|
||||
with patch('src.data.sync.Storage', return_value=mock_storage):
|
||||
with patch('src.data.sync.get_daily', return_value=pd.DataFrame({
|
||||
'ts_code': ['000001.SZ'],
|
||||
'trade_date': ['20240102'],
|
||||
})):
|
||||
sync = DataSync()
|
||||
sync.storage = mock_storage
|
||||
|
||||
result = sync.sync_single_stock(
|
||||
ts_code='000001.SZ',
|
||||
start_date='20240101',
|
||||
end_date='20240102',
|
||||
)
|
||||
|
||||
assert isinstance(result, pd.DataFrame)
|
||||
assert len(result) == 1
|
||||
|
||||
def test_sync_single_stock_empty(self, mock_storage):
|
||||
"""Test syncing a stock with no data."""
|
||||
with patch('src.data.sync.Storage', return_value=mock_storage):
|
||||
with patch('src.data.sync.get_daily', return_value=pd.DataFrame()):
|
||||
sync = DataSync()
|
||||
sync.storage = mock_storage
|
||||
|
||||
result = sync.sync_single_stock(
|
||||
ts_code='INVALID.SZ',
|
||||
start_date='20240101',
|
||||
end_date='20240102',
|
||||
)
|
||||
|
||||
assert result.empty
|
||||
|
||||
|
||||
class TestSyncAll:
|
||||
"""Test sync_all function."""
|
||||
|
||||
def test_full_sync_mode(self, mock_storage):
|
||||
"""Test full sync mode when force_full=True."""
|
||||
with patch('src.data.sync.Storage', return_value=mock_storage):
|
||||
with patch('src.data.sync.get_daily', return_value=pd.DataFrame()):
|
||||
sync = DataSync()
|
||||
sync.storage = mock_storage
|
||||
sync.sync_single_stock = Mock(return_value=pd.DataFrame())
|
||||
|
||||
mock_storage.load.return_value = pd.DataFrame({
|
||||
'ts_code': ['000001.SZ'],
|
||||
})
|
||||
|
||||
result = sync.sync_all(force_full=True)
|
||||
|
||||
# Verify sync_single_stock was called with default start date
|
||||
sync.sync_single_stock.assert_called_once()
|
||||
call_args = sync.sync_single_stock.call_args
|
||||
assert call_args[1]['start_date'] == DEFAULT_START_DATE
|
||||
|
||||
def test_incremental_sync_mode(self, mock_storage):
|
||||
"""Test incremental sync mode when data exists."""
|
||||
with patch('src.data.sync.Storage', return_value=mock_storage):
|
||||
sync = DataSync()
|
||||
sync.storage = mock_storage
|
||||
sync.sync_single_stock = Mock(return_value=pd.DataFrame())
|
||||
|
||||
# Mock existing data with last date
|
||||
mock_storage.load.side_effect = [
|
||||
pd.DataFrame({
|
||||
'ts_code': ['000001.SZ'],
|
||||
'trade_date': ['20240102'],
|
||||
}), # get_all_stock_codes
|
||||
pd.DataFrame({
|
||||
'ts_code': ['000001.SZ'],
|
||||
'trade_date': ['20240102'],
|
||||
}), # get_global_last_date
|
||||
]
|
||||
|
||||
result = sync.sync_all(force_full=False)
|
||||
|
||||
# Verify sync_single_stock was called with next date
|
||||
sync.sync_single_stock.assert_called_once()
|
||||
call_args = sync.sync_single_stock.call_args
|
||||
assert call_args[1]['start_date'] == '20240103'
|
||||
|
||||
def test_manual_start_date(self, mock_storage):
|
||||
"""Test sync with manual start date."""
|
||||
with patch('src.data.sync.Storage', return_value=mock_storage):
|
||||
sync = DataSync()
|
||||
sync.storage = mock_storage
|
||||
sync.sync_single_stock = Mock(return_value=pd.DataFrame())
|
||||
|
||||
mock_storage.load.return_value = pd.DataFrame({
|
||||
'ts_code': ['000001.SZ'],
|
||||
})
|
||||
|
||||
result = sync.sync_all(force_full=False, start_date='20230601')
|
||||
|
||||
sync.sync_single_stock.assert_called_once()
|
||||
call_args = sync.sync_single_stock.call_args
|
||||
assert call_args[1]['start_date'] == '20230601'
|
||||
|
||||
def test_no_stocks_found(self, mock_storage):
|
||||
"""Test sync when no stocks are found."""
|
||||
with patch('src.data.sync.Storage', return_value=mock_storage):
|
||||
sync = DataSync()
|
||||
sync.storage = mock_storage
|
||||
|
||||
mock_storage.load.return_value = pd.DataFrame()
|
||||
|
||||
result = sync.sync_all()
|
||||
|
||||
assert result == {}
|
||||
|
||||
|
||||
class TestSyncAllConvenienceFunction:
|
||||
"""Test sync_all convenience function."""
|
||||
|
||||
def test_sync_all_function(self):
|
||||
"""Test sync_all convenience function."""
|
||||
with patch('src.data.sync.DataSync') as MockSync:
|
||||
mock_instance = Mock()
|
||||
mock_instance.sync_all.return_value = {}
|
||||
MockSync.return_value = mock_instance
|
||||
|
||||
result = sync_all(force_full=True)
|
||||
|
||||
MockSync.assert_called_once()
|
||||
mock_instance.sync_all.assert_called_once_with(
|
||||
force_full=True,
|
||||
start_date=None,
|
||||
end_date=None,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pytest.main([__file__, '-v'])
|
||||
Reference in New Issue
Block a user