refactor(factor): 完全重构因子计算框架 - 引入DSL表达式系统

- 删除旧因子框架:移除 base.py、composite.py、data_loader.py、data_spec.py
  及所有子模块(momentum、financial、quality、sentiment等)
- 新增DSL表达式系统:实现 factor DSL 编译器和翻译器
  - dsl.py: 领域特定语言定义
  - compiler.py: AST编译与优化
  - translator.py: Polars表达式翻译
  - api.py: 统一API接口
- 新增数据路由层:data_router.py 实现字段到表的动态路由
- 新增API封装:api_pro_bar.py 提供pro_bar数据接口
- 更新执行引擎:engine.py 适配新的DSL架构
- 重构测试体系:删除旧测试,新增 test_dsl_promotion.py、
  test_factor_integration.py、test_pro_bar.py
- 清理文档:删除8个过时文档(factor_design、db_sync_guide等)
This commit is contained in:
2026-02-27 22:22:23 +08:00
parent c3c20ed7ea
commit a56433e440
51 changed files with 667 additions and 11287 deletions

View File

@@ -3,7 +3,7 @@
Provides simplified interfaces for fetching and storing Tushare data.
"""
from src.data.config import Config, get_config
from src.config.settings import Settings, get_settings, settings
from src.data.client import TushareClient
from src.data.storage import Storage, ThreadSafeStorage, DEFAULT_TYPE_MAPPING
from src.data.api_wrappers import get_stock_basic, sync_all_stocks

View File

@@ -169,6 +169,120 @@ if "date" in data.columns:
### 4.5 令牌桶限速要求
所有 API 调用必须通过 `TushareClient`,自动满足令牌桶限速要求。
#### 4.5.1 基本用法(单线程场景)
```python
from src.data.client import TushareClient
def get_{data_type}(...) -> pd.DataFrame:
client = TushareClient()
# Build parameters
params = {}
if trade_date:
params["trade_date"] = trade_date
if ts_code:
params["ts_code"] = ts_code
# ...
# Fetch data (rate limiting handled automatically)
data = client.query("{api_name}", **params)
return data
```
**注意**: `TushareClient` 自动处理:
- 令牌桶速率限制
- API 重试逻辑(指数退避)
- 配置加载
#### 4.5.2 多线程/并发场景(重要)
**问题**: 多线程并发调用时,如果每个线程创建独立的 `TushareClient` 实例,每个实例会有独立的限流器,导致实际并发请求数 = 线程数 × 单个限流器速率,**限流失效**。
**解决方案**: 数据获取函数必须接受可选的 `client` 参数Sync 类传递共享的客户端实例。
**数据获取函数签名**(必须支持 client 参数):
```python
from src.data.client import TushareClient
from typing import Optional
def get_{data_type}(
ts_code: str,
start_date: Optional[str] = None,
end_date: Optional[str] = None,
client: Optional[TushareClient] = None, # 新增:可选客户端参数
) -> pd.DataFrame:
"""Fetch {数据描述} from Tushare.
Args:
ts_code: Stock code
start_date: Start date (YYYYMMDD)
end_date: End date (YYYYMMDD)
client: Optional TushareClient instance for shared rate limiting.
If None, creates a new client. For concurrent sync operations,
pass a shared client to ensure proper rate limiting.
Returns:
pd.DataFrame with data
"""
client = client or TushareClient() # 如果没有提供则创建新实例
params = {"ts_code": ts_code}
if start_date:
params["start_date"] = start_date
if end_date:
params["end_date"] = end_date
data = client.query("{api_name}", **params)
return data
```
**Sync 类实现**(必须传递共享 client
```python
from concurrent.futures import ThreadPoolExecutor
from src.data.client import TushareClient
from src.data.storage import ThreadSafeStorage
class {DataType}Sync:
def __init__(self, max_workers: Optional[int] = None):
self.storage = ThreadSafeStorage()
self.client = TushareClient() # 共享客户端实例
self.max_workers = max_workers or 10
def sync_single_stock(
self,
ts_code: str,
start_date: str,
end_date: str,
) -> pd.DataFrame:
"""同步单只股票的数据。"""
# 传递共享 client 以确保多线程下的限流生效
data = get_{data_type}(
ts_code=ts_code,
start_date=start_date,
end_date=end_date,
client=self.client, # 关键:传递共享客户端
)
return data
def sync_all(self, ...):
# 使用 ThreadPoolExecutor 并发执行
with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
# 所有线程共享 self.client限流器正常工作
...
```
**关键规则**:
1. 所有按股票获取的接口必须接受 `client: Optional[TushareClient] = None` 参数
2. Sync 类在 `__init__` 中创建 `self.client = TushareClient()`
3. Sync 类的同步方法必须将 `self.client` 传递给数据获取函数
4. 数据获取函数使用 `client = client or TushareClient()` 模式
所有 API 调用必须通过 `TushareClient`,自动满足令牌桶限速要求:
```python
@@ -198,6 +312,26 @@ def get_{data_type}(...) -> pd.DataFrame:
## 5. DuckDB 存储规范
### 5.0 强制落库要求(关键)
**所有封装的 API 接口必须将数据落库到 DuckDB。**
这是数据同步的核心原则,确保:
- 数据持久化:避免重复调用 API节省 token
- 增量更新:基于本地数据状态进行智能同步
- 数据一致性:所有数据都有统一的存储和访问方式
- 离线可用:数据可以在没有网络的情况下查询
**落库检查清单**
- [ ]`storage.py``_init_db()` 方法中创建对应的表
- [ ] 表结构必须包含 `ts_code``trade_date` 作为主键
- [ ] 实现 `sync_{data_type}()` 函数,使用 `Storage``ThreadSafeStorage` 保存数据
- [ ] 确保同步逻辑正确处理增量更新
**反例警示**`api_pro_bar.py` 早期版本虽然实现了 `sync_pro_bar()` 函数,但忘记在 `storage.py` 中创建 `pro_bar` 表,导致同步的数据无法落库,造成 token 浪费和数据丢失。
### 5.1 存储架构
### 5.1 存储架构
项目使用 **DuckDB** 作为持久化存储:

View File

@@ -5,6 +5,7 @@ All wrapper files follow the naming convention: api_{data_type}.py
Available APIs:
- api_daily: Daily market data (日线行情)
- api_pro_bar: Pro Bar universal market data (通用行情,后复权)
- api_stock_basic: Stock basic information (股票基本信息)
- api_trade_cal: Trading calendar (交易日历)
- api_namechange: Stock name change history (股票曾用名)
@@ -12,15 +13,31 @@ Available APIs:
Example:
>>> from src.data.api_wrappers import get_daily, get_stock_basic, get_trade_cal, get_bak_basic
>>> from src.data.api_wrappers import get_bak_basic, sync_bak_basic
>>> from src.data.api_wrappers import get_pro_bar, sync_pro_bar
>>> data = get_daily('000001.SZ', start_date='20240101', end_date='20240131')
>>> pro_data = get_pro_bar('000001.SZ', start_date='20240101', end_date='20240131')
>>> stocks = get_stock_basic()
>>> calendar = get_trade_cal('20240101', '20240131')
>>> bak_basic = get_bak_basic(trade_date='20240101')
"""
from src.data.api_wrappers.api_daily import get_daily, sync_daily, preview_daily_sync, DailySync
from src.data.api_wrappers.financial_data.api_income import get_income, sync_income, IncomeSync
from src.data.api_wrappers.api_daily import (
get_daily,
sync_daily,
preview_daily_sync,
DailySync,
)
from src.data.api_wrappers.api_pro_bar import (
get_pro_bar,
sync_pro_bar,
preview_pro_bar_sync,
ProBarSync,
)
from src.data.api_wrappers.financial_data.api_income import (
get_income,
sync_income,
IncomeSync,
)
from src.data.api_wrappers.api_bak_basic import get_bak_basic, sync_bak_basic
from src.data.api_wrappers.api_namechange import get_namechange, sync_namechange
from src.data.api_wrappers.api_stock_basic import get_stock_basic, sync_all_stocks
@@ -38,6 +55,11 @@ __all__ = [
"sync_daily",
"preview_daily_sync",
"DailySync",
# Pro Bar (universal market data)
"get_pro_bar",
"sync_pro_bar",
"preview_pro_bar_sync",
"ProBarSync",
# Income statement
"get_income",
"sync_income",

View File

@@ -345,4 +345,154 @@ df = pro.bak_basic(trade_date='20211012', fields='trade_date,ts_code,name,indust
4530 20211012 688255.SH 凯尔达 机械基件 0.0000
4531 20211012 688211.SH 中科微至 专用机械 0.0000
4532 20211012 605567.SH 春雪食品 食品 0.0000
4533 20211012 605566.SH 福莱蒽特 染料涂料 0.0000
4533 20211012 605566.SH 福莱蒽特 染料涂料 0.0000
通用行情接口
接口名称pro_bar本接口是集成开发接口部分指标是现用现算
更新时间股票和指数通常在15点17点之间数字货币实时更新具体请参考各接口文档明细。
描述目前整合了股票未复权、前复权、后复权、指数、数字货币、ETF基金、期货、期权的行情数据未来还将整合包括外汇在内的所有交易行情数据同时提供分钟数据。不同数据对应不同的积分要求具体请参阅每类数据的文档说明。
其它由于本接口是集成接口在SDK层做了一些逻辑处理目前暂时没法用http的方式调取通用行情接口。用户可以访问Tushare的Github查看源代码完成类似功能。
输入参数
名称 类型 必选 描述
ts_code str Y 证券代码,不支持多值输入,多值输入获取结果会有重复记录
start_date str N 开始日期 (日线格式YYYYMMDD提取分钟数据请用2019-09-01 09:00:00这种格式)
end_date str N 结束日期 (日线格式YYYYMMDD)
asset str Y 资产类别E股票 I沪深指数 C数字货币 FT期货 FD基金 O期权 CB可转债v1.2.39默认E
adj str N 复权类型(只针对股票)None未复权 qfq前复权 hfq后复权 , 默认None目前只支持日线复权同时复权机制是根据设定的end_date参数动态复权采用分红再投模式具体请参考常见问题列表里的说明。
freq str Y 数据频度 :支持分钟(min)/日(D)/周(W)/月(M)K线其中1min表示1分钟类推1/5/15/30/60分钟 默认D。对于分钟数据有600积分用户可以试用请求2次正式权限可以参考权限列表说明 ,使用方法请参考股票分钟使用方法。
ma list N 均线支持任意合理int数值。注均线是动态计算要设置一定时间范围才能获得相应的均线比如5日均线开始和结束日期参数跨度必须要超过5日。目前只支持单一个股票提取均线即需要输入ts_code参数。e.g: ma_5表示5日均价ma_v_5表示5日均量
factors list N 股票因子asset='E'有效)支持 tor换手率 vr量比
adjfactor str N 复权因子在复权数据时如果此参数为True返回的数据中则带复权因子默认为False。 该功能从1.2.33版本开始生效
输出指标
具体输出的数据指标可参考各行情具体指标:
股票Dailyhttps://tushare.pro/document/2?doc_id=27
内容如下A股日线行情
接口daily可以通过数据工具调试和查看数据
数据说明交易日每天15点16点之间入库。本接口是未复权行情停牌期间不提供数据
调取说明基础积分每分钟内可调取500次每次6000条数据一次请求相当于提取一个股票23年历史
描述:获取股票行情数据,或通过通用行情接口获取数据,包含了前后复权数据
输入参数
名称 类型 必选 描述
ts_code str N 股票代码(支持多个股票同时提取,逗号分隔)
trade_date str N 交易日期YYYYMMDD
start_date str N 开始日期(YYYYMMDD)
end_date str N 结束日期(YYYYMMDD)
日期都填YYYYMMDD格式比如20181010
输出参数
名称 类型 描述
ts_code str 股票代码
trade_date str 交易日期
open float 开盘价
high float 最高价
low float 最低价
close float 收盘价
pre_close float 昨收价【除权价】
change float 涨跌额
pct_chg float 涨跌幅 【基于除权后的昨收计算的涨跌幅:(今收-除权昨收)/除权昨收 】
vol float 成交量 (手)
amount float 成交额 (千元)
接口示例
pro = ts.pro_api()
df = pro.daily(ts_code='000001.SZ', start_date='20180701', end_date='20180718')
#多个股票
df = pro.daily(ts_code='000001.SZ,600000.SH', start_date='20180701', end_date='20180718')
或者
df = pro.query('daily', ts_code='000001.SZ', start_date='20180701', end_date='20180718')
也可以通过日期取历史某一天的全部历史
df = pro.daily(trade_date='20180810')
数据样例
ts_code trade_date open high low close pre_close change pct_chg vol amount
0 000001.SZ 20180718 8.75 8.85 8.69 8.70 8.72 -0.02 -0.23 525152.77 460697.377
1 000001.SZ 20180717 8.74 8.75 8.66 8.72 8.73 -0.01 -0.11 375356.33 326396.994
2 000001.SZ 20180716 8.85 8.90 8.69 8.73 8.88 -0.15 -1.69 689845.58 603427.713
3 000001.SZ 20180713 8.92 8.94 8.82 8.88 8.88 0.00 0.00 603378.21 535401.175
4 000001.SZ 20180712 8.60 8.97 8.58 8.88 8.64 0.24 2.78 1140492.31 1008658.828
5 000001.SZ 20180711 8.76 8.83 8.68 8.78 8.98 -0.20 -2.23 851296.70 744765.824
6 000001.SZ 20180710 9.02 9.02 8.89 8.98 9.03 -0.05 -0.55 896862.02 803038.965
7 000001.SZ 20180709 8.69 9.03 8.68 9.03 8.66 0.37 4.27 1409954.60 1255007.609
8 000001.SZ 20180706 8.61 8.78 8.45 8.66 8.60 0.06 0.70 988282.69 852071.526
9 000001.SZ 20180705 8.62 8.73 8.55 8.60 8.61 -0.01 -0.12 835768.77 722169.579
基金Dailyhttps://tushare.pro/document/2?doc_id=127
期货Dailyhttps://tushare.pro/document/2?doc_id=138
期权Dailyhttps://tushare.pro/document/2?doc_id=159
指数Dailyhttps://tushare.pro/document/2?doc_id=95
接口用例
#取000001的前复权行情
df = ts.pro_bar(ts_code='000001.SZ', adj='qfq', start_date='20180101', end_date='20181011')
ts_code trade_date open high low close \
trade_date
20181011 000001.SZ 20181011 1085.71 1097.59 1047.90 1065.19
20181010 000001.SZ 20181010 1138.65 1151.61 1121.36 1128.92
20181009 000001.SZ 20181009 1130.00 1155.93 1122.44 1140.81
20181008 000001.SZ 20181008 1155.93 1165.65 1128.92 1128.92
20180928 000001.SZ 20180928 1164.57 1217.51 1164.57 1193.74
#取上证指数行情数据
df = ts.pro_bar(ts_code='000001.SH', asset='I', start_date='20180101', end_date='20181011')
In [10]: df.head()
Out[10]:
ts_code trade_date close open high low \
0 000001.SH 20181011 2583.4575 2643.0740 2661.2859 2560.3164
1 000001.SH 20181010 2725.8367 2723.7242 2743.5480 2703.0626
2 000001.SH 20181009 2721.0130 2713.7319 2734.3142 2711.1971
3 000001.SH 20181008 2716.5104 2768.2075 2771.9384 2710.1781
4 000001.SH 20180928 2821.3501 2794.2644 2821.7553 2791.8363
pre_close change pct_chg vol amount
0 2725.8367 -142.3792 -5.2233 197150702.0 170057762.5
1 2721.0130 4.8237 0.1773 113485736.0 111312455.3
2 2716.5104 4.5026 0.1657 116771899.0 110292457.8
3 2821.3501 -104.8397 -3.7159 149501388.0 141531551.8
4 2791.7748 29.5753 1.0594 134290456.0 125369989.4
#均线
df = ts.pro_bar(ts_code='000001.SZ', start_date='20180101', end_date='20181011', ma=[5, 20, 50])
Tushare pro_bar接口的均价和均量数据是动态计算想要获取某个时间段的均线必须要设置start_date日期大于最大均线的日期数然后自行截取想要日期段。例如想要获取20190801开始的3日均线必须设置start_date='20190729'然后剔除20190801之前的日期记录。
#换手率tor量比vr
df = ts.pro_bar(ts_code='000001.SZ', start_date='20180101', end_date='20181011', factors=['tor', 'vr'])
说明
对于pro_api参数如果在一开始就通过 ts.set_token('xxxx') 设置过token的情况这个参数就不是必需的。
例如:
df = ts.pro_bar(ts_code='000001.SH', asset='I', start_date='20180101', end_date='20181011')

View File

@@ -129,7 +129,9 @@ def sync_bak_basic(
columns = []
for col in sample.columns:
dtype = str(sample[col].dtype)
if "int" in dtype:
if col == "trade_date":
col_type = "DATE"
elif "int" in dtype:
col_type = "INTEGER"
elif "float" in dtype:
col_type = "DOUBLE"
@@ -223,10 +225,16 @@ def sync_bak_basic(
# Combine and save
combined = pd.concat(all_data, ignore_index=True)
# Convert trade_date to datetime for proper DATE type storage
combined["trade_date"] = pd.to_datetime(combined["trade_date"], format="%Y%m%d")
print(f"[sync_bak_basic] Total records: {len(combined)}")
# Delete existing data for the date range and append new data
storage._connection.execute(f'DELETE FROM "{TABLE_NAME}" WHERE "trade_date" >= ?', [sync_start])
# Convert sync_start to date format for comparison with DATE column
sync_start_date = pd.to_datetime(sync_start, format="%Y%m%d").date()
storage._connection.execute(f'DELETE FROM "{TABLE_NAME}" WHERE "trade_date" >= ?', [sync_start_date])
thread_storage.queue_save(TABLE_NAME, combined)
thread_storage.flush()

View File

@@ -17,6 +17,7 @@ import threading
from src.data.client import TushareClient
from src.data.storage import ThreadSafeStorage, Storage
from src.data.utils import get_today_date, get_next_date, DEFAULT_START_DATE
from src.config.settings import get_settings
from src.data.api_wrappers.api_trade_cal import (
get_first_trading_day,
get_last_trading_day,
@@ -105,16 +106,15 @@ class DailySync:
- 预览模式(预览同步数据量,不实际写入)
"""
# 默认工作线程数
DEFAULT_MAX_WORKERS = 10
# 默认工作线程数从配置读取默认10
DEFAULT_MAX_WORKERS = get_settings().threads
def __init__(self, max_workers: Optional[int] = None):
"""初始化同步管理器。
Args:
max_workers: 工作线程数(默认: 10
max_workers: 工作线程数(默认从配置读取
"""
self.storage = ThreadSafeStorage()
self.client = TushareClient()
self.max_workers = max_workers or self.DEFAULT_MAX_WORKERS
self._stop_flag = threading.Event()

View File

@@ -8,13 +8,13 @@ import pandas as pd
from pathlib import Path
from typing import Optional, List
from src.data.client import TushareClient
from src.data.config import get_config
from src.config.settings import get_settings
# CSV file path for namechange data
def _get_csv_path() -> Path:
"""Get the CSV file path for namechange data."""
cfg = get_config()
cfg = get_settings()
return cfg.data_path_resolved / "namechange.csv"

View File

@@ -9,13 +9,13 @@ import pandas as pd
from pathlib import Path
from typing import Optional, Literal, List
from src.data.client import TushareClient
from src.data.config import get_config
from src.config.settings import get_settings
# CSV file path for stock basic data
def _get_csv_path() -> Path:
"""Get the CSV file path for stock basic data."""
cfg = get_config()
cfg = get_settings()
return cfg.data_path_resolved / "stock_basic.csv"

View File

@@ -8,7 +8,7 @@ import pandas as pd
from typing import Optional, Literal
from pathlib import Path
from src.data.client import TushareClient
from src.data.config import get_config
from src.config.settings import get_settings
# Module-level flag to track if cache has been synced in this session
_cache_synced = False
@@ -18,7 +18,7 @@ _cache_synced = False
# Trading calendar cache file path
def _get_cache_path() -> Path:
"""Get the cache file path for trade calendar."""
cfg = get_config()
cfg = get_settings()
return cfg.data_path_resolved / "trade_cal.h5"
@@ -296,8 +296,8 @@ def get_first_trading_day(
trading_days = get_trading_days(start_date, end_date, exchange)
if not trading_days:
return None
# Trading days are sorted in descending order (newest first) from cache
return trading_days[-1]
# Return the earliest trading day
return min(trading_days)
def get_last_trading_day(
@@ -318,8 +318,8 @@ def get_last_trading_day(
trading_days = get_trading_days(start_date, end_date, exchange)
if not trading_days:
return None
# Trading days are sorted in descending order (newest first) from cache
return trading_days[0]
# Return the latest trading day
return max(trading_days)
if __name__ == "__main__":

View File

@@ -1,21 +1,25 @@
"""Simplified Tushare client with rate limiting and retry logic."""
import time
import pandas as pd
from typing import Optional
from src.data.config import get_config
from src.data.rate_limiter import TokenBucketRateLimiter
from src.config.settings import get_settings
class TushareClient:
"""Tushare API client with rate limiting and retry."""
# 类级别共享限流器(确保所有实例共享同一个限流器)
_shared_limiter: Optional[TokenBucketRateLimiter] = None
def __init__(self, token: Optional[str] = None):
"""Initialize client.
Args:
token: Tushare API token (auto-loaded from config if not provided)
"""
cfg = get_config()
cfg = get_settings()
token = token or cfg.tushare_token
if not token:
@@ -24,12 +28,21 @@ class TushareClient:
self.token = token
self.config = cfg
# Initialize rate limiter: capacity = rate_limit, refill_rate = rate_limit/60 per second
# 初始化共享限流器(确保所有 TushareClient 实例共享同一个限流器)
rate_per_second = cfg.rate_limit / 60.0
self.rate_limiter = TokenBucketRateLimiter(
capacity=cfg.rate_limit,
refill_rate_per_second=rate_per_second,
)
capacity = cfg.rate_limit
if TushareClient._shared_limiter is None:
# 首次创建:初始化共享限流器
TushareClient._shared_limiter = TokenBucketRateLimiter(
capacity=capacity,
refill_rate_per_second=rate_per_second,
)
print(
f"[TushareClient] Initialized shared rate limiter: capacity={capacity}, window=60s"
)
# 复用共享限流器
self.rate_limiter = TushareClient._shared_limiter
self._api = None
@@ -37,6 +50,7 @@ class TushareClient:
"""Get Tushare API instance."""
if self._api is None:
import tushare as ts
self._api = ts.pro_api(self.token)
return self._api
@@ -52,7 +66,7 @@ class TushareClient:
DataFrame with query results
"""
# Acquire rate limit token (None = wait indefinitely)
timeout = timeout if timeout is not None else float('inf')
timeout = timeout if timeout is not None else float("inf")
success, wait_time = self.rate_limiter.acquire(timeout=timeout)
if not success:
@@ -72,14 +86,21 @@ class TushareClient:
# pro_bar uses ts.pro_bar() instead of api.query()
if api_name == "pro_bar":
# pro_bar parameters: ts_code, start_date, end_date, adj, freq, factors, ma, adjfactor
data = ts.pro_bar(ts_code=params.get("ts_code"),
start_date=params.get("start_date"),
end_date=params.get("end_date"),
adj=params.get("adj"),
freq=params.get("freq", "D"),
factors=params.get("factors"), # factors should be a list like ['tor', 'vr']
ma=params.get("ma"),
adjfactor=params.get("adjfactor"))
data = ts.pro_bar(
ts_code=params.get("ts_code"),
start_date=params.get("start_date"),
end_date=params.get("end_date"),
adj=params.get("adj"),
freq=params.get("freq", "D"),
factors=params.get(
"factors"
), # factors should be a list like ['tor', 'vr']
ma=params.get("ma"),
adjfactor=params.get("adjfactor"),
)
# Handle None response (e.g., delisted stock)
if data is None:
data = pd.DataFrame()
else:
api = self._get_api()
data = api.query(api_name, **params)
@@ -89,10 +110,14 @@ class TushareClient:
except Exception as e:
if attempt < max_retries - 1:
delay = retry_delays[attempt]
print(f"[Retry] {api_name} failed (attempt {attempt + 1}): {e}, retry in {delay}s")
print(
f"[Retry] {api_name} failed (attempt {attempt + 1}): {e}, retry in {delay}s"
)
time.sleep(delay)
else:
raise RuntimeError(f"API call failed after {max_retries} attempts: {e}")
raise RuntimeError(
f"API call failed after {max_retries} attempts: {e}"
)
return pd.DataFrame()

View File

@@ -1,80 +0,0 @@
"""Configuration management for data collection module."""
import os
from pathlib import Path
from pydantic_settings import BaseSettings
# Config directory path - used for loading .env.local
# Static detection for pydantic-settings to find .env.local
CONFIG_DIR = Path(__file__).parent.parent.parent / "config"
def _get_project_root() -> Path:
"""Get project root path from ROOT_PATH env var or auto-detect."""
# Try to read from environment variable first
root_path = os.environ.get("ROOT_PATH") or os.environ.get("DATA_ROOT")
if root_path:
return Path(root_path)
# Fallback to auto-detection
return Path(__file__).parent.parent.parent
class Config(BaseSettings):
"""Application configuration loaded from environment variables."""
# Tushare API token
tushare_token: str = ""
# Root path - loaded from environment variable ROOT_PATH
# If not set, uses auto-detected path
root_path: str = ""
# Data storage path - can be set via DATA_PATH environment variable
# If relative path, it will be resolved relative to root_path
data_path: str = "data"
# Rate limit: requests per minute
rate_limit: int = 100
# Thread pool size
threads: int = 2
@property
def project_root(self) -> Path:
"""Get project root path."""
if self.root_path:
return Path(self.root_path)
return _get_project_root()
@property
def data_path_resolved(self) -> Path:
"""Get resolved data path (absolute)."""
path = Path(self.data_path)
if path.is_absolute():
return path
return self.project_root / path
class Config:
# 从 config/ 目录读取 .env.local 文件
env_file = str(CONFIG_DIR / ".env.local")
env_file_encoding = "utf-8"
case_sensitive = False
extra = "ignore" # 忽略 .env.local 中的额外变量
# pydantic-settings 默认会将字段名转换为大写作为环境变量名
# 所以 tushare_token 会映射到 TUSHARE_TOKEN
# root_path 会映射到 ROOT_PATH
# data_path 会映射到 DATA_PATH
# Global config instance
config = Config()
def get_config() -> Config:
"""Get configuration instance."""
return config
def get_project_root() -> Path:
"""Get project root path (convenience function)."""
return get_config().project_root

View File

@@ -32,9 +32,12 @@ def get_db_info(db_path: Optional[Path] = None):
# Get database path
if db_path is None:
from src.data.config import get_config
from src.config.settings import get_settings
cfg = get_config()
cfg = get_settings()
db_path = cfg.data_path_resolved / "prostock.db"
cfg = get_settings()
db_path = cfg.data_path_resolved / "prostock.db"
else:
db_path = Path(db_path)
@@ -231,9 +234,12 @@ def get_table_sample(table_name: str, limit: int = 5, db_path: Optional[Path] =
db_path: Path to database file
"""
if db_path is None:
from src.data.config import get_config
from src.config.settings import get_settings
cfg = get_config()
cfg = get_settings()
db_path = cfg.data_path_resolved / "prostock.db"
cfg = get_settings()
db_path = cfg.data_path_resolved / "prostock.db"
else:
db_path = Path(db_path)

View File

@@ -1,35 +1,35 @@
"""Token bucket rate limiter implementation.
"""API 速率限制器实现。
This module provides a thread-safe token bucket algorithm for rate limiting.
提供基于固定时间窗口的速率限制,适合 Tushare 等按分钟计费的 API。
"""
import time
import threading
from typing import Optional
from dataclasses import dataclass, field
from dataclasses import dataclass
@dataclass
class RateLimiterStats:
"""Statistics for rate limiter."""
"""速率限制器统计信息。"""
total_requests: int = 0
successful_requests: int = 0
denied_requests: int = 0
total_wait_time: float = 0.0
current_tokens: Optional[float] = None
current_window_requests: int = 0
window_start_time: float = 0.0
class TokenBucketRateLimiter:
"""Thread-safe token bucket rate limiter.
"""基于固定时间窗口的速率限制器。
Implements a token bucket algorithm for controlling request rate.
Tokens are added at a fixed rate up to the bucket capacity.
适合 Tushare 等按时间窗口(如每分钟)限制请求数的 API 场景。
在窗口期内,请求数达到上限后将阻塞或等待下一个窗口。
Attributes:
capacity: Maximum number of tokens in the bucket
refill_rate: Number of tokens added per second
initial_tokens: Initial number of tokens (default: capacity)
capacity: 每个时间窗口内允许的最大请求数
window_seconds: 时间窗口长度(秒)
"""
def __init__(
@@ -38,155 +38,157 @@ class TokenBucketRateLimiter:
refill_rate_per_second: float = 1.67,
initial_tokens: Optional[int] = None,
) -> None:
"""Initialize the token bucket rate limiter.
"""初始化速率限制器。
Args:
capacity: Maximum token capacity
refill_rate_per_second: Token refill rate per second
initial_tokens: Initial token count (default: capacity)
capacity: 每个时间窗口内允许的最大请求数
refill_rate_per_second: 保留参数(向后兼容),实际使用 window_seconds=60
initial_tokens: 保留参数(向后兼容)
"""
self.capacity = capacity
self.refill_rate = refill_rate_per_second
self.tokens = float(initial_tokens if initial_tokens is not None else capacity)
self.last_refill_time = time.monotonic()
# Tushare 通常按分钟限制,所以固定使用 60 秒窗口
self.window_seconds = 60.0
self._requests_in_window = 0
self._window_start = time.monotonic()
self._lock = threading.RLock()
self._stats = RateLimiterStats()
self._stats.current_tokens = self.tokens
self._stats.window_start_time = self._window_start
def _is_new_window(self) -> bool:
"""检查是否已进入新的时间窗口。"""
current_time = time.monotonic()
elapsed = current_time - self._window_start
return elapsed >= self.window_seconds
def _reset_window(self) -> None:
"""重置时间窗口。"""
self._window_start = time.monotonic()
self._requests_in_window = 0
self._stats.window_start_time = self._window_start
def acquire(self, timeout: float = float("inf")) -> tuple[bool, float]:
"""Acquire a token from the bucket.
"""获取请求许可。
Blocks until a token is available or timeout expires.
如果在当前窗口内请求数已达上限,则等待到下一个窗口。
Args:
timeout: Maximum time to wait for a token in seconds (default: inf)
timeout: 最大等待时间(秒),默认无限等待
Returns:
Tuple of (success, wait_time):
- success: True if token was acquired, False if timed out
- wait_time: Time spent waiting for token
(success, wait_time): 是否成功获取许可,以及等待时间
"""
start_time = time.monotonic()
wait_time = 0.0
with self._lock:
self._refill()
# 检查是否需要进入新窗口
if self._is_new_window():
self._reset_window()
if self.tokens >= 1:
self.tokens -= 1
# 如果当前窗口还有余量,直接通过
if self._requests_in_window < self.capacity:
self._requests_in_window += 1
self._stats.total_requests += 1
self._stats.successful_requests += 1
self._stats.current_tokens = self.tokens
self._stats.current_window_requests = self._requests_in_window
return True, 0.0
# Calculate time to wait for next token
tokens_needed = 1 - self.tokens
time_to_refill = tokens_needed / self.refill_rate
# 当前窗口已满,计算需要等待的时间
current_time = time.monotonic()
time_to_next_window = self.window_seconds - (
current_time - self._window_start
)
# Check if we can wait for the token within timeout
# Handle infinite timeout specially
is_infinite_timeout = timeout == float("inf")
if not is_infinite_timeout and time_to_refill > timeout:
if time_to_next_window <= 0:
# 刚好进入新窗口
self._reset_window()
self._requests_in_window = 1
self._stats.total_requests += 1
self._stats.successful_requests += 1
self._stats.current_window_requests = 1
return True, 0.0
# 检查是否能在超时时间内等待
if timeout != float("inf") and time_to_next_window > timeout:
self._stats.total_requests += 1
self._stats.denied_requests += 1
return False, timeout
# Wait for tokens - loop until we get one or timeout
while True:
# Calculate remaining time we can wait
elapsed = time.monotonic() - start_time
remaining_timeout = (
timeout - elapsed if not is_infinite_timeout else float("inf")
)
# 需要等待到下一个窗口
if timeout != float("inf"):
time_to_wait = min(time_to_next_window, timeout)
else:
time_to_wait = time_to_next_window
# Check if we've exceeded timeout
if not is_infinite_timeout and remaining_timeout <= 0:
self._stats.total_requests += 1
self._stats.denied_requests += 1
return False, elapsed
time.sleep(time_to_wait)
# Calculate wait time for next token
tokens_needed = max(0, 1 - self.tokens)
time_to_wait = (
tokens_needed / self.refill_rate if tokens_needed > 0 else 0.1
)
# If we can't wait long enough, fail
if not is_infinite_timeout and time_to_wait > remaining_timeout:
self._stats.total_requests += 1
self._stats.denied_requests += 1
return False, elapsed
# Wait outside the lock to allow other threads to refill
self._lock.release()
time.sleep(
min(time_to_wait, 0.1)
) # Cap wait to 100ms to check frequently
self._lock.acquire()
# Refill and check again
self._refill()
if self.tokens >= 1:
self.tokens -= 1
wait_time = time.monotonic() - start_time
self._stats.total_requests += 1
self._stats.successful_requests += 1
self._stats.total_wait_time += wait_time
self._stats.current_tokens = self.tokens
return True, wait_time
def acquire_nonblocking(self) -> tuple[bool, float]:
"""Try to acquire a token without blocking.
Returns:
Tuple of (success, wait_time):
- success: True if token was acquired, False otherwise
- wait_time: 0 for non-blocking, or required wait time if failed
"""
# 重新尝试获取许可
with self._lock:
self._refill()
# 再次检查窗口状态(可能其他线程已经重置了窗口)
if self._is_new_window():
self._reset_window()
if self.tokens >= 1:
self.tokens -= 1
if self._requests_in_window < self.capacity:
self._requests_in_window += 1
wait_time = time.monotonic() - start_time
self._stats.total_requests += 1
self._stats.successful_requests += 1
self._stats.current_tokens = self.tokens
self._stats.total_wait_time += wait_time
self._stats.current_window_requests = self._requests_in_window
return True, wait_time
else:
# 在极端情况下,等待后仍然无法获取(其他线程抢先)
wait_time = time.monotonic() - start_time
self._stats.total_requests += 1
self._stats.denied_requests += 1
return False, wait_time
def acquire_nonblocking(self) -> tuple[bool, float]:
"""尝试非阻塞地获取请求许可。
Returns:
(success, wait_time): 是否成功获取许可,以及需要等待的时间
"""
with self._lock:
# 检查是否需要进入新窗口
if self._is_new_window():
self._reset_window()
# 如果当前窗口还有余量,直接通过
if self._requests_in_window < self.capacity:
self._requests_in_window += 1
self._stats.total_requests += 1
self._stats.successful_requests += 1
self._stats.current_window_requests = self._requests_in_window
return True, 0.0
# Calculate time needed
tokens_needed = 1 - self.tokens
time_to_refill = tokens_needed / self.refill_rate
# 当前窗口已满,计算需要等待的时间
current_time = time.monotonic()
time_to_next_window = self.window_seconds - (
current_time - self._window_start
)
self._stats.total_requests += 1
self._stats.denied_requests += 1
return False, time_to_refill
def _refill(self) -> None:
"""Refill tokens based on elapsed time."""
current_time = time.monotonic()
elapsed = current_time - self.last_refill_time
self.last_refill_time = current_time
tokens_to_add = elapsed * self.refill_rate
self.tokens = min(self.capacity, self.tokens + tokens_to_add)
return False, max(0.0, time_to_next_window)
def get_available_tokens(self) -> float:
"""Get the current number of available tokens.
"""获取当前窗口剩余可用请求数。
Returns:
Current token count
当前窗口剩余可用请求数
"""
with self._lock:
self._refill()
return self.tokens
if self._is_new_window():
return float(self.capacity)
return float(self.capacity - self._requests_in_window)
def get_stats(self) -> RateLimiterStats:
"""Get rate limiter statistics.
"""获取速率限制器统计信息。
Returns:
RateLimiterStats instance
RateLimiterStats 实例
"""
with self._lock:
self._refill()
self._stats.current_tokens = self.tokens
self._stats.current_window_requests = self._requests_in_window
return self._stats

View File

@@ -6,7 +6,7 @@ from pathlib import Path
from typing import Optional, List, Dict, Any, Tuple
from collections import defaultdict
from datetime import datetime
from src.data.config import get_config
from src.config.settings import get_settings
# Default column type mapping for automatic schema inference
@@ -53,7 +53,7 @@ class Storage:
if hasattr(self, "_initialized"):
return
cfg = get_config()
cfg = get_settings()
self.base_path = path or cfg.data_path_resolved
self.base_path.mkdir(parents=True, exist_ok=True)
self.db_path = self.base_path / "prostock.db"
@@ -190,6 +190,26 @@ class Storage:
update_flag VARCHAR(1),
PRIMARY KEY (ts_code, end_date)
)
# Create pro_bar table for pro bar data (with adj, tor, vr)
self._connection.execute("""
CREATE TABLE IF NOT EXISTS pro_bar (
ts_code VARCHAR(16) NOT NULL,
trade_date DATE NOT NULL,
open DOUBLE,
high DOUBLE,
low DOUBLE,
close DOUBLE,
pre_close DOUBLE,
change DOUBLE,
pct_chg DOUBLE,
vol DOUBLE,
amount DOUBLE,
tor DOUBLE,
vr DOUBLE,
adj_factor DOUBLE,
PRIMARY KEY (ts_code, trade_date)
)
""")
# Create index for financial_income

View File

@@ -29,6 +29,7 @@ import pandas as pd
from src.data.api_wrappers import sync_all_stocks
from src.data.api_wrappers.api_daily import sync_daily, preview_daily_sync
from src.data.api_wrappers.api_pro_bar import sync_pro_bar
def preview_sync(
@@ -134,7 +135,6 @@ def sync_all_data(
dry_run: bool = False,
) -> Dict[str, pd.DataFrame]:
"""同步所有数据类型(每日同步)。
该函数按顺序同步所有可用的数据类型:
1. 交易日历 (sync_trade_cal_cache)
2. 股票基本信息 (sync_all_stocks)
@@ -146,13 +146,12 @@ def sync_all_data(
Args:
force_full: 若为 True强制所有数据类型完整重载
max_workers: 日线数据同步的工作线程数(默认: 10
dry_run: 若为 True仅显示将要同步的内容 Returns:
映射数据类型,不写入数据
dry_run: 若为 True仅显示将要同步的内容,不写入数据
到同步结果的字典
Returns:
映射数据类型到同步结果的字典
Example:
>>> # 同步所有数据(增量)
>>> result = sync_all_data()
>>>
>>> # 强制完整重载
@@ -167,6 +166,92 @@ def sync_all_data(
print("[sync_all_data] Starting full data synchronization...")
print("=" * 60)
# 1. Sync trade calendar (always needed first)
print("\n[1/6] Syncing trade calendar cache...")
try:
from src.data.api_wrappers import sync_trade_cal_cache
sync_trade_cal_cache()
results["trade_cal"] = pd.DataFrame()
print("[1/6] Trade calendar: OK")
except Exception as e:
print(f"[1/6] Trade calendar: FAILED - {e}")
results["trade_cal"] = pd.DataFrame()
# 2. Sync stock basic info
print("\n[2/6] Syncing stock basic info...")
try:
sync_all_stocks()
results["stock_basic"] = pd.DataFrame()
print("[2/6] Stock basic: OK")
except Exception as e:
print(f"[2/6] Stock basic: FAILED - {e}")
results["stock_basic"] = pd.DataFrame()
# # 3. Sync daily market data
# print("\n[3/6] Syncing daily market data...")
# try:
# daily_result = sync_daily(
# force_full=force_full,
# max_workers=max_workers,
# dry_run=dry_run,
# )
# results["daily"] = (
# pd.concat(daily_result.values(), ignore_index=True)
# if daily_result
# else pd.DataFrame()
# )
# print("[3/6] Daily data: OK")
# except Exception as e:
# print(f"[3/6] Daily data: FAILED - {e}")
# results["daily"] = pd.DataFrame()
# 4. Sync Pro Bar data
print("\n[4/6] Syncing Pro Bar data (with adj, tor, vr)...")
try:
pro_bar_result = sync_pro_bar(
force_full=force_full,
max_workers=max_workers,
dry_run=dry_run,
)
results["pro_bar"] = (
pd.concat(pro_bar_result.values(), ignore_index=True)
if pro_bar_result
else pd.DataFrame()
)
print(f"[4/6] Pro Bar data: OK ({len(results['pro_bar'])} records)")
except Exception as e:
print(f"[4/6] Pro Bar data: FAILED - {e}")
results["pro_bar"] = pd.DataFrame()
# 5. Sync stock historical list (bak_basic)
print("\n[5/6] Syncing stock historical list (bak_basic)...")
try:
bak_basic_result = sync_bak_basic(force_full=force_full)
results["bak_basic"] = bak_basic_result
print(f"[5/6] Bak basic: OK ({len(bak_basic_result)} records)")
except Exception as e:
print(f"[5/6] Bak basic: FAILED - {e}")
results["bak_basic"] = pd.DataFrame()
# Summary
print("\n" + "=" * 60)
print("[sync_all_data] Sync Summary")
print("=" * 60)
for data_type, df in results.items():
print(f" {data_type}: {len(df)} records")
print("=" * 60)
print("\nNote: namechange is NOT in auto-sync. To sync manually:")
print(" from src.data.api_wrappers import sync_namechange")
print(" sync_namechange(force=True)")
return results
results: Dict[str, pd.DataFrame] = {}
print("\n" + "=" * 60)
print("[sync_all_data] Starting full data synchronization...")
print("=" * 60)
# 1. Sync trade calendar (always needed first)
print("\n[1/5] Syncing trade calendar cache...")
try: