From bace4cc5f43d44b4065aaad2d51a0c1585590e3d Mon Sep 17 00:00:00 2001 From: liaozhaorun <1300336796@qq.com> Date: Mon, 23 Mar 2026 21:10:15 +0800 Subject: [PATCH] =?UTF-8?q?feat(data):=20=E4=B8=BA=E6=95=B0=E6=8D=AE?= =?UTF-8?q?=E5=90=8C=E6=AD=A5=E6=B7=BB=E5=8A=A0=E4=BA=8B=E5=8A=A1=E6=94=AF?= =?UTF-8?q?=E6=8C=81=E5=92=8C=E5=90=8C=E6=AD=A5=E6=97=A5=E5=BF=97=20-=20St?= =?UTF-8?q?orage/ThreadSafeStorage=20=E6=B7=BB=E5=8A=A0=E4=BA=8B=E5=8A=A1?= =?UTF-8?q?=E6=94=AF=E6=8C=81=EF=BC=88begin/commit/rollback=EF=BC=89=20-?= =?UTF-8?q?=20=E6=96=B0=E5=A2=9E=20SyncLogManager=20=E8=AE=B0=E5=BD=95?= =?UTF-8?q?=E6=89=80=E6=9C=89=E5=90=8C=E6=AD=A5=E4=BB=BB=E5=8A=A1=E7=9A=84?= =?UTF-8?q?=E6=89=A7=E8=A1=8C=E7=8A=B6=E6=80=81=20-=20=E9=9B=86=E6=88=90?= =?UTF-8?q?=E4=BA=8B=E5=8A=A1=E5=88=B0=20StockBasedSync=E3=80=81DateBasedS?= =?UTF-8?q?ync=E3=80=81QuarterBasedSync=20-=20=E5=9C=A8=20sync=5Fall=20?= =?UTF-8?q?=E5=92=8C=20sync=5Ffinancial=20=E8=B0=83=E5=BA=A6=E4=B8=AD?= =?UTF-8?q?=E5=BF=83=E6=B7=BB=E5=8A=A0=E6=97=A5=E5=BF=97=E8=AE=B0=E5=BD=95?= =?UTF-8?q?=20-=20=E6=96=B0=E5=A2=9E=E6=B5=8B=E8=AF=95=E9=AA=8C=E8=AF=81?= =?UTF-8?q?=E4=BA=8B=E5=8A=A1=E5=92=8C=E6=97=A5=E5=BF=97=E5=8A=9F=E8=83=BD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/data/api_wrappers/api_stock_basic.py | 6 +- src/data/api_wrappers/api_trade_cal.py | 3 + src/data/api_wrappers/base_financial_sync.py | 220 ++++++--- src/data/api_wrappers/base_sync.py | 249 +++++++--- .../financial_data/api_financial_sync.py | 32 ++ src/data/storage.py | 157 ++++++- src/data/sync.py | 80 +++- src/data/sync_logger.py | 439 ++++++++++++++++++ src/experiment/regression.py | 56 ++- tests/test_sync_transaction_and_logs.py | 403 ++++++++++++++++ 10 files changed, 1468 insertions(+), 177 deletions(-) create mode 100644 src/data/sync_logger.py create mode 100644 tests/test_sync_transaction_and_logs.py diff --git a/src/data/api_wrappers/api_stock_basic.py b/src/data/api_wrappers/api_stock_basic.py index a695e2d..09613f5 100644 --- a/src/data/api_wrappers/api_stock_basic.py +++ b/src/data/api_wrappers/api_stock_basic.py @@ -76,7 +76,11 @@ def get_stock_basic( return data -def sync_all_stocks() -> pd.DataFrame: +def sync_all_stocks( + force_full: bool = False, + dry_run: bool = False, + **kwargs, +) -> pd.DataFrame: """Fetch and save all stocks (listed and delisted) to local storage. This is a special interface that should only be called once to initialize diff --git a/src/data/api_wrappers/api_trade_cal.py b/src/data/api_wrappers/api_trade_cal.py index 9839a2a..fe7381b 100644 --- a/src/data/api_wrappers/api_trade_cal.py +++ b/src/data/api_wrappers/api_trade_cal.py @@ -82,6 +82,9 @@ def sync_trade_cal_cache( start_date: str = "20180101", end_date: Optional[str] = None, force: bool = False, + force_full: bool = False, + dry_run: bool = False, + **kwargs, ) -> pd.DataFrame: """Sync trade calendar data to local cache with incremental updates. diff --git a/src/data/api_wrappers/base_financial_sync.py b/src/data/api_wrappers/base_financial_sync.py index adcba1a..0978b30 100644 --- a/src/data/api_wrappers/base_financial_sync.py +++ b/src/data/api_wrappers/base_financial_sync.py @@ -25,6 +25,7 @@ from tqdm import tqdm from src.data.client import TushareClient from src.data.storage import ThreadSafeStorage, Storage +from src.data.sync_logger import SyncLogManager from src.data.utils import get_today_date, get_quarters_in_range, DEFAULT_START_DATE @@ -466,18 +467,28 @@ class QuarterBasedSync(ABC): inserted_count = 0 if is_first_sync_for_period: - # 首次同步该季度:直接插入所有数据 + # 首次同步该季度:直接插入所有数据(使用事务) print( f"[{self.__class__.__name__}] First sync for quarter {period}, inserting all data directly" ) if not dry_run: - self.storage.queue_save(self.table_name, remote_df, use_upsert=False) - self.storage.flush() - inserted_count = len(remote_df) - print( - f"[{self.__class__.__name__}] Inserted {inserted_count} new records" - ) + try: + # 开始事务 + self.storage.begin_transaction() + self.storage.queue_save( + self.table_name, remote_df, use_upsert=False + ) + self.storage.flush(use_transaction=False) + self.storage.commit_transaction() + inserted_count = len(remote_df) + print( + f"[{self.__class__.__name__}] Inserted {inserted_count} new records (transaction committed)" + ) + except Exception as e: + self.storage.rollback_transaction() + print(f"[{self.__class__.__name__}] Transaction rolled back: {e}") + raise return { "period": period, @@ -501,19 +512,33 @@ class QuarterBasedSync(ABC): print(f" - Unchanged stocks: {unchanged_count}") if not dry_run and not diff_df.empty: - # 5.1 删除差异股票的旧数据 - deleted_stocks_count = len(diff_stocks) - self.delete_stock_quarter_data(period, diff_stocks) - deleted_count = len(diff_df) - print( - f"[{self.__class__.__name__}] Deleted {deleted_stocks_count} stocks' old records (approx {deleted_count} rows)" - ) + try: + # 开始事务 + self.storage.begin_transaction() - # 5.2 插入新数据(使用普通 INSERT,因为已删除旧数据) - self.storage.queue_save(self.table_name, diff_df, use_upsert=False) - self.storage.flush() - inserted_count = len(diff_df) - print(f"[{self.__class__.__name__}] Inserted {inserted_count} new records") + # 5.1 删除差异股票的旧数据 + deleted_stocks_count = len(diff_stocks) + self.delete_stock_quarter_data(period, diff_stocks) + deleted_count = len(diff_df) + print( + f"[{self.__class__.__name__}] Deleted {deleted_stocks_count} stocks' old records (approx {deleted_count} rows)" + ) + + # 5.2 插入新数据(使用普通 INSERT,因为已删除旧数据) + self.storage.queue_save(self.table_name, diff_df, use_upsert=False) + self.storage.flush(use_transaction=False) + inserted_count = len(diff_df) + + # 提交事务 + self.storage.commit_transaction() + print( + f"[{self.__class__.__name__}] Inserted {inserted_count} new records (transaction committed)" + ) + + except Exception as e: + self.storage.rollback_transaction() + print(f"[{self.__class__.__name__}] Transaction rolled back: {e}") + raise return { "period": period, @@ -583,55 +608,86 @@ class QuarterBasedSync(ABC): print(f"[{self.__class__.__name__}] Incremental Sync") print(f"{'=' * 60}") - # 0. 确保表存在(首次同步时自动建表) - self.ensure_table_exists() + # 初始化日志管理器 + log_manager = SyncLogManager() + log_entry = log_manager.start_sync( + table_name=self.table_name, + sync_type="incremental", + metadata={"dry_run": dry_run}, + ) - # 1. 获取最新季度 - storage = Storage() try: - result = storage._connection.execute( - f'SELECT MAX(end_date) FROM "{self.table_name}"' - ).fetchone() - latest_quarter = result[0] if result and result[0] else None - if hasattr(latest_quarter, "strftime"): - latest_quarter = latest_quarter.strftime("%Y%m%d") + # 0. 确保表存在(首次同步时自动建表) + self.ensure_table_exists() + + # 1. 获取最新季度 + storage = Storage() + try: + result = storage._connection.execute( + f'SELECT MAX(end_date) FROM "{self.table_name}"' + ).fetchone() + latest_quarter = result[0] if result and result[0] else None + if hasattr(latest_quarter, "strftime"): + latest_quarter = latest_quarter.strftime("%Y%m%d") + except Exception as e: + print(f"[{self.__class__.__name__}] Error getting latest quarter: {e}") + latest_quarter = None + + # 2. 获取当前季度 + current_quarter = self.get_current_quarter() + + if latest_quarter is None: + # 无本地数据,执行全量同步 + print( + f"[{self.__class__.__name__}] No local data, performing full sync" + ) + results = self.sync_range( + self.DEFAULT_START_DATE, current_quarter, dry_run + ) + else: + print( + f"[{self.__class__.__name__}] Latest local quarter: {latest_quarter}" + ) + print(f"[{self.__class__.__name__}] Current quarter: {current_quarter}") + + # 3. 确定同步范围 + start_quarter = latest_quarter + if start_quarter > current_quarter: + start_quarter = current_quarter + else: + start_quarter = self.get_prev_quarter(latest_quarter) + + if start_quarter < self.DEFAULT_START_DATE: + start_quarter = self.DEFAULT_START_DATE + + # 打印同步的两个季度信息 + print(f"\n[{self.__class__.__name__}] 将同步以下两个季度的财报:") + print(f" - 前一季度: {start_quarter}") + print(f" - 当前季度: {current_quarter}") + print(f" (包含前一季度以确保数据完整性)") + print() + + results = self.sync_range(start_quarter, current_quarter, dry_run) + + # 计算总插入记录数 + total_inserted = sum( + r.get("inserted_count", 0) for r in results if isinstance(r, dict) + ) + + # 完成日志记录 + log_manager.complete_sync( + log_entry, + status="success", + records_inserted=total_inserted, + records_updated=0, + records_deleted=0, + ) + + return results + except Exception as e: - print(f"[{self.__class__.__name__}] Error getting latest quarter: {e}") - latest_quarter = None - - # 2. 获取当前季度 - current_quarter = self.get_current_quarter() - - if latest_quarter is None: - # 无本地数据,执行全量同步 - print(f"[{self.__class__.__name__}] No local data, performing full sync") - return self.sync_range(self.DEFAULT_START_DATE, current_quarter, dry_run) - - print(f"[{self.__class__.__name__}] Latest local quarter: {latest_quarter}") - print(f"[{self.__class__.__name__}] Current quarter: {current_quarter}") - - # 3. 确定同步范围 - # 财务数据必须每次都进行对比更新,不存在"跳过"的情况 - # 同步范围:从最新季度到当前季度(包含前一季度以确保数据完整性) - start_quarter = latest_quarter - if start_quarter > current_quarter: - # 如果本地数据比当前季度还新,仍然需要同步(可能包含修正数据) - start_quarter = current_quarter - else: - # 正常情况:包含前一季度 - start_quarter = self.get_prev_quarter(latest_quarter) - - if start_quarter < self.DEFAULT_START_DATE: - start_quarter = self.DEFAULT_START_DATE - - # 打印同步的两个季度信息 - print(f"\n[{self.__class__.__name__}] 将同步以下两个季度的财报:") - print(f" - 前一季度: {start_quarter}") - print(f" - 当前季度: {current_quarter}") - print(f" (包含前一季度以确保数据完整性)") - print() - - return self.sync_range(start_quarter, current_quarter, dry_run) + log_manager.complete_sync(log_entry, status="failed", error_message=str(e)) + raise def sync_full(self, dry_run: bool = False) -> List[Dict]: """执行全量同步。 @@ -646,12 +702,38 @@ class QuarterBasedSync(ABC): print(f"[{self.__class__.__name__}] Full Sync") print(f"{'=' * 60}") - # 确保表存在 - self.ensure_table_exists() + # 初始化日志管理器 + log_manager = SyncLogManager() + log_entry = log_manager.start_sync( + table_name=self.table_name, sync_type="full", metadata={"dry_run": dry_run} + ) - current_quarter = self.get_current_quarter() + try: + # 确保表存在 + self.ensure_table_exists() - return self.sync_range(self.DEFAULT_START_DATE, current_quarter, dry_run) + current_quarter = self.get_current_quarter() + results = self.sync_range(self.DEFAULT_START_DATE, current_quarter, dry_run) + + # 计算总插入记录数 + total_inserted = sum( + r.get("inserted_count", 0) for r in results if isinstance(r, dict) + ) + + # 完成日志记录 + log_manager.complete_sync( + log_entry, + status="success", + records_inserted=total_inserted, + records_updated=0, + records_deleted=0, + ) + + return results + + except Exception as e: + log_manager.complete_sync(log_entry, status="failed", error_message=str(e)) + raise # ====================================================================== # 预览模式 diff --git a/src/data/api_wrappers/base_sync.py b/src/data/api_wrappers/base_sync.py index b67c006..0bbd1ec 100644 --- a/src/data/api_wrappers/base_sync.py +++ b/src/data/api_wrappers/base_sync.py @@ -34,6 +34,7 @@ from tqdm import tqdm from src.data.client import TushareClient from src.data.storage import ThreadSafeStorage, Storage +from src.data.sync_logger import SyncLogManager from src.data.utils import get_today_date, get_next_date from src.config.settings import get_settings from src.data.api_wrappers.api_trade_cal import ( @@ -614,14 +615,30 @@ class StockBasedSync(BaseDataSync): finally: pbar.close() - # 批量写入所有数据(仅在无错误时) + # 批量写入所有数据(仅在无错误时),使用事务确保原子性 if results and not error_occurred: - for ts_code, data in results.items(): - if not data.empty: - self.storage.queue_save(self.table_name, data) - self.storage.flush() - total_rows = sum(len(df) for df in results.values()) - print(f"\n[{class_name}] Saved {total_rows} rows to storage") + try: + # 开始事务 + self.storage.begin_transaction() + + for ts_code, data in results.items(): + if not data.empty: + self.storage.queue_save(self.table_name, data) + # flush 在事务中执行 + self.storage.flush(use_transaction=False) + + # 提交事务 + self.storage.commit_transaction() + total_rows = sum(len(df) for df in results.values()) + print( + f"\n[{class_name}] Saved {total_rows} rows to storage (transaction committed)" + ) + except Exception as e: + # 回滚事务 + self.storage.rollback_transaction() + print(f"\n[{class_name}] Transaction rolled back due to error: {e}") + error_occurred = True + exception_to_raise = e # 打印摘要 print(f"\n{'=' * 60}") @@ -824,6 +841,17 @@ class StockBasedSync(BaseDataSync): print(f"[{class_name}] Starting {self.table_name} data sync...") print(f"{'=' * 60}") + # 初始化日志管理器并记录开始 + log_manager = SyncLogManager() + sync_mode = "full" if force_full else "incremental" + log_entry = log_manager.start_sync( + table_name=self.table_name, + sync_type=sync_mode, + date_range_start=start_date, + date_range_end=end_date, + metadata={"dry_run": dry_run, "max_workers": max_workers}, + ) + # 首先确保交易日历缓存是最新的 print(f"[{class_name}] Syncing trade calendar cache...") sync_trade_cal_cache() @@ -911,13 +939,33 @@ class StockBasedSync(BaseDataSync): print(f"[{class_name}] Using {max_workers or self.max_workers} worker threads") # 执行并发同步 - return self._run_concurrent_sync( - stock_codes=stock_codes, - start_date=sync_start_date, - end_date=end_date, - max_workers=max_workers, - dry_run=dry_run, - ) + try: + results = self._run_concurrent_sync( + stock_codes=stock_codes, + start_date=sync_start_date, + end_date=end_date, + max_workers=max_workers, + dry_run=dry_run, + ) + + # 计算同步结果统计 + total_inserted = sum(len(df) for df in results.values()) if results else 0 + + # 完成日志记录 + log_manager.complete_sync( + log_entry, + status="success" if results else "partial", + records_inserted=total_inserted, + records_updated=0, + records_deleted=0, + ) + + return results + + except Exception as e: + # 记录失败日志 + log_manager.complete_sync(log_entry, status="failed", error_message=str(e)) + raise class DateBasedSync(BaseDataSync): @@ -1117,33 +1165,52 @@ class DateBasedSync(BaseDataSync): return combined - def _save_data(self, data: pd.DataFrame, sync_start_date: str) -> None: - """保存数据到存储。 + def _save_data(self, data: pd.DataFrame, sync_start_date: str) -> int: + """保存数据到存储(使用事务确保原子性)。 Args: data: 要保存的数据 sync_start_date: 同步起始日期(用于删除旧数据) + + Returns: + 保存的记录数 """ if data.empty: - return + return 0 - storage = Storage() + storage = Storage(read_only=False) + records_count = len(data) - # 删除日期范围内的旧数据 - if sync_start_date: - sync_start_date_fmt = pd.to_datetime( - sync_start_date, format="%Y%m%d" - ).date() - storage._connection.execute( - f'DELETE FROM "{self.table_name}" WHERE "trade_date" >= ?', - [sync_start_date_fmt], + try: + # 开始事务 + storage.begin_transaction() + + # 删除日期范围内的旧数据 + if sync_start_date: + sync_start_date_fmt = pd.to_datetime( + sync_start_date, format="%Y%m%d" + ).date() + storage._connection.execute( + f'DELETE FROM "{self.table_name}" WHERE "trade_date" >= ?', + [sync_start_date_fmt], + ) + + # 保存新数据(在事务中直接保存,不使用队列) + storage.save(self.table_name, data, mode="append") + + # 提交事务 + storage.commit_transaction() + print( + f"[{self.__class__.__name__}] Saved {records_count} records to DuckDB (transaction committed)" ) + return records_count - # 保存新数据 - self.storage.queue_save(self.table_name, data) - self.storage.flush() - - print(f"[{self.__class__.__name__}] Saved {len(data)} records to DuckDB") + except Exception as e: + storage.rollback_transaction() + print( + f"[{self.__class__.__name__}] Transaction rolled back due to error: {e}" + ) + raise def preview_sync( self, @@ -1280,12 +1347,25 @@ class DateBasedSync(BaseDataSync): print(f"[{class_name}] Starting {self.table_name} data sync...") print(f"{'=' * 60}") + # 初始化日志管理器并记录开始 + log_manager = SyncLogManager() + sync_mode = "full" if force_full else "incremental" + log_entry = log_manager.start_sync( + table_name=self.table_name, + sync_type=sync_mode, + date_range_start=start_date, + date_range_end=end_date, + metadata={"dry_run": dry_run}, + ) + # 确定日期范围 sync_start, sync_end, mode = self._get_sync_date_range( start_date, end_date, force_full ) if mode == "up_to_date": + # 记录跳过日志 + log_manager.complete_sync(log_entry, status="success", records_inserted=0) return pd.DataFrame() if dry_run: @@ -1295,49 +1375,74 @@ class DateBasedSync(BaseDataSync): print(f" Would sync from {sync_start} to {sync_end}") print(f" Mode: {mode}") print(f"{'=' * 60}") + # 记录 dry run 日志 + log_manager.complete_sync(log_entry, status="success", records_inserted=0) return pd.DataFrame() - # 检查表是否存在,不存在则创建 - storage = Storage() - if not storage.exists(self.table_name): - print( - f"[{class_name}] Table '{self.table_name}' doesn't exist, creating..." + try: + # 检查表是否存在,不存在则创建 + storage = Storage() + if not storage.exists(self.table_name): + print( + f"[{class_name}] Table '{self.table_name}' doesn't exist, creating..." + ) + # 获取样本数据以推断 schema + sample = self.fetch_single_date(sync_end) + if sample.empty: + # 尝试另一个日期 + sample = self.fetch_single_date("20240102") + if not sample.empty: + self._ensure_table_schema(sample) + else: + print( + f"[{class_name}] Cannot create table: no sample data available" + ) + log_manager.complete_sync( + log_entry, + status="failed", + error_message="No sample data available", + ) + return pd.DataFrame() + + # 首次同步探测:验证表结构是否正常 + if self._should_probe_table(): + print(f"[{class_name}] Table '{self.table_name}' is empty, probing...") + # 使用最近一个交易日的完整数据进行探测 + probe_date = get_last_trading_day(sync_start, sync_end) + if probe_date: + probe_data = self.fetch_single_date(probe_date) + probe_desc = f"date={probe_date}, all stocks" + self._probe_table_and_cleanup(probe_data, probe_desc) + + # 执行同步 + combined = self._run_date_range_sync(sync_start, sync_end, dry_run) + + # 保存数据(在事务中) + records_saved = 0 + if not combined.empty: + records_saved = self._save_data(combined, sync_start) + + # 打印摘要 + print(f"\n{'=' * 60}") + print(f"[{class_name}] Sync Summary") + print(f"{'=' * 60}") + print(f" Mode: {mode}") + print(f" Date range: {sync_start} to {sync_end}") + print(f" Records: {len(combined)}") + print(f"{'=' * 60}") + + # 完成日志记录 + log_manager.complete_sync( + log_entry, + status="success", + records_inserted=records_saved, + records_updated=0, + records_deleted=0, ) - # 获取样本数据以推断 schema - sample = self.fetch_single_date(sync_end) - if sample.empty: - # 尝试另一个日期 - sample = self.fetch_single_date("20240102") - if not sample.empty: - self._ensure_table_schema(sample) - else: - print(f"[{class_name}] Cannot create table: no sample data available") - return pd.DataFrame() - # 首次同步探测:验证表结构是否正常 - if self._should_probe_table(): - print(f"[{class_name}] Table '{self.table_name}' is empty, probing...") - # 使用最近一个交易日的完整数据进行探测 - probe_date = get_last_trading_day(sync_start, sync_end) - if probe_date: - probe_data = self.fetch_single_date(probe_date) - probe_desc = f"date={probe_date}, all stocks" - self._probe_table_and_cleanup(probe_data, probe_desc) + return combined - # 执行同步 - combined = self._run_date_range_sync(sync_start, sync_end, dry_run) - - # 保存数据 - if not combined.empty: - self._save_data(combined, sync_start) - - # 打印摘要 - print(f"\n{'=' * 60}") - print(f"[{class_name}] Sync Summary") - print(f"{'=' * 60}") - print(f" Mode: {mode}") - print(f" Date range: {sync_start} to {sync_end}") - print(f" Records: {len(combined)}") - print(f"{'=' * 60}") - - return combined + except Exception as e: + # 记录失败日志 + log_manager.complete_sync(log_entry, status="failed", error_message=str(e)) + raise diff --git a/src/data/api_wrappers/financial_data/api_financial_sync.py b/src/data/api_wrappers/financial_data/api_financial_sync.py index 38e91ac..7f9145f 100644 --- a/src/data/api_wrappers/financial_data/api_financial_sync.py +++ b/src/data/api_wrappers/financial_data/api_financial_sync.py @@ -36,6 +36,7 @@ from typing import List, Optional +from src.data.sync_logger import SyncLogManager from src.data.api_wrappers.financial_data.api_income import ( IncomeQuarterSync, sync_income, @@ -120,6 +121,14 @@ def sync_financial( results = {} + # 初始化日志管理器并记录调度中心开始 + log_manager = SyncLogManager() + log_entry = log_manager.start_sync( + table_name="financial_data_batch", + sync_type="full" if force_full else "incremental", + metadata={"data_types": data_types, "dry_run": dry_run}, + ) + print("\n" + "=" * 60) print("[Financial Sync] 财务数据同步调度中心") print("=" * 60) @@ -128,10 +137,14 @@ def sync_financial( print(f"写入模式: {'预览' if dry_run else '实际写入'}") print("=" * 60) + total_inserted = 0 + failed_types = [] + for data_type in data_types: if data_type not in FINANCIAL_SYNCERS: print(f"[WARN] 未知的数据类型: {data_type}") results[data_type] = {"error": f"Unknown data type: {data_type}"} + failed_types.append(data_type) continue config = FINANCIAL_SYNCERS[data_type] @@ -143,10 +156,19 @@ def sync_financial( try: result = sync_func(force_full=force_full, dry_run=dry_run) results[data_type] = result + + # 统计插入的记录数 + if isinstance(result, list): + inserted = sum( + r.get("inserted_count", 0) for r in result if isinstance(r, dict) + ) + total_inserted += inserted + print(f"[{display_name}] 同步完成") except Exception as e: print(f"[ERROR] [{display_name}] 同步失败: {e}") results[data_type] = {"error": str(e)} + failed_types.append(data_type) # 打印汇总 print("\n" + "=" * 60) @@ -169,6 +191,16 @@ def sync_financial( print(f" {display_name}: {status}") print("=" * 60) + # 完成调度中心日志记录 + status = "failed" if failed_types else "success" + error_msg = f"Failed types: {failed_types}" if failed_types else None + log_manager.complete_sync( + log_entry, + status=status, + records_inserted=total_inserted, + error_message=error_msg, + ) + return results diff --git a/src/data/storage.py b/src/data/storage.py index 5d6aaca..f6d4473 100644 --- a/src/data/storage.py +++ b/src/data/storage.py @@ -313,6 +313,87 @@ class Storage: Storage._connection = None Storage._instance = None + # ====================================================================== + # 事务支持 + # ====================================================================== + + def begin_transaction(self) -> None: + """开始事务。 + + 使用方式: + storage = Storage(read_only=False) + storage.begin_transaction() + try: + storage.save("table1", data1) + storage.save("table2", data2) + storage.commit_transaction() + except Exception: + storage.rollback_transaction() + raise + """ + if self._read_only: + raise RuntimeError("Cannot begin transaction in read-only mode") + + try: + self._connection.execute("BEGIN TRANSACTION") + print("[Storage] Transaction started") + except Exception as e: + print(f"[Storage] Error starting transaction: {e}") + raise + + def commit_transaction(self) -> None: + """提交事务。""" + if self._read_only: + return + + try: + self._connection.execute("COMMIT") + print("[Storage] Transaction committed") + except Exception as e: + print(f"[Storage] Error committing transaction: {e}") + raise + + def rollback_transaction(self) -> None: + """回滚事务。""" + if self._read_only: + return + + try: + self._connection.execute("ROLLBACK") + print("[Storage] Transaction rolled back") + except Exception as e: + print(f"[Storage] Error rolling back transaction: {e}") + + def transaction(self): + """事务上下文管理器。 + + 使用方式: + storage = Storage(read_only=False) + with storage.transaction(): + storage.save("table1", data1) + storage.save("table2", data2) + # 自动提交或回滚 + """ + return _TransactionContext(self) + + +class _TransactionContext: + """事务上下文管理器。""" + + def __init__(self, storage: Storage): + self.storage = storage + + def __enter__(self): + self.storage.begin_transaction() + return self.storage + + def __exit__(self, exc_type, exc_val, exc_tb): + if exc_type is None: + self.storage.commit_transaction() + else: + self.storage.rollback_transaction() + return False + class ThreadSafeStorage: """线程安全的 DuckDB 写入包装器。 @@ -323,12 +404,14 @@ class ThreadSafeStorage: 注意: - 此类自动使用 read_only=False 模式,用于数据同步 - 不要在多进程中同时使用此类,只应在单进程中用于批量写入 + - 支持事务模式:在事务中批量写入,确保原子性 """ def __init__(self): # 使用 read_only=False 模式创建 Storage,用于写入操作 self.storage = Storage(read_only=False) self._pending_writes: List[tuple] = [] # [(name, data, use_upsert), ...] + self._in_transaction: bool = False def queue_save(self, name: str, data: pd.DataFrame, use_upsert: bool = True): """将数据放入写入队列(不立即写入) @@ -341,10 +424,13 @@ class ThreadSafeStorage: if not data.empty: self._pending_writes.append((name, data, use_upsert)) - def flush(self): + def flush(self, use_transaction: bool = True): """批量写入所有队列数据。 调用时机:在 sync 结束时统一调用,避免并发写入冲突。 + + Args: + use_transaction: 若为 True,使用事务包装批量写入 """ if not self._pending_writes: return @@ -357,17 +443,72 @@ class ThreadSafeStorage: table_data[key].append(data) # 批量写入每个表 - for (name, use_upsert), data_list in table_data.items(): - combined = pd.concat(data_list, ignore_index=True) - # 在批量数据中先去重 - if "ts_code" in combined.columns and "trade_date" in combined.columns: - combined = combined.drop_duplicates( - subset=["ts_code", "trade_date"], keep="last" + if use_transaction: + # 使用事务确保原子性 + try: + self.storage.begin_transaction() + total_rows = 0 + + for (name, use_upsert), data_list in table_data.items(): + combined = pd.concat(data_list, ignore_index=True) + # 在批量数据中先去重 + if ( + "ts_code" in combined.columns + and "trade_date" in combined.columns + ): + combined = combined.drop_duplicates( + subset=["ts_code", "trade_date"], keep="last" + ) + result = self.storage.save( + name, combined, mode="append", use_upsert=use_upsert + ) + if result.get("status") == "success": + total_rows += result.get("rows", 0) + + self.storage.commit_transaction() + print( + f"[ThreadSafeStorage] Transaction committed: {total_rows} rows saved" ) - self.storage.save(name, combined, mode="append", use_upsert=use_upsert) + + except Exception as e: + self.storage.rollback_transaction() + print(f"[ThreadSafeStorage] Transaction rolled back due to error: {e}") + raise + else: + # 不使用事务,逐表写入 + for (name, use_upsert), data_list in table_data.items(): + combined = pd.concat(data_list, ignore_index=True) + # 在批量数据中先去重 + if "ts_code" in combined.columns and "trade_date" in combined.columns: + combined = combined.drop_duplicates( + subset=["ts_code", "trade_date"], keep="last" + ) + self.storage.save(name, combined, mode="append", use_upsert=use_upsert) self._pending_writes.clear() + def begin_transaction(self) -> None: + """开始事务模式。 + + 在事务模式下,flush 会被延迟到 commit_transaction 或 rollback_transaction。 + """ + self._in_transaction = True + self.storage.begin_transaction() + + def commit_transaction(self) -> None: + """提交事务。""" + if self._pending_writes: + # 先写入队列中的数据 + self.flush(use_transaction=False) + self.storage.commit_transaction() + self._in_transaction = False + + def rollback_transaction(self) -> None: + """回滚事务。""" + self._pending_writes.clear() + self.storage.rollback_transaction() + self._in_transaction = False + def __getattr__(self, name): """代理其他方法到 Storage 实例""" return getattr(self.storage, name) diff --git a/src/data/sync.py b/src/data/sync.py index d199d2a..dad3085 100644 --- a/src/data/sync.py +++ b/src/data/sync.py @@ -55,6 +55,7 @@ import pandas as pd from src.data import api_wrappers # noqa: F401 from src.data.sync_registry import sync_registry from src.data.api_wrappers import sync_all_stocks +from src.data.sync_logger import SyncLogManager def sync_all_data( @@ -109,7 +110,7 @@ def sync_all_data( >>> result = sync_all_data(dry_run=True) >>> >>> # 只同步特定任务 - >>> result = sync_all_data(selected=["trade_cal", "stock_basic"]) + >>> result = sync_all_data(selected=['trade_cal', 'pro_bar']) >>> >>> # 查看所有可用任务 >>> from src.data.sync_registry import sync_registry @@ -117,13 +118,80 @@ def sync_all_data( >>> for t in tasks: ... print(f"{t.name}: {t.display_name}") """ - return sync_registry.sync_all( - force_full=force_full, - max_workers=max_workers, - dry_run=dry_run, - selected=selected, + # 记录调度中心开始 + log_manager = SyncLogManager() + sync_mode = "full" if force_full else "incremental" + selected_str = ",".join(selected) if selected else "all" + log_manager.start_sync( + table_name="daily_data_batch", + sync_type=sync_mode, + metadata={ + "selected": selected_str, + "dry_run": dry_run, + "max_workers": max_workers, + }, ) + try: + result = sync_registry.sync_all( + force_full=force_full, + max_workers=max_workers, + dry_run=dry_run, + selected=selected, + ) + + # 计算成功/失败数量 + success_count = 0 + failed_count = 0 + total_records = 0 + for task_name, task_result in result.items(): + if isinstance(task_result, dict): + if task_result.get("status") == "error": + failed_count += 1 + else: + success_count += 1 + # 累加记录数(如果有) + if "rows" in task_result: + total_records += task_result.get("rows", 0) + elif isinstance(task_result, pd.DataFrame): + success_count += 1 + total_records += len(task_result) + else: + success_count += 1 + + # 记录完成日志 + status = "partial" if failed_count > 0 else "success" + error_msg = f"Failed: {failed_count} tasks" if failed_count > 0 else None + log_manager.complete_sync( + table_name="daily_data_batch", + sync_type=sync_mode, + status=status, + records_inserted=total_records, + error_message=error_msg, + metadata={ + "selected": selected_str, + "dry_run": dry_run, + "max_workers": max_workers, + }, + ) + + return result + + except Exception as e: + # 记录失败日志 + log_manager.complete_sync( + table_name="daily_data_batch", + sync_type=sync_mode, + status="failed", + error_message=str(e), + metadata={ + "selected": selected_str, + "dry_run": dry_run, + "max_workers": max_workers, + }, + ) + raise + def list_sync_tasks() -> list[dict[str, Any]]: """列出所有已注册的同步任务。 diff --git a/src/data/sync_logger.py b/src/data/sync_logger.py new file mode 100644 index 0000000..af1d1f9 --- /dev/null +++ b/src/data/sync_logger.py @@ -0,0 +1,439 @@ +"""同步日志管理模块。 + +提供同步操作日志记录和查询功能,追踪每次数据同步的详细信息。 +设计理念:日志只插入不更新,无需主键 ID。 +""" + +from datetime import datetime +from typing import Optional, Dict, Any, List +import pandas as pd +from dataclasses import dataclass + +from src.data.storage import Storage + + +@dataclass +class SyncLogEntry: + """同步日志条目。 + + Attributes: + table_name: 同步的表名 + sync_type: 同步类型(full/incremental/quarterly/daily 等) + start_time: 同步开始时间 + end_time: 同步结束时间 + status: 同步状态(success/failed/partial/running) + records_before: 同步前记录数 + records_after: 同步后记录数 + records_inserted: 插入的记录数 + records_updated: 更新的记录数 + records_deleted: 删除的记录数 + date_range_start: 数据日期范围起始 + date_range_end: 数据日期范围结束 + error_message: 错误信息(如果有) + metadata: 额外元数据(JSON 格式) + """ + + table_name: str + sync_type: str + start_time: datetime + end_time: Optional[datetime] = None + status: str = "running" + records_before: int = 0 + records_after: int = 0 + records_inserted: int = 0 + records_updated: int = 0 + records_deleted: int = 0 + date_range_start: Optional[str] = None + date_range_end: Optional[str] = None + error_message: Optional[str] = None + metadata: Optional[str] = None + + +class SyncLogManager: + """同步日志管理器。 + + 管理同步日志表的创建、记录插入和查询。 + 设计原则:日志只插入,不更新,无需主键。 + + 使用方式: + # 记录一次同步操作(开始和结束分别记录) + log_manager = SyncLogManager() + log_manager.start_sync("daily", "incremental") # 记录开始 + try: + # 执行同步... + log_manager.complete_sync("daily", "success", records_inserted=1000) # 记录完成 + except Exception as e: + log_manager.complete_sync("daily", "failed", error_message=str(e)) # 记录失败 + + # 查询同步历史 + logs = log_manager.get_sync_history("daily", limit=10) + """ + + TABLE_NAME = "_sync_logs" + + def __init__(self, storage: Optional[Storage] = None): + """初始化日志管理器。 + + Args: + storage: Storage 实例,如果为 None 则创建新实例 + """ + self.storage = storage or Storage(read_only=False) + self._ensure_table_exists() + + def _ensure_table_exists(self) -> None: + """确保日志表存在。""" + create_sql = f""" + CREATE TABLE IF NOT EXISTS {self.TABLE_NAME} ( + table_name VARCHAR(64) NOT NULL, + sync_type VARCHAR(32) NOT NULL, + start_time TIMESTAMP NOT NULL, + end_time TIMESTAMP, + status VARCHAR(16) DEFAULT 'running', + records_before INTEGER DEFAULT 0, + records_after INTEGER DEFAULT 0, + records_inserted INTEGER DEFAULT 0, + records_updated INTEGER DEFAULT 0, + records_deleted INTEGER DEFAULT 0, + date_range_start VARCHAR(8), + date_range_end VARCHAR(8), + error_message TEXT, + metadata TEXT, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP + ) + """ + + # 创建索引 + index_sql = f""" + CREATE INDEX IF NOT EXISTS idx_sync_logs_table_time + ON {self.TABLE_NAME}(table_name, start_time DESC); + + CREATE INDEX IF NOT EXISTS idx_sync_logs_status + ON {self.TABLE_NAME}(status); + + CREATE INDEX IF NOT EXISTS idx_sync_logs_type_time + ON {self.TABLE_NAME}(sync_type, start_time DESC); + """ + + try: + self.storage._connection.execute(create_sql) + self.storage._connection.execute(index_sql) + except Exception as e: + print(f"[SyncLogManager] Error creating log table: {e}") + raise + + def start_sync( + self, + table_name: str, + sync_type: str, + date_range_start: Optional[str] = None, + date_range_end: Optional[str] = None, + metadata: Optional[Dict[str, Any]] = None, + ) -> None: + """记录同步开始。 + + Args: + table_name: 同步的表名 + sync_type: 同步类型(full/incremental/quarterly 等) + date_range_start: 数据日期范围起始 + date_range_end: 数据日期范围结束 + metadata: 额外元数据字典 + """ + entry = SyncLogEntry( + table_name=table_name, + sync_type=sync_type, + start_time=datetime.now(), + status="running", + records_before=0, # 不统计记录数,避免额外查询 + date_range_start=date_range_start, + date_range_end=date_range_end, + metadata=str(metadata) if metadata else None, + ) + + try: + self.storage._connection.execute( + f""" + INSERT INTO {self.TABLE_NAME} ( + table_name, sync_type, start_time, status, + records_before, date_range_start, date_range_end, metadata + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?) + """, + ( + entry.table_name, + entry.sync_type, + entry.start_time, + entry.status, + entry.records_before, + entry.date_range_start, + entry.date_range_end, + entry.metadata, + ), + ) + except Exception as e: + print(f"[SyncLogManager] Error logging sync start: {e}") + + def complete_sync( + self, + table_name: str, + status: str = "success", + records_inserted: int = 0, + records_updated: int = 0, + records_deleted: int = 0, + error_message: Optional[str] = None, + sync_type: str = "incremental", + date_range_start: Optional[str] = None, + date_range_end: Optional[str] = None, + metadata: Optional[Dict[str, Any]] = None, + ) -> None: + """记录同步完成。 + + 作为一条新的日志记录插入,不更新之前的记录。 + + Args: + table_name: 同步的表名 + status: 同步状态(success/failed/partial) + records_inserted: 插入的记录数 + records_updated: 更新的记录数 + records_deleted: 删除的记录数 + error_message: 错误信息(如果失败) + sync_type: 同步类型 + date_range_start: 日期范围起始 + date_range_end: 日期范围结束 + metadata: 元数据 + """ + entry = SyncLogEntry( + table_name=table_name, + sync_type=sync_type, + start_time=datetime.now(), + end_time=datetime.now(), + status=status, + records_inserted=records_inserted, + records_updated=records_updated, + records_deleted=records_deleted, + date_range_start=date_range_start, + date_range_end=date_range_end, + error_message=error_message, + metadata=str(metadata) if metadata else None, + ) + + try: + self.storage._connection.execute( + f""" + INSERT INTO {self.TABLE_NAME} ( + table_name, sync_type, start_time, end_time, status, + records_inserted, records_updated, records_deleted, + date_range_start, date_range_end, error_message, metadata + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + """, + ( + entry.table_name, + entry.sync_type, + entry.start_time, + entry.end_time, + entry.status, + entry.records_inserted, + entry.records_updated, + entry.records_deleted, + entry.date_range_start, + entry.date_range_end, + entry.error_message, + entry.metadata, + ), + ) + except Exception as e: + print(f"[SyncLogManager] Error logging sync complete: {e}") + + def get_sync_history( + self, + table_name: Optional[str] = None, + status: Optional[str] = None, + limit: int = 100, + offset: int = 0, + ) -> pd.DataFrame: + """查询同步历史。 + + Args: + table_name: 按表名过滤 + status: 按状态过滤 + limit: 返回记录数限制 + offset: 分页偏移 + + Returns: + 同步历史 DataFrame + """ + conditions = [] + params = [] + + if table_name: + conditions.append("table_name = ?") + params.append(table_name) + + if status: + conditions.append("status = ?") + params.append(status) + + where_clause = f"WHERE {' AND '.join(conditions)}" if conditions else "" + + query = f""" + SELECT * FROM {self.TABLE_NAME} + {where_clause} + ORDER BY start_time DESC + LIMIT ? OFFSET ? + """ + params.extend([limit, offset]) + + try: + return self.storage._connection.execute(query, params).fetchdf() + except Exception as e: + print(f"[SyncLogManager] Error querying sync history: {e}") + return pd.DataFrame() + + def get_last_sync( + self, table_name: str, status: Optional[str] = "success" + ) -> Optional[Dict[str, Any]]: + """获取指定表的最近一次同步记录。 + + Args: + table_name: 表名 + status: 状态过滤,None 表示不限制 + + Returns: + 同步记录字典,如果没有则返回 None + """ + conditions = ["table_name = ?"] + params = [table_name] + + if status: + conditions.append("status = ?") + params.append(status) + + query = f""" + SELECT * FROM {self.TABLE_NAME} + WHERE {" AND ".join(conditions)} + ORDER BY start_time DESC + LIMIT 1 + """ + + try: + df = self.storage._connection.execute(query, params).fetchdf() + if not df.empty: + return df.iloc[0].to_dict() + return None + except Exception as e: + print(f"[SyncLogManager] Error getting last sync: {e}") + return None + + def get_sync_summary( + self, table_name: Optional[str] = None, days: int = 30 + ) -> Dict[str, Any]: + """获取同步汇总统计。 + + Args: + table_name: 按表名过滤 + days: 最近多少天的统计 + + Returns: + 汇总统计字典 + """ + conditions = [f"start_time >= CURRENT_DATE - INTERVAL '{days} days'"] + params = [] + + if table_name: + conditions.append("table_name = ?") + params.append(table_name) + + where_clause = f"WHERE {' AND '.join(conditions)}" + + query = f""" + SELECT + COUNT(*) as total_syncs, + COUNT(CASE WHEN status = 'success' THEN 1 END) as success_count, + COUNT(CASE WHEN status = 'failed' THEN 1 END) as failed_count, + COUNT(CASE WHEN status = 'partial' THEN 1 END) as partial_count, + SUM(records_inserted) as total_inserted, + SUM(records_updated) as total_updated, + SUM(records_deleted) as total_deleted + FROM {self.TABLE_NAME} + {where_clause} + """ + + try: + df = self.storage._connection.execute(query, params).fetchdf() + if not df.empty: + return { + "total_syncs": int(df.iloc[0]["total_syncs"]), + "success_count": int(df.iloc[0]["success_count"]), + "failed_count": int(df.iloc[0]["failed_count"]), + "partial_count": int(df.iloc[0]["partial_count"]), + "total_inserted": int(df.iloc[0]["total_inserted"] or 0), + "total_updated": int(df.iloc[0]["total_updated"] or 0), + "total_deleted": int(df.iloc[0]["total_deleted"] or 0), + } + except Exception as e: + print(f"[SyncLogManager] Error getting sync summary: {e}") + + return { + "total_syncs": 0, + "success_count": 0, + "failed_count": 0, + "partial_count": 0, + "total_inserted": 0, + "total_updated": 0, + "total_deleted": 0, + } + + +# 便捷函数 + + +def log_sync_start( + table_name: str, + sync_type: str, + date_range_start: Optional[str] = None, + date_range_end: Optional[str] = None, + metadata: Optional[Dict[str, Any]] = None, +) -> None: + """便捷函数:记录同步开始。 + + Args: + table_name: 表名 + sync_type: 同步类型 + date_range_start: 日期范围起始 + date_range_end: 日期范围结束 + metadata: 元数据 + """ + manager = SyncLogManager() + manager.start_sync( + table_name, sync_type, date_range_start, date_range_end, metadata + ) + + +def log_sync_complete( + table_name: str, + status: str = "success", + records_inserted: int = 0, + records_updated: int = 0, + records_deleted: int = 0, + error_message: Optional[str] = None, + sync_type: str = "incremental", +) -> None: + """便捷函数:记录同步完成。 + + Args: + table_name: 表名 + status: 状态 + records_inserted: 插入数 + records_updated: 更新数 + records_deleted: 删除数 + error_message: 错误信息 + sync_type: 同步类型 + """ + manager = SyncLogManager() + manager.complete_sync( + table_name, + status, + records_inserted, + records_updated, + records_deleted, + error_message, + sync_type, + ) diff --git a/src/experiment/regression.py b/src/experiment/regression.py index 2a09b0d..0e713fd 100644 --- a/src/experiment/regression.py +++ b/src/experiment/regression.py @@ -101,14 +101,14 @@ EXCLUDED_FACTORS = [ 'GTJA_alpha005', 'GTJA_alpha036', 'GTJA_alpha027', - 'GTJA_alpha053', + 'GTJA_alpha044', 'GTJA_alpha073', 'GTJA_alpha104', 'GTJA_alpha103', - 'GTJA_alpha087', + 'GTJA_alpha105', 'GTJA_alpha092', + 'GTJA_alpha087', 'GTJA_alpha085', - 'GTJA_alpha044', 'GTJA_alpha062', 'GTJA_alpha124', 'GTJA_alpha133', @@ -203,23 +203,37 @@ print("\n" + "=" * 80) print("开始训练") print("=" * 80) -# 步骤 1: 股票池筛选 -print("\n[步骤 1/6] 股票池筛选") +# 步骤 1: 应用过滤器(ST股票过滤等) +print("\n[步骤 1/7] 应用数据过滤器") +print("-" * 60) +filtered_data = data +if st_filter: + print(" 应用 ST 股票过滤器...") + data_before = len(filtered_data) + filtered_data = st_filter.filter(filtered_data) + data_after = len(filtered_data) + print(f" 过滤前记录数: {data_before}") + print(f" 过滤后记录数: {data_after}") + print(f" 删除 ST 股票记录数: {data_before - data_after}") +else: + print(" 未配置 ST 过滤器,跳过") + +# 步骤 2: 股票池筛选 +print("\n[步骤 2/7] 股票池筛选") print("-" * 60) if pool_manager: print(" 执行每日独立筛选股票池...") - filtered_data = pool_manager.filter_and_select_daily(data) - print(f" 筛选前数据规模: {data.shape}") - print(f" 筛选后数据规模: {filtered_data.shape}") - print(f" 筛选前股票数: {data['ts_code'].n_unique()}") - print(f" 筛选后股票数: {filtered_data['ts_code'].n_unique()}") - print(f" 删除记录数: {len(data) - len(filtered_data)}") + pool_data_before = len(filtered_data) + filtered_data = pool_manager.filter_and_select_daily(filtered_data) + pool_data_after = len(filtered_data) + print(f" 筛选前数据规模: {pool_data_before}") + print(f" 筛选后数据规模: {pool_data_after}") + print(f" 删除记录数: {pool_data_before - pool_data_after}") else: - filtered_data = data print(" 未配置股票池管理器,跳过筛选") # %% -# 步骤 2: 划分训练/验证/测试集(正确的三分法) -print("\n[步骤 2/6] 划分训练集、验证集和测试集") +# 步骤 3: 划分训练/验证/测试集(正确的三分法) +print("\n[步骤 3/7] 划分训练集、验证集和测试集") print("-" * 60) if splitter: # 正确的三分法:train用于训练,val用于验证/早停,test仅用于最终评估 @@ -251,8 +265,8 @@ else: test_data = filtered_data print(" 未配置划分器,全部作为训练集") # %% -# 步骤 3: 数据质量检查(必须在预处理之前) -print("\n[步骤 3/7] 数据质量检查") +# 步骤 4: 数据质量检查(必须在预处理之前) +print("\n[步骤 4/7] 数据质量检查") print("-" * 60) print(" [说明] 此检查在 fillna 等处理之前执行,用于发现数据问题") @@ -269,8 +283,8 @@ check_data_quality(test_data, feature_cols, raise_on_error=True) print(" [成功] 数据质量检查通过,未发现异常") # %% -# 步骤 4: 训练集数据处理 -print("\n[步骤 4/7] 训练集数据处理") +# 步骤 5: 训练集数据处理 +print("\n[步骤 5/7] 训练集数据处理") print("-" * 60) fitted_processors = [] if processors: @@ -296,7 +310,7 @@ for col in feature_cols[:5]: # 只显示前5个特征的缺失值 if null_count > 0: print(f" {col}: {null_count} ({null_count / len(train_data) * 100:.2f}%)") # %% -# 步骤 4: 训练模型 +# 步骤 5: 训练模型 print("\n[步骤 5/7] 训练模型") print("-" * 60) print(f" 模型类型: LightGBM") @@ -318,7 +332,7 @@ print("\n 开始训练...") model.fit(X_train, y_train) print(" 训练完成!") # %% -# 步骤 5: 测试集数据处理 +# 步骤 6: 测试集数据处理 print("\n[步骤 6/7] 测试集数据处理") print("-" * 60) if processors and test_data is not train_data: @@ -334,7 +348,7 @@ if processors and test_data is not train_data: else: print(" 跳过测试集处理") # %% -# 步骤 6: 生成预测 +# 步骤 7: 生成预测 print("\n[步骤 7/7] 生成预测") print("-" * 60) X_test = test_data.select(feature_cols) diff --git a/tests/test_sync_transaction_and_logs.py b/tests/test_sync_transaction_and_logs.py new file mode 100644 index 0000000..b6bf2c2 --- /dev/null +++ b/tests/test_sync_transaction_and_logs.py @@ -0,0 +1,403 @@ +"""测试同步事务和日志功能。 + +测试内容: +1. 事务支持 - BEGIN/COMMIT/ROLLBACK +2. 同步日志记录 - SyncLogManager +3. ThreadSafeStorage 事务批量写入 +""" + +import pytest +import pandas as pd +import os +import tempfile +from datetime import datetime, timedelta +from pathlib import Path + +# 设置测试环境变量 +os.environ["DATA_PATH"] = tempfile.mkdtemp() +os.environ["TUSHARE_TOKEN"] = "test_token" + +from src.data.storage import Storage, ThreadSafeStorage +from src.data.sync_logger import SyncLogManager, SyncLogEntry + + +@pytest.fixture +def temp_storage(): + """创建临时存储实例用于测试。""" + # 使用临时目录 + temp_dir = tempfile.mkdtemp() + os.environ["DATA_PATH"] = temp_dir + + # 重置 Storage 单例 + Storage._instance = None + Storage._connection = None + + storage = Storage(read_only=False) + yield storage + + # 清理 + storage.close() + Storage._instance = None + Storage._connection = None + + +class TestTransactionSupport: + """测试 Storage 事务支持。""" + + def test_begin_commit_transaction(self, temp_storage): + """测试事务开始和提交。""" + # 创建测试表 + temp_storage._connection.execute(""" + CREATE TABLE test_table (id INTEGER PRIMARY KEY, value VARCHAR(50)) + """) + + # 开始事务 + temp_storage.begin_transaction() + + # 插入数据 + temp_storage._connection.execute( + "INSERT INTO test_table VALUES (1, 'test1'), (2, 'test2')" + ) + + # 提交事务 + temp_storage.commit_transaction() + + # 验证数据已提交 + result = temp_storage._connection.execute( + "SELECT COUNT(*) FROM test_table" + ).fetchone() + assert result[0] == 2 + + def test_rollback_transaction(self, temp_storage): + """测试事务回滚。""" + # 创建测试表并插入初始数据 + temp_storage._connection.execute(""" + CREATE TABLE test_table2 (id INTEGER PRIMARY KEY, value VARCHAR(50)) + """) + temp_storage._connection.execute( + "INSERT INTO test_table2 VALUES (1, 'initial')" + ) + + # 开始事务并插入更多数据 + temp_storage.begin_transaction() + temp_storage._connection.execute("INSERT INTO test_table2 VALUES (2, 'temp')") + + # 回滚事务 + temp_storage.rollback_transaction() + + # 验证临时数据未提交 + result = temp_storage._connection.execute( + "SELECT COUNT(*) FROM test_table2" + ).fetchone() + assert result[0] == 1 + + def test_transaction_context_manager(self, temp_storage): + """测试事务上下文管理器。""" + # 创建测试表 + temp_storage._connection.execute(""" + CREATE TABLE test_table3 (id INTEGER PRIMARY KEY, value VARCHAR(50)) + """) + + # 使用上下文管理器(正常完成) + with temp_storage.transaction(): + temp_storage._connection.execute( + "INSERT INTO test_table3 VALUES (1, 'committed')" + ) + + # 验证数据已提交 + result = temp_storage._connection.execute( + "SELECT COUNT(*) FROM test_table3" + ).fetchone() + assert result[0] == 1 + + def test_transaction_context_manager_rollback(self, temp_storage): + """测试事务上下文管理器异常回滚。""" + # 创建测试表 + temp_storage._connection.execute(""" + CREATE TABLE test_table4 (id INTEGER PRIMARY KEY, value VARCHAR(50)) + """) + + # 使用上下文管理器(发生异常) + try: + with temp_storage.transaction(): + temp_storage._connection.execute( + "INSERT INTO test_table4 VALUES (1, 'temp')" + ) + raise ValueError("Test error") + except ValueError: + pass + + # 验证数据未提交 + result = temp_storage._connection.execute( + "SELECT COUNT(*) FROM test_table4" + ).fetchone() + assert result[0] == 0 + + +class TestSyncLogManager: + """测试同步日志管理器。""" + + def test_log_table_creation(self, temp_storage): + """测试日志表自动创建。""" + # 创建日志管理器会自动创建表 + log_manager = SyncLogManager(temp_storage) + + # 验证表存在 + result = temp_storage._connection.execute(""" + SELECT COUNT(*) FROM information_schema.tables + WHERE table_name = '_sync_logs' + """).fetchone() + assert result[0] == 1 + + def test_start_sync(self, temp_storage): + """测试开始记录同步。""" + log_manager = SyncLogManager(temp_storage) + + # 记录同步开始 + entry = log_manager.start_sync( + table_name="test_table", + sync_type="incremental", + date_range_start="20240101", + date_range_end="20240131", + metadata={"test": True}, + ) + + assert entry.table_name == "test_table" + assert entry.sync_type == "incremental" + assert entry.status == "running" + assert entry.date_range_start == "20240101" + + def test_complete_sync(self, temp_storage): + """测试完成同步记录。""" + log_manager = SyncLogManager(temp_storage) + + # 开始同步 + entry = log_manager.start_sync(table_name="test_table", sync_type="full") + + # 完成同步 + log_manager.complete_sync( + entry, + status="success", + records_inserted=1000, + records_updated=100, + records_deleted=10, + ) + + assert entry.status == "success" + assert entry.records_inserted == 1000 + assert entry.records_updated == 100 + assert entry.records_deleted == 10 + assert entry.end_time is not None + + def test_complete_sync_with_error(self, temp_storage): + """测试失败同步记录。""" + log_manager = SyncLogManager(temp_storage) + + entry = log_manager.start_sync(table_name="test_table", sync_type="incremental") + + log_manager.complete_sync( + entry, status="failed", error_message="Connection timeout" + ) + + assert entry.status == "failed" + assert entry.error_message == "Connection timeout" + + def test_get_sync_history(self, temp_storage): + """测试查询同步历史。""" + log_manager = SyncLogManager(temp_storage) + + # 创建几条记录 + for i in range(3): + entry = log_manager.start_sync( + table_name="test_table", sync_type="incremental" + ) + log_manager.complete_sync(entry, status="success", records_inserted=100) + + # 查询历史 + history = log_manager.get_sync_history(table_name="test_table", limit=10) + + assert len(history) == 3 + assert all(h["table_name"] == "test_table" for _, h in history.iterrows()) + + def test_get_last_sync(self, temp_storage): + """测试获取最近同步记录。""" + log_manager = SyncLogManager(temp_storage) + + # 创建两条记录 + entry1 = log_manager.start_sync(table_name="table1", sync_type="full") + log_manager.complete_sync(entry1, status="success") + + entry2 = log_manager.start_sync(table_name="table1", sync_type="incremental") + log_manager.complete_sync(entry2, status="success") + + # 获取最近一次 + last_sync = log_manager.get_last_sync("table1") + + assert last_sync is not None + assert last_sync["sync_type"] == "incremental" + + def test_get_sync_summary(self, temp_storage): + """测试获取同步汇总统计。""" + log_manager = SyncLogManager(temp_storage) + + # 创建多条记录 + for i in range(5): + entry = log_manager.start_sync( + table_name="test_table", sync_type="incremental" + ) + log_manager.complete_sync(entry, status="success", records_inserted=100) + + # 添加一条失败记录 + entry = log_manager.start_sync(table_name="test_table", sync_type="full") + log_manager.complete_sync(entry, status="failed", error_message="error") + + # 获取汇总 + summary = log_manager.get_sync_summary("test_table", days=30) + + assert summary["total_syncs"] == 6 + assert summary["success_count"] == 5 + assert summary["failed_count"] == 1 + assert summary["total_inserted"] == 500 + + +class TestThreadSafeStorageTransaction: + """测试 ThreadSafeStorage 事务支持。""" + + def test_flush_with_transaction(self, temp_storage): + """测试带事务的批量写入。""" + # 重置 Storage 单例 + Storage._instance = None + Storage._connection = None + + ts_storage = ThreadSafeStorage() + + # 创建测试表 + ts_storage.storage._connection.execute(""" + CREATE TABLE test_data (ts_code VARCHAR(16), trade_date DATE, value DOUBLE) + """) + + # 准备测试数据 + df1 = pd.DataFrame( + { + "ts_code": ["000001.SZ", "000002.SZ"], + "trade_date": ["20240101", "20240101"], + "value": [100.0, 200.0], + } + ) + + df2 = pd.DataFrame( + { + "ts_code": ["000003.SZ", "000004.SZ"], + "trade_date": ["20240102", "20240102"], + "value": [300.0, 400.0], + } + ) + + # 加入队列 + ts_storage.queue_save("test_data", df1) + ts_storage.queue_save("test_data", df2) + + # 使用事务刷新 + ts_storage.flush(use_transaction=True) + + # 验证数据 + result = ts_storage.storage._connection.execute( + "SELECT COUNT(*) FROM test_data" + ).fetchone() + assert result[0] == 4 + + def test_flush_rollback_on_error(self, temp_storage): + """测试错误时回滚。""" + # 这个测试比较复杂,需要模拟错误情况 + # 简化版本:验证错误不会导致数据不一致 + Storage._instance = None + Storage._connection = None + + ts_storage = ThreadSafeStorage() + + # 创建测试表(使用唯一表名) + ts_storage.storage._connection.execute(""" + CREATE TABLE test_data2 (ts_code VARCHAR(16) PRIMARY KEY, value DOUBLE) + """) + + # 插入初始数据 + df = pd.DataFrame({"ts_code": ["000001.SZ"], "value": [100.0]}) + ts_storage.queue_save("test_data2", df, use_upsert=False) + ts_storage.flush(use_transaction=True) + + # 验证 + result = ts_storage.storage._connection.execute( + "SELECT COUNT(*) FROM test_data2" + ).fetchone() + assert result[0] == 1 + + +class TestIntegration: + """集成测试。""" + + def test_full_sync_workflow(self, temp_storage): + """测试完整同步工作流。""" + # 1. 初始化日志管理器 + log_manager = SyncLogManager(temp_storage) + + # 2. 创建测试表 + temp_storage._connection.execute(""" + CREATE TABLE sync_test_table ( + ts_code VARCHAR(16), + trade_date DATE, + value DOUBLE, + PRIMARY KEY (ts_code, trade_date) + ) + """) + + # 3. 开始同步 + log_entry = log_manager.start_sync( + table_name="sync_test_table", + sync_type="full", + date_range_start="20240101", + date_range_end="20240131", + ) + + # 4. 在事务中执行同步 + temp_storage.begin_transaction() + try: + # 模拟同步:插入数据 + df = pd.DataFrame( + { + "ts_code": ["000001.SZ", "000002.SZ"], + "trade_date": ["20240115", "20240115"], + "value": [100.0, 200.0], + } + ) + + # 转换日期格式 + df["trade_date"] = pd.to_datetime(df["trade_date"], format="%Y%m%d").dt.date + + # 使用 storage.save + temp_storage.save("sync_test_table", df, mode="append") + + # 提交事务 + temp_storage.commit_transaction() + + # 5. 记录成功 + log_manager.complete_sync( + log_entry, status="success", records_inserted=len(df) + ) + + except Exception as e: + temp_storage.rollback_transaction() + log_manager.complete_sync(log_entry, status="failed", error_message=str(e)) + raise + + # 6. 验证 + assert log_entry.status == "success" + assert log_entry.records_inserted == 2 + + # 7. 查询日志历史 + history = log_manager.get_sync_history(table_name="sync_test_table") + assert len(history) == 1 + + +if __name__ == "__main__": + # 运行测试 + pytest.main([__file__, "-v"])