- Storage/ThreadSafeStorage 添加事务支持(begin/commit/rollback) - 新增 SyncLogManager 记录所有同步任务的执行状态 - 集成事务到 StockBasedSync、DateBasedSync、QuarterBasedSync - 在 sync_all 和 sync_financial 调度中心添加日志记录 - 新增测试验证事务和日志功能
839 lines
30 KiB
Python
839 lines
30 KiB
Python
"""财务数据同步基础抽象模块。
|
||
|
||
提供专门用于按季度同步财务数据的基类 QuarterBasedSync。
|
||
财务数据特点:
|
||
- 按季度发布(period: 20231231, 20230930, 20230630, 20230331)
|
||
- 使用 VIP 接口一次性获取某季度的全部上市公司数据
|
||
- 数据可能会修正,增量同步需获取当前季度+前一季度
|
||
- 主键为 (ts_code, end_date)
|
||
|
||
使用方式:
|
||
class IncomeQuarterSync(QuarterBasedSync):
|
||
table_name = "financial_income"
|
||
api_name = "income_vip"
|
||
|
||
def fetch_single_quarter(self, period: str) -> pd.DataFrame:
|
||
# 实现单季度数据获取
|
||
...
|
||
"""
|
||
|
||
from abc import ABC, abstractmethod
|
||
from typing import Optional, Dict, List, Tuple, Set
|
||
from datetime import datetime
|
||
import pandas as pd
|
||
from tqdm import tqdm
|
||
|
||
from src.data.client import TushareClient
|
||
from src.data.storage import ThreadSafeStorage, Storage
|
||
from src.data.sync_logger import SyncLogManager
|
||
from src.data.utils import get_today_date, get_quarters_in_range, DEFAULT_START_DATE
|
||
|
||
|
||
class QuarterBasedSync(ABC):
|
||
"""财务数据季度同步抽象基类。
|
||
|
||
专门处理按季度同步的财务数据(利润表、资产负债表、现金流量表)。
|
||
财务数据同步特点:
|
||
1. 按季度获取:使用 VIP 接口一次性获取某季度全部上市公司数据
|
||
2. 数据可修正:同一季度数据可能被更新,增量同步需获取当前季度+前一季度
|
||
3. 差异检测:需对比本地与远程数据量,识别缺失或变更的记录
|
||
4. 主键:(ts_code, end_date)
|
||
|
||
子类必须实现:
|
||
- table_name: 类属性,目标表名
|
||
- api_name: 类属性,Tushare API 接口名
|
||
- fetch_single_quarter(period) -> pd.DataFrame: 获取单季度数据
|
||
- TABLE_SCHEMA: 类属性,表结构定义
|
||
|
||
Attributes:
|
||
table_name: 目标表名(子类必须覆盖)
|
||
api_name: Tushare API 接口名(子类必须覆盖)
|
||
DEFAULT_START_DATE: 默认起始日期(2018Q1)
|
||
TABLE_SCHEMA: 表结构定义 {列名: SQL类型}
|
||
TABLE_INDEXES: 索引定义 [(索引名, [列名列表]), ...]
|
||
注意:不要创建唯一索引,因为财务数据可能发生多次修正
|
||
"""
|
||
|
||
table_name: str = "" # 子类必须覆盖
|
||
api_name: str = "" # 子类必须覆盖
|
||
DEFAULT_START_DATE = "20180331" # 2018年Q1
|
||
|
||
# 目标报表类型(子类可覆盖)
|
||
# 默认只同步合并报表(report_type='1')
|
||
# 设为 None 则同步所有报表类型
|
||
TARGET_REPORT_TYPE: Optional[str] = "1"
|
||
|
||
# 表结构定义(子类必须覆盖)
|
||
TABLE_SCHEMA: Dict[str, str] = {}
|
||
|
||
# 索引定义(子类可覆盖)
|
||
# 格式: [("index_name", ["col1", "col2"]), ...]
|
||
# 注意:不要创建唯一索引,因为财务数据可能发生多次修正
|
||
TABLE_INDEXES: List[Tuple[str, List[str]]] = []
|
||
|
||
def __init__(self):
|
||
"""初始化季度同步管理器。"""
|
||
self.storage = ThreadSafeStorage()
|
||
self.client = TushareClient()
|
||
self._cached_data: Optional[pd.DataFrame] = None
|
||
|
||
# ======================================================================
|
||
# 抽象方法 - 子类必须实现
|
||
# ======================================================================
|
||
|
||
@abstractmethod
|
||
def fetch_single_quarter(self, period: str) -> pd.DataFrame:
|
||
"""获取单季度的全部上市公司数据。
|
||
|
||
Args:
|
||
period: 报告期,季度最后一天日期(如 '20231231')
|
||
|
||
Returns:
|
||
包含该季度全部上市公司财务数据的 DataFrame
|
||
"""
|
||
pass
|
||
|
||
# ======================================================================
|
||
# 季度计算工具方法
|
||
# ======================================================================
|
||
|
||
def get_current_quarter(self) -> str:
|
||
"""获取当前季度(考虑是否到季末)。
|
||
|
||
如果当前日期未到季度最后一天,则返回前一季度。
|
||
这样可以避免获取尚无数据的未来季度。
|
||
|
||
Returns:
|
||
当前季度字符串 (YYYYMMDD),如 '20231231'
|
||
"""
|
||
today = get_today_date()
|
||
year = int(today[:4])
|
||
month = int(today[4:6])
|
||
|
||
# 确定当前季度
|
||
if month <= 3:
|
||
current_q = f"{year}0331"
|
||
elif month <= 6:
|
||
current_q = f"{year}0630"
|
||
elif month <= 9:
|
||
current_q = f"{year}0930"
|
||
else:
|
||
current_q = f"{year}1231"
|
||
|
||
# 如果今天还没到季末,返回前一季度
|
||
if today < current_q:
|
||
return self.get_prev_quarter(current_q)
|
||
|
||
return current_q
|
||
|
||
def get_prev_quarter(self, quarter: str) -> str:
|
||
"""获取前一季度。
|
||
|
||
Args:
|
||
quarter: 季度字符串 (YYYYMMDD),如 '20231231'
|
||
|
||
Returns:
|
||
前一季度字符串 (YYYYMMDD)
|
||
"""
|
||
year = int(quarter[:4])
|
||
month_day = quarter[4:]
|
||
|
||
if month_day == "0331":
|
||
return f"{year - 1}1231"
|
||
elif month_day == "0630":
|
||
return f"{year}0331"
|
||
elif month_day == "0930":
|
||
return f"{year}0630"
|
||
else: # "1231"
|
||
return f"{year}0930"
|
||
|
||
def get_next_quarter(self, quarter: str) -> str:
|
||
"""获取下一季度。
|
||
|
||
Args:
|
||
quarter: 季度字符串 (YYYYMMDD)
|
||
|
||
Returns:
|
||
下一季度字符串 (YYYYMMDD)
|
||
"""
|
||
year = int(quarter[:4])
|
||
month_day = quarter[4:]
|
||
|
||
if month_day == "0331":
|
||
return f"{year}0630"
|
||
elif month_day == "0630":
|
||
return f"{year}0930"
|
||
elif month_day == "0930":
|
||
return f"{year}1231"
|
||
else: # "1231"
|
||
return f"{year + 1}0331"
|
||
|
||
# ======================================================================
|
||
# 表结构管理
|
||
# ======================================================================
|
||
|
||
def ensure_table_exists(self) -> None:
|
||
"""确保表结构存在,如不存在则创建表和索引。
|
||
|
||
注意:不设置主键和唯一索引,因为财务数据可能发生多次修正,
|
||
同一支股票在同一季度可能有多个版本(不同的ann_date)。
|
||
DuckDB 会自动创建 rowid 作为主键。
|
||
"""
|
||
storage = Storage()
|
||
|
||
if storage.exists(self.table_name):
|
||
return
|
||
|
||
if not self.TABLE_SCHEMA:
|
||
print(
|
||
f"[{self.__class__.__name__}] TABLE_SCHEMA not defined, skipping table creation"
|
||
)
|
||
return
|
||
|
||
# 构建列定义(不设置主键)
|
||
columns_def = []
|
||
for col_name, col_type in self.TABLE_SCHEMA.items():
|
||
columns_def.append(f'"{col_name}" {col_type}')
|
||
|
||
columns_sql = ", ".join(columns_def)
|
||
create_sql = f'CREATE TABLE IF NOT EXISTS "{self.table_name}" ({columns_sql})'
|
||
|
||
try:
|
||
storage._connection.execute(create_sql)
|
||
print(f"[{self.__class__.__name__}] Created table '{self.table_name}'")
|
||
except Exception as e:
|
||
print(f"[{self.__class__.__name__}] Error creating table: {e}")
|
||
raise
|
||
|
||
# 创建普通索引(不要创建唯一索引)
|
||
for idx_name, idx_cols in self.TABLE_INDEXES:
|
||
try:
|
||
idx_cols_sql = ", ".join(f'"{col}"' for col in idx_cols)
|
||
storage._connection.execute(
|
||
f'CREATE INDEX IF NOT EXISTS "{idx_name}" ON "{self.table_name}"({idx_cols_sql})'
|
||
)
|
||
print(
|
||
f"[{self.__class__.__name__}] Created index '{idx_name}' on {idx_cols}"
|
||
)
|
||
except Exception as e:
|
||
print(
|
||
f"[{self.__class__.__name__}] Error creating index {idx_name}: {e}"
|
||
)
|
||
|
||
# ======================================================================
|
||
# 数据差异检测(核心逻辑)
|
||
# ======================================================================
|
||
|
||
def get_local_data_count_by_stock(self, period: str) -> Dict[str, int]:
|
||
"""获取本地数据库中某季度的各股票数据量。
|
||
|
||
Args:
|
||
period: 季度字符串 (YYYYMMDD)
|
||
|
||
Returns:
|
||
字典 {ts_code: 记录数}
|
||
"""
|
||
storage = Storage()
|
||
|
||
try:
|
||
# 将 YYYYMMDD 转换为 YYYY-MM-DD 格式
|
||
period_formatted = f"{period[:4]}-{period[4:6]}-{period[6:]}"
|
||
query = f'''
|
||
SELECT ts_code, COUNT(*) as cnt
|
||
FROM "{self.table_name}"
|
||
WHERE end_date = ?
|
||
GROUP BY ts_code
|
||
'''
|
||
result = storage._connection.execute(query, [period_formatted]).fetchdf()
|
||
|
||
if result.empty:
|
||
return {}
|
||
|
||
return dict(zip(result["ts_code"], result["cnt"]))
|
||
except Exception as e:
|
||
print(f"[{self.__class__.__name__}] Error querying local data count: {e}")
|
||
return {}
|
||
|
||
def get_local_records_by_key(self, period: str) -> Dict[tuple, int]:
|
||
"""获取本地数据库中某季度的记录(按主键分组计数)。
|
||
|
||
用于更精确的差异检测,按 (ts_code, end_date, report_type) 分组。
|
||
|
||
Args:
|
||
period: 季度字符串 (YYYYMMDD)
|
||
|
||
Returns:
|
||
字典 {(ts_code, end_date, report_type): 记录数}
|
||
"""
|
||
storage = Storage()
|
||
|
||
try:
|
||
# 将 YYYYMMDD 转换为 YYYY-MM-DD 格式
|
||
period_formatted = f"{period[:4]}-{period[4:6]}-{period[6:]}"
|
||
query = f'''
|
||
SELECT ts_code, end_date, report_type, COUNT(*) as cnt
|
||
FROM "{self.table_name}"
|
||
WHERE end_date = ?
|
||
GROUP BY ts_code, end_date, report_type
|
||
'''
|
||
result = storage._connection.execute(query, [period_formatted]).fetchdf()
|
||
|
||
if result.empty:
|
||
return {}
|
||
|
||
return {
|
||
(row["ts_code"], row["end_date"], row["report_type"]): row["cnt"]
|
||
for _, row in result.iterrows()
|
||
}
|
||
except Exception as e:
|
||
print(f"[{self.__class__.__name__}] Error querying local records: {e}")
|
||
return {}
|
||
|
||
def compare_and_find_differences(
|
||
self, remote_df: pd.DataFrame, period: str
|
||
) -> Tuple[pd.DataFrame, pd.DataFrame]:
|
||
"""对比远程数据与本地数据,找出差异。
|
||
|
||
逻辑:
|
||
1. 统计远程数据中每只股票的数据量
|
||
2. 查询本地数据库中该季度每只股票的数据量
|
||
3. 对比找出:
|
||
- 本地缺失的股票(新增)
|
||
- 数据量不一致的股票(有更新,可能包含财务修正)
|
||
4. 返回需要插入的差异数据
|
||
|
||
注意:主键为 (ts_code, end_date, report_type),因此同一支股票在同一季度
|
||
可能有多个 report_type 的记录。差异检测按股票级别进行,如果该股票的
|
||
记录总数不一致,则更新该股票的所有记录。
|
||
|
||
Args:
|
||
remote_df: 从 API 获取的远程数据
|
||
period: 季度字符串
|
||
|
||
Returns:
|
||
(差异数据DataFrame, 统计信息DataFrame)
|
||
统计信息包含:ts_code, remote_count, local_count, status
|
||
"""
|
||
if remote_df.empty:
|
||
return pd.DataFrame(), pd.DataFrame()
|
||
|
||
# 1. 统计远程数据中每只股票的数据量
|
||
remote_counts = remote_df.groupby("ts_code").size().to_dict()
|
||
|
||
# 2. 获取本地数据量(按股票汇总)
|
||
local_counts = self.get_local_data_count_by_stock(period)
|
||
|
||
# 3. 对比找出差异
|
||
diff_stocks = [] # 需要更新的股票列表
|
||
stats = []
|
||
|
||
for ts_code, remote_count in remote_counts.items():
|
||
local_count = local_counts.get(ts_code, 0)
|
||
|
||
if local_count == 0:
|
||
status = "new" # 本地不存在,全部插入
|
||
diff_stocks.append(ts_code)
|
||
elif local_count != remote_count:
|
||
status = "modified" # 数据量不一致,可能包含财务修正
|
||
diff_stocks.append(ts_code)
|
||
else:
|
||
status = "same" # 数据量一致,跳过
|
||
|
||
stats.append(
|
||
{
|
||
"ts_code": ts_code,
|
||
"remote_count": remote_count,
|
||
"local_count": local_count,
|
||
"status": status,
|
||
}
|
||
)
|
||
|
||
# 4. 提取差异数据
|
||
if diff_stocks:
|
||
diff_df = remote_df[remote_df["ts_code"].isin(diff_stocks)].copy()
|
||
else:
|
||
diff_df = pd.DataFrame()
|
||
|
||
stats_df = pd.DataFrame(stats)
|
||
|
||
return diff_df, stats_df
|
||
|
||
# ======================================================================
|
||
# 同步核心逻辑
|
||
# ======================================================================
|
||
|
||
def delete_stock_quarter_data(
|
||
self, period: str, ts_codes: Optional[List[str]] = None
|
||
) -> int:
|
||
"""删除指定季度和股票的数据。
|
||
|
||
在同步前删除旧数据,然后插入新数据(先删除后插入策略)。
|
||
|
||
Args:
|
||
period: 季度字符串 (YYYYMMDD)
|
||
ts_codes: 股票代码列表,None 表示删除该季度所有数据
|
||
|
||
Returns:
|
||
删除的记录数
|
||
"""
|
||
storage = Storage()
|
||
|
||
try:
|
||
# 将 YYYYMMDD 转换为 YYYY-MM-DD 格式
|
||
period_formatted = f"{period[:4]}-{period[4:6]}-{period[6:]}"
|
||
if ts_codes:
|
||
# 删除指定股票的数据
|
||
placeholders = ", ".join(["?" for _ in ts_codes])
|
||
query = f'''
|
||
DELETE FROM "{self.table_name}"
|
||
WHERE end_date = ? AND ts_code IN ({placeholders})
|
||
'''
|
||
result = storage._connection.execute(
|
||
query, [period_formatted] + ts_codes
|
||
)
|
||
else:
|
||
# 删除整个季度的数据
|
||
query = f'DELETE FROM "{self.table_name}" WHERE end_date = ?'
|
||
result = storage._connection.execute(query, [period_formatted])
|
||
|
||
# DuckDB 的 rowcount 可能返回 -1,我们手动查询删除后的数量变化
|
||
# 由于我们已经删除了特定条件的数据,直接返回传入的股票数量作为估算
|
||
if ts_codes:
|
||
deleted_count = len(ts_codes)
|
||
else:
|
||
# 删除整个季度,查询删除前的数量
|
||
deleted_count = -1 # 标记为未知,稍后处理
|
||
return deleted_count
|
||
except Exception as e:
|
||
print(f"[{self.__class__.__name__}] Error deleting data: {e}")
|
||
return 0
|
||
|
||
def sync_quarter(self, period: str, dry_run: bool = False) -> Dict:
|
||
"""同步单个季度的数据。
|
||
|
||
流程:
|
||
1. 获取远程数据
|
||
2. 根据 TARGET_REPORT_TYPE 过滤报表类型
|
||
3. 对比本地数据,找出差异股票
|
||
4. 删除差异股票的旧数据
|
||
5. 插入新数据(先删除后插入)
|
||
|
||
注意:财务数据同步必须取当前季度和前一季度进行对比更新,
|
||
不存在"不需要同步"的情况。
|
||
|
||
Args:
|
||
period: 季度字符串 (YYYYMMDD)
|
||
dry_run: 是否为预览模式
|
||
|
||
Returns:
|
||
同步结果字典 {
|
||
'period': 季度,
|
||
'remote_total': 远程总记录数,
|
||
'diff_count': 差异记录数,
|
||
'deleted_count': 删除的记录数,
|
||
'inserted_count': 插入的记录数,
|
||
'dry_run': 是否预览模式
|
||
}
|
||
"""
|
||
print(f"[{self.__class__.__name__}] Syncing quarter {period}...")
|
||
|
||
# 1. 获取远程数据
|
||
remote_df = self.fetch_single_quarter(period)
|
||
|
||
if remote_df.empty:
|
||
print(f"[{self.__class__.__name__}] No data for quarter {period}")
|
||
return {
|
||
"period": period,
|
||
"remote_total": 0,
|
||
"diff_count": 0,
|
||
"deleted_count": 0,
|
||
"inserted_count": 0,
|
||
"dry_run": dry_run,
|
||
}
|
||
|
||
# 2. 根据 TARGET_REPORT_TYPE 过滤报表类型
|
||
if self.TARGET_REPORT_TYPE and "report_type" in remote_df.columns:
|
||
remote_df = remote_df[remote_df["report_type"] == self.TARGET_REPORT_TYPE]
|
||
|
||
remote_total = len(remote_df)
|
||
print(f"[{self.__class__.__name__}] Fetched {remote_total} records from API")
|
||
|
||
# 3. 检查本地是否有该季度数据
|
||
local_counts = self.get_local_data_count_by_stock(period)
|
||
is_first_sync_for_period = len(local_counts) == 0
|
||
|
||
# 4. 执行同步
|
||
deleted_count = 0
|
||
inserted_count = 0
|
||
|
||
if is_first_sync_for_period:
|
||
# 首次同步该季度:直接插入所有数据(使用事务)
|
||
print(
|
||
f"[{self.__class__.__name__}] First sync for quarter {period}, inserting all data directly"
|
||
)
|
||
|
||
if not dry_run:
|
||
try:
|
||
# 开始事务
|
||
self.storage.begin_transaction()
|
||
self.storage.queue_save(
|
||
self.table_name, remote_df, use_upsert=False
|
||
)
|
||
self.storage.flush(use_transaction=False)
|
||
self.storage.commit_transaction()
|
||
inserted_count = len(remote_df)
|
||
print(
|
||
f"[{self.__class__.__name__}] Inserted {inserted_count} new records (transaction committed)"
|
||
)
|
||
except Exception as e:
|
||
self.storage.rollback_transaction()
|
||
print(f"[{self.__class__.__name__}] Transaction rolled back: {e}")
|
||
raise
|
||
|
||
return {
|
||
"period": period,
|
||
"remote_total": remote_total,
|
||
"diff_count": len(remote_df),
|
||
"deleted_count": 0,
|
||
"inserted_count": inserted_count,
|
||
"dry_run": dry_run,
|
||
}
|
||
|
||
# 5. 非首次同步:对比找出差异股票
|
||
diff_df, stats_df = self.compare_and_find_differences(remote_df, period)
|
||
|
||
diff_stocks = list(diff_df["ts_code"].unique()) if not diff_df.empty else []
|
||
unchanged_count = (
|
||
len(stats_df[stats_df["status"] == "same"]) if not stats_df.empty else 0
|
||
)
|
||
|
||
print(f"[{self.__class__.__name__}] Comparison result:")
|
||
print(f" - Stocks with differences: {len(diff_stocks)}")
|
||
print(f" - Unchanged stocks: {unchanged_count}")
|
||
|
||
if not dry_run and not diff_df.empty:
|
||
try:
|
||
# 开始事务
|
||
self.storage.begin_transaction()
|
||
|
||
# 5.1 删除差异股票的旧数据
|
||
deleted_stocks_count = len(diff_stocks)
|
||
self.delete_stock_quarter_data(period, diff_stocks)
|
||
deleted_count = len(diff_df)
|
||
print(
|
||
f"[{self.__class__.__name__}] Deleted {deleted_stocks_count} stocks' old records (approx {deleted_count} rows)"
|
||
)
|
||
|
||
# 5.2 插入新数据(使用普通 INSERT,因为已删除旧数据)
|
||
self.storage.queue_save(self.table_name, diff_df, use_upsert=False)
|
||
self.storage.flush(use_transaction=False)
|
||
inserted_count = len(diff_df)
|
||
|
||
# 提交事务
|
||
self.storage.commit_transaction()
|
||
print(
|
||
f"[{self.__class__.__name__}] Inserted {inserted_count} new records (transaction committed)"
|
||
)
|
||
|
||
except Exception as e:
|
||
self.storage.rollback_transaction()
|
||
print(f"[{self.__class__.__name__}] Transaction rolled back: {e}")
|
||
raise
|
||
|
||
return {
|
||
"period": period,
|
||
"remote_total": remote_total,
|
||
"diff_count": len(diff_df),
|
||
"deleted_count": deleted_count,
|
||
"inserted_count": inserted_count,
|
||
"dry_run": dry_run,
|
||
}
|
||
|
||
def sync_range(
|
||
self, start_quarter: str, end_quarter: str, dry_run: bool = False
|
||
) -> List[Dict]:
|
||
"""同步指定季度范围的数据。
|
||
|
||
注意:增量同步会自动包含前一季度以确保数据完整性。
|
||
|
||
Args:
|
||
start_quarter: 起始季度 (YYYYMMDD)
|
||
end_quarter: 结束季度 (YYYYMMDD)
|
||
dry_run: 是否为预览模式
|
||
|
||
Returns:
|
||
各季度同步结果列表
|
||
"""
|
||
quarters = get_quarters_in_range(start_quarter, end_quarter)
|
||
|
||
if not quarters:
|
||
print(f"[{self.__class__.__name__}] No quarters to sync")
|
||
return []
|
||
|
||
print(
|
||
f"[{self.__class__.__name__}] Syncing {len(quarters)} quarters: {quarters}"
|
||
)
|
||
|
||
results = []
|
||
for period in tqdm(quarters, desc=f"Syncing {self.table_name}"):
|
||
try:
|
||
result = self.sync_quarter(period, dry_run=dry_run)
|
||
results.append(result)
|
||
except Exception as e:
|
||
print(f"[{self.__class__.__name__}] Error syncing {period}: {e}")
|
||
results.append({"period": period, "error": str(e)})
|
||
|
||
return results
|
||
|
||
def sync_incremental(self, dry_run: bool = False) -> List[Dict]:
|
||
"""执行增量同步。
|
||
|
||
策略:
|
||
1. 确保表存在(首次同步时自动建表)
|
||
2. 获取表中最新季度
|
||
3. 计算当前季度(考虑是否到季末)
|
||
4. 确定同步范围:从最新季度到当前季度
|
||
5. **重要**:额外包含前一季度以确保数据完整性
|
||
|
||
注意:财务数据同步与日线数据不同,必须每次都获取数据进行对比
|
||
更新,不存在"不需要同步"的情况。因为财务数据可能会被修正。
|
||
|
||
Args:
|
||
dry_run: 是否为预览模式
|
||
|
||
Returns:
|
||
各季度同步结果列表
|
||
"""
|
||
print(f"\n{'=' * 60}")
|
||
print(f"[{self.__class__.__name__}] Incremental Sync")
|
||
print(f"{'=' * 60}")
|
||
|
||
# 初始化日志管理器
|
||
log_manager = SyncLogManager()
|
||
log_entry = log_manager.start_sync(
|
||
table_name=self.table_name,
|
||
sync_type="incremental",
|
||
metadata={"dry_run": dry_run},
|
||
)
|
||
|
||
try:
|
||
# 0. 确保表存在(首次同步时自动建表)
|
||
self.ensure_table_exists()
|
||
|
||
# 1. 获取最新季度
|
||
storage = Storage()
|
||
try:
|
||
result = storage._connection.execute(
|
||
f'SELECT MAX(end_date) FROM "{self.table_name}"'
|
||
).fetchone()
|
||
latest_quarter = result[0] if result and result[0] else None
|
||
if hasattr(latest_quarter, "strftime"):
|
||
latest_quarter = latest_quarter.strftime("%Y%m%d")
|
||
except Exception as e:
|
||
print(f"[{self.__class__.__name__}] Error getting latest quarter: {e}")
|
||
latest_quarter = None
|
||
|
||
# 2. 获取当前季度
|
||
current_quarter = self.get_current_quarter()
|
||
|
||
if latest_quarter is None:
|
||
# 无本地数据,执行全量同步
|
||
print(
|
||
f"[{self.__class__.__name__}] No local data, performing full sync"
|
||
)
|
||
results = self.sync_range(
|
||
self.DEFAULT_START_DATE, current_quarter, dry_run
|
||
)
|
||
else:
|
||
print(
|
||
f"[{self.__class__.__name__}] Latest local quarter: {latest_quarter}"
|
||
)
|
||
print(f"[{self.__class__.__name__}] Current quarter: {current_quarter}")
|
||
|
||
# 3. 确定同步范围
|
||
start_quarter = latest_quarter
|
||
if start_quarter > current_quarter:
|
||
start_quarter = current_quarter
|
||
else:
|
||
start_quarter = self.get_prev_quarter(latest_quarter)
|
||
|
||
if start_quarter < self.DEFAULT_START_DATE:
|
||
start_quarter = self.DEFAULT_START_DATE
|
||
|
||
# 打印同步的两个季度信息
|
||
print(f"\n[{self.__class__.__name__}] 将同步以下两个季度的财报:")
|
||
print(f" - 前一季度: {start_quarter}")
|
||
print(f" - 当前季度: {current_quarter}")
|
||
print(f" (包含前一季度以确保数据完整性)")
|
||
print()
|
||
|
||
results = self.sync_range(start_quarter, current_quarter, dry_run)
|
||
|
||
# 计算总插入记录数
|
||
total_inserted = sum(
|
||
r.get("inserted_count", 0) for r in results if isinstance(r, dict)
|
||
)
|
||
|
||
# 完成日志记录
|
||
log_manager.complete_sync(
|
||
log_entry,
|
||
status="success",
|
||
records_inserted=total_inserted,
|
||
records_updated=0,
|
||
records_deleted=0,
|
||
)
|
||
|
||
return results
|
||
|
||
except Exception as e:
|
||
log_manager.complete_sync(log_entry, status="failed", error_message=str(e))
|
||
raise
|
||
|
||
def sync_full(self, dry_run: bool = False) -> List[Dict]:
|
||
"""执行全量同步。
|
||
|
||
Args:
|
||
dry_run: 是否为预览模式
|
||
|
||
Returns:
|
||
各季度同步结果列表
|
||
"""
|
||
print(f"\n{'=' * 60}")
|
||
print(f"[{self.__class__.__name__}] Full Sync")
|
||
print(f"{'=' * 60}")
|
||
|
||
# 初始化日志管理器
|
||
log_manager = SyncLogManager()
|
||
log_entry = log_manager.start_sync(
|
||
table_name=self.table_name, sync_type="full", metadata={"dry_run": dry_run}
|
||
)
|
||
|
||
try:
|
||
# 确保表存在
|
||
self.ensure_table_exists()
|
||
|
||
current_quarter = self.get_current_quarter()
|
||
results = self.sync_range(self.DEFAULT_START_DATE, current_quarter, dry_run)
|
||
|
||
# 计算总插入记录数
|
||
total_inserted = sum(
|
||
r.get("inserted_count", 0) for r in results if isinstance(r, dict)
|
||
)
|
||
|
||
# 完成日志记录
|
||
log_manager.complete_sync(
|
||
log_entry,
|
||
status="success",
|
||
records_inserted=total_inserted,
|
||
records_updated=0,
|
||
records_deleted=0,
|
||
)
|
||
|
||
return results
|
||
|
||
except Exception as e:
|
||
log_manager.complete_sync(log_entry, status="failed", error_message=str(e))
|
||
raise
|
||
|
||
# ======================================================================
|
||
# 预览模式
|
||
# ======================================================================
|
||
|
||
def preview_sync(self) -> Dict:
|
||
"""预览同步信息(不实际同步)。
|
||
|
||
注意:财务数据同步必须每次都进行,因为数据可能会被修正。
|
||
预览显示将要同步的季度范围。
|
||
|
||
Returns:
|
||
预览信息字典
|
||
"""
|
||
print(f"\n{'=' * 60}")
|
||
print(f"[{self.__class__.__name__}] Preview Mode")
|
||
print(f"{'=' * 60}")
|
||
|
||
# 获取最新季度
|
||
storage = Storage()
|
||
try:
|
||
result = storage._connection.execute(
|
||
f'SELECT MAX(end_date) FROM "{self.table_name}"'
|
||
).fetchone()
|
||
latest_quarter = result[0] if result and result[0] else None
|
||
if hasattr(latest_quarter, "strftime"):
|
||
latest_quarter = latest_quarter.strftime("%Y%m%d")
|
||
except Exception:
|
||
latest_quarter = None
|
||
|
||
current_quarter = self.get_current_quarter()
|
||
|
||
if latest_quarter is None:
|
||
# 无本地数据,需要全量同步
|
||
start_quarter = self.DEFAULT_START_DATE
|
||
message = "No local data, full sync required"
|
||
else:
|
||
# 财务数据必须每次都进行对比更新
|
||
# 同步范围:从最新季度到当前季度(包含前一季度)
|
||
start_quarter = self.get_prev_quarter(latest_quarter)
|
||
if start_quarter < self.DEFAULT_START_DATE:
|
||
start_quarter = self.DEFAULT_START_DATE
|
||
message = f"Incremental sync from {start_quarter} to {current_quarter}"
|
||
|
||
preview = {
|
||
"table_name": self.table_name,
|
||
"api_name": self.api_name,
|
||
"latest_quarter": latest_quarter,
|
||
"current_quarter": current_quarter,
|
||
"start_quarter": start_quarter,
|
||
"end_quarter": current_quarter,
|
||
"message": message,
|
||
}
|
||
|
||
print(f"Table: {self.table_name}")
|
||
print(f"API: {self.api_name}")
|
||
print(f"Latest local: {latest_quarter}")
|
||
print(f"Current quarter: {current_quarter}")
|
||
print(f"Sync range: {start_quarter} -> {current_quarter}")
|
||
print(f"Message: {message}")
|
||
print(f"{'=' * 60}")
|
||
|
||
return preview
|
||
|
||
|
||
# ======================================================================
|
||
# 便捷函数
|
||
# ======================================================================
|
||
|
||
|
||
def sync_financial_data(
|
||
syncer_class: type, force_full: bool = False, dry_run: bool = False
|
||
) -> List[Dict]:
|
||
"""通用的财务数据同步便捷函数。
|
||
|
||
Args:
|
||
syncer_class: QuarterBasedSync 的子类
|
||
force_full: 是否强制全量同步
|
||
dry_run: 是否为预览模式
|
||
|
||
Returns:
|
||
同步结果列表
|
||
"""
|
||
syncer = syncer_class()
|
||
|
||
if force_full:
|
||
return syncer.sync_full(dry_run)
|
||
else:
|
||
return syncer.sync_incremental(dry_run)
|
||
|
||
|
||
def preview_financial_sync(syncer_class: type) -> Dict:
|
||
"""预览财务数据同步信息。
|
||
|
||
Args:
|
||
syncer_class: QuarterBasedSync 的子类
|
||
|
||
Returns:
|
||
预览信息字典
|
||
"""
|
||
syncer = syncer_class()
|
||
return syncer.preview_sync()
|