diff --git a/docs/api/API_INTERFACE_SPEC.md b/docs/api/API_INTERFACE_SPEC.md index fb0681c..81e5a11 100644 --- a/docs/api/API_INTERFACE_SPEC.md +++ b/docs/api/API_INTERFACE_SPEC.md @@ -531,6 +531,7 @@ def get_{data_type}( start_date: Optional[str] = None, end_date: Optional[str] = None, ts_code: Optional[str] = None, + client: Optional[TushareClient] = None, # 关键:可选客户端参数,用于共享速率限制 ) -> pd.DataFrame: """Fetch {数据描述} from Tushare. @@ -541,6 +542,9 @@ def get_{data_type}( start_date: Start date (YYYYMMDD format) end_date: End date (YYYYMMDD format) ts_code: Stock code filter (optional) + client: Optional TushareClient instance for shared rate limiting. + If None, creates a new client. For concurrent sync operations, + pass a shared client to ensure proper rate limiting. Returns: pd.DataFrame with columns: @@ -552,12 +556,12 @@ def get_{data_type}( Example: >>> # Get all stocks for a single date >>> data = get_{data_type}(trade_date='20240101') - >>> + >>> >>> # Get date range data >>> data = get_{data_type}(start_date='20240101', end_date='20240131') """ - client = TushareClient() - + client = client or TushareClient() # 如果没有提供则创建新实例 + # Build parameters params = {} if trade_date: @@ -568,14 +572,14 @@ def get_{data_type}( params["end_date"] = end_date if ts_code: params["ts_code"] = ts_code - + # Fetch data data = client.query("{tushare_api_name}", **params) - + # Rename date column if needed if "date" in data.columns: data = data.rename(columns={"date": "trade_date"}) - + return data ``` @@ -596,6 +600,7 @@ def get_{data_type}( ts_code: str, start_date: Optional[str] = None, end_date: Optional[str] = None, + client: Optional[TushareClient] = None, # 关键:可选客户端参数,用于共享速率限制 ) -> pd.DataFrame: """Fetch {数据描述} for a specific stock. @@ -603,20 +608,23 @@ def get_{data_type}( ts_code: Stock code (e.g., '000001.SZ') start_date: Start date (YYYYMMDD format) end_date: End date (YYYYMMDD format) + client: Optional TushareClient instance for shared rate limiting. + If None, creates a new client. For concurrent sync operations, + pass a shared client to ensure proper rate limiting. Returns: pd.DataFrame with {数据描述} data """ - client = TushareClient() - + client = client or TushareClient() # 如果没有提供则创建新实例 + params = {"ts_code": ts_code} if start_date: params["start_date"] = start_date if end_date: params["end_date"] = end_date - + data = client.query("{tushare_api_name}", **params) - + return data ``` @@ -751,6 +759,8 @@ Skill 会自动: - [ ] 已创建 `tests/test_{data_type}.py` 测试文件 ### 10.2 接口实现 - [ ] 数据获取函数使用 `TushareClient` +- [ ] **关键**:数据获取函数接受 `client: Optional[TushareClient] = None` 参数用于共享速率限制 +- [ ] **关键**:Sync 类在 `fetch_single_date()` / `fetch_single_stock()` 中传递 `self.client` - [ ] 函数包含完整的 Google 风格文档字符串 - [ ] 日期参数使用 `YYYYMMDD` 格式 - [ ] 返回的 DataFrame 包含 `ts_code` 和 `trade_date` 字段 @@ -790,6 +800,6 @@ Skill 会自动: --- -**最后更新**: 2026-02-23 +**最后更新**: 2026-03-26 -**版本**: v2.0 - 更新 DuckDB 存储规范,添加 Skill 自动化说明 \ No newline at end of file +**版本**: v2.1 - 更新速率限制规范,强调多线程场景下 client 参数传递 \ No newline at end of file diff --git a/docs/api/FINANCIAL_API_SPEC.md b/docs/api/FINANCIAL_API_SPEC.md index 1340eab..f0fcb4f 100644 --- a/docs/api/FINANCIAL_API_SPEC.md +++ b/docs/api/FINANCIAL_API_SPEC.md @@ -184,6 +184,57 @@ def get_xxx(period: str, fields: Optional[str] = None) -> pd.DataFrame: --- +## 速率限制规范(关键) + +### 问题背景 + +财务数据同步使用 VIP 接口(如 `income_vip`、`balancesheet_vip`)按季度获取全市场数据。在并发场景下,如果每个线程创建独立的 `TushareClient` 实例,每个实例会有独立的令牌桶限流器,导致**限流失效**。 + +**实际案例**: +- 配置 `RATE_LIMIT=150`,理论上每分钟最多 150 次请求 +- 如果 10 个线程各自创建独立客户端,实际并发数 = 10 × 150 = 1500 次/分钟 +- 结果:触发 Tushare API 限流,请求失败 + +### 解决方案 + +**必须**在数据获取函数中接受可选的 `client` 参数,并在同步类中传递共享实例: + +```python +from src.data.client import TushareClient +from typing import Optional + +# 1. 数据获取函数必须支持 client 参数 +def get_{data_type}( + period: str, + client: Optional[TushareClient] = None, # 关键参数 +) -> pd.DataFrame: + """Fetch financial data. + + Args: + period: 报告期(YYYYMMDD) + client: Optional TushareClient for shared rate limiting + """ + client = client or TushareClient() # 如果没有提供则创建新实例 + return client.query("{api_name}", period=period) + +# 2. 同步类中传递共享 client +class XXXQuarterSync(QuarterBasedSync): + def fetch_single_quarter(self, period: str) -> pd.DataFrame: + # 使用 self.client(基类创建的共享实例) + return get_{data_type}(period=period, client=self.client) +``` + +### 关键规则 + +1. **数据获取函数**:必须接受 `client: Optional[TushareClient] = None` 参数 +2. **同步类实现**:必须在 `fetch_single_quarter()` 中传递 `self.client` +3. **基类保证**:`QuarterBasedSync` 基类在 `__init__` 中创建 `self.client = TushareClient()` +4. **使用模式**:数据获取函数使用 `client = client or TushareClient()` 模式 + +**注意**:`TushareClient` 内部使用**类级别共享限流器**(`_shared_limiter`),确保所有实例共享同一个令牌桶,但前提是必须复用同一个客户端实例。 + +--- + ## 类设计规范 ### 类命名规范 @@ -1284,6 +1335,7 @@ self.storage.flush() | 日期 | 版本 | 变更内容 | |------|------|----------| +| 2026-03-26 | v1.4 | 添加速率限制规范:
- 强调多线程场景下 client 参数传递
- 添加实际案例分析
- 说明 TushareClient 共享限流器机制 | | 2026-03-08 | v1.3 | 现金流量表接口实现:
- 完成 `api_cashflow.py` 封装
- 添加 95 个现金流量表完整字段
- 更新调度中心注册
- 更新文档标记现金流为已实现 | | 2026-03-08 | v1.2 | 资产负债表接口实现:
- 完成 `api_balance.py` 封装
- 添加 157 个资产负债表完整字段
- 更新调度中心注册
- 更新文档中的资产负债表示例为完整实现 | | 2026-03-08 | v1.1 | 完善实际编码细节:
- 添加首次同步优化说明
- 添加日期格式转换规范
- 添加存储层 UPSERT 禁用说明
- 添加删除计数处理说明
- 扩充常见问题(Q7-Q9) | diff --git a/src/data/api_wrappers/__init__.py b/src/data/api_wrappers/__init__.py index b7ab77e..1e0c02d 100644 --- a/src/data/api_wrappers/__init__.py +++ b/src/data/api_wrappers/__init__.py @@ -29,7 +29,7 @@ Example: >>> bak_basic = get_bak_basic(trade_date='20240101') >>> stock_st = get_stock_st(trade_date='20240101') >>> stk_limit = get_stk_limit(trade_date='20240101') - >>> cyq_perf = get_cyq_perf('000001.SZ', start_date='20240101', end_date='20240131') + >>> cyq_perf = get_cyq_perf(trade_date='20240115') """ from src.data.api_wrappers.api_daily_basic import ( diff --git a/src/data/api_wrappers/api_cyq_perf.py b/src/data/api_wrappers/api_cyq_perf.py index b2f7a96..7dc267c 100644 --- a/src/data/api_wrappers/api_cyq_perf.py +++ b/src/data/api_wrappers/api_cyq_perf.py @@ -9,11 +9,12 @@ import pandas as pd from typing import Optional from src.data.client import TushareClient -from src.data.api_wrappers.base_sync import StockBasedSync +from src.data.api_wrappers.base_sync import DateBasedSync def get_cyq_perf( - ts_code: str, + trade_date: Optional[str] = None, + ts_code: Optional[str] = None, start_date: Optional[str] = None, end_date: Optional[str] = None, client: Optional[TushareClient] = None, @@ -24,9 +25,10 @@ def get_cyq_perf( for A-share stocks. Data starts from 2018. Args: - ts_code: Stock code (e.g., '000001.SZ', '600000.SH') - start_date: Start date in YYYYMMDD format - end_date: End date in YYYYMMDD format + trade_date: Specific trade date in YYYYMMDD format + ts_code: Stock code filter (optional, e.g., '000001.SZ') + start_date: Start date for date range query (YYYYMMDD format) + end_date: End date for date range query (YYYYMMDD format) client: Optional TushareClient instance for shared rate limiting. If None, creates a new client. For concurrent sync operations, pass a shared client to ensure proper rate limiting. @@ -46,19 +48,23 @@ def get_cyq_perf( - winner_rate: Win rate (percentage) Example: - >>> # Get chip distribution data for a stock - >>> data = get_cyq_perf('000001.SZ', start_date='20240101', end_date='20240131') + >>> # Get all stocks' chip distribution for a single date + >>> data = get_cyq_perf(trade_date='20240115') >>> - >>> # Get data with shared client for rate limiting - >>> from src.data.client import TushareClient - >>> client = TushareClient() - >>> data = get_cyq_perf('000001.SZ', start_date='20240101', end_date='20240131', client=client) + >>> # Get date range data for a specific stock + >>> data = get_cyq_perf(ts_code='000001.SZ', start_date='20240101', end_date='20240131') + >>> + >>> # Get specific stock on specific date + >>> data = get_cyq_perf(ts_code='000001.SZ', trade_date='20240115') """ client = client or TushareClient() # Build parameters - params = {"ts_code": ts_code} - + params = {} + if trade_date: + params["trade_date"] = trade_date + if ts_code: + params["ts_code"] = ts_code if start_date: params["start_date"] = start_date if end_date: @@ -74,10 +80,10 @@ def get_cyq_perf( return data -class CyqPerfSync(StockBasedSync): +class CyqPerfSync(DateBasedSync): """筹码分布数据批量同步管理器,支持全量/增量同步。 - 继承自 StockBasedSync,使用多线程按股票并发获取数据。 + 继承自 DateBasedSync,使用按日期并发获取数据。 Example: >>> sync = CyqPerfSync() @@ -87,6 +93,7 @@ class CyqPerfSync(StockBasedSync): """ table_name = "cyq_perf" + default_start_date = "20180101" # 表结构定义 TABLE_SCHEMA = { @@ -111,52 +118,36 @@ class CyqPerfSync(StockBasedSync): # 主键定义 PRIMARY_KEY = ("ts_code", "trade_date") - def fetch_single_stock( - self, - ts_code: str, - start_date: str, - end_date: str, - ) -> pd.DataFrame: - """获取单只股票的筹码分布数据。 + def fetch_single_date(self, trade_date: str) -> pd.DataFrame: + """获取单日所有股票的筹码分布数据。 Args: - ts_code: 股票代码 - start_date: 起始日期(YYYYMMDD) - end_date: 结束日期(YYYYMMDD) + trade_date: 交易日期(YYYYMMDD) Returns: - 包含筹码分布数据的 DataFrame + 包含当日所有股票筹码分布数据的 DataFrame """ - # 使用 get_cyq_perf 获取数据(传递共享 client) - data = get_cyq_perf( - ts_code=ts_code, - start_date=start_date, - end_date=end_date, - client=self.client, # 传递共享客户端以确保限流 - ) - return data + return get_cyq_perf(trade_date=trade_date, client=self.client) def sync_cyq_perf( - force_full: bool = False, start_date: Optional[str] = None, end_date: Optional[str] = None, - max_workers: Optional[int] = None, - dry_run: bool = False, -) -> dict[str, pd.DataFrame]: - """同步所有股票的筹码分布数据。 + force_full: bool = False, +) -> pd.DataFrame: + """同步筹码分布数据到 DuckDB,支持智能增量同步。 - 这是筹码分布数据同步的主要入口点。 + 逻辑: + - 若表不存在:创建表 + 复合索引 (trade_date, ts_code) + 全量同步 + - 若表存在:从 last_date + 1 开始增量同步 Args: + start_date: 起始日期(YYYYMMDD 格式,默认全量从 20180101,增量从 last_date+1) + end_date: 结束日期(YYYYMMDD 格式,默认为今天) force_full: 若为 True,强制从 20180101 完整重载 - start_date: 手动指定起始日期(YYYYMMDD) - end_date: 手动指定结束日期(默认为今天) - max_workers: 工作线程数(默认: 10) - dry_run: 若为 True,仅预览将要同步的内容,不写入数据 Returns: - 映射 ts_code 到 DataFrame 的字典 + 包含同步数据的 pd.DataFrame Example: >>> # 首次同步(从 20180101 全量加载) @@ -170,49 +161,31 @@ def sync_cyq_perf( >>> >>> # 手动指定日期范围 >>> result = sync_cyq_perf(start_date='20240101', end_date='20240131') - >>> - >>> # 自定义线程数 - >>> result = sync_cyq_perf(max_workers=20) - >>> - >>> # Dry run(仅预览) - >>> result = sync_cyq_perf(dry_run=True) """ - sync_manager = CyqPerfSync(max_workers=max_workers) + sync_manager = CyqPerfSync() return sync_manager.sync_all( - force_full=force_full, start_date=start_date, end_date=end_date, - dry_run=dry_run, + force_full=force_full, ) def preview_cyq_perf_sync( - force_full: bool = False, start_date: Optional[str] = None, end_date: Optional[str] = None, + force_full: bool = False, sample_size: int = 3, ) -> dict: """预览筹码分布数据同步数据量和样本(不实际同步)。 - 这是推荐的方式,可在实际同步前检查将要同步的内容。 - Args: - force_full: 若为 True,预览全量同步(从 20180101) start_date: 手动指定起始日期(覆盖自动检测) end_date: 手动指定结束日期(默认为今天) - sample_size: 预览用样本股票数量(默认: 3) + force_full: 若为 True,预览全量同步(从 20180101) + sample_size: 预览天数(默认: 3) Returns: - 包含预览信息的字典: - { - 'sync_needed': bool, - 'stock_count': int, - 'start_date': str, - 'end_date': str, - 'estimated_records': int, - 'sample_data': pd.DataFrame, - 'mode': str, # 'full', 'incremental', 'partial', 或 'none' - } + 包含预览信息的字典 Example: >>> # 预览将要同步的内容 @@ -220,14 +193,11 @@ def preview_cyq_perf_sync( >>> >>> # 预览全量同步 >>> preview = preview_cyq_perf_sync(force_full=True) - >>> - >>> # 预览更多样本 - >>> preview = preview_cyq_perf_sync(sample_size=5) """ sync_manager = CyqPerfSync() return sync_manager.preview_sync( - force_full=force_full, start_date=start_date, end_date=end_date, + force_full=force_full, sample_size=sample_size, ) diff --git a/src/data/api_wrappers/api_stock_st.py b/src/data/api_wrappers/api_stock_st.py index f2dc02c..a10469c 100644 --- a/src/data/api_wrappers/api_stock_st.py +++ b/src/data/api_wrappers/api_stock_st.py @@ -16,6 +16,7 @@ def get_stock_st( start_date: Optional[str] = None, end_date: Optional[str] = None, ts_code: Optional[str] = None, + client: Optional[TushareClient] = None, ) -> pd.DataFrame: """Fetch ST stock list from Tushare. @@ -28,6 +29,9 @@ def get_stock_st( start_date: Start date for date range query (YYYYMMDD format) end_date: End date for date range query (YYYYMMDD format) ts_code: Stock code filter (optional, e.g., '000001.SZ') + client: Optional TushareClient instance for shared rate limiting. + If None, creates a new client. For concurrent sync operations, + pass a shared client to ensure proper rate limiting. Returns: pd.DataFrame with columns: @@ -47,7 +51,7 @@ def get_stock_st( >>> # Get specific stock ST history >>> data = get_stock_st(ts_code='000001.SZ') """ - client = TushareClient() + client = client or TushareClient() # Build parameters params = {} @@ -108,7 +112,7 @@ class StockSTSync(DateBasedSync): Returns: 包含当日ST股票列表的 DataFrame """ - return get_stock_st(trade_date=trade_date) + return get_stock_st(trade_date=trade_date, client=self.client) def sync_stock_st( diff --git a/src/data/api_wrappers/base_sync.py b/src/data/api_wrappers/base_sync.py index 0bbd1ec..59bd0b5 100644 --- a/src/data/api_wrappers/base_sync.py +++ b/src/data/api_wrappers/base_sync.py @@ -1058,9 +1058,9 @@ class DateBasedSync(BaseDataSync): class_name = self.__class__.__name__ storage = Storage() - # 默认结束日期 + # 默认结束日期(使用带时间逻辑的 get_today_date,9点前返回前一天) if end_date is None: - end_date = datetime.now().strftime("%Y%m%d") + end_date = get_today_date() # 检查表是否存在 table_exists = storage.exists(self.table_name) diff --git a/src/data/client.py b/src/data/client.py index 920e803..36125d8 100644 --- a/src/data/client.py +++ b/src/data/client.py @@ -12,6 +12,7 @@ class TushareClient: # 类级别共享限流器(确保所有实例共享同一个限流器) _shared_limiter: Optional[TokenBucketRateLimiter] = None + _cached_rate_limit: int = 0 # 缓存上次使用的 rate_limit def __init__(self, token: Optional[str] = None): """Initialize client. @@ -29,17 +30,19 @@ class TushareClient: self.config = cfg # 初始化共享限流器(确保所有 TushareClient 实例共享同一个限流器) - rate_per_second = cfg.rate_limit / 60.0 - capacity = cfg.rate_limit - - if TushareClient._shared_limiter is None: - # 首次创建:初始化共享限流器 + # 检查是否需要重新创建限流器(配置发生变化时) + if ( + TushareClient._shared_limiter is None + or TushareClient._cached_rate_limit != cfg.rate_limit + ): + # 首次创建或配置变更:重新初始化共享限流器 TushareClient._shared_limiter = TokenBucketRateLimiter( - capacity=capacity, - refill_rate_per_second=rate_per_second, + rate_limit=cfg.rate_limit, ) + TushareClient._cached_rate_limit = cfg.rate_limit + min_interval = 60.0 / cfg.rate_limit print( - f"[TushareClient] Initialized shared rate limiter: capacity={capacity}, window=60s" + f"[TushareClient] Initialized shared rate limiter: rate={cfg.rate_limit}/min, interval={min_interval:.2f}s" ) # 复用共享限流器 self.rate_limiter = TushareClient._shared_limiter @@ -65,21 +68,17 @@ class TushareClient: Returns: DataFrame with query results """ - # Acquire rate limit token (None = wait indefinitely) timeout = timeout if timeout is not None else float("inf") - success, wait_time = self.rate_limiter.acquire(timeout=timeout) - - if not success: - raise RuntimeError(f"Rate limit exceeded after {timeout}s timeout") - - if wait_time > 0: - pass # Silent wait - - # Execute with retry max_retries = 3 retry_delays = [1, 3, 10] for attempt in range(max_retries): + # Acquire rate limit token before each attempt (including retries) + success, wait_time = self.rate_limiter.acquire(timeout=timeout) + + if not success: + raise RuntimeError(f"Rate limit exceeded after {timeout}s timeout") + try: import tushare as ts @@ -108,10 +107,18 @@ class TushareClient: return data except Exception as e: + error_msg = str(e) if attempt < max_retries - 1: delay = retry_delays[attempt] + # 如果触发 Tushare 限流,增加等待时间避开惩罚期 + if "最多访问该接口" in error_msg: + delay = max(delay, 60) + print( + f"[RateLimit] {api_name} hit Tushare limit, waiting {delay}s..." + ) + print( - f"[Retry] {api_name} failed (attempt {attempt + 1}): {e}, retry in {delay}s" + f"[Retry] {api_name} failed (attempt {attempt + 1}): {error_msg}, retry in {delay}s" ) time.sleep(delay) else: diff --git a/src/data/rate_limiter.py b/src/data/rate_limiter.py index 2b80d0d..d6f06b0 100644 --- a/src/data/rate_limiter.py +++ b/src/data/rate_limiter.py @@ -1,12 +1,13 @@ """API 速率限制器实现。 -提供基于固定时间窗口的速率限制,适合 Tushare 等按分钟计费的 API。 +提供基于固定时间间隔的速率限制,强制两次请求之间保持最小时间间隔。 +适合 Tushare 等需要严格控制请求频率的 API。 """ import time import threading from typing import Optional -from dataclasses import dataclass +from dataclasses import dataclass, field @dataclass @@ -17,178 +18,142 @@ class RateLimiterStats: successful_requests: int = 0 denied_requests: int = 0 total_wait_time: float = 0.0 - current_window_requests: int = 0 - window_start_time: float = 0.0 + last_request_time: Optional[float] = None # 上次请求开始时间 class TokenBucketRateLimiter: - """基于固定时间窗口的速率限制器。 + """基于固定时间间隔的速率限制器。 - 适合 Tushare 等按时间窗口(如每分钟)限制请求数的 API 场景。 - 在窗口期内,请求数达到上限后将阻塞或等待下一个窗口。 + 强制两次请求之间保持最小时间间隔,无论请求处理耗时多久。 + 适合需要严格控制请求频率、避免触发服务端限流的场景。 Attributes: - capacity: 每个时间窗口内允许的最大请求数 - window_seconds: 时间窗口长度(秒) + rate_limit: 每分钟允许的请求数 + min_interval: 两次请求之间的最小时间间隔(秒) """ def __init__( self, - capacity: int = 100, - refill_rate_per_second: float = 1.67, - initial_tokens: Optional[int] = None, + rate_limit: int = 150, + **kwargs, ) -> None: """初始化速率限制器。 Args: - capacity: 每个时间窗口内允许的最大请求数 - refill_rate_per_second: 保留参数(向后兼容),实际使用 window_seconds=60 - initial_tokens: 保留参数(向后兼容) + rate_limit: 每分钟允许的请求数(默认 150) """ - self.capacity = capacity - # Tushare 通常按分钟限制,所以固定使用 60 秒窗口 - self.window_seconds = 60.0 + self.rate_limit = rate_limit + # 计算最小间隔:60秒 / 每分钟请求数 + self.min_interval = 60.0 / rate_limit - self._requests_in_window = 0 - self._window_start = time.monotonic() self._lock = threading.RLock() self._stats = RateLimiterStats() - self._stats.window_start_time = self._window_start - - def _is_new_window(self) -> bool: - """检查是否已进入新的时间窗口。""" - current_time = time.monotonic() - elapsed = current_time - self._window_start - return elapsed >= self.window_seconds - - def _reset_window(self) -> None: - """重置时间窗口。""" - self._window_start = time.monotonic() - self._requests_in_window = 0 - self._stats.window_start_time = self._window_start def acquire(self, timeout: float = float("inf")) -> tuple[bool, float]: - """获取请求许可。 + """获取请求许可,确保与上次请求间隔足够时间。 - 如果在当前窗口内请求数已达上限,则等待到下一个窗口。 + 会等待直到距离上次请求的时间 >= min_interval。 + 注意:全程加锁,确保多线程下严格串行执行。 Args: timeout: 最大等待时间(秒),默认无限等待 Returns: - (success, wait_time): 是否成功获取许可,以及等待时间 + (success, wait_time): 是否成功,以及实际等待时间 """ start_time = time.monotonic() with self._lock: - # 检查是否需要进入新窗口 - if self._is_new_window(): - self._reset_window() + now = time.monotonic() - # 如果当前窗口还有余量,直接通过 - if self._requests_in_window < self.capacity: - self._requests_in_window += 1 - self._stats.total_requests += 1 - self._stats.successful_requests += 1 - self._stats.current_window_requests = self._requests_in_window - return True, 0.0 + # 计算距离上次请求的时间 + if self._stats.last_request_time is not None: + elapsed = now - self._stats.last_request_time + time_to_wait = self.min_interval - elapsed - # 当前窗口已满,计算需要等待的时间 - current_time = time.monotonic() - time_to_next_window = self.window_seconds - ( - current_time - self._window_start - ) + if time_to_wait > 0: + # 需要等待 + if timeout != float("inf") and time_to_wait > timeout: + # 超过最大等待时间 + self._stats.total_requests += 1 + self._stats.denied_requests += 1 + return False, time_to_wait - if time_to_next_window <= 0: - # 刚好进入新窗口 - self._reset_window() - self._requests_in_window = 1 - self._stats.total_requests += 1 - self._stats.successful_requests += 1 - self._stats.current_window_requests = 1 - return True, 0.0 + # 在锁内等待(全程加锁,确保多线程严格串行) + time.sleep(time_to_wait) + now = time.monotonic() - # 检查是否能在超时时间内等待 - if timeout != float("inf") and time_to_next_window > timeout: - self._stats.total_requests += 1 - self._stats.denied_requests += 1 - return False, timeout + # 更新上次请求时间(请求开始前) + self._stats.last_request_time = now + wait_time = now - start_time - # 需要等待到下一个窗口 - if timeout != float("inf"): - time_to_wait = min(time_to_next_window, timeout) - else: - time_to_wait = time_to_next_window + self._stats.total_requests += 1 + self._stats.successful_requests += 1 + self._stats.total_wait_time += wait_time - time.sleep(time_to_wait) - - # 重新尝试获取许可 - with self._lock: - # 再次检查窗口状态(可能其他线程已经重置了窗口) - if self._is_new_window(): - self._reset_window() - - if self._requests_in_window < self.capacity: - self._requests_in_window += 1 - wait_time = time.monotonic() - start_time - self._stats.total_requests += 1 - self._stats.successful_requests += 1 - self._stats.total_wait_time += wait_time - self._stats.current_window_requests = self._requests_in_window - return True, wait_time - else: - # 在极端情况下,等待后仍然无法获取(其他线程抢先) - wait_time = time.monotonic() - start_time - self._stats.total_requests += 1 - self._stats.denied_requests += 1 - return False, wait_time + return True, wait_time def acquire_nonblocking(self) -> tuple[bool, float]: """尝试非阻塞地获取请求许可。 Returns: - (success, wait_time): 是否成功获取许可,以及需要等待的时间 + (success, wait_time): 是否成功,以及需要等待的时间 """ with self._lock: - # 检查是否需要进入新窗口 - if self._is_new_window(): - self._reset_window() + now = time.monotonic() - # 如果当前窗口还有余量,直接通过 - if self._requests_in_window < self.capacity: - self._requests_in_window += 1 - self._stats.total_requests += 1 - self._stats.successful_requests += 1 - self._stats.current_window_requests = self._requests_in_window - return True, 0.0 + if self._stats.last_request_time is not None: + elapsed = now - self._stats.last_request_time + time_to_wait = self.min_interval - elapsed - # 当前窗口已满,计算需要等待的时间 - current_time = time.monotonic() - time_to_next_window = self.window_seconds - ( - current_time - self._window_start - ) + if time_to_wait > 0: + self._stats.total_requests += 1 + self._stats.denied_requests += 1 + return False, time_to_wait + # 立即获得许可 + self._stats.last_request_time = now self._stats.total_requests += 1 - self._stats.denied_requests += 1 - return False, max(0.0, time_to_next_window) + self._stats.successful_requests += 1 + return True, 0.0 - def get_available_tokens(self) -> float: - """获取当前窗口剩余可用请求数。 + def get_min_interval(self) -> float: + """获取最小请求间隔。 Returns: - 当前窗口剩余可用请求数 + 两次请求之间的最小时间间隔(秒) + """ + return self.min_interval + + def get_time_until_next_request(self) -> float: + """获取距离下次允许请求的时间。 + + Returns: + 距离下次请求还需要等待的时间(秒),0 表示可以立即请求 """ with self._lock: - if self._is_new_window(): - return float(self.capacity) - return float(self.capacity - self._requests_in_window) + if self._stats.last_request_time is None: + return 0.0 + + elapsed = time.monotonic() - self._stats.last_request_time + return max(0.0, self.min_interval - elapsed) def get_stats(self) -> RateLimiterStats: """获取速率限制器统计信息。 Returns: - RateLimiterStats 实例 + RateLimiterStats 实例的副本 """ with self._lock: - self._stats.current_window_requests = self._requests_in_window - return self._stats + return RateLimiterStats( + total_requests=self._stats.total_requests, + successful_requests=self._stats.successful_requests, + denied_requests=self._stats.denied_requests, + total_wait_time=self._stats.total_wait_time, + last_request_time=self._stats.last_request_time, + ) + + def reset(self) -> None: + """重置限流器状态(调试用)。""" + with self._lock: + self._stats = RateLimiterStats() diff --git a/src/data/utils.py b/src/data/utils.py index b63bf2c..32da8a1 100644 --- a/src/data/utils.py +++ b/src/data/utils.py @@ -14,13 +14,36 @@ DEFAULT_START_DATE = "20180101" TODAY: str = datetime.now().strftime("%Y%m%d") -def get_today_date() -> str: +def get_today_date(cutoff_hour: int = 9) -> str: """获取今日日期(YYYYMMDD 格式)。 + 考虑数据生成时间的逻辑:在 cutoff_hour 点之前,返回前一天的日期, + 因为当天的数据还未生成。A股数据通常在交易日收盘后(约 15:00-19:00) + 生成,但为了保险起见,默认使用早上 9 点作为分界。 + + Args: + cutoff_hour: 时间分界点(小时,24小时制),默认为 9。 + 当前时间小于此值时,返回前一天日期。 + Returns: - 今日日期字符串,格式为 YYYYMMDD + 日期字符串,格式为 YYYYMMDD + + Example: + >>> # 假设当前是 2024-01-15 08:30 + >>> get_today_date() # 返回 '20240114'(前一天) + >>> + >>> # 假设当前是 2024-01-15 10:00 + >>> get_today_date() # 返回 '20240115'(当天) + >>> + >>> # 使用自定义分界点 + >>> get_today_date(cutoff_hour=15) # 15点前返回前一天 """ - return TODAY + now = datetime.now() + if now.hour < cutoff_hour: + # 在分界点之前,返回前一天 + prev_dt = now - timedelta(days=1) + return prev_dt.strftime("%Y%m%d") + return now.strftime("%Y%m%d") def get_next_date(date_str: str) -> str: