feat: 初始化 ProStock 项目基础结构和配置
- 添加项目规则文档(开发规范、安全规则、配置管理) - 实现数据模块核心功能(API 客户端、限流器、存储管理、配置加载) - 添加 .gitignore 和 .kilocodeignore 配置 - 配置环境变量模板 - 编写 daily 模块单元测试
This commit is contained in:
15
src/data/__init__.py
Normal file
15
src/data/__init__.py
Normal file
@@ -0,0 +1,15 @@
|
||||
"""Data collection module for Tushare.
|
||||
|
||||
Provides simplified interfaces for fetching and storing Tushare data.
|
||||
"""
|
||||
|
||||
from src.data.config import Config, get_config
|
||||
from src.data.client import TushareClient
|
||||
from src.data.storage import Storage
|
||||
|
||||
__all__ = [
|
||||
"Config",
|
||||
"get_config",
|
||||
"TushareClient",
|
||||
"Storage",
|
||||
]
|
||||
47
src/data/api.md
Normal file
47
src/data/api.md
Normal file
@@ -0,0 +1,47 @@
|
||||
1、通用行情接口:https://tushare.pro/document/2?doc_id=109,能够获取的字段参考https://tushare.pro/document/2?doc_id=27,要求,保存A股日线行情中所有输出字段和tor换手率 vr量比
|
||||
ts_code str Y 证券代码,不支持多值输入,多值输入获取结果会有重复记录
|
||||
start_date str N 开始日期 (日线格式:YYYYMMDD,提取分钟数据请用2019-09-01 09:00:00这种格式)
|
||||
end_date str N 结束日期 (日线格式:YYYYMMDD)
|
||||
asset str Y 资产类别:E股票 I沪深指数 C数字货币 FT期货 FD基金 O期权 CB可转债(v1.2.39),默认E
|
||||
adj str N 复权类型(只针对股票):None未复权 qfq前复权 hfq后复权 , 默认None,目前只支持日线复权,同时复权机制是根据设定的end_date参数动态复权,采用分红再投模式,具体请参考常见问题列表里的说明。
|
||||
freq str Y 数据频度 :支持分钟(min)/日(D)/周(W)/月(M)K线,其中1min表示1分钟(类推1/5/15/30/60分钟) ,默认D。对于分钟数据有600积分用户可以试用(请求2次),正式权限可以参考权限列表说明 ,使用方法请参考股票分钟使用方法。
|
||||
ma list N 均线,支持任意合理int数值。注:均线是动态计算,要设置一定时间范围才能获得相应的均线,比如5日均线,开始和结束日期参数跨度必须要超过5日。目前只支持单一个股票提取均线,即需要输入ts_code参数。e.g: ma_5表示5日均价,ma_v_5表示5日均量
|
||||
factors list N 股票因子(asset='E'有效)支持 tor换手率 vr量比
|
||||
adjfactor str N 复权因子,在复权数据时,如果此参数为True,返回的数据中则带复权因子,默认为False。 该功能从1.2.33版本开始生效
|
||||
输出指标
|
||||
具体输出的数据指标可参考各行情具体指标:
|
||||
股票Daily:https://tushare.pro/document/2?doc_id=27
|
||||
基金Daily:https://tushare.pro/document/2?doc_id=127
|
||||
期货Daily:https://tushare.pro/document/2?doc_id=138
|
||||
期权Daily:https://tushare.pro/document/2?doc_id=159
|
||||
指数Daily:https://tushare.pro/document/2?doc_id=95
|
||||
|
||||
A股日线行情
|
||||
接口:daily,可以通过数据工具调试和查看数据
|
||||
数据说明:交易日每天15点~16点之间入库。本接口是未复权行情,停牌期间不提供数据
|
||||
调取说明:基础积分每分钟内可调取500次,每次6000条数据,一次请求相当于提取一个股票23年历史
|
||||
描述:获取股票行情数据,或通过通用行情接口获取数据,包含了前后复权数据
|
||||
|
||||
输入参数
|
||||
|
||||
名称 类型 必选 描述
|
||||
ts_code str N 股票代码(支持多个股票同时提取,逗号分隔)
|
||||
trade_date str N 交易日期(YYYYMMDD)
|
||||
start_date str N 开始日期(YYYYMMDD)
|
||||
end_date str N 结束日期(YYYYMMDD)
|
||||
注:日期都填YYYYMMDD格式,比如20181010
|
||||
|
||||
输出参数
|
||||
|
||||
名称 类型 描述
|
||||
ts_code str 股票代码
|
||||
trade_date str 交易日期
|
||||
open float 开盘价
|
||||
high float 最高价
|
||||
low float 最低价
|
||||
close float 收盘价
|
||||
pre_close float 昨收价【除权价】
|
||||
change float 涨跌额
|
||||
pct_chg float 涨跌幅 【基于除权后的昨收计算的涨跌幅:(今收-除权昨收)/除权昨收 】
|
||||
vol float 成交量 (手)
|
||||
amount float 成交额 (千元)
|
||||
89
src/data/client.py
Normal file
89
src/data/client.py
Normal file
@@ -0,0 +1,89 @@
|
||||
"""Simplified Tushare client with rate limiting and retry logic."""
|
||||
import time
|
||||
import pandas as pd
|
||||
from typing import Optional
|
||||
from src.data.config import get_config
|
||||
from src.data.rate_limiter import TokenBucketRateLimiter
|
||||
|
||||
|
||||
class TushareClient:
|
||||
"""Tushare API client with rate limiting and retry."""
|
||||
|
||||
def __init__(self, token: Optional[str] = None):
|
||||
"""Initialize client.
|
||||
|
||||
Args:
|
||||
token: Tushare API token (auto-loaded from config if not provided)
|
||||
"""
|
||||
cfg = get_config()
|
||||
token = token or cfg.tushare_token
|
||||
|
||||
if not token:
|
||||
raise ValueError("Tushare token is required")
|
||||
|
||||
self.token = token
|
||||
self.config = cfg
|
||||
|
||||
# Initialize rate limiter: capacity = rate_limit, refill_rate = rate_limit/60 per second
|
||||
rate_per_second = cfg.rate_limit / 60.0
|
||||
self.rate_limiter = TokenBucketRateLimiter(
|
||||
capacity=cfg.rate_limit,
|
||||
refill_rate_per_second=rate_per_second,
|
||||
)
|
||||
|
||||
self._api = None
|
||||
|
||||
def _get_api(self):
|
||||
"""Get Tushare API instance."""
|
||||
if self._api is None:
|
||||
import tushare as ts
|
||||
self._api = ts.pro_api(self.token)
|
||||
return self._api
|
||||
|
||||
def query(self, api_name: str, timeout: float = 30.0, **params) -> pd.DataFrame:
|
||||
"""Execute API query with rate limiting and retry.
|
||||
|
||||
Args:
|
||||
api_name: API name (e.g., 'daily')
|
||||
timeout: Timeout for rate limiting
|
||||
**params: API parameters
|
||||
|
||||
Returns:
|
||||
DataFrame with query results
|
||||
"""
|
||||
# Acquire rate limit token
|
||||
success, wait_time = self.rate_limiter.acquire(timeout=timeout)
|
||||
|
||||
if not success:
|
||||
raise RuntimeError(f"Rate limit exceeded after {timeout}s timeout")
|
||||
|
||||
if wait_time > 0:
|
||||
print(f"[RateLimit] Waited {wait_time:.2f}s for token")
|
||||
|
||||
# Execute with retry
|
||||
max_retries = 3
|
||||
retry_delays = [1, 3, 10]
|
||||
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
api = self._get_api()
|
||||
data = api.query(api_name, **params)
|
||||
|
||||
available = self.rate_limiter.get_available_tokens()
|
||||
print(f"[Tushare] {api_name} | tokens: {available:.0f}/{self.rate_limiter.capacity}")
|
||||
|
||||
return data
|
||||
|
||||
except Exception as e:
|
||||
if attempt < max_retries - 1:
|
||||
delay = retry_delays[attempt]
|
||||
print(f"[Retry] {api_name} failed (attempt {attempt + 1}): {e}, retry in {delay}s")
|
||||
time.sleep(delay)
|
||||
else:
|
||||
raise RuntimeError(f"API call failed after {max_retries} attempts: {e}")
|
||||
|
||||
return pd.DataFrame()
|
||||
|
||||
def close(self):
|
||||
"""Close client."""
|
||||
pass
|
||||
33
src/data/config.py
Normal file
33
src/data/config.py
Normal file
@@ -0,0 +1,33 @@
|
||||
"""Configuration management for data collection module."""
|
||||
from pathlib import Path
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
|
||||
class Config(BaseSettings):
|
||||
"""Application configuration loaded from environment variables."""
|
||||
|
||||
# Tushare API token
|
||||
tushare_token: str = ""
|
||||
|
||||
# Data storage path
|
||||
data_path: Path = Path("./data")
|
||||
|
||||
# Rate limit: requests per minute
|
||||
rate_limit: int = 100
|
||||
|
||||
# Thread pool size
|
||||
threads: int = 2
|
||||
|
||||
class Config:
|
||||
env_file = ".env.local"
|
||||
env_file_encoding = "utf-8"
|
||||
case_sensitive = False
|
||||
|
||||
|
||||
# Global config instance
|
||||
config = Config()
|
||||
|
||||
|
||||
def get_config() -> Config:
|
||||
"""Get configuration instance."""
|
||||
return config
|
||||
70
src/data/daily.py
Normal file
70
src/data/daily.py
Normal file
@@ -0,0 +1,70 @@
|
||||
"""Simplified daily market data interface.
|
||||
|
||||
A single function to fetch A股日线行情 data from Tushare.
|
||||
Supports all output fields including tor (换手率) and vr (量比).
|
||||
"""
|
||||
import pandas as pd
|
||||
from typing import Optional, List, Literal
|
||||
from src.data.client import TushareClient
|
||||
|
||||
|
||||
def get_daily(
|
||||
ts_code: str,
|
||||
start_date: Optional[str] = None,
|
||||
end_date: Optional[str] = None,
|
||||
trade_date: Optional[str] = None,
|
||||
adj: Literal[None, "qfq", "hfq"] = None,
|
||||
factors: Optional[List[Literal["tor", "vr"]]] = None,
|
||||
adjfactor: bool = False,
|
||||
) -> pd.DataFrame:
|
||||
"""Fetch daily market data for A-share stocks.
|
||||
|
||||
This is a simplified interface that combines rate limiting, API calls,
|
||||
and error handling into a single function.
|
||||
|
||||
Args:
|
||||
ts_code: Stock code (e.g., '000001.SZ', '600000.SH')
|
||||
start_date: Start date in YYYYMMDD format
|
||||
end_date: End date in YYYYMMDD format
|
||||
trade_date: Specific trade date in YYYYMMDD format
|
||||
adj: Adjustment type - None, 'qfq' (forward), 'hfq' (backward)
|
||||
factors: List of factors to include - 'tor' (turnover rate), 'vr' (volume ratio)
|
||||
adjfactor: Whether to include adjustment factor
|
||||
|
||||
Returns:
|
||||
pd.DataFrame with daily market data containing:
|
||||
- Base fields: ts_code, trade_date, open, high, low, close, pre_close,
|
||||
change, pct_chg, vol, amount
|
||||
- Factor fields (if requested): tor, vr
|
||||
- Adjustment factor (if adjfactor=True): adjfactor
|
||||
|
||||
Example:
|
||||
>>> data = get_daily('000001.SZ', start_date='20240101', end_date='20240131')
|
||||
>>> data = get_daily('600000.SH', factors=['tor', 'vr'])
|
||||
"""
|
||||
# Initialize client
|
||||
client = TushareClient()
|
||||
|
||||
# Build parameters
|
||||
params = {"ts_code": ts_code}
|
||||
|
||||
if start_date:
|
||||
params["start_date"] = start_date
|
||||
if end_date:
|
||||
params["end_date"] = end_date
|
||||
if trade_date:
|
||||
params["trade_date"] = trade_date
|
||||
if adj:
|
||||
params["adj"] = adj
|
||||
if factors:
|
||||
params["factors"] = factors
|
||||
if adjfactor:
|
||||
params["adjfactor"] = "True"
|
||||
|
||||
# Fetch data
|
||||
data = client.query("daily", **params)
|
||||
|
||||
if data.empty:
|
||||
print(f"[get_daily] No data for ts_code={ts_code}")
|
||||
|
||||
return data
|
||||
167
src/data/rate_limiter.py
Normal file
167
src/data/rate_limiter.py
Normal file
@@ -0,0 +1,167 @@
|
||||
"""Token bucket rate limiter implementation.
|
||||
|
||||
This module provides a thread-safe token bucket algorithm for rate limiting.
|
||||
"""
|
||||
import time
|
||||
import threading
|
||||
from typing import Optional
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
|
||||
@dataclass
|
||||
class RateLimiterStats:
|
||||
"""Statistics for rate limiter."""
|
||||
total_requests: int = 0
|
||||
successful_requests: int = 0
|
||||
denied_requests: int = 0
|
||||
total_wait_time: float = 0.0
|
||||
current_tokens: float = field(default=None, init=False)
|
||||
|
||||
def __post_init__(self):
|
||||
self.current_tokens = field(default=None)
|
||||
|
||||
|
||||
class TokenBucketRateLimiter:
|
||||
"""Thread-safe token bucket rate limiter.
|
||||
|
||||
Implements a token bucket algorithm for controlling request rate.
|
||||
Tokens are added at a fixed rate up to the bucket capacity.
|
||||
|
||||
Attributes:
|
||||
capacity: Maximum number of tokens in the bucket
|
||||
refill_rate: Number of tokens added per second
|
||||
initial_tokens: Initial number of tokens (default: capacity)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
capacity: int = 100,
|
||||
refill_rate_per_second: float = 1.67,
|
||||
initial_tokens: Optional[int] = None,
|
||||
) -> None:
|
||||
"""Initialize the token bucket rate limiter.
|
||||
|
||||
Args:
|
||||
capacity: Maximum token capacity
|
||||
refill_rate_per_second: Token refill rate per second
|
||||
initial_tokens: Initial token count (default: capacity)
|
||||
"""
|
||||
self.capacity = capacity
|
||||
self.refill_rate = refill_rate_per_second
|
||||
self.tokens = float(initial_tokens if initial_tokens is not None else capacity)
|
||||
self.last_refill_time = time.monotonic()
|
||||
self._lock = threading.RLock()
|
||||
self._stats = RateLimiterStats()
|
||||
self._stats.current_tokens = self.tokens
|
||||
|
||||
def acquire(self, timeout: float = 30.0) -> tuple[bool, float]:
|
||||
"""Acquire a token from the bucket.
|
||||
|
||||
Blocks until a token is available or timeout expires.
|
||||
|
||||
Args:
|
||||
timeout: Maximum time to wait for a token in seconds
|
||||
|
||||
Returns:
|
||||
Tuple of (success, wait_time):
|
||||
- success: True if token was acquired, False if timed out
|
||||
- wait_time: Time spent waiting for token
|
||||
"""
|
||||
start_time = time.monotonic()
|
||||
wait_time = 0.0
|
||||
|
||||
with self._lock:
|
||||
self._refill()
|
||||
|
||||
if self.tokens >= 1:
|
||||
self.tokens -= 1
|
||||
self._stats.total_requests += 1
|
||||
self._stats.successful_requests += 1
|
||||
self._stats.current_tokens = self.tokens
|
||||
return True, 0.0
|
||||
|
||||
# Calculate time to wait for next token
|
||||
tokens_needed = 1 - self.tokens
|
||||
time_to_refill = tokens_needed / self.refill_rate
|
||||
|
||||
if time_to_refill > timeout:
|
||||
self._stats.total_requests += 1
|
||||
self._stats.denied_requests += 1
|
||||
return False, timeout
|
||||
|
||||
# Wait for tokens
|
||||
self._lock.release()
|
||||
time.sleep(time_to_refill)
|
||||
self._lock.acquire()
|
||||
|
||||
wait_time = time.monotonic() - start_time
|
||||
|
||||
with self._lock:
|
||||
self._refill()
|
||||
if self.tokens >= 1:
|
||||
self.tokens -= 1
|
||||
self._stats.total_requests += 1
|
||||
self._stats.successful_requests += 1
|
||||
self._stats.total_wait_time += wait_time
|
||||
self._stats.current_tokens = self.tokens
|
||||
return True, wait_time
|
||||
|
||||
self._stats.total_requests += 1
|
||||
self._stats.denied_requests += 1
|
||||
return False, wait_time
|
||||
|
||||
def acquire_nonblocking(self) -> tuple[bool, float]:
|
||||
"""Try to acquire a token without blocking.
|
||||
|
||||
Returns:
|
||||
Tuple of (success, wait_time):
|
||||
- success: True if token was acquired, False otherwise
|
||||
- wait_time: 0 for non-blocking, or required wait time if failed
|
||||
"""
|
||||
with self._lock:
|
||||
self._refill()
|
||||
|
||||
if self.tokens >= 1:
|
||||
self.tokens -= 1
|
||||
self._stats.total_requests += 1
|
||||
self._stats.successful_requests += 1
|
||||
self._stats.current_tokens = self.tokens
|
||||
return True, 0.0
|
||||
|
||||
# Calculate time needed
|
||||
tokens_needed = 1 - self.tokens
|
||||
time_to_refill = tokens_needed / self.refill_rate
|
||||
|
||||
self._stats.total_requests += 1
|
||||
self._stats.denied_requests += 1
|
||||
return False, time_to_refill
|
||||
|
||||
def _refill(self) -> None:
|
||||
"""Refill tokens based on elapsed time."""
|
||||
current_time = time.monotonic()
|
||||
elapsed = current_time - self.last_refill_time
|
||||
self.last_refill_time = current_time
|
||||
|
||||
tokens_to_add = elapsed * self.refill_rate
|
||||
self.tokens = min(self.capacity, self.tokens + tokens_to_add)
|
||||
|
||||
def get_available_tokens(self) -> float:
|
||||
"""Get the current number of available tokens.
|
||||
|
||||
Returns:
|
||||
Current token count
|
||||
"""
|
||||
with self._lock:
|
||||
self._refill()
|
||||
return self.tokens
|
||||
|
||||
def get_stats(self) -> RateLimiterStats:
|
||||
"""Get rate limiter statistics.
|
||||
|
||||
Returns:
|
||||
RateLimiterStats instance
|
||||
"""
|
||||
with self._lock:
|
||||
self._refill()
|
||||
self._stats.current_tokens = self.tokens
|
||||
return self._stats
|
||||
133
src/data/storage.py
Normal file
133
src/data/storage.py
Normal file
@@ -0,0 +1,133 @@
|
||||
"""Simplified HDF5 storage for data persistence."""
|
||||
import os
|
||||
import pandas as pd
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
from src.data.config import get_config
|
||||
|
||||
|
||||
class Storage:
|
||||
"""HDF5 storage manager for saving and loading data."""
|
||||
|
||||
def __init__(self, path: Optional[Path] = None):
|
||||
"""Initialize storage.
|
||||
|
||||
Args:
|
||||
path: Base path for data storage (auto-loaded from config if not provided)
|
||||
"""
|
||||
cfg = get_config()
|
||||
self.base_path = path or cfg.data_path
|
||||
self.base_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
def _get_file_path(self, name: str) -> Path:
|
||||
"""Get full path for an HDF5 file."""
|
||||
return self.base_path / f"{name}.h5"
|
||||
|
||||
def save(self, name: str, data: pd.DataFrame, mode: str = "append") -> dict:
|
||||
"""Save data to HDF5 file.
|
||||
|
||||
Args:
|
||||
name: Dataset name (also used as filename)
|
||||
data: DataFrame to save
|
||||
mode: 'append' or 'replace'
|
||||
|
||||
Returns:
|
||||
Dict with save result
|
||||
"""
|
||||
if data.empty:
|
||||
return {"status": "skipped", "rows": 0}
|
||||
|
||||
file_path = self._get_file_path(name)
|
||||
|
||||
try:
|
||||
with pd.HDFStore(file_path, mode="a") as store:
|
||||
if mode == "replace" or name not in store.keys():
|
||||
store.put(name, data, format="table")
|
||||
else:
|
||||
# Merge with existing data
|
||||
existing = store[name]
|
||||
combined = pd.concat([existing, data], ignore_index=True)
|
||||
combined = combined.drop_duplicates(subset=["ts_code", "trade_date"], keep="last")
|
||||
store.put(name, combined, format="table")
|
||||
|
||||
print(f"[Storage] Saved {len(data)} rows to {file_path}")
|
||||
return {"status": "success", "rows": len(data), "path": str(file_path)}
|
||||
|
||||
except Exception as e:
|
||||
print(f"[Storage] Error saving {name}: {e}")
|
||||
return {"status": "error", "error": str(e)}
|
||||
|
||||
def load(self, name: str,
|
||||
start_date: Optional[str] = None,
|
||||
end_date: Optional[str] = None,
|
||||
ts_code: Optional[str] = None) -> pd.DataFrame:
|
||||
"""Load data from HDF5 file.
|
||||
|
||||
Args:
|
||||
name: Dataset name
|
||||
start_date: Start date filter (YYYYMMDD)
|
||||
end_date: End date filter (YYYYMMDD)
|
||||
ts_code: Stock code filter
|
||||
|
||||
Returns:
|
||||
DataFrame with loaded data
|
||||
"""
|
||||
file_path = self._get_file_path(name)
|
||||
|
||||
if not file_path.exists():
|
||||
print(f"[Storage] File not found: {file_path}")
|
||||
return pd.DataFrame()
|
||||
|
||||
try:
|
||||
with pd.HDFStore(file_path, mode="r") as store:
|
||||
if name not in store.keys():
|
||||
return pd.DataFrame()
|
||||
|
||||
data = store[name]
|
||||
|
||||
# Apply filters
|
||||
if start_date and end_date and "trade_date" in data.columns:
|
||||
data = data[(data["trade_date"] >= start_date) & (data["trade_date"] <= end_date)]
|
||||
|
||||
if ts_code and "ts_code" in data.columns:
|
||||
data = data[data["ts_code"] == ts_code]
|
||||
|
||||
return data
|
||||
|
||||
except Exception as e:
|
||||
print(f"[Storage] Error loading {name}: {e}")
|
||||
return pd.DataFrame()
|
||||
|
||||
def get_last_date(self, name: str) -> Optional[str]:
|
||||
"""Get the latest date in storage.
|
||||
|
||||
Args:
|
||||
name: Dataset name
|
||||
|
||||
Returns:
|
||||
Latest date string or None
|
||||
"""
|
||||
data = self.load(name)
|
||||
if data.empty or "trade_date" not in data.columns:
|
||||
return None
|
||||
return str(data["trade_date"].max())
|
||||
|
||||
def exists(self, name: str) -> bool:
|
||||
"""Check if dataset exists."""
|
||||
return self._get_file_path(name).exists()
|
||||
|
||||
def delete(self, name: str) -> bool:
|
||||
"""Delete a dataset.
|
||||
|
||||
Args:
|
||||
name: Dataset name
|
||||
|
||||
Returns:
|
||||
True if deleted
|
||||
"""
|
||||
file_path = self._get_file_path(name)
|
||||
if file_path.exists():
|
||||
file_path.unlink()
|
||||
print(f"[Storage] Deleted {file_path}")
|
||||
return True
|
||||
return False
|
||||
Reference in New Issue
Block a user