- 移除 storage.py 集中式建表逻辑,改为各 API 文件自管理 - base_sync.py 新增 ensure_table_exists() 和表探测机制 - api_daily/api_pro_bar/api_bak_basic 添加 TABLE_SCHEMA 定义 - api_financial_sync 添加完整利润表字段定义 - sync.py 更新职责文档,明确仅同步每日更新数据 - AGENTS.md 添加 v2.1 架构变更历史和 AI 行为准则
334 lines
11 KiB
Python
334 lines
11 KiB
Python
"""DuckDB storage for data persistence."""
|
||
import pandas as pd
|
||
import polars as pl
|
||
import duckdb
|
||
from pathlib import Path
|
||
from typing import Optional, List, Dict, Any, Tuple
|
||
from collections import defaultdict
|
||
from datetime import datetime
|
||
from src.config.settings import get_settings
|
||
|
||
|
||
# Default column type mapping for automatic schema inference
|
||
DEFAULT_TYPE_MAPPING = {
|
||
"ts_code": "VARCHAR(16)",
|
||
"trade_date": "DATE",
|
||
"open": "DOUBLE",
|
||
"high": "DOUBLE",
|
||
"low": "DOUBLE",
|
||
"close": "DOUBLE",
|
||
"pre_close": "DOUBLE",
|
||
"change": "DOUBLE",
|
||
"pct_chg": "DOUBLE",
|
||
"vol": "DOUBLE",
|
||
"amount": "DOUBLE",
|
||
"turnover_rate": "DOUBLE",
|
||
"volume_ratio": "DOUBLE",
|
||
"adj_factor": "DOUBLE",
|
||
"suspend_flag": "INTEGER",
|
||
}
|
||
|
||
|
||
class Storage:
|
||
"""DuckDB storage manager for saving and loading data.
|
||
|
||
迁移说明:
|
||
- 保持 API 完全兼容,调用方无需修改
|
||
- 新增 load_polars() 方法支持 Polars 零拷贝导出
|
||
- 使用单例模式管理数据库连接
|
||
- 并发写入通过队列管理(见 ThreadSafeStorage)
|
||
"""
|
||
|
||
_instance = None
|
||
_connection = None
|
||
|
||
def __new__(cls, *args, **kwargs):
|
||
"""Singleton to ensure single connection."""
|
||
if cls._instance is None:
|
||
cls._instance = super().__new__(cls)
|
||
return cls._instance
|
||
|
||
def __init__(self, path: Optional[Path] = None):
|
||
"""Initialize storage."""
|
||
if hasattr(self, "_initialized"):
|
||
return
|
||
|
||
cfg = get_settings()
|
||
self.base_path = path or cfg.data_path_resolved
|
||
self.base_path.mkdir(parents=True, exist_ok=True)
|
||
self.db_path = self.base_path / "prostock.db"
|
||
|
||
self._init_db()
|
||
self._initialized = True
|
||
|
||
def _init_db(self):
|
||
"""Initialize database connection and schema.
|
||
|
||
注意:建表语句已迁移到对应的 API 文件中,
|
||
每个同步类负责自己的表结构定义和创建。
|
||
参见:
|
||
- api_daily.py: DailySync.TABLE_SCHEMA
|
||
- api_pro_bar.py: ProBarSync.TABLE_SCHEMA
|
||
- api_bak_basic.py: BakBasicSync.TABLE_SCHEMA
|
||
- api_financial_sync.py: FinancialSync.TABLE_SCHEMAS
|
||
"""
|
||
self._connection = duckdb.connect(str(self.db_path))
|
||
def save(self, name: str, data: pd.DataFrame, mode: str = "append") -> dict:
|
||
"""Save data to DuckDB.
|
||
|
||
Args:
|
||
name: Table name
|
||
data: DataFrame to save
|
||
mode: 'append' (UPSERT) or 'replace' (DELETE + INSERT)
|
||
|
||
Returns:
|
||
Dict with save result
|
||
"""
|
||
if data.empty:
|
||
return {"status": "skipped", "rows": 0}
|
||
|
||
# 确保日期列是正确的类型 (YYYYMMDD -> date)
|
||
# trade_date: 日线数据日期
|
||
if "trade_date" in data.columns:
|
||
data = data.copy()
|
||
data["trade_date"] = pd.to_datetime(
|
||
data["trade_date"], format="%Y%m%d"
|
||
).dt.date
|
||
|
||
# ann_date: 公告日期
|
||
if "ann_date" in data.columns:
|
||
data = data.copy()
|
||
data["ann_date"] = pd.to_datetime(
|
||
data["ann_date"], format="%Y%m%d", errors="coerce"
|
||
).dt.date
|
||
|
||
# f_ann_date: 最终公告日期
|
||
if "f_ann_date" in data.columns:
|
||
data = data.copy()
|
||
data["f_ann_date"] = pd.to_datetime(
|
||
data["f_ann_date"], format="%Y%m%d", errors="coerce"
|
||
).dt.date
|
||
|
||
# end_date: 报告期/期末日期
|
||
if "end_date" in data.columns:
|
||
data = data.copy()
|
||
data["end_date"] = pd.to_datetime(
|
||
data["end_date"], format="%Y%m%d", errors="coerce"
|
||
).dt.date
|
||
|
||
# Register DataFrame as temporary view
|
||
self._connection.register("temp_data", data)
|
||
|
||
try:
|
||
if mode == "replace":
|
||
self._connection.execute(f"DELETE FROM {name}")
|
||
|
||
# UPSERT: INSERT OR REPLACE
|
||
columns = ', '.join(f'"{col}"' for col in data.columns)
|
||
self._connection.execute(f"""
|
||
INSERT OR REPLACE INTO {name} ({columns})
|
||
SELECT {columns} FROM temp_data
|
||
""")
|
||
columns = ", ".join(data.columns)
|
||
self._connection.execute(f"""
|
||
INSERT OR REPLACE INTO {name} ({columns})
|
||
SELECT {columns} FROM temp_data
|
||
""")
|
||
|
||
row_count = len(data)
|
||
print(f"[Storage] Saved {row_count} rows to DuckDB ({name})")
|
||
return {"status": "success", "rows": row_count}
|
||
|
||
except Exception as e:
|
||
print(f"[Storage] Error saving {name}: {e}")
|
||
return {"status": "error", "error": str(e)}
|
||
finally:
|
||
self._connection.unregister("temp_data")
|
||
|
||
def load(
|
||
self,
|
||
name: str,
|
||
start_date: Optional[str] = None,
|
||
end_date: Optional[str] = None,
|
||
ts_code: Optional[str] = None,
|
||
) -> pd.DataFrame:
|
||
"""Load data from DuckDB with query pushdown.
|
||
|
||
关键优化:
|
||
- WHERE 条件在数据库层过滤,无需加载全表
|
||
- 只返回匹配条件的行,大幅减少内存占用
|
||
|
||
Args:
|
||
name: Table name
|
||
start_date: Start date filter (YYYYMMDD)
|
||
end_date: End date filter (YYYYMMDD)
|
||
ts_code: Stock code filter
|
||
|
||
Returns:
|
||
Filtered DataFrame
|
||
"""
|
||
# Build WHERE clause with parameterized queries
|
||
conditions = []
|
||
params = []
|
||
|
||
if start_date and end_date:
|
||
conditions.append("trade_date BETWEEN ? AND ?")
|
||
# Convert to DATE type
|
||
start = pd.to_datetime(start_date, format="%Y%m%d").date()
|
||
end = pd.to_datetime(end_date, format="%Y%m%d").date()
|
||
params.extend([start, end])
|
||
elif start_date:
|
||
conditions.append("trade_date >= ?")
|
||
params.append(pd.to_datetime(start_date, format="%Y%m%d").date())
|
||
elif end_date:
|
||
conditions.append("trade_date <= ?")
|
||
params.append(pd.to_datetime(end_date, format="%Y%m%d").date())
|
||
|
||
if ts_code:
|
||
conditions.append("ts_code = ?")
|
||
params.append(ts_code)
|
||
|
||
where_clause = f"WHERE {' AND '.join(conditions)}" if conditions else ""
|
||
query = f"SELECT * FROM {name} {where_clause} ORDER BY trade_date"
|
||
|
||
try:
|
||
# Execute query with parameters (SQL injection safe)
|
||
result = self._connection.execute(query, params).fetchdf()
|
||
|
||
# Convert trade_date back to string format for compatibility
|
||
if "trade_date" in result.columns:
|
||
result["trade_date"] = result["trade_date"].dt.strftime("%Y%m%d")
|
||
|
||
return result
|
||
except Exception as e:
|
||
print(f"[Storage] Error loading {name}: {e}")
|
||
return pd.DataFrame()
|
||
|
||
def load_polars(
|
||
self,
|
||
name: str,
|
||
start_date: Optional[str] = None,
|
||
end_date: Optional[str] = None,
|
||
ts_code: Optional[str] = None,
|
||
) -> pl.DataFrame:
|
||
"""Load data as Polars DataFrame (for DataLoader).
|
||
|
||
性能优势:
|
||
- 零拷贝导出(DuckDB → Polars via PyArrow)
|
||
- 需要 pyarrow 支持
|
||
"""
|
||
# Build query
|
||
conditions = []
|
||
if start_date and end_date:
|
||
start = pd.to_datetime(start_date, format='%Y%m%d').date()
|
||
end = pd.to_datetime(end_date, format='%Y%m%d').date()
|
||
conditions.append(f"trade_date BETWEEN '{start}' AND '{end}'")
|
||
if ts_code:
|
||
conditions.append(f"ts_code = '{ts_code}'")
|
||
|
||
where_clause = f"WHERE {' AND '.join(conditions)}" if conditions else ""
|
||
query = f"SELECT * FROM {name} {where_clause} ORDER BY trade_date"
|
||
|
||
# 使用 DuckDB 的 Polars 导出(需要 pyarrow)
|
||
df = self._connection.sql(query).pl()
|
||
|
||
# 将 trade_date 转换为字符串格式,保持兼容性
|
||
if "trade_date" in df.columns:
|
||
df = df.with_columns(
|
||
pl.col("trade_date").dt.strftime("%Y%m%d").alias("trade_date")
|
||
)
|
||
|
||
return df
|
||
|
||
def exists(self, name: str) -> bool:
|
||
"""Check if table exists."""
|
||
result = self._connection.execute(
|
||
"""
|
||
SELECT COUNT(*) FROM information_schema.tables
|
||
WHERE table_name = ?
|
||
""",
|
||
[name],
|
||
).fetchone()
|
||
return result[0] > 0
|
||
|
||
def delete(self, name: str) -> bool:
|
||
"""Delete a table."""
|
||
try:
|
||
self._connection.execute(f"DROP TABLE IF EXISTS {name}")
|
||
print(f"[Storage] Deleted table {name}")
|
||
return True
|
||
except Exception as e:
|
||
print(f"[Storage] Error deleting {name}: {e}")
|
||
return False
|
||
|
||
def get_last_date(self, name: str) -> Optional[str]:
|
||
"""Get the latest date in storage."""
|
||
try:
|
||
result = self._connection.execute(f"""
|
||
SELECT MAX(trade_date) FROM {name}
|
||
""").fetchone()
|
||
if result[0]:
|
||
# Convert date back to string format
|
||
return (
|
||
result[0].strftime("%Y%m%d")
|
||
if hasattr(result[0], "strftime")
|
||
else str(result[0])
|
||
)
|
||
return None
|
||
except:
|
||
return None
|
||
|
||
def close(self):
|
||
"""Close database connection."""
|
||
if self._connection:
|
||
self._connection.close()
|
||
Storage._connection = None
|
||
Storage._instance = None
|
||
|
||
|
||
class ThreadSafeStorage:
|
||
"""线程安全的 DuckDB 写入包装器。
|
||
|
||
DuckDB 写入时不支持并发,使用队列收集写入请求,
|
||
在 sync 结束时统一批量写入。
|
||
"""
|
||
|
||
def __init__(self):
|
||
self.storage = Storage()
|
||
self._pending_writes: List[tuple] = [] # [(name, data), ...]
|
||
|
||
def queue_save(self, name: str, data: pd.DataFrame):
|
||
"""将数据放入写入队列(不立即写入)"""
|
||
if not data.empty:
|
||
self._pending_writes.append((name, data))
|
||
|
||
def flush(self):
|
||
"""批量写入所有队列数据。
|
||
|
||
调用时机:在 sync 结束时统一调用,避免并发写入冲突。
|
||
"""
|
||
if not self._pending_writes:
|
||
return
|
||
|
||
# 合并相同表的数据
|
||
table_data = defaultdict(list)
|
||
|
||
for name, data in self._pending_writes:
|
||
table_data[name].append(data)
|
||
|
||
# 批量写入每个表
|
||
for name, data_list in table_data.items():
|
||
combined = pd.concat(data_list, ignore_index=True)
|
||
# 在批量数据中先去重
|
||
if "ts_code" in combined.columns and "trade_date" in combined.columns:
|
||
combined = combined.drop_duplicates(
|
||
subset=["ts_code", "trade_date"], keep="last"
|
||
)
|
||
self.storage.save(name, combined, mode="append")
|
||
|
||
self._pending_writes.clear()
|
||
|
||
def __getattr__(self, name):
|
||
"""代理其他方法到 Storage 实例"""
|
||
return getattr(self.storage, name)
|