"""DuckDB storage for data persistence.""" import pandas as pd import polars as pl import duckdb from pathlib import Path from typing import Optional, List, Dict, Any, Tuple from collections import defaultdict from datetime import datetime from src.config.settings import get_settings # Default column type mapping for automatic schema inference DEFAULT_TYPE_MAPPING = { "ts_code": "VARCHAR(16)", "trade_date": "DATE", "open": "DOUBLE", "high": "DOUBLE", "low": "DOUBLE", "close": "DOUBLE", "pre_close": "DOUBLE", "change": "DOUBLE", "pct_chg": "DOUBLE", "vol": "DOUBLE", "amount": "DOUBLE", "turnover_rate": "DOUBLE", "volume_ratio": "DOUBLE", "adj_factor": "DOUBLE", "suspend_flag": "INTEGER", } class Storage: """DuckDB storage manager for saving and loading data. 迁移说明: - 保持 API 完全兼容,调用方无需修改 - 新增 load_polars() 方法支持 Polars 零拷贝导出 - 使用单例模式管理数据库连接 - 并发写入通过队列管理(见 ThreadSafeStorage) """ _instance = None _connection = None def __new__(cls, *args, **kwargs): """Singleton to ensure single connection.""" if cls._instance is None: cls._instance = super().__new__(cls) return cls._instance def __init__(self, path: Optional[Path] = None): """Initialize storage.""" if hasattr(self, "_initialized"): return cfg = get_settings() self.base_path = path or cfg.data_path_resolved self.base_path.mkdir(parents=True, exist_ok=True) self.db_path = self.base_path / "prostock.db" self._init_db() self._initialized = True def _init_db(self): """Initialize database connection and schema.""" self._connection = duckdb.connect(str(self.db_path)) # Create tables with schema validation self._connection.execute(""" CREATE TABLE IF NOT EXISTS daily ( ts_code VARCHAR(16) NOT NULL, trade_date DATE NOT NULL, open DOUBLE, high DOUBLE, low DOUBLE, close DOUBLE, pre_close DOUBLE, change DOUBLE, pct_chg DOUBLE, vol DOUBLE, amount DOUBLE, turnover_rate DOUBLE, volume_ratio DOUBLE, PRIMARY KEY (ts_code, trade_date) ) """) # Create composite index for query optimization (trade_date, ts_code) self._connection.execute(""" CREATE INDEX IF NOT EXISTS idx_daily_date_code ON daily(trade_date, ts_code) """) # Create financial_income table for income statement data # 完整的利润表字段(94列全部) self._connection.execute(""" CREATE TABLE IF NOT EXISTS financial_income ( ts_code VARCHAR(16) NOT NULL, ann_date DATE, f_ann_date DATE, end_date DATE NOT NULL, report_type INTEGER, comp_type INTEGER, end_type VARCHAR(10), basic_eps DOUBLE, diluted_eps DOUBLE, total_revenue DOUBLE, revenue DOUBLE, int_income DOUBLE, prem_earned DOUBLE, comm_income DOUBLE, n_commis_income DOUBLE, n_oth_income DOUBLE, n_oth_b_income DOUBLE, prem_income DOUBLE, out_prem DOUBLE, une_prem_reser DOUBLE, reins_income DOUBLE, n_sec_tb_income DOUBLE, n_sec_uw_income DOUBLE, n_asset_mg_income DOUBLE, oth_b_income DOUBLE, fv_value_chg_gain DOUBLE, invest_income DOUBLE, ass_invest_income DOUBLE, forex_gain DOUBLE, total_cogs DOUBLE, oper_cost DOUBLE, int_exp DOUBLE, comm_exp DOUBLE, biz_tax_surchg DOUBLE, sell_exp DOUBLE, admin_exp DOUBLE, fin_exp DOUBLE, assets_impair_loss DOUBLE, prem_refund DOUBLE, compens_payout DOUBLE, reser_insur_liab DOUBLE, div_payt DOUBLE, reins_exp DOUBLE, oper_exp DOUBLE, compens_payout_refu DOUBLE, insur_reser_refu DOUBLE, reins_cost_refund DOUBLE, other_bus_cost DOUBLE, operate_profit DOUBLE, non_oper_income DOUBLE, non_oper_exp DOUBLE, nca_disploss DOUBLE, total_profit DOUBLE, income_tax DOUBLE, n_income DOUBLE, n_income_attr_p DOUBLE, minority_gain DOUBLE, oth_compr_income DOUBLE, t_compr_income DOUBLE, compr_inc_attr_p DOUBLE, compr_inc_attr_m_s DOUBLE, ebit DOUBLE, ebitda DOUBLE, insurance_exp DOUBLE, undist_profit DOUBLE, distable_profit DOUBLE, rd_exp DOUBLE, fin_exp_int_exp DOUBLE, fin_exp_int_inc DOUBLE, transfer_surplus_rese DOUBLE, transfer_housing_imprest DOUBLE, transfer_oth DOUBLE, adj_lossgain DOUBLE, withdra_legal_surplus DOUBLE, withdra_legal_pubfund DOUBLE, withdra_biz_devfund DOUBLE, withdra_rese_fund DOUBLE, withdra_oth_ersu DOUBLE, workers_welfare DOUBLE, distr_profit_shrhder DOUBLE, prfshare_payable_dvd DOUBLE, comshare_payable_dvd DOUBLE, capit_comstock_div DOUBLE, net_after_nr_lp_correct DOUBLE, credit_impa_loss DOUBLE, net_expo_hedging_benefits DOUBLE, oth_impair_loss_assets DOUBLE, total_opcost DOUBLE, amodcost_fin_assets DOUBLE, oth_income DOUBLE, asset_disp_income DOUBLE, continued_net_profit DOUBLE, end_net_profit DOUBLE, update_flag VARCHAR(1), PRIMARY KEY (ts_code, end_date) ) # Create pro_bar table for pro bar data (with adj, tor, vr) self._connection.execute(""" CREATE TABLE IF NOT EXISTS pro_bar ( ts_code VARCHAR(16) NOT NULL, trade_date DATE NOT NULL, open DOUBLE, high DOUBLE, low DOUBLE, close DOUBLE, pre_close DOUBLE, change DOUBLE, pct_chg DOUBLE, vol DOUBLE, amount DOUBLE, tor DOUBLE, vr DOUBLE, adj_factor DOUBLE, PRIMARY KEY (ts_code, trade_date) ) """) # Create index for financial_income self._connection.execute(""" CREATE INDEX IF NOT EXISTS idx_financial_ann ON financial_income(ts_code, ann_date) """) def save(self, name: str, data: pd.DataFrame, mode: str = "append") -> dict: """Save data to DuckDB. Args: name: Table name data: DataFrame to save mode: 'append' (UPSERT) or 'replace' (DELETE + INSERT) Returns: Dict with save result """ if data.empty: return {"status": "skipped", "rows": 0} # 确保日期列是正确的类型 (YYYYMMDD -> date) # trade_date: 日线数据日期 if "trade_date" in data.columns: data = data.copy() data["trade_date"] = pd.to_datetime( data["trade_date"], format="%Y%m%d" ).dt.date # ann_date: 公告日期 if "ann_date" in data.columns: data = data.copy() data["ann_date"] = pd.to_datetime( data["ann_date"], format="%Y%m%d", errors="coerce" ).dt.date # f_ann_date: 最终公告日期 if "f_ann_date" in data.columns: data = data.copy() data["f_ann_date"] = pd.to_datetime( data["f_ann_date"], format="%Y%m%d", errors="coerce" ).dt.date # end_date: 报告期/期末日期 if "end_date" in data.columns: data = data.copy() data["end_date"] = pd.to_datetime( data["end_date"], format="%Y%m%d", errors="coerce" ).dt.date # Register DataFrame as temporary view self._connection.register("temp_data", data) try: if mode == "replace": self._connection.execute(f"DELETE FROM {name}") # UPSERT: INSERT OR REPLACE columns = ", ".join(data.columns) self._connection.execute(f""" INSERT OR REPLACE INTO {name} ({columns}) SELECT {columns} FROM temp_data """) row_count = len(data) print(f"[Storage] Saved {row_count} rows to DuckDB ({name})") return {"status": "success", "rows": row_count} except Exception as e: print(f"[Storage] Error saving {name}: {e}") return {"status": "error", "error": str(e)} finally: self._connection.unregister("temp_data") def load( self, name: str, start_date: Optional[str] = None, end_date: Optional[str] = None, ts_code: Optional[str] = None, ) -> pd.DataFrame: """Load data from DuckDB with query pushdown. 关键优化: - WHERE 条件在数据库层过滤,无需加载全表 - 只返回匹配条件的行,大幅减少内存占用 Args: name: Table name start_date: Start date filter (YYYYMMDD) end_date: End date filter (YYYYMMDD) ts_code: Stock code filter Returns: Filtered DataFrame """ # Build WHERE clause with parameterized queries conditions = [] params = [] if start_date and end_date: conditions.append("trade_date BETWEEN ? AND ?") # Convert to DATE type start = pd.to_datetime(start_date, format="%Y%m%d").date() end = pd.to_datetime(end_date, format="%Y%m%d").date() params.extend([start, end]) elif start_date: conditions.append("trade_date >= ?") params.append(pd.to_datetime(start_date, format="%Y%m%d").date()) elif end_date: conditions.append("trade_date <= ?") params.append(pd.to_datetime(end_date, format="%Y%m%d").date()) if ts_code: conditions.append("ts_code = ?") params.append(ts_code) where_clause = f"WHERE {' AND '.join(conditions)}" if conditions else "" query = f"SELECT * FROM {name} {where_clause} ORDER BY trade_date" try: # Execute query with parameters (SQL injection safe) result = self._connection.execute(query, params).fetchdf() # Convert trade_date back to string format for compatibility if "trade_date" in result.columns: result["trade_date"] = result["trade_date"].dt.strftime("%Y%m%d") return result except Exception as e: print(f"[Storage] Error loading {name}: {e}") return pd.DataFrame() def load_polars( self, name: str, start_date: Optional[str] = None, end_date: Optional[str] = None, ts_code: Optional[str] = None, ) -> pl.DataFrame: """Load data as Polars DataFrame (for DataLoader). 性能优势: - 零拷贝导出(DuckDB → Polars via PyArrow) - 需要 pyarrow 支持 """ # Build query conditions = [] if start_date and end_date: start = pd.to_datetime(start_date, format='%Y%m%d').date() end = pd.to_datetime(end_date, format='%Y%m%d').date() conditions.append(f"trade_date BETWEEN '{start}' AND '{end}'") if ts_code: conditions.append(f"ts_code = '{ts_code}'") where_clause = f"WHERE {' AND '.join(conditions)}" if conditions else "" query = f"SELECT * FROM {name} {where_clause} ORDER BY trade_date" # 使用 DuckDB 的 Polars 导出(需要 pyarrow) df = self._connection.sql(query).pl() # 将 trade_date 转换为字符串格式,保持兼容性 if "trade_date" in df.columns: df = df.with_columns( pl.col("trade_date").dt.strftime("%Y%m%d").alias("trade_date") ) return df def exists(self, name: str) -> bool: """Check if table exists.""" result = self._connection.execute( """ SELECT COUNT(*) FROM information_schema.tables WHERE table_name = ? """, [name], ).fetchone() return result[0] > 0 def delete(self, name: str) -> bool: """Delete a table.""" try: self._connection.execute(f"DROP TABLE IF EXISTS {name}") print(f"[Storage] Deleted table {name}") return True except Exception as e: print(f"[Storage] Error deleting {name}: {e}") return False def get_last_date(self, name: str) -> Optional[str]: """Get the latest date in storage.""" try: result = self._connection.execute(f""" SELECT MAX(trade_date) FROM {name} """).fetchone() if result[0]: # Convert date back to string format return ( result[0].strftime("%Y%m%d") if hasattr(result[0], "strftime") else str(result[0]) ) return None except: return None def close(self): """Close database connection.""" if self._connection: self._connection.close() Storage._connection = None Storage._instance = None class ThreadSafeStorage: """线程安全的 DuckDB 写入包装器。 DuckDB 写入时不支持并发,使用队列收集写入请求, 在 sync 结束时统一批量写入。 """ def __init__(self): self.storage = Storage() self._pending_writes: List[tuple] = [] # [(name, data), ...] def queue_save(self, name: str, data: pd.DataFrame): """将数据放入写入队列(不立即写入)""" if not data.empty: self._pending_writes.append((name, data)) def flush(self): """批量写入所有队列数据。 调用时机:在 sync 结束时统一调用,避免并发写入冲突。 """ if not self._pending_writes: return # 合并相同表的数据 table_data = defaultdict(list) for name, data in self._pending_writes: table_data[name].append(data) # 批量写入每个表 for name, data_list in table_data.items(): combined = pd.concat(data_list, ignore_index=True) # 在批量数据中先去重 if "ts_code" in combined.columns and "trade_date" in combined.columns: combined = combined.drop_duplicates( subset=["ts_code", "trade_date"], keep="last" ) self.storage.save(name, combined, mode="append") self._pending_writes.clear() def __getattr__(self, name): """代理其他方法到 Storage 实例""" return getattr(self.storage, name)