refactor: 清理代码日志、重构速率限制器、切换存储方案
- 移除 client.py 和 daily.py 中的调试日志 - 重构 rate_limiter 支持无限超时和更精确的令牌获取 - 变更 stock_basic 存储方案 HDF5 → CSV - 更新项目规则:强制使用 uv、禁止读取 config/ 目录 - 新增数据同步模块 sync.py 和测试 - .gitignore 添加 !data/ 允许跟踪数据文件
This commit is contained in:
@@ -40,25 +40,26 @@ class TushareClient:
|
||||
self._api = ts.pro_api(self.token)
|
||||
return self._api
|
||||
|
||||
def query(self, api_name: str, timeout: float = 30.0, **params) -> pd.DataFrame:
|
||||
def query(self, api_name: str, timeout: float = None, **params) -> pd.DataFrame:
|
||||
"""Execute API query with rate limiting and retry.
|
||||
|
||||
Args:
|
||||
api_name: API name ('daily', 'pro_bar', etc.)
|
||||
timeout: Timeout for rate limiting
|
||||
timeout: Timeout for rate limiting (None = wait indefinitely)
|
||||
**params: API parameters
|
||||
|
||||
Returns:
|
||||
DataFrame with query results
|
||||
"""
|
||||
# Acquire rate limit token
|
||||
# Acquire rate limit token (None = wait indefinitely)
|
||||
timeout = timeout if timeout is not None else float('inf')
|
||||
success, wait_time = self.rate_limiter.acquire(timeout=timeout)
|
||||
|
||||
if not success:
|
||||
raise RuntimeError(f"Rate limit exceeded after {timeout}s timeout")
|
||||
|
||||
if wait_time > 0:
|
||||
print(f"[RateLimit] Waited {wait_time:.2f}s for token")
|
||||
pass # Silent wait
|
||||
|
||||
# Execute with retry
|
||||
max_retries = 3
|
||||
@@ -83,9 +84,6 @@ class TushareClient:
|
||||
api = self._get_api()
|
||||
data = api.query(api_name, **params)
|
||||
|
||||
available = self.rate_limiter.get_available_tokens()
|
||||
print(f"[Tushare] {api_name} | tokens: {available:.0f}/{self.rate_limiter.capacity}")
|
||||
|
||||
return data
|
||||
|
||||
except Exception as e:
|
||||
|
||||
@@ -63,18 +63,10 @@ def get_daily(
|
||||
else:
|
||||
factors_str = factors
|
||||
params["factors"] = factors_str
|
||||
print(f"[get_daily] factors param: '{factors_str}'")
|
||||
if adjfactor:
|
||||
params["adjfactor"] = "True"
|
||||
|
||||
# Fetch data using pro_bar (supports factors like tor, vr)
|
||||
print(f"[get_daily] Query params: {params}")
|
||||
data = client.query("pro_bar", **params)
|
||||
|
||||
if not data.empty:
|
||||
print(f"[get_daily] Returned columns: {data.columns.tolist()}")
|
||||
print(f"[get_daily] Sample row: {data.iloc[0].to_dict()}")
|
||||
else:
|
||||
print(f"[get_daily] No data for ts_code={ts_code}")
|
||||
|
||||
return data
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
|
||||
This module provides a thread-safe token bucket algorithm for rate limiting.
|
||||
"""
|
||||
|
||||
import time
|
||||
import threading
|
||||
from typing import Optional
|
||||
@@ -11,14 +12,12 @@ from dataclasses import dataclass, field
|
||||
@dataclass
|
||||
class RateLimiterStats:
|
||||
"""Statistics for rate limiter."""
|
||||
|
||||
total_requests: int = 0
|
||||
successful_requests: int = 0
|
||||
denied_requests: int = 0
|
||||
total_wait_time: float = 0.0
|
||||
current_tokens: float = field(default=None, init=False)
|
||||
|
||||
def __post_init__(self):
|
||||
self.current_tokens = field(default=None)
|
||||
current_tokens: Optional[float] = None
|
||||
|
||||
|
||||
class TokenBucketRateLimiter:
|
||||
@@ -54,13 +53,13 @@ class TokenBucketRateLimiter:
|
||||
self._stats = RateLimiterStats()
|
||||
self._stats.current_tokens = self.tokens
|
||||
|
||||
def acquire(self, timeout: float = 30.0) -> tuple[bool, float]:
|
||||
def acquire(self, timeout: float = float("inf")) -> tuple[bool, float]:
|
||||
"""Acquire a token from the bucket.
|
||||
|
||||
Blocks until a token is available or timeout expires.
|
||||
|
||||
Args:
|
||||
timeout: Maximum time to wait for a token in seconds
|
||||
timeout: Maximum time to wait for a token in seconds (default: inf)
|
||||
|
||||
Returns:
|
||||
Tuple of (success, wait_time):
|
||||
@@ -84,32 +83,58 @@ class TokenBucketRateLimiter:
|
||||
tokens_needed = 1 - self.tokens
|
||||
time_to_refill = tokens_needed / self.refill_rate
|
||||
|
||||
if time_to_refill > timeout:
|
||||
# Check if we can wait for the token within timeout
|
||||
# Handle infinite timeout specially
|
||||
is_infinite_timeout = timeout == float("inf")
|
||||
if not is_infinite_timeout and time_to_refill > timeout:
|
||||
self._stats.total_requests += 1
|
||||
self._stats.denied_requests += 1
|
||||
return False, timeout
|
||||
|
||||
# Wait for tokens
|
||||
self._lock.release()
|
||||
time.sleep(time_to_refill)
|
||||
self._lock.acquire()
|
||||
# Wait for tokens - loop until we get one or timeout
|
||||
while True:
|
||||
# Calculate remaining time we can wait
|
||||
elapsed = time.monotonic() - start_time
|
||||
remaining_timeout = (
|
||||
timeout - elapsed if not is_infinite_timeout else float("inf")
|
||||
)
|
||||
|
||||
wait_time = time.monotonic() - start_time
|
||||
# Check if we've exceeded timeout
|
||||
if not is_infinite_timeout and remaining_timeout <= 0:
|
||||
self._stats.total_requests += 1
|
||||
self._stats.denied_requests += 1
|
||||
return False, elapsed
|
||||
|
||||
with self._lock:
|
||||
# Calculate wait time for next token
|
||||
tokens_needed = max(0, 1 - self.tokens)
|
||||
time_to_wait = (
|
||||
tokens_needed / self.refill_rate if tokens_needed > 0 else 0.1
|
||||
)
|
||||
|
||||
# If we can't wait long enough, fail
|
||||
if not is_infinite_timeout and time_to_wait > remaining_timeout:
|
||||
self._stats.total_requests += 1
|
||||
self._stats.denied_requests += 1
|
||||
return False, elapsed
|
||||
|
||||
# Wait outside the lock to allow other threads to refill
|
||||
self._lock.release()
|
||||
time.sleep(
|
||||
min(time_to_wait, 0.1)
|
||||
) # Cap wait to 100ms to check frequently
|
||||
self._lock.acquire()
|
||||
|
||||
# Refill and check again
|
||||
self._refill()
|
||||
if self.tokens >= 1:
|
||||
self.tokens -= 1
|
||||
wait_time = time.monotonic() - start_time
|
||||
self._stats.total_requests += 1
|
||||
self._stats.successful_requests += 1
|
||||
self._stats.total_wait_time += wait_time
|
||||
self._stats.current_tokens = self.tokens
|
||||
return True, wait_time
|
||||
|
||||
self._stats.total_requests += 1
|
||||
self._stats.denied_requests += 1
|
||||
return False, wait_time
|
||||
|
||||
def acquire_nonblocking(self) -> tuple[bool, float]:
|
||||
"""Try to acquire a token without blocking.
|
||||
|
||||
|
||||
@@ -3,10 +3,19 @@
|
||||
Fetch basic stock information including code, name, listing date, etc.
|
||||
This is a special interface - call once to get all stocks (listed and delisted).
|
||||
"""
|
||||
import os
|
||||
import pandas as pd
|
||||
from pathlib import Path
|
||||
from typing import Optional, Literal, List
|
||||
from src.data.client import TushareClient
|
||||
from src.data.storage import Storage
|
||||
from src.data.config import get_config
|
||||
|
||||
|
||||
# CSV file path for stock basic data
|
||||
def _get_csv_path() -> Path:
|
||||
"""Get the CSV file path for stock basic data."""
|
||||
cfg = get_config()
|
||||
return cfg.data_path_resolved / "stock_basic.csv"
|
||||
|
||||
|
||||
def get_stock_basic(
|
||||
@@ -75,20 +84,19 @@ def sync_all_stocks() -> pd.DataFrame:
|
||||
Returns:
|
||||
pd.DataFrame with all stock information
|
||||
"""
|
||||
# Initialize storage
|
||||
storage = Storage()
|
||||
csv_path = _get_csv_path()
|
||||
|
||||
# Check if already exists
|
||||
if storage.exists("stock_basic"):
|
||||
print("[sync_all_stocks] stock_basic data already exists, skipping...")
|
||||
return storage.load("stock_basic")
|
||||
# Check if CSV file already exists
|
||||
if csv_path.exists():
|
||||
print("[sync_all_stocks] stock_basic.csv already exists, skipping...")
|
||||
return pd.read_csv(csv_path)
|
||||
|
||||
print("[sync_all_stocks] Fetching all stocks (listed and delisted)...")
|
||||
|
||||
# Fetch all stocks - explicitly get all list_status values
|
||||
# API default is L (listed), so we need to fetch all statuses
|
||||
client = TushareClient()
|
||||
|
||||
|
||||
all_data = []
|
||||
for status in ["L", "D", "P", "G"]:
|
||||
print(f"[sync_all_stocks] Fetching stocks with status: {status}")
|
||||
@@ -96,21 +104,20 @@ def sync_all_stocks() -> pd.DataFrame:
|
||||
print(f"[sync_all_stocks] Fetched {len(data)} stocks with status {status}")
|
||||
if not data.empty:
|
||||
all_data.append(data)
|
||||
|
||||
|
||||
if not all_data:
|
||||
print("[sync_all_stocks] No stock data fetched")
|
||||
return pd.DataFrame()
|
||||
|
||||
|
||||
# Combine all data
|
||||
data = pd.concat(all_data, ignore_index=True)
|
||||
# Remove duplicates if any
|
||||
data = data.drop_duplicates(subset=["ts_code"], keep="first")
|
||||
print(f"[sync_all_stocks] Total unique stocks: {len(data)}")
|
||||
|
||||
# Save to storage
|
||||
storage.save("stock_basic", data, mode="replace")
|
||||
|
||||
print(f"[sync_all_stocks] Saved {len(data)} stocks to local storage")
|
||||
# Save to CSV
|
||||
data.to_csv(csv_path, index=False, encoding="utf-8-sig")
|
||||
print(f"[sync_all_stocks] Saved {len(data)} stocks to {csv_path}")
|
||||
return data
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user