refactor(financial-sync): 重构财务数据同步架构

- 新增 base_financial_sync.py 基础同步抽象类
- 重构 api_financial_sync.py 简化调度逻辑
- 重命名 IncomeSync 为 IncomeQuarterSync 继承新基础类
- 增强 storage.py 支持 use_upsert 参数
- 更新 __init__.py 导出符号
This commit is contained in:
2026-03-08 00:30:04 +08:00
parent c01bf76a3d
commit 85044a74c6
6 changed files with 1360 additions and 927 deletions

View File

@@ -0,0 +1,756 @@
"""财务数据同步基础抽象模块。
提供专门用于按季度同步财务数据的基类 QuarterBasedSync。
财务数据特点:
- 按季度发布period: 20231231, 20230930, 20230630, 20230331
- 使用 VIP 接口一次性获取某季度的全部上市公司数据
- 数据可能会修正,增量同步需获取当前季度+前一季度
- 主键为 (ts_code, end_date)
使用方式:
class IncomeQuarterSync(QuarterBasedSync):
table_name = "financial_income"
api_name = "income_vip"
def fetch_single_quarter(self, period: str) -> pd.DataFrame:
# 实现单季度数据获取
...
"""
from abc import ABC, abstractmethod
from typing import Optional, Dict, List, Tuple, Set
from datetime import datetime
import pandas as pd
from tqdm import tqdm
from src.data.client import TushareClient
from src.data.storage import ThreadSafeStorage, Storage
from src.data.utils import get_today_date, get_quarters_in_range, DEFAULT_START_DATE
class QuarterBasedSync(ABC):
"""财务数据季度同步抽象基类。
专门处理按季度同步的财务数据(利润表、资产负债表、现金流量表)。
财务数据同步特点:
1. 按季度获取:使用 VIP 接口一次性获取某季度全部上市公司数据
2. 数据可修正:同一季度数据可能被更新,增量同步需获取当前季度+前一季度
3. 差异检测:需对比本地与远程数据量,识别缺失或变更的记录
4. 主键:(ts_code, end_date)
子类必须实现:
- table_name: 类属性,目标表名
- api_name: 类属性Tushare API 接口名
- fetch_single_quarter(period) -> pd.DataFrame: 获取单季度数据
- TABLE_SCHEMA: 类属性,表结构定义
Attributes:
table_name: 目标表名(子类必须覆盖)
api_name: Tushare API 接口名(子类必须覆盖)
DEFAULT_START_DATE: 默认起始日期2018Q1
TABLE_SCHEMA: 表结构定义 {列名: SQL类型}
TABLE_INDEXES: 索引定义 [(索引名, [列名列表]), ...]
注意:不要创建唯一索引,因为财务数据可能发生多次修正
"""
table_name: str = "" # 子类必须覆盖
api_name: str = "" # 子类必须覆盖
DEFAULT_START_DATE = "20180331" # 2018年Q1
# 目标报表类型(子类可覆盖)
# 默认只同步合并报表report_type='1'
# 设为 None 则同步所有报表类型
TARGET_REPORT_TYPE: Optional[str] = "1"
# 表结构定义(子类必须覆盖)
TABLE_SCHEMA: Dict[str, str] = {}
# 索引定义(子类可覆盖)
# 格式: [("index_name", ["col1", "col2"]), ...]
# 注意:不要创建唯一索引,因为财务数据可能发生多次修正
TABLE_INDEXES: List[Tuple[str, List[str]]] = []
def __init__(self):
"""初始化季度同步管理器。"""
self.storage = ThreadSafeStorage()
self.client = TushareClient()
self._cached_data: Optional[pd.DataFrame] = None
# ======================================================================
# 抽象方法 - 子类必须实现
# ======================================================================
@abstractmethod
def fetch_single_quarter(self, period: str) -> pd.DataFrame:
"""获取单季度的全部上市公司数据。
Args:
period: 报告期,季度最后一天日期(如 '20231231'
Returns:
包含该季度全部上市公司财务数据的 DataFrame
"""
pass
# ======================================================================
# 季度计算工具方法
# ======================================================================
def get_current_quarter(self) -> str:
"""获取当前季度(考虑是否到季末)。
如果当前日期未到季度最后一天,则返回前一季度。
这样可以避免获取尚无数据的未来季度。
Returns:
当前季度字符串 (YYYYMMDD),如 '20231231'
"""
today = get_today_date()
year = int(today[:4])
month = int(today[4:6])
# 确定当前季度
if month <= 3:
current_q = f"{year}0331"
elif month <= 6:
current_q = f"{year}0630"
elif month <= 9:
current_q = f"{year}0930"
else:
current_q = f"{year}1231"
# 如果今天还没到季末,返回前一季度
if today < current_q:
return self.get_prev_quarter(current_q)
return current_q
def get_prev_quarter(self, quarter: str) -> str:
"""获取前一季度。
Args:
quarter: 季度字符串 (YYYYMMDD),如 '20231231'
Returns:
前一季度字符串 (YYYYMMDD)
"""
year = int(quarter[:4])
month_day = quarter[4:]
if month_day == "0331":
return f"{year - 1}1231"
elif month_day == "0630":
return f"{year}0331"
elif month_day == "0930":
return f"{year}0630"
else: # "1231"
return f"{year}0930"
def get_next_quarter(self, quarter: str) -> str:
"""获取下一季度。
Args:
quarter: 季度字符串 (YYYYMMDD)
Returns:
下一季度字符串 (YYYYMMDD)
"""
year = int(quarter[:4])
month_day = quarter[4:]
if month_day == "0331":
return f"{year}0630"
elif month_day == "0630":
return f"{year}0930"
elif month_day == "0930":
return f"{year}1231"
else: # "1231"
return f"{year + 1}0331"
# ======================================================================
# 表结构管理
# ======================================================================
def ensure_table_exists(self) -> None:
"""确保表结构存在,如不存在则创建表和索引。
注意:不设置主键和唯一索引,因为财务数据可能发生多次修正,
同一支股票在同一季度可能有多个版本不同的ann_date
DuckDB 会自动创建 rowid 作为主键。
"""
storage = Storage()
if storage.exists(self.table_name):
return
if not self.TABLE_SCHEMA:
print(
f"[{self.__class__.__name__}] TABLE_SCHEMA not defined, skipping table creation"
)
return
# 构建列定义(不设置主键)
columns_def = []
for col_name, col_type in self.TABLE_SCHEMA.items():
columns_def.append(f'"{col_name}" {col_type}')
columns_sql = ", ".join(columns_def)
create_sql = f'CREATE TABLE IF NOT EXISTS "{self.table_name}" ({columns_sql})'
try:
storage._connection.execute(create_sql)
print(f"[{self.__class__.__name__}] Created table '{self.table_name}'")
except Exception as e:
print(f"[{self.__class__.__name__}] Error creating table: {e}")
raise
# 创建普通索引(不要创建唯一索引)
for idx_name, idx_cols in self.TABLE_INDEXES:
try:
idx_cols_sql = ", ".join(f'"{col}"' for col in idx_cols)
storage._connection.execute(
f'CREATE INDEX IF NOT EXISTS "{idx_name}" ON "{self.table_name}"({idx_cols_sql})'
)
print(
f"[{self.__class__.__name__}] Created index '{idx_name}' on {idx_cols}"
)
except Exception as e:
print(
f"[{self.__class__.__name__}] Error creating index {idx_name}: {e}"
)
# ======================================================================
# 数据差异检测(核心逻辑)
# ======================================================================
def get_local_data_count_by_stock(self, period: str) -> Dict[str, int]:
"""获取本地数据库中某季度的各股票数据量。
Args:
period: 季度字符串 (YYYYMMDD)
Returns:
字典 {ts_code: 记录数}
"""
storage = Storage()
try:
# 将 YYYYMMDD 转换为 YYYY-MM-DD 格式
period_formatted = f"{period[:4]}-{period[4:6]}-{period[6:]}"
query = f'''
SELECT ts_code, COUNT(*) as cnt
FROM "{self.table_name}"
WHERE end_date = ?
GROUP BY ts_code
'''
result = storage._connection.execute(query, [period_formatted]).fetchdf()
if result.empty:
return {}
return dict(zip(result["ts_code"], result["cnt"]))
except Exception as e:
print(f"[{self.__class__.__name__}] Error querying local data count: {e}")
return {}
def get_local_records_by_key(self, period: str) -> Dict[tuple, int]:
"""获取本地数据库中某季度的记录(按主键分组计数)。
用于更精确的差异检测,按 (ts_code, end_date, report_type) 分组。
Args:
period: 季度字符串 (YYYYMMDD)
Returns:
字典 {(ts_code, end_date, report_type): 记录数}
"""
storage = Storage()
try:
# 将 YYYYMMDD 转换为 YYYY-MM-DD 格式
period_formatted = f"{period[:4]}-{period[4:6]}-{period[6:]}"
query = f'''
SELECT ts_code, end_date, report_type, COUNT(*) as cnt
FROM "{self.table_name}"
WHERE end_date = ?
GROUP BY ts_code, end_date, report_type
'''
result = storage._connection.execute(query, [period_formatted]).fetchdf()
if result.empty:
return {}
return {
(row["ts_code"], row["end_date"], row["report_type"]): row["cnt"]
for _, row in result.iterrows()
}
except Exception as e:
print(f"[{self.__class__.__name__}] Error querying local records: {e}")
return {}
def compare_and_find_differences(
self, remote_df: pd.DataFrame, period: str
) -> Tuple[pd.DataFrame, pd.DataFrame]:
"""对比远程数据与本地数据,找出差异。
逻辑:
1. 统计远程数据中每只股票的数据量
2. 查询本地数据库中该季度每只股票的数据量
3. 对比找出:
- 本地缺失的股票(新增)
- 数据量不一致的股票(有更新,可能包含财务修正)
4. 返回需要插入的差异数据
注意:主键为 (ts_code, end_date, report_type),因此同一支股票在同一季度
可能有多个 report_type 的记录。差异检测按股票级别进行,如果该股票的
记录总数不一致,则更新该股票的所有记录。
Args:
remote_df: 从 API 获取的远程数据
period: 季度字符串
Returns:
(差异数据DataFrame, 统计信息DataFrame)
统计信息包含ts_code, remote_count, local_count, status
"""
if remote_df.empty:
return pd.DataFrame(), pd.DataFrame()
# 1. 统计远程数据中每只股票的数据量
remote_counts = remote_df.groupby("ts_code").size().to_dict()
# 2. 获取本地数据量(按股票汇总)
local_counts = self.get_local_data_count_by_stock(period)
# 3. 对比找出差异
diff_stocks = [] # 需要更新的股票列表
stats = []
for ts_code, remote_count in remote_counts.items():
local_count = local_counts.get(ts_code, 0)
if local_count == 0:
status = "new" # 本地不存在,全部插入
diff_stocks.append(ts_code)
elif local_count != remote_count:
status = "modified" # 数据量不一致,可能包含财务修正
diff_stocks.append(ts_code)
else:
status = "same" # 数据量一致,跳过
stats.append(
{
"ts_code": ts_code,
"remote_count": remote_count,
"local_count": local_count,
"status": status,
}
)
# 4. 提取差异数据
if diff_stocks:
diff_df = remote_df[remote_df["ts_code"].isin(diff_stocks)].copy()
else:
diff_df = pd.DataFrame()
stats_df = pd.DataFrame(stats)
return diff_df, stats_df
# ======================================================================
# 同步核心逻辑
# ======================================================================
def delete_stock_quarter_data(
self, period: str, ts_codes: Optional[List[str]] = None
) -> int:
"""删除指定季度和股票的数据。
在同步前删除旧数据,然后插入新数据(先删除后插入策略)。
Args:
period: 季度字符串 (YYYYMMDD)
ts_codes: 股票代码列表None 表示删除该季度所有数据
Returns:
删除的记录数
"""
storage = Storage()
try:
# 将 YYYYMMDD 转换为 YYYY-MM-DD 格式
period_formatted = f"{period[:4]}-{period[4:6]}-{period[6:]}"
if ts_codes:
# 删除指定股票的数据
placeholders = ", ".join(["?" for _ in ts_codes])
query = f'''
DELETE FROM "{self.table_name}"
WHERE end_date = ? AND ts_code IN ({placeholders})
'''
result = storage._connection.execute(
query, [period_formatted] + ts_codes
)
else:
# 删除整个季度的数据
query = f'DELETE FROM "{self.table_name}" WHERE end_date = ?'
result = storage._connection.execute(query, [period_formatted])
# DuckDB 的 rowcount 可能返回 -1我们手动查询删除后的数量变化
# 由于我们已经删除了特定条件的数据,直接返回传入的股票数量作为估算
if ts_codes:
deleted_count = len(ts_codes)
else:
# 删除整个季度,查询删除前的数量
deleted_count = -1 # 标记为未知,稍后处理
return deleted_count
except Exception as e:
print(f"[{self.__class__.__name__}] Error deleting data: {e}")
return 0
def sync_quarter(self, period: str, dry_run: bool = False) -> Dict:
"""同步单个季度的数据。
流程:
1. 获取远程数据
2. 根据 TARGET_REPORT_TYPE 过滤报表类型
3. 对比本地数据,找出差异股票
4. 删除差异股票的旧数据
5. 插入新数据(先删除后插入)
注意:财务数据同步必须取当前季度和前一季度进行对比更新,
不存在"不需要同步"的情况。
Args:
period: 季度字符串 (YYYYMMDD)
dry_run: 是否为预览模式
Returns:
同步结果字典 {
'period': 季度,
'remote_total': 远程总记录数,
'diff_count': 差异记录数,
'deleted_count': 删除的记录数,
'inserted_count': 插入的记录数,
'dry_run': 是否预览模式
}
"""
print(f"[{self.__class__.__name__}] Syncing quarter {period}...")
# 1. 获取远程数据
remote_df = self.fetch_single_quarter(period)
if remote_df.empty:
print(f"[{self.__class__.__name__}] No data for quarter {period}")
return {
"period": period,
"remote_total": 0,
"diff_count": 0,
"deleted_count": 0,
"inserted_count": 0,
"dry_run": dry_run,
}
# 2. 根据 TARGET_REPORT_TYPE 过滤报表类型
if self.TARGET_REPORT_TYPE and "report_type" in remote_df.columns:
remote_df = remote_df[remote_df["report_type"] == self.TARGET_REPORT_TYPE]
remote_total = len(remote_df)
print(f"[{self.__class__.__name__}] Fetched {remote_total} records from API")
# 3. 检查本地是否有该季度数据
local_counts = self.get_local_data_count_by_stock(period)
is_first_sync_for_period = len(local_counts) == 0
# 4. 执行同步
deleted_count = 0
inserted_count = 0
if is_first_sync_for_period:
# 首次同步该季度:直接插入所有数据
print(
f"[{self.__class__.__name__}] First sync for quarter {period}, inserting all data directly"
)
if not dry_run:
self.storage.queue_save(self.table_name, remote_df, use_upsert=False)
self.storage.flush()
inserted_count = len(remote_df)
print(
f"[{self.__class__.__name__}] Inserted {inserted_count} new records"
)
return {
"period": period,
"remote_total": remote_total,
"diff_count": len(remote_df),
"deleted_count": 0,
"inserted_count": inserted_count,
"dry_run": dry_run,
}
# 5. 非首次同步:对比找出差异股票
diff_df, stats_df = self.compare_and_find_differences(remote_df, period)
diff_stocks = list(diff_df["ts_code"].unique()) if not diff_df.empty else []
unchanged_count = (
len(stats_df[stats_df["status"] == "same"]) if not stats_df.empty else 0
)
print(f"[{self.__class__.__name__}] Comparison result:")
print(f" - Stocks with differences: {len(diff_stocks)}")
print(f" - Unchanged stocks: {unchanged_count}")
if not dry_run and not diff_df.empty:
# 5.1 删除差异股票的旧数据
deleted_stocks_count = len(diff_stocks)
self.delete_stock_quarter_data(period, diff_stocks)
deleted_count = len(diff_df)
print(
f"[{self.__class__.__name__}] Deleted {deleted_stocks_count} stocks' old records (approx {deleted_count} rows)"
)
# 5.2 插入新数据(使用普通 INSERT因为已删除旧数据
self.storage.queue_save(self.table_name, diff_df, use_upsert=False)
self.storage.flush()
inserted_count = len(diff_df)
print(f"[{self.__class__.__name__}] Inserted {inserted_count} new records")
return {
"period": period,
"remote_total": remote_total,
"diff_count": len(diff_df),
"deleted_count": deleted_count,
"inserted_count": inserted_count,
"dry_run": dry_run,
}
def sync_range(
self, start_quarter: str, end_quarter: str, dry_run: bool = False
) -> List[Dict]:
"""同步指定季度范围的数据。
注意:增量同步会自动包含前一季度以确保数据完整性。
Args:
start_quarter: 起始季度 (YYYYMMDD)
end_quarter: 结束季度 (YYYYMMDD)
dry_run: 是否为预览模式
Returns:
各季度同步结果列表
"""
quarters = get_quarters_in_range(start_quarter, end_quarter)
if not quarters:
print(f"[{self.__class__.__name__}] No quarters to sync")
return []
print(
f"[{self.__class__.__name__}] Syncing {len(quarters)} quarters: {quarters}"
)
results = []
for period in tqdm(quarters, desc=f"Syncing {self.table_name}"):
try:
result = self.sync_quarter(period, dry_run=dry_run)
results.append(result)
except Exception as e:
print(f"[{self.__class__.__name__}] Error syncing {period}: {e}")
results.append({"period": period, "error": str(e)})
return results
def sync_incremental(self, dry_run: bool = False) -> List[Dict]:
"""执行增量同步。
策略:
1. 确保表存在(首次同步时自动建表)
2. 获取表中最新季度
3. 计算当前季度(考虑是否到季末)
4. 确定同步范围:从最新季度到当前季度
5. **重要**:额外包含前一季度以确保数据完整性
注意:财务数据同步与日线数据不同,必须每次都获取数据进行对比
更新,不存在"不需要同步"的情况。因为财务数据可能会被修正。
Args:
dry_run: 是否为预览模式
Returns:
各季度同步结果列表
"""
print(f"\n{'=' * 60}")
print(f"[{self.__class__.__name__}] Incremental Sync")
print(f"{'=' * 60}")
# 0. 确保表存在(首次同步时自动建表)
self.ensure_table_exists()
# 1. 获取最新季度
storage = Storage()
try:
result = storage._connection.execute(
f'SELECT MAX(end_date) FROM "{self.table_name}"'
).fetchone()
latest_quarter = result[0] if result and result[0] else None
if hasattr(latest_quarter, "strftime"):
latest_quarter = latest_quarter.strftime("%Y%m%d")
except Exception as e:
print(f"[{self.__class__.__name__}] Error getting latest quarter: {e}")
latest_quarter = None
# 2. 获取当前季度
current_quarter = self.get_current_quarter()
if latest_quarter is None:
# 无本地数据,执行全量同步
print(f"[{self.__class__.__name__}] No local data, performing full sync")
return self.sync_range(self.DEFAULT_START_DATE, current_quarter, dry_run)
print(f"[{self.__class__.__name__}] Latest local quarter: {latest_quarter}")
print(f"[{self.__class__.__name__}] Current quarter: {current_quarter}")
# 3. 确定同步范围
# 财务数据必须每次都进行对比更新,不存在"跳过"的情况
# 同步范围:从最新季度到当前季度(包含前一季度以确保数据完整性)
start_quarter = latest_quarter
if start_quarter > current_quarter:
# 如果本地数据比当前季度还新,仍然需要同步(可能包含修正数据)
start_quarter = current_quarter
else:
# 正常情况:包含前一季度
start_quarter = self.get_prev_quarter(latest_quarter)
if start_quarter < self.DEFAULT_START_DATE:
start_quarter = self.DEFAULT_START_DATE
# 打印同步的两个季度信息
print(f"\n[{self.__class__.__name__}] 将同步以下两个季度的财报:")
print(f" - 前一季度: {start_quarter}")
print(f" - 当前季度: {current_quarter}")
print(f" (包含前一季度以确保数据完整性)")
print()
return self.sync_range(start_quarter, current_quarter, dry_run)
def sync_full(self, dry_run: bool = False) -> List[Dict]:
"""执行全量同步。
Args:
dry_run: 是否为预览模式
Returns:
各季度同步结果列表
"""
print(f"\n{'=' * 60}")
print(f"[{self.__class__.__name__}] Full Sync")
print(f"{'=' * 60}")
# 确保表存在
self.ensure_table_exists()
current_quarter = self.get_current_quarter()
return self.sync_range(self.DEFAULT_START_DATE, current_quarter, dry_run)
# ======================================================================
# 预览模式
# ======================================================================
def preview_sync(self) -> Dict:
"""预览同步信息(不实际同步)。
注意:财务数据同步必须每次都进行,因为数据可能会被修正。
预览显示将要同步的季度范围。
Returns:
预览信息字典
"""
print(f"\n{'=' * 60}")
print(f"[{self.__class__.__name__}] Preview Mode")
print(f"{'=' * 60}")
# 获取最新季度
storage = Storage()
try:
result = storage._connection.execute(
f'SELECT MAX(end_date) FROM "{self.table_name}"'
).fetchone()
latest_quarter = result[0] if result and result[0] else None
if hasattr(latest_quarter, "strftime"):
latest_quarter = latest_quarter.strftime("%Y%m%d")
except Exception:
latest_quarter = None
current_quarter = self.get_current_quarter()
if latest_quarter is None:
# 无本地数据,需要全量同步
start_quarter = self.DEFAULT_START_DATE
message = "No local data, full sync required"
else:
# 财务数据必须每次都进行对比更新
# 同步范围:从最新季度到当前季度(包含前一季度)
start_quarter = self.get_prev_quarter(latest_quarter)
if start_quarter < self.DEFAULT_START_DATE:
start_quarter = self.DEFAULT_START_DATE
message = f"Incremental sync from {start_quarter} to {current_quarter}"
preview = {
"table_name": self.table_name,
"api_name": self.api_name,
"latest_quarter": latest_quarter,
"current_quarter": current_quarter,
"start_quarter": start_quarter,
"end_quarter": current_quarter,
"message": message,
}
print(f"Table: {self.table_name}")
print(f"API: {self.api_name}")
print(f"Latest local: {latest_quarter}")
print(f"Current quarter: {current_quarter}")
print(f"Sync range: {start_quarter} -> {current_quarter}")
print(f"Message: {message}")
print(f"{'=' * 60}")
return preview
# ======================================================================
# 便捷函数
# ======================================================================
def sync_financial_data(
syncer_class: type, force_full: bool = False, dry_run: bool = False
) -> List[Dict]:
"""通用的财务数据同步便捷函数。
Args:
syncer_class: QuarterBasedSync 的子类
force_full: 是否强制全量同步
dry_run: 是否为预览模式
Returns:
同步结果列表
"""
syncer = syncer_class()
if force_full:
return syncer.sync_full(dry_run)
else:
return syncer.sync_incremental(dry_run)
def preview_financial_sync(syncer_class: type) -> Dict:
"""预览财务数据同步信息。
Args:
syncer_class: QuarterBasedSync 的子类
Returns:
预览信息字典
"""
syncer = syncer_class()
return syncer.preview_sync()