2026-01-31 03:04:51 +08:00
|
|
|
"""Token bucket rate limiter implementation.
|
|
|
|
|
|
|
|
|
|
This module provides a thread-safe token bucket algorithm for rate limiting.
|
|
|
|
|
"""
|
2026-02-01 02:29:54 +08:00
|
|
|
|
2026-01-31 03:04:51 +08:00
|
|
|
import time
|
|
|
|
|
import threading
|
|
|
|
|
from typing import Optional
|
|
|
|
|
from dataclasses import dataclass, field
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
|
|
class RateLimiterStats:
|
|
|
|
|
"""Statistics for rate limiter."""
|
2026-02-01 02:29:54 +08:00
|
|
|
|
2026-01-31 03:04:51 +08:00
|
|
|
total_requests: int = 0
|
|
|
|
|
successful_requests: int = 0
|
|
|
|
|
denied_requests: int = 0
|
|
|
|
|
total_wait_time: float = 0.0
|
2026-02-01 02:29:54 +08:00
|
|
|
current_tokens: Optional[float] = None
|
2026-01-31 03:04:51 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
class TokenBucketRateLimiter:
|
|
|
|
|
"""Thread-safe token bucket rate limiter.
|
|
|
|
|
|
|
|
|
|
Implements a token bucket algorithm for controlling request rate.
|
|
|
|
|
Tokens are added at a fixed rate up to the bucket capacity.
|
|
|
|
|
|
|
|
|
|
Attributes:
|
|
|
|
|
capacity: Maximum number of tokens in the bucket
|
|
|
|
|
refill_rate: Number of tokens added per second
|
|
|
|
|
initial_tokens: Initial number of tokens (default: capacity)
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def __init__(
|
|
|
|
|
self,
|
|
|
|
|
capacity: int = 100,
|
|
|
|
|
refill_rate_per_second: float = 1.67,
|
|
|
|
|
initial_tokens: Optional[int] = None,
|
|
|
|
|
) -> None:
|
|
|
|
|
"""Initialize the token bucket rate limiter.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
capacity: Maximum token capacity
|
|
|
|
|
refill_rate_per_second: Token refill rate per second
|
|
|
|
|
initial_tokens: Initial token count (default: capacity)
|
|
|
|
|
"""
|
|
|
|
|
self.capacity = capacity
|
|
|
|
|
self.refill_rate = refill_rate_per_second
|
|
|
|
|
self.tokens = float(initial_tokens if initial_tokens is not None else capacity)
|
|
|
|
|
self.last_refill_time = time.monotonic()
|
|
|
|
|
self._lock = threading.RLock()
|
|
|
|
|
self._stats = RateLimiterStats()
|
|
|
|
|
self._stats.current_tokens = self.tokens
|
|
|
|
|
|
2026-02-01 02:29:54 +08:00
|
|
|
def acquire(self, timeout: float = float("inf")) -> tuple[bool, float]:
|
2026-01-31 03:04:51 +08:00
|
|
|
"""Acquire a token from the bucket.
|
|
|
|
|
|
|
|
|
|
Blocks until a token is available or timeout expires.
|
|
|
|
|
|
|
|
|
|
Args:
|
2026-02-01 02:29:54 +08:00
|
|
|
timeout: Maximum time to wait for a token in seconds (default: inf)
|
2026-01-31 03:04:51 +08:00
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
Tuple of (success, wait_time):
|
|
|
|
|
- success: True if token was acquired, False if timed out
|
|
|
|
|
- wait_time: Time spent waiting for token
|
|
|
|
|
"""
|
|
|
|
|
start_time = time.monotonic()
|
|
|
|
|
wait_time = 0.0
|
|
|
|
|
|
|
|
|
|
with self._lock:
|
|
|
|
|
self._refill()
|
|
|
|
|
|
|
|
|
|
if self.tokens >= 1:
|
|
|
|
|
self.tokens -= 1
|
|
|
|
|
self._stats.total_requests += 1
|
|
|
|
|
self._stats.successful_requests += 1
|
|
|
|
|
self._stats.current_tokens = self.tokens
|
|
|
|
|
return True, 0.0
|
|
|
|
|
|
|
|
|
|
# Calculate time to wait for next token
|
|
|
|
|
tokens_needed = 1 - self.tokens
|
|
|
|
|
time_to_refill = tokens_needed / self.refill_rate
|
|
|
|
|
|
2026-02-01 02:29:54 +08:00
|
|
|
# Check if we can wait for the token within timeout
|
|
|
|
|
# Handle infinite timeout specially
|
|
|
|
|
is_infinite_timeout = timeout == float("inf")
|
|
|
|
|
if not is_infinite_timeout and time_to_refill > timeout:
|
2026-01-31 03:04:51 +08:00
|
|
|
self._stats.total_requests += 1
|
|
|
|
|
self._stats.denied_requests += 1
|
|
|
|
|
return False, timeout
|
|
|
|
|
|
2026-02-01 02:29:54 +08:00
|
|
|
# Wait for tokens - loop until we get one or timeout
|
|
|
|
|
while True:
|
|
|
|
|
# Calculate remaining time we can wait
|
|
|
|
|
elapsed = time.monotonic() - start_time
|
|
|
|
|
remaining_timeout = (
|
|
|
|
|
timeout - elapsed if not is_infinite_timeout else float("inf")
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Check if we've exceeded timeout
|
|
|
|
|
if not is_infinite_timeout and remaining_timeout <= 0:
|
|
|
|
|
self._stats.total_requests += 1
|
|
|
|
|
self._stats.denied_requests += 1
|
|
|
|
|
return False, elapsed
|
2026-01-31 03:04:51 +08:00
|
|
|
|
2026-02-01 02:29:54 +08:00
|
|
|
# Calculate wait time for next token
|
|
|
|
|
tokens_needed = max(0, 1 - self.tokens)
|
|
|
|
|
time_to_wait = (
|
|
|
|
|
tokens_needed / self.refill_rate if tokens_needed > 0 else 0.1
|
|
|
|
|
)
|
2026-01-31 03:04:51 +08:00
|
|
|
|
2026-02-01 02:29:54 +08:00
|
|
|
# If we can't wait long enough, fail
|
|
|
|
|
if not is_infinite_timeout and time_to_wait > remaining_timeout:
|
|
|
|
|
self._stats.total_requests += 1
|
|
|
|
|
self._stats.denied_requests += 1
|
|
|
|
|
return False, elapsed
|
|
|
|
|
|
|
|
|
|
# Wait outside the lock to allow other threads to refill
|
|
|
|
|
self._lock.release()
|
|
|
|
|
time.sleep(
|
|
|
|
|
min(time_to_wait, 0.1)
|
|
|
|
|
) # Cap wait to 100ms to check frequently
|
|
|
|
|
self._lock.acquire()
|
|
|
|
|
|
|
|
|
|
# Refill and check again
|
2026-01-31 03:04:51 +08:00
|
|
|
self._refill()
|
|
|
|
|
if self.tokens >= 1:
|
|
|
|
|
self.tokens -= 1
|
2026-02-01 02:29:54 +08:00
|
|
|
wait_time = time.monotonic() - start_time
|
2026-01-31 03:04:51 +08:00
|
|
|
self._stats.total_requests += 1
|
|
|
|
|
self._stats.successful_requests += 1
|
|
|
|
|
self._stats.total_wait_time += wait_time
|
|
|
|
|
self._stats.current_tokens = self.tokens
|
|
|
|
|
return True, wait_time
|
|
|
|
|
|
|
|
|
|
def acquire_nonblocking(self) -> tuple[bool, float]:
|
|
|
|
|
"""Try to acquire a token without blocking.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
Tuple of (success, wait_time):
|
|
|
|
|
- success: True if token was acquired, False otherwise
|
|
|
|
|
- wait_time: 0 for non-blocking, or required wait time if failed
|
|
|
|
|
"""
|
|
|
|
|
with self._lock:
|
|
|
|
|
self._refill()
|
|
|
|
|
|
|
|
|
|
if self.tokens >= 1:
|
|
|
|
|
self.tokens -= 1
|
|
|
|
|
self._stats.total_requests += 1
|
|
|
|
|
self._stats.successful_requests += 1
|
|
|
|
|
self._stats.current_tokens = self.tokens
|
|
|
|
|
return True, 0.0
|
|
|
|
|
|
|
|
|
|
# Calculate time needed
|
|
|
|
|
tokens_needed = 1 - self.tokens
|
|
|
|
|
time_to_refill = tokens_needed / self.refill_rate
|
|
|
|
|
|
|
|
|
|
self._stats.total_requests += 1
|
|
|
|
|
self._stats.denied_requests += 1
|
|
|
|
|
return False, time_to_refill
|
|
|
|
|
|
|
|
|
|
def _refill(self) -> None:
|
|
|
|
|
"""Refill tokens based on elapsed time."""
|
|
|
|
|
current_time = time.monotonic()
|
|
|
|
|
elapsed = current_time - self.last_refill_time
|
|
|
|
|
self.last_refill_time = current_time
|
|
|
|
|
|
|
|
|
|
tokens_to_add = elapsed * self.refill_rate
|
|
|
|
|
self.tokens = min(self.capacity, self.tokens + tokens_to_add)
|
|
|
|
|
|
|
|
|
|
def get_available_tokens(self) -> float:
|
|
|
|
|
"""Get the current number of available tokens.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
Current token count
|
|
|
|
|
"""
|
|
|
|
|
with self._lock:
|
|
|
|
|
self._refill()
|
|
|
|
|
return self.tokens
|
|
|
|
|
|
|
|
|
|
def get_stats(self) -> RateLimiterStats:
|
|
|
|
|
"""Get rate limiter statistics.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
RateLimiterStats instance
|
|
|
|
|
"""
|
|
|
|
|
with self._lock:
|
|
|
|
|
self._refill()
|
|
|
|
|
self._stats.current_tokens = self.tokens
|
|
|
|
|
return self._stats
|