feat: HDF5迁移至DuckDB存储
- 新增DuckDB Storage与ThreadSafeStorage实现 - 新增db_manager模块支持增量同步策略 - DataLoader与Sync模块适配DuckDB - 补充迁移相关文档与测试 - 修复README文档链接
This commit is contained in:
@@ -5,14 +5,35 @@ Provides simplified interfaces for fetching and storing Tushare data.
|
||||
|
||||
from src.data.config import Config, get_config
|
||||
from src.data.client import TushareClient
|
||||
from src.data.storage import Storage
|
||||
from src.data.storage import Storage, ThreadSafeStorage, DEFAULT_TYPE_MAPPING
|
||||
from src.data.api_wrappers import get_stock_basic, sync_all_stocks
|
||||
from src.data.db_manager import (
|
||||
TableManager,
|
||||
IncrementalSync,
|
||||
SyncManager,
|
||||
ensure_table,
|
||||
get_table_info,
|
||||
sync_table,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
# Configuration
|
||||
"Config",
|
||||
"get_config",
|
||||
# Core clients
|
||||
"TushareClient",
|
||||
# Storage
|
||||
"Storage",
|
||||
"ThreadSafeStorage",
|
||||
"DEFAULT_TYPE_MAPPING",
|
||||
# API wrappers
|
||||
"get_stock_basic",
|
||||
"sync_all_stocks",
|
||||
# Database management (new)
|
||||
"TableManager",
|
||||
"IncrementalSync",
|
||||
"SyncManager",
|
||||
"ensure_table",
|
||||
"get_table_info",
|
||||
"sync_table",
|
||||
]
|
||||
|
||||
@@ -10,6 +10,10 @@ from pathlib import Path
|
||||
from src.data.client import TushareClient
|
||||
from src.data.config import get_config
|
||||
|
||||
# Module-level flag to track if cache has been synced in this session
|
||||
_cache_synced = False
|
||||
|
||||
|
||||
|
||||
# Trading calendar cache file path
|
||||
def _get_cache_path() -> Path:
|
||||
@@ -51,8 +55,9 @@ def _load_from_cache() -> pd.DataFrame:
|
||||
|
||||
try:
|
||||
with pd.HDFStore(cache_path, mode="r") as store:
|
||||
if "trade_cal" in store.keys():
|
||||
data = store["trade_cal"]
|
||||
# HDF5 keys include leading slash (e.g., '/trade_cal')
|
||||
if "/trade_cal" in store.keys():
|
||||
data = store["/trade_cal"]
|
||||
print(f"[trade_cal] Loaded {len(data)} records from cache")
|
||||
return data
|
||||
except Exception as e:
|
||||
@@ -77,6 +82,7 @@ def _get_cached_date_range() -> tuple[Optional[str], Optional[str]]:
|
||||
def sync_trade_cal_cache(
|
||||
start_date: str = "20180101",
|
||||
end_date: Optional[str] = None,
|
||||
force: bool = False,
|
||||
) -> pd.DataFrame:
|
||||
"""Sync trade calendar data to local cache with incremental updates.
|
||||
|
||||
@@ -86,10 +92,17 @@ def sync_trade_cal_cache(
|
||||
Args:
|
||||
start_date: Initial start date for full sync (default: 20180101)
|
||||
end_date: End date (defaults to today)
|
||||
force: If True, force sync even if already synced in this session
|
||||
|
||||
Returns:
|
||||
Full trade calendar DataFrame (cached + new)
|
||||
"""
|
||||
global _cache_synced
|
||||
|
||||
# Skip if already synced in this session (unless forced)
|
||||
if _cache_synced and not force:
|
||||
return _load_from_cache()
|
||||
|
||||
if end_date is None:
|
||||
from datetime import datetime
|
||||
|
||||
@@ -137,6 +150,8 @@ def sync_trade_cal_cache(
|
||||
combined = new_data
|
||||
|
||||
# Save combined data to cache
|
||||
# Mark as synced to avoid redundant syncs in this session
|
||||
_cache_synced = True
|
||||
_save_to_cache(combined)
|
||||
return combined
|
||||
else:
|
||||
@@ -153,6 +168,8 @@ def sync_trade_cal_cache(
|
||||
print("[trade_cal] No data returned")
|
||||
return data
|
||||
|
||||
# Mark as synced to avoid redundant syncs in this session
|
||||
_cache_synced = True
|
||||
_save_to_cache(data)
|
||||
return data
|
||||
|
||||
|
||||
271
src/data/db_inspector.py
Normal file
271
src/data/db_inspector.py
Normal file
@@ -0,0 +1,271 @@
|
||||
"""DuckDB Database Inspector Tool
|
||||
|
||||
Usage:
|
||||
uv run python -c "from src.data.db_inspector import get_db_info; get_db_info()"
|
||||
|
||||
Or as standalone script:
|
||||
cd D:\\PyProject\\ProStock && uv run python -c "import sys; sys.path.insert(0, '.'); from src.data.db_inspector import get_db_info; get_db_info()"
|
||||
|
||||
Features:
|
||||
- List all tables
|
||||
- Show row count for each table
|
||||
- Show database file size
|
||||
- Show column information for each table
|
||||
"""
|
||||
|
||||
import duckdb
|
||||
import pandas as pd
|
||||
from pathlib import Path
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
|
||||
def get_db_info(db_path: Optional[Path] = None):
|
||||
"""Get complete summary of DuckDB database
|
||||
|
||||
Args:
|
||||
db_path: Path to database file, uses default if None
|
||||
|
||||
Returns:
|
||||
DataFrame: Summary of all tables
|
||||
"""
|
||||
|
||||
# Get database path
|
||||
if db_path is None:
|
||||
from src.data.config import get_config
|
||||
|
||||
cfg = get_config()
|
||||
db_path = cfg.data_path_resolved / "prostock.db"
|
||||
else:
|
||||
db_path = Path(db_path)
|
||||
|
||||
if not db_path.exists():
|
||||
print(f"[ERROR] Database file not found: {db_path}")
|
||||
return None
|
||||
|
||||
# Connect to database (read-only mode)
|
||||
conn = duckdb.connect(str(db_path), read_only=True)
|
||||
|
||||
try:
|
||||
print("=" * 80)
|
||||
print("ProStock DuckDB Database Summary")
|
||||
print("=" * 80)
|
||||
print(f"Database Path: {db_path}")
|
||||
print(f"Check Time: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
|
||||
|
||||
# Get database file size
|
||||
db_size_bytes = db_path.stat().st_size
|
||||
db_size_mb = db_size_bytes / (1024 * 1024)
|
||||
print(f"Database Size: {db_size_mb:.2f} MB ({db_size_bytes:,} bytes)")
|
||||
print("=" * 80)
|
||||
|
||||
# Get all table information
|
||||
tables_query = """
|
||||
SELECT
|
||||
table_name,
|
||||
table_type
|
||||
FROM information_schema.tables
|
||||
WHERE table_schema = 'main'
|
||||
ORDER BY table_name
|
||||
"""
|
||||
tables_df = conn.execute(tables_query).fetchdf()
|
||||
|
||||
if tables_df.empty:
|
||||
print("\n[WARNING] No tables found in database")
|
||||
return pd.DataFrame()
|
||||
|
||||
print(f"\nTable List (Total: {len(tables_df)} tables)")
|
||||
print("-" * 80)
|
||||
|
||||
# Store summary information
|
||||
summary_data = []
|
||||
|
||||
for _, row in tables_df.iterrows():
|
||||
table_name = row["table_name"]
|
||||
table_type = row["table_type"]
|
||||
|
||||
# Get row count for table
|
||||
try:
|
||||
count_result = conn.execute(
|
||||
f'SELECT COUNT(*) FROM "{table_name}"'
|
||||
).fetchone()
|
||||
row_count = count_result[0] if count_result else 0
|
||||
except Exception as e:
|
||||
row_count = f"Error: {e}"
|
||||
|
||||
# Get column count
|
||||
try:
|
||||
columns_query = f"""
|
||||
SELECT COUNT(*)
|
||||
FROM information_schema.columns
|
||||
WHERE table_name = '{table_name}' AND table_schema = 'main'
|
||||
"""
|
||||
col_result = conn.execute(columns_query).fetchone()
|
||||
col_count = col_result[0] if col_result else 0
|
||||
except Exception:
|
||||
col_count = 0
|
||||
|
||||
# Get date range (for daily table)
|
||||
date_range = "-"
|
||||
if (
|
||||
table_name == "daily"
|
||||
and row_count
|
||||
and isinstance(row_count, int)
|
||||
and row_count > 0
|
||||
):
|
||||
try:
|
||||
date_query = """
|
||||
SELECT
|
||||
MIN(trade_date) as min_date,
|
||||
MAX(trade_date) as max_date
|
||||
FROM daily
|
||||
"""
|
||||
date_result = conn.execute(date_query).fetchone()
|
||||
if date_result and date_result[0] and date_result[1]:
|
||||
date_range = f"{date_result[0]} ~ {date_result[1]}"
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
summary_data.append(
|
||||
{
|
||||
"Table Name": table_name,
|
||||
"Type": table_type,
|
||||
"Row Count": row_count if isinstance(row_count, int) else 0,
|
||||
"Column Count": col_count,
|
||||
"Date Range": date_range,
|
||||
}
|
||||
)
|
||||
|
||||
# Print single line info
|
||||
row_str = f"{row_count:,}" if isinstance(row_count, int) else str(row_count)
|
||||
print(f" * {table_name:<20} | Rows: {row_str:>12} | Cols: {col_count}")
|
||||
|
||||
print("-" * 80)
|
||||
|
||||
# Calculate total rows
|
||||
total_rows = sum(
|
||||
item["Row Count"]
|
||||
for item in summary_data
|
||||
if isinstance(item["Row Count"], int)
|
||||
)
|
||||
print(f"\nData Summary")
|
||||
print(f" Total Tables: {len(summary_data)}")
|
||||
print(f" Total Rows: {total_rows:,}")
|
||||
print(
|
||||
f" Avg Rows/Table: {total_rows // len(summary_data):,}"
|
||||
if summary_data
|
||||
else " Avg Rows/Table: 0"
|
||||
)
|
||||
|
||||
# Detailed table structure
|
||||
print("\nDetailed Table Structure")
|
||||
print("=" * 80)
|
||||
|
||||
for item in summary_data:
|
||||
table_name = item["Table Name"]
|
||||
print(f"\n[{table_name}]")
|
||||
|
||||
# Get column information
|
||||
columns_query = f"""
|
||||
SELECT
|
||||
column_name,
|
||||
data_type,
|
||||
is_nullable
|
||||
FROM information_schema.columns
|
||||
WHERE table_name = '{table_name}' AND table_schema = 'main'
|
||||
ORDER BY ordinal_position
|
||||
"""
|
||||
columns_df = conn.execute(columns_query).fetchdf()
|
||||
|
||||
if not columns_df.empty:
|
||||
print(f" Columns: {len(columns_df)}")
|
||||
print(f" {'Column':<20} {'Data Type':<20} {'Nullable':<10}")
|
||||
print(f" {'-' * 20} {'-' * 20} {'-' * 10}")
|
||||
for _, col in columns_df.iterrows():
|
||||
nullable = "YES" if col["is_nullable"] == "YES" else "NO"
|
||||
print(
|
||||
f" {col['column_name']:<20} {col['data_type']:<20} {nullable:<10}"
|
||||
)
|
||||
|
||||
# For daily table, show extra statistics
|
||||
if (
|
||||
table_name == "daily"
|
||||
and isinstance(item["Row Count"], int)
|
||||
and item["Row Count"] > 0
|
||||
):
|
||||
try:
|
||||
stats_query = """
|
||||
SELECT
|
||||
COUNT(DISTINCT ts_code) as stock_count,
|
||||
COUNT(DISTINCT trade_date) as date_count
|
||||
FROM daily
|
||||
"""
|
||||
stats = conn.execute(stats_query).fetchone()
|
||||
if stats:
|
||||
print(f"\n Statistics:")
|
||||
print(f" - Unique Stocks: {stats[0]:,}")
|
||||
print(f" - Trade Dates: {stats[1]:,}")
|
||||
print(
|
||||
f" - Avg Records/Stock/Date: {item['Row Count'] // stats[0] if stats[0] > 0 else 0}"
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"\n Statistics query failed: {e}")
|
||||
|
||||
print("\n" + "=" * 80)
|
||||
print("Check Complete")
|
||||
print("=" * 80)
|
||||
|
||||
# Return DataFrame for further use
|
||||
return pd.DataFrame(summary_data)
|
||||
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
|
||||
def get_table_sample(table_name: str, limit: int = 5, db_path: Optional[Path] = None):
|
||||
"""Get sample data from specified table
|
||||
|
||||
Args:
|
||||
table_name: Name of table
|
||||
limit: Number of rows to return
|
||||
db_path: Path to database file
|
||||
"""
|
||||
if db_path is None:
|
||||
from src.data.config import get_config
|
||||
|
||||
cfg = get_config()
|
||||
db_path = cfg.data_path_resolved / "prostock.db"
|
||||
else:
|
||||
db_path = Path(db_path)
|
||||
|
||||
if not db_path.exists():
|
||||
print(f"[ERROR] Database file not found: {db_path}")
|
||||
return None
|
||||
|
||||
conn = duckdb.connect(str(db_path), read_only=True)
|
||||
|
||||
try:
|
||||
query = f'SELECT * FROM "{table_name}" LIMIT {limit}'
|
||||
df = conn.execute(query).fetchdf()
|
||||
print(f"\nTable [{table_name}] Sample Data (first {len(df)} rows):")
|
||||
print(df.to_string())
|
||||
return df
|
||||
except Exception as e:
|
||||
print(f"[ERROR] Query failed: {e}")
|
||||
return None
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Display database summary
|
||||
summary_df = get_db_info()
|
||||
|
||||
# If daily table exists, show sample data
|
||||
if (
|
||||
summary_df is not None
|
||||
and not summary_df.empty
|
||||
and "daily" in summary_df["Table Name"].values
|
||||
):
|
||||
print("\n")
|
||||
get_table_sample("daily", limit=5)
|
||||
592
src/data/db_manager.py
Normal file
592
src/data/db_manager.py
Normal file
@@ -0,0 +1,592 @@
|
||||
"""DuckDB table management and incremental sync utilities.
|
||||
|
||||
This module provides utilities for:
|
||||
- Automatic table creation with schema inference
|
||||
- Composite index creation for (trade_date, ts_code)
|
||||
- Incremental sync strategies (by date or by stock)
|
||||
- Table statistics and metadata
|
||||
"""
|
||||
|
||||
import pandas as pd
|
||||
from typing import Optional, List, Dict, Any, Callable, Tuple, Literal
|
||||
from datetime import datetime, timedelta
|
||||
from collections import defaultdict
|
||||
from src.data.storage import Storage, ThreadSafeStorage, DEFAULT_TYPE_MAPPING
|
||||
|
||||
|
||||
class TableManager:
|
||||
"""Manages DuckDB table creation and schema."""
|
||||
|
||||
def __init__(self, storage: Optional[Storage] = None):
|
||||
"""Initialize table manager.
|
||||
|
||||
Args:
|
||||
storage: Storage instance (creates new if None)
|
||||
"""
|
||||
self.storage = storage or Storage()
|
||||
|
||||
def create_table_from_dataframe(
|
||||
self,
|
||||
table_name: str,
|
||||
data: pd.DataFrame,
|
||||
primary_keys: Optional[List[str]] = None,
|
||||
create_index: bool = True,
|
||||
) -> bool:
|
||||
"""Create table from DataFrame schema with automatic type inference.
|
||||
|
||||
Automatically creates composite index on (trade_date, ts_code) if both exist.
|
||||
|
||||
Args:
|
||||
table_name: Name of the table to create
|
||||
data: DataFrame to infer schema from
|
||||
primary_keys: List of columns for primary key (default: auto-detect)
|
||||
create_index: Whether to create composite index
|
||||
|
||||
Returns:
|
||||
True if table created successfully
|
||||
"""
|
||||
if data.empty:
|
||||
print(
|
||||
f"[TableManager] Cannot create table {table_name} from empty DataFrame"
|
||||
)
|
||||
return False
|
||||
|
||||
try:
|
||||
# Build column definitions
|
||||
columns = []
|
||||
for col in data.columns:
|
||||
if col in DEFAULT_TYPE_MAPPING:
|
||||
col_type = DEFAULT_TYPE_MAPPING[col]
|
||||
else:
|
||||
# Infer type from pandas dtype
|
||||
dtype = str(data[col].dtype)
|
||||
if "int" in dtype:
|
||||
col_type = "INTEGER"
|
||||
elif "float" in dtype:
|
||||
col_type = "DOUBLE"
|
||||
elif "bool" in dtype:
|
||||
col_type = "BOOLEAN"
|
||||
elif "datetime" in dtype:
|
||||
col_type = "TIMESTAMP"
|
||||
else:
|
||||
col_type = "VARCHAR"
|
||||
columns.append(f'"{col}" {col_type}')
|
||||
|
||||
# Determine primary key
|
||||
pk_constraint = ""
|
||||
if primary_keys:
|
||||
pk_cols = ", ".join([f'"{k}"' for k in primary_keys])
|
||||
pk_constraint = f", PRIMARY KEY ({pk_cols})"
|
||||
elif "ts_code" in data.columns and "trade_date" in data.columns:
|
||||
pk_constraint = ', PRIMARY KEY ("ts_code", "trade_date")'
|
||||
|
||||
# Create table
|
||||
columns_sql = ", ".join(columns)
|
||||
create_sql = f'CREATE TABLE IF NOT EXISTS "{table_name}" ({columns_sql}{pk_constraint})'
|
||||
|
||||
self.storage._connection.execute(create_sql)
|
||||
print(
|
||||
f"[TableManager] Created table '{table_name}' with {len(data.columns)} columns"
|
||||
)
|
||||
|
||||
# Create composite index if requested and columns exist
|
||||
if (
|
||||
create_index
|
||||
and "trade_date" in data.columns
|
||||
and "ts_code" in data.columns
|
||||
):
|
||||
index_name = f"idx_{table_name}_date_code"
|
||||
self.storage._connection.execute(f"""
|
||||
CREATE INDEX IF NOT EXISTS "{index_name}" ON "{table_name}"("trade_date", "ts_code")
|
||||
""")
|
||||
print(
|
||||
f"[TableManager] Created composite index on '{table_name}'(trade_date, ts_code)"
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"[TableManager] Error creating table {table_name}: {e}")
|
||||
return False
|
||||
|
||||
def ensure_table_exists(
|
||||
self,
|
||||
table_name: str,
|
||||
sample_data: Optional[pd.DataFrame] = None,
|
||||
) -> bool:
|
||||
"""Ensure table exists, create if it doesn't.
|
||||
|
||||
Args:
|
||||
table_name: Name of the table
|
||||
sample_data: Sample DataFrame to infer schema (required if table doesn't exist)
|
||||
|
||||
Returns:
|
||||
True if table exists or was created successfully
|
||||
"""
|
||||
if self.storage.exists(table_name):
|
||||
return True
|
||||
|
||||
if sample_data is None or sample_data.empty:
|
||||
print(
|
||||
f"[TableManager] Table '{table_name}' doesn't exist and no sample data provided"
|
||||
)
|
||||
return False
|
||||
|
||||
return self.create_table_from_dataframe(table_name, sample_data)
|
||||
|
||||
|
||||
class IncrementalSync:
|
||||
"""Handles incremental synchronization strategies."""
|
||||
|
||||
# Sync strategy types
|
||||
SYNC_BY_DATE = "by_date" # Sync all stocks for date range
|
||||
SYNC_BY_STOCK = "by_stock" # Sync specific stocks for full date range
|
||||
|
||||
def __init__(self, storage: Optional[Storage] = None):
|
||||
"""Initialize incremental sync manager.
|
||||
|
||||
Args:
|
||||
storage: Storage instance (creates new if None)
|
||||
"""
|
||||
self.storage = storage or Storage()
|
||||
self.table_manager = TableManager(self.storage)
|
||||
|
||||
def get_sync_strategy(
|
||||
self,
|
||||
table_name: str,
|
||||
start_date: str,
|
||||
end_date: str,
|
||||
stock_codes: Optional[List[str]] = None,
|
||||
) -> Tuple[str, Optional[str], Optional[str], Optional[List[str]]]:
|
||||
"""Determine the best sync strategy based on existing data.
|
||||
|
||||
Logic:
|
||||
1. If table doesn't exist: full sync by date
|
||||
2. If table exists and has data:
|
||||
- If stock_codes provided: sync by stock (update specific stocks)
|
||||
- Otherwise: sync by date from last_date + 1
|
||||
|
||||
Args:
|
||||
table_name: Name of the table to sync
|
||||
start_date: Requested start date (YYYYMMDD)
|
||||
end_date: Requested end date (YYYYMMDD)
|
||||
stock_codes: Optional list of specific stocks to sync
|
||||
|
||||
Returns:
|
||||
Tuple of (strategy, sync_start, sync_end, stocks_to_sync)
|
||||
- strategy: 'by_date' or 'by_stock' or 'none'
|
||||
- sync_start: Start date for sync (None if no sync needed)
|
||||
- sync_end: End date for sync (None if no sync needed)
|
||||
- stocks_to_sync: List of stocks to sync (None for all)
|
||||
"""
|
||||
# Check if table exists
|
||||
if not self.storage.exists(table_name):
|
||||
print(
|
||||
f"[IncrementalSync] Table '{table_name}' doesn't exist, will create and do full sync"
|
||||
)
|
||||
return (self.SYNC_BY_DATE, start_date, end_date, None)
|
||||
|
||||
# Get table stats
|
||||
stats = self.get_table_stats(table_name)
|
||||
|
||||
if stats["row_count"] == 0:
|
||||
print(f"[IncrementalSync] Table '{table_name}' is empty, doing full sync")
|
||||
return (self.SYNC_BY_DATE, start_date, end_date, None)
|
||||
|
||||
# If specific stocks requested, sync by stock
|
||||
if stock_codes:
|
||||
existing_stocks = set(self.storage.get_distinct_stocks(table_name))
|
||||
requested_stocks = set(stock_codes)
|
||||
missing_stocks = requested_stocks - existing_stocks
|
||||
|
||||
if not missing_stocks:
|
||||
print(
|
||||
f"[IncrementalSync] All requested stocks already exist in '{table_name}'"
|
||||
)
|
||||
return ("none", None, None, None)
|
||||
|
||||
print(
|
||||
f"[IncrementalSync] Syncing {len(missing_stocks)} missing stocks by stock strategy"
|
||||
)
|
||||
return (self.SYNC_BY_STOCK, start_date, end_date, list(missing_stocks))
|
||||
|
||||
# Check if we need date-based sync
|
||||
table_last_date = stats.get("max_date")
|
||||
|
||||
if table_last_date is None:
|
||||
return (self.SYNC_BY_DATE, start_date, end_date, None)
|
||||
|
||||
# Compare dates
|
||||
table_last = int(table_last_date)
|
||||
requested_end = int(end_date)
|
||||
|
||||
if table_last >= requested_end:
|
||||
print(
|
||||
f"[IncrementalSync] Table '{table_name}' is up-to-date (last: {table_last_date})"
|
||||
)
|
||||
return ("none", None, None, None)
|
||||
|
||||
# Incremental sync from next day after last_date
|
||||
next_date = self._get_next_date(table_last_date)
|
||||
print(f"[IncrementalSync] Incremental sync needed: {next_date} to {end_date}")
|
||||
return (self.SYNC_BY_DATE, next_date, end_date, None)
|
||||
|
||||
def get_table_stats(self, table_name: str) -> Dict[str, Any]:
|
||||
"""Get statistics about a table.
|
||||
|
||||
Returns:
|
||||
Dict with exists, row_count, min_date, max_date, unique_stocks
|
||||
"""
|
||||
stats = {
|
||||
"exists": False,
|
||||
"row_count": 0,
|
||||
"min_date": None,
|
||||
"max_date": None,
|
||||
"unique_stocks": 0,
|
||||
}
|
||||
|
||||
if not self.storage.exists(table_name):
|
||||
return stats
|
||||
|
||||
try:
|
||||
conn = self.storage._connection
|
||||
|
||||
# Row count
|
||||
row_count = conn.execute(f'SELECT COUNT(*) FROM "{table_name}"').fetchone()[
|
||||
0
|
||||
]
|
||||
stats["row_count"] = row_count
|
||||
stats["exists"] = True
|
||||
|
||||
# Get column names
|
||||
columns_result = conn.execute(
|
||||
"""
|
||||
SELECT column_name
|
||||
FROM information_schema.columns
|
||||
WHERE table_name = ?
|
||||
""",
|
||||
[table_name],
|
||||
).fetchall()
|
||||
columns = [row[0] for row in columns_result]
|
||||
|
||||
# Date range
|
||||
if "trade_date" in columns:
|
||||
date_result = conn.execute(f'''
|
||||
SELECT MIN("trade_date"), MAX("trade_date") FROM "{table_name}"
|
||||
''').fetchone()
|
||||
if date_result[0]:
|
||||
stats["min_date"] = (
|
||||
date_result[0].strftime("%Y%m%d")
|
||||
if hasattr(date_result[0], "strftime")
|
||||
else str(date_result[0])
|
||||
)
|
||||
if date_result[1]:
|
||||
stats["max_date"] = (
|
||||
date_result[1].strftime("%Y%m%d")
|
||||
if hasattr(date_result[1], "strftime")
|
||||
else str(date_result[1])
|
||||
)
|
||||
|
||||
# Unique stocks
|
||||
if "ts_code" in columns:
|
||||
unique_count = conn.execute(f'''
|
||||
SELECT COUNT(DISTINCT "ts_code") FROM "{table_name}"
|
||||
''').fetchone()[0]
|
||||
stats["unique_stocks"] = unique_count
|
||||
|
||||
except Exception as e:
|
||||
print(f"[IncrementalSync] Error getting stats for {table_name}: {e}")
|
||||
|
||||
return stats
|
||||
|
||||
def sync_data(
|
||||
self,
|
||||
table_name: str,
|
||||
data: pd.DataFrame,
|
||||
strategy: Literal["by_date", "by_stock", "replace"] = "by_date",
|
||||
) -> Dict[str, Any]:
|
||||
"""Sync data to table using specified strategy.
|
||||
|
||||
Args:
|
||||
table_name: Target table name
|
||||
data: DataFrame to sync
|
||||
strategy: Sync strategy
|
||||
- 'by_date': UPSERT based on primary key (ts_code, trade_date)
|
||||
- 'by_stock': Replace data for specific stocks
|
||||
- 'replace': Full replace of table
|
||||
|
||||
Returns:
|
||||
Dict with status, rows_inserted, rows_updated
|
||||
"""
|
||||
if data.empty:
|
||||
return {"status": "skipped", "rows_inserted": 0, "rows_updated": 0}
|
||||
|
||||
# Ensure table exists
|
||||
if not self.table_manager.ensure_table_exists(table_name, data):
|
||||
return {"status": "error", "error": "Failed to create table"}
|
||||
|
||||
try:
|
||||
if strategy == "replace":
|
||||
# Full replace
|
||||
result = self.storage.save(table_name, data, mode="replace")
|
||||
return {
|
||||
"status": result["status"],
|
||||
"rows_inserted": result.get("rows", 0),
|
||||
"rows_updated": 0,
|
||||
}
|
||||
|
||||
elif strategy == "by_stock":
|
||||
# Delete existing data for these stocks, then insert
|
||||
if "ts_code" in data.columns:
|
||||
stocks = data["ts_code"].unique().tolist()
|
||||
placeholders = ", ".join(["?"] * len(stocks))
|
||||
self.storage._connection.execute(
|
||||
f'''
|
||||
DELETE FROM "{table_name}" WHERE "ts_code" IN ({placeholders})
|
||||
''',
|
||||
stocks,
|
||||
)
|
||||
print(
|
||||
f"[IncrementalSync] Deleted existing data for {len(stocks)} stocks"
|
||||
)
|
||||
|
||||
result = self.storage.save(table_name, data, mode="append")
|
||||
return {
|
||||
"status": result["status"],
|
||||
"rows_inserted": result.get("rows", 0),
|
||||
"rows_updated": 0,
|
||||
}
|
||||
|
||||
else: # by_date (default)
|
||||
# UPSERT using INSERT OR REPLACE
|
||||
result = self.storage.save(table_name, data, mode="append")
|
||||
return {
|
||||
"status": result["status"],
|
||||
"rows_inserted": result.get("rows", 0),
|
||||
"rows_updated": 0, # DuckDB doesn't distinguish in UPSERT
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
print(f"[IncrementalSync] Error syncing data to {table_name}: {e}")
|
||||
return {"status": "error", "error": str(e)}
|
||||
|
||||
def _get_next_date(self, date_str: str) -> str:
|
||||
"""Get the next day after given date.
|
||||
|
||||
Args:
|
||||
date_str: Date in YYYYMMDD format
|
||||
|
||||
Returns:
|
||||
Next date in YYYYMMDD format
|
||||
"""
|
||||
dt = datetime.strptime(date_str, "%Y%m%d")
|
||||
next_dt = dt + timedelta(days=1)
|
||||
return next_dt.strftime("%Y%m%d")
|
||||
|
||||
|
||||
class SyncManager:
|
||||
"""High-level sync manager that coordinates table creation and incremental updates."""
|
||||
|
||||
def __init__(self, storage: Optional[Storage] = None):
|
||||
"""Initialize sync manager.
|
||||
|
||||
Args:
|
||||
storage: Storage instance (creates new if None)
|
||||
"""
|
||||
self.storage = storage or Storage()
|
||||
self.table_manager = TableManager(self.storage)
|
||||
self.incremental_sync = IncrementalSync(self.storage)
|
||||
|
||||
def sync(
|
||||
self,
|
||||
table_name: str,
|
||||
fetch_func: Callable[..., pd.DataFrame],
|
||||
start_date: str,
|
||||
end_date: str,
|
||||
stock_codes: Optional[List[str]] = None,
|
||||
**fetch_kwargs,
|
||||
) -> Dict[str, Any]:
|
||||
"""Main sync method - handles full logic of table creation and incremental sync.
|
||||
|
||||
This is the recommended way to sync data:
|
||||
1. Checks if table exists, creates if not
|
||||
2. Determines best sync strategy
|
||||
3. Fetches data using provided function
|
||||
4. Applies incremental update
|
||||
|
||||
Args:
|
||||
table_name: Target table name
|
||||
fetch_func: Function to fetch data (should return DataFrame)
|
||||
start_date: Start date for sync (YYYYMMDD)
|
||||
end_date: End date for sync (YYYYMMDD)
|
||||
stock_codes: Optional list of stocks to sync (None = all)
|
||||
**fetch_kwargs: Additional arguments to pass to fetch_func
|
||||
|
||||
Returns:
|
||||
Dict with sync results
|
||||
"""
|
||||
print(f"\n[SyncManager] Starting sync for table '{table_name}'")
|
||||
print(f"[SyncManager] Date range: {start_date} to {end_date}")
|
||||
|
||||
# Determine sync strategy
|
||||
strategy, sync_start, sync_end, stocks_to_sync = (
|
||||
self.incremental_sync.get_sync_strategy(
|
||||
table_name, start_date, end_date, stock_codes
|
||||
)
|
||||
)
|
||||
|
||||
if strategy == "none":
|
||||
print(f"[SyncManager] No sync needed for '{table_name}'")
|
||||
return {
|
||||
"status": "skipped",
|
||||
"table": table_name,
|
||||
"reason": "up-to-date",
|
||||
}
|
||||
|
||||
# Fetch data
|
||||
print(f"[SyncManager] Fetching data with strategy '{strategy}'...")
|
||||
try:
|
||||
if stocks_to_sync:
|
||||
# Fetch specific stocks
|
||||
data_list = []
|
||||
for ts_code in stocks_to_sync:
|
||||
df = fetch_func(
|
||||
ts_code=ts_code,
|
||||
start_date=sync_start,
|
||||
end_date=sync_end,
|
||||
**fetch_kwargs,
|
||||
)
|
||||
if not df.empty:
|
||||
data_list.append(df)
|
||||
|
||||
if data_list:
|
||||
data = pd.concat(data_list, ignore_index=True)
|
||||
else:
|
||||
data = pd.DataFrame()
|
||||
else:
|
||||
# Fetch all data at once
|
||||
data = fetch_func(
|
||||
start_date=sync_start, end_date=sync_end, **fetch_kwargs
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"[SyncManager] Error fetching data: {e}")
|
||||
return {
|
||||
"status": "error",
|
||||
"table": table_name,
|
||||
"error": str(e),
|
||||
}
|
||||
|
||||
if data.empty:
|
||||
print(f"[SyncManager] No data fetched")
|
||||
return {
|
||||
"status": "no_data",
|
||||
"table": table_name,
|
||||
}
|
||||
|
||||
print(f"[SyncManager] Fetched {len(data)} rows")
|
||||
|
||||
# Ensure table exists
|
||||
if not self.table_manager.ensure_table_exists(table_name, data):
|
||||
return {
|
||||
"status": "error",
|
||||
"table": table_name,
|
||||
"error": "Failed to create table",
|
||||
}
|
||||
|
||||
# Apply sync
|
||||
result = self.incremental_sync.sync_data(table_name, data, strategy)
|
||||
|
||||
print(f"[SyncManager] Sync complete: {result}")
|
||||
return {
|
||||
"status": result["status"],
|
||||
"table": table_name,
|
||||
"strategy": strategy,
|
||||
"rows": result.get("rows_inserted", 0),
|
||||
"date_range": f"{sync_start} to {sync_end}"
|
||||
if sync_start and sync_end
|
||||
else None,
|
||||
}
|
||||
|
||||
|
||||
# Convenience functions
|
||||
|
||||
|
||||
def ensure_table(
|
||||
table_name: str,
|
||||
sample_data: pd.DataFrame,
|
||||
storage: Optional[Storage] = None,
|
||||
) -> bool:
|
||||
"""Ensure a table exists, creating it if necessary.
|
||||
|
||||
Args:
|
||||
table_name: Name of the table
|
||||
sample_data: Sample DataFrame to define schema
|
||||
storage: Optional Storage instance
|
||||
|
||||
Returns:
|
||||
True if table exists or was created
|
||||
"""
|
||||
manager = TableManager(storage)
|
||||
return manager.ensure_table_exists(table_name, sample_data)
|
||||
|
||||
|
||||
def get_table_info(
|
||||
table_name: str, storage: Optional[Storage] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""Get information about a table.
|
||||
|
||||
Args:
|
||||
table_name: Name of the table
|
||||
storage: Optional Storage instance
|
||||
|
||||
Returns:
|
||||
Dict with table statistics
|
||||
"""
|
||||
sync = IncrementalSync(storage)
|
||||
return sync.get_table_stats(table_name)
|
||||
|
||||
|
||||
def sync_table(
|
||||
table_name: str,
|
||||
fetch_func: Callable[..., pd.DataFrame],
|
||||
start_date: str,
|
||||
end_date: str,
|
||||
stock_codes: Optional[List[str]] = None,
|
||||
storage: Optional[Storage] = None,
|
||||
**fetch_kwargs,
|
||||
) -> Dict[str, Any]:
|
||||
"""Sync data to a table with automatic table creation and incremental updates.
|
||||
|
||||
This is the main entry point for syncing data to DuckDB.
|
||||
|
||||
Args:
|
||||
table_name: Target table name
|
||||
fetch_func: Function to fetch data
|
||||
start_date: Start date (YYYYMMDD)
|
||||
end_date: End date (YYYYMMDD)
|
||||
stock_codes: Optional list of specific stocks
|
||||
storage: Optional Storage instance
|
||||
**fetch_kwargs: Additional arguments for fetch_func
|
||||
|
||||
Returns:
|
||||
Dict with sync results
|
||||
|
||||
Example:
|
||||
>>> from src.data.api_wrappers import get_daily
|
||||
>>> result = sync_table(
|
||||
... "daily",
|
||||
... get_daily,
|
||||
... "20240101",
|
||||
... "20240131",
|
||||
... stock_codes=["000001.SZ", "600000.SH"]
|
||||
... )
|
||||
"""
|
||||
manager = SyncManager(storage)
|
||||
return manager.sync(
|
||||
table_name=table_name,
|
||||
fetch_func=fetch_func,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
stock_codes=stock_codes,
|
||||
**fetch_kwargs,
|
||||
)
|
||||
@@ -1,36 +1,102 @@
|
||||
"""Simplified HDF5 storage for data persistence."""
|
||||
|
||||
import os
|
||||
"""DuckDB storage for data persistence."""
|
||||
import pandas as pd
|
||||
import polars as pl
|
||||
import duckdb
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
from typing import Optional, List, Dict, Any, Tuple
|
||||
from collections import defaultdict
|
||||
from datetime import datetime
|
||||
from src.data.config import get_config
|
||||
|
||||
|
||||
# Default column type mapping for automatic schema inference
|
||||
DEFAULT_TYPE_MAPPING = {
|
||||
"ts_code": "VARCHAR(16)",
|
||||
"trade_date": "DATE",
|
||||
"open": "DOUBLE",
|
||||
"high": "DOUBLE",
|
||||
"low": "DOUBLE",
|
||||
"close": "DOUBLE",
|
||||
"pre_close": "DOUBLE",
|
||||
"change": "DOUBLE",
|
||||
"pct_chg": "DOUBLE",
|
||||
"vol": "DOUBLE",
|
||||
"amount": "DOUBLE",
|
||||
"turnover_rate": "DOUBLE",
|
||||
"volume_ratio": "DOUBLE",
|
||||
"adj_factor": "DOUBLE",
|
||||
"suspend_flag": "INTEGER",
|
||||
}
|
||||
|
||||
|
||||
class Storage:
|
||||
"""HDF5 storage manager for saving and loading data."""
|
||||
"""DuckDB storage manager for saving and loading data.
|
||||
|
||||
迁移说明:
|
||||
- 保持 API 完全兼容,调用方无需修改
|
||||
- 新增 load_polars() 方法支持 Polars 零拷贝导出
|
||||
- 使用单例模式管理数据库连接
|
||||
- 并发写入通过队列管理(见 ThreadSafeStorage)
|
||||
"""
|
||||
|
||||
_instance = None
|
||||
_connection = None
|
||||
|
||||
def __new__(cls, *args, **kwargs):
|
||||
"""Singleton to ensure single connection."""
|
||||
if cls._instance is None:
|
||||
cls._instance = super().__new__(cls)
|
||||
return cls._instance
|
||||
|
||||
def __init__(self, path: Optional[Path] = None):
|
||||
"""Initialize storage.
|
||||
"""Initialize storage."""
|
||||
if hasattr(self, "_initialized"):
|
||||
return
|
||||
|
||||
Args:
|
||||
path: Base path for data storage (auto-loaded from config if not provided)
|
||||
"""
|
||||
cfg = get_config()
|
||||
self.base_path = path or cfg.data_path_resolved
|
||||
self.base_path.mkdir(parents=True, exist_ok=True)
|
||||
self.db_path = self.base_path / "prostock.db"
|
||||
|
||||
def _get_file_path(self, name: str) -> Path:
|
||||
"""Get full path for an HDF5 file."""
|
||||
return self.base_path / f"{name}.h5"
|
||||
self._init_db()
|
||||
self._initialized = True
|
||||
|
||||
def _init_db(self):
|
||||
"""Initialize database connection and schema."""
|
||||
self._connection = duckdb.connect(str(self.db_path))
|
||||
|
||||
# Create tables with schema validation
|
||||
self._connection.execute("""
|
||||
CREATE TABLE IF NOT EXISTS daily (
|
||||
ts_code VARCHAR(16) NOT NULL,
|
||||
trade_date DATE NOT NULL,
|
||||
open DOUBLE,
|
||||
high DOUBLE,
|
||||
low DOUBLE,
|
||||
close DOUBLE,
|
||||
pre_close DOUBLE,
|
||||
change DOUBLE,
|
||||
pct_chg DOUBLE,
|
||||
vol DOUBLE,
|
||||
amount DOUBLE,
|
||||
turnover_rate DOUBLE,
|
||||
volume_ratio DOUBLE,
|
||||
PRIMARY KEY (ts_code, trade_date)
|
||||
)
|
||||
""")
|
||||
|
||||
# Create composite index for query optimization (trade_date, ts_code)
|
||||
self._connection.execute("""
|
||||
CREATE INDEX IF NOT EXISTS idx_daily_date_code ON daily(trade_date, ts_code)
|
||||
""")
|
||||
|
||||
def save(self, name: str, data: pd.DataFrame, mode: str = "append") -> dict:
|
||||
"""Save data to HDF5 file.
|
||||
"""Save data to DuckDB.
|
||||
|
||||
Args:
|
||||
name: Dataset name (also used as filename)
|
||||
name: Table name
|
||||
data: DataFrame to save
|
||||
mode: 'append' or 'replace'
|
||||
mode: 'append' (UPSERT) or 'replace' (DELETE + INSERT)
|
||||
|
||||
Returns:
|
||||
Dict with save result
|
||||
@@ -38,27 +104,36 @@ class Storage:
|
||||
if data.empty:
|
||||
return {"status": "skipped", "rows": 0}
|
||||
|
||||
file_path = self._get_file_path(name)
|
||||
# Ensure date column is proper type
|
||||
if "trade_date" in data.columns:
|
||||
data = data.copy()
|
||||
data["trade_date"] = pd.to_datetime(
|
||||
data["trade_date"], format="%Y%m%d"
|
||||
).dt.date
|
||||
|
||||
# Register DataFrame as temporary view
|
||||
self._connection.register("temp_data", data)
|
||||
|
||||
try:
|
||||
with pd.HDFStore(file_path, mode="a") as store:
|
||||
if mode == "replace" or name not in store.keys():
|
||||
store.put(name, data, format="table")
|
||||
else:
|
||||
# Merge with existing data
|
||||
existing = store[name]
|
||||
combined = pd.concat([existing, data], ignore_index=True)
|
||||
combined = combined.drop_duplicates(
|
||||
subset=["ts_code", "trade_date"], keep="last"
|
||||
)
|
||||
store.put(name, combined, format="table")
|
||||
if mode == "replace":
|
||||
self._connection.execute(f"DELETE FROM {name}")
|
||||
|
||||
print(f"[Storage] Saved {len(data)} rows to {file_path}")
|
||||
return {"status": "success", "rows": len(data), "path": str(file_path)}
|
||||
# UPSERT: INSERT OR REPLACE
|
||||
columns = ", ".join(data.columns)
|
||||
self._connection.execute(f"""
|
||||
INSERT OR REPLACE INTO {name} ({columns})
|
||||
SELECT {columns} FROM temp_data
|
||||
""")
|
||||
|
||||
row_count = len(data)
|
||||
print(f"[Storage] Saved {row_count} rows to DuckDB ({name})")
|
||||
return {"status": "success", "rows": row_count}
|
||||
|
||||
except Exception as e:
|
||||
print(f"[Storage] Error saving {name}: {e}")
|
||||
return {"status": "error", "error": str(e)}
|
||||
finally:
|
||||
self._connection.unregister("temp_data")
|
||||
|
||||
def load(
|
||||
self,
|
||||
@@ -67,84 +142,182 @@ class Storage:
|
||||
end_date: Optional[str] = None,
|
||||
ts_code: Optional[str] = None,
|
||||
) -> pd.DataFrame:
|
||||
"""Load data from HDF5 file.
|
||||
"""Load data from DuckDB with query pushdown.
|
||||
|
||||
关键优化:
|
||||
- WHERE 条件在数据库层过滤,无需加载全表
|
||||
- 只返回匹配条件的行,大幅减少内存占用
|
||||
|
||||
Args:
|
||||
name: Dataset name
|
||||
name: Table name
|
||||
start_date: Start date filter (YYYYMMDD)
|
||||
end_date: End date filter (YYYYMMDD)
|
||||
ts_code: Stock code filter
|
||||
|
||||
Returns:
|
||||
DataFrame with loaded data
|
||||
Filtered DataFrame
|
||||
"""
|
||||
file_path = self._get_file_path(name)
|
||||
# Build WHERE clause with parameterized queries
|
||||
conditions = []
|
||||
params = []
|
||||
|
||||
if not file_path.exists():
|
||||
print(f"[Storage] File not found: {file_path}")
|
||||
return pd.DataFrame()
|
||||
if start_date and end_date:
|
||||
conditions.append("trade_date BETWEEN ? AND ?")
|
||||
# Convert to DATE type
|
||||
start = pd.to_datetime(start_date, format="%Y%m%d").date()
|
||||
end = pd.to_datetime(end_date, format="%Y%m%d").date()
|
||||
params.extend([start, end])
|
||||
elif start_date:
|
||||
conditions.append("trade_date >= ?")
|
||||
params.append(pd.to_datetime(start_date, format="%Y%m%d").date())
|
||||
elif end_date:
|
||||
conditions.append("trade_date <= ?")
|
||||
params.append(pd.to_datetime(end_date, format="%Y%m%d").date())
|
||||
|
||||
if ts_code:
|
||||
conditions.append("ts_code = ?")
|
||||
params.append(ts_code)
|
||||
|
||||
where_clause = f"WHERE {' AND '.join(conditions)}" if conditions else ""
|
||||
query = f"SELECT * FROM {name} {where_clause} ORDER BY trade_date"
|
||||
|
||||
try:
|
||||
with pd.HDFStore(file_path, mode="r") as store:
|
||||
keys = store.keys()
|
||||
# Handle both '/daily' and 'daily' keys
|
||||
actual_key = None
|
||||
if name in keys:
|
||||
actual_key = name
|
||||
elif f"/{name}" in keys:
|
||||
actual_key = f"/{name}"
|
||||
# Execute query with parameters (SQL injection safe)
|
||||
result = self._connection.execute(query, params).fetchdf()
|
||||
|
||||
if actual_key is None:
|
||||
return pd.DataFrame()
|
||||
|
||||
data = store[actual_key]
|
||||
|
||||
# Apply filters
|
||||
if start_date and end_date and "trade_date" in data.columns:
|
||||
data = data[
|
||||
(data["trade_date"] >= start_date)
|
||||
& (data["trade_date"] <= end_date)
|
||||
]
|
||||
|
||||
if ts_code and "ts_code" in data.columns:
|
||||
data = data[data["ts_code"] == ts_code]
|
||||
|
||||
return data
|
||||
# Convert trade_date back to string format for compatibility
|
||||
if "trade_date" in result.columns:
|
||||
result["trade_date"] = result["trade_date"].dt.strftime("%Y%m%d")
|
||||
|
||||
return result
|
||||
except Exception as e:
|
||||
print(f"[Storage] Error loading {name}: {e}")
|
||||
return pd.DataFrame()
|
||||
|
||||
def get_last_date(self, name: str) -> Optional[str]:
|
||||
"""Get the latest date in storage.
|
||||
def load_polars(
|
||||
self,
|
||||
name: str,
|
||||
start_date: Optional[str] = None,
|
||||
end_date: Optional[str] = None,
|
||||
ts_code: Optional[str] = None,
|
||||
) -> pl.DataFrame:
|
||||
"""Load data as Polars DataFrame (for DataLoader).
|
||||
|
||||
Args:
|
||||
name: Dataset name
|
||||
|
||||
Returns:
|
||||
Latest date string or None
|
||||
性能优势:
|
||||
- 零拷贝导出(DuckDB → Polars via PyArrow)
|
||||
- 需要 pyarrow 支持
|
||||
"""
|
||||
data = self.load(name)
|
||||
if data.empty or "trade_date" not in data.columns:
|
||||
return None
|
||||
return str(data["trade_date"].max())
|
||||
# Build query
|
||||
conditions = []
|
||||
if start_date and end_date:
|
||||
start = pd.to_datetime(start_date, format='%Y%m%d').date()
|
||||
end = pd.to_datetime(end_date, format='%Y%m%d').date()
|
||||
conditions.append(f"trade_date BETWEEN '{start}' AND '{end}'")
|
||||
if ts_code:
|
||||
conditions.append(f"ts_code = '{ts_code}'")
|
||||
|
||||
where_clause = f"WHERE {' AND '.join(conditions)}" if conditions else ""
|
||||
query = f"SELECT * FROM {name} {where_clause} ORDER BY trade_date"
|
||||
|
||||
# 使用 DuckDB 的 Polars 导出(需要 pyarrow)
|
||||
df = self._connection.sql(query).pl()
|
||||
|
||||
# 将 trade_date 转换为字符串格式,保持兼容性
|
||||
if "trade_date" in df.columns:
|
||||
df = df.with_columns(
|
||||
pl.col("trade_date").dt.strftime("%Y%m%d").alias("trade_date")
|
||||
)
|
||||
|
||||
return df
|
||||
|
||||
def exists(self, name: str) -> bool:
|
||||
"""Check if dataset exists."""
|
||||
return self._get_file_path(name).exists()
|
||||
"""Check if table exists."""
|
||||
result = self._connection.execute(
|
||||
"""
|
||||
SELECT COUNT(*) FROM information_schema.tables
|
||||
WHERE table_name = ?
|
||||
""",
|
||||
[name],
|
||||
).fetchone()
|
||||
return result[0] > 0
|
||||
|
||||
def delete(self, name: str) -> bool:
|
||||
"""Delete a dataset.
|
||||
|
||||
Args:
|
||||
name: Dataset name
|
||||
|
||||
Returns:
|
||||
True if deleted
|
||||
"""
|
||||
file_path = self._get_file_path(name)
|
||||
if file_path.exists():
|
||||
file_path.unlink()
|
||||
print(f"[Storage] Deleted {file_path}")
|
||||
"""Delete a table."""
|
||||
try:
|
||||
self._connection.execute(f"DROP TABLE IF EXISTS {name}")
|
||||
print(f"[Storage] Deleted table {name}")
|
||||
return True
|
||||
return False
|
||||
except Exception as e:
|
||||
print(f"[Storage] Error deleting {name}: {e}")
|
||||
return False
|
||||
|
||||
def get_last_date(self, name: str) -> Optional[str]:
|
||||
"""Get the latest date in storage."""
|
||||
try:
|
||||
result = self._connection.execute(f"""
|
||||
SELECT MAX(trade_date) FROM {name}
|
||||
""").fetchone()
|
||||
if result[0]:
|
||||
# Convert date back to string format
|
||||
return (
|
||||
result[0].strftime("%Y%m%d")
|
||||
if hasattr(result[0], "strftime")
|
||||
else str(result[0])
|
||||
)
|
||||
return None
|
||||
except:
|
||||
return None
|
||||
|
||||
def close(self):
|
||||
"""Close database connection."""
|
||||
if self._connection:
|
||||
self._connection.close()
|
||||
Storage._connection = None
|
||||
Storage._instance = None
|
||||
|
||||
|
||||
class ThreadSafeStorage:
|
||||
"""线程安全的 DuckDB 写入包装器。
|
||||
|
||||
DuckDB 写入时不支持并发,使用队列收集写入请求,
|
||||
在 sync 结束时统一批量写入。
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.storage = Storage()
|
||||
self._pending_writes: List[tuple] = [] # [(name, data), ...]
|
||||
|
||||
def queue_save(self, name: str, data: pd.DataFrame):
|
||||
"""将数据放入写入队列(不立即写入)"""
|
||||
if not data.empty:
|
||||
self._pending_writes.append((name, data))
|
||||
|
||||
def flush(self):
|
||||
"""批量写入所有队列数据。
|
||||
|
||||
调用时机:在 sync 结束时统一调用,避免并发写入冲突。
|
||||
"""
|
||||
if not self._pending_writes:
|
||||
return
|
||||
|
||||
# 合并相同表的数据
|
||||
table_data = defaultdict(list)
|
||||
|
||||
for name, data in self._pending_writes:
|
||||
table_data[name].append(data)
|
||||
|
||||
# 批量写入每个表
|
||||
for name, data_list in table_data.items():
|
||||
combined = pd.concat(data_list, ignore_index=True)
|
||||
# 在批量数据中先去重
|
||||
if "ts_code" in combined.columns and "trade_date" in combined.columns:
|
||||
combined = combined.drop_duplicates(
|
||||
subset=["ts_code", "trade_date"], keep="last"
|
||||
)
|
||||
self.storage.save(name, combined, mode="append")
|
||||
|
||||
self._pending_writes.clear()
|
||||
|
||||
def __getattr__(self, name):
|
||||
"""代理其他方法到 Storage 实例"""
|
||||
return getattr(self.storage, name)
|
||||
|
||||
@@ -36,7 +36,7 @@ import threading
|
||||
import sys
|
||||
|
||||
from src.data.client import TushareClient
|
||||
from src.data.storage import Storage
|
||||
from src.data.storage import ThreadSafeStorage
|
||||
from src.data.api_wrappers import get_daily
|
||||
from src.data.api_wrappers import (
|
||||
get_first_trading_day,
|
||||
@@ -83,7 +83,7 @@ class DataSync:
|
||||
Args:
|
||||
max_workers: Number of worker threads (default: 10)
|
||||
"""
|
||||
self.storage = Storage()
|
||||
self.storage = ThreadSafeStorage()
|
||||
self.client = TushareClient()
|
||||
self.max_workers = max_workers or self.DEFAULT_MAX_WORKERS
|
||||
self._stop_flag = threading.Event()
|
||||
@@ -667,11 +667,15 @@ class DataSync:
|
||||
finally:
|
||||
pbar.close()
|
||||
|
||||
# Write all data at once (only if no error)
|
||||
# Queue all data for batch write (only if no error)
|
||||
if results and not error_occurred:
|
||||
combined_data = pd.concat(results.values(), ignore_index=True)
|
||||
self.storage.save("daily", combined_data, mode="append")
|
||||
print(f"\n[DataSync] Saved {len(combined_data)} rows to storage")
|
||||
for ts_code, data in results.items():
|
||||
if not data.empty:
|
||||
self.storage.queue_save("daily", data)
|
||||
# Flush all queued writes at once
|
||||
self.storage.flush()
|
||||
total_rows = sum(len(df) for df in results.values())
|
||||
print(f"\n[DataSync] Saved {total_rows} rows to storage")
|
||||
|
||||
# Summary
|
||||
print("\n" + "=" * 60)
|
||||
|
||||
Reference in New Issue
Block a user