feat(data): 添加每日筹码及胜率数据接口 (cyq_perf)

- 新增 api_cyq_perf 模块,支持筹码分布数据获取和同步
- 在 sync_registry 中注册 cyq_perf 同步器
This commit is contained in:
2026-03-26 22:22:43 +08:00
parent 6730acbae1
commit d4e0e2a0b6
9 changed files with 261 additions and 230 deletions

View File

@@ -531,6 +531,7 @@ def get_{data_type}(
start_date: Optional[str] = None, start_date: Optional[str] = None,
end_date: Optional[str] = None, end_date: Optional[str] = None,
ts_code: Optional[str] = None, ts_code: Optional[str] = None,
client: Optional[TushareClient] = None, # 关键:可选客户端参数,用于共享速率限制
) -> pd.DataFrame: ) -> pd.DataFrame:
"""Fetch {数据描述} from Tushare. """Fetch {数据描述} from Tushare.
@@ -541,6 +542,9 @@ def get_{data_type}(
start_date: Start date (YYYYMMDD format) start_date: Start date (YYYYMMDD format)
end_date: End date (YYYYMMDD format) end_date: End date (YYYYMMDD format)
ts_code: Stock code filter (optional) ts_code: Stock code filter (optional)
client: Optional TushareClient instance for shared rate limiting.
If None, creates a new client. For concurrent sync operations,
pass a shared client to ensure proper rate limiting.
Returns: Returns:
pd.DataFrame with columns: pd.DataFrame with columns:
@@ -556,7 +560,7 @@ def get_{data_type}(
>>> # Get date range data >>> # Get date range data
>>> data = get_{data_type}(start_date='20240101', end_date='20240131') >>> data = get_{data_type}(start_date='20240101', end_date='20240131')
""" """
client = TushareClient() client = client or TushareClient() # 如果没有提供则创建新实例
# Build parameters # Build parameters
params = {} params = {}
@@ -596,6 +600,7 @@ def get_{data_type}(
ts_code: str, ts_code: str,
start_date: Optional[str] = None, start_date: Optional[str] = None,
end_date: Optional[str] = None, end_date: Optional[str] = None,
client: Optional[TushareClient] = None, # 关键:可选客户端参数,用于共享速率限制
) -> pd.DataFrame: ) -> pd.DataFrame:
"""Fetch {数据描述} for a specific stock. """Fetch {数据描述} for a specific stock.
@@ -603,11 +608,14 @@ def get_{data_type}(
ts_code: Stock code (e.g., '000001.SZ') ts_code: Stock code (e.g., '000001.SZ')
start_date: Start date (YYYYMMDD format) start_date: Start date (YYYYMMDD format)
end_date: End date (YYYYMMDD format) end_date: End date (YYYYMMDD format)
client: Optional TushareClient instance for shared rate limiting.
If None, creates a new client. For concurrent sync operations,
pass a shared client to ensure proper rate limiting.
Returns: Returns:
pd.DataFrame with {数据描述} data pd.DataFrame with {数据描述} data
""" """
client = TushareClient() client = client or TushareClient() # 如果没有提供则创建新实例
params = {"ts_code": ts_code} params = {"ts_code": ts_code}
if start_date: if start_date:
@@ -751,6 +759,8 @@ Skill 会自动:
- [ ] 已创建 `tests/test_{data_type}.py` 测试文件 - [ ] 已创建 `tests/test_{data_type}.py` 测试文件
### 10.2 接口实现 ### 10.2 接口实现
- [ ] 数据获取函数使用 `TushareClient` - [ ] 数据获取函数使用 `TushareClient`
- [ ] **关键**:数据获取函数接受 `client: Optional[TushareClient] = None` 参数用于共享速率限制
- [ ] **关键**Sync 类在 `fetch_single_date()` / `fetch_single_stock()` 中传递 `self.client`
- [ ] 函数包含完整的 Google 风格文档字符串 - [ ] 函数包含完整的 Google 风格文档字符串
- [ ] 日期参数使用 `YYYYMMDD` 格式 - [ ] 日期参数使用 `YYYYMMDD` 格式
- [ ] 返回的 DataFrame 包含 `ts_code``trade_date` 字段 - [ ] 返回的 DataFrame 包含 `ts_code``trade_date` 字段
@@ -790,6 +800,6 @@ Skill 会自动:
--- ---
**最后更新**: 2026-02-23 **最后更新**: 2026-03-26
**版本**: v2.0 - 更新 DuckDB 存储规范,添加 Skill 自动化说明 **版本**: v2.1 - 更新速率限制规范,强调多线程场景下 client 参数传递

View File

@@ -184,6 +184,57 @@ def get_xxx(period: str, fields: Optional[str] = None) -> pd.DataFrame:
--- ---
## 速率限制规范(关键)
### 问题背景
财务数据同步使用 VIP 接口(如 `income_vip``balancesheet_vip`)按季度获取全市场数据。在并发场景下,如果每个线程创建独立的 `TushareClient` 实例,每个实例会有独立的令牌桶限流器,导致**限流失效**。
**实际案例**
- 配置 `RATE_LIMIT=150`,理论上每分钟最多 150 次请求
- 如果 10 个线程各自创建独立客户端,实际并发数 = 10 × 150 = 1500 次/分钟
- 结果:触发 Tushare API 限流,请求失败
### 解决方案
**必须**在数据获取函数中接受可选的 `client` 参数,并在同步类中传递共享实例:
```python
from src.data.client import TushareClient
from typing import Optional
# 1. 数据获取函数必须支持 client 参数
def get_{data_type}(
period: str,
client: Optional[TushareClient] = None, # 关键参数
) -> pd.DataFrame:
"""Fetch financial data.
Args:
period: 报告期YYYYMMDD
client: Optional TushareClient for shared rate limiting
"""
client = client or TushareClient() # 如果没有提供则创建新实例
return client.query("{api_name}", period=period)
# 2. 同步类中传递共享 client
class XXXQuarterSync(QuarterBasedSync):
def fetch_single_quarter(self, period: str) -> pd.DataFrame:
# 使用 self.client基类创建的共享实例
return get_{data_type}(period=period, client=self.client)
```
### 关键规则
1. **数据获取函数**:必须接受 `client: Optional[TushareClient] = None` 参数
2. **同步类实现**:必须在 `fetch_single_quarter()` 中传递 `self.client`
3. **基类保证**`QuarterBasedSync` 基类在 `__init__` 中创建 `self.client = TushareClient()`
4. **使用模式**:数据获取函数使用 `client = client or TushareClient()` 模式
**注意**`TushareClient` 内部使用**类级别共享限流器**`_shared_limiter`),确保所有实例共享同一个令牌桶,但前提是必须复用同一个客户端实例。
---
## 类设计规范 ## 类设计规范
### 类命名规范 ### 类命名规范
@@ -1284,6 +1335,7 @@ self.storage.flush()
| 日期 | 版本 | 变更内容 | | 日期 | 版本 | 变更内容 |
|------|------|----------| |------|------|----------|
| 2026-03-26 | v1.4 | 添加速率限制规范:<br>- 强调多线程场景下 client 参数传递<br>- 添加实际案例分析<br>- 说明 TushareClient 共享限流器机制 |
| 2026-03-08 | v1.3 | 现金流量表接口实现:<br>- 完成 `api_cashflow.py` 封装<br>- 添加 95 个现金流量表完整字段<br>- 更新调度中心注册<br>- 更新文档标记现金流为已实现 | | 2026-03-08 | v1.3 | 现金流量表接口实现:<br>- 完成 `api_cashflow.py` 封装<br>- 添加 95 个现金流量表完整字段<br>- 更新调度中心注册<br>- 更新文档标记现金流为已实现 |
| 2026-03-08 | v1.2 | 资产负债表接口实现:<br>- 完成 `api_balance.py` 封装<br>- 添加 157 个资产负债表完整字段<br>- 更新调度中心注册<br>- 更新文档中的资产负债表示例为完整实现 | | 2026-03-08 | v1.2 | 资产负债表接口实现:<br>- 完成 `api_balance.py` 封装<br>- 添加 157 个资产负债表完整字段<br>- 更新调度中心注册<br>- 更新文档中的资产负债表示例为完整实现 |
| 2026-03-08 | v1.1 | 完善实际编码细节:<br>- 添加首次同步优化说明<br>- 添加日期格式转换规范<br>- 添加存储层 UPSERT 禁用说明<br>- 添加删除计数处理说明<br>- 扩充常见问题Q7-Q9 | | 2026-03-08 | v1.1 | 完善实际编码细节:<br>- 添加首次同步优化说明<br>- 添加日期格式转换规范<br>- 添加存储层 UPSERT 禁用说明<br>- 添加删除计数处理说明<br>- 扩充常见问题Q7-Q9 |

View File

@@ -29,7 +29,7 @@ Example:
>>> bak_basic = get_bak_basic(trade_date='20240101') >>> bak_basic = get_bak_basic(trade_date='20240101')
>>> stock_st = get_stock_st(trade_date='20240101') >>> stock_st = get_stock_st(trade_date='20240101')
>>> stk_limit = get_stk_limit(trade_date='20240101') >>> stk_limit = get_stk_limit(trade_date='20240101')
>>> cyq_perf = get_cyq_perf('000001.SZ', start_date='20240101', end_date='20240131') >>> cyq_perf = get_cyq_perf(trade_date='20240115')
""" """
from src.data.api_wrappers.api_daily_basic import ( from src.data.api_wrappers.api_daily_basic import (

View File

@@ -9,11 +9,12 @@ import pandas as pd
from typing import Optional from typing import Optional
from src.data.client import TushareClient from src.data.client import TushareClient
from src.data.api_wrappers.base_sync import StockBasedSync from src.data.api_wrappers.base_sync import DateBasedSync
def get_cyq_perf( def get_cyq_perf(
ts_code: str, trade_date: Optional[str] = None,
ts_code: Optional[str] = None,
start_date: Optional[str] = None, start_date: Optional[str] = None,
end_date: Optional[str] = None, end_date: Optional[str] = None,
client: Optional[TushareClient] = None, client: Optional[TushareClient] = None,
@@ -24,9 +25,10 @@ def get_cyq_perf(
for A-share stocks. Data starts from 2018. for A-share stocks. Data starts from 2018.
Args: Args:
ts_code: Stock code (e.g., '000001.SZ', '600000.SH') trade_date: Specific trade date in YYYYMMDD format
start_date: Start date in YYYYMMDD format ts_code: Stock code filter (optional, e.g., '000001.SZ')
end_date: End date in YYYYMMDD format start_date: Start date for date range query (YYYYMMDD format)
end_date: End date for date range query (YYYYMMDD format)
client: Optional TushareClient instance for shared rate limiting. client: Optional TushareClient instance for shared rate limiting.
If None, creates a new client. For concurrent sync operations, If None, creates a new client. For concurrent sync operations,
pass a shared client to ensure proper rate limiting. pass a shared client to ensure proper rate limiting.
@@ -46,19 +48,23 @@ def get_cyq_perf(
- winner_rate: Win rate (percentage) - winner_rate: Win rate (percentage)
Example: Example:
>>> # Get chip distribution data for a stock >>> # Get all stocks' chip distribution for a single date
>>> data = get_cyq_perf('000001.SZ', start_date='20240101', end_date='20240131') >>> data = get_cyq_perf(trade_date='20240115')
>>> >>>
>>> # Get data with shared client for rate limiting >>> # Get date range data for a specific stock
>>> from src.data.client import TushareClient >>> data = get_cyq_perf(ts_code='000001.SZ', start_date='20240101', end_date='20240131')
>>> client = TushareClient() >>>
>>> data = get_cyq_perf('000001.SZ', start_date='20240101', end_date='20240131', client=client) >>> # Get specific stock on specific date
>>> data = get_cyq_perf(ts_code='000001.SZ', trade_date='20240115')
""" """
client = client or TushareClient() client = client or TushareClient()
# Build parameters # Build parameters
params = {"ts_code": ts_code} params = {}
if trade_date:
params["trade_date"] = trade_date
if ts_code:
params["ts_code"] = ts_code
if start_date: if start_date:
params["start_date"] = start_date params["start_date"] = start_date
if end_date: if end_date:
@@ -74,10 +80,10 @@ def get_cyq_perf(
return data return data
class CyqPerfSync(StockBasedSync): class CyqPerfSync(DateBasedSync):
"""筹码分布数据批量同步管理器,支持全量/增量同步。 """筹码分布数据批量同步管理器,支持全量/增量同步。
继承自 StockBasedSync使用多线程按股票并发获取数据。 继承自 DateBasedSync使用按日期并发获取数据。
Example: Example:
>>> sync = CyqPerfSync() >>> sync = CyqPerfSync()
@@ -87,6 +93,7 @@ class CyqPerfSync(StockBasedSync):
""" """
table_name = "cyq_perf" table_name = "cyq_perf"
default_start_date = "20180101"
# 表结构定义 # 表结构定义
TABLE_SCHEMA = { TABLE_SCHEMA = {
@@ -111,52 +118,36 @@ class CyqPerfSync(StockBasedSync):
# 主键定义 # 主键定义
PRIMARY_KEY = ("ts_code", "trade_date") PRIMARY_KEY = ("ts_code", "trade_date")
def fetch_single_stock( def fetch_single_date(self, trade_date: str) -> pd.DataFrame:
self, """获取单日所有股票的筹码分布数据。
ts_code: str,
start_date: str,
end_date: str,
) -> pd.DataFrame:
"""获取单只股票的筹码分布数据。
Args: Args:
ts_code: 股票代码 trade_date: 交易日期YYYYMMDD
start_date: 起始日期YYYYMMDD
end_date: 结束日期YYYYMMDD
Returns: Returns:
包含筹码分布数据的 DataFrame 包含当日所有股票筹码分布数据的 DataFrame
""" """
# 使用 get_cyq_perf 获取数据(传递共享 client return get_cyq_perf(trade_date=trade_date, client=self.client)
data = get_cyq_perf(
ts_code=ts_code,
start_date=start_date,
end_date=end_date,
client=self.client, # 传递共享客户端以确保限流
)
return data
def sync_cyq_perf( def sync_cyq_perf(
force_full: bool = False,
start_date: Optional[str] = None, start_date: Optional[str] = None,
end_date: Optional[str] = None, end_date: Optional[str] = None,
max_workers: Optional[int] = None, force_full: bool = False,
dry_run: bool = False, ) -> pd.DataFrame:
) -> dict[str, pd.DataFrame]: """同步筹码分布数据到 DuckDB支持智能增量同步。
"""同步所有股票的筹码分布数据。
这是筹码分布数据同步的主要入口点。 逻辑:
- 若表不存在:创建表 + 复合索引 (trade_date, ts_code) + 全量同步
- 若表存在:从 last_date + 1 开始增量同步
Args: Args:
start_date: 起始日期YYYYMMDD 格式,默认全量从 20180101增量从 last_date+1
end_date: 结束日期YYYYMMDD 格式,默认为今天)
force_full: 若为 True强制从 20180101 完整重载 force_full: 若为 True强制从 20180101 完整重载
start_date: 手动指定起始日期YYYYMMDD
end_date: 手动指定结束日期(默认为今天)
max_workers: 工作线程数(默认: 10
dry_run: 若为 True仅预览将要同步的内容不写入数据
Returns: Returns:
映射 ts_code 到 DataFrame 的字典 包含同步数据的 pd.DataFrame
Example: Example:
>>> # 首次同步(从 20180101 全量加载) >>> # 首次同步(从 20180101 全量加载)
@@ -170,49 +161,31 @@ def sync_cyq_perf(
>>> >>>
>>> # 手动指定日期范围 >>> # 手动指定日期范围
>>> result = sync_cyq_perf(start_date='20240101', end_date='20240131') >>> result = sync_cyq_perf(start_date='20240101', end_date='20240131')
>>>
>>> # 自定义线程数
>>> result = sync_cyq_perf(max_workers=20)
>>>
>>> # Dry run仅预览
>>> result = sync_cyq_perf(dry_run=True)
""" """
sync_manager = CyqPerfSync(max_workers=max_workers) sync_manager = CyqPerfSync()
return sync_manager.sync_all( return sync_manager.sync_all(
force_full=force_full,
start_date=start_date, start_date=start_date,
end_date=end_date, end_date=end_date,
dry_run=dry_run, force_full=force_full,
) )
def preview_cyq_perf_sync( def preview_cyq_perf_sync(
force_full: bool = False,
start_date: Optional[str] = None, start_date: Optional[str] = None,
end_date: Optional[str] = None, end_date: Optional[str] = None,
force_full: bool = False,
sample_size: int = 3, sample_size: int = 3,
) -> dict: ) -> dict:
"""预览筹码分布数据同步数据量和样本(不实际同步)。 """预览筹码分布数据同步数据量和样本(不实际同步)。
这是推荐的方式,可在实际同步前检查将要同步的内容。
Args: Args:
force_full: 若为 True预览全量同步从 20180101
start_date: 手动指定起始日期(覆盖自动检测) start_date: 手动指定起始日期(覆盖自动检测)
end_date: 手动指定结束日期(默认为今天) end_date: 手动指定结束日期(默认为今天)
sample_size: 预览用样本股票数量(默认: 3 force_full: 若为 True预览全量同步从 20180101
sample_size: 预览天数(默认: 3
Returns: Returns:
包含预览信息的字典 包含预览信息的字典
{
'sync_needed': bool,
'stock_count': int,
'start_date': str,
'end_date': str,
'estimated_records': int,
'sample_data': pd.DataFrame,
'mode': str, # 'full', 'incremental', 'partial', 或 'none'
}
Example: Example:
>>> # 预览将要同步的内容 >>> # 预览将要同步的内容
@@ -220,14 +193,11 @@ def preview_cyq_perf_sync(
>>> >>>
>>> # 预览全量同步 >>> # 预览全量同步
>>> preview = preview_cyq_perf_sync(force_full=True) >>> preview = preview_cyq_perf_sync(force_full=True)
>>>
>>> # 预览更多样本
>>> preview = preview_cyq_perf_sync(sample_size=5)
""" """
sync_manager = CyqPerfSync() sync_manager = CyqPerfSync()
return sync_manager.preview_sync( return sync_manager.preview_sync(
force_full=force_full,
start_date=start_date, start_date=start_date,
end_date=end_date, end_date=end_date,
force_full=force_full,
sample_size=sample_size, sample_size=sample_size,
) )

View File

@@ -16,6 +16,7 @@ def get_stock_st(
start_date: Optional[str] = None, start_date: Optional[str] = None,
end_date: Optional[str] = None, end_date: Optional[str] = None,
ts_code: Optional[str] = None, ts_code: Optional[str] = None,
client: Optional[TushareClient] = None,
) -> pd.DataFrame: ) -> pd.DataFrame:
"""Fetch ST stock list from Tushare. """Fetch ST stock list from Tushare.
@@ -28,6 +29,9 @@ def get_stock_st(
start_date: Start date for date range query (YYYYMMDD format) start_date: Start date for date range query (YYYYMMDD format)
end_date: End date for date range query (YYYYMMDD format) end_date: End date for date range query (YYYYMMDD format)
ts_code: Stock code filter (optional, e.g., '000001.SZ') ts_code: Stock code filter (optional, e.g., '000001.SZ')
client: Optional TushareClient instance for shared rate limiting.
If None, creates a new client. For concurrent sync operations,
pass a shared client to ensure proper rate limiting.
Returns: Returns:
pd.DataFrame with columns: pd.DataFrame with columns:
@@ -47,7 +51,7 @@ def get_stock_st(
>>> # Get specific stock ST history >>> # Get specific stock ST history
>>> data = get_stock_st(ts_code='000001.SZ') >>> data = get_stock_st(ts_code='000001.SZ')
""" """
client = TushareClient() client = client or TushareClient()
# Build parameters # Build parameters
params = {} params = {}
@@ -108,7 +112,7 @@ class StockSTSync(DateBasedSync):
Returns: Returns:
包含当日ST股票列表的 DataFrame 包含当日ST股票列表的 DataFrame
""" """
return get_stock_st(trade_date=trade_date) return get_stock_st(trade_date=trade_date, client=self.client)
def sync_stock_st( def sync_stock_st(

View File

@@ -1058,9 +1058,9 @@ class DateBasedSync(BaseDataSync):
class_name = self.__class__.__name__ class_name = self.__class__.__name__
storage = Storage() storage = Storage()
# 默认结束日期 # 默认结束日期(使用带时间逻辑的 get_today_date9点前返回前一天
if end_date is None: if end_date is None:
end_date = datetime.now().strftime("%Y%m%d") end_date = get_today_date()
# 检查表是否存在 # 检查表是否存在
table_exists = storage.exists(self.table_name) table_exists = storage.exists(self.table_name)

View File

@@ -12,6 +12,7 @@ class TushareClient:
# 类级别共享限流器(确保所有实例共享同一个限流器) # 类级别共享限流器(确保所有实例共享同一个限流器)
_shared_limiter: Optional[TokenBucketRateLimiter] = None _shared_limiter: Optional[TokenBucketRateLimiter] = None
_cached_rate_limit: int = 0 # 缓存上次使用的 rate_limit
def __init__(self, token: Optional[str] = None): def __init__(self, token: Optional[str] = None):
"""Initialize client. """Initialize client.
@@ -29,17 +30,19 @@ class TushareClient:
self.config = cfg self.config = cfg
# 初始化共享限流器(确保所有 TushareClient 实例共享同一个限流器) # 初始化共享限流器(确保所有 TushareClient 实例共享同一个限流器)
rate_per_second = cfg.rate_limit / 60.0 # 检查是否需要重新创建限流器(配置发生变化时)
capacity = cfg.rate_limit if (
TushareClient._shared_limiter is None
if TushareClient._shared_limiter is None: or TushareClient._cached_rate_limit != cfg.rate_limit
# 首次创建:初始化共享限流器 ):
# 首次创建或配置变更:重新初始化共享限流器
TushareClient._shared_limiter = TokenBucketRateLimiter( TushareClient._shared_limiter = TokenBucketRateLimiter(
capacity=capacity, rate_limit=cfg.rate_limit,
refill_rate_per_second=rate_per_second,
) )
TushareClient._cached_rate_limit = cfg.rate_limit
min_interval = 60.0 / cfg.rate_limit
print( print(
f"[TushareClient] Initialized shared rate limiter: capacity={capacity}, window=60s" f"[TushareClient] Initialized shared rate limiter: rate={cfg.rate_limit}/min, interval={min_interval:.2f}s"
) )
# 复用共享限流器 # 复用共享限流器
self.rate_limiter = TushareClient._shared_limiter self.rate_limiter = TushareClient._shared_limiter
@@ -65,21 +68,17 @@ class TushareClient:
Returns: Returns:
DataFrame with query results DataFrame with query results
""" """
# Acquire rate limit token (None = wait indefinitely)
timeout = timeout if timeout is not None else float("inf") timeout = timeout if timeout is not None else float("inf")
max_retries = 3
retry_delays = [1, 3, 10]
for attempt in range(max_retries):
# Acquire rate limit token before each attempt (including retries)
success, wait_time = self.rate_limiter.acquire(timeout=timeout) success, wait_time = self.rate_limiter.acquire(timeout=timeout)
if not success: if not success:
raise RuntimeError(f"Rate limit exceeded after {timeout}s timeout") raise RuntimeError(f"Rate limit exceeded after {timeout}s timeout")
if wait_time > 0:
pass # Silent wait
# Execute with retry
max_retries = 3
retry_delays = [1, 3, 10]
for attempt in range(max_retries):
try: try:
import tushare as ts import tushare as ts
@@ -108,10 +107,18 @@ class TushareClient:
return data return data
except Exception as e: except Exception as e:
error_msg = str(e)
if attempt < max_retries - 1: if attempt < max_retries - 1:
delay = retry_delays[attempt] delay = retry_delays[attempt]
# 如果触发 Tushare 限流,增加等待时间避开惩罚期
if "最多访问该接口" in error_msg:
delay = max(delay, 60)
print( print(
f"[Retry] {api_name} failed (attempt {attempt + 1}): {e}, retry in {delay}s" f"[RateLimit] {api_name} hit Tushare limit, waiting {delay}s..."
)
print(
f"[Retry] {api_name} failed (attempt {attempt + 1}): {error_msg}, retry in {delay}s"
) )
time.sleep(delay) time.sleep(delay)
else: else:

View File

@@ -1,12 +1,13 @@
"""API 速率限制器实现。 """API 速率限制器实现。
提供基于固定时间窗口的速率限制,适合 Tushare 等按分钟计费的 API 提供基于固定时间间隔的速率限制,强制两次请求之间保持最小时间间隔
适合 Tushare 等需要严格控制请求频率的 API。
""" """
import time import time
import threading import threading
from typing import Optional from typing import Optional
from dataclasses import dataclass from dataclasses import dataclass, field
@dataclass @dataclass
@@ -17,178 +18,142 @@ class RateLimiterStats:
successful_requests: int = 0 successful_requests: int = 0
denied_requests: int = 0 denied_requests: int = 0
total_wait_time: float = 0.0 total_wait_time: float = 0.0
current_window_requests: int = 0 last_request_time: Optional[float] = None # 上次请求开始时间
window_start_time: float = 0.0
class TokenBucketRateLimiter: class TokenBucketRateLimiter:
"""基于固定时间窗口的速率限制器。 """基于固定时间间隔的速率限制器。
适合 Tushare 等按时间窗口(如每分钟)限制请求数的 API 场景 强制两次请求之间保持最小时间间隔,无论请求处理耗时多久
在窗口期内,请求数达到上限后将阻塞或等待下一个窗口 适合需要严格控制请求频率、避免触发服务端限流的场景
Attributes: Attributes:
capacity: 每个时间窗口内允许的最大请求数 rate_limit: 每分钟允许的请求数
window_seconds: 时间窗口长度(秒) min_interval: 两次请求之间的最小时间间隔(秒)
""" """
def __init__( def __init__(
self, self,
capacity: int = 100, rate_limit: int = 150,
refill_rate_per_second: float = 1.67, **kwargs,
initial_tokens: Optional[int] = None,
) -> None: ) -> None:
"""初始化速率限制器。 """初始化速率限制器。
Args: Args:
capacity: 每个时间窗口内允许的最大请求数 rate_limit: 每分钟允许的请求数(默认 150
refill_rate_per_second: 保留参数(向后兼容),实际使用 window_seconds=60
initial_tokens: 保留参数(向后兼容)
""" """
self.capacity = capacity self.rate_limit = rate_limit
# Tushare 通常按分钟限制,所以固定使用 60 秒窗口 # 计算最小间隔60秒 / 每分钟请求数
self.window_seconds = 60.0 self.min_interval = 60.0 / rate_limit
self._requests_in_window = 0
self._window_start = time.monotonic()
self._lock = threading.RLock() self._lock = threading.RLock()
self._stats = RateLimiterStats() self._stats = RateLimiterStats()
self._stats.window_start_time = self._window_start
def _is_new_window(self) -> bool:
"""检查是否已进入新的时间窗口。"""
current_time = time.monotonic()
elapsed = current_time - self._window_start
return elapsed >= self.window_seconds
def _reset_window(self) -> None:
"""重置时间窗口。"""
self._window_start = time.monotonic()
self._requests_in_window = 0
self._stats.window_start_time = self._window_start
def acquire(self, timeout: float = float("inf")) -> tuple[bool, float]: def acquire(self, timeout: float = float("inf")) -> tuple[bool, float]:
"""获取请求许可。 """获取请求许可,确保与上次请求间隔足够时间
如果在当前窗口内请求数已达上限,则等待到下一个窗口 会等待直到距离上次请求的时间 >= min_interval
注意:全程加锁,确保多线程下严格串行执行。
Args: Args:
timeout: 最大等待时间(秒),默认无限等待 timeout: 最大等待时间(秒),默认无限等待
Returns: Returns:
(success, wait_time): 是否成功获取许可,以及等待时间 (success, wait_time): 是否成功,以及实际等待时间
""" """
start_time = time.monotonic() start_time = time.monotonic()
with self._lock: with self._lock:
# 检查是否需要进入新窗口 now = time.monotonic()
if self._is_new_window():
self._reset_window()
# 如果当前窗口还有余量,直接通过 # 计算距离上次请求的时间
if self._requests_in_window < self.capacity: if self._stats.last_request_time is not None:
self._requests_in_window += 1 elapsed = now - self._stats.last_request_time
self._stats.total_requests += 1 time_to_wait = self.min_interval - elapsed
self._stats.successful_requests += 1
self._stats.current_window_requests = self._requests_in_window
return True, 0.0
# 当前窗口已满,计算需要等待的时间 if time_to_wait > 0:
current_time = time.monotonic() # 需要等待
time_to_next_window = self.window_seconds - ( if timeout != float("inf") and time_to_wait > timeout:
current_time - self._window_start # 超过最大等待时间
)
if time_to_next_window <= 0:
# 刚好进入新窗口
self._reset_window()
self._requests_in_window = 1
self._stats.total_requests += 1
self._stats.successful_requests += 1
self._stats.current_window_requests = 1
return True, 0.0
# 检查是否能在超时时间内等待
if timeout != float("inf") and time_to_next_window > timeout:
self._stats.total_requests += 1 self._stats.total_requests += 1
self._stats.denied_requests += 1 self._stats.denied_requests += 1
return False, timeout return False, time_to_wait
# 需要等待到下一个窗口
if timeout != float("inf"):
time_to_wait = min(time_to_next_window, timeout)
else:
time_to_wait = time_to_next_window
# 在锁内等待(全程加锁,确保多线程严格串行)
time.sleep(time_to_wait) time.sleep(time_to_wait)
now = time.monotonic()
# 重新尝试获取许可 # 更新上次请求时间(请求开始前)
with self._lock: self._stats.last_request_time = now
# 再次检查窗口状态(可能其他线程已经重置了窗口) wait_time = now - start_time
if self._is_new_window():
self._reset_window()
if self._requests_in_window < self.capacity:
self._requests_in_window += 1
wait_time = time.monotonic() - start_time
self._stats.total_requests += 1 self._stats.total_requests += 1
self._stats.successful_requests += 1 self._stats.successful_requests += 1
self._stats.total_wait_time += wait_time self._stats.total_wait_time += wait_time
self._stats.current_window_requests = self._requests_in_window
return True, wait_time return True, wait_time
else:
# 在极端情况下,等待后仍然无法获取(其他线程抢先)
wait_time = time.monotonic() - start_time
self._stats.total_requests += 1
self._stats.denied_requests += 1
return False, wait_time
def acquire_nonblocking(self) -> tuple[bool, float]: def acquire_nonblocking(self) -> tuple[bool, float]:
"""尝试非阻塞地获取请求许可。 """尝试非阻塞地获取请求许可。
Returns: Returns:
(success, wait_time): 是否成功获取许可,以及需要等待的时间 (success, wait_time): 是否成功,以及需要等待的时间
""" """
with self._lock: with self._lock:
# 检查是否需要进入新窗口 now = time.monotonic()
if self._is_new_window():
self._reset_window()
# 如果当前窗口还有余量,直接通过 if self._stats.last_request_time is not None:
if self._requests_in_window < self.capacity: elapsed = now - self._stats.last_request_time
self._requests_in_window += 1 time_to_wait = self.min_interval - elapsed
self._stats.total_requests += 1
self._stats.successful_requests += 1
self._stats.current_window_requests = self._requests_in_window
return True, 0.0
# 当前窗口已满,计算需要等待的时间
current_time = time.monotonic()
time_to_next_window = self.window_seconds - (
current_time - self._window_start
)
if time_to_wait > 0:
self._stats.total_requests += 1 self._stats.total_requests += 1
self._stats.denied_requests += 1 self._stats.denied_requests += 1
return False, max(0.0, time_to_next_window) return False, time_to_wait
def get_available_tokens(self) -> float: # 立即获得许可
"""获取当前窗口剩余可用请求数。 self._stats.last_request_time = now
self._stats.total_requests += 1
self._stats.successful_requests += 1
return True, 0.0
def get_min_interval(self) -> float:
"""获取最小请求间隔。
Returns: Returns:
当前窗口剩余可用请求数 两次请求之间的最小时间间隔(秒)
"""
return self.min_interval
def get_time_until_next_request(self) -> float:
"""获取距离下次允许请求的时间。
Returns:
距离下次请求还需要等待的时间0 表示可以立即请求
""" """
with self._lock: with self._lock:
if self._is_new_window(): if self._stats.last_request_time is None:
return float(self.capacity) return 0.0
return float(self.capacity - self._requests_in_window)
elapsed = time.monotonic() - self._stats.last_request_time
return max(0.0, self.min_interval - elapsed)
def get_stats(self) -> RateLimiterStats: def get_stats(self) -> RateLimiterStats:
"""获取速率限制器统计信息。 """获取速率限制器统计信息。
Returns: Returns:
RateLimiterStats 实例 RateLimiterStats 实例的副本
""" """
with self._lock: with self._lock:
self._stats.current_window_requests = self._requests_in_window return RateLimiterStats(
return self._stats total_requests=self._stats.total_requests,
successful_requests=self._stats.successful_requests,
denied_requests=self._stats.denied_requests,
total_wait_time=self._stats.total_wait_time,
last_request_time=self._stats.last_request_time,
)
def reset(self) -> None:
"""重置限流器状态(调试用)。"""
with self._lock:
self._stats = RateLimiterStats()

View File

@@ -14,13 +14,36 @@ DEFAULT_START_DATE = "20180101"
TODAY: str = datetime.now().strftime("%Y%m%d") TODAY: str = datetime.now().strftime("%Y%m%d")
def get_today_date() -> str: def get_today_date(cutoff_hour: int = 9) -> str:
"""获取今日日期YYYYMMDD 格式)。 """获取今日日期YYYYMMDD 格式)。
考虑数据生成时间的逻辑:在 cutoff_hour 点之前,返回前一天的日期,
因为当天的数据还未生成。A股数据通常在交易日收盘后约 15:00-19:00
生成,但为了保险起见,默认使用早上 9 点作为分界。
Args:
cutoff_hour: 时间分界点小时24小时制默认为 9。
当前时间小于此值时,返回前一天日期。
Returns: Returns:
今日日期字符串,格式为 YYYYMMDD 日期字符串,格式为 YYYYMMDD
Example:
>>> # 假设当前是 2024-01-15 08:30
>>> get_today_date() # 返回 '20240114'(前一天)
>>>
>>> # 假设当前是 2024-01-15 10:00
>>> get_today_date() # 返回 '20240115'(当天)
>>>
>>> # 使用自定义分界点
>>> get_today_date(cutoff_hour=15) # 15点前返回前一天
""" """
return TODAY now = datetime.now()
if now.hour < cutoff_hour:
# 在分界点之前,返回前一天
prev_dt = now - timedelta(days=1)
return prev_dt.strftime("%Y%m%d")
return now.strftime("%Y%m%d")
def get_next_date(date_str: str) -> str: def get_next_date(date_str: str) -> str: