feat(data): 为数据同步添加事务支持和同步日志

- Storage/ThreadSafeStorage 添加事务支持(begin/commit/rollback)
- 新增 SyncLogManager 记录所有同步任务的执行状态
- 集成事务到 StockBasedSync、DateBasedSync、QuarterBasedSync
- 在 sync_all 和 sync_financial 调度中心添加日志记录
- 新增测试验证事务和日志功能
This commit is contained in:
2026-03-23 21:10:15 +08:00
parent 31b25074c3
commit bace4cc5f4
10 changed files with 1468 additions and 177 deletions

View File

@@ -76,7 +76,11 @@ def get_stock_basic(
return data
def sync_all_stocks() -> pd.DataFrame:
def sync_all_stocks(
force_full: bool = False,
dry_run: bool = False,
**kwargs,
) -> pd.DataFrame:
"""Fetch and save all stocks (listed and delisted) to local storage.
This is a special interface that should only be called once to initialize

View File

@@ -82,6 +82,9 @@ def sync_trade_cal_cache(
start_date: str = "20180101",
end_date: Optional[str] = None,
force: bool = False,
force_full: bool = False,
dry_run: bool = False,
**kwargs,
) -> pd.DataFrame:
"""Sync trade calendar data to local cache with incremental updates.

View File

@@ -25,6 +25,7 @@ from tqdm import tqdm
from src.data.client import TushareClient
from src.data.storage import ThreadSafeStorage, Storage
from src.data.sync_logger import SyncLogManager
from src.data.utils import get_today_date, get_quarters_in_range, DEFAULT_START_DATE
@@ -466,18 +467,28 @@ class QuarterBasedSync(ABC):
inserted_count = 0
if is_first_sync_for_period:
# 首次同步该季度:直接插入所有数据
# 首次同步该季度:直接插入所有数据(使用事务)
print(
f"[{self.__class__.__name__}] First sync for quarter {period}, inserting all data directly"
)
if not dry_run:
self.storage.queue_save(self.table_name, remote_df, use_upsert=False)
self.storage.flush()
inserted_count = len(remote_df)
print(
f"[{self.__class__.__name__}] Inserted {inserted_count} new records"
)
try:
# 开始事务
self.storage.begin_transaction()
self.storage.queue_save(
self.table_name, remote_df, use_upsert=False
)
self.storage.flush(use_transaction=False)
self.storage.commit_transaction()
inserted_count = len(remote_df)
print(
f"[{self.__class__.__name__}] Inserted {inserted_count} new records (transaction committed)"
)
except Exception as e:
self.storage.rollback_transaction()
print(f"[{self.__class__.__name__}] Transaction rolled back: {e}")
raise
return {
"period": period,
@@ -501,19 +512,33 @@ class QuarterBasedSync(ABC):
print(f" - Unchanged stocks: {unchanged_count}")
if not dry_run and not diff_df.empty:
# 5.1 删除差异股票的旧数据
deleted_stocks_count = len(diff_stocks)
self.delete_stock_quarter_data(period, diff_stocks)
deleted_count = len(diff_df)
print(
f"[{self.__class__.__name__}] Deleted {deleted_stocks_count} stocks' old records (approx {deleted_count} rows)"
)
try:
# 开始事务
self.storage.begin_transaction()
# 5.2 插入新数据(使用普通 INSERT因为已删除旧数据
self.storage.queue_save(self.table_name, diff_df, use_upsert=False)
self.storage.flush()
inserted_count = len(diff_df)
print(f"[{self.__class__.__name__}] Inserted {inserted_count} new records")
# 5.1 删除差异股票的旧数据
deleted_stocks_count = len(diff_stocks)
self.delete_stock_quarter_data(period, diff_stocks)
deleted_count = len(diff_df)
print(
f"[{self.__class__.__name__}] Deleted {deleted_stocks_count} stocks' old records (approx {deleted_count} rows)"
)
# 5.2 插入新数据(使用普通 INSERT因为已删除旧数据
self.storage.queue_save(self.table_name, diff_df, use_upsert=False)
self.storage.flush(use_transaction=False)
inserted_count = len(diff_df)
# 提交事务
self.storage.commit_transaction()
print(
f"[{self.__class__.__name__}] Inserted {inserted_count} new records (transaction committed)"
)
except Exception as e:
self.storage.rollback_transaction()
print(f"[{self.__class__.__name__}] Transaction rolled back: {e}")
raise
return {
"period": period,
@@ -583,55 +608,86 @@ class QuarterBasedSync(ABC):
print(f"[{self.__class__.__name__}] Incremental Sync")
print(f"{'=' * 60}")
# 0. 确保表存在(首次同步时自动建表)
self.ensure_table_exists()
# 初始化日志管理器
log_manager = SyncLogManager()
log_entry = log_manager.start_sync(
table_name=self.table_name,
sync_type="incremental",
metadata={"dry_run": dry_run},
)
# 1. 获取最新季度
storage = Storage()
try:
result = storage._connection.execute(
f'SELECT MAX(end_date) FROM "{self.table_name}"'
).fetchone()
latest_quarter = result[0] if result and result[0] else None
if hasattr(latest_quarter, "strftime"):
latest_quarter = latest_quarter.strftime("%Y%m%d")
# 0. 确保表存在(首次同步时自动建表)
self.ensure_table_exists()
# 1. 获取最新季度
storage = Storage()
try:
result = storage._connection.execute(
f'SELECT MAX(end_date) FROM "{self.table_name}"'
).fetchone()
latest_quarter = result[0] if result and result[0] else None
if hasattr(latest_quarter, "strftime"):
latest_quarter = latest_quarter.strftime("%Y%m%d")
except Exception as e:
print(f"[{self.__class__.__name__}] Error getting latest quarter: {e}")
latest_quarter = None
# 2. 获取当前季度
current_quarter = self.get_current_quarter()
if latest_quarter is None:
# 无本地数据,执行全量同步
print(
f"[{self.__class__.__name__}] No local data, performing full sync"
)
results = self.sync_range(
self.DEFAULT_START_DATE, current_quarter, dry_run
)
else:
print(
f"[{self.__class__.__name__}] Latest local quarter: {latest_quarter}"
)
print(f"[{self.__class__.__name__}] Current quarter: {current_quarter}")
# 3. 确定同步范围
start_quarter = latest_quarter
if start_quarter > current_quarter:
start_quarter = current_quarter
else:
start_quarter = self.get_prev_quarter(latest_quarter)
if start_quarter < self.DEFAULT_START_DATE:
start_quarter = self.DEFAULT_START_DATE
# 打印同步的两个季度信息
print(f"\n[{self.__class__.__name__}] 将同步以下两个季度的财报:")
print(f" - 前一季度: {start_quarter}")
print(f" - 当前季度: {current_quarter}")
print(f" (包含前一季度以确保数据完整性)")
print()
results = self.sync_range(start_quarter, current_quarter, dry_run)
# 计算总插入记录数
total_inserted = sum(
r.get("inserted_count", 0) for r in results if isinstance(r, dict)
)
# 完成日志记录
log_manager.complete_sync(
log_entry,
status="success",
records_inserted=total_inserted,
records_updated=0,
records_deleted=0,
)
return results
except Exception as e:
print(f"[{self.__class__.__name__}] Error getting latest quarter: {e}")
latest_quarter = None
# 2. 获取当前季度
current_quarter = self.get_current_quarter()
if latest_quarter is None:
# 无本地数据,执行全量同步
print(f"[{self.__class__.__name__}] No local data, performing full sync")
return self.sync_range(self.DEFAULT_START_DATE, current_quarter, dry_run)
print(f"[{self.__class__.__name__}] Latest local quarter: {latest_quarter}")
print(f"[{self.__class__.__name__}] Current quarter: {current_quarter}")
# 3. 确定同步范围
# 财务数据必须每次都进行对比更新,不存在"跳过"的情况
# 同步范围:从最新季度到当前季度(包含前一季度以确保数据完整性)
start_quarter = latest_quarter
if start_quarter > current_quarter:
# 如果本地数据比当前季度还新,仍然需要同步(可能包含修正数据)
start_quarter = current_quarter
else:
# 正常情况:包含前一季度
start_quarter = self.get_prev_quarter(latest_quarter)
if start_quarter < self.DEFAULT_START_DATE:
start_quarter = self.DEFAULT_START_DATE
# 打印同步的两个季度信息
print(f"\n[{self.__class__.__name__}] 将同步以下两个季度的财报:")
print(f" - 前一季度: {start_quarter}")
print(f" - 当前季度: {current_quarter}")
print(f" (包含前一季度以确保数据完整性)")
print()
return self.sync_range(start_quarter, current_quarter, dry_run)
log_manager.complete_sync(log_entry, status="failed", error_message=str(e))
raise
def sync_full(self, dry_run: bool = False) -> List[Dict]:
"""执行全量同步。
@@ -646,12 +702,38 @@ class QuarterBasedSync(ABC):
print(f"[{self.__class__.__name__}] Full Sync")
print(f"{'=' * 60}")
# 确保表存在
self.ensure_table_exists()
# 初始化日志管理器
log_manager = SyncLogManager()
log_entry = log_manager.start_sync(
table_name=self.table_name, sync_type="full", metadata={"dry_run": dry_run}
)
current_quarter = self.get_current_quarter()
try:
# 确保表存在
self.ensure_table_exists()
return self.sync_range(self.DEFAULT_START_DATE, current_quarter, dry_run)
current_quarter = self.get_current_quarter()
results = self.sync_range(self.DEFAULT_START_DATE, current_quarter, dry_run)
# 计算总插入记录数
total_inserted = sum(
r.get("inserted_count", 0) for r in results if isinstance(r, dict)
)
# 完成日志记录
log_manager.complete_sync(
log_entry,
status="success",
records_inserted=total_inserted,
records_updated=0,
records_deleted=0,
)
return results
except Exception as e:
log_manager.complete_sync(log_entry, status="failed", error_message=str(e))
raise
# ======================================================================
# 预览模式

View File

@@ -34,6 +34,7 @@ from tqdm import tqdm
from src.data.client import TushareClient
from src.data.storage import ThreadSafeStorage, Storage
from src.data.sync_logger import SyncLogManager
from src.data.utils import get_today_date, get_next_date
from src.config.settings import get_settings
from src.data.api_wrappers.api_trade_cal import (
@@ -614,14 +615,30 @@ class StockBasedSync(BaseDataSync):
finally:
pbar.close()
# 批量写入所有数据(仅在无错误时)
# 批量写入所有数据(仅在无错误时),使用事务确保原子性
if results and not error_occurred:
for ts_code, data in results.items():
if not data.empty:
self.storage.queue_save(self.table_name, data)
self.storage.flush()
total_rows = sum(len(df) for df in results.values())
print(f"\n[{class_name}] Saved {total_rows} rows to storage")
try:
# 开始事务
self.storage.begin_transaction()
for ts_code, data in results.items():
if not data.empty:
self.storage.queue_save(self.table_name, data)
# flush 在事务中执行
self.storage.flush(use_transaction=False)
# 提交事务
self.storage.commit_transaction()
total_rows = sum(len(df) for df in results.values())
print(
f"\n[{class_name}] Saved {total_rows} rows to storage (transaction committed)"
)
except Exception as e:
# 回滚事务
self.storage.rollback_transaction()
print(f"\n[{class_name}] Transaction rolled back due to error: {e}")
error_occurred = True
exception_to_raise = e
# 打印摘要
print(f"\n{'=' * 60}")
@@ -824,6 +841,17 @@ class StockBasedSync(BaseDataSync):
print(f"[{class_name}] Starting {self.table_name} data sync...")
print(f"{'=' * 60}")
# 初始化日志管理器并记录开始
log_manager = SyncLogManager()
sync_mode = "full" if force_full else "incremental"
log_entry = log_manager.start_sync(
table_name=self.table_name,
sync_type=sync_mode,
date_range_start=start_date,
date_range_end=end_date,
metadata={"dry_run": dry_run, "max_workers": max_workers},
)
# 首先确保交易日历缓存是最新的
print(f"[{class_name}] Syncing trade calendar cache...")
sync_trade_cal_cache()
@@ -911,13 +939,33 @@ class StockBasedSync(BaseDataSync):
print(f"[{class_name}] Using {max_workers or self.max_workers} worker threads")
# 执行并发同步
return self._run_concurrent_sync(
stock_codes=stock_codes,
start_date=sync_start_date,
end_date=end_date,
max_workers=max_workers,
dry_run=dry_run,
)
try:
results = self._run_concurrent_sync(
stock_codes=stock_codes,
start_date=sync_start_date,
end_date=end_date,
max_workers=max_workers,
dry_run=dry_run,
)
# 计算同步结果统计
total_inserted = sum(len(df) for df in results.values()) if results else 0
# 完成日志记录
log_manager.complete_sync(
log_entry,
status="success" if results else "partial",
records_inserted=total_inserted,
records_updated=0,
records_deleted=0,
)
return results
except Exception as e:
# 记录失败日志
log_manager.complete_sync(log_entry, status="failed", error_message=str(e))
raise
class DateBasedSync(BaseDataSync):
@@ -1117,33 +1165,52 @@ class DateBasedSync(BaseDataSync):
return combined
def _save_data(self, data: pd.DataFrame, sync_start_date: str) -> None:
"""保存数据到存储。
def _save_data(self, data: pd.DataFrame, sync_start_date: str) -> int:
"""保存数据到存储(使用事务确保原子性)
Args:
data: 要保存的数据
sync_start_date: 同步起始日期(用于删除旧数据)
Returns:
保存的记录数
"""
if data.empty:
return
return 0
storage = Storage()
storage = Storage(read_only=False)
records_count = len(data)
# 删除日期范围内的旧数据
if sync_start_date:
sync_start_date_fmt = pd.to_datetime(
sync_start_date, format="%Y%m%d"
).date()
storage._connection.execute(
f'DELETE FROM "{self.table_name}" WHERE "trade_date" >= ?',
[sync_start_date_fmt],
try:
# 开始事务
storage.begin_transaction()
# 删除日期范围内的旧数据
if sync_start_date:
sync_start_date_fmt = pd.to_datetime(
sync_start_date, format="%Y%m%d"
).date()
storage._connection.execute(
f'DELETE FROM "{self.table_name}" WHERE "trade_date" >= ?',
[sync_start_date_fmt],
)
# 保存新数据(在事务中直接保存,不使用队列)
storage.save(self.table_name, data, mode="append")
# 提交事务
storage.commit_transaction()
print(
f"[{self.__class__.__name__}] Saved {records_count} records to DuckDB (transaction committed)"
)
return records_count
# 保存新数据
self.storage.queue_save(self.table_name, data)
self.storage.flush()
print(f"[{self.__class__.__name__}] Saved {len(data)} records to DuckDB")
except Exception as e:
storage.rollback_transaction()
print(
f"[{self.__class__.__name__}] Transaction rolled back due to error: {e}"
)
raise
def preview_sync(
self,
@@ -1280,12 +1347,25 @@ class DateBasedSync(BaseDataSync):
print(f"[{class_name}] Starting {self.table_name} data sync...")
print(f"{'=' * 60}")
# 初始化日志管理器并记录开始
log_manager = SyncLogManager()
sync_mode = "full" if force_full else "incremental"
log_entry = log_manager.start_sync(
table_name=self.table_name,
sync_type=sync_mode,
date_range_start=start_date,
date_range_end=end_date,
metadata={"dry_run": dry_run},
)
# 确定日期范围
sync_start, sync_end, mode = self._get_sync_date_range(
start_date, end_date, force_full
)
if mode == "up_to_date":
# 记录跳过日志
log_manager.complete_sync(log_entry, status="success", records_inserted=0)
return pd.DataFrame()
if dry_run:
@@ -1295,49 +1375,74 @@ class DateBasedSync(BaseDataSync):
print(f" Would sync from {sync_start} to {sync_end}")
print(f" Mode: {mode}")
print(f"{'=' * 60}")
# 记录 dry run 日志
log_manager.complete_sync(log_entry, status="success", records_inserted=0)
return pd.DataFrame()
# 检查表是否存在,不存在则创建
storage = Storage()
if not storage.exists(self.table_name):
print(
f"[{class_name}] Table '{self.table_name}' doesn't exist, creating..."
try:
# 检查表是否存在,不存在则创建
storage = Storage()
if not storage.exists(self.table_name):
print(
f"[{class_name}] Table '{self.table_name}' doesn't exist, creating..."
)
# 获取样本数据以推断 schema
sample = self.fetch_single_date(sync_end)
if sample.empty:
# 尝试另一个日期
sample = self.fetch_single_date("20240102")
if not sample.empty:
self._ensure_table_schema(sample)
else:
print(
f"[{class_name}] Cannot create table: no sample data available"
)
log_manager.complete_sync(
log_entry,
status="failed",
error_message="No sample data available",
)
return pd.DataFrame()
# 首次同步探测:验证表结构是否正常
if self._should_probe_table():
print(f"[{class_name}] Table '{self.table_name}' is empty, probing...")
# 使用最近一个交易日的完整数据进行探测
probe_date = get_last_trading_day(sync_start, sync_end)
if probe_date:
probe_data = self.fetch_single_date(probe_date)
probe_desc = f"date={probe_date}, all stocks"
self._probe_table_and_cleanup(probe_data, probe_desc)
# 执行同步
combined = self._run_date_range_sync(sync_start, sync_end, dry_run)
# 保存数据(在事务中)
records_saved = 0
if not combined.empty:
records_saved = self._save_data(combined, sync_start)
# 打印摘要
print(f"\n{'=' * 60}")
print(f"[{class_name}] Sync Summary")
print(f"{'=' * 60}")
print(f" Mode: {mode}")
print(f" Date range: {sync_start} to {sync_end}")
print(f" Records: {len(combined)}")
print(f"{'=' * 60}")
# 完成日志记录
log_manager.complete_sync(
log_entry,
status="success",
records_inserted=records_saved,
records_updated=0,
records_deleted=0,
)
# 获取样本数据以推断 schema
sample = self.fetch_single_date(sync_end)
if sample.empty:
# 尝试另一个日期
sample = self.fetch_single_date("20240102")
if not sample.empty:
self._ensure_table_schema(sample)
else:
print(f"[{class_name}] Cannot create table: no sample data available")
return pd.DataFrame()
# 首次同步探测:验证表结构是否正常
if self._should_probe_table():
print(f"[{class_name}] Table '{self.table_name}' is empty, probing...")
# 使用最近一个交易日的完整数据进行探测
probe_date = get_last_trading_day(sync_start, sync_end)
if probe_date:
probe_data = self.fetch_single_date(probe_date)
probe_desc = f"date={probe_date}, all stocks"
self._probe_table_and_cleanup(probe_data, probe_desc)
return combined
# 执行同步
combined = self._run_date_range_sync(sync_start, sync_end, dry_run)
# 保存数据
if not combined.empty:
self._save_data(combined, sync_start)
# 打印摘要
print(f"\n{'=' * 60}")
print(f"[{class_name}] Sync Summary")
print(f"{'=' * 60}")
print(f" Mode: {mode}")
print(f" Date range: {sync_start} to {sync_end}")
print(f" Records: {len(combined)}")
print(f"{'=' * 60}")
return combined
except Exception as e:
# 记录失败日志
log_manager.complete_sync(log_entry, status="failed", error_message=str(e))
raise

View File

@@ -36,6 +36,7 @@
from typing import List, Optional
from src.data.sync_logger import SyncLogManager
from src.data.api_wrappers.financial_data.api_income import (
IncomeQuarterSync,
sync_income,
@@ -120,6 +121,14 @@ def sync_financial(
results = {}
# 初始化日志管理器并记录调度中心开始
log_manager = SyncLogManager()
log_entry = log_manager.start_sync(
table_name="financial_data_batch",
sync_type="full" if force_full else "incremental",
metadata={"data_types": data_types, "dry_run": dry_run},
)
print("\n" + "=" * 60)
print("[Financial Sync] 财务数据同步调度中心")
print("=" * 60)
@@ -128,10 +137,14 @@ def sync_financial(
print(f"写入模式: {'预览' if dry_run else '实际写入'}")
print("=" * 60)
total_inserted = 0
failed_types = []
for data_type in data_types:
if data_type not in FINANCIAL_SYNCERS:
print(f"[WARN] 未知的数据类型: {data_type}")
results[data_type] = {"error": f"Unknown data type: {data_type}"}
failed_types.append(data_type)
continue
config = FINANCIAL_SYNCERS[data_type]
@@ -143,10 +156,19 @@ def sync_financial(
try:
result = sync_func(force_full=force_full, dry_run=dry_run)
results[data_type] = result
# 统计插入的记录数
if isinstance(result, list):
inserted = sum(
r.get("inserted_count", 0) for r in result if isinstance(r, dict)
)
total_inserted += inserted
print(f"[{display_name}] 同步完成")
except Exception as e:
print(f"[ERROR] [{display_name}] 同步失败: {e}")
results[data_type] = {"error": str(e)}
failed_types.append(data_type)
# 打印汇总
print("\n" + "=" * 60)
@@ -169,6 +191,16 @@ def sync_financial(
print(f" {display_name}: {status}")
print("=" * 60)
# 完成调度中心日志记录
status = "failed" if failed_types else "success"
error_msg = f"Failed types: {failed_types}" if failed_types else None
log_manager.complete_sync(
log_entry,
status=status,
records_inserted=total_inserted,
error_message=error_msg,
)
return results

View File

@@ -313,6 +313,87 @@ class Storage:
Storage._connection = None
Storage._instance = None
# ======================================================================
# 事务支持
# ======================================================================
def begin_transaction(self) -> None:
"""开始事务。
使用方式:
storage = Storage(read_only=False)
storage.begin_transaction()
try:
storage.save("table1", data1)
storage.save("table2", data2)
storage.commit_transaction()
except Exception:
storage.rollback_transaction()
raise
"""
if self._read_only:
raise RuntimeError("Cannot begin transaction in read-only mode")
try:
self._connection.execute("BEGIN TRANSACTION")
print("[Storage] Transaction started")
except Exception as e:
print(f"[Storage] Error starting transaction: {e}")
raise
def commit_transaction(self) -> None:
"""提交事务。"""
if self._read_only:
return
try:
self._connection.execute("COMMIT")
print("[Storage] Transaction committed")
except Exception as e:
print(f"[Storage] Error committing transaction: {e}")
raise
def rollback_transaction(self) -> None:
"""回滚事务。"""
if self._read_only:
return
try:
self._connection.execute("ROLLBACK")
print("[Storage] Transaction rolled back")
except Exception as e:
print(f"[Storage] Error rolling back transaction: {e}")
def transaction(self):
"""事务上下文管理器。
使用方式:
storage = Storage(read_only=False)
with storage.transaction():
storage.save("table1", data1)
storage.save("table2", data2)
# 自动提交或回滚
"""
return _TransactionContext(self)
class _TransactionContext:
"""事务上下文管理器。"""
def __init__(self, storage: Storage):
self.storage = storage
def __enter__(self):
self.storage.begin_transaction()
return self.storage
def __exit__(self, exc_type, exc_val, exc_tb):
if exc_type is None:
self.storage.commit_transaction()
else:
self.storage.rollback_transaction()
return False
class ThreadSafeStorage:
"""线程安全的 DuckDB 写入包装器。
@@ -323,12 +404,14 @@ class ThreadSafeStorage:
注意:
- 此类自动使用 read_only=False 模式,用于数据同步
- 不要在多进程中同时使用此类,只应在单进程中用于批量写入
- 支持事务模式:在事务中批量写入,确保原子性
"""
def __init__(self):
# 使用 read_only=False 模式创建 Storage用于写入操作
self.storage = Storage(read_only=False)
self._pending_writes: List[tuple] = [] # [(name, data, use_upsert), ...]
self._in_transaction: bool = False
def queue_save(self, name: str, data: pd.DataFrame, use_upsert: bool = True):
"""将数据放入写入队列(不立即写入)
@@ -341,10 +424,13 @@ class ThreadSafeStorage:
if not data.empty:
self._pending_writes.append((name, data, use_upsert))
def flush(self):
def flush(self, use_transaction: bool = True):
"""批量写入所有队列数据。
调用时机:在 sync 结束时统一调用,避免并发写入冲突。
Args:
use_transaction: 若为 True使用事务包装批量写入
"""
if not self._pending_writes:
return
@@ -357,17 +443,72 @@ class ThreadSafeStorage:
table_data[key].append(data)
# 批量写入每个表
for (name, use_upsert), data_list in table_data.items():
combined = pd.concat(data_list, ignore_index=True)
# 在批量数据中先去重
if "ts_code" in combined.columns and "trade_date" in combined.columns:
combined = combined.drop_duplicates(
subset=["ts_code", "trade_date"], keep="last"
if use_transaction:
# 使用事务确保原子性
try:
self.storage.begin_transaction()
total_rows = 0
for (name, use_upsert), data_list in table_data.items():
combined = pd.concat(data_list, ignore_index=True)
# 在批量数据中先去重
if (
"ts_code" in combined.columns
and "trade_date" in combined.columns
):
combined = combined.drop_duplicates(
subset=["ts_code", "trade_date"], keep="last"
)
result = self.storage.save(
name, combined, mode="append", use_upsert=use_upsert
)
if result.get("status") == "success":
total_rows += result.get("rows", 0)
self.storage.commit_transaction()
print(
f"[ThreadSafeStorage] Transaction committed: {total_rows} rows saved"
)
self.storage.save(name, combined, mode="append", use_upsert=use_upsert)
except Exception as e:
self.storage.rollback_transaction()
print(f"[ThreadSafeStorage] Transaction rolled back due to error: {e}")
raise
else:
# 不使用事务,逐表写入
for (name, use_upsert), data_list in table_data.items():
combined = pd.concat(data_list, ignore_index=True)
# 在批量数据中先去重
if "ts_code" in combined.columns and "trade_date" in combined.columns:
combined = combined.drop_duplicates(
subset=["ts_code", "trade_date"], keep="last"
)
self.storage.save(name, combined, mode="append", use_upsert=use_upsert)
self._pending_writes.clear()
def begin_transaction(self) -> None:
"""开始事务模式。
在事务模式下flush 会被延迟到 commit_transaction 或 rollback_transaction。
"""
self._in_transaction = True
self.storage.begin_transaction()
def commit_transaction(self) -> None:
"""提交事务。"""
if self._pending_writes:
# 先写入队列中的数据
self.flush(use_transaction=False)
self.storage.commit_transaction()
self._in_transaction = False
def rollback_transaction(self) -> None:
"""回滚事务。"""
self._pending_writes.clear()
self.storage.rollback_transaction()
self._in_transaction = False
def __getattr__(self, name):
"""代理其他方法到 Storage 实例"""
return getattr(self.storage, name)

View File

@@ -55,6 +55,7 @@ import pandas as pd
from src.data import api_wrappers # noqa: F401
from src.data.sync_registry import sync_registry
from src.data.api_wrappers import sync_all_stocks
from src.data.sync_logger import SyncLogManager
def sync_all_data(
@@ -109,7 +110,7 @@ def sync_all_data(
>>> result = sync_all_data(dry_run=True)
>>>
>>> # 只同步特定任务
>>> result = sync_all_data(selected=["trade_cal", "stock_basic"])
>>> result = sync_all_data(selected=['trade_cal', 'pro_bar'])
>>>
>>> # 查看所有可用任务
>>> from src.data.sync_registry import sync_registry
@@ -117,13 +118,80 @@ def sync_all_data(
>>> for t in tasks:
... print(f"{t.name}: {t.display_name}")
"""
return sync_registry.sync_all(
force_full=force_full,
max_workers=max_workers,
dry_run=dry_run,
selected=selected,
# 记录调度中心开始
log_manager = SyncLogManager()
sync_mode = "full" if force_full else "incremental"
selected_str = ",".join(selected) if selected else "all"
log_manager.start_sync(
table_name="daily_data_batch",
sync_type=sync_mode,
metadata={
"selected": selected_str,
"dry_run": dry_run,
"max_workers": max_workers,
},
)
try:
result = sync_registry.sync_all(
force_full=force_full,
max_workers=max_workers,
dry_run=dry_run,
selected=selected,
)
# 计算成功/失败数量
success_count = 0
failed_count = 0
total_records = 0
for task_name, task_result in result.items():
if isinstance(task_result, dict):
if task_result.get("status") == "error":
failed_count += 1
else:
success_count += 1
# 累加记录数(如果有)
if "rows" in task_result:
total_records += task_result.get("rows", 0)
elif isinstance(task_result, pd.DataFrame):
success_count += 1
total_records += len(task_result)
else:
success_count += 1
# 记录完成日志
status = "partial" if failed_count > 0 else "success"
error_msg = f"Failed: {failed_count} tasks" if failed_count > 0 else None
log_manager.complete_sync(
table_name="daily_data_batch",
sync_type=sync_mode,
status=status,
records_inserted=total_records,
error_message=error_msg,
metadata={
"selected": selected_str,
"dry_run": dry_run,
"max_workers": max_workers,
},
)
return result
except Exception as e:
# 记录失败日志
log_manager.complete_sync(
table_name="daily_data_batch",
sync_type=sync_mode,
status="failed",
error_message=str(e),
metadata={
"selected": selected_str,
"dry_run": dry_run,
"max_workers": max_workers,
},
)
raise
def list_sync_tasks() -> list[dict[str, Any]]:
"""列出所有已注册的同步任务。

439
src/data/sync_logger.py Normal file
View File

@@ -0,0 +1,439 @@
"""同步日志管理模块。
提供同步操作日志记录和查询功能,追踪每次数据同步的详细信息。
设计理念:日志只插入不更新,无需主键 ID。
"""
from datetime import datetime
from typing import Optional, Dict, Any, List
import pandas as pd
from dataclasses import dataclass
from src.data.storage import Storage
@dataclass
class SyncLogEntry:
"""同步日志条目。
Attributes:
table_name: 同步的表名
sync_type: 同步类型full/incremental/quarterly/daily 等)
start_time: 同步开始时间
end_time: 同步结束时间
status: 同步状态success/failed/partial/running
records_before: 同步前记录数
records_after: 同步后记录数
records_inserted: 插入的记录数
records_updated: 更新的记录数
records_deleted: 删除的记录数
date_range_start: 数据日期范围起始
date_range_end: 数据日期范围结束
error_message: 错误信息(如果有)
metadata: 额外元数据JSON 格式)
"""
table_name: str
sync_type: str
start_time: datetime
end_time: Optional[datetime] = None
status: str = "running"
records_before: int = 0
records_after: int = 0
records_inserted: int = 0
records_updated: int = 0
records_deleted: int = 0
date_range_start: Optional[str] = None
date_range_end: Optional[str] = None
error_message: Optional[str] = None
metadata: Optional[str] = None
class SyncLogManager:
"""同步日志管理器。
管理同步日志表的创建、记录插入和查询。
设计原则:日志只插入,不更新,无需主键。
使用方式:
# 记录一次同步操作(开始和结束分别记录)
log_manager = SyncLogManager()
log_manager.start_sync("daily", "incremental") # 记录开始
try:
# 执行同步...
log_manager.complete_sync("daily", "success", records_inserted=1000) # 记录完成
except Exception as e:
log_manager.complete_sync("daily", "failed", error_message=str(e)) # 记录失败
# 查询同步历史
logs = log_manager.get_sync_history("daily", limit=10)
"""
TABLE_NAME = "_sync_logs"
def __init__(self, storage: Optional[Storage] = None):
"""初始化日志管理器。
Args:
storage: Storage 实例,如果为 None 则创建新实例
"""
self.storage = storage or Storage(read_only=False)
self._ensure_table_exists()
def _ensure_table_exists(self) -> None:
"""确保日志表存在。"""
create_sql = f"""
CREATE TABLE IF NOT EXISTS {self.TABLE_NAME} (
table_name VARCHAR(64) NOT NULL,
sync_type VARCHAR(32) NOT NULL,
start_time TIMESTAMP NOT NULL,
end_time TIMESTAMP,
status VARCHAR(16) DEFAULT 'running',
records_before INTEGER DEFAULT 0,
records_after INTEGER DEFAULT 0,
records_inserted INTEGER DEFAULT 0,
records_updated INTEGER DEFAULT 0,
records_deleted INTEGER DEFAULT 0,
date_range_start VARCHAR(8),
date_range_end VARCHAR(8),
error_message TEXT,
metadata TEXT,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
)
"""
# 创建索引
index_sql = f"""
CREATE INDEX IF NOT EXISTS idx_sync_logs_table_time
ON {self.TABLE_NAME}(table_name, start_time DESC);
CREATE INDEX IF NOT EXISTS idx_sync_logs_status
ON {self.TABLE_NAME}(status);
CREATE INDEX IF NOT EXISTS idx_sync_logs_type_time
ON {self.TABLE_NAME}(sync_type, start_time DESC);
"""
try:
self.storage._connection.execute(create_sql)
self.storage._connection.execute(index_sql)
except Exception as e:
print(f"[SyncLogManager] Error creating log table: {e}")
raise
def start_sync(
self,
table_name: str,
sync_type: str,
date_range_start: Optional[str] = None,
date_range_end: Optional[str] = None,
metadata: Optional[Dict[str, Any]] = None,
) -> None:
"""记录同步开始。
Args:
table_name: 同步的表名
sync_type: 同步类型full/incremental/quarterly 等)
date_range_start: 数据日期范围起始
date_range_end: 数据日期范围结束
metadata: 额外元数据字典
"""
entry = SyncLogEntry(
table_name=table_name,
sync_type=sync_type,
start_time=datetime.now(),
status="running",
records_before=0, # 不统计记录数,避免额外查询
date_range_start=date_range_start,
date_range_end=date_range_end,
metadata=str(metadata) if metadata else None,
)
try:
self.storage._connection.execute(
f"""
INSERT INTO {self.TABLE_NAME} (
table_name, sync_type, start_time, status,
records_before, date_range_start, date_range_end, metadata
) VALUES (?, ?, ?, ?, ?, ?, ?, ?)
""",
(
entry.table_name,
entry.sync_type,
entry.start_time,
entry.status,
entry.records_before,
entry.date_range_start,
entry.date_range_end,
entry.metadata,
),
)
except Exception as e:
print(f"[SyncLogManager] Error logging sync start: {e}")
def complete_sync(
self,
table_name: str,
status: str = "success",
records_inserted: int = 0,
records_updated: int = 0,
records_deleted: int = 0,
error_message: Optional[str] = None,
sync_type: str = "incremental",
date_range_start: Optional[str] = None,
date_range_end: Optional[str] = None,
metadata: Optional[Dict[str, Any]] = None,
) -> None:
"""记录同步完成。
作为一条新的日志记录插入,不更新之前的记录。
Args:
table_name: 同步的表名
status: 同步状态success/failed/partial
records_inserted: 插入的记录数
records_updated: 更新的记录数
records_deleted: 删除的记录数
error_message: 错误信息(如果失败)
sync_type: 同步类型
date_range_start: 日期范围起始
date_range_end: 日期范围结束
metadata: 元数据
"""
entry = SyncLogEntry(
table_name=table_name,
sync_type=sync_type,
start_time=datetime.now(),
end_time=datetime.now(),
status=status,
records_inserted=records_inserted,
records_updated=records_updated,
records_deleted=records_deleted,
date_range_start=date_range_start,
date_range_end=date_range_end,
error_message=error_message,
metadata=str(metadata) if metadata else None,
)
try:
self.storage._connection.execute(
f"""
INSERT INTO {self.TABLE_NAME} (
table_name, sync_type, start_time, end_time, status,
records_inserted, records_updated, records_deleted,
date_range_start, date_range_end, error_message, metadata
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
""",
(
entry.table_name,
entry.sync_type,
entry.start_time,
entry.end_time,
entry.status,
entry.records_inserted,
entry.records_updated,
entry.records_deleted,
entry.date_range_start,
entry.date_range_end,
entry.error_message,
entry.metadata,
),
)
except Exception as e:
print(f"[SyncLogManager] Error logging sync complete: {e}")
def get_sync_history(
self,
table_name: Optional[str] = None,
status: Optional[str] = None,
limit: int = 100,
offset: int = 0,
) -> pd.DataFrame:
"""查询同步历史。
Args:
table_name: 按表名过滤
status: 按状态过滤
limit: 返回记录数限制
offset: 分页偏移
Returns:
同步历史 DataFrame
"""
conditions = []
params = []
if table_name:
conditions.append("table_name = ?")
params.append(table_name)
if status:
conditions.append("status = ?")
params.append(status)
where_clause = f"WHERE {' AND '.join(conditions)}" if conditions else ""
query = f"""
SELECT * FROM {self.TABLE_NAME}
{where_clause}
ORDER BY start_time DESC
LIMIT ? OFFSET ?
"""
params.extend([limit, offset])
try:
return self.storage._connection.execute(query, params).fetchdf()
except Exception as e:
print(f"[SyncLogManager] Error querying sync history: {e}")
return pd.DataFrame()
def get_last_sync(
self, table_name: str, status: Optional[str] = "success"
) -> Optional[Dict[str, Any]]:
"""获取指定表的最近一次同步记录。
Args:
table_name: 表名
status: 状态过滤None 表示不限制
Returns:
同步记录字典,如果没有则返回 None
"""
conditions = ["table_name = ?"]
params = [table_name]
if status:
conditions.append("status = ?")
params.append(status)
query = f"""
SELECT * FROM {self.TABLE_NAME}
WHERE {" AND ".join(conditions)}
ORDER BY start_time DESC
LIMIT 1
"""
try:
df = self.storage._connection.execute(query, params).fetchdf()
if not df.empty:
return df.iloc[0].to_dict()
return None
except Exception as e:
print(f"[SyncLogManager] Error getting last sync: {e}")
return None
def get_sync_summary(
self, table_name: Optional[str] = None, days: int = 30
) -> Dict[str, Any]:
"""获取同步汇总统计。
Args:
table_name: 按表名过滤
days: 最近多少天的统计
Returns:
汇总统计字典
"""
conditions = [f"start_time >= CURRENT_DATE - INTERVAL '{days} days'"]
params = []
if table_name:
conditions.append("table_name = ?")
params.append(table_name)
where_clause = f"WHERE {' AND '.join(conditions)}"
query = f"""
SELECT
COUNT(*) as total_syncs,
COUNT(CASE WHEN status = 'success' THEN 1 END) as success_count,
COUNT(CASE WHEN status = 'failed' THEN 1 END) as failed_count,
COUNT(CASE WHEN status = 'partial' THEN 1 END) as partial_count,
SUM(records_inserted) as total_inserted,
SUM(records_updated) as total_updated,
SUM(records_deleted) as total_deleted
FROM {self.TABLE_NAME}
{where_clause}
"""
try:
df = self.storage._connection.execute(query, params).fetchdf()
if not df.empty:
return {
"total_syncs": int(df.iloc[0]["total_syncs"]),
"success_count": int(df.iloc[0]["success_count"]),
"failed_count": int(df.iloc[0]["failed_count"]),
"partial_count": int(df.iloc[0]["partial_count"]),
"total_inserted": int(df.iloc[0]["total_inserted"] or 0),
"total_updated": int(df.iloc[0]["total_updated"] or 0),
"total_deleted": int(df.iloc[0]["total_deleted"] or 0),
}
except Exception as e:
print(f"[SyncLogManager] Error getting sync summary: {e}")
return {
"total_syncs": 0,
"success_count": 0,
"failed_count": 0,
"partial_count": 0,
"total_inserted": 0,
"total_updated": 0,
"total_deleted": 0,
}
# 便捷函数
def log_sync_start(
table_name: str,
sync_type: str,
date_range_start: Optional[str] = None,
date_range_end: Optional[str] = None,
metadata: Optional[Dict[str, Any]] = None,
) -> None:
"""便捷函数:记录同步开始。
Args:
table_name: 表名
sync_type: 同步类型
date_range_start: 日期范围起始
date_range_end: 日期范围结束
metadata: 元数据
"""
manager = SyncLogManager()
manager.start_sync(
table_name, sync_type, date_range_start, date_range_end, metadata
)
def log_sync_complete(
table_name: str,
status: str = "success",
records_inserted: int = 0,
records_updated: int = 0,
records_deleted: int = 0,
error_message: Optional[str] = None,
sync_type: str = "incremental",
) -> None:
"""便捷函数:记录同步完成。
Args:
table_name: 表名
status: 状态
records_inserted: 插入数
records_updated: 更新数
records_deleted: 删除数
error_message: 错误信息
sync_type: 同步类型
"""
manager = SyncLogManager()
manager.complete_sync(
table_name,
status,
records_inserted,
records_updated,
records_deleted,
error_message,
sync_type,
)

View File

@@ -101,14 +101,14 @@ EXCLUDED_FACTORS = [
'GTJA_alpha005',
'GTJA_alpha036',
'GTJA_alpha027',
'GTJA_alpha053',
'GTJA_alpha044',
'GTJA_alpha073',
'GTJA_alpha104',
'GTJA_alpha103',
'GTJA_alpha087',
'GTJA_alpha105',
'GTJA_alpha092',
'GTJA_alpha087',
'GTJA_alpha085',
'GTJA_alpha044',
'GTJA_alpha062',
'GTJA_alpha124',
'GTJA_alpha133',
@@ -203,23 +203,37 @@ print("\n" + "=" * 80)
print("开始训练")
print("=" * 80)
# 步骤 1: 股票池筛选
print("\n[步骤 1/6] 股票池筛选")
# 步骤 1: 应用过滤器ST股票过滤等
print("\n[步骤 1/7] 应用数据过滤器")
print("-" * 60)
filtered_data = data
if st_filter:
print(" 应用 ST 股票过滤器...")
data_before = len(filtered_data)
filtered_data = st_filter.filter(filtered_data)
data_after = len(filtered_data)
print(f" 过滤前记录数: {data_before}")
print(f" 过滤后记录数: {data_after}")
print(f" 删除 ST 股票记录数: {data_before - data_after}")
else:
print(" 未配置 ST 过滤器,跳过")
# 步骤 2: 股票池筛选
print("\n[步骤 2/7] 股票池筛选")
print("-" * 60)
if pool_manager:
print(" 执行每日独立筛选股票池...")
filtered_data = pool_manager.filter_and_select_daily(data)
print(f" 筛选前数据规模: {data.shape}")
print(f" 筛选后数据规模: {filtered_data.shape}")
print(f" 筛选前股票数: {data['ts_code'].n_unique()}")
print(f" 筛选后股票数: {filtered_data['ts_code'].n_unique()}")
print(f" 删除记录数: {len(data) - len(filtered_data)}")
pool_data_before = len(filtered_data)
filtered_data = pool_manager.filter_and_select_daily(filtered_data)
pool_data_after = len(filtered_data)
print(f" 筛选前数据规模: {pool_data_before}")
print(f" 筛选后数据规模: {pool_data_after}")
print(f" 删除记录数: {pool_data_before - pool_data_after}")
else:
filtered_data = data
print(" 未配置股票池管理器,跳过筛选")
# %%
# 步骤 2: 划分训练/验证/测试集(正确的三分法)
print("\n[步骤 2/6] 划分训练集、验证集和测试集")
# 步骤 3: 划分训练/验证/测试集(正确的三分法)
print("\n[步骤 3/7] 划分训练集、验证集和测试集")
print("-" * 60)
if splitter:
# 正确的三分法train用于训练val用于验证/早停test仅用于最终评估
@@ -251,8 +265,8 @@ else:
test_data = filtered_data
print(" 未配置划分器,全部作为训练集")
# %%
# 步骤 3: 数据质量检查(必须在预处理之前)
print("\n[步骤 3/7] 数据质量检查")
# 步骤 4: 数据质量检查(必须在预处理之前)
print("\n[步骤 4/7] 数据质量检查")
print("-" * 60)
print(" [说明] 此检查在 fillna 等处理之前执行,用于发现数据问题")
@@ -269,8 +283,8 @@ check_data_quality(test_data, feature_cols, raise_on_error=True)
print(" [成功] 数据质量检查通过,未发现异常")
# %%
# 步骤 4: 训练集数据处理
print("\n[步骤 4/7] 训练集数据处理")
# 步骤 5: 训练集数据处理
print("\n[步骤 5/7] 训练集数据处理")
print("-" * 60)
fitted_processors = []
if processors:
@@ -296,7 +310,7 @@ for col in feature_cols[:5]: # 只显示前5个特征的缺失值
if null_count > 0:
print(f" {col}: {null_count} ({null_count / len(train_data) * 100:.2f}%)")
# %%
# 步骤 4: 训练模型
# 步骤 5: 训练模型
print("\n[步骤 5/7] 训练模型")
print("-" * 60)
print(f" 模型类型: LightGBM")
@@ -318,7 +332,7 @@ print("\n 开始训练...")
model.fit(X_train, y_train)
print(" 训练完成!")
# %%
# 步骤 5: 测试集数据处理
# 步骤 6: 测试集数据处理
print("\n[步骤 6/7] 测试集数据处理")
print("-" * 60)
if processors and test_data is not train_data:
@@ -334,7 +348,7 @@ if processors and test_data is not train_data:
else:
print(" 跳过测试集处理")
# %%
# 步骤 6: 生成预测
# 步骤 7: 生成预测
print("\n[步骤 7/7] 生成预测")
print("-" * 60)
X_test = test_data.select(feature_cols)

View File

@@ -0,0 +1,403 @@
"""测试同步事务和日志功能。
测试内容:
1. 事务支持 - BEGIN/COMMIT/ROLLBACK
2. 同步日志记录 - SyncLogManager
3. ThreadSafeStorage 事务批量写入
"""
import pytest
import pandas as pd
import os
import tempfile
from datetime import datetime, timedelta
from pathlib import Path
# 设置测试环境变量
os.environ["DATA_PATH"] = tempfile.mkdtemp()
os.environ["TUSHARE_TOKEN"] = "test_token"
from src.data.storage import Storage, ThreadSafeStorage
from src.data.sync_logger import SyncLogManager, SyncLogEntry
@pytest.fixture
def temp_storage():
"""创建临时存储实例用于测试。"""
# 使用临时目录
temp_dir = tempfile.mkdtemp()
os.environ["DATA_PATH"] = temp_dir
# 重置 Storage 单例
Storage._instance = None
Storage._connection = None
storage = Storage(read_only=False)
yield storage
# 清理
storage.close()
Storage._instance = None
Storage._connection = None
class TestTransactionSupport:
"""测试 Storage 事务支持。"""
def test_begin_commit_transaction(self, temp_storage):
"""测试事务开始和提交。"""
# 创建测试表
temp_storage._connection.execute("""
CREATE TABLE test_table (id INTEGER PRIMARY KEY, value VARCHAR(50))
""")
# 开始事务
temp_storage.begin_transaction()
# 插入数据
temp_storage._connection.execute(
"INSERT INTO test_table VALUES (1, 'test1'), (2, 'test2')"
)
# 提交事务
temp_storage.commit_transaction()
# 验证数据已提交
result = temp_storage._connection.execute(
"SELECT COUNT(*) FROM test_table"
).fetchone()
assert result[0] == 2
def test_rollback_transaction(self, temp_storage):
"""测试事务回滚。"""
# 创建测试表并插入初始数据
temp_storage._connection.execute("""
CREATE TABLE test_table2 (id INTEGER PRIMARY KEY, value VARCHAR(50))
""")
temp_storage._connection.execute(
"INSERT INTO test_table2 VALUES (1, 'initial')"
)
# 开始事务并插入更多数据
temp_storage.begin_transaction()
temp_storage._connection.execute("INSERT INTO test_table2 VALUES (2, 'temp')")
# 回滚事务
temp_storage.rollback_transaction()
# 验证临时数据未提交
result = temp_storage._connection.execute(
"SELECT COUNT(*) FROM test_table2"
).fetchone()
assert result[0] == 1
def test_transaction_context_manager(self, temp_storage):
"""测试事务上下文管理器。"""
# 创建测试表
temp_storage._connection.execute("""
CREATE TABLE test_table3 (id INTEGER PRIMARY KEY, value VARCHAR(50))
""")
# 使用上下文管理器(正常完成)
with temp_storage.transaction():
temp_storage._connection.execute(
"INSERT INTO test_table3 VALUES (1, 'committed')"
)
# 验证数据已提交
result = temp_storage._connection.execute(
"SELECT COUNT(*) FROM test_table3"
).fetchone()
assert result[0] == 1
def test_transaction_context_manager_rollback(self, temp_storage):
"""测试事务上下文管理器异常回滚。"""
# 创建测试表
temp_storage._connection.execute("""
CREATE TABLE test_table4 (id INTEGER PRIMARY KEY, value VARCHAR(50))
""")
# 使用上下文管理器(发生异常)
try:
with temp_storage.transaction():
temp_storage._connection.execute(
"INSERT INTO test_table4 VALUES (1, 'temp')"
)
raise ValueError("Test error")
except ValueError:
pass
# 验证数据未提交
result = temp_storage._connection.execute(
"SELECT COUNT(*) FROM test_table4"
).fetchone()
assert result[0] == 0
class TestSyncLogManager:
"""测试同步日志管理器。"""
def test_log_table_creation(self, temp_storage):
"""测试日志表自动创建。"""
# 创建日志管理器会自动创建表
log_manager = SyncLogManager(temp_storage)
# 验证表存在
result = temp_storage._connection.execute("""
SELECT COUNT(*) FROM information_schema.tables
WHERE table_name = '_sync_logs'
""").fetchone()
assert result[0] == 1
def test_start_sync(self, temp_storage):
"""测试开始记录同步。"""
log_manager = SyncLogManager(temp_storage)
# 记录同步开始
entry = log_manager.start_sync(
table_name="test_table",
sync_type="incremental",
date_range_start="20240101",
date_range_end="20240131",
metadata={"test": True},
)
assert entry.table_name == "test_table"
assert entry.sync_type == "incremental"
assert entry.status == "running"
assert entry.date_range_start == "20240101"
def test_complete_sync(self, temp_storage):
"""测试完成同步记录。"""
log_manager = SyncLogManager(temp_storage)
# 开始同步
entry = log_manager.start_sync(table_name="test_table", sync_type="full")
# 完成同步
log_manager.complete_sync(
entry,
status="success",
records_inserted=1000,
records_updated=100,
records_deleted=10,
)
assert entry.status == "success"
assert entry.records_inserted == 1000
assert entry.records_updated == 100
assert entry.records_deleted == 10
assert entry.end_time is not None
def test_complete_sync_with_error(self, temp_storage):
"""测试失败同步记录。"""
log_manager = SyncLogManager(temp_storage)
entry = log_manager.start_sync(table_name="test_table", sync_type="incremental")
log_manager.complete_sync(
entry, status="failed", error_message="Connection timeout"
)
assert entry.status == "failed"
assert entry.error_message == "Connection timeout"
def test_get_sync_history(self, temp_storage):
"""测试查询同步历史。"""
log_manager = SyncLogManager(temp_storage)
# 创建几条记录
for i in range(3):
entry = log_manager.start_sync(
table_name="test_table", sync_type="incremental"
)
log_manager.complete_sync(entry, status="success", records_inserted=100)
# 查询历史
history = log_manager.get_sync_history(table_name="test_table", limit=10)
assert len(history) == 3
assert all(h["table_name"] == "test_table" for _, h in history.iterrows())
def test_get_last_sync(self, temp_storage):
"""测试获取最近同步记录。"""
log_manager = SyncLogManager(temp_storage)
# 创建两条记录
entry1 = log_manager.start_sync(table_name="table1", sync_type="full")
log_manager.complete_sync(entry1, status="success")
entry2 = log_manager.start_sync(table_name="table1", sync_type="incremental")
log_manager.complete_sync(entry2, status="success")
# 获取最近一次
last_sync = log_manager.get_last_sync("table1")
assert last_sync is not None
assert last_sync["sync_type"] == "incremental"
def test_get_sync_summary(self, temp_storage):
"""测试获取同步汇总统计。"""
log_manager = SyncLogManager(temp_storage)
# 创建多条记录
for i in range(5):
entry = log_manager.start_sync(
table_name="test_table", sync_type="incremental"
)
log_manager.complete_sync(entry, status="success", records_inserted=100)
# 添加一条失败记录
entry = log_manager.start_sync(table_name="test_table", sync_type="full")
log_manager.complete_sync(entry, status="failed", error_message="error")
# 获取汇总
summary = log_manager.get_sync_summary("test_table", days=30)
assert summary["total_syncs"] == 6
assert summary["success_count"] == 5
assert summary["failed_count"] == 1
assert summary["total_inserted"] == 500
class TestThreadSafeStorageTransaction:
"""测试 ThreadSafeStorage 事务支持。"""
def test_flush_with_transaction(self, temp_storage):
"""测试带事务的批量写入。"""
# 重置 Storage 单例
Storage._instance = None
Storage._connection = None
ts_storage = ThreadSafeStorage()
# 创建测试表
ts_storage.storage._connection.execute("""
CREATE TABLE test_data (ts_code VARCHAR(16), trade_date DATE, value DOUBLE)
""")
# 准备测试数据
df1 = pd.DataFrame(
{
"ts_code": ["000001.SZ", "000002.SZ"],
"trade_date": ["20240101", "20240101"],
"value": [100.0, 200.0],
}
)
df2 = pd.DataFrame(
{
"ts_code": ["000003.SZ", "000004.SZ"],
"trade_date": ["20240102", "20240102"],
"value": [300.0, 400.0],
}
)
# 加入队列
ts_storage.queue_save("test_data", df1)
ts_storage.queue_save("test_data", df2)
# 使用事务刷新
ts_storage.flush(use_transaction=True)
# 验证数据
result = ts_storage.storage._connection.execute(
"SELECT COUNT(*) FROM test_data"
).fetchone()
assert result[0] == 4
def test_flush_rollback_on_error(self, temp_storage):
"""测试错误时回滚。"""
# 这个测试比较复杂,需要模拟错误情况
# 简化版本:验证错误不会导致数据不一致
Storage._instance = None
Storage._connection = None
ts_storage = ThreadSafeStorage()
# 创建测试表(使用唯一表名)
ts_storage.storage._connection.execute("""
CREATE TABLE test_data2 (ts_code VARCHAR(16) PRIMARY KEY, value DOUBLE)
""")
# 插入初始数据
df = pd.DataFrame({"ts_code": ["000001.SZ"], "value": [100.0]})
ts_storage.queue_save("test_data2", df, use_upsert=False)
ts_storage.flush(use_transaction=True)
# 验证
result = ts_storage.storage._connection.execute(
"SELECT COUNT(*) FROM test_data2"
).fetchone()
assert result[0] == 1
class TestIntegration:
"""集成测试。"""
def test_full_sync_workflow(self, temp_storage):
"""测试完整同步工作流。"""
# 1. 初始化日志管理器
log_manager = SyncLogManager(temp_storage)
# 2. 创建测试表
temp_storage._connection.execute("""
CREATE TABLE sync_test_table (
ts_code VARCHAR(16),
trade_date DATE,
value DOUBLE,
PRIMARY KEY (ts_code, trade_date)
)
""")
# 3. 开始同步
log_entry = log_manager.start_sync(
table_name="sync_test_table",
sync_type="full",
date_range_start="20240101",
date_range_end="20240131",
)
# 4. 在事务中执行同步
temp_storage.begin_transaction()
try:
# 模拟同步:插入数据
df = pd.DataFrame(
{
"ts_code": ["000001.SZ", "000002.SZ"],
"trade_date": ["20240115", "20240115"],
"value": [100.0, 200.0],
}
)
# 转换日期格式
df["trade_date"] = pd.to_datetime(df["trade_date"], format="%Y%m%d").dt.date
# 使用 storage.save
temp_storage.save("sync_test_table", df, mode="append")
# 提交事务
temp_storage.commit_transaction()
# 5. 记录成功
log_manager.complete_sync(
log_entry, status="success", records_inserted=len(df)
)
except Exception as e:
temp_storage.rollback_transaction()
log_manager.complete_sync(log_entry, status="failed", error_message=str(e))
raise
# 6. 验证
assert log_entry.status == "success"
assert log_entry.records_inserted == 2
# 7. 查询日志历史
history = log_manager.get_sync_history(table_name="sync_test_table")
assert len(history) == 1
if __name__ == "__main__":
# 运行测试
pytest.main([__file__, "-v"])