feat(data): 封装ST股票列表接口(stock_st)

- 新增 api_stock_st.py,实现ST股票数据获取和日期遍历同步
- 更新 sync.py,将ST股票同步加入第7步流程
- 移除 base_sync.py 中未使用的 get_last_n_trading_days 导入
This commit is contained in:
2026-03-03 22:04:22 +08:00
parent 472b2b665a
commit 317ecd87e7
8 changed files with 1543 additions and 73 deletions

File diff suppressed because it is too large Load Diff

View File

@@ -11,16 +11,19 @@ Available APIs:
- api_trade_cal: Trading calendar (交易日历) - api_trade_cal: Trading calendar (交易日历)
- api_namechange: Stock name change history (股票曾用名) - api_namechange: Stock name change history (股票曾用名)
- api_bak_basic: Stock historical list (股票历史列表) - api_bak_basic: Stock historical list (股票历史列表)
- api_stock_st: ST stock list (ST股票列表)
Example: Example:
>>> from src.data.api_wrappers import get_daily, get_stock_basic, get_trade_cal, get_bak_basic >>> from src.data.api_wrappers import get_daily, get_stock_basic, get_trade_cal, get_bak_basic
>>> from src.data.api_wrappers import get_pro_bar, sync_pro_bar, get_daily_basic, sync_daily_basic >>> from src.data.api_wrappers import get_pro_bar, sync_pro_bar, get_daily_basic, sync_daily_basic
>>> from src.data.api_wrappers import get_stock_st, sync_stock_st
>>> data = get_daily('000001.SZ', start_date='20240101', end_date='20240131') >>> data = get_daily('000001.SZ', start_date='20240101', end_date='20240131')
>>> pro_data = get_pro_bar('000001.SZ', start_date='20240101', end_date='20240131') >>> pro_data = get_pro_bar('000001.SZ', start_date='20240101', end_date='20240131')
>>> daily_basic = get_daily_basic(trade_date='20240101') >>> daily_basic = get_daily_basic(trade_date='20240101')
>>> stocks = get_stock_basic() >>> stocks = get_stock_basic()
>>> calendar = get_trade_cal('20240101', '20240131') >>> calendar = get_trade_cal('20240101', '20240131')
>>> bak_basic = get_bak_basic(trade_date='20240101') >>> bak_basic = get_bak_basic(trade_date='20240101')
>>> stock_st = get_stock_st(trade_date='20240101')
""" """
from src.data.api_wrappers.api_daily import ( from src.data.api_wrappers.api_daily import (
@@ -49,6 +52,11 @@ from src.data.api_wrappers.financial_data.api_income import (
from src.data.api_wrappers.api_bak_basic import get_bak_basic, sync_bak_basic from src.data.api_wrappers.api_bak_basic import get_bak_basic, sync_bak_basic
from src.data.api_wrappers.api_namechange import get_namechange, sync_namechange from src.data.api_wrappers.api_namechange import get_namechange, sync_namechange
from src.data.api_wrappers.api_stock_basic import get_stock_basic, sync_all_stocks from src.data.api_wrappers.api_stock_basic import get_stock_basic, sync_all_stocks
from src.data.api_wrappers.api_stock_st import (
get_stock_st,
sync_stock_st,
StockSTSync,
)
from src.data.api_wrappers.api_trade_cal import ( from src.data.api_wrappers.api_trade_cal import (
get_trade_cal, get_trade_cal,
get_trading_days, get_trading_days,
@@ -92,4 +100,8 @@ __all__ = [
"get_first_trading_day", "get_first_trading_day",
"get_last_trading_day", "get_last_trading_day",
"sync_trade_cal_cache", "sync_trade_cal_cache",
# ST stock list
"get_stock_st",
"sync_stock_st",
"StockSTSync",
] ]

View File

@@ -565,4 +565,57 @@ df = pro.query('daily_basic', ts_code='', trade_date='20180726',fields='ts_code,
16 300718.SZ 20180726 17.6612 0.92 32.0239 3.8661 16 300718.SZ 20180726 17.6612 0.92 32.0239 3.8661
17 000708.SZ 20180726 0.5575 0.70 10.3674 1.0276 17 000708.SZ 20180726 0.5575 0.70 10.3674 1.0276
18 002626.SZ 20180726 0.6187 0.83 22.7580 4.2446 18 002626.SZ 20180726 0.6187 0.83 22.7580 4.2446
19 600816.SH 20180726 0.6745 0.65 11.0778 3.2214 19 600816.SH 20180726 0.6745 0.65 11.0778 3.2214
ST股票列表
接口stock_st可以通过数据工具调试和查看数据。
描述获取ST股票列表可根据交易日期获取历史上每天的ST列表
权限3000积分起
提示每天上午9:20更新单次请求最大返回1000行数据可循环提取,本接口数据从20160101开始,太早历史无法补齐
输入参数
名称 类型 必选 描述
ts_code str N 股票代码
trade_date str N 交易日期格式YYYYMMDD下同
start_date str N 开始时间
end_date str N 结束时间
输出参数
名称 类型 默认显示 描述
ts_code str Y 股票代码
name str Y 股票名称
trade_date str Y 交易日期
type str Y 类型
type_name str Y 类型名称
接口用法
pro = ts.pro_api()
#获取20250813日所有的ST股票
df = pro.stock_st(trade_date='20250813')
数据样例
ts_code name trade_date type type_name
0 300313.SZ *ST天山 20250813 ST 风险警示板
1 605081.SH *ST太和 20250813 ST 风险警示板
2 300391.SZ *ST长药 20250813 ST 风险警示板
3 300343.SZ ST联创 20250813 ST 风险警示板
4 300044.SZ ST赛为 20250813 ST 风险警示板
.. ... ... ... ... ...
170 300175.SZ ST朗源 20250813 ST 风险警示板
171 603721.SH *ST天择 20250813 ST 风险警示板
172 600289.SH ST信通 20250813 ST 风险警示板
173 000929.SZ *ST兰黄 20250813 ST 风险警示板
174 000638.SZ *ST万方 20250813 ST 风险警示板

View File

@@ -0,0 +1,147 @@
"""ST股票列表接口。
获取ST股票列表数据可根据交易日期获取历史上每天的ST列表。
数据从20160101开始可用每天上午9:20更新。
"""
import pandas as pd
from typing import Optional
from src.data.client import TushareClient
from src.data.api_wrappers.base_sync import DateBasedSync
def get_stock_st(
trade_date: Optional[str] = None,
start_date: Optional[str] = None,
end_date: Optional[str] = None,
ts_code: Optional[str] = None,
) -> pd.DataFrame:
"""Fetch ST stock list from Tushare.
This interface retrieves the daily ST stock list including stock codes,
names, and ST type information. Data is available from 20160101 onwards.
Updates at 9:20 AM daily.
Args:
trade_date: Specific trade date in YYYYMMDD format
start_date: Start date for date range query (YYYYMMDD format)
end_date: End date for date range query (YYYYMMDD format)
ts_code: Stock code filter (optional, e.g., '000001.SZ')
Returns:
pd.DataFrame with columns:
- ts_code: Stock code
- name: Stock name
- trade_date: Trade date (YYYYMMDD)
- type: Type code
- type_name: Type name (风险警示板)
Example:
>>> # Get all ST stocks for a single date
>>> data = get_stock_st(trade_date='20240101')
>>>
>>> # Get date range data
>>> data = get_stock_st(start_date='20240101', end_date='20240131')
>>>
>>> # Get specific stock ST history
>>> data = get_stock_st(ts_code='000001.SZ')
"""
client = TushareClient()
# Build parameters
params = {}
if trade_date:
params["trade_date"] = trade_date
if start_date:
params["start_date"] = start_date
if end_date:
params["end_date"] = end_date
if ts_code:
params["ts_code"] = ts_code
# Fetch data
data = client.query("stock_st", **params)
return data
class StockSTSync(DateBasedSync):
"""ST股票列表批量同步管理器支持全量/增量同步。
继承自 DateBasedSync按日期顺序获取数据。
数据从 2016 年开始可用单次请求最大返回1000行数据。
Example:
>>> sync = StockSTSync()
>>> results = sync.sync_all() # 增量同步
>>> results = sync.sync_all(force_full=True) # 全量同步
>>> preview = sync.preview_sync() # 预览
"""
table_name = "stock_st"
default_start_date = "20160101"
# 表结构定义
TABLE_SCHEMA = {
"ts_code": "VARCHAR(16) NOT NULL",
"name": "VARCHAR(50)",
"trade_date": "DATE NOT NULL",
"type": "VARCHAR(10)",
"type_name": "VARCHAR(50)",
}
# 索引定义
TABLE_INDEXES = [
("idx_stock_st_date_code", ["trade_date", "ts_code"]),
]
# 主键定义
PRIMARY_KEY = ("trade_date", "ts_code")
def fetch_single_date(self, trade_date: str) -> pd.DataFrame:
"""获取单日的ST股票列表数据。
Args:
trade_date: 交易日期YYYYMMDD
Returns:
包含当日ST股票列表的 DataFrame
"""
return get_stock_st(trade_date=trade_date)
def sync_stock_st(
start_date: Optional[str] = None,
end_date: Optional[str] = None,
force_full: bool = False,
) -> pd.DataFrame:
"""Sync ST stock list to DuckDB with intelligent incremental sync.
Logic:
- If table doesn't exist: create table + composite index (trade_date, ts_code) + full sync
- If table exists: incremental sync from last_date + 1
Args:
start_date: Start date for sync (YYYYMMDD format, default: 20160101 for full, last_date+1 for incremental)
end_date: End date for sync (YYYYMMDD format, default: today)
force_full: If True, force full reload from 20160101
Returns:
pd.DataFrame with synced data
"""
sync_manager = StockSTSync()
return sync_manager.sync_all(
start_date=start_date,
end_date=end_date,
force_full=force_full,
)
if __name__ == "__main__":
# Test sync
result = sync_stock_st(end_date="20240102")
print(f"Synced {len(result)} records")
if not result.empty:
print("\nSample data:")
print(result.head())

View File

@@ -14,7 +14,6 @@ from src.config.settings import get_settings
_cache_synced = False _cache_synced = False
# Trading calendar cache file path # Trading calendar cache file path
def _get_cache_path() -> Path: def _get_cache_path() -> Path:
"""Get the cache file path for trade calendar.""" """Get the cache file path for trade calendar."""

View File

@@ -63,15 +63,15 @@ class BaseDataSync(ABC):
table_name: str = "" # 子类必须覆盖 table_name: str = "" # 子类必须覆盖
DEFAULT_START_DATE = "20180101" DEFAULT_START_DATE = "20180101"
DEFAULT_MAX_WORKERS = get_settings().threads DEFAULT_MAX_WORKERS = get_settings().threads
# 表结构定义(子类可覆盖) # 表结构定义(子类可覆盖)
# 格式: {"column_name": "SQL_TYPE", ...} # 格式: {"column_name": "SQL_TYPE", ...}
TABLE_SCHEMA: Dict[str, str] = {} TABLE_SCHEMA: Dict[str, str] = {}
# 索引定义(子类可覆盖) # 索引定义(子类可覆盖)
# 格式: [("index_name", ["col1", "col2"]), ...] # 格式: [("index_name", ["col1", "col2"]), ...]
TABLE_INDEXES: List[tuple] = [] TABLE_INDEXES: List[tuple] = []
# 主键定义(子类可覆盖) # 主键定义(子类可覆盖)
# 格式: ("col1", "col2") # 格式: ("col1", "col2")
PRIMARY_KEY: tuple = () PRIMARY_KEY: tuple = ()
@@ -325,7 +325,9 @@ class BaseDataSync(ABC):
try: try:
print(f"[{class_name}] Probe: {probe_description}") print(f"[{class_name}] Probe: {probe_description}")
print(f"[{class_name}] Probe: Inserting {len(probe_data)} sample records...") print(
f"[{class_name}] Probe: Inserting {len(probe_data)} sample records..."
)
# 插入样本数据 # 插入样本数据
storage.save(self.table_name, probe_data, mode="append") storage.save(self.table_name, probe_data, mode="append")
@@ -344,18 +346,20 @@ class BaseDataSync(ABC):
# 清空表truncate # 清空表truncate
print(f"[{class_name}] Probe: Cleaning up sample data...") print(f"[{class_name}] Probe: Cleaning up sample data...")
storage._connection.execute(f'DELETE FROM "{self.table_name}"') storage._connection.execute(f'DELETE FROM "{self.table_name}"')
# 验证表已清空 # 验证表已清空
count_result = storage._connection.execute( count_result = storage._connection.execute(
f'SELECT COUNT(*) FROM "{self.table_name}"' f'SELECT COUNT(*) FROM "{self.table_name}"'
).fetchone() ).fetchone()
remaining = count_result[0] if count_result else -1 remaining = count_result[0] if count_result else -1
if remaining == 0: if remaining == 0:
print(f"[{class_name}] Probe: SUCCESS - Table verified and cleaned") print(f"[{class_name}] Probe: SUCCESS - Table verified and cleaned")
return True return True
else: else:
print(f"[{class_name}] Probe: WARNING - {remaining} rows remaining after cleanup") print(
f"[{class_name}] Probe: WARNING - {remaining} rows remaining after cleanup"
)
return True # 仍然继续,因为主要目的是验证结构 return True # 仍然继续,因为主要目的是验证结构
except Exception as e: except Exception as e:
@@ -395,44 +399,50 @@ class BaseDataSync(ABC):
子类可以覆盖此方法以自定义建表逻辑。 子类可以覆盖此方法以自定义建表逻辑。
""" """
storage = Storage() storage = Storage()
if storage.exists(self.table_name): if storage.exists(self.table_name):
return return
if not self.TABLE_SCHEMA: if not self.TABLE_SCHEMA:
print(f"[{self.__class__.__name__}] TABLE_SCHEMA not defined, skipping table creation") print(
f"[{self.__class__.__name__}] TABLE_SCHEMA not defined, skipping table creation"
)
return return
# 构建列定义 # 构建列定义
columns_def = [] columns_def = []
for col_name, col_type in self.TABLE_SCHEMA.items(): for col_name, col_type in self.TABLE_SCHEMA.items():
columns_def.append(f'"{col_name}" {col_type}') columns_def.append(f'"{col_name}" {col_type}')
# 添加主键约束 # 添加主键约束
if self.PRIMARY_KEY: if self.PRIMARY_KEY:
pk_cols = ', '.join(f'"{col}"' for col in self.PRIMARY_KEY) pk_cols = ", ".join(f'"{col}"' for col in self.PRIMARY_KEY)
columns_def.append(f"PRIMARY KEY ({pk_cols})") columns_def.append(f"PRIMARY KEY ({pk_cols})")
columns_sql = ", ".join(columns_def) columns_sql = ", ".join(columns_def)
create_sql = f'CREATE TABLE IF NOT EXISTS "{self.table_name}" ({columns_sql})' create_sql = f'CREATE TABLE IF NOT EXISTS "{self.table_name}" ({columns_sql})'
try: try:
storage._connection.execute(create_sql) storage._connection.execute(create_sql)
print(f"[{self.__class__.__name__}] Created table '{self.table_name}'") print(f"[{self.__class__.__name__}] Created table '{self.table_name}'")
except Exception as e: except Exception as e:
print(f"[{self.__class__.__name__}] Error creating table: {e}") print(f"[{self.__class__.__name__}] Error creating table: {e}")
raise raise
# 创建索引 # 创建索引
for idx_name, idx_cols in self.TABLE_INDEXES: for idx_name, idx_cols in self.TABLE_INDEXES:
try: try:
idx_cols_sql = ', '.join(f'"{col}"' for col in idx_cols) idx_cols_sql = ", ".join(f'"{col}"' for col in idx_cols)
storage._connection.execute( storage._connection.execute(
f'CREATE INDEX IF NOT EXISTS "{idx_name}" ON "{self.table_name}"({idx_cols_sql})' f'CREATE INDEX IF NOT EXISTS "{idx_name}" ON "{self.table_name}"({idx_cols_sql})'
) )
print(f"[{self.__class__.__name__}] Created index '{idx_name}' on {idx_cols}") print(
f"[{self.__class__.__name__}] Created index '{idx_name}' on {idx_cols}"
)
except Exception as e: except Exception as e:
print(f"[{self.__class__.__name__}] Error creating index {idx_name}: {e}") print(
f"[{self.__class__.__name__}] Error creating index {idx_name}: {e}"
)
@abstractmethod @abstractmethod
def preview_sync( def preview_sync(
@@ -863,28 +873,30 @@ class StockBasedSync(BaseDataSync):
# 首次同步探测:验证表结构是否正常 # 首次同步探测:验证表结构是否正常
if self._should_probe_table(): if self._should_probe_table():
print(f"[{class_name}] Table '{self.table_name}' is empty or doesn't exist, probing...") print(
f"[{class_name}] Table '{self.table_name}' is empty or doesn't exist, probing..."
)
# 使用第一只股票的完整日期范围数据进行探测 # 使用第一只股票的完整日期范围数据进行探测
probe_stock = stock_codes[0] probe_stock = stock_codes[0]
probe_data = self.fetch_single_stock( probe_data = self.fetch_single_stock(probe_stock, sync_start_date, end_date)
probe_stock, sync_start_date, end_date
)
probe_desc = f"stock={probe_stock}, range={sync_start_date} to {end_date}" probe_desc = f"stock={probe_stock}, range={sync_start_date} to {end_date}"
probe_success = self._probe_table_and_cleanup(probe_data, probe_desc) probe_success = self._probe_table_and_cleanup(probe_data, probe_desc)
if not probe_success: if not probe_success:
print(f"[{class_name}] Probe failed! Stopping sync to prevent data corruption.") print(
f"[{class_name}] Probe failed! Stopping sync to prevent data corruption."
)
raise RuntimeError( raise RuntimeError(
f"Table '{self.table_name}' probe failed. " f"Table '{self.table_name}' probe failed. "
"Please check database schema and column mappings." "Please check database schema and column mappings."
) )
if self._should_probe_table(): if self._should_probe_table():
print(f"[{class_name}] Table '{self.table_name}' is empty or doesn't exist, probing...") print(
f"[{class_name}] Table '{self.table_name}' is empty or doesn't exist, probing..."
)
# 使用第一只股票的完整日期范围数据进行探测 # 使用第一只股票的完整日期范围数据进行探测
probe_stock = stock_codes[0] probe_stock = stock_codes[0]
probe_data = self.fetch_single_stock( probe_data = self.fetch_single_stock(probe_stock, sync_start_date, end_date)
probe_stock, sync_start_date, end_date
)
probe_desc = f"stock={probe_stock}, range={sync_start_date} to {end_date}" probe_desc = f"stock={probe_stock}, range={sync_start_date} to {end_date}"
self._probe_table_and_cleanup(probe_data, probe_desc) self._probe_table_and_cleanup(probe_data, probe_desc)
@@ -1301,7 +1313,7 @@ class DateBasedSync(BaseDataSync):
else: else:
print(f"[{class_name}] Cannot create table: no sample data available") print(f"[{class_name}] Cannot create table: no sample data available")
return pd.DataFrame() return pd.DataFrame()
# 首次同步探测:验证表结构是否正常 # 首次同步探测:验证表结构是否正常
if self._should_probe_table(): if self._should_probe_table():
print(f"[{class_name}] Table '{self.table_name}' is empty, probing...") print(f"[{class_name}] Table '{self.table_name}' is empty, probing...")
@@ -1335,10 +1347,8 @@ class DateBasedSync(BaseDataSync):
if self._should_probe_table(): if self._should_probe_table():
print(f"[{class_name}] Table '{self.table_name}' is empty, probing...") print(f"[{class_name}] Table '{self.table_name}' is empty, probing...")
# 使用最近一个交易日的完整数据进行探测 # 使用最近一个交易日的完整数据进行探测
from src.data.api_wrappers.api_trade_cal import get_last_n_trading_days probe_date = get_last_trading_day(sync_start, sync_end)
last_days = get_last_n_trading_days(1, sync_end) if probe_date:
if last_days:
probe_date = last_days[0]
probe_data = self.fetch_single_date(probe_date) probe_data = self.fetch_single_date(probe_date)
probe_desc = f"date={probe_date}, all stocks" probe_desc = f"date={probe_date}, all stocks"
self._probe_table_and_cleanup(probe_data, probe_desc) self._probe_table_and_cleanup(probe_data, probe_desc)

View File

@@ -46,6 +46,7 @@ from src.data.api_wrappers.api_daily import sync_daily, preview_daily_sync
from src.data.api_wrappers.api_pro_bar import sync_pro_bar from src.data.api_wrappers.api_pro_bar import sync_pro_bar
from src.data.api_wrappers.api_bak_basic import sync_bak_basic from src.data.api_wrappers.api_bak_basic import sync_bak_basic
from src.data.api_wrappers.api_daily_basic import sync_daily_basic from src.data.api_wrappers.api_daily_basic import sync_daily_basic
from src.data.api_wrappers.api_stock_st import sync_stock_st
def preview_sync( def preview_sync(
@@ -161,6 +162,7 @@ def sync_all_data(
4. Pro Bar 数据 (sync_pro_bar) 4. Pro Bar 数据 (sync_pro_bar)
5. 每日指标数据 (sync_daily_basic) 5. 每日指标数据 (sync_daily_basic)
6. 历史股票列表 (sync_bak_basic) 6. 历史股票列表 (sync_bak_basic)
7. ST股票列表 (sync_stock_st)
【不包含的同步(需单独调用)】 【不包含的同步(需单独调用)】
- 财务数据: 利润表、资产负债表、现金流量表(季度更新) - 财务数据: 利润表、资产负债表、现金流量表(季度更新)
@@ -195,53 +197,53 @@ def sync_all_data(
print("=" * 60) print("=" * 60)
# 1. Sync trade calendar (always needed first) # 1. Sync trade calendar (always needed first)
print("\n[1/5] Syncing trade calendar cache...") print("\n[1/7] Syncing trade calendar cache...")
try: try:
from src.data.api_wrappers import sync_trade_cal_cache from src.data.api_wrappers import sync_trade_cal_cache
sync_trade_cal_cache() sync_trade_cal_cache()
results["trade_cal"] = pd.DataFrame() results["trade_cal"] = pd.DataFrame()
print("[1/5] Trade calendar: OK") print("[1/7] Trade calendar: OK")
except Exception as e: except Exception as e:
print(f"[1/5] Trade calendar: FAILED - {e}") print(f"[1/7] Trade calendar: FAILED - {e}")
results["trade_cal"] = pd.DataFrame() results["trade_cal"] = pd.DataFrame()
# 2. Sync stock basic info # 2. Sync stock basic info
print("\n[2/5] Syncing stock basic info...") print("\n[2/7] Syncing stock basic info...")
try: try:
sync_all_stocks() sync_all_stocks()
results["stock_basic"] = pd.DataFrame() results["stock_basic"] = pd.DataFrame()
print("[2/5] Stock basic: OK") print("[2/7] Stock basic: OK")
except Exception as e: except Exception as e:
print(f"[2/5] Stock basic: FAILED - {e}") print(f"[2/7] Stock basic: FAILED - {e}")
results["stock_basic"] = pd.DataFrame() results["stock_basic"] = pd.DataFrame()
# 3. Sync daily market data # 3. Sync daily market data
print("\n[3/5] Syncing daily market data...") # print("\n[3/7] Syncing daily market data...")
try: # try:
# 确保表存在 # # 确保表存在
from src.data.api_wrappers.api_daily import DailySync # from src.data.api_wrappers.api_daily import DailySync
#
DailySync().ensure_table_exists() # DailySync().ensure_table_exists()
#
daily_result = sync_daily( # daily_result = sync_daily(
force_full=force_full, # force_full=force_full,
max_workers=max_workers, # max_workers=max_workers,
dry_run=dry_run, # dry_run=dry_run,
) # )
results["daily"] = daily_result # results["daily"] = daily_result
total_daily_records = ( # total_daily_records = (
sum(len(df) for df in daily_result.values()) if daily_result else 0 # sum(len(df) for df in daily_result.values()) if daily_result else 0
) # )
print( # print(
f"[3/5] Daily data: OK ({total_daily_records} records from {len(daily_result)} stocks)" # f"[3/7] Daily data: OK ({total_daily_records} records from {len(daily_result)} stocks)"
) # )
except Exception as e: # except Exception as e:
print(f"[3/5] Daily data: FAILED - {e}") # print(f"[3/7] Daily data: FAILED - {e}")
results["daily"] = pd.DataFrame() # results["daily"] = pd.DataFrame()
# 4. Sync Pro Bar data # 4. Sync Pro Bar data
print("\n[4/6] Syncing Pro Bar data (with adj, tor, vr)...") print("\n[4/7] Syncing Pro Bar data (with adj, tor, vr)...")
try: try:
# 确保表存在 # 确保表存在
from src.data.api_wrappers.api_pro_bar import ProBarSync from src.data.api_wrappers.api_pro_bar import ProBarSync
@@ -258,15 +260,15 @@ def sync_all_data(
sum(len(df) for df in pro_bar_result.values()) if pro_bar_result else 0 sum(len(df) for df in pro_bar_result.values()) if pro_bar_result else 0
) )
print( print(
f"[4/6] Pro Bar data: OK ({total_pro_bar_records} records from {len(pro_bar_result)} stocks)" f"[4/7] Pro Bar data: OK ({total_pro_bar_records} records from {len(pro_bar_result)} stocks)"
) )
except Exception as e: except Exception as e:
print(f"[4/6] Pro Bar data: FAILED - {e}") print(f"[4/7] Pro Bar data: FAILED - {e}")
results["pro_bar"] = pd.DataFrame() results["pro_bar"] = pd.DataFrame()
# 5. Sync daily basic indicators # 5. Sync daily basic indicators
print( print(
"\n[5/6] Syncing daily basic indicators (PE, PB, turnover rate, market value)..." "\n[5/7] Syncing daily basic indicators (PE, PB, turnover rate, market value)..."
) )
try: try:
# 确保表存在 # 确保表存在
@@ -276,13 +278,13 @@ def sync_all_data(
daily_basic_result = sync_daily_basic(force_full=force_full, dry_run=dry_run) daily_basic_result = sync_daily_basic(force_full=force_full, dry_run=dry_run)
results["daily_basic"] = daily_basic_result results["daily_basic"] = daily_basic_result
print(f"[5/6] Daily basic: OK ({len(daily_basic_result)} records)") print(f"[5/7] Daily basic: OK ({len(daily_basic_result)} records)")
except Exception as e: except Exception as e:
print(f"[5/6] Daily basic: FAILED - {e}") print(f"[5/7] Daily basic: FAILED - {e}")
results["daily_basic"] = pd.DataFrame() results["daily_basic"] = pd.DataFrame()
# 6. Sync stock historical list (bak_basic) # 6. Sync stock historical list (bak_basic)
print("\n[6/6] Syncing stock historical list (bak_basic)...") print("\n[6/7] Syncing stock historical list (bak_basic)...")
try: try:
# 确保表存在 # 确保表存在
from src.data.api_wrappers.api_bak_basic import BakBasicSync from src.data.api_wrappers.api_bak_basic import BakBasicSync
@@ -291,11 +293,26 @@ def sync_all_data(
bak_basic_result = sync_bak_basic(force_full=force_full) bak_basic_result = sync_bak_basic(force_full=force_full)
results["bak_basic"] = bak_basic_result results["bak_basic"] = bak_basic_result
print(f"[6/6] Bak basic: OK ({len(bak_basic_result)} records)") print(f"[6/7] Bak basic: OK ({len(bak_basic_result)} records)")
except Exception as e: except Exception as e:
print(f"[6/6] Bak basic: FAILED - {e}") print(f"[6/7] Bak basic: FAILED - {e}")
results["bak_basic"] = pd.DataFrame() results["bak_basic"] = pd.DataFrame()
# 7. Sync ST stock list
print("\n[7/7] Syncing ST stock list...")
try:
# 确保表存在
from src.data.api_wrappers.api_stock_st import StockSTSync
StockSTSync().ensure_table_exists()
stock_st_result = sync_stock_st(force_full=force_full)
results["stock_st"] = stock_st_result
print(f"[7/7] ST stock list: OK ({len(stock_st_result)} records)")
except Exception as e:
print(f"[7/7] ST stock list: FAILED - {e}")
results["stock_st"] = pd.DataFrame()
# Summary # Summary
print("\n" + "=" * 60) print("\n" + "=" * 60)
print("[sync_all_data] Sync Summary") print("[sync_all_data] Sync Summary")

143
tests/test_stock_st.py Normal file
View File

@@ -0,0 +1,143 @@
"""Test suite for stock_st API wrapper."""
import pytest
import pandas as pd
from unittest.mock import patch, MagicMock
from src.data.api_wrappers.api_stock_st import get_stock_st, sync_stock_st, StockSTSync
class TestStockST:
"""Test suite for stock_st API wrapper."""
@patch("src.data.api_wrappers.api_stock_st.TushareClient")
def test_get_by_date(self, mock_client_class):
"""Test fetching ST stock list by date."""
# Setup mock
mock_client = MagicMock()
mock_client_class.return_value = mock_client
mock_client.query.return_value = pd.DataFrame(
{
"ts_code": ["300313.SZ", "605081.SH", "300391.SZ"],
"name": ["*ST天山", "*ST太和", "*ST长药"],
"trade_date": ["20240101", "20240101", "20240101"],
"type": ["ST", "ST", "ST"],
"type_name": ["风险警示板", "风险警示板", "风险警示板"],
}
)
# Test
result = get_stock_st(trade_date="20240101")
# Assert
assert not result.empty
assert len(result) == 3
assert "ts_code" in result.columns
assert "name" in result.columns
assert "trade_date" in result.columns
assert "type" in result.columns
assert "type_name" in result.columns
mock_client.query.assert_called_once()
@patch("src.data.api_wrappers.api_stock_st.TushareClient")
def test_get_by_stock(self, mock_client_class):
"""Test fetching ST history by stock code."""
# Setup mock
mock_client = MagicMock()
mock_client_class.return_value = mock_client
mock_client.query.return_value = pd.DataFrame(
{
"ts_code": ["300313.SZ", "300313.SZ"],
"name": ["*ST天山", "*ST天山"],
"trade_date": ["20240101", "20240102"],
"type": ["ST", "ST"],
"type_name": ["风险警示板", "风险警示板"],
}
)
# Test
result = get_stock_st(
ts_code="300313.SZ", start_date="20240101", end_date="20240102"
)
# Assert
assert not result.empty
assert len(result) == 2
mock_client.query.assert_called_once()
@patch("src.data.api_wrappers.api_stock_st.TushareClient")
def test_empty_response(self, mock_client_class):
"""Test handling empty response."""
mock_client = MagicMock()
mock_client_class.return_value = mock_client
mock_client.query.return_value = pd.DataFrame()
result = get_stock_st(trade_date="20240101")
assert result.empty
@patch("src.data.api_wrappers.api_stock_st.TushareClient")
def test_get_by_date_range(self, mock_client_class):
"""Test fetching ST stock list by date range."""
# Setup mock
mock_client = MagicMock()
mock_client_class.return_value = mock_client
mock_client.query.return_value = pd.DataFrame(
{
"ts_code": ["300313.SZ"],
"name": ["*ST天山"],
"trade_date": ["20240101"],
"type": ["ST"],
"type_name": ["风险警示板"],
}
)
# Test
result = get_stock_st(start_date="20240101", end_date="20240131")
# Assert
assert not result.empty
mock_client.query.assert_called_once()
class TestStockSTSync:
"""Test suite for StockSTSync class."""
def test_sync_class_attributes(self):
"""Test that sync class has correct attributes."""
sync = StockSTSync()
assert sync.table_name == "stock_st"
assert sync.default_start_date == "20160101"
assert "ts_code" in sync.TABLE_SCHEMA
assert "trade_date" in sync.TABLE_SCHEMA
assert "name" in sync.TABLE_SCHEMA
assert "type" in sync.TABLE_SCHEMA
assert "type_name" in sync.TABLE_SCHEMA
assert sync.PRIMARY_KEY == ("trade_date", "ts_code")
@patch("src.data.api_wrappers.api_stock_st.TushareClient")
def test_fetch_single_date(self, mock_client_class):
"""Test fetching single date data."""
# Setup mock
mock_client = MagicMock()
mock_client_class.return_value = mock_client
mock_client.query.return_value = pd.DataFrame(
{
"ts_code": ["300313.SZ"],
"name": ["*ST天山"],
"trade_date": ["20240101"],
"type": ["ST"],
"type_name": ["风险警示板"],
}
)
# Test
sync = StockSTSync()
result = sync.fetch_single_date("20240101")
# Assert
assert not result.empty
assert len(result) == 1
if __name__ == "__main__":
pytest.main([__file__, "-v"])