feat: HDF5迁移至DuckDB存储
- 新增DuckDB Storage与ThreadSafeStorage实现 - 新增db_manager模块支持增量同步策略 - DataLoader与Sync模块适配DuckDB - 补充迁移相关文档与测试 - 修复README文档链接
This commit is contained in:
@@ -1,36 +1,102 @@
|
||||
"""Simplified HDF5 storage for data persistence."""
|
||||
|
||||
import os
|
||||
"""DuckDB storage for data persistence."""
|
||||
import pandas as pd
|
||||
import polars as pl
|
||||
import duckdb
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
from typing import Optional, List, Dict, Any, Tuple
|
||||
from collections import defaultdict
|
||||
from datetime import datetime
|
||||
from src.data.config import get_config
|
||||
|
||||
|
||||
# Default column type mapping for automatic schema inference
|
||||
DEFAULT_TYPE_MAPPING = {
|
||||
"ts_code": "VARCHAR(16)",
|
||||
"trade_date": "DATE",
|
||||
"open": "DOUBLE",
|
||||
"high": "DOUBLE",
|
||||
"low": "DOUBLE",
|
||||
"close": "DOUBLE",
|
||||
"pre_close": "DOUBLE",
|
||||
"change": "DOUBLE",
|
||||
"pct_chg": "DOUBLE",
|
||||
"vol": "DOUBLE",
|
||||
"amount": "DOUBLE",
|
||||
"turnover_rate": "DOUBLE",
|
||||
"volume_ratio": "DOUBLE",
|
||||
"adj_factor": "DOUBLE",
|
||||
"suspend_flag": "INTEGER",
|
||||
}
|
||||
|
||||
|
||||
class Storage:
|
||||
"""HDF5 storage manager for saving and loading data."""
|
||||
"""DuckDB storage manager for saving and loading data.
|
||||
|
||||
迁移说明:
|
||||
- 保持 API 完全兼容,调用方无需修改
|
||||
- 新增 load_polars() 方法支持 Polars 零拷贝导出
|
||||
- 使用单例模式管理数据库连接
|
||||
- 并发写入通过队列管理(见 ThreadSafeStorage)
|
||||
"""
|
||||
|
||||
_instance = None
|
||||
_connection = None
|
||||
|
||||
def __new__(cls, *args, **kwargs):
|
||||
"""Singleton to ensure single connection."""
|
||||
if cls._instance is None:
|
||||
cls._instance = super().__new__(cls)
|
||||
return cls._instance
|
||||
|
||||
def __init__(self, path: Optional[Path] = None):
|
||||
"""Initialize storage.
|
||||
"""Initialize storage."""
|
||||
if hasattr(self, "_initialized"):
|
||||
return
|
||||
|
||||
Args:
|
||||
path: Base path for data storage (auto-loaded from config if not provided)
|
||||
"""
|
||||
cfg = get_config()
|
||||
self.base_path = path or cfg.data_path_resolved
|
||||
self.base_path.mkdir(parents=True, exist_ok=True)
|
||||
self.db_path = self.base_path / "prostock.db"
|
||||
|
||||
def _get_file_path(self, name: str) -> Path:
|
||||
"""Get full path for an HDF5 file."""
|
||||
return self.base_path / f"{name}.h5"
|
||||
self._init_db()
|
||||
self._initialized = True
|
||||
|
||||
def _init_db(self):
|
||||
"""Initialize database connection and schema."""
|
||||
self._connection = duckdb.connect(str(self.db_path))
|
||||
|
||||
# Create tables with schema validation
|
||||
self._connection.execute("""
|
||||
CREATE TABLE IF NOT EXISTS daily (
|
||||
ts_code VARCHAR(16) NOT NULL,
|
||||
trade_date DATE NOT NULL,
|
||||
open DOUBLE,
|
||||
high DOUBLE,
|
||||
low DOUBLE,
|
||||
close DOUBLE,
|
||||
pre_close DOUBLE,
|
||||
change DOUBLE,
|
||||
pct_chg DOUBLE,
|
||||
vol DOUBLE,
|
||||
amount DOUBLE,
|
||||
turnover_rate DOUBLE,
|
||||
volume_ratio DOUBLE,
|
||||
PRIMARY KEY (ts_code, trade_date)
|
||||
)
|
||||
""")
|
||||
|
||||
# Create composite index for query optimization (trade_date, ts_code)
|
||||
self._connection.execute("""
|
||||
CREATE INDEX IF NOT EXISTS idx_daily_date_code ON daily(trade_date, ts_code)
|
||||
""")
|
||||
|
||||
def save(self, name: str, data: pd.DataFrame, mode: str = "append") -> dict:
|
||||
"""Save data to HDF5 file.
|
||||
"""Save data to DuckDB.
|
||||
|
||||
Args:
|
||||
name: Dataset name (also used as filename)
|
||||
name: Table name
|
||||
data: DataFrame to save
|
||||
mode: 'append' or 'replace'
|
||||
mode: 'append' (UPSERT) or 'replace' (DELETE + INSERT)
|
||||
|
||||
Returns:
|
||||
Dict with save result
|
||||
@@ -38,27 +104,36 @@ class Storage:
|
||||
if data.empty:
|
||||
return {"status": "skipped", "rows": 0}
|
||||
|
||||
file_path = self._get_file_path(name)
|
||||
# Ensure date column is proper type
|
||||
if "trade_date" in data.columns:
|
||||
data = data.copy()
|
||||
data["trade_date"] = pd.to_datetime(
|
||||
data["trade_date"], format="%Y%m%d"
|
||||
).dt.date
|
||||
|
||||
# Register DataFrame as temporary view
|
||||
self._connection.register("temp_data", data)
|
||||
|
||||
try:
|
||||
with pd.HDFStore(file_path, mode="a") as store:
|
||||
if mode == "replace" or name not in store.keys():
|
||||
store.put(name, data, format="table")
|
||||
else:
|
||||
# Merge with existing data
|
||||
existing = store[name]
|
||||
combined = pd.concat([existing, data], ignore_index=True)
|
||||
combined = combined.drop_duplicates(
|
||||
subset=["ts_code", "trade_date"], keep="last"
|
||||
)
|
||||
store.put(name, combined, format="table")
|
||||
if mode == "replace":
|
||||
self._connection.execute(f"DELETE FROM {name}")
|
||||
|
||||
print(f"[Storage] Saved {len(data)} rows to {file_path}")
|
||||
return {"status": "success", "rows": len(data), "path": str(file_path)}
|
||||
# UPSERT: INSERT OR REPLACE
|
||||
columns = ", ".join(data.columns)
|
||||
self._connection.execute(f"""
|
||||
INSERT OR REPLACE INTO {name} ({columns})
|
||||
SELECT {columns} FROM temp_data
|
||||
""")
|
||||
|
||||
row_count = len(data)
|
||||
print(f"[Storage] Saved {row_count} rows to DuckDB ({name})")
|
||||
return {"status": "success", "rows": row_count}
|
||||
|
||||
except Exception as e:
|
||||
print(f"[Storage] Error saving {name}: {e}")
|
||||
return {"status": "error", "error": str(e)}
|
||||
finally:
|
||||
self._connection.unregister("temp_data")
|
||||
|
||||
def load(
|
||||
self,
|
||||
@@ -67,84 +142,182 @@ class Storage:
|
||||
end_date: Optional[str] = None,
|
||||
ts_code: Optional[str] = None,
|
||||
) -> pd.DataFrame:
|
||||
"""Load data from HDF5 file.
|
||||
"""Load data from DuckDB with query pushdown.
|
||||
|
||||
关键优化:
|
||||
- WHERE 条件在数据库层过滤,无需加载全表
|
||||
- 只返回匹配条件的行,大幅减少内存占用
|
||||
|
||||
Args:
|
||||
name: Dataset name
|
||||
name: Table name
|
||||
start_date: Start date filter (YYYYMMDD)
|
||||
end_date: End date filter (YYYYMMDD)
|
||||
ts_code: Stock code filter
|
||||
|
||||
Returns:
|
||||
DataFrame with loaded data
|
||||
Filtered DataFrame
|
||||
"""
|
||||
file_path = self._get_file_path(name)
|
||||
# Build WHERE clause with parameterized queries
|
||||
conditions = []
|
||||
params = []
|
||||
|
||||
if not file_path.exists():
|
||||
print(f"[Storage] File not found: {file_path}")
|
||||
return pd.DataFrame()
|
||||
if start_date and end_date:
|
||||
conditions.append("trade_date BETWEEN ? AND ?")
|
||||
# Convert to DATE type
|
||||
start = pd.to_datetime(start_date, format="%Y%m%d").date()
|
||||
end = pd.to_datetime(end_date, format="%Y%m%d").date()
|
||||
params.extend([start, end])
|
||||
elif start_date:
|
||||
conditions.append("trade_date >= ?")
|
||||
params.append(pd.to_datetime(start_date, format="%Y%m%d").date())
|
||||
elif end_date:
|
||||
conditions.append("trade_date <= ?")
|
||||
params.append(pd.to_datetime(end_date, format="%Y%m%d").date())
|
||||
|
||||
if ts_code:
|
||||
conditions.append("ts_code = ?")
|
||||
params.append(ts_code)
|
||||
|
||||
where_clause = f"WHERE {' AND '.join(conditions)}" if conditions else ""
|
||||
query = f"SELECT * FROM {name} {where_clause} ORDER BY trade_date"
|
||||
|
||||
try:
|
||||
with pd.HDFStore(file_path, mode="r") as store:
|
||||
keys = store.keys()
|
||||
# Handle both '/daily' and 'daily' keys
|
||||
actual_key = None
|
||||
if name in keys:
|
||||
actual_key = name
|
||||
elif f"/{name}" in keys:
|
||||
actual_key = f"/{name}"
|
||||
# Execute query with parameters (SQL injection safe)
|
||||
result = self._connection.execute(query, params).fetchdf()
|
||||
|
||||
if actual_key is None:
|
||||
return pd.DataFrame()
|
||||
|
||||
data = store[actual_key]
|
||||
|
||||
# Apply filters
|
||||
if start_date and end_date and "trade_date" in data.columns:
|
||||
data = data[
|
||||
(data["trade_date"] >= start_date)
|
||||
& (data["trade_date"] <= end_date)
|
||||
]
|
||||
|
||||
if ts_code and "ts_code" in data.columns:
|
||||
data = data[data["ts_code"] == ts_code]
|
||||
|
||||
return data
|
||||
# Convert trade_date back to string format for compatibility
|
||||
if "trade_date" in result.columns:
|
||||
result["trade_date"] = result["trade_date"].dt.strftime("%Y%m%d")
|
||||
|
||||
return result
|
||||
except Exception as e:
|
||||
print(f"[Storage] Error loading {name}: {e}")
|
||||
return pd.DataFrame()
|
||||
|
||||
def get_last_date(self, name: str) -> Optional[str]:
|
||||
"""Get the latest date in storage.
|
||||
def load_polars(
|
||||
self,
|
||||
name: str,
|
||||
start_date: Optional[str] = None,
|
||||
end_date: Optional[str] = None,
|
||||
ts_code: Optional[str] = None,
|
||||
) -> pl.DataFrame:
|
||||
"""Load data as Polars DataFrame (for DataLoader).
|
||||
|
||||
Args:
|
||||
name: Dataset name
|
||||
|
||||
Returns:
|
||||
Latest date string or None
|
||||
性能优势:
|
||||
- 零拷贝导出(DuckDB → Polars via PyArrow)
|
||||
- 需要 pyarrow 支持
|
||||
"""
|
||||
data = self.load(name)
|
||||
if data.empty or "trade_date" not in data.columns:
|
||||
return None
|
||||
return str(data["trade_date"].max())
|
||||
# Build query
|
||||
conditions = []
|
||||
if start_date and end_date:
|
||||
start = pd.to_datetime(start_date, format='%Y%m%d').date()
|
||||
end = pd.to_datetime(end_date, format='%Y%m%d').date()
|
||||
conditions.append(f"trade_date BETWEEN '{start}' AND '{end}'")
|
||||
if ts_code:
|
||||
conditions.append(f"ts_code = '{ts_code}'")
|
||||
|
||||
where_clause = f"WHERE {' AND '.join(conditions)}" if conditions else ""
|
||||
query = f"SELECT * FROM {name} {where_clause} ORDER BY trade_date"
|
||||
|
||||
# 使用 DuckDB 的 Polars 导出(需要 pyarrow)
|
||||
df = self._connection.sql(query).pl()
|
||||
|
||||
# 将 trade_date 转换为字符串格式,保持兼容性
|
||||
if "trade_date" in df.columns:
|
||||
df = df.with_columns(
|
||||
pl.col("trade_date").dt.strftime("%Y%m%d").alias("trade_date")
|
||||
)
|
||||
|
||||
return df
|
||||
|
||||
def exists(self, name: str) -> bool:
|
||||
"""Check if dataset exists."""
|
||||
return self._get_file_path(name).exists()
|
||||
"""Check if table exists."""
|
||||
result = self._connection.execute(
|
||||
"""
|
||||
SELECT COUNT(*) FROM information_schema.tables
|
||||
WHERE table_name = ?
|
||||
""",
|
||||
[name],
|
||||
).fetchone()
|
||||
return result[0] > 0
|
||||
|
||||
def delete(self, name: str) -> bool:
|
||||
"""Delete a dataset.
|
||||
|
||||
Args:
|
||||
name: Dataset name
|
||||
|
||||
Returns:
|
||||
True if deleted
|
||||
"""
|
||||
file_path = self._get_file_path(name)
|
||||
if file_path.exists():
|
||||
file_path.unlink()
|
||||
print(f"[Storage] Deleted {file_path}")
|
||||
"""Delete a table."""
|
||||
try:
|
||||
self._connection.execute(f"DROP TABLE IF EXISTS {name}")
|
||||
print(f"[Storage] Deleted table {name}")
|
||||
return True
|
||||
return False
|
||||
except Exception as e:
|
||||
print(f"[Storage] Error deleting {name}: {e}")
|
||||
return False
|
||||
|
||||
def get_last_date(self, name: str) -> Optional[str]:
|
||||
"""Get the latest date in storage."""
|
||||
try:
|
||||
result = self._connection.execute(f"""
|
||||
SELECT MAX(trade_date) FROM {name}
|
||||
""").fetchone()
|
||||
if result[0]:
|
||||
# Convert date back to string format
|
||||
return (
|
||||
result[0].strftime("%Y%m%d")
|
||||
if hasattr(result[0], "strftime")
|
||||
else str(result[0])
|
||||
)
|
||||
return None
|
||||
except:
|
||||
return None
|
||||
|
||||
def close(self):
|
||||
"""Close database connection."""
|
||||
if self._connection:
|
||||
self._connection.close()
|
||||
Storage._connection = None
|
||||
Storage._instance = None
|
||||
|
||||
|
||||
class ThreadSafeStorage:
|
||||
"""线程安全的 DuckDB 写入包装器。
|
||||
|
||||
DuckDB 写入时不支持并发,使用队列收集写入请求,
|
||||
在 sync 结束时统一批量写入。
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.storage = Storage()
|
||||
self._pending_writes: List[tuple] = [] # [(name, data), ...]
|
||||
|
||||
def queue_save(self, name: str, data: pd.DataFrame):
|
||||
"""将数据放入写入队列(不立即写入)"""
|
||||
if not data.empty:
|
||||
self._pending_writes.append((name, data))
|
||||
|
||||
def flush(self):
|
||||
"""批量写入所有队列数据。
|
||||
|
||||
调用时机:在 sync 结束时统一调用,避免并发写入冲突。
|
||||
"""
|
||||
if not self._pending_writes:
|
||||
return
|
||||
|
||||
# 合并相同表的数据
|
||||
table_data = defaultdict(list)
|
||||
|
||||
for name, data in self._pending_writes:
|
||||
table_data[name].append(data)
|
||||
|
||||
# 批量写入每个表
|
||||
for name, data_list in table_data.items():
|
||||
combined = pd.concat(data_list, ignore_index=True)
|
||||
# 在批量数据中先去重
|
||||
if "ts_code" in combined.columns and "trade_date" in combined.columns:
|
||||
combined = combined.drop_duplicates(
|
||||
subset=["ts_code", "trade_date"], keep="last"
|
||||
)
|
||||
self.storage.save(name, combined, mode="append")
|
||||
|
||||
self._pending_writes.clear()
|
||||
|
||||
def __getattr__(self, name):
|
||||
"""代理其他方法到 Storage 实例"""
|
||||
return getattr(self.storage, name)
|
||||
|
||||
Reference in New Issue
Block a user