feat: 初始化 ProStock 项目基础结构和配置

- 添加项目规则文档(开发规范、安全规则、配置管理)
- 实现数据模块核心功能(API 客户端、限流器、存储管理、配置加载)
- 添加 .gitignore 和 .kilocodeignore 配置
- 配置环境变量模板
- 编写 daily 模块单元测试
This commit is contained in:
2026-01-31 03:04:51 +08:00
parent f3bb1d8933
commit e625a53162
17 changed files with 1832 additions and 0 deletions

73
.gitignore vendored Normal file
View File

@@ -0,0 +1,73 @@
# ===========================================
# Python项目.gitignore模板
# ===========================================
# Python字节码缓存
__pycache__/
*.py[cod]
*$py.class
*.so
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
pip-wheel-metadata/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# 虚拟环境
venv/
ENV/
env/
.venv/
# 配置文件(敏感信息)
.env.local
config/*.local
!config/.env.example
# 日志文件
*.log
logs/
*.log.*
# IDE和编辑器
.vscode/
.idea/
*.swp
*.swo
*~
# 操作系统文件
.DS_Store
Thumbs.db
# 测试覆盖率
.coverage
htmlcov/
.pytest_cache/
.coverage
cover/
# 数据库文件
*.db
*.sqlite
*.sqlite3
# 临时文件
*.tmp
*.temp
tmp/
temp/

View File

@@ -0,0 +1,72 @@
# ProStock 项目规则
本项目使用 Kilo Code 规则系统来确保代码质量和一致性。
## 项目概述
- **项目名称**: ProStock
- **主要语言**: Python
- **框架**: 待定
- **代码目录**: `src/`
- **配置目录**: `config/`
## 核心原则
1. **代码质量**: 所有代码必须符合 Python PEP 8 编码规范
2. **类型提示**: 建议为公共函数和类添加类型注解
3. **文档字符串**: 使用 Google 风格的 docstring
4. **测试覆盖**: 关键业务逻辑应有对应的单元测试
## 目录结构规范
```
project/
├── src/ # 源代码主目录
├── config/ # 配置文件目录(禁止直接读取)
├── docs/ # 文档目录
├── tests/ # 测试目录
├── .kilocode/ # Kilo Code 配置
│ └── rules/ # 规则文件
├── requirements.txt # 依赖管理
└── README.md # 项目说明
```
## 文件命名规范
- **Python 文件**: 使用小写下划线命名法 (`snake_case.py`)
- **配置文件**: 使用小写下划线命名法 (`config_settings.py`)
- **测试文件**: 使用 `test_` 前缀 (`test_example.py`)
- **常量文件**: 使用 `constants_` 前缀 (`constants_sizes.py`)
## 代码组织
- 每个 Python 模块应尽可能保持简洁,职责单一
- 避免在 `__init__.py` 中放置过多逻辑
- 使用相对导入 (`from .module import ...`) 而非绝对导入
- 配置应集中管理,避免硬编码
## 导入顺序规范
```python
# 1. 标准库导入
import os
import sys
from datetime import datetime
# 2. 第三方库导入
import pandas as pd
from flask import Flask
# 3. 本地应用导入
from src.config.settings import Settings
from src.models.user import User
```
## 提交前检查
在提交代码前,请确保:
- [ ] 所有代码通过类型检查(如使用 mypy
- [ ] 代码格式符合规范(使用 black/isort
- [ ] 没有未使用的导入或变量
- [ ] 关键功能有对应的测试
- [ ] 文档已更新(如需要)

View File

@@ -0,0 +1,325 @@
# Python 开发规范
## 1 代码风格规范
### 1.1 缩进与空格
- 使用 **4个空格** 进行缩进禁止使用Tab
- 每行代码最大长度限制为 **120字符**
- 二元运算符前后使用空格(链式赋值除外)
- 函数参数周围空格一致
```python
# 正确示例
def calculate_total_price(price: float, tax_rate: float) -> float:
total = price * (1 + tax_rate)
return round(total, 2)
# 错误示例
def calculate_total_price(price:float,tax_rate:float)->float:
total=price*(1+tax_rate)
return round(total,2)
```
### 1.2 导入规范
- 导入位于文件顶部,位于模块注释和文档字符串之后
- 标准库导入放在最前面,其次是第三方库,最后是本地应用模块
- 使用绝对导入,禁止使用通配符导入(`from module import *`
- 每组导入之间保留一个空行
```python
# 标准库
import os
import sys
from datetime import datetime
# 第三方库
import numpy as np
import pandas as pd
from pydantic import BaseModel
# 本地模块
from src.config import settings
from src.utils.logger import get_logger
```
### 1.3 命名规范
| 类型 | 规范 | 示例 |
|------|------|------|
| 模块 | 全小写,下划线分隔 | `data_processor.py` |
| 包 | 全小写,禁止下划线 | `src/utils` |
| 类 | PascalCase首字母大写 | `UserAccount`, `DataValidator` |
| 函数 | snake_case全小写 | `get_user_data()`, `calculate_total()` |
| 变量 | snake_case全小写 | `user_name`, `total_count` |
| 常量 | UPPER_SNAKE_CASE | `MAX_RETRY_COUNT`, `DEFAULT_TIMEOUT` |
| 私有方法/变量 | 单下划线前缀 | `_internal_method()`, `_private_var` |
| 类型变量 | PascalCase | `UserType`, `T = TypeVar('T')` |
### 1.4 注释规范
- 使用英文注释,禁止使用中文注释
- 复杂逻辑必须添加注释说明意图
- 注释与代码同步更新,禁止过时注释
```python
# 单行注释与代码同行时注释前保留2个空格
def process_data(data: list) -> dict: # Process input data and return statistics
pass
# 多行文档字符串Google风格
def calculate_metrics(values: list[float]) -> dict[str, float]:
"""Calculate statistical metrics for a list of values.
Args:
values: List of numeric values to analyze
Returns:
Dictionary containing mean, median, and standard deviation
Raises:
ValueError: If input list is empty
"""
pass
```
## 2 架构设计规范
### 2.1 模块结构
```
src/
├── core/ # 核心业务逻辑
├── modules/ # 功能模块
├── utils/ # 工具函数
├── config/ # 配置管理
├── services/ # 服务层
├── models/ # 数据模型
├── repositories/ # 数据访问层
└── schemas/ # Pydantic模型
```
### 2.2 依赖原则
- **依赖倒置**:高层模块不依赖低层模块,两者都依赖抽象
- **禁止循环依赖**:模块间引用必须形成有向无环图
- **单一职责**:每个类/模块只负责一项职责
```python
# 错误示例:高内聚低耦合违反
class UserService:
def __init__(self, db_connection, email_sender, cache_manager):
pass
# 正确示例:通过抽象接口解耦
class UserService:
def __init__(self, user_repository: IUserRepository):
self.repository = user_repository
```
### 2.3 函数设计
- **单一职责**:每个函数只做一件事
- **参数控制**函数参数不超过5个超过则使用对象封装
- **返回值明确**:返回类型必须注解
```python
# 错误示例:函数职责过多
def process_user_registration(name, email, password, send_email, create_session):
if send_email:
send_verification_email(email)
if create_session:
create_user_session(email)
return save_user(name, email, password)
# 正确示例:单一职责拆分
class UserRegistrationService:
def register(self, user_data: UserCreateDto) -> User:
user = self._validate_and_create_user(user_data)
self._send_verification_email(user)
self._create_session(user)
return user
def _validate_and_create_user(self, data: UserCreateDto) -> User:
pass
def _send_verification_email(self, user: User) -> None:
pass
```
## 3 配置管理规范
### 3.1 禁止硬编码
**所有关键配置必须外置,禁止在代码中硬编码**
| 类型 | 必须外置的配置 |
|------|----------------|
| 数据库连接 | HOST, PORT, USER, PASSWORD, DATABASE |
| API密钥 | SECRET_KEY, API_KEY |
| 外部服务 | ENDPOINT, TIMEOUT, RETRY_COUNT |
| 业务参数 | MAX_RETRIES, CACHE_TTL, RATE_LIMIT |
### 3.2 配置目录结构
```
config/
├── .env.example # 环境变量模板(不包含敏感信息)
├── .env.local # 本地环境配置(被.gitignore忽略
├── config.yaml # 通用配置
├── config.development.yaml # 开发环境配置
├── config.production.yaml # 生产环境配置
└── config.test.yaml # 测试环境配置
```
### 3.3 配置加载示例
```python
# src/config/settings.py
from pydantic_settings import BaseSettings
from functools import lru_cache
class Settings(BaseSettings):
"""应用配置类,从环境变量加载"""
# 数据库配置
database_host: str = "localhost"
database_port: int = 5432
database_name: str = "prostock"
database_user: str
database_password: str
# API配置
api_key: str
secret_key: str
jwt_algorithm: str = "HS256"
access_token_expire_minutes: int = 30
# Redis配置
redis_host: str = "localhost"
redis_port: int = 6379
class Config:
env_file = ".env.local"
env_file_encoding = "utf-8"
@lru_cache()
def get_settings() -> Settings:
"""获取配置单例"""
return Settings()
```
### 3.4 .env.example 模板
```bash
# ===========================================
# ProStock 环境变量配置模板
# 复制此文件为 .env.local 并填入实际值
# ===========================================
# 数据库配置
DATABASE_HOST=localhost
DATABASE_PORT=5432
DATABASE_NAME=prostock
DATABASE_USER=your_username
DATABASE_PASSWORD=your_password
# API密钥配置
API_KEY=your_api_key
SECRET_KEY=your_secret_key_here
# Redis配置可选
REDIS_HOST=localhost
REDIS_PORT=6379
```
## 4 错误处理规范
### 4.1 异常分类
```python
# src/core/exceptions.py
class BaseCustomException(Exception):
"""基础异常类"""
status_code: int = 500
detail: str = "An unexpected error occurred"
def __init__(self, detail: str = None):
self.detail = detail or self.detail
super().__init__(self.detail)
class ValidationError(BaseCustomException):
"""数据验证异常"""
status_code = 422
detail = "Validation error"
class AuthenticationError(BaseCustomException):
"""认证异常"""
status_code = 401
detail = "Authentication required"
class AuthorizationError(BaseCustomException):
"""授权异常"""
status_code = 403
detail = "Permission denied"
class NotFoundError(BaseCustomException):
"""资源不存在"""
status_code = 404
detail = "Resource not found"
```
### 4.2 错误处理原则
- 向上层传递有意义的信息
- 记录详细日志(不含敏感信息)
- 区分可恢复和不可恢复错误
## 5 测试规范
### 5.1 测试要求
- 所有核心功能必须有单元测试
- 关键业务逻辑测试覆盖率不低于80%
- 使用 `pytest` 作为测试框架
- 使用 `pytest-cov` 生成覆盖率报告
### 5.2 测试文件结构
```
tests/
├── conftest.py # 共享fixtures
├── unit/ # 单元测试
│ ├── test_models.py
│ ├── test_services.py
│ └── test_utils.py
├── integration/ # 集成测试
│ └── test_api.py
└── fixtures/ # 测试数据
```
## 6 Git提交规范
### 6.1 提交信息格式
```
<type>(<scope>): <subject>
<body>
<footer>
```
### 6.2 类型标识
| 类型 | 说明 |
|------|------|
| feat | 新功能 |
| fix | Bug修复 |
| docs | 文档更新 |
| style | 代码格式调整 |
| refactor | 重构 |
| test | 测试相关 |
| chore | 构建/辅助工具 |
## 7 代码审查清单
- [ ] 代码符合PEP 8规范
- [ ] 关键配置未硬编码
- [ ] 函数/类添加了类型注解
- [ ] 复杂逻辑有注释说明
- [ ] 单元测试覆盖关键逻辑
- [ ] 无循环依赖
- [ ] 命名符合规范
- [ ] 日志不包含敏感信息

View File

@@ -0,0 +1,271 @@
# 安全与访问控制规则
本文件定义了项目的安全约束和访问控制规则。
## 🔒 禁止访问配置文件
### 核心约束
**严禁在编码过程中访问根目录下的 `config` 目录下的任何文件**。这包括:
1. **禁止读取** - 不得使用任何方式读取 `config/` 目录下的文件
2. **禁止编辑** - 不得修改 `config/` 目录下的配置文件
3. **禁止查看** - 不得阅读 `config/` 目录下的文件内容
4. **禁止搜索** - 不得在 `config/` 目录下进行任何搜索操作
5. **禁止执行** - 不得在 `config/` 目录下执行任何命令
所有配置读取必须通过集中管理的配置模块(`src/config/`)进行。**`config/``src/config/` 是完全不同的目录,前者受保护,后者是配置模块代码目录**。
### 目录结构说明
```
ProStock/
├── config/ # 受保护的配置文件目录(禁止任何访问)
│ ├── .env.example # 环境变量模板
│ ├── .env.local # 本地环境配置(敏感)
│ ├── config.yaml # 通用配置
│ └── ...
├── src/config/ # 配置模块代码目录(仅限模块内部访问)
│ ├── __init__.py
│ ├── settings.py # 配置加载逻辑
│ └── ...
└── ...
```
### 受限制工具完整列表
所有以下工具均**严格禁止**访问 `config/` 目录:
| 工具类别 | 受限制工具 | 禁止操作 |
|---------|-----------|---------|
| 文件读取 | `read_file` | 读取 `config/` 目录下任何文件 |
| 文件读取 | `list_files` | 列出 `config/` 目录内容 |
| 文件编辑 | `edit_file` | 编辑 `config/` 目录下任何文件 |
| 文件编辑 | `search_and_replace` | 替换 `config/` 目录下文件内容 |
| 文件编辑 | `write_to_file` | 写入 `config/` 目录(新建或覆盖) |
| 文件编辑 | `delete_file` | 删除 `config/` 目录下任何文件 |
| 搜索工具 | `search_files` | 在 `config/` 目录下进行正则搜索 |
| 命令执行 | `execute_command` | 执行任何涉及 `config/` 目录的命令 |
| 代码模式 | 所有代码写入操作 | 向 `config/` 目录写入代码或配置 |
### 违禁行为完整列表 ❌
```python
# 禁止:直接读取 .env 文件
import dotenv
dotenv.load_dotenv('config/.env') # 禁止!
# 禁止:直接读取配置文件
with open('config/settings.json', 'r') as f: # 禁止!
config = json.load(f)
# 禁止:硬编码配置路径
config_path = os.path.join('config', 'database.yml') # 禁止!
# 禁止:使用 read_file 工具查看 config 文件
read_file(path='config/.env') # 禁止!
# 禁止:使用 edit_file/search_and_replace 编辑 config 文件
edit_file(path='config/settings.py', ...) # 禁止!
# 禁止:使用 search_files 搜索 config 目录
search_files(path='config', regex='.*') # 禁止!
# 禁止:使用 list_files 列出 config 目录
list_files(path='config', recursive=True) # 禁止!
# 禁止:使用 write_to_file 创建或修改 config 文件
write_to_file(path='config/custom.py', ...) # 禁止!
# 禁止:使用 delete_file 删除 config 文件
delete_file(path='config/old.env') # 禁止!
# 禁止Python 文件系统操作
import os
os.listdir('config') # 禁止!
os.path.exists('config/.env') # 禁止!
os.path.isfile('config/settings.yaml') # 禁止!
os.walk('config') # 禁止!
# 禁止pathlib 操作
from pathlib import Path
Path('config/.env').read_text() # 禁止!
Path('config').iterdir() # 禁止!
# 禁止glob 模式匹配
import glob
glob.glob('config/**/*') # 禁止!
glob.iglob('config/*.yaml') # 禁止!
# 禁止shutil 操作
import shutil
shutil.copy('config/.env', 'backup/') # 禁止!
```
### 违禁命令执行 ❌
```bash
# 禁止:进入 config 目录
cd config # 禁止!
# 禁止:列出 config 目录
ls config/ # 禁止!
dir config # 禁止!
# 禁止:读取 config 文件
cat config/.env # 禁止!
type config\.env # 禁止!
# 禁止:搜索 config 目录
grep -r "SECRET" config/ # 禁止!
# 禁止:任何 config 目录相关命令
find config -name "*.py" # 禁止!
```
### 合规行为 ✅
```python
# 正确:使用配置管理模块
from src.config.settings import Settings
settings = Settings()
db_config = settings.database
# 正确:配置模块内部处理
from src.config import get_config
config = get_config()
```
### 配置文件保护规则
1. **`config` 目录**:根目录下的 `config/` 目录**完全受保护**,禁止任何工具访问
2. **`src/config` 目录**:配置模块代码目录,仅限 `src/config/` 内部代码访问
3. **敏感文件**`.env` 文件必须添加到 `.gitignore`
4. **配置加载**:必须在应用启动时一次性加载,而非在运行时多次读取
5. **工具调用限制**:所有工具调用必须验证目标路径不包含 `config/` 前缀
## 🔐 敏感信息处理
### 禁止的行为
```python
# 禁止:在代码中硬编码密钥
API_KEY = "sk-1234567890abcdef" # 禁止!
# 禁止:打印敏感信息
print(f"Password: {password}") # 禁止!
# 禁止:将密钥写入日志
logger.debug(f"API Key: {api_key}") # 禁止!
```
### 合规做法
```python
# 正确:从环境变量获取
import os
API_KEY = os.environ.get('API_KEY')
# 正确:使用配置管理
from src.config.settings import Settings
settings = Settings()
api_key = settings.api_key
```
## 🛡️ 安全最佳实践
### 输入验证
- 所有外部输入必须经过验证和清理
- 使用参数化查询防止 SQL 注入
- 对用户输入进行适当的转义和过滤
### 依赖安全
- 定期更新依赖包以修复安全漏洞
- 使用 `pip audit``safety` 检查依赖
- 避免使用已知存在安全问题的包
### 日志规范
- **禁止** 记录敏感信息(密码、密钥、令牌等)
- **禁止** 记录完整的用户数据(考虑脱敏)
- **建议** 记录操作类型、用户ID不含敏感字段、时间戳
### 错误处理
- 禁止向用户暴露详细的错误堆栈
- 敏感错误应记录到安全日志,而非返回给客户端
- 使用通用的错误消息对外展示
## ⚠️ 违规处理
### 自动检测机制
1. **提交前扫描**:使用 Git hooks 自动扫描提交内容
2. **CI/CD 流水线检测**:在持续集成流程中运行安全扫描
3. **静态代码分析**:集成静态分析工具检测违规模式
4. **工具调用监控**:监控所有工具调用是否涉及 `config/` 目录
### 惩罚机制
违反上述规则将导致:
1. **代码审查不通过**:提交将被自动拒绝
2. **安全扫描工具报警**:触发安全警报通知
3. **安全漏洞评级**:标记为高优先级安全漏洞
4. **构建阻断**CI/CD 流水线自动失败
5. **审计日志记录**:记录违规行为用于审计追踪
### 违规严重程度分类
| 等级 | 违规类型 | 处罚措施 |
|------|---------|---------|
| 严重 | 故意读取敏感配置文件(如 `.env` | 代码审查拒绝、团队通知 |
| 高 | 使用工具访问 `config/` 目录 | 代码审查拒绝、要求整改 |
| 中 | 在代码中硬编码配置路径 | 要求修改、代码审查标记 |
| 低 | 潜在风险操作(需人工审核) | 代码审查提醒 |
## 📋 规则检查清单
### 提交前检查
- [ ] 没有使用任何工具访问 `config/` 目录
- [ ] 没有硬编码的 `config/` 目录路径
- [ ] 配置文件仅通过配置模块访问
- [ ] 没有硬编码的敏感信息
- [ ] 敏感信息从环境变量或安全存储获取
- [ ] 日志不包含敏感数据
- [ ] 错误处理不暴露敏感信息
### 工具调用检查
- [ ] `read_file` 未调用 `config/` 目录
- [ ] `edit_file` 未指向 `config/` 目录
- [ ] `write_to_file` 未写入 `config/` 目录
- [ ] `delete_file` 未删除 `config/` 目录
- [ ] `search_files` 未搜索 `config/` 目录
- [ ] `list_files` 未列出 `config/` 目录
- [ ] `execute_command` 未涉及 `config/` 路径
## 🔧 合规验证脚本
为确保规则得到遵守,可使用以下验证方法:
```bash
# 检查是否有访问 config 目录的代码
grep -r "config/\." --include="*.py" src/ tests/ docs/
# 检查工具调用中的 config 路径
grep -rn "path='config" --include="*.py"
# 检查是否有 .env 文件被提交
git ls-files | grep "^config/.env"
```
## 📝 培训与文档
1. **新成员培训**:入职时必须学习本安全规则
2. **定期审计**:每月进行一次安全规则执行情况审计
3. **规则更新**:安全规则随项目演进定期更新
4. **违规通报**:定期通报违规案例以提高安全意识

20
.kilocodeignore Normal file
View File

@@ -0,0 +1,20 @@
# 环境变量
.env*
.env.local
.env.*.local
# 密钥与证书
*.pem
*.key
id_rsa
id_dsa
*.pfx
# 数据库文件
*.sqlite
*.db
# 日志与缓存
logs/
cache/
*.log

28
config/.env.example Normal file
View File

@@ -0,0 +1,28 @@
# ===========================================
# ProStock 环境变量配置模板
#
# 使用说明:
# 1. 复制此文件为 .env.local
# 2. 填入实际的配置值
# 3. 放在 config/ 目录下
# ===========================================
# 数据库配置
DATABASE_HOST=localhost
DATABASE_PORT=5432
DATABASE_NAME=prostock
DATABASE_USER=postgres
DATABASE_PASSWORD=your_password
# API密钥配置重要不要泄露到版本控制
API_KEY=your_api_key_here
SECRET_KEY=your_secret_key_here
# Redis配置可选
REDIS_HOST=localhost
REDIS_PORT=6379
# 应用配置
APP_ENV=development
APP_DEBUG=true
APP_PORT=8000

5
src/__init__.py Normal file
View File

@@ -0,0 +1,5 @@
"""ProStock 股票分析项目
提供股票数据分析和交易策略等功能
"""
__version__ = "1.0.0"

7
src/config/__init__.py Normal file
View File

@@ -0,0 +1,7 @@
"""配置管理模块
提供配置加载和管理功能
"""
from src.config.settings import Settings, get_settings, settings
__all__ = ["Settings", "get_settings", "settings"]

58
src/config/settings.py Normal file
View File

@@ -0,0 +1,58 @@
"""配置管理模块
从环境变量加载应用配置使用pydantic-settings进行类型验证
"""
import os
from pathlib import Path
from pydantic_settings import BaseSettings
from functools import lru_cache
from typing import Optional
# 获取项目根目录config文件夹所在目录
PROJECT_ROOT = Path(__file__).parent.parent.parent
CONFIG_DIR = PROJECT_ROOT / "config"
class Settings(BaseSettings):
"""应用配置类,从环境变量加载"""
# 数据库配置
database_host: str = "localhost"
database_port: int = 5432
database_name: str = "prostock"
database_user: str
database_password: str
# API密钥配置
api_key: str
secret_key: str
# Redis配置
redis_host: str = "localhost"
redis_port: int = 6379
redis_password: Optional[str] = None
# 应用配置
app_env: str = "development"
app_debug: bool = False
app_port: int = 8000
class Config:
# 从 config/ 目录读取 .env.local 文件
env_file = str(CONFIG_DIR / ".env.local")
env_file_encoding = "utf-8"
case_sensitive = False
@lru_cache()
def get_settings() -> Settings:
"""获取配置单例
使用lru_cache确保配置只加载一次
"""
return Settings()
# 导出配置实例供全局使用
settings = get_settings()

15
src/data/__init__.py Normal file
View File

@@ -0,0 +1,15 @@
"""Data collection module for Tushare.
Provides simplified interfaces for fetching and storing Tushare data.
"""
from src.data.config import Config, get_config
from src.data.client import TushareClient
from src.data.storage import Storage
__all__ = [
"Config",
"get_config",
"TushareClient",
"Storage",
]

47
src/data/api.md Normal file
View File

@@ -0,0 +1,47 @@
1、通用行情接口https://tushare.pro/document/2?doc_id=109能够获取的字段参考https://tushare.pro/document/2?doc_id=27要求保存A股日线行情中所有输出字段和tor换手率 vr量比
ts_code str Y 证券代码,不支持多值输入,多值输入获取结果会有重复记录
start_date str N 开始日期 (日线格式YYYYMMDD提取分钟数据请用2019-09-01 09:00:00这种格式)
end_date str N 结束日期 (日线格式YYYYMMDD)
asset str Y 资产类别E股票 I沪深指数 C数字货币 FT期货 FD基金 O期权 CB可转债v1.2.39默认E
adj str N 复权类型(只针对股票)None未复权 qfq前复权 hfq后复权 , 默认None目前只支持日线复权同时复权机制是根据设定的end_date参数动态复权采用分红再投模式具体请参考常见问题列表里的说明。
freq str Y 数据频度 :支持分钟(min)/日(D)/周(W)/月(M)K线其中1min表示1分钟类推1/5/15/30/60分钟 默认D。对于分钟数据有600积分用户可以试用请求2次正式权限可以参考权限列表说明 ,使用方法请参考股票分钟使用方法。
ma list N 均线支持任意合理int数值。注均线是动态计算要设置一定时间范围才能获得相应的均线比如5日均线开始和结束日期参数跨度必须要超过5日。目前只支持单一个股票提取均线即需要输入ts_code参数。e.g: ma_5表示5日均价ma_v_5表示5日均量
factors list N 股票因子asset='E'有效)支持 tor换手率 vr量比
adjfactor str N 复权因子在复权数据时如果此参数为True返回的数据中则带复权因子默认为False。 该功能从1.2.33版本开始生效
输出指标
具体输出的数据指标可参考各行情具体指标:
股票Dailyhttps://tushare.pro/document/2?doc_id=27
基金Dailyhttps://tushare.pro/document/2?doc_id=127
期货Dailyhttps://tushare.pro/document/2?doc_id=138
期权Dailyhttps://tushare.pro/document/2?doc_id=159
指数Dailyhttps://tushare.pro/document/2?doc_id=95
A股日线行情
接口daily可以通过数据工具调试和查看数据
数据说明交易日每天15点16点之间入库。本接口是未复权行情停牌期间不提供数据
调取说明基础积分每分钟内可调取500次每次6000条数据一次请求相当于提取一个股票23年历史
描述:获取股票行情数据,或通过通用行情接口获取数据,包含了前后复权数据
输入参数
名称 类型 必选 描述
ts_code str N 股票代码(支持多个股票同时提取,逗号分隔)
trade_date str N 交易日期YYYYMMDD
start_date str N 开始日期(YYYYMMDD)
end_date str N 结束日期(YYYYMMDD)
日期都填YYYYMMDD格式比如20181010
输出参数
名称 类型 描述
ts_code str 股票代码
trade_date str 交易日期
open float 开盘价
high float 最高价
low float 最低价
close float 收盘价
pre_close float 昨收价【除权价】
change float 涨跌额
pct_chg float 涨跌幅 【基于除权后的昨收计算的涨跌幅:(今收-除权昨收)/除权昨收 】
vol float 成交量 (手)
amount float 成交额 (千元)

89
src/data/client.py Normal file
View File

@@ -0,0 +1,89 @@
"""Simplified Tushare client with rate limiting and retry logic."""
import time
import pandas as pd
from typing import Optional
from src.data.config import get_config
from src.data.rate_limiter import TokenBucketRateLimiter
class TushareClient:
"""Tushare API client with rate limiting and retry."""
def __init__(self, token: Optional[str] = None):
"""Initialize client.
Args:
token: Tushare API token (auto-loaded from config if not provided)
"""
cfg = get_config()
token = token or cfg.tushare_token
if not token:
raise ValueError("Tushare token is required")
self.token = token
self.config = cfg
# Initialize rate limiter: capacity = rate_limit, refill_rate = rate_limit/60 per second
rate_per_second = cfg.rate_limit / 60.0
self.rate_limiter = TokenBucketRateLimiter(
capacity=cfg.rate_limit,
refill_rate_per_second=rate_per_second,
)
self._api = None
def _get_api(self):
"""Get Tushare API instance."""
if self._api is None:
import tushare as ts
self._api = ts.pro_api(self.token)
return self._api
def query(self, api_name: str, timeout: float = 30.0, **params) -> pd.DataFrame:
"""Execute API query with rate limiting and retry.
Args:
api_name: API name (e.g., 'daily')
timeout: Timeout for rate limiting
**params: API parameters
Returns:
DataFrame with query results
"""
# Acquire rate limit token
success, wait_time = self.rate_limiter.acquire(timeout=timeout)
if not success:
raise RuntimeError(f"Rate limit exceeded after {timeout}s timeout")
if wait_time > 0:
print(f"[RateLimit] Waited {wait_time:.2f}s for token")
# Execute with retry
max_retries = 3
retry_delays = [1, 3, 10]
for attempt in range(max_retries):
try:
api = self._get_api()
data = api.query(api_name, **params)
available = self.rate_limiter.get_available_tokens()
print(f"[Tushare] {api_name} | tokens: {available:.0f}/{self.rate_limiter.capacity}")
return data
except Exception as e:
if attempt < max_retries - 1:
delay = retry_delays[attempt]
print(f"[Retry] {api_name} failed (attempt {attempt + 1}): {e}, retry in {delay}s")
time.sleep(delay)
else:
raise RuntimeError(f"API call failed after {max_retries} attempts: {e}")
return pd.DataFrame()
def close(self):
"""Close client."""
pass

33
src/data/config.py Normal file
View File

@@ -0,0 +1,33 @@
"""Configuration management for data collection module."""
from pathlib import Path
from pydantic_settings import BaseSettings
class Config(BaseSettings):
"""Application configuration loaded from environment variables."""
# Tushare API token
tushare_token: str = ""
# Data storage path
data_path: Path = Path("./data")
# Rate limit: requests per minute
rate_limit: int = 100
# Thread pool size
threads: int = 2
class Config:
env_file = ".env.local"
env_file_encoding = "utf-8"
case_sensitive = False
# Global config instance
config = Config()
def get_config() -> Config:
"""Get configuration instance."""
return config

70
src/data/daily.py Normal file
View File

@@ -0,0 +1,70 @@
"""Simplified daily market data interface.
A single function to fetch A股日线行情 data from Tushare.
Supports all output fields including tor (换手率) and vr (量比).
"""
import pandas as pd
from typing import Optional, List, Literal
from src.data.client import TushareClient
def get_daily(
ts_code: str,
start_date: Optional[str] = None,
end_date: Optional[str] = None,
trade_date: Optional[str] = None,
adj: Literal[None, "qfq", "hfq"] = None,
factors: Optional[List[Literal["tor", "vr"]]] = None,
adjfactor: bool = False,
) -> pd.DataFrame:
"""Fetch daily market data for A-share stocks.
This is a simplified interface that combines rate limiting, API calls,
and error handling into a single function.
Args:
ts_code: Stock code (e.g., '000001.SZ', '600000.SH')
start_date: Start date in YYYYMMDD format
end_date: End date in YYYYMMDD format
trade_date: Specific trade date in YYYYMMDD format
adj: Adjustment type - None, 'qfq' (forward), 'hfq' (backward)
factors: List of factors to include - 'tor' (turnover rate), 'vr' (volume ratio)
adjfactor: Whether to include adjustment factor
Returns:
pd.DataFrame with daily market data containing:
- Base fields: ts_code, trade_date, open, high, low, close, pre_close,
change, pct_chg, vol, amount
- Factor fields (if requested): tor, vr
- Adjustment factor (if adjfactor=True): adjfactor
Example:
>>> data = get_daily('000001.SZ', start_date='20240101', end_date='20240131')
>>> data = get_daily('600000.SH', factors=['tor', 'vr'])
"""
# Initialize client
client = TushareClient()
# Build parameters
params = {"ts_code": ts_code}
if start_date:
params["start_date"] = start_date
if end_date:
params["end_date"] = end_date
if trade_date:
params["trade_date"] = trade_date
if adj:
params["adj"] = adj
if factors:
params["factors"] = factors
if adjfactor:
params["adjfactor"] = "True"
# Fetch data
data = client.query("daily", **params)
if data.empty:
print(f"[get_daily] No data for ts_code={ts_code}")
return data

167
src/data/rate_limiter.py Normal file
View File

@@ -0,0 +1,167 @@
"""Token bucket rate limiter implementation.
This module provides a thread-safe token bucket algorithm for rate limiting.
"""
import time
import threading
from typing import Optional
from dataclasses import dataclass, field
@dataclass
class RateLimiterStats:
"""Statistics for rate limiter."""
total_requests: int = 0
successful_requests: int = 0
denied_requests: int = 0
total_wait_time: float = 0.0
current_tokens: float = field(default=None, init=False)
def __post_init__(self):
self.current_tokens = field(default=None)
class TokenBucketRateLimiter:
"""Thread-safe token bucket rate limiter.
Implements a token bucket algorithm for controlling request rate.
Tokens are added at a fixed rate up to the bucket capacity.
Attributes:
capacity: Maximum number of tokens in the bucket
refill_rate: Number of tokens added per second
initial_tokens: Initial number of tokens (default: capacity)
"""
def __init__(
self,
capacity: int = 100,
refill_rate_per_second: float = 1.67,
initial_tokens: Optional[int] = None,
) -> None:
"""Initialize the token bucket rate limiter.
Args:
capacity: Maximum token capacity
refill_rate_per_second: Token refill rate per second
initial_tokens: Initial token count (default: capacity)
"""
self.capacity = capacity
self.refill_rate = refill_rate_per_second
self.tokens = float(initial_tokens if initial_tokens is not None else capacity)
self.last_refill_time = time.monotonic()
self._lock = threading.RLock()
self._stats = RateLimiterStats()
self._stats.current_tokens = self.tokens
def acquire(self, timeout: float = 30.0) -> tuple[bool, float]:
"""Acquire a token from the bucket.
Blocks until a token is available or timeout expires.
Args:
timeout: Maximum time to wait for a token in seconds
Returns:
Tuple of (success, wait_time):
- success: True if token was acquired, False if timed out
- wait_time: Time spent waiting for token
"""
start_time = time.monotonic()
wait_time = 0.0
with self._lock:
self._refill()
if self.tokens >= 1:
self.tokens -= 1
self._stats.total_requests += 1
self._stats.successful_requests += 1
self._stats.current_tokens = self.tokens
return True, 0.0
# Calculate time to wait for next token
tokens_needed = 1 - self.tokens
time_to_refill = tokens_needed / self.refill_rate
if time_to_refill > timeout:
self._stats.total_requests += 1
self._stats.denied_requests += 1
return False, timeout
# Wait for tokens
self._lock.release()
time.sleep(time_to_refill)
self._lock.acquire()
wait_time = time.monotonic() - start_time
with self._lock:
self._refill()
if self.tokens >= 1:
self.tokens -= 1
self._stats.total_requests += 1
self._stats.successful_requests += 1
self._stats.total_wait_time += wait_time
self._stats.current_tokens = self.tokens
return True, wait_time
self._stats.total_requests += 1
self._stats.denied_requests += 1
return False, wait_time
def acquire_nonblocking(self) -> tuple[bool, float]:
"""Try to acquire a token without blocking.
Returns:
Tuple of (success, wait_time):
- success: True if token was acquired, False otherwise
- wait_time: 0 for non-blocking, or required wait time if failed
"""
with self._lock:
self._refill()
if self.tokens >= 1:
self.tokens -= 1
self._stats.total_requests += 1
self._stats.successful_requests += 1
self._stats.current_tokens = self.tokens
return True, 0.0
# Calculate time needed
tokens_needed = 1 - self.tokens
time_to_refill = tokens_needed / self.refill_rate
self._stats.total_requests += 1
self._stats.denied_requests += 1
return False, time_to_refill
def _refill(self) -> None:
"""Refill tokens based on elapsed time."""
current_time = time.monotonic()
elapsed = current_time - self.last_refill_time
self.last_refill_time = current_time
tokens_to_add = elapsed * self.refill_rate
self.tokens = min(self.capacity, self.tokens + tokens_to_add)
def get_available_tokens(self) -> float:
"""Get the current number of available tokens.
Returns:
Current token count
"""
with self._lock:
self._refill()
return self.tokens
def get_stats(self) -> RateLimiterStats:
"""Get rate limiter statistics.
Returns:
RateLimiterStats instance
"""
with self._lock:
self._refill()
self._stats.current_tokens = self.tokens
return self._stats

133
src/data/storage.py Normal file
View File

@@ -0,0 +1,133 @@
"""Simplified HDF5 storage for data persistence."""
import os
import pandas as pd
from pathlib import Path
from typing import Optional
from src.data.config import get_config
class Storage:
"""HDF5 storage manager for saving and loading data."""
def __init__(self, path: Optional[Path] = None):
"""Initialize storage.
Args:
path: Base path for data storage (auto-loaded from config if not provided)
"""
cfg = get_config()
self.base_path = path or cfg.data_path
self.base_path.mkdir(parents=True, exist_ok=True)
def _get_file_path(self, name: str) -> Path:
"""Get full path for an HDF5 file."""
return self.base_path / f"{name}.h5"
def save(self, name: str, data: pd.DataFrame, mode: str = "append") -> dict:
"""Save data to HDF5 file.
Args:
name: Dataset name (also used as filename)
data: DataFrame to save
mode: 'append' or 'replace'
Returns:
Dict with save result
"""
if data.empty:
return {"status": "skipped", "rows": 0}
file_path = self._get_file_path(name)
try:
with pd.HDFStore(file_path, mode="a") as store:
if mode == "replace" or name not in store.keys():
store.put(name, data, format="table")
else:
# Merge with existing data
existing = store[name]
combined = pd.concat([existing, data], ignore_index=True)
combined = combined.drop_duplicates(subset=["ts_code", "trade_date"], keep="last")
store.put(name, combined, format="table")
print(f"[Storage] Saved {len(data)} rows to {file_path}")
return {"status": "success", "rows": len(data), "path": str(file_path)}
except Exception as e:
print(f"[Storage] Error saving {name}: {e}")
return {"status": "error", "error": str(e)}
def load(self, name: str,
start_date: Optional[str] = None,
end_date: Optional[str] = None,
ts_code: Optional[str] = None) -> pd.DataFrame:
"""Load data from HDF5 file.
Args:
name: Dataset name
start_date: Start date filter (YYYYMMDD)
end_date: End date filter (YYYYMMDD)
ts_code: Stock code filter
Returns:
DataFrame with loaded data
"""
file_path = self._get_file_path(name)
if not file_path.exists():
print(f"[Storage] File not found: {file_path}")
return pd.DataFrame()
try:
with pd.HDFStore(file_path, mode="r") as store:
if name not in store.keys():
return pd.DataFrame()
data = store[name]
# Apply filters
if start_date and end_date and "trade_date" in data.columns:
data = data[(data["trade_date"] >= start_date) & (data["trade_date"] <= end_date)]
if ts_code and "ts_code" in data.columns:
data = data[data["ts_code"] == ts_code]
return data
except Exception as e:
print(f"[Storage] Error loading {name}: {e}")
return pd.DataFrame()
def get_last_date(self, name: str) -> Optional[str]:
"""Get the latest date in storage.
Args:
name: Dataset name
Returns:
Latest date string or None
"""
data = self.load(name)
if data.empty or "trade_date" not in data.columns:
return None
return str(data["trade_date"].max())
def exists(self, name: str) -> bool:
"""Check if dataset exists."""
return self._get_file_path(name).exists()
def delete(self, name: str) -> bool:
"""Delete a dataset.
Args:
name: Dataset name
Returns:
True if deleted
"""
file_path = self._get_file_path(name)
if file_path.exists():
file_path.unlink()
print(f"[Storage] Deleted {file_path}")
return True
return False

419
tests/test_daily.py Normal file
View File

@@ -0,0 +1,419 @@
"""Test for daily market data API.
Tests the daily interface implementation against api.md requirements:
- A股日线行情所有输出字段
- tor 换手率
- vr 量比
"""
import pytest
import pandas as pd
from unittest.mock import Mock, patch
from src.data.daily import get_daily
from src.data.client import TushareClient
# Expected output fields according to api.md
EXPECTED_BASE_FIELDS = [
'ts_code', # 股票代码
'trade_date', # 交易日期
'open', # 开盘价
'high', # 最高价
'low', # 最低价
'close', # 收盘价
'pre_close', # 昨收价
'change', # 涨跌额
'pct_chg', # 涨跌幅
'vol', # 成交量
'amount', # 成交额
]
EXPECTED_FACTOR_FIELDS = [
'tor', # 换手率
'vr', # 量比
]
def run_tests_with_print():
"""Run all tests and print results."""
print("=" * 60)
print("Daily API 测试开始")
print("=" * 60)
test_results = []
# Test 1: Basic daily data fetch
print("\n【测试1】基本日线数据获取")
print("-" * 40)
mock_data = pd.DataFrame({
'ts_code': ['000001.SZ'],
'trade_date': ['20240102'],
'open': [10.5],
'high': [11.0],
'low': [10.2],
'close': [10.8],
'pre_close': [10.3],
'change': [0.5],
'pct_chg': [4.85],
'vol': [1000000],
'amount': [10800000],
})
with patch.object(TushareClient, '__init__', lambda self, token=None: None):
with patch.object(TushareClient, 'query', return_value=mock_data):
result = get_daily('000001.SZ', start_date='20240101', end_date='20240131')
print(f"获取数据形状: {result.shape}")
print(f"获取数据列: {result.columns.tolist()}")
print(f"数据内容:\n{result}")
# Verify
tests_passed = isinstance(result, pd.DataFrame)
tests_passed &= len(result) == 1
tests_passed &= result['ts_code'].iloc[0] == '000001.SZ'
print(f"\n测试结果: {'通过 ✓' if tests_passed else '失败 ✗'}")
test_results.append(("基本日线数据获取", tests_passed))
# Test 2: Fetch with factors
print("\n【测试2】获取含换手率和量比的数据")
print("-" * 40)
mock_data_factors = pd.DataFrame({
'ts_code': ['000001.SZ'],
'trade_date': ['20240102'],
'open': [10.5],
'high': [11.0],
'low': [10.2],
'close': [10.8],
'pre_close': [10.3],
'change': [0.5],
'pct_chg': [4.85],
'vol': [1000000],
'amount': [10800000],
'tor': [2.5],
'vr': [1.2],
})
with patch.object(TushareClient, '__init__', lambda self, token=None: None):
with patch.object(TushareClient, 'query', return_value=mock_data_factors):
result = get_daily(
'000001.SZ',
start_date='20240101',
end_date='20240131',
factors=['tor', 'vr'],
)
print(f"获取数据形状: {result.shape}")
print(f"获取数据列: {result.columns.tolist()}")
print(f"数据内容:\n{result}")
# Verify columns
tests_passed = isinstance(result, pd.DataFrame)
missing_base = [f for f in EXPECTED_BASE_FIELDS if f not in result.columns]
missing_factors = [f for f in EXPECTED_FACTOR_FIELDS if f not in result.columns]
print(f"\n基础列检查: {'全部存在 ✓' if not missing_base else f'缺失: {missing_base}'}")
print(f"因子列检查: {'全部存在 ✓' if not missing_factors else f'缺失: {missing_factors}'}")
tests_passed &= len(missing_base) == 0
tests_passed &= len(missing_factors) == 0
print(f"\n测试结果: {'通过 ✓' if tests_passed else '失败 ✗'}")
test_results.append(("含因子数据获取", tests_passed))
# Test 3: Output fields completeness
print("\n【测试3】输出字段完整性检查")
print("-" * 40)
mock_data = pd.DataFrame({
'ts_code': ['600000.SH'],
'trade_date': ['20240102'],
'open': [10.5],
'high': [11.0],
'low': [10.2],
'close': [10.8],
'pre_close': [10.3],
'change': [0.5],
'pct_chg': [4.85],
'vol': [1000000],
'amount': [10800000],
})
with patch.object(TushareClient, '__init__', lambda self, token=None: None):
with patch.object(TushareClient, 'query', return_value=mock_data):
result = get_daily('600000.SH')
print(f"获取数据形状: {result.shape}")
print(f"获取数据列: {result.columns.tolist()}")
print(f"期望基础列: {EXPECTED_BASE_FIELDS}")
missing = set(EXPECTED_BASE_FIELDS) - set(result.columns)
print(f"缺失列: {missing if missing else ''}")
tests_passed = set(EXPECTED_BASE_FIELDS).issubset(result.columns.tolist())
print(f"\n测试结果: {'通过 ✓' if tests_passed else '失败 ✗'}")
test_results.append(("输出字段完整性", tests_passed))
# Test 4: Empty result
print("\n【测试4】空结果处理")
print("-" * 40)
mock_data = pd.DataFrame()
with patch.object(TushareClient, '__init__', lambda self, token=None: None):
with patch.object(TushareClient, 'query', return_value=mock_data):
result = get_daily('INVALID.SZ')
print(f"获取数据是否为空: {result.empty}")
tests_passed = result.empty
print(f"\n测试结果: {'通过 ✓' if tests_passed else '失败 ✗'}")
test_results.append(("空结果处理", tests_passed))
# Test 5: Date range query
print("\n【测试5】日期范围查询")
print("-" * 40)
mock_data = pd.DataFrame({
'ts_code': ['000001.SZ', '000001.SZ'],
'trade_date': ['20240102', '20240103'],
'open': [10.5, 10.6],
'high': [11.0, 11.1],
'low': [10.2, 10.3],
'close': [10.8, 10.9],
'pre_close': [10.3, 10.8],
'change': [0.5, 0.1],
'pct_chg': [4.85, 0.93],
'vol': [1000000, 1100000],
'amount': [10800000, 11900000],
})
with patch.object(TushareClient, '__init__', lambda self, token=None: None):
with patch.object(TushareClient, 'query', return_value=mock_data):
result = get_daily(
'000001.SZ',
start_date='20240101',
end_date='20240131',
)
print(f"获取数据数量: {len(result)}")
print(f"期望数量: 2")
print(f"数据内容:\n{result}")
tests_passed = len(result) == 2
print(f"\n测试结果: {'通过 ✓' if tests_passed else '失败 ✗'}")
test_results.append(("日期范围查询", tests_passed))
# Test 6: With adjustment
print("\n【测试6】带复权参数查询")
print("-" * 40)
mock_data = pd.DataFrame({
'ts_code': ['000001.SZ'],
'trade_date': ['20240102'],
'open': [10.5],
'high': [11.0],
'low': [10.2],
'close': [10.8],
'pre_close': [10.3],
'change': [0.5],
'pct_chg': [4.85],
'vol': [1000000],
'amount': [10800000],
})
with patch.object(TushareClient, '__init__', lambda self, token=None: None):
with patch.object(TushareClient, 'query', return_value=mock_data):
result = get_daily('000001.SZ', adj='qfq')
print(f"获取数据形状: {result.shape}")
print(f"数据内容:\n{result}")
tests_passed = isinstance(result, pd.DataFrame)
print(f"\n测试结果: {'通过 ✓' if tests_passed else '失败 ✗'}")
test_results.append(("复权参数查询", tests_passed))
# Summary
print("\n" + "=" * 60)
print("测试汇总")
print("=" * 60)
passed = sum(1 for _, r in test_results if r)
total = len(test_results)
print(f"总测试数: {total}")
print(f"通过: {passed}")
print(f"失败: {total - passed}")
print(f"通过率: {passed/total*100:.1f}%")
print("\n详细结果:")
for name, passed in test_results:
status = "通过 ✓" if passed else "失败 ✗"
print(f" - {name}: {status}")
return all(r for _, r in test_results)
class TestGetDaily:
"""Test cases for simplified get_daily function."""
def test_fetch_basic(self):
"""Test basic daily data fetch."""
mock_data = pd.DataFrame({
'ts_code': ['000001.SZ'],
'trade_date': ['20240102'],
'open': [10.5],
'high': [11.0],
'low': [10.2],
'close': [10.8],
'pre_close': [10.3],
'change': [0.5],
'pct_chg': [4.85],
'vol': [1000000],
'amount': [10800000],
})
with patch.object(TushareClient, '__init__', lambda self, token=None: None):
with patch.object(TushareClient, 'query', return_value=mock_data):
result = get_daily('000001.SZ', start_date='20240101', end_date='20240131')
assert isinstance(result, pd.DataFrame)
assert len(result) == 1
assert result['ts_code'].iloc[0] == '000001.SZ'
def test_fetch_with_factors(self):
"""Test fetch with tor and vr factors."""
mock_data = pd.DataFrame({
'ts_code': ['000001.SZ'],
'trade_date': ['20240102'],
'open': [10.5],
'high': [11.0],
'low': [10.2],
'close': [10.8],
'pre_close': [10.3],
'change': [0.5],
'pct_chg': [4.85],
'vol': [1000000],
'amount': [10800000],
'tor': [2.5], # 换手率
'vr': [1.2], # 量比
})
with patch.object(TushareClient, '__init__', lambda self, token=None: None):
with patch.object(TushareClient, 'query', return_value=mock_data):
result = get_daily(
'000001.SZ',
start_date='20240101',
end_date='20240131',
factors=['tor', 'vr'],
)
assert isinstance(result, pd.DataFrame)
# Check all base fields are present
for field in EXPECTED_BASE_FIELDS:
assert field in result.columns, f"Missing base field: {field}"
# Check factor fields are present
for field in EXPECTED_FACTOR_FIELDS:
assert field in result.columns, f"Missing factor field: {field}"
def test_output_fields_completeness(self):
"""Verify all required output fields are returned."""
mock_data = pd.DataFrame({
'ts_code': ['600000.SH'],
'trade_date': ['20240102'],
'open': [10.5],
'high': [11.0],
'low': [10.2],
'close': [10.8],
'pre_close': [10.3],
'change': [0.5],
'pct_chg': [4.85],
'vol': [1000000],
'amount': [10800000],
})
with patch.object(TushareClient, '__init__', lambda self, token=None: None):
with patch.object(TushareClient, 'query', return_value=mock_data):
result = get_daily('600000.SH')
# Verify all base fields are present
assert set(EXPECTED_BASE_FIELDS).issubset(result.columns.tolist()), \
f"Missing fields: {set(EXPECTED_BASE_FIELDS) - set(result.columns)}"
def test_empty_result(self):
"""Test handling of empty results."""
mock_data = pd.DataFrame()
with patch.object(TushareClient, '__init__', lambda self, token=None: None):
with patch.object(TushareClient, 'query', return_value=mock_data):
result = get_daily('INVALID.SZ')
assert result.empty
def test_date_range_query(self):
"""Test query with date range."""
mock_data = pd.DataFrame({
'ts_code': ['000001.SZ', '000001.SZ'],
'trade_date': ['20240102', '20240103'],
'open': [10.5, 10.6],
'high': [11.0, 11.1],
'low': [10.2, 10.3],
'close': [10.8, 10.9],
'pre_close': [10.3, 10.8],
'change': [0.5, 0.1],
'pct_chg': [4.85, 0.93],
'vol': [1000000, 1100000],
'amount': [10800000, 11900000],
})
with patch.object(TushareClient, '__init__', lambda self, token=None: None):
with patch.object(TushareClient, 'query', return_value=mock_data):
result = get_daily(
'000001.SZ',
start_date='20240101',
end_date='20240131',
)
assert len(result) == 2
def test_with_adj(self):
"""Test fetch with adjustment type."""
mock_data = pd.DataFrame({
'ts_code': ['000001.SZ'],
'trade_date': ['20240102'],
'open': [10.5],
'high': [11.0],
'low': [10.2],
'close': [10.8],
'pre_close': [10.3],
'change': [0.5],
'pct_chg': [4.85],
'vol': [1000000],
'amount': [10800000],
})
with patch.object(TushareClient, '__init__', lambda self, token=None: None):
with patch.object(TushareClient, 'query', return_value=mock_data):
result = get_daily('000001.SZ', adj='qfq')
assert isinstance(result, pd.DataFrame)
def test_integration():
"""Integration test with real Tushare API (requires valid token)."""
import os
token = os.environ.get('TUSHARE_TOKEN')
if not token:
pytest.skip("TUSHARE_TOKEN not configured")
result = get_daily('000001.SZ', start_date='20240101', end_date='20240131', factors=['tor', 'vr'])
# Verify structure
assert isinstance(result, pd.DataFrame)
if not result.empty:
# Check base fields
for field in EXPECTED_BASE_FIELDS:
assert field in result.columns, f"Missing base field: {field}"
# Check factor fields
for field in EXPECTED_FACTOR_FIELDS:
assert field in result.columns, f"Missing factor field: {field}"
if __name__ == '__main__':
# 运行详细的打印测试
run_tests_with_print()
print("\n" + "=" * 60)
print("运行 pytest 单元测试")
print("=" * 60 + "\n")
pytest.main([__file__, '-v'])