"""API 速率限制器实现。 提供基于固定时间窗口的速率限制,适合 Tushare 等按分钟计费的 API。 """ import time import threading from typing import Optional from dataclasses import dataclass @dataclass class RateLimiterStats: """速率限制器统计信息。""" total_requests: int = 0 successful_requests: int = 0 denied_requests: int = 0 total_wait_time: float = 0.0 current_window_requests: int = 0 window_start_time: float = 0.0 class TokenBucketRateLimiter: """基于固定时间窗口的速率限制器。 适合 Tushare 等按时间窗口(如每分钟)限制请求数的 API 场景。 在窗口期内,请求数达到上限后将阻塞或等待下一个窗口。 Attributes: capacity: 每个时间窗口内允许的最大请求数 window_seconds: 时间窗口长度(秒) """ def __init__( self, capacity: int = 100, refill_rate_per_second: float = 1.67, initial_tokens: Optional[int] = None, ) -> None: """初始化速率限制器。 Args: capacity: 每个时间窗口内允许的最大请求数 refill_rate_per_second: 保留参数(向后兼容),实际使用 window_seconds=60 initial_tokens: 保留参数(向后兼容) """ self.capacity = capacity # Tushare 通常按分钟限制,所以固定使用 60 秒窗口 self.window_seconds = 60.0 self._requests_in_window = 0 self._window_start = time.monotonic() self._lock = threading.RLock() self._stats = RateLimiterStats() self._stats.window_start_time = self._window_start def _is_new_window(self) -> bool: """检查是否已进入新的时间窗口。""" current_time = time.monotonic() elapsed = current_time - self._window_start return elapsed >= self.window_seconds def _reset_window(self) -> None: """重置时间窗口。""" self._window_start = time.monotonic() self._requests_in_window = 0 self._stats.window_start_time = self._window_start def acquire(self, timeout: float = float("inf")) -> tuple[bool, float]: """获取请求许可。 如果在当前窗口内请求数已达上限,则等待到下一个窗口。 Args: timeout: 最大等待时间(秒),默认无限等待 Returns: (success, wait_time): 是否成功获取许可,以及等待时间 """ start_time = time.monotonic() with self._lock: # 检查是否需要进入新窗口 if self._is_new_window(): self._reset_window() # 如果当前窗口还有余量,直接通过 if self._requests_in_window < self.capacity: self._requests_in_window += 1 self._stats.total_requests += 1 self._stats.successful_requests += 1 self._stats.current_window_requests = self._requests_in_window return True, 0.0 # 当前窗口已满,计算需要等待的时间 current_time = time.monotonic() time_to_next_window = self.window_seconds - ( current_time - self._window_start ) if time_to_next_window <= 0: # 刚好进入新窗口 self._reset_window() self._requests_in_window = 1 self._stats.total_requests += 1 self._stats.successful_requests += 1 self._stats.current_window_requests = 1 return True, 0.0 # 检查是否能在超时时间内等待 if timeout != float("inf") and time_to_next_window > timeout: self._stats.total_requests += 1 self._stats.denied_requests += 1 return False, timeout # 需要等待到下一个窗口 if timeout != float("inf"): time_to_wait = min(time_to_next_window, timeout) else: time_to_wait = time_to_next_window time.sleep(time_to_wait) # 重新尝试获取许可 with self._lock: # 再次检查窗口状态(可能其他线程已经重置了窗口) if self._is_new_window(): self._reset_window() if self._requests_in_window < self.capacity: self._requests_in_window += 1 wait_time = time.monotonic() - start_time self._stats.total_requests += 1 self._stats.successful_requests += 1 self._stats.total_wait_time += wait_time self._stats.current_window_requests = self._requests_in_window return True, wait_time else: # 在极端情况下,等待后仍然无法获取(其他线程抢先) wait_time = time.monotonic() - start_time self._stats.total_requests += 1 self._stats.denied_requests += 1 return False, wait_time def acquire_nonblocking(self) -> tuple[bool, float]: """尝试非阻塞地获取请求许可。 Returns: (success, wait_time): 是否成功获取许可,以及需要等待的时间 """ with self._lock: # 检查是否需要进入新窗口 if self._is_new_window(): self._reset_window() # 如果当前窗口还有余量,直接通过 if self._requests_in_window < self.capacity: self._requests_in_window += 1 self._stats.total_requests += 1 self._stats.successful_requests += 1 self._stats.current_window_requests = self._requests_in_window return True, 0.0 # 当前窗口已满,计算需要等待的时间 current_time = time.monotonic() time_to_next_window = self.window_seconds - ( current_time - self._window_start ) self._stats.total_requests += 1 self._stats.denied_requests += 1 return False, max(0.0, time_to_next_window) def get_available_tokens(self) -> float: """获取当前窗口剩余可用请求数。 Returns: 当前窗口剩余可用请求数 """ with self._lock: if self._is_new_window(): return float(self.capacity) return float(self.capacity - self._requests_in_window) def get_stats(self) -> RateLimiterStats: """获取速率限制器统计信息。 Returns: RateLimiterStats 实例 """ with self._lock: self._stats.current_window_requests = self._requests_in_window return self._stats