feat(data): 添加每日筹码及胜率数据接口 (cyq_perf)
- 新增 api_cyq_perf 模块,支持筹码分布数据获取和同步 - 在 sync_registry 中注册 cyq_perf 同步器
This commit is contained in:
@@ -684,4 +684,69 @@ df = pro.stk_limit(ts_code='002149.SZ', start_date='20190115', end_date='2019061
|
|||||||
17 20190625 000021.SZ 9.30 7.61
|
17 20190625 000021.SZ 9.30 7.61
|
||||||
18 20190625 000023.SZ 14.61 11.95
|
18 20190625 000023.SZ 14.61 11.95
|
||||||
19 20190625 000025.SZ 23.08 18.88
|
19 20190625 000025.SZ 23.08 18.88
|
||||||
20 20190625 000026.SZ 8.66 7.08
|
20 20190625 000026.SZ 8.66 7.08
|
||||||
|
|
||||||
|
|
||||||
|
每日筹码及胜率
|
||||||
|
接口:cyq_perf
|
||||||
|
描述:获取A股每日筹码平均成本和胜率情况,每天18~19点左右更新,数据从2018年开始
|
||||||
|
来源:Tushare社区
|
||||||
|
限量:单次最大5000条,可以分页或者循环提取
|
||||||
|
积分:5000积分每天20000次,10000积分每天200000次,15000积分每天不限总量
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
输入参数
|
||||||
|
|
||||||
|
名称 类型 必选 描述
|
||||||
|
ts_code str Y 股票代码
|
||||||
|
trade_date str N 交易日期(YYYYMMDD)
|
||||||
|
start_date str N 开始日期
|
||||||
|
end_date str N 结束日期
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
输出参数
|
||||||
|
|
||||||
|
名称 类型 默认显示 描述
|
||||||
|
ts_code str Y 股票代码
|
||||||
|
trade_date str Y 交易日期
|
||||||
|
his_low float Y 历史最低价
|
||||||
|
his_high float Y 历史最高价
|
||||||
|
cost_5pct float Y 5分位成本
|
||||||
|
cost_15pct float Y 15分位成本
|
||||||
|
cost_50pct float Y 50分位成本
|
||||||
|
cost_85pct float Y 85分位成本
|
||||||
|
cost_95pct float Y 95分位成本
|
||||||
|
weight_avg float Y 加权平均成本
|
||||||
|
winner_rate float Y 胜率
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
接口用法
|
||||||
|
|
||||||
|
pro = ts.pro_api()
|
||||||
|
|
||||||
|
df = pro.cyq_perf(ts_code='600000.SH', start_date='20220101', end_date='20220429')
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
数据样例
|
||||||
|
|
||||||
|
ts_code trade_date his_low his_high cost_5pct cost_95pct weight_avg winner_rate
|
||||||
|
0 600000.SH 20220429 0.72 12.16 8.18 11.34 9.76 3.52
|
||||||
|
1 600000.SH 20220428 0.72 12.16 8.24 11.34 9.76 3.08
|
||||||
|
2 600000.SH 20220427 0.72 12.16 8.30 11.34 9.76 1.71
|
||||||
|
3 600000.SH 20220426 0.72 12.16 8.34 11.34 9.76 2.02
|
||||||
|
4 600000.SH 20220425 0.72 12.16 8.36 11.34 9.77 1.44
|
||||||
|
.. ... ... ... ... ... ... ... ...
|
||||||
|
72 600000.SH 20220110 0.72 12.16 8.60 11.36 9.89 7.62
|
||||||
|
73 600000.SH 20220107 0.72 12.16 8.60 11.36 9.89 7.59
|
||||||
|
74 600000.SH 20220106 0.72 12.16 8.60 11.36 9.89 3.92
|
||||||
|
75 600000.SH 20220105 0.72 12.16 8.60 11.36 9.89 5.65
|
||||||
|
76 600000.SH 20220104 0.72 12.16 8.60 11.36 9.89 3.93
|
||||||
@@ -13,12 +13,14 @@ Available APIs:
|
|||||||
- api_bak_basic: Stock historical list (股票历史列表)
|
- api_bak_basic: Stock historical list (股票历史列表)
|
||||||
- api_stock_st: ST stock list (ST股票列表)
|
- api_stock_st: ST stock list (ST股票列表)
|
||||||
- api_stk_limit: Stock limit price (每日涨跌停价格)
|
- api_stk_limit: Stock limit price (每日涨跌停价格)
|
||||||
|
- api_cyq_perf: CYQ performance (每日筹码及胜率)
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
>>> from src.data.api_wrappers import get_daily, get_stock_basic, get_trade_cal, get_bak_basic
|
>>> from src.data.api_wrappers import get_daily, get_stock_basic, get_trade_cal, get_bak_basic
|
||||||
>>> from src.data.api_wrappers import get_pro_bar, sync_pro_bar, get_daily_basic, sync_daily_basic
|
>>> from src.data.api_wrappers import get_pro_bar, sync_pro_bar, get_daily_basic, sync_daily_basic
|
||||||
>>> from src.data.api_wrappers import get_stock_st, sync_stock_st
|
>>> from src.data.api_wrappers import get_stock_st, sync_stock_st
|
||||||
>>> from src.data.api_wrappers import get_stk_limit, sync_stk_limit
|
>>> from src.data.api_wrappers import get_stk_limit, sync_stk_limit
|
||||||
|
>>> from src.data.api_wrappers import get_cyq_perf, sync_cyq_perf
|
||||||
>>> data = get_daily('000001.SZ', start_date='20240101', end_date='20240131')
|
>>> data = get_daily('000001.SZ', start_date='20240101', end_date='20240131')
|
||||||
>>> pro_data = get_pro_bar('000001.SZ', start_date='20240101', end_date='20240131')
|
>>> pro_data = get_pro_bar('000001.SZ', start_date='20240101', end_date='20240131')
|
||||||
>>> daily_basic = get_daily_basic(trade_date='20240101')
|
>>> daily_basic = get_daily_basic(trade_date='20240101')
|
||||||
@@ -27,6 +29,7 @@ Example:
|
|||||||
>>> bak_basic = get_bak_basic(trade_date='20240101')
|
>>> bak_basic = get_bak_basic(trade_date='20240101')
|
||||||
>>> stock_st = get_stock_st(trade_date='20240101')
|
>>> stock_st = get_stock_st(trade_date='20240101')
|
||||||
>>> stk_limit = get_stk_limit(trade_date='20240101')
|
>>> stk_limit = get_stk_limit(trade_date='20240101')
|
||||||
|
>>> cyq_perf = get_cyq_perf('000001.SZ', start_date='20240101', end_date='20240131')
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from src.data.api_wrappers.api_daily_basic import (
|
from src.data.api_wrappers.api_daily_basic import (
|
||||||
@@ -68,6 +71,12 @@ from src.data.api_wrappers.api_trade_cal import (
|
|||||||
get_last_trading_day,
|
get_last_trading_day,
|
||||||
sync_trade_cal_cache,
|
sync_trade_cal_cache,
|
||||||
)
|
)
|
||||||
|
from src.data.api_wrappers.api_cyq_perf import (
|
||||||
|
get_cyq_perf,
|
||||||
|
sync_cyq_perf,
|
||||||
|
preview_cyq_perf_sync,
|
||||||
|
CyqPerfSync,
|
||||||
|
)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
# Daily market data
|
# Daily market data
|
||||||
@@ -115,6 +124,11 @@ __all__ = [
|
|||||||
"sync_stk_limit",
|
"sync_stk_limit",
|
||||||
"preview_stk_limit_sync",
|
"preview_stk_limit_sync",
|
||||||
"StkLimitSync",
|
"StkLimitSync",
|
||||||
|
# CYQ Performance (筹码分布)
|
||||||
|
"get_cyq_perf",
|
||||||
|
"sync_cyq_perf",
|
||||||
|
"preview_cyq_perf_sync",
|
||||||
|
"CyqPerfSync",
|
||||||
]
|
]
|
||||||
|
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
@@ -198,6 +212,17 @@ try:
|
|||||||
order=50,
|
order=50,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# 8. CYQ Performance - 每日筹码及胜率
|
||||||
|
from src.data.api_wrappers.api_cyq_perf import CyqPerfSync
|
||||||
|
|
||||||
|
sync_registry.register_class(
|
||||||
|
name="cyq_perf",
|
||||||
|
sync_class=CyqPerfSync,
|
||||||
|
display_name="每日筹码及胜率",
|
||||||
|
description="A股每日筹码平均成本和胜率情况(2018年开始)",
|
||||||
|
order=60,
|
||||||
|
)
|
||||||
|
|
||||||
except ImportError:
|
except ImportError:
|
||||||
# sync_registry 可能不存在(首次导入),忽略
|
# sync_registry 可能不存在(首次导入),忽略
|
||||||
pass
|
pass
|
||||||
|
|||||||
233
src/data/api_wrappers/api_cyq_perf.py
Normal file
233
src/data/api_wrappers/api_cyq_perf.py
Normal file
@@ -0,0 +1,233 @@
|
|||||||
|
"""CYQ Performance (筹码分布) interface.
|
||||||
|
|
||||||
|
Fetch A-share stock chip distribution data (cost distribution and win rate) from Tushare.
|
||||||
|
This interface retrieves daily chip average cost and win rate information.
|
||||||
|
Data starts from 2018.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import pandas as pd
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from src.data.client import TushareClient
|
||||||
|
from src.data.api_wrappers.base_sync import StockBasedSync
|
||||||
|
|
||||||
|
|
||||||
|
def get_cyq_perf(
|
||||||
|
ts_code: str,
|
||||||
|
start_date: Optional[str] = None,
|
||||||
|
end_date: Optional[str] = None,
|
||||||
|
client: Optional[TushareClient] = None,
|
||||||
|
) -> pd.DataFrame:
|
||||||
|
"""Fetch chip distribution (CYQ) performance data from Tushare.
|
||||||
|
|
||||||
|
This interface retrieves daily chip average cost and win rate information
|
||||||
|
for A-share stocks. Data starts from 2018.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ts_code: Stock code (e.g., '000001.SZ', '600000.SH')
|
||||||
|
start_date: Start date in YYYYMMDD format
|
||||||
|
end_date: End date in YYYYMMDD format
|
||||||
|
client: Optional TushareClient instance for shared rate limiting.
|
||||||
|
If None, creates a new client. For concurrent sync operations,
|
||||||
|
pass a shared client to ensure proper rate limiting.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
pd.DataFrame with columns:
|
||||||
|
- ts_code: Stock code
|
||||||
|
- trade_date: Trade date (YYYYMMDD)
|
||||||
|
- his_low: Historical lowest price
|
||||||
|
- his_high: Historical highest price
|
||||||
|
- cost_5pct: 5th percentile cost
|
||||||
|
- cost_15pct: 15th percentile cost
|
||||||
|
- cost_50pct: 50th percentile cost (median)
|
||||||
|
- cost_85pct: 85th percentile cost
|
||||||
|
- cost_95pct: 95th percentile cost
|
||||||
|
- weight_avg: Weighted average cost
|
||||||
|
- winner_rate: Win rate (percentage)
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> # Get chip distribution data for a stock
|
||||||
|
>>> data = get_cyq_perf('000001.SZ', start_date='20240101', end_date='20240131')
|
||||||
|
>>>
|
||||||
|
>>> # Get data with shared client for rate limiting
|
||||||
|
>>> from src.data.client import TushareClient
|
||||||
|
>>> client = TushareClient()
|
||||||
|
>>> data = get_cyq_perf('000001.SZ', start_date='20240101', end_date='20240131', client=client)
|
||||||
|
"""
|
||||||
|
client = client or TushareClient()
|
||||||
|
|
||||||
|
# Build parameters
|
||||||
|
params = {"ts_code": ts_code}
|
||||||
|
|
||||||
|
if start_date:
|
||||||
|
params["start_date"] = start_date
|
||||||
|
if end_date:
|
||||||
|
params["end_date"] = end_date
|
||||||
|
|
||||||
|
# Fetch data using cyq_perf API
|
||||||
|
data = client.query("cyq_perf", **params)
|
||||||
|
|
||||||
|
# Rename date column if needed
|
||||||
|
if "date" in data.columns:
|
||||||
|
data = data.rename(columns={"date": "trade_date"})
|
||||||
|
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
class CyqPerfSync(StockBasedSync):
|
||||||
|
"""筹码分布数据批量同步管理器,支持全量/增量同步。
|
||||||
|
|
||||||
|
继承自 StockBasedSync,使用多线程按股票并发获取数据。
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> sync = CyqPerfSync()
|
||||||
|
>>> results = sync.sync_all() # 增量同步
|
||||||
|
>>> results = sync.sync_all(force_full=True) # 全量同步
|
||||||
|
>>> preview = sync.preview_sync() # 预览
|
||||||
|
"""
|
||||||
|
|
||||||
|
table_name = "cyq_perf"
|
||||||
|
|
||||||
|
# 表结构定义
|
||||||
|
TABLE_SCHEMA = {
|
||||||
|
"ts_code": "VARCHAR(16) NOT NULL",
|
||||||
|
"trade_date": "DATE NOT NULL",
|
||||||
|
"his_low": "DOUBLE",
|
||||||
|
"his_high": "DOUBLE",
|
||||||
|
"cost_5pct": "DOUBLE",
|
||||||
|
"cost_15pct": "DOUBLE",
|
||||||
|
"cost_50pct": "DOUBLE",
|
||||||
|
"cost_85pct": "DOUBLE",
|
||||||
|
"cost_95pct": "DOUBLE",
|
||||||
|
"weight_avg": "DOUBLE",
|
||||||
|
"winner_rate": "DOUBLE",
|
||||||
|
}
|
||||||
|
|
||||||
|
# 索引定义
|
||||||
|
TABLE_INDEXES = [
|
||||||
|
("idx_cyq_perf_date_code", ["trade_date", "ts_code"]),
|
||||||
|
]
|
||||||
|
|
||||||
|
# 主键定义
|
||||||
|
PRIMARY_KEY = ("ts_code", "trade_date")
|
||||||
|
|
||||||
|
def fetch_single_stock(
|
||||||
|
self,
|
||||||
|
ts_code: str,
|
||||||
|
start_date: str,
|
||||||
|
end_date: str,
|
||||||
|
) -> pd.DataFrame:
|
||||||
|
"""获取单只股票的筹码分布数据。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ts_code: 股票代码
|
||||||
|
start_date: 起始日期(YYYYMMDD)
|
||||||
|
end_date: 结束日期(YYYYMMDD)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
包含筹码分布数据的 DataFrame
|
||||||
|
"""
|
||||||
|
# 使用 get_cyq_perf 获取数据(传递共享 client)
|
||||||
|
data = get_cyq_perf(
|
||||||
|
ts_code=ts_code,
|
||||||
|
start_date=start_date,
|
||||||
|
end_date=end_date,
|
||||||
|
client=self.client, # 传递共享客户端以确保限流
|
||||||
|
)
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
def sync_cyq_perf(
|
||||||
|
force_full: bool = False,
|
||||||
|
start_date: Optional[str] = None,
|
||||||
|
end_date: Optional[str] = None,
|
||||||
|
max_workers: Optional[int] = None,
|
||||||
|
dry_run: bool = False,
|
||||||
|
) -> dict[str, pd.DataFrame]:
|
||||||
|
"""同步所有股票的筹码分布数据。
|
||||||
|
|
||||||
|
这是筹码分布数据同步的主要入口点。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
force_full: 若为 True,强制从 20180101 完整重载
|
||||||
|
start_date: 手动指定起始日期(YYYYMMDD)
|
||||||
|
end_date: 手动指定结束日期(默认为今天)
|
||||||
|
max_workers: 工作线程数(默认: 10)
|
||||||
|
dry_run: 若为 True,仅预览将要同步的内容,不写入数据
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
映射 ts_code 到 DataFrame 的字典
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> # 首次同步(从 20180101 全量加载)
|
||||||
|
>>> result = sync_cyq_perf()
|
||||||
|
>>>
|
||||||
|
>>> # 后续同步(增量 - 仅新数据)
|
||||||
|
>>> result = sync_cyq_perf()
|
||||||
|
>>>
|
||||||
|
>>> # 强制完整重载
|
||||||
|
>>> result = sync_cyq_perf(force_full=True)
|
||||||
|
>>>
|
||||||
|
>>> # 手动指定日期范围
|
||||||
|
>>> result = sync_cyq_perf(start_date='20240101', end_date='20240131')
|
||||||
|
>>>
|
||||||
|
>>> # 自定义线程数
|
||||||
|
>>> result = sync_cyq_perf(max_workers=20)
|
||||||
|
>>>
|
||||||
|
>>> # Dry run(仅预览)
|
||||||
|
>>> result = sync_cyq_perf(dry_run=True)
|
||||||
|
"""
|
||||||
|
sync_manager = CyqPerfSync(max_workers=max_workers)
|
||||||
|
return sync_manager.sync_all(
|
||||||
|
force_full=force_full,
|
||||||
|
start_date=start_date,
|
||||||
|
end_date=end_date,
|
||||||
|
dry_run=dry_run,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def preview_cyq_perf_sync(
|
||||||
|
force_full: bool = False,
|
||||||
|
start_date: Optional[str] = None,
|
||||||
|
end_date: Optional[str] = None,
|
||||||
|
sample_size: int = 3,
|
||||||
|
) -> dict:
|
||||||
|
"""预览筹码分布数据同步数据量和样本(不实际同步)。
|
||||||
|
|
||||||
|
这是推荐的方式,可在实际同步前检查将要同步的内容。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
force_full: 若为 True,预览全量同步(从 20180101)
|
||||||
|
start_date: 手动指定起始日期(覆盖自动检测)
|
||||||
|
end_date: 手动指定结束日期(默认为今天)
|
||||||
|
sample_size: 预览用样本股票数量(默认: 3)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
包含预览信息的字典:
|
||||||
|
{
|
||||||
|
'sync_needed': bool,
|
||||||
|
'stock_count': int,
|
||||||
|
'start_date': str,
|
||||||
|
'end_date': str,
|
||||||
|
'estimated_records': int,
|
||||||
|
'sample_data': pd.DataFrame,
|
||||||
|
'mode': str, # 'full', 'incremental', 'partial', 或 'none'
|
||||||
|
}
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> # 预览将要同步的内容
|
||||||
|
>>> preview = preview_cyq_perf_sync()
|
||||||
|
>>>
|
||||||
|
>>> # 预览全量同步
|
||||||
|
>>> preview = preview_cyq_perf_sync(force_full=True)
|
||||||
|
>>>
|
||||||
|
>>> # 预览更多样本
|
||||||
|
>>> preview = preview_cyq_perf_sync(sample_size=5)
|
||||||
|
"""
|
||||||
|
sync_manager = CyqPerfSync()
|
||||||
|
return sync_manager.preview_sync(
|
||||||
|
force_full=force_full,
|
||||||
|
start_date=start_date,
|
||||||
|
end_date=end_date,
|
||||||
|
sample_size=sample_size,
|
||||||
|
)
|
||||||
@@ -10,6 +10,9 @@
|
|||||||
- api_daily_basic.py: 每日指标数据同步 (DailyBasicSync 类)
|
- api_daily_basic.py: 每日指标数据同步 (DailyBasicSync 类)
|
||||||
- api_bak_basic.py: 历史股票列表同步 (BakBasicSync 类)
|
- api_bak_basic.py: 历史股票列表同步 (BakBasicSync 类)
|
||||||
- api_pro_bar.py: Pro Bar 数据同步 (ProBarSync 类)
|
- api_pro_bar.py: Pro Bar 数据同步 (ProBarSync 类)
|
||||||
|
- api_stock_st.py: ST股票列表同步 (StockSTSync 类)
|
||||||
|
- api_stk_limit.py: 涨跌停价格同步 (StkLimitSync 类)
|
||||||
|
- api_cyq_perf.py: 筹码分布数据同步 (CyqPerfSync 类)
|
||||||
- api_stock_basic.py: 股票基本信息同步
|
- api_stock_basic.py: 股票基本信息同步
|
||||||
- api_trade_cal.py: 交易日历同步
|
- api_trade_cal.py: 交易日历同步
|
||||||
|
|
||||||
@@ -77,6 +80,8 @@ def sync_all_data(
|
|||||||
4. daily_basic: 每日指标(PE、PB、换手率、市值)
|
4. daily_basic: 每日指标(PE、PB、换手率、市值)
|
||||||
5. bak_basic: 历史股票列表
|
5. bak_basic: 历史股票列表
|
||||||
6. stock_st: ST股票列表
|
6. stock_st: ST股票列表
|
||||||
|
7. stk_limit: 每日涨跌停价格
|
||||||
|
8. cyq_perf: 每日筹码及胜率
|
||||||
|
|
||||||
新增接口时,只需在 api_wrappers/__init__.py 中添加注册代码,
|
新增接口时,只需在 api_wrappers/__init__.py 中添加注册代码,
|
||||||
无需修改本函数。
|
无需修改本函数。
|
||||||
|
|||||||
@@ -6,13 +6,6 @@
|
|||||||
# ## 1. 导入依赖
|
# ## 1. 导入依赖
|
||||||
# %%
|
# %%
|
||||||
import os
|
import os
|
||||||
from datetime import datetime
|
|
||||||
from typing import List, Tuple, Optional
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import polars as pl
|
|
||||||
import pandas as pd
|
|
||||||
import matplotlib.pyplot as plt
|
|
||||||
|
|
||||||
from src.factors import FactorEngine
|
from src.factors import FactorEngine
|
||||||
from src.training import (
|
from src.training import (
|
||||||
@@ -23,7 +16,7 @@ from src.training import (
|
|||||||
Winsorizer,
|
Winsorizer,
|
||||||
CrossSectionalStandardScaler,
|
CrossSectionalStandardScaler,
|
||||||
)
|
)
|
||||||
from src.training.trainer_v2 import Trainer
|
from src.training.core.trainer_v2 import Trainer
|
||||||
from src.training.components.filters import STFilter
|
from src.training.components.filters import STFilter
|
||||||
from src.experiment.common import (
|
from src.experiment.common import (
|
||||||
SELECTED_FACTORS,
|
SELECTED_FACTORS,
|
||||||
|
|||||||
@@ -6,9 +6,6 @@
|
|||||||
# ## 1. 导入依赖
|
# ## 1. 导入依赖
|
||||||
# %%
|
# %%
|
||||||
import os
|
import os
|
||||||
from datetime import datetime
|
|
||||||
|
|
||||||
import polars as pl
|
|
||||||
|
|
||||||
from src.factors import FactorEngine
|
from src.factors import FactorEngine
|
||||||
from src.training import (
|
from src.training import (
|
||||||
@@ -19,7 +16,7 @@ from src.training import (
|
|||||||
Winsorizer,
|
Winsorizer,
|
||||||
StandardScaler,
|
StandardScaler,
|
||||||
)
|
)
|
||||||
from src.training.trainer_v2 import Trainer
|
from src.training.core.trainer_v2 import Trainer
|
||||||
from src.training.components.filters import STFilter
|
from src.training.components.filters import STFilter
|
||||||
from src.experiment.common import (
|
from src.experiment.common import (
|
||||||
SELECTED_FACTORS,
|
SELECTED_FACTORS,
|
||||||
|
|||||||
@@ -53,7 +53,7 @@ from src.training.result_analyzer import ResultAnalyzer
|
|||||||
from src.training.tasks import BaseTask, RegressionTask, RankTask
|
from src.training.tasks import BaseTask, RegressionTask, RankTask
|
||||||
|
|
||||||
# 从 trainer_v2 导入新 Trainer(推荐)
|
# 从 trainer_v2 导入新 Trainer(推荐)
|
||||||
from src.training.trainer_v2 import Trainer as TrainerV2
|
from src.training.core.trainer_v2 import Trainer as TrainerV2
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
# 基础抽象类
|
# 基础抽象类
|
||||||
|
|||||||
@@ -5,5 +5,6 @@
|
|||||||
|
|
||||||
from src.training.core.stock_pool_manager import StockPoolManager
|
from src.training.core.stock_pool_manager import StockPoolManager
|
||||||
from src.training.core.trainer import Trainer
|
from src.training.core.trainer import Trainer
|
||||||
|
from src.training.core.trainer_v2 import Trainer as TrainerV2
|
||||||
|
|
||||||
__all__ = ["StockPoolManager", "Trainer"]
|
__all__ = ["StockPoolManager", "Trainer", "TrainerV2"]
|
||||||
|
|||||||
278
tests/test_cyq_perf.py
Normal file
278
tests/test_cyq_perf.py
Normal file
@@ -0,0 +1,278 @@
|
|||||||
|
"""Tests for cyq_perf API wrapper.
|
||||||
|
|
||||||
|
Tests for src.data.api_wrappers.api_cyq_perf module.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import pandas as pd
|
||||||
|
from unittest.mock import patch, MagicMock
|
||||||
|
|
||||||
|
from src.data.api_wrappers.api_cyq_perf import (
|
||||||
|
get_cyq_perf,
|
||||||
|
sync_cyq_perf,
|
||||||
|
preview_cyq_perf_sync,
|
||||||
|
CyqPerfSync,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestCyqPerf:
|
||||||
|
"""Test suite for cyq_perf API wrapper."""
|
||||||
|
|
||||||
|
@patch("src.data.api_wrappers.api_cyq_perf.TushareClient")
|
||||||
|
def test_get_cyq_perf_by_stock(self, mock_client_class):
|
||||||
|
"""Test fetching chip distribution data by stock code."""
|
||||||
|
# Setup mock
|
||||||
|
mock_client = MagicMock()
|
||||||
|
mock_client_class.return_value = mock_client
|
||||||
|
mock_client.query.return_value = pd.DataFrame(
|
||||||
|
{
|
||||||
|
"ts_code": ["000001.SZ"],
|
||||||
|
"trade_date": ["20240101"],
|
||||||
|
"his_low": [8.50],
|
||||||
|
"his_high": [12.00],
|
||||||
|
"cost_5pct": [8.80],
|
||||||
|
"cost_15pct": [9.20],
|
||||||
|
"cost_50pct": [10.00],
|
||||||
|
"cost_85pct": [10.80],
|
||||||
|
"cost_95pct": [11.20],
|
||||||
|
"weight_avg": [10.00],
|
||||||
|
"winner_rate": [5.50],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Test
|
||||||
|
result = get_cyq_perf(
|
||||||
|
ts_code="000001.SZ",
|
||||||
|
start_date="20240101",
|
||||||
|
end_date="20240131",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert not result.empty
|
||||||
|
assert "ts_code" in result.columns
|
||||||
|
assert "trade_date" in result.columns
|
||||||
|
assert "cost_50pct" in result.columns
|
||||||
|
assert "winner_rate" in result.columns
|
||||||
|
assert result["ts_code"].iloc[0] == "000001.SZ"
|
||||||
|
mock_client.query.assert_called_once()
|
||||||
|
|
||||||
|
# Verify parameters
|
||||||
|
call_args = mock_client.query.call_args
|
||||||
|
assert call_args[0][0] == "cyq_perf"
|
||||||
|
assert call_args[1]["ts_code"] == "000001.SZ"
|
||||||
|
assert call_args[1]["start_date"] == "20240101"
|
||||||
|
assert call_args[1]["end_date"] == "20240131"
|
||||||
|
|
||||||
|
@patch("src.data.api_wrappers.api_cyq_perf.TushareClient")
|
||||||
|
def test_get_cyq_perf_with_shared_client(self, mock_client_class):
|
||||||
|
"""Test fetching data with shared client for rate limiting."""
|
||||||
|
# Setup mock for shared client
|
||||||
|
shared_client = MagicMock()
|
||||||
|
shared_client.query.return_value = pd.DataFrame(
|
||||||
|
{
|
||||||
|
"ts_code": ["600000.SH"],
|
||||||
|
"trade_date": ["20240115"],
|
||||||
|
"his_low": [5.00],
|
||||||
|
"his_high": [8.00],
|
||||||
|
"cost_5pct": [5.50],
|
||||||
|
"cost_15pct": [6.00],
|
||||||
|
"cost_50pct": [6.50],
|
||||||
|
"cost_85pct": [7.00],
|
||||||
|
"cost_95pct": [7.50],
|
||||||
|
"weight_avg": [6.50],
|
||||||
|
"winner_rate": [3.20],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Test with shared client
|
||||||
|
result = get_cyq_perf(
|
||||||
|
ts_code="600000.SH",
|
||||||
|
start_date="20240101",
|
||||||
|
end_date="20240131",
|
||||||
|
client=shared_client,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Assert shared client was used, not new instance
|
||||||
|
mock_client_class.assert_not_called()
|
||||||
|
shared_client.query.assert_called_once()
|
||||||
|
assert not result.empty
|
||||||
|
assert result["ts_code"].iloc[0] == "600000.SH"
|
||||||
|
|
||||||
|
@patch("src.data.api_wrappers.api_cyq_perf.TushareClient")
|
||||||
|
def test_get_cyq_perf_empty_response(self, mock_client_class):
|
||||||
|
"""Test handling empty response."""
|
||||||
|
mock_client = MagicMock()
|
||||||
|
mock_client_class.return_value = mock_client
|
||||||
|
mock_client.query.return_value = pd.DataFrame()
|
||||||
|
|
||||||
|
result = get_cyq_perf(
|
||||||
|
ts_code="000001.SZ",
|
||||||
|
start_date="20240101",
|
||||||
|
end_date="20240131",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.empty
|
||||||
|
|
||||||
|
@patch("src.data.api_wrappers.api_cyq_perf.TushareClient")
|
||||||
|
def test_get_cyq_perf_date_column_rename(self, mock_client_class):
|
||||||
|
"""Test that 'date' column is renamed to 'trade_date'."""
|
||||||
|
mock_client = MagicMock()
|
||||||
|
mock_client_class.return_value = mock_client
|
||||||
|
# Return data with 'date' column instead of 'trade_date'
|
||||||
|
mock_client.query.return_value = pd.DataFrame(
|
||||||
|
{
|
||||||
|
"ts_code": ["000001.SZ"],
|
||||||
|
"date": ["20240101"], # Note: 'date' not 'trade_date'
|
||||||
|
"cost_50pct": [10.00],
|
||||||
|
"winner_rate": [5.50],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
result = get_cyq_perf(ts_code="000001.SZ")
|
||||||
|
|
||||||
|
# Assert 'date' was renamed to 'trade_date'
|
||||||
|
assert "trade_date" in result.columns
|
||||||
|
assert "date" not in result.columns
|
||||||
|
assert result["trade_date"].iloc[0] == "20240101"
|
||||||
|
|
||||||
|
|
||||||
|
class TestCyqPerfSync:
|
||||||
|
"""Test suite for CyqPerfSync class."""
|
||||||
|
|
||||||
|
@patch("src.data.api_wrappers.api_cyq_perf.TushareClient")
|
||||||
|
@patch("src.data.api_wrappers.base_sync.Storage")
|
||||||
|
@patch("src.data.api_wrappers.base_sync.ThreadSafeStorage")
|
||||||
|
@patch("src.data.api_wrappers.base_sync.sync_trade_cal_cache")
|
||||||
|
@patch("src.data.api_wrappers.base_sync.sync_all_stocks")
|
||||||
|
def test_cyq_perf_sync_class_structure(
|
||||||
|
self,
|
||||||
|
mock_sync_stocks,
|
||||||
|
mock_sync_cal,
|
||||||
|
mock_storage_class,
|
||||||
|
mock_base_storage_class,
|
||||||
|
mock_client_class,
|
||||||
|
):
|
||||||
|
"""Test CyqPerfSync class structure and attributes."""
|
||||||
|
# Verify class attributes
|
||||||
|
assert CyqPerfSync.table_name == "cyq_perf"
|
||||||
|
assert "ts_code" in CyqPerfSync.TABLE_SCHEMA
|
||||||
|
assert "trade_date" in CyqPerfSync.TABLE_SCHEMA
|
||||||
|
assert "cost_5pct" in CyqPerfSync.TABLE_SCHEMA
|
||||||
|
assert "cost_95pct" in CyqPerfSync.TABLE_SCHEMA
|
||||||
|
assert "winner_rate" in CyqPerfSync.TABLE_SCHEMA
|
||||||
|
assert CyqPerfSync.PRIMARY_KEY == ("ts_code", "trade_date")
|
||||||
|
|
||||||
|
@patch("src.data.api_wrappers.api_cyq_perf.get_cyq_perf")
|
||||||
|
def test_fetch_single_stock(self, mock_get_cyq_perf):
|
||||||
|
"""Test fetch_single_stock method."""
|
||||||
|
mock_get_cyq_perf.return_value = pd.DataFrame(
|
||||||
|
{
|
||||||
|
"ts_code": ["000001.SZ", "000001.SZ"],
|
||||||
|
"trade_date": ["20240101", "20240102"],
|
||||||
|
"cost_50pct": [10.00, 10.10],
|
||||||
|
"winner_rate": [5.50, 5.60],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
sync = CyqPerfSync()
|
||||||
|
# Mock the client to avoid real initialization
|
||||||
|
sync.client = MagicMock()
|
||||||
|
|
||||||
|
result = sync.fetch_single_stock(
|
||||||
|
ts_code="000001.SZ",
|
||||||
|
start_date="20240101",
|
||||||
|
end_date="20240102",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert not result.empty
|
||||||
|
assert len(result) == 2
|
||||||
|
mock_get_cyq_perf.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
|
class TestSyncCyqPerf:
|
||||||
|
"""Test suite for sync_cyq_perf convenience function."""
|
||||||
|
|
||||||
|
@patch("src.data.api_wrappers.api_cyq_perf.CyqPerfSync")
|
||||||
|
def test_sync_cyq_perf_calls_sync_all(self, mock_sync_class):
|
||||||
|
"""Test that sync_cyq_perf calls sync_all on CyqPerfSync."""
|
||||||
|
mock_sync_instance = MagicMock()
|
||||||
|
mock_sync_class.return_value = mock_sync_instance
|
||||||
|
mock_sync_instance.sync_all.return_value = {
|
||||||
|
"000001.SZ": pd.DataFrame({"ts_code": ["000001.SZ"]})
|
||||||
|
}
|
||||||
|
|
||||||
|
result = sync_cyq_perf()
|
||||||
|
|
||||||
|
mock_sync_class.assert_called_once_with(max_workers=None)
|
||||||
|
mock_sync_instance.sync_all.assert_called_once()
|
||||||
|
assert isinstance(result, dict)
|
||||||
|
|
||||||
|
@patch("src.data.api_wrappers.api_cyq_perf.CyqPerfSync")
|
||||||
|
def test_sync_cyq_perf_with_params(self, mock_sync_class):
|
||||||
|
"""Test sync_cyq_perf with parameters."""
|
||||||
|
mock_sync_instance = MagicMock()
|
||||||
|
mock_sync_class.return_value = mock_sync_instance
|
||||||
|
mock_sync_instance.sync_all.return_value = {}
|
||||||
|
|
||||||
|
result = sync_cyq_perf(
|
||||||
|
force_full=True,
|
||||||
|
start_date="20240101",
|
||||||
|
end_date="20240131",
|
||||||
|
max_workers=20,
|
||||||
|
dry_run=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_sync_class.assert_called_once_with(max_workers=20)
|
||||||
|
mock_sync_instance.sync_all.assert_called_once_with(
|
||||||
|
force_full=True,
|
||||||
|
start_date="20240101",
|
||||||
|
end_date="20240131",
|
||||||
|
dry_run=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestPreviewCyqPerfSync:
|
||||||
|
"""Test suite for preview_cyq_perf_sync convenience function."""
|
||||||
|
|
||||||
|
@patch("src.data.api_wrappers.api_cyq_perf.CyqPerfSync")
|
||||||
|
def test_preview_cyq_perf_sync(self, mock_sync_class):
|
||||||
|
"""Test preview_cyq_perf_sync function."""
|
||||||
|
mock_sync_instance = MagicMock()
|
||||||
|
mock_sync_class.return_value = mock_sync_instance
|
||||||
|
mock_sync_instance.preview_sync.return_value = {
|
||||||
|
"sync_needed": True,
|
||||||
|
"stock_count": 5000,
|
||||||
|
"start_date": "20240101",
|
||||||
|
"end_date": "20240131",
|
||||||
|
"estimated_records": 100000,
|
||||||
|
"sample_data": pd.DataFrame(),
|
||||||
|
"mode": "incremental",
|
||||||
|
}
|
||||||
|
|
||||||
|
result = preview_cyq_perf_sync()
|
||||||
|
|
||||||
|
mock_sync_class.assert_called_once_with()
|
||||||
|
mock_sync_instance.preview_sync.assert_called_once()
|
||||||
|
assert result["sync_needed"] is True
|
||||||
|
assert result["stock_count"] == 5000
|
||||||
|
|
||||||
|
@patch("src.data.api_wrappers.api_cyq_perf.CyqPerfSync")
|
||||||
|
def test_preview_cyq_perf_sync_with_params(self, mock_sync_class):
|
||||||
|
"""Test preview with custom parameters."""
|
||||||
|
mock_sync_instance = MagicMock()
|
||||||
|
mock_sync_class.return_value = mock_sync_instance
|
||||||
|
mock_sync_instance.preview_sync.return_value = {}
|
||||||
|
|
||||||
|
preview_cyq_perf_sync(
|
||||||
|
force_full=True,
|
||||||
|
start_date="20240101",
|
||||||
|
end_date="20240131",
|
||||||
|
sample_size=5,
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_sync_instance.preview_sync.assert_called_once_with(
|
||||||
|
force_full=True,
|
||||||
|
start_date="20240101",
|
||||||
|
end_date="20240131",
|
||||||
|
sample_size=5,
|
||||||
|
)
|
||||||
Reference in New Issue
Block a user