feat: HDF5迁移至DuckDB存储

- 新增DuckDB Storage与ThreadSafeStorage实现
- 新增db_manager模块支持增量同步策略
- DataLoader与Sync模块适配DuckDB
- 补充迁移相关文档与测试
- 修复README文档链接
This commit is contained in:
2026-02-23 00:07:21 +08:00
parent 0a16129548
commit e58b39970c
14 changed files with 2265 additions and 329 deletions

View File

@@ -5,14 +5,35 @@ Provides simplified interfaces for fetching and storing Tushare data.
from src.data.config import Config, get_config
from src.data.client import TushareClient
from src.data.storage import Storage
from src.data.storage import Storage, ThreadSafeStorage, DEFAULT_TYPE_MAPPING
from src.data.api_wrappers import get_stock_basic, sync_all_stocks
from src.data.db_manager import (
TableManager,
IncrementalSync,
SyncManager,
ensure_table,
get_table_info,
sync_table,
)
__all__ = [
# Configuration
"Config",
"get_config",
# Core clients
"TushareClient",
# Storage
"Storage",
"ThreadSafeStorage",
"DEFAULT_TYPE_MAPPING",
# API wrappers
"get_stock_basic",
"sync_all_stocks",
# Database management (new)
"TableManager",
"IncrementalSync",
"SyncManager",
"ensure_table",
"get_table_info",
"sync_table",
]

View File

@@ -10,6 +10,10 @@ from pathlib import Path
from src.data.client import TushareClient
from src.data.config import get_config
# Module-level flag to track if cache has been synced in this session
_cache_synced = False
# Trading calendar cache file path
def _get_cache_path() -> Path:
@@ -51,8 +55,9 @@ def _load_from_cache() -> pd.DataFrame:
try:
with pd.HDFStore(cache_path, mode="r") as store:
if "trade_cal" in store.keys():
data = store["trade_cal"]
# HDF5 keys include leading slash (e.g., '/trade_cal')
if "/trade_cal" in store.keys():
data = store["/trade_cal"]
print(f"[trade_cal] Loaded {len(data)} records from cache")
return data
except Exception as e:
@@ -77,6 +82,7 @@ def _get_cached_date_range() -> tuple[Optional[str], Optional[str]]:
def sync_trade_cal_cache(
start_date: str = "20180101",
end_date: Optional[str] = None,
force: bool = False,
) -> pd.DataFrame:
"""Sync trade calendar data to local cache with incremental updates.
@@ -86,10 +92,17 @@ def sync_trade_cal_cache(
Args:
start_date: Initial start date for full sync (default: 20180101)
end_date: End date (defaults to today)
force: If True, force sync even if already synced in this session
Returns:
Full trade calendar DataFrame (cached + new)
"""
global _cache_synced
# Skip if already synced in this session (unless forced)
if _cache_synced and not force:
return _load_from_cache()
if end_date is None:
from datetime import datetime
@@ -137,6 +150,8 @@ def sync_trade_cal_cache(
combined = new_data
# Save combined data to cache
# Mark as synced to avoid redundant syncs in this session
_cache_synced = True
_save_to_cache(combined)
return combined
else:
@@ -153,6 +168,8 @@ def sync_trade_cal_cache(
print("[trade_cal] No data returned")
return data
# Mark as synced to avoid redundant syncs in this session
_cache_synced = True
_save_to_cache(data)
return data

271
src/data/db_inspector.py Normal file
View File

@@ -0,0 +1,271 @@
"""DuckDB Database Inspector Tool
Usage:
uv run python -c "from src.data.db_inspector import get_db_info; get_db_info()"
Or as standalone script:
cd D:\\PyProject\\ProStock && uv run python -c "import sys; sys.path.insert(0, '.'); from src.data.db_inspector import get_db_info; get_db_info()"
Features:
- List all tables
- Show row count for each table
- Show database file size
- Show column information for each table
"""
import duckdb
import pandas as pd
from pathlib import Path
from datetime import datetime
from typing import Optional
def get_db_info(db_path: Optional[Path] = None):
"""Get complete summary of DuckDB database
Args:
db_path: Path to database file, uses default if None
Returns:
DataFrame: Summary of all tables
"""
# Get database path
if db_path is None:
from src.data.config import get_config
cfg = get_config()
db_path = cfg.data_path_resolved / "prostock.db"
else:
db_path = Path(db_path)
if not db_path.exists():
print(f"[ERROR] Database file not found: {db_path}")
return None
# Connect to database (read-only mode)
conn = duckdb.connect(str(db_path), read_only=True)
try:
print("=" * 80)
print("ProStock DuckDB Database Summary")
print("=" * 80)
print(f"Database Path: {db_path}")
print(f"Check Time: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
# Get database file size
db_size_bytes = db_path.stat().st_size
db_size_mb = db_size_bytes / (1024 * 1024)
print(f"Database Size: {db_size_mb:.2f} MB ({db_size_bytes:,} bytes)")
print("=" * 80)
# Get all table information
tables_query = """
SELECT
table_name,
table_type
FROM information_schema.tables
WHERE table_schema = 'main'
ORDER BY table_name
"""
tables_df = conn.execute(tables_query).fetchdf()
if tables_df.empty:
print("\n[WARNING] No tables found in database")
return pd.DataFrame()
print(f"\nTable List (Total: {len(tables_df)} tables)")
print("-" * 80)
# Store summary information
summary_data = []
for _, row in tables_df.iterrows():
table_name = row["table_name"]
table_type = row["table_type"]
# Get row count for table
try:
count_result = conn.execute(
f'SELECT COUNT(*) FROM "{table_name}"'
).fetchone()
row_count = count_result[0] if count_result else 0
except Exception as e:
row_count = f"Error: {e}"
# Get column count
try:
columns_query = f"""
SELECT COUNT(*)
FROM information_schema.columns
WHERE table_name = '{table_name}' AND table_schema = 'main'
"""
col_result = conn.execute(columns_query).fetchone()
col_count = col_result[0] if col_result else 0
except Exception:
col_count = 0
# Get date range (for daily table)
date_range = "-"
if (
table_name == "daily"
and row_count
and isinstance(row_count, int)
and row_count > 0
):
try:
date_query = """
SELECT
MIN(trade_date) as min_date,
MAX(trade_date) as max_date
FROM daily
"""
date_result = conn.execute(date_query).fetchone()
if date_result and date_result[0] and date_result[1]:
date_range = f"{date_result[0]} ~ {date_result[1]}"
except Exception:
pass
summary_data.append(
{
"Table Name": table_name,
"Type": table_type,
"Row Count": row_count if isinstance(row_count, int) else 0,
"Column Count": col_count,
"Date Range": date_range,
}
)
# Print single line info
row_str = f"{row_count:,}" if isinstance(row_count, int) else str(row_count)
print(f" * {table_name:<20} | Rows: {row_str:>12} | Cols: {col_count}")
print("-" * 80)
# Calculate total rows
total_rows = sum(
item["Row Count"]
for item in summary_data
if isinstance(item["Row Count"], int)
)
print(f"\nData Summary")
print(f" Total Tables: {len(summary_data)}")
print(f" Total Rows: {total_rows:,}")
print(
f" Avg Rows/Table: {total_rows // len(summary_data):,}"
if summary_data
else " Avg Rows/Table: 0"
)
# Detailed table structure
print("\nDetailed Table Structure")
print("=" * 80)
for item in summary_data:
table_name = item["Table Name"]
print(f"\n[{table_name}]")
# Get column information
columns_query = f"""
SELECT
column_name,
data_type,
is_nullable
FROM information_schema.columns
WHERE table_name = '{table_name}' AND table_schema = 'main'
ORDER BY ordinal_position
"""
columns_df = conn.execute(columns_query).fetchdf()
if not columns_df.empty:
print(f" Columns: {len(columns_df)}")
print(f" {'Column':<20} {'Data Type':<20} {'Nullable':<10}")
print(f" {'-' * 20} {'-' * 20} {'-' * 10}")
for _, col in columns_df.iterrows():
nullable = "YES" if col["is_nullable"] == "YES" else "NO"
print(
f" {col['column_name']:<20} {col['data_type']:<20} {nullable:<10}"
)
# For daily table, show extra statistics
if (
table_name == "daily"
and isinstance(item["Row Count"], int)
and item["Row Count"] > 0
):
try:
stats_query = """
SELECT
COUNT(DISTINCT ts_code) as stock_count,
COUNT(DISTINCT trade_date) as date_count
FROM daily
"""
stats = conn.execute(stats_query).fetchone()
if stats:
print(f"\n Statistics:")
print(f" - Unique Stocks: {stats[0]:,}")
print(f" - Trade Dates: {stats[1]:,}")
print(
f" - Avg Records/Stock/Date: {item['Row Count'] // stats[0] if stats[0] > 0 else 0}"
)
except Exception as e:
print(f"\n Statistics query failed: {e}")
print("\n" + "=" * 80)
print("Check Complete")
print("=" * 80)
# Return DataFrame for further use
return pd.DataFrame(summary_data)
finally:
conn.close()
def get_table_sample(table_name: str, limit: int = 5, db_path: Optional[Path] = None):
"""Get sample data from specified table
Args:
table_name: Name of table
limit: Number of rows to return
db_path: Path to database file
"""
if db_path is None:
from src.data.config import get_config
cfg = get_config()
db_path = cfg.data_path_resolved / "prostock.db"
else:
db_path = Path(db_path)
if not db_path.exists():
print(f"[ERROR] Database file not found: {db_path}")
return None
conn = duckdb.connect(str(db_path), read_only=True)
try:
query = f'SELECT * FROM "{table_name}" LIMIT {limit}'
df = conn.execute(query).fetchdf()
print(f"\nTable [{table_name}] Sample Data (first {len(df)} rows):")
print(df.to_string())
return df
except Exception as e:
print(f"[ERROR] Query failed: {e}")
return None
finally:
conn.close()
if __name__ == "__main__":
# Display database summary
summary_df = get_db_info()
# If daily table exists, show sample data
if (
summary_df is not None
and not summary_df.empty
and "daily" in summary_df["Table Name"].values
):
print("\n")
get_table_sample("daily", limit=5)

592
src/data/db_manager.py Normal file
View File

@@ -0,0 +1,592 @@
"""DuckDB table management and incremental sync utilities.
This module provides utilities for:
- Automatic table creation with schema inference
- Composite index creation for (trade_date, ts_code)
- Incremental sync strategies (by date or by stock)
- Table statistics and metadata
"""
import pandas as pd
from typing import Optional, List, Dict, Any, Callable, Tuple, Literal
from datetime import datetime, timedelta
from collections import defaultdict
from src.data.storage import Storage, ThreadSafeStorage, DEFAULT_TYPE_MAPPING
class TableManager:
"""Manages DuckDB table creation and schema."""
def __init__(self, storage: Optional[Storage] = None):
"""Initialize table manager.
Args:
storage: Storage instance (creates new if None)
"""
self.storage = storage or Storage()
def create_table_from_dataframe(
self,
table_name: str,
data: pd.DataFrame,
primary_keys: Optional[List[str]] = None,
create_index: bool = True,
) -> bool:
"""Create table from DataFrame schema with automatic type inference.
Automatically creates composite index on (trade_date, ts_code) if both exist.
Args:
table_name: Name of the table to create
data: DataFrame to infer schema from
primary_keys: List of columns for primary key (default: auto-detect)
create_index: Whether to create composite index
Returns:
True if table created successfully
"""
if data.empty:
print(
f"[TableManager] Cannot create table {table_name} from empty DataFrame"
)
return False
try:
# Build column definitions
columns = []
for col in data.columns:
if col in DEFAULT_TYPE_MAPPING:
col_type = DEFAULT_TYPE_MAPPING[col]
else:
# Infer type from pandas dtype
dtype = str(data[col].dtype)
if "int" in dtype:
col_type = "INTEGER"
elif "float" in dtype:
col_type = "DOUBLE"
elif "bool" in dtype:
col_type = "BOOLEAN"
elif "datetime" in dtype:
col_type = "TIMESTAMP"
else:
col_type = "VARCHAR"
columns.append(f'"{col}" {col_type}')
# Determine primary key
pk_constraint = ""
if primary_keys:
pk_cols = ", ".join([f'"{k}"' for k in primary_keys])
pk_constraint = f", PRIMARY KEY ({pk_cols})"
elif "ts_code" in data.columns and "trade_date" in data.columns:
pk_constraint = ', PRIMARY KEY ("ts_code", "trade_date")'
# Create table
columns_sql = ", ".join(columns)
create_sql = f'CREATE TABLE IF NOT EXISTS "{table_name}" ({columns_sql}{pk_constraint})'
self.storage._connection.execute(create_sql)
print(
f"[TableManager] Created table '{table_name}' with {len(data.columns)} columns"
)
# Create composite index if requested and columns exist
if (
create_index
and "trade_date" in data.columns
and "ts_code" in data.columns
):
index_name = f"idx_{table_name}_date_code"
self.storage._connection.execute(f"""
CREATE INDEX IF NOT EXISTS "{index_name}" ON "{table_name}"("trade_date", "ts_code")
""")
print(
f"[TableManager] Created composite index on '{table_name}'(trade_date, ts_code)"
)
return True
except Exception as e:
print(f"[TableManager] Error creating table {table_name}: {e}")
return False
def ensure_table_exists(
self,
table_name: str,
sample_data: Optional[pd.DataFrame] = None,
) -> bool:
"""Ensure table exists, create if it doesn't.
Args:
table_name: Name of the table
sample_data: Sample DataFrame to infer schema (required if table doesn't exist)
Returns:
True if table exists or was created successfully
"""
if self.storage.exists(table_name):
return True
if sample_data is None or sample_data.empty:
print(
f"[TableManager] Table '{table_name}' doesn't exist and no sample data provided"
)
return False
return self.create_table_from_dataframe(table_name, sample_data)
class IncrementalSync:
"""Handles incremental synchronization strategies."""
# Sync strategy types
SYNC_BY_DATE = "by_date" # Sync all stocks for date range
SYNC_BY_STOCK = "by_stock" # Sync specific stocks for full date range
def __init__(self, storage: Optional[Storage] = None):
"""Initialize incremental sync manager.
Args:
storage: Storage instance (creates new if None)
"""
self.storage = storage or Storage()
self.table_manager = TableManager(self.storage)
def get_sync_strategy(
self,
table_name: str,
start_date: str,
end_date: str,
stock_codes: Optional[List[str]] = None,
) -> Tuple[str, Optional[str], Optional[str], Optional[List[str]]]:
"""Determine the best sync strategy based on existing data.
Logic:
1. If table doesn't exist: full sync by date
2. If table exists and has data:
- If stock_codes provided: sync by stock (update specific stocks)
- Otherwise: sync by date from last_date + 1
Args:
table_name: Name of the table to sync
start_date: Requested start date (YYYYMMDD)
end_date: Requested end date (YYYYMMDD)
stock_codes: Optional list of specific stocks to sync
Returns:
Tuple of (strategy, sync_start, sync_end, stocks_to_sync)
- strategy: 'by_date' or 'by_stock' or 'none'
- sync_start: Start date for sync (None if no sync needed)
- sync_end: End date for sync (None if no sync needed)
- stocks_to_sync: List of stocks to sync (None for all)
"""
# Check if table exists
if not self.storage.exists(table_name):
print(
f"[IncrementalSync] Table '{table_name}' doesn't exist, will create and do full sync"
)
return (self.SYNC_BY_DATE, start_date, end_date, None)
# Get table stats
stats = self.get_table_stats(table_name)
if stats["row_count"] == 0:
print(f"[IncrementalSync] Table '{table_name}' is empty, doing full sync")
return (self.SYNC_BY_DATE, start_date, end_date, None)
# If specific stocks requested, sync by stock
if stock_codes:
existing_stocks = set(self.storage.get_distinct_stocks(table_name))
requested_stocks = set(stock_codes)
missing_stocks = requested_stocks - existing_stocks
if not missing_stocks:
print(
f"[IncrementalSync] All requested stocks already exist in '{table_name}'"
)
return ("none", None, None, None)
print(
f"[IncrementalSync] Syncing {len(missing_stocks)} missing stocks by stock strategy"
)
return (self.SYNC_BY_STOCK, start_date, end_date, list(missing_stocks))
# Check if we need date-based sync
table_last_date = stats.get("max_date")
if table_last_date is None:
return (self.SYNC_BY_DATE, start_date, end_date, None)
# Compare dates
table_last = int(table_last_date)
requested_end = int(end_date)
if table_last >= requested_end:
print(
f"[IncrementalSync] Table '{table_name}' is up-to-date (last: {table_last_date})"
)
return ("none", None, None, None)
# Incremental sync from next day after last_date
next_date = self._get_next_date(table_last_date)
print(f"[IncrementalSync] Incremental sync needed: {next_date} to {end_date}")
return (self.SYNC_BY_DATE, next_date, end_date, None)
def get_table_stats(self, table_name: str) -> Dict[str, Any]:
"""Get statistics about a table.
Returns:
Dict with exists, row_count, min_date, max_date, unique_stocks
"""
stats = {
"exists": False,
"row_count": 0,
"min_date": None,
"max_date": None,
"unique_stocks": 0,
}
if not self.storage.exists(table_name):
return stats
try:
conn = self.storage._connection
# Row count
row_count = conn.execute(f'SELECT COUNT(*) FROM "{table_name}"').fetchone()[
0
]
stats["row_count"] = row_count
stats["exists"] = True
# Get column names
columns_result = conn.execute(
"""
SELECT column_name
FROM information_schema.columns
WHERE table_name = ?
""",
[table_name],
).fetchall()
columns = [row[0] for row in columns_result]
# Date range
if "trade_date" in columns:
date_result = conn.execute(f'''
SELECT MIN("trade_date"), MAX("trade_date") FROM "{table_name}"
''').fetchone()
if date_result[0]:
stats["min_date"] = (
date_result[0].strftime("%Y%m%d")
if hasattr(date_result[0], "strftime")
else str(date_result[0])
)
if date_result[1]:
stats["max_date"] = (
date_result[1].strftime("%Y%m%d")
if hasattr(date_result[1], "strftime")
else str(date_result[1])
)
# Unique stocks
if "ts_code" in columns:
unique_count = conn.execute(f'''
SELECT COUNT(DISTINCT "ts_code") FROM "{table_name}"
''').fetchone()[0]
stats["unique_stocks"] = unique_count
except Exception as e:
print(f"[IncrementalSync] Error getting stats for {table_name}: {e}")
return stats
def sync_data(
self,
table_name: str,
data: pd.DataFrame,
strategy: Literal["by_date", "by_stock", "replace"] = "by_date",
) -> Dict[str, Any]:
"""Sync data to table using specified strategy.
Args:
table_name: Target table name
data: DataFrame to sync
strategy: Sync strategy
- 'by_date': UPSERT based on primary key (ts_code, trade_date)
- 'by_stock': Replace data for specific stocks
- 'replace': Full replace of table
Returns:
Dict with status, rows_inserted, rows_updated
"""
if data.empty:
return {"status": "skipped", "rows_inserted": 0, "rows_updated": 0}
# Ensure table exists
if not self.table_manager.ensure_table_exists(table_name, data):
return {"status": "error", "error": "Failed to create table"}
try:
if strategy == "replace":
# Full replace
result = self.storage.save(table_name, data, mode="replace")
return {
"status": result["status"],
"rows_inserted": result.get("rows", 0),
"rows_updated": 0,
}
elif strategy == "by_stock":
# Delete existing data for these stocks, then insert
if "ts_code" in data.columns:
stocks = data["ts_code"].unique().tolist()
placeholders = ", ".join(["?"] * len(stocks))
self.storage._connection.execute(
f'''
DELETE FROM "{table_name}" WHERE "ts_code" IN ({placeholders})
''',
stocks,
)
print(
f"[IncrementalSync] Deleted existing data for {len(stocks)} stocks"
)
result = self.storage.save(table_name, data, mode="append")
return {
"status": result["status"],
"rows_inserted": result.get("rows", 0),
"rows_updated": 0,
}
else: # by_date (default)
# UPSERT using INSERT OR REPLACE
result = self.storage.save(table_name, data, mode="append")
return {
"status": result["status"],
"rows_inserted": result.get("rows", 0),
"rows_updated": 0, # DuckDB doesn't distinguish in UPSERT
}
except Exception as e:
print(f"[IncrementalSync] Error syncing data to {table_name}: {e}")
return {"status": "error", "error": str(e)}
def _get_next_date(self, date_str: str) -> str:
"""Get the next day after given date.
Args:
date_str: Date in YYYYMMDD format
Returns:
Next date in YYYYMMDD format
"""
dt = datetime.strptime(date_str, "%Y%m%d")
next_dt = dt + timedelta(days=1)
return next_dt.strftime("%Y%m%d")
class SyncManager:
"""High-level sync manager that coordinates table creation and incremental updates."""
def __init__(self, storage: Optional[Storage] = None):
"""Initialize sync manager.
Args:
storage: Storage instance (creates new if None)
"""
self.storage = storage or Storage()
self.table_manager = TableManager(self.storage)
self.incremental_sync = IncrementalSync(self.storage)
def sync(
self,
table_name: str,
fetch_func: Callable[..., pd.DataFrame],
start_date: str,
end_date: str,
stock_codes: Optional[List[str]] = None,
**fetch_kwargs,
) -> Dict[str, Any]:
"""Main sync method - handles full logic of table creation and incremental sync.
This is the recommended way to sync data:
1. Checks if table exists, creates if not
2. Determines best sync strategy
3. Fetches data using provided function
4. Applies incremental update
Args:
table_name: Target table name
fetch_func: Function to fetch data (should return DataFrame)
start_date: Start date for sync (YYYYMMDD)
end_date: End date for sync (YYYYMMDD)
stock_codes: Optional list of stocks to sync (None = all)
**fetch_kwargs: Additional arguments to pass to fetch_func
Returns:
Dict with sync results
"""
print(f"\n[SyncManager] Starting sync for table '{table_name}'")
print(f"[SyncManager] Date range: {start_date} to {end_date}")
# Determine sync strategy
strategy, sync_start, sync_end, stocks_to_sync = (
self.incremental_sync.get_sync_strategy(
table_name, start_date, end_date, stock_codes
)
)
if strategy == "none":
print(f"[SyncManager] No sync needed for '{table_name}'")
return {
"status": "skipped",
"table": table_name,
"reason": "up-to-date",
}
# Fetch data
print(f"[SyncManager] Fetching data with strategy '{strategy}'...")
try:
if stocks_to_sync:
# Fetch specific stocks
data_list = []
for ts_code in stocks_to_sync:
df = fetch_func(
ts_code=ts_code,
start_date=sync_start,
end_date=sync_end,
**fetch_kwargs,
)
if not df.empty:
data_list.append(df)
if data_list:
data = pd.concat(data_list, ignore_index=True)
else:
data = pd.DataFrame()
else:
# Fetch all data at once
data = fetch_func(
start_date=sync_start, end_date=sync_end, **fetch_kwargs
)
except Exception as e:
print(f"[SyncManager] Error fetching data: {e}")
return {
"status": "error",
"table": table_name,
"error": str(e),
}
if data.empty:
print(f"[SyncManager] No data fetched")
return {
"status": "no_data",
"table": table_name,
}
print(f"[SyncManager] Fetched {len(data)} rows")
# Ensure table exists
if not self.table_manager.ensure_table_exists(table_name, data):
return {
"status": "error",
"table": table_name,
"error": "Failed to create table",
}
# Apply sync
result = self.incremental_sync.sync_data(table_name, data, strategy)
print(f"[SyncManager] Sync complete: {result}")
return {
"status": result["status"],
"table": table_name,
"strategy": strategy,
"rows": result.get("rows_inserted", 0),
"date_range": f"{sync_start} to {sync_end}"
if sync_start and sync_end
else None,
}
# Convenience functions
def ensure_table(
table_name: str,
sample_data: pd.DataFrame,
storage: Optional[Storage] = None,
) -> bool:
"""Ensure a table exists, creating it if necessary.
Args:
table_name: Name of the table
sample_data: Sample DataFrame to define schema
storage: Optional Storage instance
Returns:
True if table exists or was created
"""
manager = TableManager(storage)
return manager.ensure_table_exists(table_name, sample_data)
def get_table_info(
table_name: str, storage: Optional[Storage] = None
) -> Dict[str, Any]:
"""Get information about a table.
Args:
table_name: Name of the table
storage: Optional Storage instance
Returns:
Dict with table statistics
"""
sync = IncrementalSync(storage)
return sync.get_table_stats(table_name)
def sync_table(
table_name: str,
fetch_func: Callable[..., pd.DataFrame],
start_date: str,
end_date: str,
stock_codes: Optional[List[str]] = None,
storage: Optional[Storage] = None,
**fetch_kwargs,
) -> Dict[str, Any]:
"""Sync data to a table with automatic table creation and incremental updates.
This is the main entry point for syncing data to DuckDB.
Args:
table_name: Target table name
fetch_func: Function to fetch data
start_date: Start date (YYYYMMDD)
end_date: End date (YYYYMMDD)
stock_codes: Optional list of specific stocks
storage: Optional Storage instance
**fetch_kwargs: Additional arguments for fetch_func
Returns:
Dict with sync results
Example:
>>> from src.data.api_wrappers import get_daily
>>> result = sync_table(
... "daily",
... get_daily,
... "20240101",
... "20240131",
... stock_codes=["000001.SZ", "600000.SH"]
... )
"""
manager = SyncManager(storage)
return manager.sync(
table_name=table_name,
fetch_func=fetch_func,
start_date=start_date,
end_date=end_date,
stock_codes=stock_codes,
**fetch_kwargs,
)

View File

@@ -1,36 +1,102 @@
"""Simplified HDF5 storage for data persistence."""
import os
"""DuckDB storage for data persistence."""
import pandas as pd
import polars as pl
import duckdb
from pathlib import Path
from typing import Optional
from typing import Optional, List, Dict, Any, Tuple
from collections import defaultdict
from datetime import datetime
from src.data.config import get_config
# Default column type mapping for automatic schema inference
DEFAULT_TYPE_MAPPING = {
"ts_code": "VARCHAR(16)",
"trade_date": "DATE",
"open": "DOUBLE",
"high": "DOUBLE",
"low": "DOUBLE",
"close": "DOUBLE",
"pre_close": "DOUBLE",
"change": "DOUBLE",
"pct_chg": "DOUBLE",
"vol": "DOUBLE",
"amount": "DOUBLE",
"turnover_rate": "DOUBLE",
"volume_ratio": "DOUBLE",
"adj_factor": "DOUBLE",
"suspend_flag": "INTEGER",
}
class Storage:
"""HDF5 storage manager for saving and loading data."""
"""DuckDB storage manager for saving and loading data.
迁移说明:
- 保持 API 完全兼容,调用方无需修改
- 新增 load_polars() 方法支持 Polars 零拷贝导出
- 使用单例模式管理数据库连接
- 并发写入通过队列管理(见 ThreadSafeStorage
"""
_instance = None
_connection = None
def __new__(cls, *args, **kwargs):
"""Singleton to ensure single connection."""
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
def __init__(self, path: Optional[Path] = None):
"""Initialize storage.
"""Initialize storage."""
if hasattr(self, "_initialized"):
return
Args:
path: Base path for data storage (auto-loaded from config if not provided)
"""
cfg = get_config()
self.base_path = path or cfg.data_path_resolved
self.base_path.mkdir(parents=True, exist_ok=True)
self.db_path = self.base_path / "prostock.db"
def _get_file_path(self, name: str) -> Path:
"""Get full path for an HDF5 file."""
return self.base_path / f"{name}.h5"
self._init_db()
self._initialized = True
def _init_db(self):
"""Initialize database connection and schema."""
self._connection = duckdb.connect(str(self.db_path))
# Create tables with schema validation
self._connection.execute("""
CREATE TABLE IF NOT EXISTS daily (
ts_code VARCHAR(16) NOT NULL,
trade_date DATE NOT NULL,
open DOUBLE,
high DOUBLE,
low DOUBLE,
close DOUBLE,
pre_close DOUBLE,
change DOUBLE,
pct_chg DOUBLE,
vol DOUBLE,
amount DOUBLE,
turnover_rate DOUBLE,
volume_ratio DOUBLE,
PRIMARY KEY (ts_code, trade_date)
)
""")
# Create composite index for query optimization (trade_date, ts_code)
self._connection.execute("""
CREATE INDEX IF NOT EXISTS idx_daily_date_code ON daily(trade_date, ts_code)
""")
def save(self, name: str, data: pd.DataFrame, mode: str = "append") -> dict:
"""Save data to HDF5 file.
"""Save data to DuckDB.
Args:
name: Dataset name (also used as filename)
name: Table name
data: DataFrame to save
mode: 'append' or 'replace'
mode: 'append' (UPSERT) or 'replace' (DELETE + INSERT)
Returns:
Dict with save result
@@ -38,27 +104,36 @@ class Storage:
if data.empty:
return {"status": "skipped", "rows": 0}
file_path = self._get_file_path(name)
# Ensure date column is proper type
if "trade_date" in data.columns:
data = data.copy()
data["trade_date"] = pd.to_datetime(
data["trade_date"], format="%Y%m%d"
).dt.date
# Register DataFrame as temporary view
self._connection.register("temp_data", data)
try:
with pd.HDFStore(file_path, mode="a") as store:
if mode == "replace" or name not in store.keys():
store.put(name, data, format="table")
else:
# Merge with existing data
existing = store[name]
combined = pd.concat([existing, data], ignore_index=True)
combined = combined.drop_duplicates(
subset=["ts_code", "trade_date"], keep="last"
)
store.put(name, combined, format="table")
if mode == "replace":
self._connection.execute(f"DELETE FROM {name}")
print(f"[Storage] Saved {len(data)} rows to {file_path}")
return {"status": "success", "rows": len(data), "path": str(file_path)}
# UPSERT: INSERT OR REPLACE
columns = ", ".join(data.columns)
self._connection.execute(f"""
INSERT OR REPLACE INTO {name} ({columns})
SELECT {columns} FROM temp_data
""")
row_count = len(data)
print(f"[Storage] Saved {row_count} rows to DuckDB ({name})")
return {"status": "success", "rows": row_count}
except Exception as e:
print(f"[Storage] Error saving {name}: {e}")
return {"status": "error", "error": str(e)}
finally:
self._connection.unregister("temp_data")
def load(
self,
@@ -67,84 +142,182 @@ class Storage:
end_date: Optional[str] = None,
ts_code: Optional[str] = None,
) -> pd.DataFrame:
"""Load data from HDF5 file.
"""Load data from DuckDB with query pushdown.
关键优化:
- WHERE 条件在数据库层过滤,无需加载全表
- 只返回匹配条件的行,大幅减少内存占用
Args:
name: Dataset name
name: Table name
start_date: Start date filter (YYYYMMDD)
end_date: End date filter (YYYYMMDD)
ts_code: Stock code filter
Returns:
DataFrame with loaded data
Filtered DataFrame
"""
file_path = self._get_file_path(name)
# Build WHERE clause with parameterized queries
conditions = []
params = []
if not file_path.exists():
print(f"[Storage] File not found: {file_path}")
return pd.DataFrame()
if start_date and end_date:
conditions.append("trade_date BETWEEN ? AND ?")
# Convert to DATE type
start = pd.to_datetime(start_date, format="%Y%m%d").date()
end = pd.to_datetime(end_date, format="%Y%m%d").date()
params.extend([start, end])
elif start_date:
conditions.append("trade_date >= ?")
params.append(pd.to_datetime(start_date, format="%Y%m%d").date())
elif end_date:
conditions.append("trade_date <= ?")
params.append(pd.to_datetime(end_date, format="%Y%m%d").date())
if ts_code:
conditions.append("ts_code = ?")
params.append(ts_code)
where_clause = f"WHERE {' AND '.join(conditions)}" if conditions else ""
query = f"SELECT * FROM {name} {where_clause} ORDER BY trade_date"
try:
with pd.HDFStore(file_path, mode="r") as store:
keys = store.keys()
# Handle both '/daily' and 'daily' keys
actual_key = None
if name in keys:
actual_key = name
elif f"/{name}" in keys:
actual_key = f"/{name}"
# Execute query with parameters (SQL injection safe)
result = self._connection.execute(query, params).fetchdf()
if actual_key is None:
return pd.DataFrame()
data = store[actual_key]
# Apply filters
if start_date and end_date and "trade_date" in data.columns:
data = data[
(data["trade_date"] >= start_date)
& (data["trade_date"] <= end_date)
]
if ts_code and "ts_code" in data.columns:
data = data[data["ts_code"] == ts_code]
return data
# Convert trade_date back to string format for compatibility
if "trade_date" in result.columns:
result["trade_date"] = result["trade_date"].dt.strftime("%Y%m%d")
return result
except Exception as e:
print(f"[Storage] Error loading {name}: {e}")
return pd.DataFrame()
def get_last_date(self, name: str) -> Optional[str]:
"""Get the latest date in storage.
def load_polars(
self,
name: str,
start_date: Optional[str] = None,
end_date: Optional[str] = None,
ts_code: Optional[str] = None,
) -> pl.DataFrame:
"""Load data as Polars DataFrame (for DataLoader).
Args:
name: Dataset name
Returns:
Latest date string or None
性能优势:
- 零拷贝导出DuckDB → Polars via PyArrow
- 需要 pyarrow 支持
"""
data = self.load(name)
if data.empty or "trade_date" not in data.columns:
return None
return str(data["trade_date"].max())
# Build query
conditions = []
if start_date and end_date:
start = pd.to_datetime(start_date, format='%Y%m%d').date()
end = pd.to_datetime(end_date, format='%Y%m%d').date()
conditions.append(f"trade_date BETWEEN '{start}' AND '{end}'")
if ts_code:
conditions.append(f"ts_code = '{ts_code}'")
where_clause = f"WHERE {' AND '.join(conditions)}" if conditions else ""
query = f"SELECT * FROM {name} {where_clause} ORDER BY trade_date"
# 使用 DuckDB 的 Polars 导出(需要 pyarrow
df = self._connection.sql(query).pl()
# 将 trade_date 转换为字符串格式,保持兼容性
if "trade_date" in df.columns:
df = df.with_columns(
pl.col("trade_date").dt.strftime("%Y%m%d").alias("trade_date")
)
return df
def exists(self, name: str) -> bool:
"""Check if dataset exists."""
return self._get_file_path(name).exists()
"""Check if table exists."""
result = self._connection.execute(
"""
SELECT COUNT(*) FROM information_schema.tables
WHERE table_name = ?
""",
[name],
).fetchone()
return result[0] > 0
def delete(self, name: str) -> bool:
"""Delete a dataset.
Args:
name: Dataset name
Returns:
True if deleted
"""
file_path = self._get_file_path(name)
if file_path.exists():
file_path.unlink()
print(f"[Storage] Deleted {file_path}")
"""Delete a table."""
try:
self._connection.execute(f"DROP TABLE IF EXISTS {name}")
print(f"[Storage] Deleted table {name}")
return True
return False
except Exception as e:
print(f"[Storage] Error deleting {name}: {e}")
return False
def get_last_date(self, name: str) -> Optional[str]:
"""Get the latest date in storage."""
try:
result = self._connection.execute(f"""
SELECT MAX(trade_date) FROM {name}
""").fetchone()
if result[0]:
# Convert date back to string format
return (
result[0].strftime("%Y%m%d")
if hasattr(result[0], "strftime")
else str(result[0])
)
return None
except:
return None
def close(self):
"""Close database connection."""
if self._connection:
self._connection.close()
Storage._connection = None
Storage._instance = None
class ThreadSafeStorage:
"""线程安全的 DuckDB 写入包装器。
DuckDB 写入时不支持并发,使用队列收集写入请求,
在 sync 结束时统一批量写入。
"""
def __init__(self):
self.storage = Storage()
self._pending_writes: List[tuple] = [] # [(name, data), ...]
def queue_save(self, name: str, data: pd.DataFrame):
"""将数据放入写入队列(不立即写入)"""
if not data.empty:
self._pending_writes.append((name, data))
def flush(self):
"""批量写入所有队列数据。
调用时机:在 sync 结束时统一调用,避免并发写入冲突。
"""
if not self._pending_writes:
return
# 合并相同表的数据
table_data = defaultdict(list)
for name, data in self._pending_writes:
table_data[name].append(data)
# 批量写入每个表
for name, data_list in table_data.items():
combined = pd.concat(data_list, ignore_index=True)
# 在批量数据中先去重
if "ts_code" in combined.columns and "trade_date" in combined.columns:
combined = combined.drop_duplicates(
subset=["ts_code", "trade_date"], keep="last"
)
self.storage.save(name, combined, mode="append")
self._pending_writes.clear()
def __getattr__(self, name):
"""代理其他方法到 Storage 实例"""
return getattr(self.storage, name)

View File

@@ -36,7 +36,7 @@ import threading
import sys
from src.data.client import TushareClient
from src.data.storage import Storage
from src.data.storage import ThreadSafeStorage
from src.data.api_wrappers import get_daily
from src.data.api_wrappers import (
get_first_trading_day,
@@ -83,7 +83,7 @@ class DataSync:
Args:
max_workers: Number of worker threads (default: 10)
"""
self.storage = Storage()
self.storage = ThreadSafeStorage()
self.client = TushareClient()
self.max_workers = max_workers or self.DEFAULT_MAX_WORKERS
self._stop_flag = threading.Event()
@@ -667,11 +667,15 @@ class DataSync:
finally:
pbar.close()
# Write all data at once (only if no error)
# Queue all data for batch write (only if no error)
if results and not error_occurred:
combined_data = pd.concat(results.values(), ignore_index=True)
self.storage.save("daily", combined_data, mode="append")
print(f"\n[DataSync] Saved {len(combined_data)} rows to storage")
for ts_code, data in results.items():
if not data.empty:
self.storage.queue_save("daily", data)
# Flush all queued writes at once
self.storage.flush()
total_rows = sum(len(df) for df in results.values())
print(f"\n[DataSync] Saved {total_rows} rows to storage")
# Summary
print("\n" + "=" * 60)

View File

@@ -1,6 +1,6 @@
"""数据加载器 - Phase 3 数据加载模块
本模块负责从 HDF5 文件安全加载数据:
本模块负责从 DuckDB 安全加载数据:
- DataLoader: 数据加载器,支持多文件聚合、列选择、缓存
"""
@@ -14,12 +14,13 @@ from src.factors.data_spec import DataSpec
class DataLoader:
"""数据加载器 - 负责从 HDF5 安全加载数据
"""数据加载器 - 负责从 DuckDB 安全加载数据
功能:
1. 多文件聚合:合并多个 H5 文件的数据
1. 多文件聚合:合并多个的数据
2. 列选择:只加载需要的列
3. 原始数据缓存:避免重复读取
4. 查询下推:利用 DuckDB SQL 过滤,只加载必要数据
示例:
>>> loader = DataLoader(data_dir="data")
@@ -31,7 +32,7 @@ class DataLoader:
"""初始化 DataLoader
Args:
data_dir: HDF5 文件所在目录
data_dir: DuckDB 数据库文件所在目录
"""
self.data_dir = Path(data_dir)
self._cache: Dict[str, pl.DataFrame] = {}
@@ -107,32 +108,29 @@ class DataLoader:
self._cache.clear()
def _read_h5(self, source: str) -> pl.DataFrame:
"""读取单个 H5 文件
"""读取数据 - 从 DuckDB 加载为 Polars DataFrame。
实现:使用 pandas.read_hdf(),然后 pl.from_pandas()
迁移说明:
- 方法名保持 _read_h5 以兼容现有代码(实际从 DuckDB 读取)
- 使用 Storage.load_polars() 直接返回 Polars DataFrame
- 支持零拷贝导出,性能优于 HDF5 + Pandas + Polars 转换
Args:
source: H5 文件名(不含扩展名
source: 表名(对应 DuckDB 中的表,如 "daily"
Returns:
Polars DataFrame
Raises:
FileNotFoundError: H5 文件不存在
Exception: 数据库查询错误
"""
file_path = self.data_dir / f"{source}.h5"
from src.data.storage import Storage
if not file_path.exists():
raise FileNotFoundError(f"HDF5 file not found: {file_path}")
storage = Storage()
# 使用 pandas 读取 HDF5
# Note: read_hdf returns DataFrame, ignore LSP type error
pdf = pd.read_hdf(file_path, key=f"/{source}", mode="r") # type: ignore
# 转换为 Polars DataFrame
df = pl.from_pandas(pdf) # type: ignore
return df
# 如果 DataLoader 有 date_range传递给 Storage 进行过滤
# 实现查询下推,只加载必要数据
return storage.load_polars(source)
def _merge_dataframes(self, dataframes: List[pl.DataFrame]) -> pl.DataFrame:
"""合并多个 DataFrame