90 lines
2.8 KiB
Python
90 lines
2.8 KiB
Python
|
|
"""Simplified Tushare client with rate limiting and retry logic."""
|
||
|
|
import time
|
||
|
|
import pandas as pd
|
||
|
|
from typing import Optional
|
||
|
|
from src.data.config import get_config
|
||
|
|
from src.data.rate_limiter import TokenBucketRateLimiter
|
||
|
|
|
||
|
|
|
||
|
|
class TushareClient:
|
||
|
|
"""Tushare API client with rate limiting and retry."""
|
||
|
|
|
||
|
|
def __init__(self, token: Optional[str] = None):
|
||
|
|
"""Initialize client.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
token: Tushare API token (auto-loaded from config if not provided)
|
||
|
|
"""
|
||
|
|
cfg = get_config()
|
||
|
|
token = token or cfg.tushare_token
|
||
|
|
|
||
|
|
if not token:
|
||
|
|
raise ValueError("Tushare token is required")
|
||
|
|
|
||
|
|
self.token = token
|
||
|
|
self.config = cfg
|
||
|
|
|
||
|
|
# Initialize rate limiter: capacity = rate_limit, refill_rate = rate_limit/60 per second
|
||
|
|
rate_per_second = cfg.rate_limit / 60.0
|
||
|
|
self.rate_limiter = TokenBucketRateLimiter(
|
||
|
|
capacity=cfg.rate_limit,
|
||
|
|
refill_rate_per_second=rate_per_second,
|
||
|
|
)
|
||
|
|
|
||
|
|
self._api = None
|
||
|
|
|
||
|
|
def _get_api(self):
|
||
|
|
"""Get Tushare API instance."""
|
||
|
|
if self._api is None:
|
||
|
|
import tushare as ts
|
||
|
|
self._api = ts.pro_api(self.token)
|
||
|
|
return self._api
|
||
|
|
|
||
|
|
def query(self, api_name: str, timeout: float = 30.0, **params) -> pd.DataFrame:
|
||
|
|
"""Execute API query with rate limiting and retry.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
api_name: API name (e.g., 'daily')
|
||
|
|
timeout: Timeout for rate limiting
|
||
|
|
**params: API parameters
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
DataFrame with query results
|
||
|
|
"""
|
||
|
|
# Acquire rate limit token
|
||
|
|
success, wait_time = self.rate_limiter.acquire(timeout=timeout)
|
||
|
|
|
||
|
|
if not success:
|
||
|
|
raise RuntimeError(f"Rate limit exceeded after {timeout}s timeout")
|
||
|
|
|
||
|
|
if wait_time > 0:
|
||
|
|
print(f"[RateLimit] Waited {wait_time:.2f}s for token")
|
||
|
|
|
||
|
|
# Execute with retry
|
||
|
|
max_retries = 3
|
||
|
|
retry_delays = [1, 3, 10]
|
||
|
|
|
||
|
|
for attempt in range(max_retries):
|
||
|
|
try:
|
||
|
|
api = self._get_api()
|
||
|
|
data = api.query(api_name, **params)
|
||
|
|
|
||
|
|
available = self.rate_limiter.get_available_tokens()
|
||
|
|
print(f"[Tushare] {api_name} | tokens: {available:.0f}/{self.rate_limiter.capacity}")
|
||
|
|
|
||
|
|
return data
|
||
|
|
|
||
|
|
except Exception as e:
|
||
|
|
if attempt < max_retries - 1:
|
||
|
|
delay = retry_delays[attempt]
|
||
|
|
print(f"[Retry] {api_name} failed (attempt {attempt + 1}): {e}, retry in {delay}s")
|
||
|
|
time.sleep(delay)
|
||
|
|
else:
|
||
|
|
raise RuntimeError(f"API call failed after {max_retries} attempts: {e}")
|
||
|
|
|
||
|
|
return pd.DataFrame()
|
||
|
|
|
||
|
|
def close(self):
|
||
|
|
"""Close client."""
|
||
|
|
pass
|