feat: HDF5迁移至DuckDB存储

- 新增DuckDB Storage与ThreadSafeStorage实现
- 新增db_manager模块支持增量同步策略
- DataLoader与Sync模块适配DuckDB
- 补充迁移相关文档与测试
- 修复README文档链接
This commit is contained in:
2026-02-23 00:07:21 +08:00
parent 0a16129548
commit e58b39970c
14 changed files with 2265 additions and 329 deletions

View File

@@ -1,36 +1,102 @@
"""Simplified HDF5 storage for data persistence."""
import os
"""DuckDB storage for data persistence."""
import pandas as pd
import polars as pl
import duckdb
from pathlib import Path
from typing import Optional
from typing import Optional, List, Dict, Any, Tuple
from collections import defaultdict
from datetime import datetime
from src.data.config import get_config
# Default column type mapping for automatic schema inference
DEFAULT_TYPE_MAPPING = {
"ts_code": "VARCHAR(16)",
"trade_date": "DATE",
"open": "DOUBLE",
"high": "DOUBLE",
"low": "DOUBLE",
"close": "DOUBLE",
"pre_close": "DOUBLE",
"change": "DOUBLE",
"pct_chg": "DOUBLE",
"vol": "DOUBLE",
"amount": "DOUBLE",
"turnover_rate": "DOUBLE",
"volume_ratio": "DOUBLE",
"adj_factor": "DOUBLE",
"suspend_flag": "INTEGER",
}
class Storage:
"""HDF5 storage manager for saving and loading data."""
"""DuckDB storage manager for saving and loading data.
迁移说明:
- 保持 API 完全兼容,调用方无需修改
- 新增 load_polars() 方法支持 Polars 零拷贝导出
- 使用单例模式管理数据库连接
- 并发写入通过队列管理(见 ThreadSafeStorage
"""
_instance = None
_connection = None
def __new__(cls, *args, **kwargs):
"""Singleton to ensure single connection."""
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
def __init__(self, path: Optional[Path] = None):
"""Initialize storage.
"""Initialize storage."""
if hasattr(self, "_initialized"):
return
Args:
path: Base path for data storage (auto-loaded from config if not provided)
"""
cfg = get_config()
self.base_path = path or cfg.data_path_resolved
self.base_path.mkdir(parents=True, exist_ok=True)
self.db_path = self.base_path / "prostock.db"
def _get_file_path(self, name: str) -> Path:
"""Get full path for an HDF5 file."""
return self.base_path / f"{name}.h5"
self._init_db()
self._initialized = True
def _init_db(self):
"""Initialize database connection and schema."""
self._connection = duckdb.connect(str(self.db_path))
# Create tables with schema validation
self._connection.execute("""
CREATE TABLE IF NOT EXISTS daily (
ts_code VARCHAR(16) NOT NULL,
trade_date DATE NOT NULL,
open DOUBLE,
high DOUBLE,
low DOUBLE,
close DOUBLE,
pre_close DOUBLE,
change DOUBLE,
pct_chg DOUBLE,
vol DOUBLE,
amount DOUBLE,
turnover_rate DOUBLE,
volume_ratio DOUBLE,
PRIMARY KEY (ts_code, trade_date)
)
""")
# Create composite index for query optimization (trade_date, ts_code)
self._connection.execute("""
CREATE INDEX IF NOT EXISTS idx_daily_date_code ON daily(trade_date, ts_code)
""")
def save(self, name: str, data: pd.DataFrame, mode: str = "append") -> dict:
"""Save data to HDF5 file.
"""Save data to DuckDB.
Args:
name: Dataset name (also used as filename)
name: Table name
data: DataFrame to save
mode: 'append' or 'replace'
mode: 'append' (UPSERT) or 'replace' (DELETE + INSERT)
Returns:
Dict with save result
@@ -38,27 +104,36 @@ class Storage:
if data.empty:
return {"status": "skipped", "rows": 0}
file_path = self._get_file_path(name)
# Ensure date column is proper type
if "trade_date" in data.columns:
data = data.copy()
data["trade_date"] = pd.to_datetime(
data["trade_date"], format="%Y%m%d"
).dt.date
# Register DataFrame as temporary view
self._connection.register("temp_data", data)
try:
with pd.HDFStore(file_path, mode="a") as store:
if mode == "replace" or name not in store.keys():
store.put(name, data, format="table")
else:
# Merge with existing data
existing = store[name]
combined = pd.concat([existing, data], ignore_index=True)
combined = combined.drop_duplicates(
subset=["ts_code", "trade_date"], keep="last"
)
store.put(name, combined, format="table")
if mode == "replace":
self._connection.execute(f"DELETE FROM {name}")
print(f"[Storage] Saved {len(data)} rows to {file_path}")
return {"status": "success", "rows": len(data), "path": str(file_path)}
# UPSERT: INSERT OR REPLACE
columns = ", ".join(data.columns)
self._connection.execute(f"""
INSERT OR REPLACE INTO {name} ({columns})
SELECT {columns} FROM temp_data
""")
row_count = len(data)
print(f"[Storage] Saved {row_count} rows to DuckDB ({name})")
return {"status": "success", "rows": row_count}
except Exception as e:
print(f"[Storage] Error saving {name}: {e}")
return {"status": "error", "error": str(e)}
finally:
self._connection.unregister("temp_data")
def load(
self,
@@ -67,84 +142,182 @@ class Storage:
end_date: Optional[str] = None,
ts_code: Optional[str] = None,
) -> pd.DataFrame:
"""Load data from HDF5 file.
"""Load data from DuckDB with query pushdown.
关键优化:
- WHERE 条件在数据库层过滤,无需加载全表
- 只返回匹配条件的行,大幅减少内存占用
Args:
name: Dataset name
name: Table name
start_date: Start date filter (YYYYMMDD)
end_date: End date filter (YYYYMMDD)
ts_code: Stock code filter
Returns:
DataFrame with loaded data
Filtered DataFrame
"""
file_path = self._get_file_path(name)
# Build WHERE clause with parameterized queries
conditions = []
params = []
if not file_path.exists():
print(f"[Storage] File not found: {file_path}")
return pd.DataFrame()
if start_date and end_date:
conditions.append("trade_date BETWEEN ? AND ?")
# Convert to DATE type
start = pd.to_datetime(start_date, format="%Y%m%d").date()
end = pd.to_datetime(end_date, format="%Y%m%d").date()
params.extend([start, end])
elif start_date:
conditions.append("trade_date >= ?")
params.append(pd.to_datetime(start_date, format="%Y%m%d").date())
elif end_date:
conditions.append("trade_date <= ?")
params.append(pd.to_datetime(end_date, format="%Y%m%d").date())
if ts_code:
conditions.append("ts_code = ?")
params.append(ts_code)
where_clause = f"WHERE {' AND '.join(conditions)}" if conditions else ""
query = f"SELECT * FROM {name} {where_clause} ORDER BY trade_date"
try:
with pd.HDFStore(file_path, mode="r") as store:
keys = store.keys()
# Handle both '/daily' and 'daily' keys
actual_key = None
if name in keys:
actual_key = name
elif f"/{name}" in keys:
actual_key = f"/{name}"
# Execute query with parameters (SQL injection safe)
result = self._connection.execute(query, params).fetchdf()
if actual_key is None:
return pd.DataFrame()
data = store[actual_key]
# Apply filters
if start_date and end_date and "trade_date" in data.columns:
data = data[
(data["trade_date"] >= start_date)
& (data["trade_date"] <= end_date)
]
if ts_code and "ts_code" in data.columns:
data = data[data["ts_code"] == ts_code]
return data
# Convert trade_date back to string format for compatibility
if "trade_date" in result.columns:
result["trade_date"] = result["trade_date"].dt.strftime("%Y%m%d")
return result
except Exception as e:
print(f"[Storage] Error loading {name}: {e}")
return pd.DataFrame()
def get_last_date(self, name: str) -> Optional[str]:
"""Get the latest date in storage.
def load_polars(
self,
name: str,
start_date: Optional[str] = None,
end_date: Optional[str] = None,
ts_code: Optional[str] = None,
) -> pl.DataFrame:
"""Load data as Polars DataFrame (for DataLoader).
Args:
name: Dataset name
Returns:
Latest date string or None
性能优势:
- 零拷贝导出DuckDB → Polars via PyArrow
- 需要 pyarrow 支持
"""
data = self.load(name)
if data.empty or "trade_date" not in data.columns:
return None
return str(data["trade_date"].max())
# Build query
conditions = []
if start_date and end_date:
start = pd.to_datetime(start_date, format='%Y%m%d').date()
end = pd.to_datetime(end_date, format='%Y%m%d').date()
conditions.append(f"trade_date BETWEEN '{start}' AND '{end}'")
if ts_code:
conditions.append(f"ts_code = '{ts_code}'")
where_clause = f"WHERE {' AND '.join(conditions)}" if conditions else ""
query = f"SELECT * FROM {name} {where_clause} ORDER BY trade_date"
# 使用 DuckDB 的 Polars 导出(需要 pyarrow
df = self._connection.sql(query).pl()
# 将 trade_date 转换为字符串格式,保持兼容性
if "trade_date" in df.columns:
df = df.with_columns(
pl.col("trade_date").dt.strftime("%Y%m%d").alias("trade_date")
)
return df
def exists(self, name: str) -> bool:
"""Check if dataset exists."""
return self._get_file_path(name).exists()
"""Check if table exists."""
result = self._connection.execute(
"""
SELECT COUNT(*) FROM information_schema.tables
WHERE table_name = ?
""",
[name],
).fetchone()
return result[0] > 0
def delete(self, name: str) -> bool:
"""Delete a dataset.
Args:
name: Dataset name
Returns:
True if deleted
"""
file_path = self._get_file_path(name)
if file_path.exists():
file_path.unlink()
print(f"[Storage] Deleted {file_path}")
"""Delete a table."""
try:
self._connection.execute(f"DROP TABLE IF EXISTS {name}")
print(f"[Storage] Deleted table {name}")
return True
return False
except Exception as e:
print(f"[Storage] Error deleting {name}: {e}")
return False
def get_last_date(self, name: str) -> Optional[str]:
"""Get the latest date in storage."""
try:
result = self._connection.execute(f"""
SELECT MAX(trade_date) FROM {name}
""").fetchone()
if result[0]:
# Convert date back to string format
return (
result[0].strftime("%Y%m%d")
if hasattr(result[0], "strftime")
else str(result[0])
)
return None
except:
return None
def close(self):
"""Close database connection."""
if self._connection:
self._connection.close()
Storage._connection = None
Storage._instance = None
class ThreadSafeStorage:
"""线程安全的 DuckDB 写入包装器。
DuckDB 写入时不支持并发,使用队列收集写入请求,
在 sync 结束时统一批量写入。
"""
def __init__(self):
self.storage = Storage()
self._pending_writes: List[tuple] = [] # [(name, data), ...]
def queue_save(self, name: str, data: pd.DataFrame):
"""将数据放入写入队列(不立即写入)"""
if not data.empty:
self._pending_writes.append((name, data))
def flush(self):
"""批量写入所有队列数据。
调用时机:在 sync 结束时统一调用,避免并发写入冲突。
"""
if not self._pending_writes:
return
# 合并相同表的数据
table_data = defaultdict(list)
for name, data in self._pending_writes:
table_data[name].append(data)
# 批量写入每个表
for name, data_list in table_data.items():
combined = pd.concat(data_list, ignore_index=True)
# 在批量数据中先去重
if "ts_code" in combined.columns and "trade_date" in combined.columns:
combined = combined.drop_duplicates(
subset=["ts_code", "trade_date"], keep="last"
)
self.storage.save(name, combined, mode="append")
self._pending_writes.clear()
def __getattr__(self, name):
"""代理其他方法到 Storage 实例"""
return getattr(self.storage, name)