feat: 初始化 ProStock 项目基础结构和配置
- 添加项目规则文档(开发规范、安全规则、配置管理) - 实现数据模块核心功能(API 客户端、限流器、存储管理、配置加载) - 添加 .gitignore 和 .kilocodeignore 配置 - 配置环境变量模板 - 编写 daily 模块单元测试
This commit is contained in:
73
.gitignore
vendored
Normal file
73
.gitignore
vendored
Normal file
@@ -0,0 +1,73 @@
|
||||
# ===========================================
|
||||
# Python项目.gitignore模板
|
||||
# ===========================================
|
||||
|
||||
# Python字节码缓存
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
*.so
|
||||
.Python
|
||||
build/
|
||||
develop-eggs/
|
||||
dist/
|
||||
downloads/
|
||||
eggs/
|
||||
.eggs/
|
||||
lib/
|
||||
lib64/
|
||||
parts/
|
||||
sdist/
|
||||
var/
|
||||
wheels/
|
||||
pip-wheel-metadata/
|
||||
share/python-wheels/
|
||||
*.egg-info/
|
||||
.installed.cfg
|
||||
*.egg
|
||||
MANIFEST
|
||||
|
||||
# 虚拟环境
|
||||
venv/
|
||||
ENV/
|
||||
env/
|
||||
.venv/
|
||||
|
||||
# 配置文件(敏感信息)
|
||||
.env.local
|
||||
config/*.local
|
||||
!config/.env.example
|
||||
|
||||
# 日志文件
|
||||
*.log
|
||||
logs/
|
||||
*.log.*
|
||||
|
||||
# IDE和编辑器
|
||||
.vscode/
|
||||
.idea/
|
||||
*.swp
|
||||
*.swo
|
||||
*~
|
||||
|
||||
# 操作系统文件
|
||||
.DS_Store
|
||||
Thumbs.db
|
||||
|
||||
# 测试覆盖率
|
||||
.coverage
|
||||
htmlcov/
|
||||
.pytest_cache/
|
||||
.coverage
|
||||
cover/
|
||||
|
||||
# 数据库文件
|
||||
*.db
|
||||
*.sqlite
|
||||
*.sqlite3
|
||||
|
||||
# 临时文件
|
||||
*.tmp
|
||||
*.temp
|
||||
tmp/
|
||||
temp/
|
||||
72
.kilocode/rules/project_rules.md
Normal file
72
.kilocode/rules/project_rules.md
Normal file
@@ -0,0 +1,72 @@
|
||||
# ProStock 项目规则
|
||||
|
||||
本项目使用 Kilo Code 规则系统来确保代码质量和一致性。
|
||||
|
||||
## 项目概述
|
||||
|
||||
- **项目名称**: ProStock
|
||||
- **主要语言**: Python
|
||||
- **框架**: 待定
|
||||
- **代码目录**: `src/`
|
||||
- **配置目录**: `config/`
|
||||
|
||||
## 核心原则
|
||||
|
||||
1. **代码质量**: 所有代码必须符合 Python PEP 8 编码规范
|
||||
2. **类型提示**: 建议为公共函数和类添加类型注解
|
||||
3. **文档字符串**: 使用 Google 风格的 docstring
|
||||
4. **测试覆盖**: 关键业务逻辑应有对应的单元测试
|
||||
|
||||
## 目录结构规范
|
||||
|
||||
```
|
||||
project/
|
||||
├── src/ # 源代码主目录
|
||||
├── config/ # 配置文件目录(禁止直接读取)
|
||||
├── docs/ # 文档目录
|
||||
├── tests/ # 测试目录
|
||||
├── .kilocode/ # Kilo Code 配置
|
||||
│ └── rules/ # 规则文件
|
||||
├── requirements.txt # 依赖管理
|
||||
└── README.md # 项目说明
|
||||
```
|
||||
|
||||
## 文件命名规范
|
||||
|
||||
- **Python 文件**: 使用小写下划线命名法 (`snake_case.py`)
|
||||
- **配置文件**: 使用小写下划线命名法 (`config_settings.py`)
|
||||
- **测试文件**: 使用 `test_` 前缀 (`test_example.py`)
|
||||
- **常量文件**: 使用 `constants_` 前缀 (`constants_sizes.py`)
|
||||
|
||||
## 代码组织
|
||||
|
||||
- 每个 Python 模块应尽可能保持简洁,职责单一
|
||||
- 避免在 `__init__.py` 中放置过多逻辑
|
||||
- 使用相对导入 (`from .module import ...`) 而非绝对导入
|
||||
- 配置应集中管理,避免硬编码
|
||||
|
||||
## 导入顺序规范
|
||||
|
||||
```python
|
||||
# 1. 标准库导入
|
||||
import os
|
||||
import sys
|
||||
from datetime import datetime
|
||||
|
||||
# 2. 第三方库导入
|
||||
import pandas as pd
|
||||
from flask import Flask
|
||||
|
||||
# 3. 本地应用导入
|
||||
from src.config.settings import Settings
|
||||
from src.models.user import User
|
||||
```
|
||||
|
||||
## 提交前检查
|
||||
|
||||
在提交代码前,请确保:
|
||||
- [ ] 所有代码通过类型检查(如使用 mypy)
|
||||
- [ ] 代码格式符合规范(使用 black/isort)
|
||||
- [ ] 没有未使用的导入或变量
|
||||
- [ ] 关键功能有对应的测试
|
||||
- [ ] 文档已更新(如需要)
|
||||
325
.kilocode/rules/python-development-guidelines.md
Normal file
325
.kilocode/rules/python-development-guidelines.md
Normal file
@@ -0,0 +1,325 @@
|
||||
# Python 开发规范
|
||||
|
||||
## 1 代码风格规范
|
||||
|
||||
### 1.1 缩进与空格
|
||||
- 使用 **4个空格** 进行缩进(禁止使用Tab)
|
||||
- 每行代码最大长度限制为 **120字符**
|
||||
- 二元运算符前后使用空格(链式赋值除外)
|
||||
- 函数参数周围空格一致
|
||||
|
||||
```python
|
||||
# 正确示例
|
||||
def calculate_total_price(price: float, tax_rate: float) -> float:
|
||||
total = price * (1 + tax_rate)
|
||||
return round(total, 2)
|
||||
|
||||
# 错误示例
|
||||
def calculate_total_price(price:float,tax_rate:float)->float:
|
||||
total=price*(1+tax_rate)
|
||||
return round(total,2)
|
||||
```
|
||||
|
||||
### 1.2 导入规范
|
||||
- 导入位于文件顶部,位于模块注释和文档字符串之后
|
||||
- 标准库导入放在最前面,其次是第三方库,最后是本地应用模块
|
||||
- 使用绝对导入,禁止使用通配符导入(`from module import *`)
|
||||
- 每组导入之间保留一个空行
|
||||
|
||||
```python
|
||||
# 标准库
|
||||
import os
|
||||
import sys
|
||||
from datetime import datetime
|
||||
|
||||
# 第三方库
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from pydantic import BaseModel
|
||||
|
||||
# 本地模块
|
||||
from src.config import settings
|
||||
from src.utils.logger import get_logger
|
||||
```
|
||||
|
||||
### 1.3 命名规范
|
||||
|
||||
| 类型 | 规范 | 示例 |
|
||||
|------|------|------|
|
||||
| 模块 | 全小写,下划线分隔 | `data_processor.py` |
|
||||
| 包 | 全小写,禁止下划线 | `src/utils` |
|
||||
| 类 | PascalCase,首字母大写 | `UserAccount`, `DataValidator` |
|
||||
| 函数 | snake_case,全小写 | `get_user_data()`, `calculate_total()` |
|
||||
| 变量 | snake_case,全小写 | `user_name`, `total_count` |
|
||||
| 常量 | UPPER_SNAKE_CASE | `MAX_RETRY_COUNT`, `DEFAULT_TIMEOUT` |
|
||||
| 私有方法/变量 | 单下划线前缀 | `_internal_method()`, `_private_var` |
|
||||
| 类型变量 | PascalCase | `UserType`, `T = TypeVar('T')` |
|
||||
|
||||
### 1.4 注释规范
|
||||
- 使用英文注释,禁止使用中文注释
|
||||
- 复杂逻辑必须添加注释说明意图
|
||||
- 注释与代码同步更新,禁止过时注释
|
||||
|
||||
```python
|
||||
# 单行注释(与代码同行时,注释前保留2个空格)
|
||||
def process_data(data: list) -> dict: # Process input data and return statistics
|
||||
pass
|
||||
|
||||
# 多行文档字符串(Google风格)
|
||||
def calculate_metrics(values: list[float]) -> dict[str, float]:
|
||||
"""Calculate statistical metrics for a list of values.
|
||||
|
||||
Args:
|
||||
values: List of numeric values to analyze
|
||||
|
||||
Returns:
|
||||
Dictionary containing mean, median, and standard deviation
|
||||
|
||||
Raises:
|
||||
ValueError: If input list is empty
|
||||
"""
|
||||
pass
|
||||
```
|
||||
|
||||
## 2 架构设计规范
|
||||
|
||||
### 2.1 模块结构
|
||||
```
|
||||
src/
|
||||
├── core/ # 核心业务逻辑
|
||||
├── modules/ # 功能模块
|
||||
├── utils/ # 工具函数
|
||||
├── config/ # 配置管理
|
||||
├── services/ # 服务层
|
||||
├── models/ # 数据模型
|
||||
├── repositories/ # 数据访问层
|
||||
└── schemas/ # Pydantic模型
|
||||
```
|
||||
|
||||
### 2.2 依赖原则
|
||||
- **依赖倒置**:高层模块不依赖低层模块,两者都依赖抽象
|
||||
- **禁止循环依赖**:模块间引用必须形成有向无环图
|
||||
- **单一职责**:每个类/模块只负责一项职责
|
||||
|
||||
```python
|
||||
# 错误示例:高内聚低耦合违反
|
||||
class UserService:
|
||||
def __init__(self, db_connection, email_sender, cache_manager):
|
||||
pass
|
||||
|
||||
# 正确示例:通过抽象接口解耦
|
||||
class UserService:
|
||||
def __init__(self, user_repository: IUserRepository):
|
||||
self.repository = user_repository
|
||||
```
|
||||
|
||||
### 2.3 函数设计
|
||||
- **单一职责**:每个函数只做一件事
|
||||
- **参数控制**:函数参数不超过5个,超过则使用对象封装
|
||||
- **返回值明确**:返回类型必须注解
|
||||
|
||||
```python
|
||||
# 错误示例:函数职责过多
|
||||
def process_user_registration(name, email, password, send_email, create_session):
|
||||
if send_email:
|
||||
send_verification_email(email)
|
||||
if create_session:
|
||||
create_user_session(email)
|
||||
return save_user(name, email, password)
|
||||
|
||||
# 正确示例:单一职责拆分
|
||||
class UserRegistrationService:
|
||||
def register(self, user_data: UserCreateDto) -> User:
|
||||
user = self._validate_and_create_user(user_data)
|
||||
self._send_verification_email(user)
|
||||
self._create_session(user)
|
||||
return user
|
||||
|
||||
def _validate_and_create_user(self, data: UserCreateDto) -> User:
|
||||
pass
|
||||
|
||||
def _send_verification_email(self, user: User) -> None:
|
||||
pass
|
||||
```
|
||||
|
||||
## 3 配置管理规范
|
||||
|
||||
### 3.1 禁止硬编码
|
||||
**所有关键配置必须外置,禁止在代码中硬编码**
|
||||
|
||||
| 类型 | 必须外置的配置 |
|
||||
|------|----------------|
|
||||
| 数据库连接 | HOST, PORT, USER, PASSWORD, DATABASE |
|
||||
| API密钥 | SECRET_KEY, API_KEY |
|
||||
| 外部服务 | ENDPOINT, TIMEOUT, RETRY_COUNT |
|
||||
| 业务参数 | MAX_RETRIES, CACHE_TTL, RATE_LIMIT |
|
||||
|
||||
### 3.2 配置目录结构
|
||||
```
|
||||
config/
|
||||
├── .env.example # 环境变量模板(不包含敏感信息)
|
||||
├── .env.local # 本地环境配置(被.gitignore忽略)
|
||||
├── config.yaml # 通用配置
|
||||
├── config.development.yaml # 开发环境配置
|
||||
├── config.production.yaml # 生产环境配置
|
||||
└── config.test.yaml # 测试环境配置
|
||||
```
|
||||
|
||||
### 3.3 配置加载示例
|
||||
```python
|
||||
# src/config/settings.py
|
||||
from pydantic_settings import BaseSettings
|
||||
from functools import lru_cache
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
"""应用配置类,从环境变量加载"""
|
||||
|
||||
# 数据库配置
|
||||
database_host: str = "localhost"
|
||||
database_port: int = 5432
|
||||
database_name: str = "prostock"
|
||||
database_user: str
|
||||
database_password: str
|
||||
|
||||
# API配置
|
||||
api_key: str
|
||||
secret_key: str
|
||||
jwt_algorithm: str = "HS256"
|
||||
access_token_expire_minutes: int = 30
|
||||
|
||||
# Redis配置
|
||||
redis_host: str = "localhost"
|
||||
redis_port: int = 6379
|
||||
|
||||
class Config:
|
||||
env_file = ".env.local"
|
||||
env_file_encoding = "utf-8"
|
||||
|
||||
|
||||
@lru_cache()
|
||||
def get_settings() -> Settings:
|
||||
"""获取配置单例"""
|
||||
return Settings()
|
||||
```
|
||||
|
||||
### 3.4 .env.example 模板
|
||||
```bash
|
||||
# ===========================================
|
||||
# ProStock 环境变量配置模板
|
||||
# 复制此文件为 .env.local 并填入实际值
|
||||
# ===========================================
|
||||
|
||||
# 数据库配置
|
||||
DATABASE_HOST=localhost
|
||||
DATABASE_PORT=5432
|
||||
DATABASE_NAME=prostock
|
||||
DATABASE_USER=your_username
|
||||
DATABASE_PASSWORD=your_password
|
||||
|
||||
# API密钥配置
|
||||
API_KEY=your_api_key
|
||||
SECRET_KEY=your_secret_key_here
|
||||
|
||||
# Redis配置(可选)
|
||||
REDIS_HOST=localhost
|
||||
REDIS_PORT=6379
|
||||
```
|
||||
|
||||
## 4 错误处理规范
|
||||
|
||||
### 4.1 异常分类
|
||||
```python
|
||||
# src/core/exceptions.py
|
||||
class BaseCustomException(Exception):
|
||||
"""基础异常类"""
|
||||
status_code: int = 500
|
||||
detail: str = "An unexpected error occurred"
|
||||
|
||||
def __init__(self, detail: str = None):
|
||||
self.detail = detail or self.detail
|
||||
super().__init__(self.detail)
|
||||
|
||||
|
||||
class ValidationError(BaseCustomException):
|
||||
"""数据验证异常"""
|
||||
status_code = 422
|
||||
detail = "Validation error"
|
||||
|
||||
|
||||
class AuthenticationError(BaseCustomException):
|
||||
"""认证异常"""
|
||||
status_code = 401
|
||||
detail = "Authentication required"
|
||||
|
||||
|
||||
class AuthorizationError(BaseCustomException):
|
||||
"""授权异常"""
|
||||
status_code = 403
|
||||
detail = "Permission denied"
|
||||
|
||||
|
||||
class NotFoundError(BaseCustomException):
|
||||
"""资源不存在"""
|
||||
status_code = 404
|
||||
detail = "Resource not found"
|
||||
```
|
||||
|
||||
### 4.2 错误处理原则
|
||||
- 向上层传递有意义的信息
|
||||
- 记录详细日志(不含敏感信息)
|
||||
- 区分可恢复和不可恢复错误
|
||||
|
||||
## 5 测试规范
|
||||
|
||||
### 5.1 测试要求
|
||||
- 所有核心功能必须有单元测试
|
||||
- 关键业务逻辑测试覆盖率不低于80%
|
||||
- 使用 `pytest` 作为测试框架
|
||||
- 使用 `pytest-cov` 生成覆盖率报告
|
||||
|
||||
### 5.2 测试文件结构
|
||||
```
|
||||
tests/
|
||||
├── conftest.py # 共享fixtures
|
||||
├── unit/ # 单元测试
|
||||
│ ├── test_models.py
|
||||
│ ├── test_services.py
|
||||
│ └── test_utils.py
|
||||
├── integration/ # 集成测试
|
||||
│ └── test_api.py
|
||||
└── fixtures/ # 测试数据
|
||||
```
|
||||
|
||||
## 6 Git提交规范
|
||||
|
||||
### 6.1 提交信息格式
|
||||
```
|
||||
<type>(<scope>): <subject>
|
||||
|
||||
<body>
|
||||
|
||||
<footer>
|
||||
```
|
||||
|
||||
### 6.2 类型标识
|
||||
| 类型 | 说明 |
|
||||
|------|------|
|
||||
| feat | 新功能 |
|
||||
| fix | Bug修复 |
|
||||
| docs | 文档更新 |
|
||||
| style | 代码格式调整 |
|
||||
| refactor | 重构 |
|
||||
| test | 测试相关 |
|
||||
| chore | 构建/辅助工具 |
|
||||
|
||||
## 7 代码审查清单
|
||||
|
||||
- [ ] 代码符合PEP 8规范
|
||||
- [ ] 关键配置未硬编码
|
||||
- [ ] 函数/类添加了类型注解
|
||||
- [ ] 复杂逻辑有注释说明
|
||||
- [ ] 单元测试覆盖关键逻辑
|
||||
- [ ] 无循环依赖
|
||||
- [ ] 命名符合规范
|
||||
- [ ] 日志不包含敏感信息
|
||||
271
.kilocode/rules/security_rules.md
Normal file
271
.kilocode/rules/security_rules.md
Normal file
@@ -0,0 +1,271 @@
|
||||
# 安全与访问控制规则
|
||||
|
||||
本文件定义了项目的安全约束和访问控制规则。
|
||||
|
||||
## 🔒 禁止访问配置文件
|
||||
|
||||
### 核心约束
|
||||
|
||||
**严禁在编码过程中访问根目录下的 `config` 目录下的任何文件**。这包括:
|
||||
|
||||
1. **禁止读取** - 不得使用任何方式读取 `config/` 目录下的文件
|
||||
2. **禁止编辑** - 不得修改 `config/` 目录下的配置文件
|
||||
3. **禁止查看** - 不得阅读 `config/` 目录下的文件内容
|
||||
4. **禁止搜索** - 不得在 `config/` 目录下进行任何搜索操作
|
||||
5. **禁止执行** - 不得在 `config/` 目录下执行任何命令
|
||||
|
||||
所有配置读取必须通过集中管理的配置模块(`src/config/`)进行。**`config/` 与 `src/config/` 是完全不同的目录,前者受保护,后者是配置模块代码目录**。
|
||||
|
||||
### 目录结构说明
|
||||
|
||||
```
|
||||
ProStock/
|
||||
├── config/ # 受保护的配置文件目录(禁止任何访问)
|
||||
│ ├── .env.example # 环境变量模板
|
||||
│ ├── .env.local # 本地环境配置(敏感)
|
||||
│ ├── config.yaml # 通用配置
|
||||
│ └── ...
|
||||
├── src/config/ # 配置模块代码目录(仅限模块内部访问)
|
||||
│ ├── __init__.py
|
||||
│ ├── settings.py # 配置加载逻辑
|
||||
│ └── ...
|
||||
└── ...
|
||||
```
|
||||
|
||||
### 受限制工具完整列表
|
||||
|
||||
所有以下工具均**严格禁止**访问 `config/` 目录:
|
||||
|
||||
| 工具类别 | 受限制工具 | 禁止操作 |
|
||||
|---------|-----------|---------|
|
||||
| 文件读取 | `read_file` | 读取 `config/` 目录下任何文件 |
|
||||
| 文件读取 | `list_files` | 列出 `config/` 目录内容 |
|
||||
| 文件编辑 | `edit_file` | 编辑 `config/` 目录下任何文件 |
|
||||
| 文件编辑 | `search_and_replace` | 替换 `config/` 目录下文件内容 |
|
||||
| 文件编辑 | `write_to_file` | 写入 `config/` 目录(新建或覆盖) |
|
||||
| 文件编辑 | `delete_file` | 删除 `config/` 目录下任何文件 |
|
||||
| 搜索工具 | `search_files` | 在 `config/` 目录下进行正则搜索 |
|
||||
| 命令执行 | `execute_command` | 执行任何涉及 `config/` 目录的命令 |
|
||||
| 代码模式 | 所有代码写入操作 | 向 `config/` 目录写入代码或配置 |
|
||||
|
||||
### 违禁行为完整列表 ❌
|
||||
|
||||
```python
|
||||
# 禁止:直接读取 .env 文件
|
||||
import dotenv
|
||||
dotenv.load_dotenv('config/.env') # 禁止!
|
||||
|
||||
# 禁止:直接读取配置文件
|
||||
with open('config/settings.json', 'r') as f: # 禁止!
|
||||
config = json.load(f)
|
||||
|
||||
# 禁止:硬编码配置路径
|
||||
config_path = os.path.join('config', 'database.yml') # 禁止!
|
||||
|
||||
# 禁止:使用 read_file 工具查看 config 文件
|
||||
read_file(path='config/.env') # 禁止!
|
||||
|
||||
# 禁止:使用 edit_file/search_and_replace 编辑 config 文件
|
||||
edit_file(path='config/settings.py', ...) # 禁止!
|
||||
|
||||
# 禁止:使用 search_files 搜索 config 目录
|
||||
search_files(path='config', regex='.*') # 禁止!
|
||||
|
||||
# 禁止:使用 list_files 列出 config 目录
|
||||
list_files(path='config', recursive=True) # 禁止!
|
||||
|
||||
# 禁止:使用 write_to_file 创建或修改 config 文件
|
||||
write_to_file(path='config/custom.py', ...) # 禁止!
|
||||
|
||||
# 禁止:使用 delete_file 删除 config 文件
|
||||
delete_file(path='config/old.env') # 禁止!
|
||||
|
||||
# 禁止:Python 文件系统操作
|
||||
import os
|
||||
os.listdir('config') # 禁止!
|
||||
os.path.exists('config/.env') # 禁止!
|
||||
os.path.isfile('config/settings.yaml') # 禁止!
|
||||
os.walk('config') # 禁止!
|
||||
|
||||
# 禁止:pathlib 操作
|
||||
from pathlib import Path
|
||||
Path('config/.env').read_text() # 禁止!
|
||||
Path('config').iterdir() # 禁止!
|
||||
|
||||
# 禁止:glob 模式匹配
|
||||
import glob
|
||||
glob.glob('config/**/*') # 禁止!
|
||||
glob.iglob('config/*.yaml') # 禁止!
|
||||
|
||||
# 禁止:shutil 操作
|
||||
import shutil
|
||||
shutil.copy('config/.env', 'backup/') # 禁止!
|
||||
```
|
||||
|
||||
### 违禁命令执行 ❌
|
||||
|
||||
```bash
|
||||
# 禁止:进入 config 目录
|
||||
cd config # 禁止!
|
||||
|
||||
# 禁止:列出 config 目录
|
||||
ls config/ # 禁止!
|
||||
dir config # 禁止!
|
||||
|
||||
# 禁止:读取 config 文件
|
||||
cat config/.env # 禁止!
|
||||
type config\.env # 禁止!
|
||||
|
||||
# 禁止:搜索 config 目录
|
||||
grep -r "SECRET" config/ # 禁止!
|
||||
|
||||
# 禁止:任何 config 目录相关命令
|
||||
find config -name "*.py" # 禁止!
|
||||
```
|
||||
|
||||
### 合规行为 ✅
|
||||
|
||||
```python
|
||||
# 正确:使用配置管理模块
|
||||
from src.config.settings import Settings
|
||||
settings = Settings()
|
||||
db_config = settings.database
|
||||
|
||||
# 正确:配置模块内部处理
|
||||
from src.config import get_config
|
||||
config = get_config()
|
||||
```
|
||||
|
||||
### 配置文件保护规则
|
||||
|
||||
1. **`config` 目录**:根目录下的 `config/` 目录**完全受保护**,禁止任何工具访问
|
||||
2. **`src/config` 目录**:配置模块代码目录,仅限 `src/config/` 内部代码访问
|
||||
3. **敏感文件**:`.env` 文件必须添加到 `.gitignore`
|
||||
4. **配置加载**:必须在应用启动时一次性加载,而非在运行时多次读取
|
||||
5. **工具调用限制**:所有工具调用必须验证目标路径不包含 `config/` 前缀
|
||||
|
||||
## 🔐 敏感信息处理
|
||||
|
||||
### 禁止的行为
|
||||
|
||||
```python
|
||||
# 禁止:在代码中硬编码密钥
|
||||
API_KEY = "sk-1234567890abcdef" # 禁止!
|
||||
|
||||
# 禁止:打印敏感信息
|
||||
print(f"Password: {password}") # 禁止!
|
||||
|
||||
# 禁止:将密钥写入日志
|
||||
logger.debug(f"API Key: {api_key}") # 禁止!
|
||||
```
|
||||
|
||||
### 合规做法
|
||||
|
||||
```python
|
||||
# 正确:从环境变量获取
|
||||
import os
|
||||
API_KEY = os.environ.get('API_KEY')
|
||||
|
||||
# 正确:使用配置管理
|
||||
from src.config.settings import Settings
|
||||
settings = Settings()
|
||||
api_key = settings.api_key
|
||||
```
|
||||
|
||||
## 🛡️ 安全最佳实践
|
||||
|
||||
### 输入验证
|
||||
|
||||
- 所有外部输入必须经过验证和清理
|
||||
- 使用参数化查询防止 SQL 注入
|
||||
- 对用户输入进行适当的转义和过滤
|
||||
|
||||
### 依赖安全
|
||||
|
||||
- 定期更新依赖包以修复安全漏洞
|
||||
- 使用 `pip audit` 或 `safety` 检查依赖
|
||||
- 避免使用已知存在安全问题的包
|
||||
|
||||
### 日志规范
|
||||
|
||||
- **禁止** 记录敏感信息(密码、密钥、令牌等)
|
||||
- **禁止** 记录完整的用户数据(考虑脱敏)
|
||||
- **建议** 记录操作类型、用户ID(不含敏感字段)、时间戳
|
||||
|
||||
### 错误处理
|
||||
|
||||
- 禁止向用户暴露详细的错误堆栈
|
||||
- 敏感错误应记录到安全日志,而非返回给客户端
|
||||
- 使用通用的错误消息对外展示
|
||||
|
||||
## ⚠️ 违规处理
|
||||
|
||||
### 自动检测机制
|
||||
|
||||
1. **提交前扫描**:使用 Git hooks 自动扫描提交内容
|
||||
2. **CI/CD 流水线检测**:在持续集成流程中运行安全扫描
|
||||
3. **静态代码分析**:集成静态分析工具检测违规模式
|
||||
4. **工具调用监控**:监控所有工具调用是否涉及 `config/` 目录
|
||||
|
||||
### 惩罚机制
|
||||
|
||||
违反上述规则将导致:
|
||||
|
||||
1. **代码审查不通过**:提交将被自动拒绝
|
||||
2. **安全扫描工具报警**:触发安全警报通知
|
||||
3. **安全漏洞评级**:标记为高优先级安全漏洞
|
||||
4. **构建阻断**:CI/CD 流水线自动失败
|
||||
5. **审计日志记录**:记录违规行为用于审计追踪
|
||||
|
||||
### 违规严重程度分类
|
||||
|
||||
| 等级 | 违规类型 | 处罚措施 |
|
||||
|------|---------|---------|
|
||||
| 严重 | 故意读取敏感配置文件(如 `.env`) | 代码审查拒绝、团队通知 |
|
||||
| 高 | 使用工具访问 `config/` 目录 | 代码审查拒绝、要求整改 |
|
||||
| 中 | 在代码中硬编码配置路径 | 要求修改、代码审查标记 |
|
||||
| 低 | 潜在风险操作(需人工审核) | 代码审查提醒 |
|
||||
|
||||
## 📋 规则检查清单
|
||||
|
||||
### 提交前检查
|
||||
|
||||
- [ ] 没有使用任何工具访问 `config/` 目录
|
||||
- [ ] 没有硬编码的 `config/` 目录路径
|
||||
- [ ] 配置文件仅通过配置模块访问
|
||||
- [ ] 没有硬编码的敏感信息
|
||||
- [ ] 敏感信息从环境变量或安全存储获取
|
||||
- [ ] 日志不包含敏感数据
|
||||
- [ ] 错误处理不暴露敏感信息
|
||||
|
||||
### 工具调用检查
|
||||
|
||||
- [ ] `read_file` 未调用 `config/` 目录
|
||||
- [ ] `edit_file` 未指向 `config/` 目录
|
||||
- [ ] `write_to_file` 未写入 `config/` 目录
|
||||
- [ ] `delete_file` 未删除 `config/` 目录
|
||||
- [ ] `search_files` 未搜索 `config/` 目录
|
||||
- [ ] `list_files` 未列出 `config/` 目录
|
||||
- [ ] `execute_command` 未涉及 `config/` 路径
|
||||
|
||||
## 🔧 合规验证脚本
|
||||
|
||||
为确保规则得到遵守,可使用以下验证方法:
|
||||
|
||||
```bash
|
||||
# 检查是否有访问 config 目录的代码
|
||||
grep -r "config/\." --include="*.py" src/ tests/ docs/
|
||||
|
||||
# 检查工具调用中的 config 路径
|
||||
grep -rn "path='config" --include="*.py"
|
||||
|
||||
# 检查是否有 .env 文件被提交
|
||||
git ls-files | grep "^config/.env"
|
||||
```
|
||||
|
||||
## 📝 培训与文档
|
||||
|
||||
1. **新成员培训**:入职时必须学习本安全规则
|
||||
2. **定期审计**:每月进行一次安全规则执行情况审计
|
||||
3. **规则更新**:安全规则随项目演进定期更新
|
||||
4. **违规通报**:定期通报违规案例以提高安全意识
|
||||
20
.kilocodeignore
Normal file
20
.kilocodeignore
Normal file
@@ -0,0 +1,20 @@
|
||||
# 环境变量
|
||||
.env*
|
||||
.env.local
|
||||
.env.*.local
|
||||
|
||||
# 密钥与证书
|
||||
*.pem
|
||||
*.key
|
||||
id_rsa
|
||||
id_dsa
|
||||
*.pfx
|
||||
|
||||
# 数据库文件
|
||||
*.sqlite
|
||||
*.db
|
||||
|
||||
# 日志与缓存
|
||||
logs/
|
||||
cache/
|
||||
*.log
|
||||
28
config/.env.example
Normal file
28
config/.env.example
Normal file
@@ -0,0 +1,28 @@
|
||||
# ===========================================
|
||||
# ProStock 环境变量配置模板
|
||||
#
|
||||
# 使用说明:
|
||||
# 1. 复制此文件为 .env.local
|
||||
# 2. 填入实际的配置值
|
||||
# 3. 放在 config/ 目录下
|
||||
# ===========================================
|
||||
|
||||
# 数据库配置
|
||||
DATABASE_HOST=localhost
|
||||
DATABASE_PORT=5432
|
||||
DATABASE_NAME=prostock
|
||||
DATABASE_USER=postgres
|
||||
DATABASE_PASSWORD=your_password
|
||||
|
||||
# API密钥配置(重要:不要泄露到版本控制)
|
||||
API_KEY=your_api_key_here
|
||||
SECRET_KEY=your_secret_key_here
|
||||
|
||||
# Redis配置(可选)
|
||||
REDIS_HOST=localhost
|
||||
REDIS_PORT=6379
|
||||
|
||||
# 应用配置
|
||||
APP_ENV=development
|
||||
APP_DEBUG=true
|
||||
APP_PORT=8000
|
||||
5
src/__init__.py
Normal file
5
src/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
"""ProStock 股票分析项目
|
||||
|
||||
提供股票数据分析和交易策略等功能
|
||||
"""
|
||||
__version__ = "1.0.0"
|
||||
7
src/config/__init__.py
Normal file
7
src/config/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
||||
"""配置管理模块
|
||||
|
||||
提供配置加载和管理功能
|
||||
"""
|
||||
from src.config.settings import Settings, get_settings, settings
|
||||
|
||||
__all__ = ["Settings", "get_settings", "settings"]
|
||||
58
src/config/settings.py
Normal file
58
src/config/settings.py
Normal file
@@ -0,0 +1,58 @@
|
||||
"""配置管理模块
|
||||
|
||||
从环境变量加载应用配置,使用pydantic-settings进行类型验证
|
||||
"""
|
||||
import os
|
||||
from pathlib import Path
|
||||
from pydantic_settings import BaseSettings
|
||||
from functools import lru_cache
|
||||
from typing import Optional
|
||||
|
||||
|
||||
# 获取项目根目录(config文件夹所在目录)
|
||||
PROJECT_ROOT = Path(__file__).parent.parent.parent
|
||||
CONFIG_DIR = PROJECT_ROOT / "config"
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
"""应用配置类,从环境变量加载"""
|
||||
|
||||
# 数据库配置
|
||||
database_host: str = "localhost"
|
||||
database_port: int = 5432
|
||||
database_name: str = "prostock"
|
||||
database_user: str
|
||||
database_password: str
|
||||
|
||||
# API密钥配置
|
||||
api_key: str
|
||||
secret_key: str
|
||||
|
||||
# Redis配置
|
||||
redis_host: str = "localhost"
|
||||
redis_port: int = 6379
|
||||
redis_password: Optional[str] = None
|
||||
|
||||
# 应用配置
|
||||
app_env: str = "development"
|
||||
app_debug: bool = False
|
||||
app_port: int = 8000
|
||||
|
||||
class Config:
|
||||
# 从 config/ 目录读取 .env.local 文件
|
||||
env_file = str(CONFIG_DIR / ".env.local")
|
||||
env_file_encoding = "utf-8"
|
||||
case_sensitive = False
|
||||
|
||||
|
||||
@lru_cache()
|
||||
def get_settings() -> Settings:
|
||||
"""获取配置单例
|
||||
|
||||
使用lru_cache确保配置只加载一次
|
||||
"""
|
||||
return Settings()
|
||||
|
||||
|
||||
# 导出配置实例供全局使用
|
||||
settings = get_settings()
|
||||
15
src/data/__init__.py
Normal file
15
src/data/__init__.py
Normal file
@@ -0,0 +1,15 @@
|
||||
"""Data collection module for Tushare.
|
||||
|
||||
Provides simplified interfaces for fetching and storing Tushare data.
|
||||
"""
|
||||
|
||||
from src.data.config import Config, get_config
|
||||
from src.data.client import TushareClient
|
||||
from src.data.storage import Storage
|
||||
|
||||
__all__ = [
|
||||
"Config",
|
||||
"get_config",
|
||||
"TushareClient",
|
||||
"Storage",
|
||||
]
|
||||
47
src/data/api.md
Normal file
47
src/data/api.md
Normal file
@@ -0,0 +1,47 @@
|
||||
1、通用行情接口:https://tushare.pro/document/2?doc_id=109,能够获取的字段参考https://tushare.pro/document/2?doc_id=27,要求,保存A股日线行情中所有输出字段和tor换手率 vr量比
|
||||
ts_code str Y 证券代码,不支持多值输入,多值输入获取结果会有重复记录
|
||||
start_date str N 开始日期 (日线格式:YYYYMMDD,提取分钟数据请用2019-09-01 09:00:00这种格式)
|
||||
end_date str N 结束日期 (日线格式:YYYYMMDD)
|
||||
asset str Y 资产类别:E股票 I沪深指数 C数字货币 FT期货 FD基金 O期权 CB可转债(v1.2.39),默认E
|
||||
adj str N 复权类型(只针对股票):None未复权 qfq前复权 hfq后复权 , 默认None,目前只支持日线复权,同时复权机制是根据设定的end_date参数动态复权,采用分红再投模式,具体请参考常见问题列表里的说明。
|
||||
freq str Y 数据频度 :支持分钟(min)/日(D)/周(W)/月(M)K线,其中1min表示1分钟(类推1/5/15/30/60分钟) ,默认D。对于分钟数据有600积分用户可以试用(请求2次),正式权限可以参考权限列表说明 ,使用方法请参考股票分钟使用方法。
|
||||
ma list N 均线,支持任意合理int数值。注:均线是动态计算,要设置一定时间范围才能获得相应的均线,比如5日均线,开始和结束日期参数跨度必须要超过5日。目前只支持单一个股票提取均线,即需要输入ts_code参数。e.g: ma_5表示5日均价,ma_v_5表示5日均量
|
||||
factors list N 股票因子(asset='E'有效)支持 tor换手率 vr量比
|
||||
adjfactor str N 复权因子,在复权数据时,如果此参数为True,返回的数据中则带复权因子,默认为False。 该功能从1.2.33版本开始生效
|
||||
输出指标
|
||||
具体输出的数据指标可参考各行情具体指标:
|
||||
股票Daily:https://tushare.pro/document/2?doc_id=27
|
||||
基金Daily:https://tushare.pro/document/2?doc_id=127
|
||||
期货Daily:https://tushare.pro/document/2?doc_id=138
|
||||
期权Daily:https://tushare.pro/document/2?doc_id=159
|
||||
指数Daily:https://tushare.pro/document/2?doc_id=95
|
||||
|
||||
A股日线行情
|
||||
接口:daily,可以通过数据工具调试和查看数据
|
||||
数据说明:交易日每天15点~16点之间入库。本接口是未复权行情,停牌期间不提供数据
|
||||
调取说明:基础积分每分钟内可调取500次,每次6000条数据,一次请求相当于提取一个股票23年历史
|
||||
描述:获取股票行情数据,或通过通用行情接口获取数据,包含了前后复权数据
|
||||
|
||||
输入参数
|
||||
|
||||
名称 类型 必选 描述
|
||||
ts_code str N 股票代码(支持多个股票同时提取,逗号分隔)
|
||||
trade_date str N 交易日期(YYYYMMDD)
|
||||
start_date str N 开始日期(YYYYMMDD)
|
||||
end_date str N 结束日期(YYYYMMDD)
|
||||
注:日期都填YYYYMMDD格式,比如20181010
|
||||
|
||||
输出参数
|
||||
|
||||
名称 类型 描述
|
||||
ts_code str 股票代码
|
||||
trade_date str 交易日期
|
||||
open float 开盘价
|
||||
high float 最高价
|
||||
low float 最低价
|
||||
close float 收盘价
|
||||
pre_close float 昨收价【除权价】
|
||||
change float 涨跌额
|
||||
pct_chg float 涨跌幅 【基于除权后的昨收计算的涨跌幅:(今收-除权昨收)/除权昨收 】
|
||||
vol float 成交量 (手)
|
||||
amount float 成交额 (千元)
|
||||
89
src/data/client.py
Normal file
89
src/data/client.py
Normal file
@@ -0,0 +1,89 @@
|
||||
"""Simplified Tushare client with rate limiting and retry logic."""
|
||||
import time
|
||||
import pandas as pd
|
||||
from typing import Optional
|
||||
from src.data.config import get_config
|
||||
from src.data.rate_limiter import TokenBucketRateLimiter
|
||||
|
||||
|
||||
class TushareClient:
|
||||
"""Tushare API client with rate limiting and retry."""
|
||||
|
||||
def __init__(self, token: Optional[str] = None):
|
||||
"""Initialize client.
|
||||
|
||||
Args:
|
||||
token: Tushare API token (auto-loaded from config if not provided)
|
||||
"""
|
||||
cfg = get_config()
|
||||
token = token or cfg.tushare_token
|
||||
|
||||
if not token:
|
||||
raise ValueError("Tushare token is required")
|
||||
|
||||
self.token = token
|
||||
self.config = cfg
|
||||
|
||||
# Initialize rate limiter: capacity = rate_limit, refill_rate = rate_limit/60 per second
|
||||
rate_per_second = cfg.rate_limit / 60.0
|
||||
self.rate_limiter = TokenBucketRateLimiter(
|
||||
capacity=cfg.rate_limit,
|
||||
refill_rate_per_second=rate_per_second,
|
||||
)
|
||||
|
||||
self._api = None
|
||||
|
||||
def _get_api(self):
|
||||
"""Get Tushare API instance."""
|
||||
if self._api is None:
|
||||
import tushare as ts
|
||||
self._api = ts.pro_api(self.token)
|
||||
return self._api
|
||||
|
||||
def query(self, api_name: str, timeout: float = 30.0, **params) -> pd.DataFrame:
|
||||
"""Execute API query with rate limiting and retry.
|
||||
|
||||
Args:
|
||||
api_name: API name (e.g., 'daily')
|
||||
timeout: Timeout for rate limiting
|
||||
**params: API parameters
|
||||
|
||||
Returns:
|
||||
DataFrame with query results
|
||||
"""
|
||||
# Acquire rate limit token
|
||||
success, wait_time = self.rate_limiter.acquire(timeout=timeout)
|
||||
|
||||
if not success:
|
||||
raise RuntimeError(f"Rate limit exceeded after {timeout}s timeout")
|
||||
|
||||
if wait_time > 0:
|
||||
print(f"[RateLimit] Waited {wait_time:.2f}s for token")
|
||||
|
||||
# Execute with retry
|
||||
max_retries = 3
|
||||
retry_delays = [1, 3, 10]
|
||||
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
api = self._get_api()
|
||||
data = api.query(api_name, **params)
|
||||
|
||||
available = self.rate_limiter.get_available_tokens()
|
||||
print(f"[Tushare] {api_name} | tokens: {available:.0f}/{self.rate_limiter.capacity}")
|
||||
|
||||
return data
|
||||
|
||||
except Exception as e:
|
||||
if attempt < max_retries - 1:
|
||||
delay = retry_delays[attempt]
|
||||
print(f"[Retry] {api_name} failed (attempt {attempt + 1}): {e}, retry in {delay}s")
|
||||
time.sleep(delay)
|
||||
else:
|
||||
raise RuntimeError(f"API call failed after {max_retries} attempts: {e}")
|
||||
|
||||
return pd.DataFrame()
|
||||
|
||||
def close(self):
|
||||
"""Close client."""
|
||||
pass
|
||||
33
src/data/config.py
Normal file
33
src/data/config.py
Normal file
@@ -0,0 +1,33 @@
|
||||
"""Configuration management for data collection module."""
|
||||
from pathlib import Path
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
|
||||
class Config(BaseSettings):
|
||||
"""Application configuration loaded from environment variables."""
|
||||
|
||||
# Tushare API token
|
||||
tushare_token: str = ""
|
||||
|
||||
# Data storage path
|
||||
data_path: Path = Path("./data")
|
||||
|
||||
# Rate limit: requests per minute
|
||||
rate_limit: int = 100
|
||||
|
||||
# Thread pool size
|
||||
threads: int = 2
|
||||
|
||||
class Config:
|
||||
env_file = ".env.local"
|
||||
env_file_encoding = "utf-8"
|
||||
case_sensitive = False
|
||||
|
||||
|
||||
# Global config instance
|
||||
config = Config()
|
||||
|
||||
|
||||
def get_config() -> Config:
|
||||
"""Get configuration instance."""
|
||||
return config
|
||||
70
src/data/daily.py
Normal file
70
src/data/daily.py
Normal file
@@ -0,0 +1,70 @@
|
||||
"""Simplified daily market data interface.
|
||||
|
||||
A single function to fetch A股日线行情 data from Tushare.
|
||||
Supports all output fields including tor (换手率) and vr (量比).
|
||||
"""
|
||||
import pandas as pd
|
||||
from typing import Optional, List, Literal
|
||||
from src.data.client import TushareClient
|
||||
|
||||
|
||||
def get_daily(
|
||||
ts_code: str,
|
||||
start_date: Optional[str] = None,
|
||||
end_date: Optional[str] = None,
|
||||
trade_date: Optional[str] = None,
|
||||
adj: Literal[None, "qfq", "hfq"] = None,
|
||||
factors: Optional[List[Literal["tor", "vr"]]] = None,
|
||||
adjfactor: bool = False,
|
||||
) -> pd.DataFrame:
|
||||
"""Fetch daily market data for A-share stocks.
|
||||
|
||||
This is a simplified interface that combines rate limiting, API calls,
|
||||
and error handling into a single function.
|
||||
|
||||
Args:
|
||||
ts_code: Stock code (e.g., '000001.SZ', '600000.SH')
|
||||
start_date: Start date in YYYYMMDD format
|
||||
end_date: End date in YYYYMMDD format
|
||||
trade_date: Specific trade date in YYYYMMDD format
|
||||
adj: Adjustment type - None, 'qfq' (forward), 'hfq' (backward)
|
||||
factors: List of factors to include - 'tor' (turnover rate), 'vr' (volume ratio)
|
||||
adjfactor: Whether to include adjustment factor
|
||||
|
||||
Returns:
|
||||
pd.DataFrame with daily market data containing:
|
||||
- Base fields: ts_code, trade_date, open, high, low, close, pre_close,
|
||||
change, pct_chg, vol, amount
|
||||
- Factor fields (if requested): tor, vr
|
||||
- Adjustment factor (if adjfactor=True): adjfactor
|
||||
|
||||
Example:
|
||||
>>> data = get_daily('000001.SZ', start_date='20240101', end_date='20240131')
|
||||
>>> data = get_daily('600000.SH', factors=['tor', 'vr'])
|
||||
"""
|
||||
# Initialize client
|
||||
client = TushareClient()
|
||||
|
||||
# Build parameters
|
||||
params = {"ts_code": ts_code}
|
||||
|
||||
if start_date:
|
||||
params["start_date"] = start_date
|
||||
if end_date:
|
||||
params["end_date"] = end_date
|
||||
if trade_date:
|
||||
params["trade_date"] = trade_date
|
||||
if adj:
|
||||
params["adj"] = adj
|
||||
if factors:
|
||||
params["factors"] = factors
|
||||
if adjfactor:
|
||||
params["adjfactor"] = "True"
|
||||
|
||||
# Fetch data
|
||||
data = client.query("daily", **params)
|
||||
|
||||
if data.empty:
|
||||
print(f"[get_daily] No data for ts_code={ts_code}")
|
||||
|
||||
return data
|
||||
167
src/data/rate_limiter.py
Normal file
167
src/data/rate_limiter.py
Normal file
@@ -0,0 +1,167 @@
|
||||
"""Token bucket rate limiter implementation.
|
||||
|
||||
This module provides a thread-safe token bucket algorithm for rate limiting.
|
||||
"""
|
||||
import time
|
||||
import threading
|
||||
from typing import Optional
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
|
||||
@dataclass
|
||||
class RateLimiterStats:
|
||||
"""Statistics for rate limiter."""
|
||||
total_requests: int = 0
|
||||
successful_requests: int = 0
|
||||
denied_requests: int = 0
|
||||
total_wait_time: float = 0.0
|
||||
current_tokens: float = field(default=None, init=False)
|
||||
|
||||
def __post_init__(self):
|
||||
self.current_tokens = field(default=None)
|
||||
|
||||
|
||||
class TokenBucketRateLimiter:
|
||||
"""Thread-safe token bucket rate limiter.
|
||||
|
||||
Implements a token bucket algorithm for controlling request rate.
|
||||
Tokens are added at a fixed rate up to the bucket capacity.
|
||||
|
||||
Attributes:
|
||||
capacity: Maximum number of tokens in the bucket
|
||||
refill_rate: Number of tokens added per second
|
||||
initial_tokens: Initial number of tokens (default: capacity)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
capacity: int = 100,
|
||||
refill_rate_per_second: float = 1.67,
|
||||
initial_tokens: Optional[int] = None,
|
||||
) -> None:
|
||||
"""Initialize the token bucket rate limiter.
|
||||
|
||||
Args:
|
||||
capacity: Maximum token capacity
|
||||
refill_rate_per_second: Token refill rate per second
|
||||
initial_tokens: Initial token count (default: capacity)
|
||||
"""
|
||||
self.capacity = capacity
|
||||
self.refill_rate = refill_rate_per_second
|
||||
self.tokens = float(initial_tokens if initial_tokens is not None else capacity)
|
||||
self.last_refill_time = time.monotonic()
|
||||
self._lock = threading.RLock()
|
||||
self._stats = RateLimiterStats()
|
||||
self._stats.current_tokens = self.tokens
|
||||
|
||||
def acquire(self, timeout: float = 30.0) -> tuple[bool, float]:
|
||||
"""Acquire a token from the bucket.
|
||||
|
||||
Blocks until a token is available or timeout expires.
|
||||
|
||||
Args:
|
||||
timeout: Maximum time to wait for a token in seconds
|
||||
|
||||
Returns:
|
||||
Tuple of (success, wait_time):
|
||||
- success: True if token was acquired, False if timed out
|
||||
- wait_time: Time spent waiting for token
|
||||
"""
|
||||
start_time = time.monotonic()
|
||||
wait_time = 0.0
|
||||
|
||||
with self._lock:
|
||||
self._refill()
|
||||
|
||||
if self.tokens >= 1:
|
||||
self.tokens -= 1
|
||||
self._stats.total_requests += 1
|
||||
self._stats.successful_requests += 1
|
||||
self._stats.current_tokens = self.tokens
|
||||
return True, 0.0
|
||||
|
||||
# Calculate time to wait for next token
|
||||
tokens_needed = 1 - self.tokens
|
||||
time_to_refill = tokens_needed / self.refill_rate
|
||||
|
||||
if time_to_refill > timeout:
|
||||
self._stats.total_requests += 1
|
||||
self._stats.denied_requests += 1
|
||||
return False, timeout
|
||||
|
||||
# Wait for tokens
|
||||
self._lock.release()
|
||||
time.sleep(time_to_refill)
|
||||
self._lock.acquire()
|
||||
|
||||
wait_time = time.monotonic() - start_time
|
||||
|
||||
with self._lock:
|
||||
self._refill()
|
||||
if self.tokens >= 1:
|
||||
self.tokens -= 1
|
||||
self._stats.total_requests += 1
|
||||
self._stats.successful_requests += 1
|
||||
self._stats.total_wait_time += wait_time
|
||||
self._stats.current_tokens = self.tokens
|
||||
return True, wait_time
|
||||
|
||||
self._stats.total_requests += 1
|
||||
self._stats.denied_requests += 1
|
||||
return False, wait_time
|
||||
|
||||
def acquire_nonblocking(self) -> tuple[bool, float]:
|
||||
"""Try to acquire a token without blocking.
|
||||
|
||||
Returns:
|
||||
Tuple of (success, wait_time):
|
||||
- success: True if token was acquired, False otherwise
|
||||
- wait_time: 0 for non-blocking, or required wait time if failed
|
||||
"""
|
||||
with self._lock:
|
||||
self._refill()
|
||||
|
||||
if self.tokens >= 1:
|
||||
self.tokens -= 1
|
||||
self._stats.total_requests += 1
|
||||
self._stats.successful_requests += 1
|
||||
self._stats.current_tokens = self.tokens
|
||||
return True, 0.0
|
||||
|
||||
# Calculate time needed
|
||||
tokens_needed = 1 - self.tokens
|
||||
time_to_refill = tokens_needed / self.refill_rate
|
||||
|
||||
self._stats.total_requests += 1
|
||||
self._stats.denied_requests += 1
|
||||
return False, time_to_refill
|
||||
|
||||
def _refill(self) -> None:
|
||||
"""Refill tokens based on elapsed time."""
|
||||
current_time = time.monotonic()
|
||||
elapsed = current_time - self.last_refill_time
|
||||
self.last_refill_time = current_time
|
||||
|
||||
tokens_to_add = elapsed * self.refill_rate
|
||||
self.tokens = min(self.capacity, self.tokens + tokens_to_add)
|
||||
|
||||
def get_available_tokens(self) -> float:
|
||||
"""Get the current number of available tokens.
|
||||
|
||||
Returns:
|
||||
Current token count
|
||||
"""
|
||||
with self._lock:
|
||||
self._refill()
|
||||
return self.tokens
|
||||
|
||||
def get_stats(self) -> RateLimiterStats:
|
||||
"""Get rate limiter statistics.
|
||||
|
||||
Returns:
|
||||
RateLimiterStats instance
|
||||
"""
|
||||
with self._lock:
|
||||
self._refill()
|
||||
self._stats.current_tokens = self.tokens
|
||||
return self._stats
|
||||
133
src/data/storage.py
Normal file
133
src/data/storage.py
Normal file
@@ -0,0 +1,133 @@
|
||||
"""Simplified HDF5 storage for data persistence."""
|
||||
import os
|
||||
import pandas as pd
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
from src.data.config import get_config
|
||||
|
||||
|
||||
class Storage:
|
||||
"""HDF5 storage manager for saving and loading data."""
|
||||
|
||||
def __init__(self, path: Optional[Path] = None):
|
||||
"""Initialize storage.
|
||||
|
||||
Args:
|
||||
path: Base path for data storage (auto-loaded from config if not provided)
|
||||
"""
|
||||
cfg = get_config()
|
||||
self.base_path = path or cfg.data_path
|
||||
self.base_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
def _get_file_path(self, name: str) -> Path:
|
||||
"""Get full path for an HDF5 file."""
|
||||
return self.base_path / f"{name}.h5"
|
||||
|
||||
def save(self, name: str, data: pd.DataFrame, mode: str = "append") -> dict:
|
||||
"""Save data to HDF5 file.
|
||||
|
||||
Args:
|
||||
name: Dataset name (also used as filename)
|
||||
data: DataFrame to save
|
||||
mode: 'append' or 'replace'
|
||||
|
||||
Returns:
|
||||
Dict with save result
|
||||
"""
|
||||
if data.empty:
|
||||
return {"status": "skipped", "rows": 0}
|
||||
|
||||
file_path = self._get_file_path(name)
|
||||
|
||||
try:
|
||||
with pd.HDFStore(file_path, mode="a") as store:
|
||||
if mode == "replace" or name not in store.keys():
|
||||
store.put(name, data, format="table")
|
||||
else:
|
||||
# Merge with existing data
|
||||
existing = store[name]
|
||||
combined = pd.concat([existing, data], ignore_index=True)
|
||||
combined = combined.drop_duplicates(subset=["ts_code", "trade_date"], keep="last")
|
||||
store.put(name, combined, format="table")
|
||||
|
||||
print(f"[Storage] Saved {len(data)} rows to {file_path}")
|
||||
return {"status": "success", "rows": len(data), "path": str(file_path)}
|
||||
|
||||
except Exception as e:
|
||||
print(f"[Storage] Error saving {name}: {e}")
|
||||
return {"status": "error", "error": str(e)}
|
||||
|
||||
def load(self, name: str,
|
||||
start_date: Optional[str] = None,
|
||||
end_date: Optional[str] = None,
|
||||
ts_code: Optional[str] = None) -> pd.DataFrame:
|
||||
"""Load data from HDF5 file.
|
||||
|
||||
Args:
|
||||
name: Dataset name
|
||||
start_date: Start date filter (YYYYMMDD)
|
||||
end_date: End date filter (YYYYMMDD)
|
||||
ts_code: Stock code filter
|
||||
|
||||
Returns:
|
||||
DataFrame with loaded data
|
||||
"""
|
||||
file_path = self._get_file_path(name)
|
||||
|
||||
if not file_path.exists():
|
||||
print(f"[Storage] File not found: {file_path}")
|
||||
return pd.DataFrame()
|
||||
|
||||
try:
|
||||
with pd.HDFStore(file_path, mode="r") as store:
|
||||
if name not in store.keys():
|
||||
return pd.DataFrame()
|
||||
|
||||
data = store[name]
|
||||
|
||||
# Apply filters
|
||||
if start_date and end_date and "trade_date" in data.columns:
|
||||
data = data[(data["trade_date"] >= start_date) & (data["trade_date"] <= end_date)]
|
||||
|
||||
if ts_code and "ts_code" in data.columns:
|
||||
data = data[data["ts_code"] == ts_code]
|
||||
|
||||
return data
|
||||
|
||||
except Exception as e:
|
||||
print(f"[Storage] Error loading {name}: {e}")
|
||||
return pd.DataFrame()
|
||||
|
||||
def get_last_date(self, name: str) -> Optional[str]:
|
||||
"""Get the latest date in storage.
|
||||
|
||||
Args:
|
||||
name: Dataset name
|
||||
|
||||
Returns:
|
||||
Latest date string or None
|
||||
"""
|
||||
data = self.load(name)
|
||||
if data.empty or "trade_date" not in data.columns:
|
||||
return None
|
||||
return str(data["trade_date"].max())
|
||||
|
||||
def exists(self, name: str) -> bool:
|
||||
"""Check if dataset exists."""
|
||||
return self._get_file_path(name).exists()
|
||||
|
||||
def delete(self, name: str) -> bool:
|
||||
"""Delete a dataset.
|
||||
|
||||
Args:
|
||||
name: Dataset name
|
||||
|
||||
Returns:
|
||||
True if deleted
|
||||
"""
|
||||
file_path = self._get_file_path(name)
|
||||
if file_path.exists():
|
||||
file_path.unlink()
|
||||
print(f"[Storage] Deleted {file_path}")
|
||||
return True
|
||||
return False
|
||||
419
tests/test_daily.py
Normal file
419
tests/test_daily.py
Normal file
@@ -0,0 +1,419 @@
|
||||
"""Test for daily market data API.
|
||||
|
||||
Tests the daily interface implementation against api.md requirements:
|
||||
- A股日线行情所有输出字段
|
||||
- tor 换手率
|
||||
- vr 量比
|
||||
"""
|
||||
import pytest
|
||||
import pandas as pd
|
||||
from unittest.mock import Mock, patch
|
||||
from src.data.daily import get_daily
|
||||
from src.data.client import TushareClient
|
||||
|
||||
|
||||
# Expected output fields according to api.md
|
||||
EXPECTED_BASE_FIELDS = [
|
||||
'ts_code', # 股票代码
|
||||
'trade_date', # 交易日期
|
||||
'open', # 开盘价
|
||||
'high', # 最高价
|
||||
'low', # 最低价
|
||||
'close', # 收盘价
|
||||
'pre_close', # 昨收价
|
||||
'change', # 涨跌额
|
||||
'pct_chg', # 涨跌幅
|
||||
'vol', # 成交量
|
||||
'amount', # 成交额
|
||||
]
|
||||
|
||||
EXPECTED_FACTOR_FIELDS = [
|
||||
'tor', # 换手率
|
||||
'vr', # 量比
|
||||
]
|
||||
|
||||
|
||||
def run_tests_with_print():
|
||||
"""Run all tests and print results."""
|
||||
print("=" * 60)
|
||||
print("Daily API 测试开始")
|
||||
print("=" * 60)
|
||||
|
||||
test_results = []
|
||||
|
||||
# Test 1: Basic daily data fetch
|
||||
print("\n【测试1】基本日线数据获取")
|
||||
print("-" * 40)
|
||||
mock_data = pd.DataFrame({
|
||||
'ts_code': ['000001.SZ'],
|
||||
'trade_date': ['20240102'],
|
||||
'open': [10.5],
|
||||
'high': [11.0],
|
||||
'low': [10.2],
|
||||
'close': [10.8],
|
||||
'pre_close': [10.3],
|
||||
'change': [0.5],
|
||||
'pct_chg': [4.85],
|
||||
'vol': [1000000],
|
||||
'amount': [10800000],
|
||||
})
|
||||
|
||||
with patch.object(TushareClient, '__init__', lambda self, token=None: None):
|
||||
with patch.object(TushareClient, 'query', return_value=mock_data):
|
||||
result = get_daily('000001.SZ', start_date='20240101', end_date='20240131')
|
||||
|
||||
print(f"获取数据形状: {result.shape}")
|
||||
print(f"获取数据列: {result.columns.tolist()}")
|
||||
print(f"数据内容:\n{result}")
|
||||
|
||||
# Verify
|
||||
tests_passed = isinstance(result, pd.DataFrame)
|
||||
tests_passed &= len(result) == 1
|
||||
tests_passed &= result['ts_code'].iloc[0] == '000001.SZ'
|
||||
|
||||
print(f"\n测试结果: {'通过 ✓' if tests_passed else '失败 ✗'}")
|
||||
test_results.append(("基本日线数据获取", tests_passed))
|
||||
|
||||
# Test 2: Fetch with factors
|
||||
print("\n【测试2】获取含换手率和量比的数据")
|
||||
print("-" * 40)
|
||||
mock_data_factors = pd.DataFrame({
|
||||
'ts_code': ['000001.SZ'],
|
||||
'trade_date': ['20240102'],
|
||||
'open': [10.5],
|
||||
'high': [11.0],
|
||||
'low': [10.2],
|
||||
'close': [10.8],
|
||||
'pre_close': [10.3],
|
||||
'change': [0.5],
|
||||
'pct_chg': [4.85],
|
||||
'vol': [1000000],
|
||||
'amount': [10800000],
|
||||
'tor': [2.5],
|
||||
'vr': [1.2],
|
||||
})
|
||||
|
||||
with patch.object(TushareClient, '__init__', lambda self, token=None: None):
|
||||
with patch.object(TushareClient, 'query', return_value=mock_data_factors):
|
||||
result = get_daily(
|
||||
'000001.SZ',
|
||||
start_date='20240101',
|
||||
end_date='20240131',
|
||||
factors=['tor', 'vr'],
|
||||
)
|
||||
|
||||
print(f"获取数据形状: {result.shape}")
|
||||
print(f"获取数据列: {result.columns.tolist()}")
|
||||
print(f"数据内容:\n{result}")
|
||||
|
||||
# Verify columns
|
||||
tests_passed = isinstance(result, pd.DataFrame)
|
||||
missing_base = [f for f in EXPECTED_BASE_FIELDS if f not in result.columns]
|
||||
missing_factors = [f for f in EXPECTED_FACTOR_FIELDS if f not in result.columns]
|
||||
|
||||
print(f"\n基础列检查: {'全部存在 ✓' if not missing_base else f'缺失: {missing_base} ✗'}")
|
||||
print(f"因子列检查: {'全部存在 ✓' if not missing_factors else f'缺失: {missing_factors} ✗'}")
|
||||
|
||||
tests_passed &= len(missing_base) == 0
|
||||
tests_passed &= len(missing_factors) == 0
|
||||
|
||||
print(f"\n测试结果: {'通过 ✓' if tests_passed else '失败 ✗'}")
|
||||
test_results.append(("含因子数据获取", tests_passed))
|
||||
|
||||
# Test 3: Output fields completeness
|
||||
print("\n【测试3】输出字段完整性检查")
|
||||
print("-" * 40)
|
||||
mock_data = pd.DataFrame({
|
||||
'ts_code': ['600000.SH'],
|
||||
'trade_date': ['20240102'],
|
||||
'open': [10.5],
|
||||
'high': [11.0],
|
||||
'low': [10.2],
|
||||
'close': [10.8],
|
||||
'pre_close': [10.3],
|
||||
'change': [0.5],
|
||||
'pct_chg': [4.85],
|
||||
'vol': [1000000],
|
||||
'amount': [10800000],
|
||||
})
|
||||
|
||||
with patch.object(TushareClient, '__init__', lambda self, token=None: None):
|
||||
with patch.object(TushareClient, 'query', return_value=mock_data):
|
||||
result = get_daily('600000.SH')
|
||||
|
||||
print(f"获取数据形状: {result.shape}")
|
||||
print(f"获取数据列: {result.columns.tolist()}")
|
||||
print(f"期望基础列: {EXPECTED_BASE_FIELDS}")
|
||||
|
||||
missing = set(EXPECTED_BASE_FIELDS) - set(result.columns)
|
||||
print(f"缺失列: {missing if missing else '无'}")
|
||||
|
||||
tests_passed = set(EXPECTED_BASE_FIELDS).issubset(result.columns.tolist())
|
||||
print(f"\n测试结果: {'通过 ✓' if tests_passed else '失败 ✗'}")
|
||||
test_results.append(("输出字段完整性", tests_passed))
|
||||
|
||||
# Test 4: Empty result
|
||||
print("\n【测试4】空结果处理")
|
||||
print("-" * 40)
|
||||
mock_data = pd.DataFrame()
|
||||
|
||||
with patch.object(TushareClient, '__init__', lambda self, token=None: None):
|
||||
with patch.object(TushareClient, 'query', return_value=mock_data):
|
||||
result = get_daily('INVALID.SZ')
|
||||
|
||||
print(f"获取数据是否为空: {result.empty}")
|
||||
tests_passed = result.empty
|
||||
print(f"\n测试结果: {'通过 ✓' if tests_passed else '失败 ✗'}")
|
||||
test_results.append(("空结果处理", tests_passed))
|
||||
|
||||
# Test 5: Date range query
|
||||
print("\n【测试5】日期范围查询")
|
||||
print("-" * 40)
|
||||
mock_data = pd.DataFrame({
|
||||
'ts_code': ['000001.SZ', '000001.SZ'],
|
||||
'trade_date': ['20240102', '20240103'],
|
||||
'open': [10.5, 10.6],
|
||||
'high': [11.0, 11.1],
|
||||
'low': [10.2, 10.3],
|
||||
'close': [10.8, 10.9],
|
||||
'pre_close': [10.3, 10.8],
|
||||
'change': [0.5, 0.1],
|
||||
'pct_chg': [4.85, 0.93],
|
||||
'vol': [1000000, 1100000],
|
||||
'amount': [10800000, 11900000],
|
||||
})
|
||||
|
||||
with patch.object(TushareClient, '__init__', lambda self, token=None: None):
|
||||
with patch.object(TushareClient, 'query', return_value=mock_data):
|
||||
result = get_daily(
|
||||
'000001.SZ',
|
||||
start_date='20240101',
|
||||
end_date='20240131',
|
||||
)
|
||||
|
||||
print(f"获取数据数量: {len(result)}")
|
||||
print(f"期望数量: 2")
|
||||
print(f"数据内容:\n{result}")
|
||||
|
||||
tests_passed = len(result) == 2
|
||||
print(f"\n测试结果: {'通过 ✓' if tests_passed else '失败 ✗'}")
|
||||
test_results.append(("日期范围查询", tests_passed))
|
||||
|
||||
# Test 6: With adjustment
|
||||
print("\n【测试6】带复权参数查询")
|
||||
print("-" * 40)
|
||||
mock_data = pd.DataFrame({
|
||||
'ts_code': ['000001.SZ'],
|
||||
'trade_date': ['20240102'],
|
||||
'open': [10.5],
|
||||
'high': [11.0],
|
||||
'low': [10.2],
|
||||
'close': [10.8],
|
||||
'pre_close': [10.3],
|
||||
'change': [0.5],
|
||||
'pct_chg': [4.85],
|
||||
'vol': [1000000],
|
||||
'amount': [10800000],
|
||||
})
|
||||
|
||||
with patch.object(TushareClient, '__init__', lambda self, token=None: None):
|
||||
with patch.object(TushareClient, 'query', return_value=mock_data):
|
||||
result = get_daily('000001.SZ', adj='qfq')
|
||||
|
||||
print(f"获取数据形状: {result.shape}")
|
||||
print(f"数据内容:\n{result}")
|
||||
|
||||
tests_passed = isinstance(result, pd.DataFrame)
|
||||
print(f"\n测试结果: {'通过 ✓' if tests_passed else '失败 ✗'}")
|
||||
test_results.append(("复权参数查询", tests_passed))
|
||||
|
||||
# Summary
|
||||
print("\n" + "=" * 60)
|
||||
print("测试汇总")
|
||||
print("=" * 60)
|
||||
passed = sum(1 for _, r in test_results if r)
|
||||
total = len(test_results)
|
||||
print(f"总测试数: {total}")
|
||||
print(f"通过: {passed}")
|
||||
print(f"失败: {total - passed}")
|
||||
print(f"通过率: {passed/total*100:.1f}%")
|
||||
print("\n详细结果:")
|
||||
for name, passed in test_results:
|
||||
status = "通过 ✓" if passed else "失败 ✗"
|
||||
print(f" - {name}: {status}")
|
||||
|
||||
return all(r for _, r in test_results)
|
||||
|
||||
|
||||
class TestGetDaily:
|
||||
"""Test cases for simplified get_daily function."""
|
||||
|
||||
def test_fetch_basic(self):
|
||||
"""Test basic daily data fetch."""
|
||||
mock_data = pd.DataFrame({
|
||||
'ts_code': ['000001.SZ'],
|
||||
'trade_date': ['20240102'],
|
||||
'open': [10.5],
|
||||
'high': [11.0],
|
||||
'low': [10.2],
|
||||
'close': [10.8],
|
||||
'pre_close': [10.3],
|
||||
'change': [0.5],
|
||||
'pct_chg': [4.85],
|
||||
'vol': [1000000],
|
||||
'amount': [10800000],
|
||||
})
|
||||
|
||||
with patch.object(TushareClient, '__init__', lambda self, token=None: None):
|
||||
with patch.object(TushareClient, 'query', return_value=mock_data):
|
||||
result = get_daily('000001.SZ', start_date='20240101', end_date='20240131')
|
||||
|
||||
assert isinstance(result, pd.DataFrame)
|
||||
assert len(result) == 1
|
||||
assert result['ts_code'].iloc[0] == '000001.SZ'
|
||||
|
||||
def test_fetch_with_factors(self):
|
||||
"""Test fetch with tor and vr factors."""
|
||||
mock_data = pd.DataFrame({
|
||||
'ts_code': ['000001.SZ'],
|
||||
'trade_date': ['20240102'],
|
||||
'open': [10.5],
|
||||
'high': [11.0],
|
||||
'low': [10.2],
|
||||
'close': [10.8],
|
||||
'pre_close': [10.3],
|
||||
'change': [0.5],
|
||||
'pct_chg': [4.85],
|
||||
'vol': [1000000],
|
||||
'amount': [10800000],
|
||||
'tor': [2.5], # 换手率
|
||||
'vr': [1.2], # 量比
|
||||
})
|
||||
|
||||
with patch.object(TushareClient, '__init__', lambda self, token=None: None):
|
||||
with patch.object(TushareClient, 'query', return_value=mock_data):
|
||||
result = get_daily(
|
||||
'000001.SZ',
|
||||
start_date='20240101',
|
||||
end_date='20240131',
|
||||
factors=['tor', 'vr'],
|
||||
)
|
||||
|
||||
assert isinstance(result, pd.DataFrame)
|
||||
# Check all base fields are present
|
||||
for field in EXPECTED_BASE_FIELDS:
|
||||
assert field in result.columns, f"Missing base field: {field}"
|
||||
# Check factor fields are present
|
||||
for field in EXPECTED_FACTOR_FIELDS:
|
||||
assert field in result.columns, f"Missing factor field: {field}"
|
||||
|
||||
def test_output_fields_completeness(self):
|
||||
"""Verify all required output fields are returned."""
|
||||
mock_data = pd.DataFrame({
|
||||
'ts_code': ['600000.SH'],
|
||||
'trade_date': ['20240102'],
|
||||
'open': [10.5],
|
||||
'high': [11.0],
|
||||
'low': [10.2],
|
||||
'close': [10.8],
|
||||
'pre_close': [10.3],
|
||||
'change': [0.5],
|
||||
'pct_chg': [4.85],
|
||||
'vol': [1000000],
|
||||
'amount': [10800000],
|
||||
})
|
||||
|
||||
with patch.object(TushareClient, '__init__', lambda self, token=None: None):
|
||||
with patch.object(TushareClient, 'query', return_value=mock_data):
|
||||
result = get_daily('600000.SH')
|
||||
|
||||
# Verify all base fields are present
|
||||
assert set(EXPECTED_BASE_FIELDS).issubset(result.columns.tolist()), \
|
||||
f"Missing fields: {set(EXPECTED_BASE_FIELDS) - set(result.columns)}"
|
||||
|
||||
def test_empty_result(self):
|
||||
"""Test handling of empty results."""
|
||||
mock_data = pd.DataFrame()
|
||||
|
||||
with patch.object(TushareClient, '__init__', lambda self, token=None: None):
|
||||
with patch.object(TushareClient, 'query', return_value=mock_data):
|
||||
result = get_daily('INVALID.SZ')
|
||||
|
||||
assert result.empty
|
||||
|
||||
def test_date_range_query(self):
|
||||
"""Test query with date range."""
|
||||
mock_data = pd.DataFrame({
|
||||
'ts_code': ['000001.SZ', '000001.SZ'],
|
||||
'trade_date': ['20240102', '20240103'],
|
||||
'open': [10.5, 10.6],
|
||||
'high': [11.0, 11.1],
|
||||
'low': [10.2, 10.3],
|
||||
'close': [10.8, 10.9],
|
||||
'pre_close': [10.3, 10.8],
|
||||
'change': [0.5, 0.1],
|
||||
'pct_chg': [4.85, 0.93],
|
||||
'vol': [1000000, 1100000],
|
||||
'amount': [10800000, 11900000],
|
||||
})
|
||||
|
||||
with patch.object(TushareClient, '__init__', lambda self, token=None: None):
|
||||
with patch.object(TushareClient, 'query', return_value=mock_data):
|
||||
result = get_daily(
|
||||
'000001.SZ',
|
||||
start_date='20240101',
|
||||
end_date='20240131',
|
||||
)
|
||||
|
||||
assert len(result) == 2
|
||||
|
||||
def test_with_adj(self):
|
||||
"""Test fetch with adjustment type."""
|
||||
mock_data = pd.DataFrame({
|
||||
'ts_code': ['000001.SZ'],
|
||||
'trade_date': ['20240102'],
|
||||
'open': [10.5],
|
||||
'high': [11.0],
|
||||
'low': [10.2],
|
||||
'close': [10.8],
|
||||
'pre_close': [10.3],
|
||||
'change': [0.5],
|
||||
'pct_chg': [4.85],
|
||||
'vol': [1000000],
|
||||
'amount': [10800000],
|
||||
})
|
||||
|
||||
with patch.object(TushareClient, '__init__', lambda self, token=None: None):
|
||||
with patch.object(TushareClient, 'query', return_value=mock_data):
|
||||
result = get_daily('000001.SZ', adj='qfq')
|
||||
|
||||
assert isinstance(result, pd.DataFrame)
|
||||
|
||||
|
||||
def test_integration():
|
||||
"""Integration test with real Tushare API (requires valid token)."""
|
||||
import os
|
||||
token = os.environ.get('TUSHARE_TOKEN')
|
||||
if not token:
|
||||
pytest.skip("TUSHARE_TOKEN not configured")
|
||||
|
||||
result = get_daily('000001.SZ', start_date='20240101', end_date='20240131', factors=['tor', 'vr'])
|
||||
|
||||
# Verify structure
|
||||
assert isinstance(result, pd.DataFrame)
|
||||
if not result.empty:
|
||||
# Check base fields
|
||||
for field in EXPECTED_BASE_FIELDS:
|
||||
assert field in result.columns, f"Missing base field: {field}"
|
||||
# Check factor fields
|
||||
for field in EXPECTED_FACTOR_FIELDS:
|
||||
assert field in result.columns, f"Missing factor field: {field}"
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# 运行详细的打印测试
|
||||
run_tests_with_print()
|
||||
print("\n" + "=" * 60)
|
||||
print("运行 pytest 单元测试")
|
||||
print("=" * 60 + "\n")
|
||||
pytest.main([__file__, '-v'])
|
||||
Reference in New Issue
Block a user