Files
ProStock/src/data/rate_limiter.py

168 lines
5.3 KiB
Python
Raw Normal View History

"""Token bucket rate limiter implementation.
This module provides a thread-safe token bucket algorithm for rate limiting.
"""
import time
import threading
from typing import Optional
from dataclasses import dataclass, field
@dataclass
class RateLimiterStats:
"""Statistics for rate limiter."""
total_requests: int = 0
successful_requests: int = 0
denied_requests: int = 0
total_wait_time: float = 0.0
current_tokens: float = field(default=None, init=False)
def __post_init__(self):
self.current_tokens = field(default=None)
class TokenBucketRateLimiter:
"""Thread-safe token bucket rate limiter.
Implements a token bucket algorithm for controlling request rate.
Tokens are added at a fixed rate up to the bucket capacity.
Attributes:
capacity: Maximum number of tokens in the bucket
refill_rate: Number of tokens added per second
initial_tokens: Initial number of tokens (default: capacity)
"""
def __init__(
self,
capacity: int = 100,
refill_rate_per_second: float = 1.67,
initial_tokens: Optional[int] = None,
) -> None:
"""Initialize the token bucket rate limiter.
Args:
capacity: Maximum token capacity
refill_rate_per_second: Token refill rate per second
initial_tokens: Initial token count (default: capacity)
"""
self.capacity = capacity
self.refill_rate = refill_rate_per_second
self.tokens = float(initial_tokens if initial_tokens is not None else capacity)
self.last_refill_time = time.monotonic()
self._lock = threading.RLock()
self._stats = RateLimiterStats()
self._stats.current_tokens = self.tokens
def acquire(self, timeout: float = 30.0) -> tuple[bool, float]:
"""Acquire a token from the bucket.
Blocks until a token is available or timeout expires.
Args:
timeout: Maximum time to wait for a token in seconds
Returns:
Tuple of (success, wait_time):
- success: True if token was acquired, False if timed out
- wait_time: Time spent waiting for token
"""
start_time = time.monotonic()
wait_time = 0.0
with self._lock:
self._refill()
if self.tokens >= 1:
self.tokens -= 1
self._stats.total_requests += 1
self._stats.successful_requests += 1
self._stats.current_tokens = self.tokens
return True, 0.0
# Calculate time to wait for next token
tokens_needed = 1 - self.tokens
time_to_refill = tokens_needed / self.refill_rate
if time_to_refill > timeout:
self._stats.total_requests += 1
self._stats.denied_requests += 1
return False, timeout
# Wait for tokens
self._lock.release()
time.sleep(time_to_refill)
self._lock.acquire()
wait_time = time.monotonic() - start_time
with self._lock:
self._refill()
if self.tokens >= 1:
self.tokens -= 1
self._stats.total_requests += 1
self._stats.successful_requests += 1
self._stats.total_wait_time += wait_time
self._stats.current_tokens = self.tokens
return True, wait_time
self._stats.total_requests += 1
self._stats.denied_requests += 1
return False, wait_time
def acquire_nonblocking(self) -> tuple[bool, float]:
"""Try to acquire a token without blocking.
Returns:
Tuple of (success, wait_time):
- success: True if token was acquired, False otherwise
- wait_time: 0 for non-blocking, or required wait time if failed
"""
with self._lock:
self._refill()
if self.tokens >= 1:
self.tokens -= 1
self._stats.total_requests += 1
self._stats.successful_requests += 1
self._stats.current_tokens = self.tokens
return True, 0.0
# Calculate time needed
tokens_needed = 1 - self.tokens
time_to_refill = tokens_needed / self.refill_rate
self._stats.total_requests += 1
self._stats.denied_requests += 1
return False, time_to_refill
def _refill(self) -> None:
"""Refill tokens based on elapsed time."""
current_time = time.monotonic()
elapsed = current_time - self.last_refill_time
self.last_refill_time = current_time
tokens_to_add = elapsed * self.refill_rate
self.tokens = min(self.capacity, self.tokens + tokens_to_add)
def get_available_tokens(self) -> float:
"""Get the current number of available tokens.
Returns:
Current token count
"""
with self._lock:
self._refill()
return self.tokens
def get_stats(self) -> RateLimiterStats:
"""Get rate limiter statistics.
Returns:
RateLimiterStats instance
"""
with self._lock:
self._refill()
self._stats.current_tokens = self.tokens
return self._stats