feat(data): 添加每日筹码及胜率数据接口 (cyq_perf)
- 新增 api_cyq_perf 模块,支持筹码分布数据获取和同步 - 在 sync_registry 中注册 cyq_perf 同步器
This commit is contained in:
@@ -531,6 +531,7 @@ def get_{data_type}(
|
||||
start_date: Optional[str] = None,
|
||||
end_date: Optional[str] = None,
|
||||
ts_code: Optional[str] = None,
|
||||
client: Optional[TushareClient] = None, # 关键:可选客户端参数,用于共享速率限制
|
||||
) -> pd.DataFrame:
|
||||
"""Fetch {数据描述} from Tushare.
|
||||
|
||||
@@ -541,6 +542,9 @@ def get_{data_type}(
|
||||
start_date: Start date (YYYYMMDD format)
|
||||
end_date: End date (YYYYMMDD format)
|
||||
ts_code: Stock code filter (optional)
|
||||
client: Optional TushareClient instance for shared rate limiting.
|
||||
If None, creates a new client. For concurrent sync operations,
|
||||
pass a shared client to ensure proper rate limiting.
|
||||
|
||||
Returns:
|
||||
pd.DataFrame with columns:
|
||||
@@ -552,12 +556,12 @@ def get_{data_type}(
|
||||
Example:
|
||||
>>> # Get all stocks for a single date
|
||||
>>> data = get_{data_type}(trade_date='20240101')
|
||||
>>>
|
||||
>>>
|
||||
>>> # Get date range data
|
||||
>>> data = get_{data_type}(start_date='20240101', end_date='20240131')
|
||||
"""
|
||||
client = TushareClient()
|
||||
|
||||
client = client or TushareClient() # 如果没有提供则创建新实例
|
||||
|
||||
# Build parameters
|
||||
params = {}
|
||||
if trade_date:
|
||||
@@ -568,14 +572,14 @@ def get_{data_type}(
|
||||
params["end_date"] = end_date
|
||||
if ts_code:
|
||||
params["ts_code"] = ts_code
|
||||
|
||||
|
||||
# Fetch data
|
||||
data = client.query("{tushare_api_name}", **params)
|
||||
|
||||
|
||||
# Rename date column if needed
|
||||
if "date" in data.columns:
|
||||
data = data.rename(columns={"date": "trade_date"})
|
||||
|
||||
|
||||
return data
|
||||
```
|
||||
|
||||
@@ -596,6 +600,7 @@ def get_{data_type}(
|
||||
ts_code: str,
|
||||
start_date: Optional[str] = None,
|
||||
end_date: Optional[str] = None,
|
||||
client: Optional[TushareClient] = None, # 关键:可选客户端参数,用于共享速率限制
|
||||
) -> pd.DataFrame:
|
||||
"""Fetch {数据描述} for a specific stock.
|
||||
|
||||
@@ -603,20 +608,23 @@ def get_{data_type}(
|
||||
ts_code: Stock code (e.g., '000001.SZ')
|
||||
start_date: Start date (YYYYMMDD format)
|
||||
end_date: End date (YYYYMMDD format)
|
||||
client: Optional TushareClient instance for shared rate limiting.
|
||||
If None, creates a new client. For concurrent sync operations,
|
||||
pass a shared client to ensure proper rate limiting.
|
||||
|
||||
Returns:
|
||||
pd.DataFrame with {数据描述} data
|
||||
"""
|
||||
client = TushareClient()
|
||||
|
||||
client = client or TushareClient() # 如果没有提供则创建新实例
|
||||
|
||||
params = {"ts_code": ts_code}
|
||||
if start_date:
|
||||
params["start_date"] = start_date
|
||||
if end_date:
|
||||
params["end_date"] = end_date
|
||||
|
||||
|
||||
data = client.query("{tushare_api_name}", **params)
|
||||
|
||||
|
||||
return data
|
||||
```
|
||||
|
||||
@@ -751,6 +759,8 @@ Skill 会自动:
|
||||
- [ ] 已创建 `tests/test_{data_type}.py` 测试文件
|
||||
### 10.2 接口实现
|
||||
- [ ] 数据获取函数使用 `TushareClient`
|
||||
- [ ] **关键**:数据获取函数接受 `client: Optional[TushareClient] = None` 参数用于共享速率限制
|
||||
- [ ] **关键**:Sync 类在 `fetch_single_date()` / `fetch_single_stock()` 中传递 `self.client`
|
||||
- [ ] 函数包含完整的 Google 风格文档字符串
|
||||
- [ ] 日期参数使用 `YYYYMMDD` 格式
|
||||
- [ ] 返回的 DataFrame 包含 `ts_code` 和 `trade_date` 字段
|
||||
@@ -790,6 +800,6 @@ Skill 会自动:
|
||||
|
||||
---
|
||||
|
||||
**最后更新**: 2026-02-23
|
||||
**最后更新**: 2026-03-26
|
||||
|
||||
**版本**: v2.0 - 更新 DuckDB 存储规范,添加 Skill 自动化说明
|
||||
**版本**: v2.1 - 更新速率限制规范,强调多线程场景下 client 参数传递
|
||||
@@ -184,6 +184,57 @@ def get_xxx(period: str, fields: Optional[str] = None) -> pd.DataFrame:
|
||||
|
||||
---
|
||||
|
||||
## 速率限制规范(关键)
|
||||
|
||||
### 问题背景
|
||||
|
||||
财务数据同步使用 VIP 接口(如 `income_vip`、`balancesheet_vip`)按季度获取全市场数据。在并发场景下,如果每个线程创建独立的 `TushareClient` 实例,每个实例会有独立的令牌桶限流器,导致**限流失效**。
|
||||
|
||||
**实际案例**:
|
||||
- 配置 `RATE_LIMIT=150`,理论上每分钟最多 150 次请求
|
||||
- 如果 10 个线程各自创建独立客户端,实际并发数 = 10 × 150 = 1500 次/分钟
|
||||
- 结果:触发 Tushare API 限流,请求失败
|
||||
|
||||
### 解决方案
|
||||
|
||||
**必须**在数据获取函数中接受可选的 `client` 参数,并在同步类中传递共享实例:
|
||||
|
||||
```python
|
||||
from src.data.client import TushareClient
|
||||
from typing import Optional
|
||||
|
||||
# 1. 数据获取函数必须支持 client 参数
|
||||
def get_{data_type}(
|
||||
period: str,
|
||||
client: Optional[TushareClient] = None, # 关键参数
|
||||
) -> pd.DataFrame:
|
||||
"""Fetch financial data.
|
||||
|
||||
Args:
|
||||
period: 报告期(YYYYMMDD)
|
||||
client: Optional TushareClient for shared rate limiting
|
||||
"""
|
||||
client = client or TushareClient() # 如果没有提供则创建新实例
|
||||
return client.query("{api_name}", period=period)
|
||||
|
||||
# 2. 同步类中传递共享 client
|
||||
class XXXQuarterSync(QuarterBasedSync):
|
||||
def fetch_single_quarter(self, period: str) -> pd.DataFrame:
|
||||
# 使用 self.client(基类创建的共享实例)
|
||||
return get_{data_type}(period=period, client=self.client)
|
||||
```
|
||||
|
||||
### 关键规则
|
||||
|
||||
1. **数据获取函数**:必须接受 `client: Optional[TushareClient] = None` 参数
|
||||
2. **同步类实现**:必须在 `fetch_single_quarter()` 中传递 `self.client`
|
||||
3. **基类保证**:`QuarterBasedSync` 基类在 `__init__` 中创建 `self.client = TushareClient()`
|
||||
4. **使用模式**:数据获取函数使用 `client = client or TushareClient()` 模式
|
||||
|
||||
**注意**:`TushareClient` 内部使用**类级别共享限流器**(`_shared_limiter`),确保所有实例共享同一个令牌桶,但前提是必须复用同一个客户端实例。
|
||||
|
||||
---
|
||||
|
||||
## 类设计规范
|
||||
|
||||
### 类命名规范
|
||||
@@ -1284,6 +1335,7 @@ self.storage.flush()
|
||||
|
||||
| 日期 | 版本 | 变更内容 |
|
||||
|------|------|----------|
|
||||
| 2026-03-26 | v1.4 | 添加速率限制规范:<br>- 强调多线程场景下 client 参数传递<br>- 添加实际案例分析<br>- 说明 TushareClient 共享限流器机制 |
|
||||
| 2026-03-08 | v1.3 | 现金流量表接口实现:<br>- 完成 `api_cashflow.py` 封装<br>- 添加 95 个现金流量表完整字段<br>- 更新调度中心注册<br>- 更新文档标记现金流为已实现 |
|
||||
| 2026-03-08 | v1.2 | 资产负债表接口实现:<br>- 完成 `api_balance.py` 封装<br>- 添加 157 个资产负债表完整字段<br>- 更新调度中心注册<br>- 更新文档中的资产负债表示例为完整实现 |
|
||||
| 2026-03-08 | v1.1 | 完善实际编码细节:<br>- 添加首次同步优化说明<br>- 添加日期格式转换规范<br>- 添加存储层 UPSERT 禁用说明<br>- 添加删除计数处理说明<br>- 扩充常见问题(Q7-Q9) |
|
||||
|
||||
@@ -29,7 +29,7 @@ Example:
|
||||
>>> bak_basic = get_bak_basic(trade_date='20240101')
|
||||
>>> stock_st = get_stock_st(trade_date='20240101')
|
||||
>>> stk_limit = get_stk_limit(trade_date='20240101')
|
||||
>>> cyq_perf = get_cyq_perf('000001.SZ', start_date='20240101', end_date='20240131')
|
||||
>>> cyq_perf = get_cyq_perf(trade_date='20240115')
|
||||
"""
|
||||
|
||||
from src.data.api_wrappers.api_daily_basic import (
|
||||
|
||||
@@ -9,11 +9,12 @@ import pandas as pd
|
||||
from typing import Optional
|
||||
|
||||
from src.data.client import TushareClient
|
||||
from src.data.api_wrappers.base_sync import StockBasedSync
|
||||
from src.data.api_wrappers.base_sync import DateBasedSync
|
||||
|
||||
|
||||
def get_cyq_perf(
|
||||
ts_code: str,
|
||||
trade_date: Optional[str] = None,
|
||||
ts_code: Optional[str] = None,
|
||||
start_date: Optional[str] = None,
|
||||
end_date: Optional[str] = None,
|
||||
client: Optional[TushareClient] = None,
|
||||
@@ -24,9 +25,10 @@ def get_cyq_perf(
|
||||
for A-share stocks. Data starts from 2018.
|
||||
|
||||
Args:
|
||||
ts_code: Stock code (e.g., '000001.SZ', '600000.SH')
|
||||
start_date: Start date in YYYYMMDD format
|
||||
end_date: End date in YYYYMMDD format
|
||||
trade_date: Specific trade date in YYYYMMDD format
|
||||
ts_code: Stock code filter (optional, e.g., '000001.SZ')
|
||||
start_date: Start date for date range query (YYYYMMDD format)
|
||||
end_date: End date for date range query (YYYYMMDD format)
|
||||
client: Optional TushareClient instance for shared rate limiting.
|
||||
If None, creates a new client. For concurrent sync operations,
|
||||
pass a shared client to ensure proper rate limiting.
|
||||
@@ -46,19 +48,23 @@ def get_cyq_perf(
|
||||
- winner_rate: Win rate (percentage)
|
||||
|
||||
Example:
|
||||
>>> # Get chip distribution data for a stock
|
||||
>>> data = get_cyq_perf('000001.SZ', start_date='20240101', end_date='20240131')
|
||||
>>> # Get all stocks' chip distribution for a single date
|
||||
>>> data = get_cyq_perf(trade_date='20240115')
|
||||
>>>
|
||||
>>> # Get data with shared client for rate limiting
|
||||
>>> from src.data.client import TushareClient
|
||||
>>> client = TushareClient()
|
||||
>>> data = get_cyq_perf('000001.SZ', start_date='20240101', end_date='20240131', client=client)
|
||||
>>> # Get date range data for a specific stock
|
||||
>>> data = get_cyq_perf(ts_code='000001.SZ', start_date='20240101', end_date='20240131')
|
||||
>>>
|
||||
>>> # Get specific stock on specific date
|
||||
>>> data = get_cyq_perf(ts_code='000001.SZ', trade_date='20240115')
|
||||
"""
|
||||
client = client or TushareClient()
|
||||
|
||||
# Build parameters
|
||||
params = {"ts_code": ts_code}
|
||||
|
||||
params = {}
|
||||
if trade_date:
|
||||
params["trade_date"] = trade_date
|
||||
if ts_code:
|
||||
params["ts_code"] = ts_code
|
||||
if start_date:
|
||||
params["start_date"] = start_date
|
||||
if end_date:
|
||||
@@ -74,10 +80,10 @@ def get_cyq_perf(
|
||||
return data
|
||||
|
||||
|
||||
class CyqPerfSync(StockBasedSync):
|
||||
class CyqPerfSync(DateBasedSync):
|
||||
"""筹码分布数据批量同步管理器,支持全量/增量同步。
|
||||
|
||||
继承自 StockBasedSync,使用多线程按股票并发获取数据。
|
||||
继承自 DateBasedSync,使用按日期并发获取数据。
|
||||
|
||||
Example:
|
||||
>>> sync = CyqPerfSync()
|
||||
@@ -87,6 +93,7 @@ class CyqPerfSync(StockBasedSync):
|
||||
"""
|
||||
|
||||
table_name = "cyq_perf"
|
||||
default_start_date = "20180101"
|
||||
|
||||
# 表结构定义
|
||||
TABLE_SCHEMA = {
|
||||
@@ -111,52 +118,36 @@ class CyqPerfSync(StockBasedSync):
|
||||
# 主键定义
|
||||
PRIMARY_KEY = ("ts_code", "trade_date")
|
||||
|
||||
def fetch_single_stock(
|
||||
self,
|
||||
ts_code: str,
|
||||
start_date: str,
|
||||
end_date: str,
|
||||
) -> pd.DataFrame:
|
||||
"""获取单只股票的筹码分布数据。
|
||||
def fetch_single_date(self, trade_date: str) -> pd.DataFrame:
|
||||
"""获取单日所有股票的筹码分布数据。
|
||||
|
||||
Args:
|
||||
ts_code: 股票代码
|
||||
start_date: 起始日期(YYYYMMDD)
|
||||
end_date: 结束日期(YYYYMMDD)
|
||||
trade_date: 交易日期(YYYYMMDD)
|
||||
|
||||
Returns:
|
||||
包含筹码分布数据的 DataFrame
|
||||
包含当日所有股票筹码分布数据的 DataFrame
|
||||
"""
|
||||
# 使用 get_cyq_perf 获取数据(传递共享 client)
|
||||
data = get_cyq_perf(
|
||||
ts_code=ts_code,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
client=self.client, # 传递共享客户端以确保限流
|
||||
)
|
||||
return data
|
||||
return get_cyq_perf(trade_date=trade_date, client=self.client)
|
||||
|
||||
|
||||
def sync_cyq_perf(
|
||||
force_full: bool = False,
|
||||
start_date: Optional[str] = None,
|
||||
end_date: Optional[str] = None,
|
||||
max_workers: Optional[int] = None,
|
||||
dry_run: bool = False,
|
||||
) -> dict[str, pd.DataFrame]:
|
||||
"""同步所有股票的筹码分布数据。
|
||||
force_full: bool = False,
|
||||
) -> pd.DataFrame:
|
||||
"""同步筹码分布数据到 DuckDB,支持智能增量同步。
|
||||
|
||||
这是筹码分布数据同步的主要入口点。
|
||||
逻辑:
|
||||
- 若表不存在:创建表 + 复合索引 (trade_date, ts_code) + 全量同步
|
||||
- 若表存在:从 last_date + 1 开始增量同步
|
||||
|
||||
Args:
|
||||
start_date: 起始日期(YYYYMMDD 格式,默认全量从 20180101,增量从 last_date+1)
|
||||
end_date: 结束日期(YYYYMMDD 格式,默认为今天)
|
||||
force_full: 若为 True,强制从 20180101 完整重载
|
||||
start_date: 手动指定起始日期(YYYYMMDD)
|
||||
end_date: 手动指定结束日期(默认为今天)
|
||||
max_workers: 工作线程数(默认: 10)
|
||||
dry_run: 若为 True,仅预览将要同步的内容,不写入数据
|
||||
|
||||
Returns:
|
||||
映射 ts_code 到 DataFrame 的字典
|
||||
包含同步数据的 pd.DataFrame
|
||||
|
||||
Example:
|
||||
>>> # 首次同步(从 20180101 全量加载)
|
||||
@@ -170,49 +161,31 @@ def sync_cyq_perf(
|
||||
>>>
|
||||
>>> # 手动指定日期范围
|
||||
>>> result = sync_cyq_perf(start_date='20240101', end_date='20240131')
|
||||
>>>
|
||||
>>> # 自定义线程数
|
||||
>>> result = sync_cyq_perf(max_workers=20)
|
||||
>>>
|
||||
>>> # Dry run(仅预览)
|
||||
>>> result = sync_cyq_perf(dry_run=True)
|
||||
"""
|
||||
sync_manager = CyqPerfSync(max_workers=max_workers)
|
||||
sync_manager = CyqPerfSync()
|
||||
return sync_manager.sync_all(
|
||||
force_full=force_full,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
dry_run=dry_run,
|
||||
force_full=force_full,
|
||||
)
|
||||
|
||||
|
||||
def preview_cyq_perf_sync(
|
||||
force_full: bool = False,
|
||||
start_date: Optional[str] = None,
|
||||
end_date: Optional[str] = None,
|
||||
force_full: bool = False,
|
||||
sample_size: int = 3,
|
||||
) -> dict:
|
||||
"""预览筹码分布数据同步数据量和样本(不实际同步)。
|
||||
|
||||
这是推荐的方式,可在实际同步前检查将要同步的内容。
|
||||
|
||||
Args:
|
||||
force_full: 若为 True,预览全量同步(从 20180101)
|
||||
start_date: 手动指定起始日期(覆盖自动检测)
|
||||
end_date: 手动指定结束日期(默认为今天)
|
||||
sample_size: 预览用样本股票数量(默认: 3)
|
||||
force_full: 若为 True,预览全量同步(从 20180101)
|
||||
sample_size: 预览天数(默认: 3)
|
||||
|
||||
Returns:
|
||||
包含预览信息的字典:
|
||||
{
|
||||
'sync_needed': bool,
|
||||
'stock_count': int,
|
||||
'start_date': str,
|
||||
'end_date': str,
|
||||
'estimated_records': int,
|
||||
'sample_data': pd.DataFrame,
|
||||
'mode': str, # 'full', 'incremental', 'partial', 或 'none'
|
||||
}
|
||||
包含预览信息的字典
|
||||
|
||||
Example:
|
||||
>>> # 预览将要同步的内容
|
||||
@@ -220,14 +193,11 @@ def preview_cyq_perf_sync(
|
||||
>>>
|
||||
>>> # 预览全量同步
|
||||
>>> preview = preview_cyq_perf_sync(force_full=True)
|
||||
>>>
|
||||
>>> # 预览更多样本
|
||||
>>> preview = preview_cyq_perf_sync(sample_size=5)
|
||||
"""
|
||||
sync_manager = CyqPerfSync()
|
||||
return sync_manager.preview_sync(
|
||||
force_full=force_full,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
force_full=force_full,
|
||||
sample_size=sample_size,
|
||||
)
|
||||
|
||||
@@ -16,6 +16,7 @@ def get_stock_st(
|
||||
start_date: Optional[str] = None,
|
||||
end_date: Optional[str] = None,
|
||||
ts_code: Optional[str] = None,
|
||||
client: Optional[TushareClient] = None,
|
||||
) -> pd.DataFrame:
|
||||
"""Fetch ST stock list from Tushare.
|
||||
|
||||
@@ -28,6 +29,9 @@ def get_stock_st(
|
||||
start_date: Start date for date range query (YYYYMMDD format)
|
||||
end_date: End date for date range query (YYYYMMDD format)
|
||||
ts_code: Stock code filter (optional, e.g., '000001.SZ')
|
||||
client: Optional TushareClient instance for shared rate limiting.
|
||||
If None, creates a new client. For concurrent sync operations,
|
||||
pass a shared client to ensure proper rate limiting.
|
||||
|
||||
Returns:
|
||||
pd.DataFrame with columns:
|
||||
@@ -47,7 +51,7 @@ def get_stock_st(
|
||||
>>> # Get specific stock ST history
|
||||
>>> data = get_stock_st(ts_code='000001.SZ')
|
||||
"""
|
||||
client = TushareClient()
|
||||
client = client or TushareClient()
|
||||
|
||||
# Build parameters
|
||||
params = {}
|
||||
@@ -108,7 +112,7 @@ class StockSTSync(DateBasedSync):
|
||||
Returns:
|
||||
包含当日ST股票列表的 DataFrame
|
||||
"""
|
||||
return get_stock_st(trade_date=trade_date)
|
||||
return get_stock_st(trade_date=trade_date, client=self.client)
|
||||
|
||||
|
||||
def sync_stock_st(
|
||||
|
||||
@@ -1058,9 +1058,9 @@ class DateBasedSync(BaseDataSync):
|
||||
class_name = self.__class__.__name__
|
||||
storage = Storage()
|
||||
|
||||
# 默认结束日期
|
||||
# 默认结束日期(使用带时间逻辑的 get_today_date,9点前返回前一天)
|
||||
if end_date is None:
|
||||
end_date = datetime.now().strftime("%Y%m%d")
|
||||
end_date = get_today_date()
|
||||
|
||||
# 检查表是否存在
|
||||
table_exists = storage.exists(self.table_name)
|
||||
|
||||
@@ -12,6 +12,7 @@ class TushareClient:
|
||||
|
||||
# 类级别共享限流器(确保所有实例共享同一个限流器)
|
||||
_shared_limiter: Optional[TokenBucketRateLimiter] = None
|
||||
_cached_rate_limit: int = 0 # 缓存上次使用的 rate_limit
|
||||
|
||||
def __init__(self, token: Optional[str] = None):
|
||||
"""Initialize client.
|
||||
@@ -29,17 +30,19 @@ class TushareClient:
|
||||
self.config = cfg
|
||||
|
||||
# 初始化共享限流器(确保所有 TushareClient 实例共享同一个限流器)
|
||||
rate_per_second = cfg.rate_limit / 60.0
|
||||
capacity = cfg.rate_limit
|
||||
|
||||
if TushareClient._shared_limiter is None:
|
||||
# 首次创建:初始化共享限流器
|
||||
# 检查是否需要重新创建限流器(配置发生变化时)
|
||||
if (
|
||||
TushareClient._shared_limiter is None
|
||||
or TushareClient._cached_rate_limit != cfg.rate_limit
|
||||
):
|
||||
# 首次创建或配置变更:重新初始化共享限流器
|
||||
TushareClient._shared_limiter = TokenBucketRateLimiter(
|
||||
capacity=capacity,
|
||||
refill_rate_per_second=rate_per_second,
|
||||
rate_limit=cfg.rate_limit,
|
||||
)
|
||||
TushareClient._cached_rate_limit = cfg.rate_limit
|
||||
min_interval = 60.0 / cfg.rate_limit
|
||||
print(
|
||||
f"[TushareClient] Initialized shared rate limiter: capacity={capacity}, window=60s"
|
||||
f"[TushareClient] Initialized shared rate limiter: rate={cfg.rate_limit}/min, interval={min_interval:.2f}s"
|
||||
)
|
||||
# 复用共享限流器
|
||||
self.rate_limiter = TushareClient._shared_limiter
|
||||
@@ -65,21 +68,17 @@ class TushareClient:
|
||||
Returns:
|
||||
DataFrame with query results
|
||||
"""
|
||||
# Acquire rate limit token (None = wait indefinitely)
|
||||
timeout = timeout if timeout is not None else float("inf")
|
||||
success, wait_time = self.rate_limiter.acquire(timeout=timeout)
|
||||
|
||||
if not success:
|
||||
raise RuntimeError(f"Rate limit exceeded after {timeout}s timeout")
|
||||
|
||||
if wait_time > 0:
|
||||
pass # Silent wait
|
||||
|
||||
# Execute with retry
|
||||
max_retries = 3
|
||||
retry_delays = [1, 3, 10]
|
||||
|
||||
for attempt in range(max_retries):
|
||||
# Acquire rate limit token before each attempt (including retries)
|
||||
success, wait_time = self.rate_limiter.acquire(timeout=timeout)
|
||||
|
||||
if not success:
|
||||
raise RuntimeError(f"Rate limit exceeded after {timeout}s timeout")
|
||||
|
||||
try:
|
||||
import tushare as ts
|
||||
|
||||
@@ -108,10 +107,18 @@ class TushareClient:
|
||||
return data
|
||||
|
||||
except Exception as e:
|
||||
error_msg = str(e)
|
||||
if attempt < max_retries - 1:
|
||||
delay = retry_delays[attempt]
|
||||
# 如果触发 Tushare 限流,增加等待时间避开惩罚期
|
||||
if "最多访问该接口" in error_msg:
|
||||
delay = max(delay, 60)
|
||||
print(
|
||||
f"[RateLimit] {api_name} hit Tushare limit, waiting {delay}s..."
|
||||
)
|
||||
|
||||
print(
|
||||
f"[Retry] {api_name} failed (attempt {attempt + 1}): {e}, retry in {delay}s"
|
||||
f"[Retry] {api_name} failed (attempt {attempt + 1}): {error_msg}, retry in {delay}s"
|
||||
)
|
||||
time.sleep(delay)
|
||||
else:
|
||||
|
||||
@@ -1,12 +1,13 @@
|
||||
"""API 速率限制器实现。
|
||||
|
||||
提供基于固定时间窗口的速率限制,适合 Tushare 等按分钟计费的 API。
|
||||
提供基于固定时间间隔的速率限制,强制两次请求之间保持最小时间间隔。
|
||||
适合 Tushare 等需要严格控制请求频率的 API。
|
||||
"""
|
||||
|
||||
import time
|
||||
import threading
|
||||
from typing import Optional
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -17,178 +18,142 @@ class RateLimiterStats:
|
||||
successful_requests: int = 0
|
||||
denied_requests: int = 0
|
||||
total_wait_time: float = 0.0
|
||||
current_window_requests: int = 0
|
||||
window_start_time: float = 0.0
|
||||
last_request_time: Optional[float] = None # 上次请求开始时间
|
||||
|
||||
|
||||
class TokenBucketRateLimiter:
|
||||
"""基于固定时间窗口的速率限制器。
|
||||
"""基于固定时间间隔的速率限制器。
|
||||
|
||||
适合 Tushare 等按时间窗口(如每分钟)限制请求数的 API 场景。
|
||||
在窗口期内,请求数达到上限后将阻塞或等待下一个窗口。
|
||||
强制两次请求之间保持最小时间间隔,无论请求处理耗时多久。
|
||||
适合需要严格控制请求频率、避免触发服务端限流的场景。
|
||||
|
||||
Attributes:
|
||||
capacity: 每个时间窗口内允许的最大请求数
|
||||
window_seconds: 时间窗口长度(秒)
|
||||
rate_limit: 每分钟允许的请求数
|
||||
min_interval: 两次请求之间的最小时间间隔(秒)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
capacity: int = 100,
|
||||
refill_rate_per_second: float = 1.67,
|
||||
initial_tokens: Optional[int] = None,
|
||||
rate_limit: int = 150,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
"""初始化速率限制器。
|
||||
|
||||
Args:
|
||||
capacity: 每个时间窗口内允许的最大请求数
|
||||
refill_rate_per_second: 保留参数(向后兼容),实际使用 window_seconds=60
|
||||
initial_tokens: 保留参数(向后兼容)
|
||||
rate_limit: 每分钟允许的请求数(默认 150)
|
||||
"""
|
||||
self.capacity = capacity
|
||||
# Tushare 通常按分钟限制,所以固定使用 60 秒窗口
|
||||
self.window_seconds = 60.0
|
||||
self.rate_limit = rate_limit
|
||||
# 计算最小间隔:60秒 / 每分钟请求数
|
||||
self.min_interval = 60.0 / rate_limit
|
||||
|
||||
self._requests_in_window = 0
|
||||
self._window_start = time.monotonic()
|
||||
self._lock = threading.RLock()
|
||||
self._stats = RateLimiterStats()
|
||||
self._stats.window_start_time = self._window_start
|
||||
|
||||
def _is_new_window(self) -> bool:
|
||||
"""检查是否已进入新的时间窗口。"""
|
||||
current_time = time.monotonic()
|
||||
elapsed = current_time - self._window_start
|
||||
return elapsed >= self.window_seconds
|
||||
|
||||
def _reset_window(self) -> None:
|
||||
"""重置时间窗口。"""
|
||||
self._window_start = time.monotonic()
|
||||
self._requests_in_window = 0
|
||||
self._stats.window_start_time = self._window_start
|
||||
|
||||
def acquire(self, timeout: float = float("inf")) -> tuple[bool, float]:
|
||||
"""获取请求许可。
|
||||
"""获取请求许可,确保与上次请求间隔足够时间。
|
||||
|
||||
如果在当前窗口内请求数已达上限,则等待到下一个窗口。
|
||||
会等待直到距离上次请求的时间 >= min_interval。
|
||||
注意:全程加锁,确保多线程下严格串行执行。
|
||||
|
||||
Args:
|
||||
timeout: 最大等待时间(秒),默认无限等待
|
||||
|
||||
Returns:
|
||||
(success, wait_time): 是否成功获取许可,以及等待时间
|
||||
(success, wait_time): 是否成功,以及实际等待时间
|
||||
"""
|
||||
start_time = time.monotonic()
|
||||
|
||||
with self._lock:
|
||||
# 检查是否需要进入新窗口
|
||||
if self._is_new_window():
|
||||
self._reset_window()
|
||||
now = time.monotonic()
|
||||
|
||||
# 如果当前窗口还有余量,直接通过
|
||||
if self._requests_in_window < self.capacity:
|
||||
self._requests_in_window += 1
|
||||
self._stats.total_requests += 1
|
||||
self._stats.successful_requests += 1
|
||||
self._stats.current_window_requests = self._requests_in_window
|
||||
return True, 0.0
|
||||
# 计算距离上次请求的时间
|
||||
if self._stats.last_request_time is not None:
|
||||
elapsed = now - self._stats.last_request_time
|
||||
time_to_wait = self.min_interval - elapsed
|
||||
|
||||
# 当前窗口已满,计算需要等待的时间
|
||||
current_time = time.monotonic()
|
||||
time_to_next_window = self.window_seconds - (
|
||||
current_time - self._window_start
|
||||
)
|
||||
if time_to_wait > 0:
|
||||
# 需要等待
|
||||
if timeout != float("inf") and time_to_wait > timeout:
|
||||
# 超过最大等待时间
|
||||
self._stats.total_requests += 1
|
||||
self._stats.denied_requests += 1
|
||||
return False, time_to_wait
|
||||
|
||||
if time_to_next_window <= 0:
|
||||
# 刚好进入新窗口
|
||||
self._reset_window()
|
||||
self._requests_in_window = 1
|
||||
self._stats.total_requests += 1
|
||||
self._stats.successful_requests += 1
|
||||
self._stats.current_window_requests = 1
|
||||
return True, 0.0
|
||||
# 在锁内等待(全程加锁,确保多线程严格串行)
|
||||
time.sleep(time_to_wait)
|
||||
now = time.monotonic()
|
||||
|
||||
# 检查是否能在超时时间内等待
|
||||
if timeout != float("inf") and time_to_next_window > timeout:
|
||||
self._stats.total_requests += 1
|
||||
self._stats.denied_requests += 1
|
||||
return False, timeout
|
||||
# 更新上次请求时间(请求开始前)
|
||||
self._stats.last_request_time = now
|
||||
wait_time = now - start_time
|
||||
|
||||
# 需要等待到下一个窗口
|
||||
if timeout != float("inf"):
|
||||
time_to_wait = min(time_to_next_window, timeout)
|
||||
else:
|
||||
time_to_wait = time_to_next_window
|
||||
self._stats.total_requests += 1
|
||||
self._stats.successful_requests += 1
|
||||
self._stats.total_wait_time += wait_time
|
||||
|
||||
time.sleep(time_to_wait)
|
||||
|
||||
# 重新尝试获取许可
|
||||
with self._lock:
|
||||
# 再次检查窗口状态(可能其他线程已经重置了窗口)
|
||||
if self._is_new_window():
|
||||
self._reset_window()
|
||||
|
||||
if self._requests_in_window < self.capacity:
|
||||
self._requests_in_window += 1
|
||||
wait_time = time.monotonic() - start_time
|
||||
self._stats.total_requests += 1
|
||||
self._stats.successful_requests += 1
|
||||
self._stats.total_wait_time += wait_time
|
||||
self._stats.current_window_requests = self._requests_in_window
|
||||
return True, wait_time
|
||||
else:
|
||||
# 在极端情况下,等待后仍然无法获取(其他线程抢先)
|
||||
wait_time = time.monotonic() - start_time
|
||||
self._stats.total_requests += 1
|
||||
self._stats.denied_requests += 1
|
||||
return False, wait_time
|
||||
return True, wait_time
|
||||
|
||||
def acquire_nonblocking(self) -> tuple[bool, float]:
|
||||
"""尝试非阻塞地获取请求许可。
|
||||
|
||||
Returns:
|
||||
(success, wait_time): 是否成功获取许可,以及需要等待的时间
|
||||
(success, wait_time): 是否成功,以及需要等待的时间
|
||||
"""
|
||||
with self._lock:
|
||||
# 检查是否需要进入新窗口
|
||||
if self._is_new_window():
|
||||
self._reset_window()
|
||||
now = time.monotonic()
|
||||
|
||||
# 如果当前窗口还有余量,直接通过
|
||||
if self._requests_in_window < self.capacity:
|
||||
self._requests_in_window += 1
|
||||
self._stats.total_requests += 1
|
||||
self._stats.successful_requests += 1
|
||||
self._stats.current_window_requests = self._requests_in_window
|
||||
return True, 0.0
|
||||
if self._stats.last_request_time is not None:
|
||||
elapsed = now - self._stats.last_request_time
|
||||
time_to_wait = self.min_interval - elapsed
|
||||
|
||||
# 当前窗口已满,计算需要等待的时间
|
||||
current_time = time.monotonic()
|
||||
time_to_next_window = self.window_seconds - (
|
||||
current_time - self._window_start
|
||||
)
|
||||
if time_to_wait > 0:
|
||||
self._stats.total_requests += 1
|
||||
self._stats.denied_requests += 1
|
||||
return False, time_to_wait
|
||||
|
||||
# 立即获得许可
|
||||
self._stats.last_request_time = now
|
||||
self._stats.total_requests += 1
|
||||
self._stats.denied_requests += 1
|
||||
return False, max(0.0, time_to_next_window)
|
||||
self._stats.successful_requests += 1
|
||||
return True, 0.0
|
||||
|
||||
def get_available_tokens(self) -> float:
|
||||
"""获取当前窗口剩余可用请求数。
|
||||
def get_min_interval(self) -> float:
|
||||
"""获取最小请求间隔。
|
||||
|
||||
Returns:
|
||||
当前窗口剩余可用请求数
|
||||
两次请求之间的最小时间间隔(秒)
|
||||
"""
|
||||
return self.min_interval
|
||||
|
||||
def get_time_until_next_request(self) -> float:
|
||||
"""获取距离下次允许请求的时间。
|
||||
|
||||
Returns:
|
||||
距离下次请求还需要等待的时间(秒),0 表示可以立即请求
|
||||
"""
|
||||
with self._lock:
|
||||
if self._is_new_window():
|
||||
return float(self.capacity)
|
||||
return float(self.capacity - self._requests_in_window)
|
||||
if self._stats.last_request_time is None:
|
||||
return 0.0
|
||||
|
||||
elapsed = time.monotonic() - self._stats.last_request_time
|
||||
return max(0.0, self.min_interval - elapsed)
|
||||
|
||||
def get_stats(self) -> RateLimiterStats:
|
||||
"""获取速率限制器统计信息。
|
||||
|
||||
Returns:
|
||||
RateLimiterStats 实例
|
||||
RateLimiterStats 实例的副本
|
||||
"""
|
||||
with self._lock:
|
||||
self._stats.current_window_requests = self._requests_in_window
|
||||
return self._stats
|
||||
return RateLimiterStats(
|
||||
total_requests=self._stats.total_requests,
|
||||
successful_requests=self._stats.successful_requests,
|
||||
denied_requests=self._stats.denied_requests,
|
||||
total_wait_time=self._stats.total_wait_time,
|
||||
last_request_time=self._stats.last_request_time,
|
||||
)
|
||||
|
||||
def reset(self) -> None:
|
||||
"""重置限流器状态(调试用)。"""
|
||||
with self._lock:
|
||||
self._stats = RateLimiterStats()
|
||||
|
||||
@@ -14,13 +14,36 @@ DEFAULT_START_DATE = "20180101"
|
||||
TODAY: str = datetime.now().strftime("%Y%m%d")
|
||||
|
||||
|
||||
def get_today_date() -> str:
|
||||
def get_today_date(cutoff_hour: int = 9) -> str:
|
||||
"""获取今日日期(YYYYMMDD 格式)。
|
||||
|
||||
考虑数据生成时间的逻辑:在 cutoff_hour 点之前,返回前一天的日期,
|
||||
因为当天的数据还未生成。A股数据通常在交易日收盘后(约 15:00-19:00)
|
||||
生成,但为了保险起见,默认使用早上 9 点作为分界。
|
||||
|
||||
Args:
|
||||
cutoff_hour: 时间分界点(小时,24小时制),默认为 9。
|
||||
当前时间小于此值时,返回前一天日期。
|
||||
|
||||
Returns:
|
||||
今日日期字符串,格式为 YYYYMMDD
|
||||
日期字符串,格式为 YYYYMMDD
|
||||
|
||||
Example:
|
||||
>>> # 假设当前是 2024-01-15 08:30
|
||||
>>> get_today_date() # 返回 '20240114'(前一天)
|
||||
>>>
|
||||
>>> # 假设当前是 2024-01-15 10:00
|
||||
>>> get_today_date() # 返回 '20240115'(当天)
|
||||
>>>
|
||||
>>> # 使用自定义分界点
|
||||
>>> get_today_date(cutoff_hour=15) # 15点前返回前一天
|
||||
"""
|
||||
return TODAY
|
||||
now = datetime.now()
|
||||
if now.hour < cutoff_hour:
|
||||
# 在分界点之前,返回前一天
|
||||
prev_dt = now - timedelta(days=1)
|
||||
return prev_dt.strftime("%Y%m%d")
|
||||
return now.strftime("%Y%m%d")
|
||||
|
||||
|
||||
def get_next_date(date_str: str) -> str:
|
||||
|
||||
Reference in New Issue
Block a user