2026-02-23 00:07:21 +08:00
|
|
|
|
"""DuckDB storage for data persistence."""
|
2026-01-31 03:04:51 +08:00
|
|
|
|
import pandas as pd
|
2026-02-23 00:07:21 +08:00
|
|
|
|
import polars as pl
|
|
|
|
|
|
import duckdb
|
2026-01-31 03:04:51 +08:00
|
|
|
|
from pathlib import Path
|
2026-02-23 00:07:21 +08:00
|
|
|
|
from typing import Optional, List, Dict, Any, Tuple
|
|
|
|
|
|
from collections import defaultdict
|
|
|
|
|
|
from datetime import datetime
|
2026-01-31 03:04:51 +08:00
|
|
|
|
from src.data.config import get_config
|
|
|
|
|
|
|
|
|
|
|
|
|
2026-02-23 00:07:21 +08:00
|
|
|
|
# Default column type mapping for automatic schema inference
|
|
|
|
|
|
DEFAULT_TYPE_MAPPING = {
|
|
|
|
|
|
"ts_code": "VARCHAR(16)",
|
|
|
|
|
|
"trade_date": "DATE",
|
|
|
|
|
|
"open": "DOUBLE",
|
|
|
|
|
|
"high": "DOUBLE",
|
|
|
|
|
|
"low": "DOUBLE",
|
|
|
|
|
|
"close": "DOUBLE",
|
|
|
|
|
|
"pre_close": "DOUBLE",
|
|
|
|
|
|
"change": "DOUBLE",
|
|
|
|
|
|
"pct_chg": "DOUBLE",
|
|
|
|
|
|
"vol": "DOUBLE",
|
|
|
|
|
|
"amount": "DOUBLE",
|
|
|
|
|
|
"turnover_rate": "DOUBLE",
|
|
|
|
|
|
"volume_ratio": "DOUBLE",
|
|
|
|
|
|
"adj_factor": "DOUBLE",
|
|
|
|
|
|
"suspend_flag": "INTEGER",
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
2026-01-31 03:04:51 +08:00
|
|
|
|
class Storage:
|
2026-02-23 00:07:21 +08:00
|
|
|
|
"""DuckDB storage manager for saving and loading data.
|
|
|
|
|
|
|
|
|
|
|
|
迁移说明:
|
|
|
|
|
|
- 保持 API 完全兼容,调用方无需修改
|
|
|
|
|
|
- 新增 load_polars() 方法支持 Polars 零拷贝导出
|
|
|
|
|
|
- 使用单例模式管理数据库连接
|
|
|
|
|
|
- 并发写入通过队列管理(见 ThreadSafeStorage)
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
_instance = None
|
|
|
|
|
|
_connection = None
|
|
|
|
|
|
|
|
|
|
|
|
def __new__(cls, *args, **kwargs):
|
|
|
|
|
|
"""Singleton to ensure single connection."""
|
|
|
|
|
|
if cls._instance is None:
|
|
|
|
|
|
cls._instance = super().__new__(cls)
|
|
|
|
|
|
return cls._instance
|
2026-01-31 03:04:51 +08:00
|
|
|
|
|
|
|
|
|
|
def __init__(self, path: Optional[Path] = None):
|
2026-02-23 00:07:21 +08:00
|
|
|
|
"""Initialize storage."""
|
|
|
|
|
|
if hasattr(self, "_initialized"):
|
|
|
|
|
|
return
|
2026-01-31 03:04:51 +08:00
|
|
|
|
|
|
|
|
|
|
cfg = get_config()
|
2026-01-31 04:30:29 +08:00
|
|
|
|
self.base_path = path or cfg.data_path_resolved
|
2026-01-31 03:04:51 +08:00
|
|
|
|
self.base_path.mkdir(parents=True, exist_ok=True)
|
2026-02-23 00:07:21 +08:00
|
|
|
|
self.db_path = self.base_path / "prostock.db"
|
|
|
|
|
|
|
|
|
|
|
|
self._init_db()
|
|
|
|
|
|
self._initialized = True
|
|
|
|
|
|
|
|
|
|
|
|
def _init_db(self):
|
|
|
|
|
|
"""Initialize database connection and schema."""
|
|
|
|
|
|
self._connection = duckdb.connect(str(self.db_path))
|
|
|
|
|
|
|
|
|
|
|
|
# Create tables with schema validation
|
|
|
|
|
|
self._connection.execute("""
|
|
|
|
|
|
CREATE TABLE IF NOT EXISTS daily (
|
|
|
|
|
|
ts_code VARCHAR(16) NOT NULL,
|
|
|
|
|
|
trade_date DATE NOT NULL,
|
|
|
|
|
|
open DOUBLE,
|
|
|
|
|
|
high DOUBLE,
|
|
|
|
|
|
low DOUBLE,
|
|
|
|
|
|
close DOUBLE,
|
|
|
|
|
|
pre_close DOUBLE,
|
|
|
|
|
|
change DOUBLE,
|
|
|
|
|
|
pct_chg DOUBLE,
|
|
|
|
|
|
vol DOUBLE,
|
|
|
|
|
|
amount DOUBLE,
|
|
|
|
|
|
turnover_rate DOUBLE,
|
|
|
|
|
|
volume_ratio DOUBLE,
|
|
|
|
|
|
PRIMARY KEY (ts_code, trade_date)
|
|
|
|
|
|
)
|
|
|
|
|
|
""")
|
2026-01-31 03:04:51 +08:00
|
|
|
|
|
2026-02-23 00:07:21 +08:00
|
|
|
|
# Create composite index for query optimization (trade_date, ts_code)
|
|
|
|
|
|
self._connection.execute("""
|
|
|
|
|
|
CREATE INDEX IF NOT EXISTS idx_daily_date_code ON daily(trade_date, ts_code)
|
|
|
|
|
|
""")
|
2026-01-31 03:04:51 +08:00
|
|
|
|
|
|
|
|
|
|
def save(self, name: str, data: pd.DataFrame, mode: str = "append") -> dict:
|
2026-02-23 00:07:21 +08:00
|
|
|
|
"""Save data to DuckDB.
|
2026-01-31 03:04:51 +08:00
|
|
|
|
|
|
|
|
|
|
Args:
|
2026-02-23 00:07:21 +08:00
|
|
|
|
name: Table name
|
2026-01-31 03:04:51 +08:00
|
|
|
|
data: DataFrame to save
|
2026-02-23 00:07:21 +08:00
|
|
|
|
mode: 'append' (UPSERT) or 'replace' (DELETE + INSERT)
|
2026-01-31 03:04:51 +08:00
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
|
Dict with save result
|
|
|
|
|
|
"""
|
|
|
|
|
|
if data.empty:
|
|
|
|
|
|
return {"status": "skipped", "rows": 0}
|
|
|
|
|
|
|
2026-02-23 00:07:21 +08:00
|
|
|
|
# Ensure date column is proper type
|
|
|
|
|
|
if "trade_date" in data.columns:
|
|
|
|
|
|
data = data.copy()
|
|
|
|
|
|
data["trade_date"] = pd.to_datetime(
|
|
|
|
|
|
data["trade_date"], format="%Y%m%d"
|
|
|
|
|
|
).dt.date
|
|
|
|
|
|
|
|
|
|
|
|
# Register DataFrame as temporary view
|
|
|
|
|
|
self._connection.register("temp_data", data)
|
2026-01-31 03:04:51 +08:00
|
|
|
|
|
|
|
|
|
|
try:
|
2026-02-23 00:07:21 +08:00
|
|
|
|
if mode == "replace":
|
|
|
|
|
|
self._connection.execute(f"DELETE FROM {name}")
|
|
|
|
|
|
|
|
|
|
|
|
# UPSERT: INSERT OR REPLACE
|
|
|
|
|
|
columns = ", ".join(data.columns)
|
|
|
|
|
|
self._connection.execute(f"""
|
|
|
|
|
|
INSERT OR REPLACE INTO {name} ({columns})
|
|
|
|
|
|
SELECT {columns} FROM temp_data
|
|
|
|
|
|
""")
|
|
|
|
|
|
|
|
|
|
|
|
row_count = len(data)
|
|
|
|
|
|
print(f"[Storage] Saved {row_count} rows to DuckDB ({name})")
|
|
|
|
|
|
return {"status": "success", "rows": row_count}
|
2026-01-31 03:04:51 +08:00
|
|
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
print(f"[Storage] Error saving {name}: {e}")
|
|
|
|
|
|
return {"status": "error", "error": str(e)}
|
2026-02-23 00:07:21 +08:00
|
|
|
|
finally:
|
|
|
|
|
|
self._connection.unregister("temp_data")
|
2026-01-31 03:04:51 +08:00
|
|
|
|
|
2026-02-21 03:43:30 +08:00
|
|
|
|
def load(
|
|
|
|
|
|
self,
|
|
|
|
|
|
name: str,
|
|
|
|
|
|
start_date: Optional[str] = None,
|
|
|
|
|
|
end_date: Optional[str] = None,
|
|
|
|
|
|
ts_code: Optional[str] = None,
|
|
|
|
|
|
) -> pd.DataFrame:
|
2026-02-23 00:07:21 +08:00
|
|
|
|
"""Load data from DuckDB with query pushdown.
|
|
|
|
|
|
|
|
|
|
|
|
关键优化:
|
|
|
|
|
|
- WHERE 条件在数据库层过滤,无需加载全表
|
|
|
|
|
|
- 只返回匹配条件的行,大幅减少内存占用
|
2026-01-31 03:04:51 +08:00
|
|
|
|
|
|
|
|
|
|
Args:
|
2026-02-23 00:07:21 +08:00
|
|
|
|
name: Table name
|
2026-01-31 03:04:51 +08:00
|
|
|
|
start_date: Start date filter (YYYYMMDD)
|
|
|
|
|
|
end_date: End date filter (YYYYMMDD)
|
|
|
|
|
|
ts_code: Stock code filter
|
|
|
|
|
|
|
|
|
|
|
|
Returns:
|
2026-02-23 00:07:21 +08:00
|
|
|
|
Filtered DataFrame
|
2026-01-31 03:04:51 +08:00
|
|
|
|
"""
|
2026-02-23 00:07:21 +08:00
|
|
|
|
# Build WHERE clause with parameterized queries
|
|
|
|
|
|
conditions = []
|
|
|
|
|
|
params = []
|
2026-01-31 03:04:51 +08:00
|
|
|
|
|
2026-02-23 00:07:21 +08:00
|
|
|
|
if start_date and end_date:
|
|
|
|
|
|
conditions.append("trade_date BETWEEN ? AND ?")
|
|
|
|
|
|
# Convert to DATE type
|
|
|
|
|
|
start = pd.to_datetime(start_date, format="%Y%m%d").date()
|
|
|
|
|
|
end = pd.to_datetime(end_date, format="%Y%m%d").date()
|
|
|
|
|
|
params.extend([start, end])
|
|
|
|
|
|
elif start_date:
|
|
|
|
|
|
conditions.append("trade_date >= ?")
|
|
|
|
|
|
params.append(pd.to_datetime(start_date, format="%Y%m%d").date())
|
|
|
|
|
|
elif end_date:
|
|
|
|
|
|
conditions.append("trade_date <= ?")
|
|
|
|
|
|
params.append(pd.to_datetime(end_date, format="%Y%m%d").date())
|
2026-02-21 03:43:30 +08:00
|
|
|
|
|
2026-02-23 00:07:21 +08:00
|
|
|
|
if ts_code:
|
|
|
|
|
|
conditions.append("ts_code = ?")
|
|
|
|
|
|
params.append(ts_code)
|
2026-01-31 03:04:51 +08:00
|
|
|
|
|
2026-02-23 00:07:21 +08:00
|
|
|
|
where_clause = f"WHERE {' AND '.join(conditions)}" if conditions else ""
|
|
|
|
|
|
query = f"SELECT * FROM {name} {where_clause} ORDER BY trade_date"
|
2026-01-31 03:04:51 +08:00
|
|
|
|
|
2026-02-23 00:07:21 +08:00
|
|
|
|
try:
|
|
|
|
|
|
# Execute query with parameters (SQL injection safe)
|
|
|
|
|
|
result = self._connection.execute(query, params).fetchdf()
|
2026-01-31 03:04:51 +08:00
|
|
|
|
|
2026-02-23 00:07:21 +08:00
|
|
|
|
# Convert trade_date back to string format for compatibility
|
|
|
|
|
|
if "trade_date" in result.columns:
|
|
|
|
|
|
result["trade_date"] = result["trade_date"].dt.strftime("%Y%m%d")
|
2026-01-31 03:04:51 +08:00
|
|
|
|
|
2026-02-23 00:07:21 +08:00
|
|
|
|
return result
|
2026-01-31 03:04:51 +08:00
|
|
|
|
except Exception as e:
|
|
|
|
|
|
print(f"[Storage] Error loading {name}: {e}")
|
|
|
|
|
|
return pd.DataFrame()
|
|
|
|
|
|
|
2026-02-23 00:07:21 +08:00
|
|
|
|
def load_polars(
|
|
|
|
|
|
self,
|
|
|
|
|
|
name: str,
|
|
|
|
|
|
start_date: Optional[str] = None,
|
|
|
|
|
|
end_date: Optional[str] = None,
|
|
|
|
|
|
ts_code: Optional[str] = None,
|
|
|
|
|
|
) -> pl.DataFrame:
|
|
|
|
|
|
"""Load data as Polars DataFrame (for DataLoader).
|
2026-01-31 03:04:51 +08:00
|
|
|
|
|
2026-02-23 00:07:21 +08:00
|
|
|
|
性能优势:
|
|
|
|
|
|
- 零拷贝导出(DuckDB → Polars via PyArrow)
|
|
|
|
|
|
- 需要 pyarrow 支持
|
2026-01-31 03:04:51 +08:00
|
|
|
|
"""
|
2026-02-23 00:07:21 +08:00
|
|
|
|
# Build query
|
|
|
|
|
|
conditions = []
|
|
|
|
|
|
if start_date and end_date:
|
|
|
|
|
|
start = pd.to_datetime(start_date, format='%Y%m%d').date()
|
|
|
|
|
|
end = pd.to_datetime(end_date, format='%Y%m%d').date()
|
|
|
|
|
|
conditions.append(f"trade_date BETWEEN '{start}' AND '{end}'")
|
|
|
|
|
|
if ts_code:
|
|
|
|
|
|
conditions.append(f"ts_code = '{ts_code}'")
|
|
|
|
|
|
|
|
|
|
|
|
where_clause = f"WHERE {' AND '.join(conditions)}" if conditions else ""
|
|
|
|
|
|
query = f"SELECT * FROM {name} {where_clause} ORDER BY trade_date"
|
|
|
|
|
|
|
|
|
|
|
|
# 使用 DuckDB 的 Polars 导出(需要 pyarrow)
|
|
|
|
|
|
df = self._connection.sql(query).pl()
|
|
|
|
|
|
|
|
|
|
|
|
# 将 trade_date 转换为字符串格式,保持兼容性
|
|
|
|
|
|
if "trade_date" in df.columns:
|
|
|
|
|
|
df = df.with_columns(
|
|
|
|
|
|
pl.col("trade_date").dt.strftime("%Y%m%d").alias("trade_date")
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
return df
|
2026-01-31 03:04:51 +08:00
|
|
|
|
|
|
|
|
|
|
def exists(self, name: str) -> bool:
|
2026-02-23 00:07:21 +08:00
|
|
|
|
"""Check if table exists."""
|
|
|
|
|
|
result = self._connection.execute(
|
|
|
|
|
|
"""
|
|
|
|
|
|
SELECT COUNT(*) FROM information_schema.tables
|
|
|
|
|
|
WHERE table_name = ?
|
|
|
|
|
|
""",
|
|
|
|
|
|
[name],
|
|
|
|
|
|
).fetchone()
|
|
|
|
|
|
return result[0] > 0
|
2026-01-31 03:04:51 +08:00
|
|
|
|
|
|
|
|
|
|
def delete(self, name: str) -> bool:
|
2026-02-23 00:07:21 +08:00
|
|
|
|
"""Delete a table."""
|
|
|
|
|
|
try:
|
|
|
|
|
|
self._connection.execute(f"DROP TABLE IF EXISTS {name}")
|
|
|
|
|
|
print(f"[Storage] Deleted table {name}")
|
|
|
|
|
|
return True
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
print(f"[Storage] Error deleting {name}: {e}")
|
|
|
|
|
|
return False
|
2026-01-31 03:04:51 +08:00
|
|
|
|
|
2026-02-23 00:07:21 +08:00
|
|
|
|
def get_last_date(self, name: str) -> Optional[str]:
|
|
|
|
|
|
"""Get the latest date in storage."""
|
|
|
|
|
|
try:
|
|
|
|
|
|
result = self._connection.execute(f"""
|
|
|
|
|
|
SELECT MAX(trade_date) FROM {name}
|
|
|
|
|
|
""").fetchone()
|
|
|
|
|
|
if result[0]:
|
|
|
|
|
|
# Convert date back to string format
|
|
|
|
|
|
return (
|
|
|
|
|
|
result[0].strftime("%Y%m%d")
|
|
|
|
|
|
if hasattr(result[0], "strftime")
|
|
|
|
|
|
else str(result[0])
|
|
|
|
|
|
)
|
|
|
|
|
|
return None
|
|
|
|
|
|
except:
|
|
|
|
|
|
return None
|
2026-01-31 03:04:51 +08:00
|
|
|
|
|
2026-02-23 00:07:21 +08:00
|
|
|
|
def close(self):
|
|
|
|
|
|
"""Close database connection."""
|
|
|
|
|
|
if self._connection:
|
|
|
|
|
|
self._connection.close()
|
|
|
|
|
|
Storage._connection = None
|
|
|
|
|
|
Storage._instance = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ThreadSafeStorage:
|
|
|
|
|
|
"""线程安全的 DuckDB 写入包装器。
|
|
|
|
|
|
|
|
|
|
|
|
DuckDB 写入时不支持并发,使用队列收集写入请求,
|
|
|
|
|
|
在 sync 结束时统一批量写入。
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
def __init__(self):
|
|
|
|
|
|
self.storage = Storage()
|
|
|
|
|
|
self._pending_writes: List[tuple] = [] # [(name, data), ...]
|
|
|
|
|
|
|
|
|
|
|
|
def queue_save(self, name: str, data: pd.DataFrame):
|
|
|
|
|
|
"""将数据放入写入队列(不立即写入)"""
|
|
|
|
|
|
if not data.empty:
|
|
|
|
|
|
self._pending_writes.append((name, data))
|
|
|
|
|
|
|
|
|
|
|
|
def flush(self):
|
|
|
|
|
|
"""批量写入所有队列数据。
|
|
|
|
|
|
|
|
|
|
|
|
调用时机:在 sync 结束时统一调用,避免并发写入冲突。
|
2026-01-31 03:04:51 +08:00
|
|
|
|
"""
|
2026-02-23 00:07:21 +08:00
|
|
|
|
if not self._pending_writes:
|
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
|
|
# 合并相同表的数据
|
|
|
|
|
|
table_data = defaultdict(list)
|
|
|
|
|
|
|
|
|
|
|
|
for name, data in self._pending_writes:
|
|
|
|
|
|
table_data[name].append(data)
|
|
|
|
|
|
|
|
|
|
|
|
# 批量写入每个表
|
|
|
|
|
|
for name, data_list in table_data.items():
|
|
|
|
|
|
combined = pd.concat(data_list, ignore_index=True)
|
|
|
|
|
|
# 在批量数据中先去重
|
|
|
|
|
|
if "ts_code" in combined.columns and "trade_date" in combined.columns:
|
|
|
|
|
|
combined = combined.drop_duplicates(
|
|
|
|
|
|
subset=["ts_code", "trade_date"], keep="last"
|
|
|
|
|
|
)
|
|
|
|
|
|
self.storage.save(name, combined, mode="append")
|
|
|
|
|
|
|
|
|
|
|
|
self._pending_writes.clear()
|
|
|
|
|
|
|
|
|
|
|
|
def __getattr__(self, name):
|
|
|
|
|
|
"""代理其他方法到 Storage 实例"""
|
|
|
|
|
|
return getattr(self.storage, name)
|