"""测试同步事务和日志功能。 测试内容: 1. 事务支持 - BEGIN/COMMIT/ROLLBACK 2. 同步日志记录 - SyncLogManager 3. ThreadSafeStorage 事务批量写入 """ import pytest import pandas as pd import os import tempfile from datetime import datetime, timedelta from pathlib import Path # 设置测试环境变量 os.environ["DATA_PATH"] = tempfile.mkdtemp() os.environ["TUSHARE_TOKEN"] = "test_token" from src.data.storage import Storage, ThreadSafeStorage from src.data.sync_logger import SyncLogManager, SyncLogEntry @pytest.fixture def temp_storage(): """创建临时存储实例用于测试。""" # 使用临时目录 temp_dir = tempfile.mkdtemp() os.environ["DATA_PATH"] = temp_dir # 重置 Storage 单例 Storage._instance = None Storage._connection = None storage = Storage(read_only=False) yield storage # 清理 storage.close() Storage._instance = None Storage._connection = None class TestTransactionSupport: """测试 Storage 事务支持。""" def test_begin_commit_transaction(self, temp_storage): """测试事务开始和提交。""" # 创建测试表 temp_storage._connection.execute(""" CREATE TABLE test_table (id INTEGER PRIMARY KEY, value VARCHAR(50)) """) # 开始事务 temp_storage.begin_transaction() # 插入数据 temp_storage._connection.execute( "INSERT INTO test_table VALUES (1, 'test1'), (2, 'test2')" ) # 提交事务 temp_storage.commit_transaction() # 验证数据已提交 result = temp_storage._connection.execute( "SELECT COUNT(*) FROM test_table" ).fetchone() assert result[0] == 2 def test_rollback_transaction(self, temp_storage): """测试事务回滚。""" # 创建测试表并插入初始数据 temp_storage._connection.execute(""" CREATE TABLE test_table2 (id INTEGER PRIMARY KEY, value VARCHAR(50)) """) temp_storage._connection.execute( "INSERT INTO test_table2 VALUES (1, 'initial')" ) # 开始事务并插入更多数据 temp_storage.begin_transaction() temp_storage._connection.execute("INSERT INTO test_table2 VALUES (2, 'temp')") # 回滚事务 temp_storage.rollback_transaction() # 验证临时数据未提交 result = temp_storage._connection.execute( "SELECT COUNT(*) FROM test_table2" ).fetchone() assert result[0] == 1 def test_transaction_context_manager(self, temp_storage): """测试事务上下文管理器。""" # 创建测试表 temp_storage._connection.execute(""" CREATE TABLE test_table3 (id INTEGER PRIMARY KEY, value VARCHAR(50)) """) # 使用上下文管理器(正常完成) with temp_storage.transaction(): temp_storage._connection.execute( "INSERT INTO test_table3 VALUES (1, 'committed')" ) # 验证数据已提交 result = temp_storage._connection.execute( "SELECT COUNT(*) FROM test_table3" ).fetchone() assert result[0] == 1 def test_transaction_context_manager_rollback(self, temp_storage): """测试事务上下文管理器异常回滚。""" # 创建测试表 temp_storage._connection.execute(""" CREATE TABLE test_table4 (id INTEGER PRIMARY KEY, value VARCHAR(50)) """) # 使用上下文管理器(发生异常) try: with temp_storage.transaction(): temp_storage._connection.execute( "INSERT INTO test_table4 VALUES (1, 'temp')" ) raise ValueError("Test error") except ValueError: pass # 验证数据未提交 result = temp_storage._connection.execute( "SELECT COUNT(*) FROM test_table4" ).fetchone() assert result[0] == 0 class TestSyncLogManager: """测试同步日志管理器。""" def test_log_table_creation(self, temp_storage): """测试日志表自动创建。""" # 创建日志管理器会自动创建表 log_manager = SyncLogManager(temp_storage) # 验证表存在 result = temp_storage._connection.execute(""" SELECT COUNT(*) FROM information_schema.tables WHERE table_name = '_sync_logs' """).fetchone() assert result[0] == 1 def test_start_sync(self, temp_storage): """测试开始记录同步。""" log_manager = SyncLogManager(temp_storage) # 记录同步开始 entry = log_manager.start_sync( table_name="test_table", sync_type="incremental", date_range_start="20240101", date_range_end="20240131", metadata={"test": True}, ) assert entry.table_name == "test_table" assert entry.sync_type == "incremental" assert entry.status == "running" assert entry.date_range_start == "20240101" def test_complete_sync(self, temp_storage): """测试完成同步记录。""" log_manager = SyncLogManager(temp_storage) # 开始同步 entry = log_manager.start_sync(table_name="test_table", sync_type="full") # 完成同步 log_manager.complete_sync( entry, status="success", records_inserted=1000, records_updated=100, records_deleted=10, ) assert entry.status == "success" assert entry.records_inserted == 1000 assert entry.records_updated == 100 assert entry.records_deleted == 10 assert entry.end_time is not None def test_complete_sync_with_error(self, temp_storage): """测试失败同步记录。""" log_manager = SyncLogManager(temp_storage) entry = log_manager.start_sync(table_name="test_table", sync_type="incremental") log_manager.complete_sync( entry, status="failed", error_message="Connection timeout" ) assert entry.status == "failed" assert entry.error_message == "Connection timeout" def test_get_sync_history(self, temp_storage): """测试查询同步历史。""" log_manager = SyncLogManager(temp_storage) # 创建几条记录 for i in range(3): entry = log_manager.start_sync( table_name="test_table", sync_type="incremental" ) log_manager.complete_sync(entry, status="success", records_inserted=100) # 查询历史 history = log_manager.get_sync_history(table_name="test_table", limit=10) assert len(history) == 3 assert all(h["table_name"] == "test_table" for _, h in history.iterrows()) def test_get_last_sync(self, temp_storage): """测试获取最近同步记录。""" log_manager = SyncLogManager(temp_storage) # 创建两条记录 entry1 = log_manager.start_sync(table_name="table1", sync_type="full") log_manager.complete_sync(entry1, status="success") entry2 = log_manager.start_sync(table_name="table1", sync_type="incremental") log_manager.complete_sync(entry2, status="success") # 获取最近一次 last_sync = log_manager.get_last_sync("table1") assert last_sync is not None assert last_sync["sync_type"] == "incremental" def test_get_sync_summary(self, temp_storage): """测试获取同步汇总统计。""" log_manager = SyncLogManager(temp_storage) # 创建多条记录 for i in range(5): entry = log_manager.start_sync( table_name="test_table", sync_type="incremental" ) log_manager.complete_sync(entry, status="success", records_inserted=100) # 添加一条失败记录 entry = log_manager.start_sync(table_name="test_table", sync_type="full") log_manager.complete_sync(entry, status="failed", error_message="error") # 获取汇总 summary = log_manager.get_sync_summary("test_table", days=30) assert summary["total_syncs"] == 6 assert summary["success_count"] == 5 assert summary["failed_count"] == 1 assert summary["total_inserted"] == 500 class TestThreadSafeStorageTransaction: """测试 ThreadSafeStorage 事务支持。""" def test_flush_with_transaction(self, temp_storage): """测试带事务的批量写入。""" # 重置 Storage 单例 Storage._instance = None Storage._connection = None ts_storage = ThreadSafeStorage() # 创建测试表 ts_storage.storage._connection.execute(""" CREATE TABLE test_data (ts_code VARCHAR(16), trade_date DATE, value DOUBLE) """) # 准备测试数据 df1 = pd.DataFrame( { "ts_code": ["000001.SZ", "000002.SZ"], "trade_date": ["20240101", "20240101"], "value": [100.0, 200.0], } ) df2 = pd.DataFrame( { "ts_code": ["000003.SZ", "000004.SZ"], "trade_date": ["20240102", "20240102"], "value": [300.0, 400.0], } ) # 加入队列 ts_storage.queue_save("test_data", df1) ts_storage.queue_save("test_data", df2) # 使用事务刷新 ts_storage.flush(use_transaction=True) # 验证数据 result = ts_storage.storage._connection.execute( "SELECT COUNT(*) FROM test_data" ).fetchone() assert result[0] == 4 def test_flush_rollback_on_error(self, temp_storage): """测试错误时回滚。""" # 这个测试比较复杂,需要模拟错误情况 # 简化版本:验证错误不会导致数据不一致 Storage._instance = None Storage._connection = None ts_storage = ThreadSafeStorage() # 创建测试表(使用唯一表名) ts_storage.storage._connection.execute(""" CREATE TABLE test_data2 (ts_code VARCHAR(16) PRIMARY KEY, value DOUBLE) """) # 插入初始数据 df = pd.DataFrame({"ts_code": ["000001.SZ"], "value": [100.0]}) ts_storage.queue_save("test_data2", df, use_upsert=False) ts_storage.flush(use_transaction=True) # 验证 result = ts_storage.storage._connection.execute( "SELECT COUNT(*) FROM test_data2" ).fetchone() assert result[0] == 1 class TestIntegration: """集成测试。""" def test_full_sync_workflow(self, temp_storage): """测试完整同步工作流。""" # 1. 初始化日志管理器 log_manager = SyncLogManager(temp_storage) # 2. 创建测试表 temp_storage._connection.execute(""" CREATE TABLE sync_test_table ( ts_code VARCHAR(16), trade_date DATE, value DOUBLE, PRIMARY KEY (ts_code, trade_date) ) """) # 3. 开始同步 log_entry = log_manager.start_sync( table_name="sync_test_table", sync_type="full", date_range_start="20240101", date_range_end="20240131", ) # 4. 在事务中执行同步 temp_storage.begin_transaction() try: # 模拟同步:插入数据 df = pd.DataFrame( { "ts_code": ["000001.SZ", "000002.SZ"], "trade_date": ["20240115", "20240115"], "value": [100.0, 200.0], } ) # 转换日期格式 df["trade_date"] = pd.to_datetime(df["trade_date"], format="%Y%m%d").dt.date # 使用 storage.save temp_storage.save("sync_test_table", df, mode="append") # 提交事务 temp_storage.commit_transaction() # 5. 记录成功 log_manager.complete_sync( log_entry, status="success", records_inserted=len(df) ) except Exception as e: temp_storage.rollback_transaction() log_manager.complete_sync(log_entry, status="failed", error_message=str(e)) raise # 6. 验证 assert log_entry.status == "success" assert log_entry.records_inserted == 2 # 7. 查询日志历史 history = log_manager.get_sync_history(table_name="sync_test_table") assert len(history) == 1 if __name__ == "__main__": # 运行测试 pytest.main([__file__, "-v"])