refactor: 存储层迁移DuckDB + 模块重构
- 存储层重构: HDF5 → DuckDB(UPSERT模式、线程安全存储) - Sync类迁移: DataSync从sync.py迁移到api_daily.py(职责分离) - 模型模块重构: src/models → src/pipeline(更清晰的命名) - 新增因子模块: factors/momentum (MA、收益率排名)、factors/financial - 新增API接口: api_namechange、api_bak_basic - 新增训练入口: training模块(main.py、pipeline配置) - 工具函数统一: get_today_date等移至utils.py - 文档更新: AGENTS.md添加架构变更历史
This commit is contained in:
243
src/data/api_wrappers/api_bak_basic.py
Normal file
243
src/data/api_wrappers/api_bak_basic.py
Normal file
@@ -0,0 +1,243 @@
|
||||
"""Stock historical list interface.
|
||||
|
||||
Fetch daily stock list from Tushare bak_basic API.
|
||||
Data available from 2016 onwards.
|
||||
"""
|
||||
|
||||
import pandas as pd
|
||||
from typing import Optional, List
|
||||
from datetime import datetime, timedelta
|
||||
from tqdm import tqdm
|
||||
from src.data.client import TushareClient
|
||||
from src.data.storage import ThreadSafeStorage, Storage
|
||||
from src.data.db_manager import ensure_table
|
||||
|
||||
|
||||
def get_bak_basic(
|
||||
trade_date: Optional[str] = None,
|
||||
ts_code: Optional[str] = None,
|
||||
) -> pd.DataFrame:
|
||||
"""Fetch historical stock list from Tushare.
|
||||
|
||||
This interface retrieves the daily stock list including basic information
|
||||
for all stocks on a specific trade date. Data is available from 2016 onwards.
|
||||
|
||||
Args:
|
||||
trade_date: Specific trade date in YYYYMMDD format
|
||||
ts_code: Stock code filter (optional, e.g., '000001.SZ')
|
||||
|
||||
Returns:
|
||||
pd.DataFrame with columns:
|
||||
- trade_date: Trade date (YYYYMMDD)
|
||||
- ts_code: TS stock code
|
||||
- name: Stock name
|
||||
- industry: Industry
|
||||
- area: Region
|
||||
- pe: P/E ratio (dynamic)
|
||||
- float_share: Float shares (100 million)
|
||||
- total_share: Total shares (100 million)
|
||||
- total_assets: Total assets (100 million)
|
||||
- liquid_assets: Liquid assets (100 million)
|
||||
- fixed_assets: Fixed assets (100 million)
|
||||
- reserved: Reserve fund
|
||||
- reserved_pershare: Reserve per share
|
||||
- eps: Earnings per share
|
||||
- bvps: Book value per share
|
||||
- pb: P/B ratio
|
||||
- list_date: Listing date
|
||||
- undp: Undistributed profit
|
||||
- per_undp: Undistributed profit per share
|
||||
- rev_yoy: Revenue YoY (%)
|
||||
- profit_yoy: Profit YoY (%)
|
||||
- gpr: Gross profit ratio (%)
|
||||
- npr: Net profit ratio (%)
|
||||
- holder_num: Number of shareholders
|
||||
|
||||
Example:
|
||||
>>> # Get all stocks for a single date
|
||||
>>> data = get_bak_basic(trade_date='20240101')
|
||||
>>>
|
||||
>>> # Get specific stock data
|
||||
>>> data = get_bak_basic(ts_code='000001.SZ', trade_date='20240101')
|
||||
"""
|
||||
client = TushareClient()
|
||||
|
||||
# Build parameters
|
||||
params = {}
|
||||
if trade_date:
|
||||
params["trade_date"] = trade_date
|
||||
if ts_code:
|
||||
params["ts_code"] = ts_code
|
||||
|
||||
# Fetch data
|
||||
data = client.query("bak_basic", **params)
|
||||
|
||||
return data
|
||||
|
||||
|
||||
def sync_bak_basic(
|
||||
start_date: Optional[str] = None,
|
||||
end_date: Optional[str] = None,
|
||||
force_full: bool = False,
|
||||
) -> pd.DataFrame:
|
||||
"""Sync historical stock list to DuckDB with intelligent incremental sync.
|
||||
|
||||
Logic:
|
||||
- If table doesn't exist: create table + composite index (trade_date, ts_code) + full sync
|
||||
- If table exists: incremental sync from last_date + 1
|
||||
|
||||
Args:
|
||||
start_date: Start date for sync (YYYYMMDD format, default: 20160101 for full, last_date+1 for incremental)
|
||||
end_date: End date for sync (YYYYMMDD format, default: today)
|
||||
force_full: If True, force full reload from 20160101
|
||||
|
||||
Returns:
|
||||
pd.DataFrame with synced data
|
||||
"""
|
||||
from src.data.db_manager import ensure_table
|
||||
|
||||
TABLE_NAME = "bak_basic"
|
||||
storage = Storage()
|
||||
thread_storage = ThreadSafeStorage()
|
||||
|
||||
# Default end date
|
||||
if end_date is None:
|
||||
end_date = datetime.now().strftime("%Y%m%d")
|
||||
|
||||
# Check if table exists
|
||||
table_exists = storage.exists(TABLE_NAME)
|
||||
|
||||
if not table_exists or force_full:
|
||||
# ===== FULL SYNC =====
|
||||
# 1. Create table with schema
|
||||
# 2. Create composite index (trade_date, ts_code)
|
||||
# 3. Full sync from start_date
|
||||
|
||||
if not table_exists:
|
||||
print(f"[sync_bak_basic] Table '{TABLE_NAME}' doesn't exist, creating...")
|
||||
|
||||
# Fetch sample to get schema
|
||||
sample = get_bak_basic(trade_date=end_date)
|
||||
if sample.empty:
|
||||
sample = get_bak_basic(trade_date="20240102")
|
||||
|
||||
if sample.empty:
|
||||
print("[sync_bak_basic] Cannot create table: no sample data available")
|
||||
return pd.DataFrame()
|
||||
|
||||
# Create table with schema
|
||||
columns = []
|
||||
for col in sample.columns:
|
||||
dtype = str(sample[col].dtype)
|
||||
if "int" in dtype:
|
||||
col_type = "INTEGER"
|
||||
elif "float" in dtype:
|
||||
col_type = "DOUBLE"
|
||||
else:
|
||||
col_type = "VARCHAR"
|
||||
columns.append(f'"{col}" {col_type}')
|
||||
|
||||
columns_sql = ", ".join(columns)
|
||||
create_sql = f'CREATE TABLE IF NOT EXISTS "{TABLE_NAME}" ({columns_sql}, PRIMARY KEY ("trade_date", "ts_code"))'
|
||||
|
||||
try:
|
||||
storage._connection.execute(create_sql)
|
||||
print(f"[sync_bak_basic] Created table '{TABLE_NAME}'")
|
||||
except Exception as e:
|
||||
print(f"[sync_bak_basic] Error creating table: {e}")
|
||||
|
||||
# Create composite index
|
||||
try:
|
||||
storage._connection.execute(f"""
|
||||
CREATE INDEX IF NOT EXISTS "idx_bak_basic_date_code"
|
||||
ON "{TABLE_NAME}"("trade_date", "ts_code")
|
||||
""")
|
||||
print(f"[sync_bak_basic] Created composite index on (trade_date, ts_code)")
|
||||
except Exception as e:
|
||||
print(f"[sync_bak_basic] Error creating index: {e}")
|
||||
|
||||
# Determine sync dates
|
||||
sync_start = start_date or "20160101"
|
||||
mode = "FULL"
|
||||
print(f"[sync_bak_basic] Mode: {mode} SYNC from {sync_start} to {end_date}")
|
||||
|
||||
else:
|
||||
# ===== INCREMENTAL SYNC =====
|
||||
# Check last date in table, sync from last_date + 1
|
||||
|
||||
try:
|
||||
result = storage._connection.execute(
|
||||
f'SELECT MAX("trade_date") FROM "{TABLE_NAME}"'
|
||||
).fetchone()
|
||||
last_date = result[0] if result and result[0] else None
|
||||
except Exception as e:
|
||||
print(f"[sync_bak_basic] Error getting last date: {e}")
|
||||
last_date = None
|
||||
|
||||
if last_date is None:
|
||||
# Table exists but empty, do full sync
|
||||
sync_start = start_date or "20160101"
|
||||
mode = "FULL (empty table)"
|
||||
else:
|
||||
# Incremental from last_date + 1
|
||||
# Handle both YYYYMMDD and YYYY-MM-DD formats
|
||||
last_date_str = str(last_date).replace("-", "")
|
||||
last_dt = datetime.strptime(last_date_str, "%Y%m%d")
|
||||
next_dt = last_dt + timedelta(days=1)
|
||||
sync_start = next_dt.strftime("%Y%m%d")
|
||||
mode = "INCREMENTAL"
|
||||
|
||||
# Skip if already up to date
|
||||
if sync_start > end_date:
|
||||
print(f"[sync_bak_basic] Data is up-to-date (last: {last_date}), skipping sync")
|
||||
return pd.DataFrame()
|
||||
|
||||
print(f"[sync_bak_basic] Mode: {mode} from {sync_start} to {end_date} (last: {last_date})")
|
||||
|
||||
# ===== FETCH AND SAVE DATA =====
|
||||
all_data: List[pd.DataFrame] = []
|
||||
current = datetime.strptime(sync_start, "%Y%m%d")
|
||||
end_dt = datetime.strptime(end_date, "%Y%m%d")
|
||||
|
||||
# Calculate total days for progress bar
|
||||
total_days = (end_dt - current).days + 1
|
||||
print(f"[sync_bak_basic] Fetching data for {total_days} days...")
|
||||
|
||||
with tqdm(total=total_days, desc="Syncing dates") as pbar:
|
||||
while current <= end_dt:
|
||||
date_str = current.strftime("%Y%m%d")
|
||||
try:
|
||||
data = get_bak_basic(trade_date=date_str)
|
||||
if not data.empty:
|
||||
all_data.append(data)
|
||||
pbar.set_postfix({"date": date_str, "records": len(data)})
|
||||
except Exception as e:
|
||||
print(f" {date_str}: ERROR - {e}")
|
||||
|
||||
current += timedelta(days=1)
|
||||
pbar.update(1)
|
||||
|
||||
if not all_data:
|
||||
print("[sync_bak_basic] No data fetched")
|
||||
return pd.DataFrame()
|
||||
|
||||
# Combine and save
|
||||
combined = pd.concat(all_data, ignore_index=True)
|
||||
print(f"[sync_bak_basic] Total records: {len(combined)}")
|
||||
|
||||
# Delete existing data for the date range and append new data
|
||||
storage._connection.execute(f'DELETE FROM "{TABLE_NAME}" WHERE "trade_date" >= ?', [sync_start])
|
||||
thread_storage.queue_save(TABLE_NAME, combined)
|
||||
thread_storage.flush()
|
||||
|
||||
print(f"[sync_bak_basic] Saved {len(combined)} records to DuckDB")
|
||||
return combined
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Test sync
|
||||
result = sync_bak_basic(end_date="20240102")
|
||||
print(f"Synced {len(result)} records")
|
||||
if not result.empty:
|
||||
print("\nSample data:")
|
||||
print(result.head())
|
||||
Reference in New Issue
Block a user