"""Token bucket rate limiter implementation. This module provides a thread-safe token bucket algorithm for rate limiting. """ import time import threading from typing import Optional from dataclasses import dataclass, field @dataclass class RateLimiterStats: """Statistics for rate limiter.""" total_requests: int = 0 successful_requests: int = 0 denied_requests: int = 0 total_wait_time: float = 0.0 current_tokens: Optional[float] = None class TokenBucketRateLimiter: """Thread-safe token bucket rate limiter. Implements a token bucket algorithm for controlling request rate. Tokens are added at a fixed rate up to the bucket capacity. Attributes: capacity: Maximum number of tokens in the bucket refill_rate: Number of tokens added per second initial_tokens: Initial number of tokens (default: capacity) """ def __init__( self, capacity: int = 100, refill_rate_per_second: float = 1.67, initial_tokens: Optional[int] = None, ) -> None: """Initialize the token bucket rate limiter. Args: capacity: Maximum token capacity refill_rate_per_second: Token refill rate per second initial_tokens: Initial token count (default: capacity) """ self.capacity = capacity self.refill_rate = refill_rate_per_second self.tokens = float(initial_tokens if initial_tokens is not None else capacity) self.last_refill_time = time.monotonic() self._lock = threading.RLock() self._stats = RateLimiterStats() self._stats.current_tokens = self.tokens def acquire(self, timeout: float = float("inf")) -> tuple[bool, float]: """Acquire a token from the bucket. Blocks until a token is available or timeout expires. Args: timeout: Maximum time to wait for a token in seconds (default: inf) Returns: Tuple of (success, wait_time): - success: True if token was acquired, False if timed out - wait_time: Time spent waiting for token """ start_time = time.monotonic() wait_time = 0.0 with self._lock: self._refill() if self.tokens >= 1: self.tokens -= 1 self._stats.total_requests += 1 self._stats.successful_requests += 1 self._stats.current_tokens = self.tokens return True, 0.0 # Calculate time to wait for next token tokens_needed = 1 - self.tokens time_to_refill = tokens_needed / self.refill_rate # Check if we can wait for the token within timeout # Handle infinite timeout specially is_infinite_timeout = timeout == float("inf") if not is_infinite_timeout and time_to_refill > timeout: self._stats.total_requests += 1 self._stats.denied_requests += 1 return False, timeout # Wait for tokens - loop until we get one or timeout while True: # Calculate remaining time we can wait elapsed = time.monotonic() - start_time remaining_timeout = ( timeout - elapsed if not is_infinite_timeout else float("inf") ) # Check if we've exceeded timeout if not is_infinite_timeout and remaining_timeout <= 0: self._stats.total_requests += 1 self._stats.denied_requests += 1 return False, elapsed # Calculate wait time for next token tokens_needed = max(0, 1 - self.tokens) time_to_wait = ( tokens_needed / self.refill_rate if tokens_needed > 0 else 0.1 ) # If we can't wait long enough, fail if not is_infinite_timeout and time_to_wait > remaining_timeout: self._stats.total_requests += 1 self._stats.denied_requests += 1 return False, elapsed # Wait outside the lock to allow other threads to refill self._lock.release() time.sleep( min(time_to_wait, 0.1) ) # Cap wait to 100ms to check frequently self._lock.acquire() # Refill and check again self._refill() if self.tokens >= 1: self.tokens -= 1 wait_time = time.monotonic() - start_time self._stats.total_requests += 1 self._stats.successful_requests += 1 self._stats.total_wait_time += wait_time self._stats.current_tokens = self.tokens return True, wait_time def acquire_nonblocking(self) -> tuple[bool, float]: """Try to acquire a token without blocking. Returns: Tuple of (success, wait_time): - success: True if token was acquired, False otherwise - wait_time: 0 for non-blocking, or required wait time if failed """ with self._lock: self._refill() if self.tokens >= 1: self.tokens -= 1 self._stats.total_requests += 1 self._stats.successful_requests += 1 self._stats.current_tokens = self.tokens return True, 0.0 # Calculate time needed tokens_needed = 1 - self.tokens time_to_refill = tokens_needed / self.refill_rate self._stats.total_requests += 1 self._stats.denied_requests += 1 return False, time_to_refill def _refill(self) -> None: """Refill tokens based on elapsed time.""" current_time = time.monotonic() elapsed = current_time - self.last_refill_time self.last_refill_time = current_time tokens_to_add = elapsed * self.refill_rate self.tokens = min(self.capacity, self.tokens + tokens_to_add) def get_available_tokens(self) -> float: """Get the current number of available tokens. Returns: Current token count """ with self._lock: self._refill() return self.tokens def get_stats(self) -> RateLimiterStats: """Get rate limiter statistics. Returns: RateLimiterStats instance """ with self._lock: self._refill() self._stats.current_tokens = self.tokens return self._stats