diff --git a/docs/api/api.md b/docs/api/api.md index 851d137..d34f294 100644 --- a/docs/api/api.md +++ b/docs/api/api.md @@ -684,4 +684,69 @@ df = pro.stk_limit(ts_code='002149.SZ', start_date='20190115', end_date='2019061 17 20190625 000021.SZ 9.30 7.61 18 20190625 000023.SZ 14.61 11.95 19 20190625 000025.SZ 23.08 18.88 -20 20190625 000026.SZ 8.66 7.08 \ No newline at end of file +20 20190625 000026.SZ 8.66 7.08 + + +每日筹码及胜率 +接口:cyq_perf +描述:获取A股每日筹码平均成本和胜率情况,每天18~19点左右更新,数据从2018年开始 +来源:Tushare社区 +限量:单次最大5000条,可以分页或者循环提取 +积分:5000积分每天20000次,10000积分每天200000次,15000积分每天不限总量 + + + + + +输入参数 + +名称 类型 必选 描述 +ts_code str Y 股票代码 +trade_date str N 交易日期(YYYYMMDD) +start_date str N 开始日期 +end_date str N 结束日期 + + + + +输出参数 + +名称 类型 默认显示 描述 +ts_code str Y 股票代码 +trade_date str Y 交易日期 +his_low float Y 历史最低价 +his_high float Y 历史最高价 +cost_5pct float Y 5分位成本 +cost_15pct float Y 15分位成本 +cost_50pct float Y 50分位成本 +cost_85pct float Y 85分位成本 +cost_95pct float Y 95分位成本 +weight_avg float Y 加权平均成本 +winner_rate float Y 胜率 + + + + +接口用法 + +pro = ts.pro_api() + +df = pro.cyq_perf(ts_code='600000.SH', start_date='20220101', end_date='20220429') + + + + +数据样例 + + ts_code trade_date his_low his_high cost_5pct cost_95pct weight_avg winner_rate +0 600000.SH 20220429 0.72 12.16 8.18 11.34 9.76 3.52 +1 600000.SH 20220428 0.72 12.16 8.24 11.34 9.76 3.08 +2 600000.SH 20220427 0.72 12.16 8.30 11.34 9.76 1.71 +3 600000.SH 20220426 0.72 12.16 8.34 11.34 9.76 2.02 +4 600000.SH 20220425 0.72 12.16 8.36 11.34 9.77 1.44 +.. ... ... ... ... ... ... ... ... +72 600000.SH 20220110 0.72 12.16 8.60 11.36 9.89 7.62 +73 600000.SH 20220107 0.72 12.16 8.60 11.36 9.89 7.59 +74 600000.SH 20220106 0.72 12.16 8.60 11.36 9.89 3.92 +75 600000.SH 20220105 0.72 12.16 8.60 11.36 9.89 5.65 +76 600000.SH 20220104 0.72 12.16 8.60 11.36 9.89 3.93 \ No newline at end of file diff --git a/src/data/api_wrappers/__init__.py b/src/data/api_wrappers/__init__.py index c14ed87..b7ab77e 100644 --- a/src/data/api_wrappers/__init__.py +++ b/src/data/api_wrappers/__init__.py @@ -13,12 +13,14 @@ Available APIs: - api_bak_basic: Stock historical list (股票历史列表) - api_stock_st: ST stock list (ST股票列表) - api_stk_limit: Stock limit price (每日涨跌停价格) + - api_cyq_perf: CYQ performance (每日筹码及胜率) Example: >>> from src.data.api_wrappers import get_daily, get_stock_basic, get_trade_cal, get_bak_basic >>> from src.data.api_wrappers import get_pro_bar, sync_pro_bar, get_daily_basic, sync_daily_basic >>> from src.data.api_wrappers import get_stock_st, sync_stock_st >>> from src.data.api_wrappers import get_stk_limit, sync_stk_limit + >>> from src.data.api_wrappers import get_cyq_perf, sync_cyq_perf >>> data = get_daily('000001.SZ', start_date='20240101', end_date='20240131') >>> pro_data = get_pro_bar('000001.SZ', start_date='20240101', end_date='20240131') >>> daily_basic = get_daily_basic(trade_date='20240101') @@ -27,6 +29,7 @@ Example: >>> bak_basic = get_bak_basic(trade_date='20240101') >>> stock_st = get_stock_st(trade_date='20240101') >>> stk_limit = get_stk_limit(trade_date='20240101') + >>> cyq_perf = get_cyq_perf('000001.SZ', start_date='20240101', end_date='20240131') """ from src.data.api_wrappers.api_daily_basic import ( @@ -68,6 +71,12 @@ from src.data.api_wrappers.api_trade_cal import ( get_last_trading_day, sync_trade_cal_cache, ) +from src.data.api_wrappers.api_cyq_perf import ( + get_cyq_perf, + sync_cyq_perf, + preview_cyq_perf_sync, + CyqPerfSync, +) __all__ = [ # Daily market data @@ -115,6 +124,11 @@ __all__ = [ "sync_stk_limit", "preview_stk_limit_sync", "StkLimitSync", + # CYQ Performance (筹码分布) + "get_cyq_perf", + "sync_cyq_perf", + "preview_cyq_perf_sync", + "CyqPerfSync", ] # ============================================================================= @@ -198,6 +212,17 @@ try: order=50, ) + # 8. CYQ Performance - 每日筹码及胜率 + from src.data.api_wrappers.api_cyq_perf import CyqPerfSync + + sync_registry.register_class( + name="cyq_perf", + sync_class=CyqPerfSync, + display_name="每日筹码及胜率", + description="A股每日筹码平均成本和胜率情况(2018年开始)", + order=60, + ) + except ImportError: # sync_registry 可能不存在(首次导入),忽略 pass diff --git a/src/data/api_wrappers/api_cyq_perf.py b/src/data/api_wrappers/api_cyq_perf.py new file mode 100644 index 0000000..b2f7a96 --- /dev/null +++ b/src/data/api_wrappers/api_cyq_perf.py @@ -0,0 +1,233 @@ +"""CYQ Performance (筹码分布) interface. + +Fetch A-share stock chip distribution data (cost distribution and win rate) from Tushare. +This interface retrieves daily chip average cost and win rate information. +Data starts from 2018. +""" + +import pandas as pd +from typing import Optional + +from src.data.client import TushareClient +from src.data.api_wrappers.base_sync import StockBasedSync + + +def get_cyq_perf( + ts_code: str, + start_date: Optional[str] = None, + end_date: Optional[str] = None, + client: Optional[TushareClient] = None, +) -> pd.DataFrame: + """Fetch chip distribution (CYQ) performance data from Tushare. + + This interface retrieves daily chip average cost and win rate information + for A-share stocks. Data starts from 2018. + + Args: + ts_code: Stock code (e.g., '000001.SZ', '600000.SH') + start_date: Start date in YYYYMMDD format + end_date: End date in YYYYMMDD format + client: Optional TushareClient instance for shared rate limiting. + If None, creates a new client. For concurrent sync operations, + pass a shared client to ensure proper rate limiting. + + Returns: + pd.DataFrame with columns: + - ts_code: Stock code + - trade_date: Trade date (YYYYMMDD) + - his_low: Historical lowest price + - his_high: Historical highest price + - cost_5pct: 5th percentile cost + - cost_15pct: 15th percentile cost + - cost_50pct: 50th percentile cost (median) + - cost_85pct: 85th percentile cost + - cost_95pct: 95th percentile cost + - weight_avg: Weighted average cost + - winner_rate: Win rate (percentage) + + Example: + >>> # Get chip distribution data for a stock + >>> data = get_cyq_perf('000001.SZ', start_date='20240101', end_date='20240131') + >>> + >>> # Get data with shared client for rate limiting + >>> from src.data.client import TushareClient + >>> client = TushareClient() + >>> data = get_cyq_perf('000001.SZ', start_date='20240101', end_date='20240131', client=client) + """ + client = client or TushareClient() + + # Build parameters + params = {"ts_code": ts_code} + + if start_date: + params["start_date"] = start_date + if end_date: + params["end_date"] = end_date + + # Fetch data using cyq_perf API + data = client.query("cyq_perf", **params) + + # Rename date column if needed + if "date" in data.columns: + data = data.rename(columns={"date": "trade_date"}) + + return data + + +class CyqPerfSync(StockBasedSync): + """筹码分布数据批量同步管理器,支持全量/增量同步。 + + 继承自 StockBasedSync,使用多线程按股票并发获取数据。 + + Example: + >>> sync = CyqPerfSync() + >>> results = sync.sync_all() # 增量同步 + >>> results = sync.sync_all(force_full=True) # 全量同步 + >>> preview = sync.preview_sync() # 预览 + """ + + table_name = "cyq_perf" + + # 表结构定义 + TABLE_SCHEMA = { + "ts_code": "VARCHAR(16) NOT NULL", + "trade_date": "DATE NOT NULL", + "his_low": "DOUBLE", + "his_high": "DOUBLE", + "cost_5pct": "DOUBLE", + "cost_15pct": "DOUBLE", + "cost_50pct": "DOUBLE", + "cost_85pct": "DOUBLE", + "cost_95pct": "DOUBLE", + "weight_avg": "DOUBLE", + "winner_rate": "DOUBLE", + } + + # 索引定义 + TABLE_INDEXES = [ + ("idx_cyq_perf_date_code", ["trade_date", "ts_code"]), + ] + + # 主键定义 + PRIMARY_KEY = ("ts_code", "trade_date") + + def fetch_single_stock( + self, + ts_code: str, + start_date: str, + end_date: str, + ) -> pd.DataFrame: + """获取单只股票的筹码分布数据。 + + Args: + ts_code: 股票代码 + start_date: 起始日期(YYYYMMDD) + end_date: 结束日期(YYYYMMDD) + + Returns: + 包含筹码分布数据的 DataFrame + """ + # 使用 get_cyq_perf 获取数据(传递共享 client) + data = get_cyq_perf( + ts_code=ts_code, + start_date=start_date, + end_date=end_date, + client=self.client, # 传递共享客户端以确保限流 + ) + return data + + +def sync_cyq_perf( + force_full: bool = False, + start_date: Optional[str] = None, + end_date: Optional[str] = None, + max_workers: Optional[int] = None, + dry_run: bool = False, +) -> dict[str, pd.DataFrame]: + """同步所有股票的筹码分布数据。 + + 这是筹码分布数据同步的主要入口点。 + + Args: + force_full: 若为 True,强制从 20180101 完整重载 + start_date: 手动指定起始日期(YYYYMMDD) + end_date: 手动指定结束日期(默认为今天) + max_workers: 工作线程数(默认: 10) + dry_run: 若为 True,仅预览将要同步的内容,不写入数据 + + Returns: + 映射 ts_code 到 DataFrame 的字典 + + Example: + >>> # 首次同步(从 20180101 全量加载) + >>> result = sync_cyq_perf() + >>> + >>> # 后续同步(增量 - 仅新数据) + >>> result = sync_cyq_perf() + >>> + >>> # 强制完整重载 + >>> result = sync_cyq_perf(force_full=True) + >>> + >>> # 手动指定日期范围 + >>> result = sync_cyq_perf(start_date='20240101', end_date='20240131') + >>> + >>> # 自定义线程数 + >>> result = sync_cyq_perf(max_workers=20) + >>> + >>> # Dry run(仅预览) + >>> result = sync_cyq_perf(dry_run=True) + """ + sync_manager = CyqPerfSync(max_workers=max_workers) + return sync_manager.sync_all( + force_full=force_full, + start_date=start_date, + end_date=end_date, + dry_run=dry_run, + ) + + +def preview_cyq_perf_sync( + force_full: bool = False, + start_date: Optional[str] = None, + end_date: Optional[str] = None, + sample_size: int = 3, +) -> dict: + """预览筹码分布数据同步数据量和样本(不实际同步)。 + + 这是推荐的方式,可在实际同步前检查将要同步的内容。 + + Args: + force_full: 若为 True,预览全量同步(从 20180101) + start_date: 手动指定起始日期(覆盖自动检测) + end_date: 手动指定结束日期(默认为今天) + sample_size: 预览用样本股票数量(默认: 3) + + Returns: + 包含预览信息的字典: + { + 'sync_needed': bool, + 'stock_count': int, + 'start_date': str, + 'end_date': str, + 'estimated_records': int, + 'sample_data': pd.DataFrame, + 'mode': str, # 'full', 'incremental', 'partial', 或 'none' + } + + Example: + >>> # 预览将要同步的内容 + >>> preview = preview_cyq_perf_sync() + >>> + >>> # 预览全量同步 + >>> preview = preview_cyq_perf_sync(force_full=True) + >>> + >>> # 预览更多样本 + >>> preview = preview_cyq_perf_sync(sample_size=5) + """ + sync_manager = CyqPerfSync() + return sync_manager.preview_sync( + force_full=force_full, + start_date=start_date, + end_date=end_date, + sample_size=sample_size, + ) diff --git a/src/data/sync.py b/src/data/sync.py index dad3085..62708ab 100644 --- a/src/data/sync.py +++ b/src/data/sync.py @@ -10,6 +10,9 @@ - api_daily_basic.py: 每日指标数据同步 (DailyBasicSync 类) - api_bak_basic.py: 历史股票列表同步 (BakBasicSync 类) - api_pro_bar.py: Pro Bar 数据同步 (ProBarSync 类) + - api_stock_st.py: ST股票列表同步 (StockSTSync 类) + - api_stk_limit.py: 涨跌停价格同步 (StkLimitSync 类) + - api_cyq_perf.py: 筹码分布数据同步 (CyqPerfSync 类) - api_stock_basic.py: 股票基本信息同步 - api_trade_cal.py: 交易日历同步 @@ -77,6 +80,8 @@ def sync_all_data( 4. daily_basic: 每日指标(PE、PB、换手率、市值) 5. bak_basic: 历史股票列表 6. stock_st: ST股票列表 + 7. stk_limit: 每日涨跌停价格 + 8. cyq_perf: 每日筹码及胜率 新增接口时,只需在 api_wrappers/__init__.py 中添加注册代码, 无需修改本函数。 diff --git a/src/experiment/learn_to_rank.py b/src/experiment/learn_to_rank.py index f97ad5b..25a1416 100644 --- a/src/experiment/learn_to_rank.py +++ b/src/experiment/learn_to_rank.py @@ -6,13 +6,6 @@ # ## 1. 导入依赖 # %% import os -from datetime import datetime -from typing import List, Tuple, Optional - -import numpy as np -import polars as pl -import pandas as pd -import matplotlib.pyplot as plt from src.factors import FactorEngine from src.training import ( @@ -23,7 +16,7 @@ from src.training import ( Winsorizer, CrossSectionalStandardScaler, ) -from src.training.trainer_v2 import Trainer +from src.training.core.trainer_v2 import Trainer from src.training.components.filters import STFilter from src.experiment.common import ( SELECTED_FACTORS, diff --git a/src/experiment/regression.py b/src/experiment/regression.py index 2dd05f6..225eebf 100644 --- a/src/experiment/regression.py +++ b/src/experiment/regression.py @@ -6,9 +6,6 @@ # ## 1. 导入依赖 # %% import os -from datetime import datetime - -import polars as pl from src.factors import FactorEngine from src.training import ( @@ -19,7 +16,7 @@ from src.training import ( Winsorizer, StandardScaler, ) -from src.training.trainer_v2 import Trainer +from src.training.core.trainer_v2 import Trainer from src.training.components.filters import STFilter from src.experiment.common import ( SELECTED_FACTORS, diff --git a/src/training/__init__.py b/src/training/__init__.py index 9e043ca..0d04219 100644 --- a/src/training/__init__.py +++ b/src/training/__init__.py @@ -53,7 +53,7 @@ from src.training.result_analyzer import ResultAnalyzer from src.training.tasks import BaseTask, RegressionTask, RankTask # 从 trainer_v2 导入新 Trainer(推荐) -from src.training.trainer_v2 import Trainer as TrainerV2 +from src.training.core.trainer_v2 import Trainer as TrainerV2 __all__ = [ # 基础抽象类 diff --git a/src/training/core/__init__.py b/src/training/core/__init__.py index 4aa568d..953918c 100644 --- a/src/training/core/__init__.py +++ b/src/training/core/__init__.py @@ -5,5 +5,6 @@ from src.training.core.stock_pool_manager import StockPoolManager from src.training.core.trainer import Trainer +from src.training.core.trainer_v2 import Trainer as TrainerV2 -__all__ = ["StockPoolManager", "Trainer"] +__all__ = ["StockPoolManager", "Trainer", "TrainerV2"] diff --git a/src/training/trainer_v2.py b/src/training/core/trainer_v2.py similarity index 100% rename from src/training/trainer_v2.py rename to src/training/core/trainer_v2.py diff --git a/tests/test_cyq_perf.py b/tests/test_cyq_perf.py new file mode 100644 index 0000000..57b6e3b --- /dev/null +++ b/tests/test_cyq_perf.py @@ -0,0 +1,278 @@ +"""Tests for cyq_perf API wrapper. + +Tests for src.data.api_wrappers.api_cyq_perf module. +""" + +import pytest +import pandas as pd +from unittest.mock import patch, MagicMock + +from src.data.api_wrappers.api_cyq_perf import ( + get_cyq_perf, + sync_cyq_perf, + preview_cyq_perf_sync, + CyqPerfSync, +) + + +class TestCyqPerf: + """Test suite for cyq_perf API wrapper.""" + + @patch("src.data.api_wrappers.api_cyq_perf.TushareClient") + def test_get_cyq_perf_by_stock(self, mock_client_class): + """Test fetching chip distribution data by stock code.""" + # Setup mock + mock_client = MagicMock() + mock_client_class.return_value = mock_client + mock_client.query.return_value = pd.DataFrame( + { + "ts_code": ["000001.SZ"], + "trade_date": ["20240101"], + "his_low": [8.50], + "his_high": [12.00], + "cost_5pct": [8.80], + "cost_15pct": [9.20], + "cost_50pct": [10.00], + "cost_85pct": [10.80], + "cost_95pct": [11.20], + "weight_avg": [10.00], + "winner_rate": [5.50], + } + ) + + # Test + result = get_cyq_perf( + ts_code="000001.SZ", + start_date="20240101", + end_date="20240131", + ) + + # Assert + assert not result.empty + assert "ts_code" in result.columns + assert "trade_date" in result.columns + assert "cost_50pct" in result.columns + assert "winner_rate" in result.columns + assert result["ts_code"].iloc[0] == "000001.SZ" + mock_client.query.assert_called_once() + + # Verify parameters + call_args = mock_client.query.call_args + assert call_args[0][0] == "cyq_perf" + assert call_args[1]["ts_code"] == "000001.SZ" + assert call_args[1]["start_date"] == "20240101" + assert call_args[1]["end_date"] == "20240131" + + @patch("src.data.api_wrappers.api_cyq_perf.TushareClient") + def test_get_cyq_perf_with_shared_client(self, mock_client_class): + """Test fetching data with shared client for rate limiting.""" + # Setup mock for shared client + shared_client = MagicMock() + shared_client.query.return_value = pd.DataFrame( + { + "ts_code": ["600000.SH"], + "trade_date": ["20240115"], + "his_low": [5.00], + "his_high": [8.00], + "cost_5pct": [5.50], + "cost_15pct": [6.00], + "cost_50pct": [6.50], + "cost_85pct": [7.00], + "cost_95pct": [7.50], + "weight_avg": [6.50], + "winner_rate": [3.20], + } + ) + + # Test with shared client + result = get_cyq_perf( + ts_code="600000.SH", + start_date="20240101", + end_date="20240131", + client=shared_client, + ) + + # Assert shared client was used, not new instance + mock_client_class.assert_not_called() + shared_client.query.assert_called_once() + assert not result.empty + assert result["ts_code"].iloc[0] == "600000.SH" + + @patch("src.data.api_wrappers.api_cyq_perf.TushareClient") + def test_get_cyq_perf_empty_response(self, mock_client_class): + """Test handling empty response.""" + mock_client = MagicMock() + mock_client_class.return_value = mock_client + mock_client.query.return_value = pd.DataFrame() + + result = get_cyq_perf( + ts_code="000001.SZ", + start_date="20240101", + end_date="20240131", + ) + + assert result.empty + + @patch("src.data.api_wrappers.api_cyq_perf.TushareClient") + def test_get_cyq_perf_date_column_rename(self, mock_client_class): + """Test that 'date' column is renamed to 'trade_date'.""" + mock_client = MagicMock() + mock_client_class.return_value = mock_client + # Return data with 'date' column instead of 'trade_date' + mock_client.query.return_value = pd.DataFrame( + { + "ts_code": ["000001.SZ"], + "date": ["20240101"], # Note: 'date' not 'trade_date' + "cost_50pct": [10.00], + "winner_rate": [5.50], + } + ) + + result = get_cyq_perf(ts_code="000001.SZ") + + # Assert 'date' was renamed to 'trade_date' + assert "trade_date" in result.columns + assert "date" not in result.columns + assert result["trade_date"].iloc[0] == "20240101" + + +class TestCyqPerfSync: + """Test suite for CyqPerfSync class.""" + + @patch("src.data.api_wrappers.api_cyq_perf.TushareClient") + @patch("src.data.api_wrappers.base_sync.Storage") + @patch("src.data.api_wrappers.base_sync.ThreadSafeStorage") + @patch("src.data.api_wrappers.base_sync.sync_trade_cal_cache") + @patch("src.data.api_wrappers.base_sync.sync_all_stocks") + def test_cyq_perf_sync_class_structure( + self, + mock_sync_stocks, + mock_sync_cal, + mock_storage_class, + mock_base_storage_class, + mock_client_class, + ): + """Test CyqPerfSync class structure and attributes.""" + # Verify class attributes + assert CyqPerfSync.table_name == "cyq_perf" + assert "ts_code" in CyqPerfSync.TABLE_SCHEMA + assert "trade_date" in CyqPerfSync.TABLE_SCHEMA + assert "cost_5pct" in CyqPerfSync.TABLE_SCHEMA + assert "cost_95pct" in CyqPerfSync.TABLE_SCHEMA + assert "winner_rate" in CyqPerfSync.TABLE_SCHEMA + assert CyqPerfSync.PRIMARY_KEY == ("ts_code", "trade_date") + + @patch("src.data.api_wrappers.api_cyq_perf.get_cyq_perf") + def test_fetch_single_stock(self, mock_get_cyq_perf): + """Test fetch_single_stock method.""" + mock_get_cyq_perf.return_value = pd.DataFrame( + { + "ts_code": ["000001.SZ", "000001.SZ"], + "trade_date": ["20240101", "20240102"], + "cost_50pct": [10.00, 10.10], + "winner_rate": [5.50, 5.60], + } + ) + + sync = CyqPerfSync() + # Mock the client to avoid real initialization + sync.client = MagicMock() + + result = sync.fetch_single_stock( + ts_code="000001.SZ", + start_date="20240101", + end_date="20240102", + ) + + assert not result.empty + assert len(result) == 2 + mock_get_cyq_perf.assert_called_once() + + +class TestSyncCyqPerf: + """Test suite for sync_cyq_perf convenience function.""" + + @patch("src.data.api_wrappers.api_cyq_perf.CyqPerfSync") + def test_sync_cyq_perf_calls_sync_all(self, mock_sync_class): + """Test that sync_cyq_perf calls sync_all on CyqPerfSync.""" + mock_sync_instance = MagicMock() + mock_sync_class.return_value = mock_sync_instance + mock_sync_instance.sync_all.return_value = { + "000001.SZ": pd.DataFrame({"ts_code": ["000001.SZ"]}) + } + + result = sync_cyq_perf() + + mock_sync_class.assert_called_once_with(max_workers=None) + mock_sync_instance.sync_all.assert_called_once() + assert isinstance(result, dict) + + @patch("src.data.api_wrappers.api_cyq_perf.CyqPerfSync") + def test_sync_cyq_perf_with_params(self, mock_sync_class): + """Test sync_cyq_perf with parameters.""" + mock_sync_instance = MagicMock() + mock_sync_class.return_value = mock_sync_instance + mock_sync_instance.sync_all.return_value = {} + + result = sync_cyq_perf( + force_full=True, + start_date="20240101", + end_date="20240131", + max_workers=20, + dry_run=True, + ) + + mock_sync_class.assert_called_once_with(max_workers=20) + mock_sync_instance.sync_all.assert_called_once_with( + force_full=True, + start_date="20240101", + end_date="20240131", + dry_run=True, + ) + + +class TestPreviewCyqPerfSync: + """Test suite for preview_cyq_perf_sync convenience function.""" + + @patch("src.data.api_wrappers.api_cyq_perf.CyqPerfSync") + def test_preview_cyq_perf_sync(self, mock_sync_class): + """Test preview_cyq_perf_sync function.""" + mock_sync_instance = MagicMock() + mock_sync_class.return_value = mock_sync_instance + mock_sync_instance.preview_sync.return_value = { + "sync_needed": True, + "stock_count": 5000, + "start_date": "20240101", + "end_date": "20240131", + "estimated_records": 100000, + "sample_data": pd.DataFrame(), + "mode": "incremental", + } + + result = preview_cyq_perf_sync() + + mock_sync_class.assert_called_once_with() + mock_sync_instance.preview_sync.assert_called_once() + assert result["sync_needed"] is True + assert result["stock_count"] == 5000 + + @patch("src.data.api_wrappers.api_cyq_perf.CyqPerfSync") + def test_preview_cyq_perf_sync_with_params(self, mock_sync_class): + """Test preview with custom parameters.""" + mock_sync_instance = MagicMock() + mock_sync_class.return_value = mock_sync_instance + mock_sync_instance.preview_sync.return_value = {} + + preview_cyq_perf_sync( + force_full=True, + start_date="20240101", + end_date="20240131", + sample_size=5, + ) + + mock_sync_instance.preview_sync.assert_called_once_with( + force_full=True, + start_date="20240101", + end_date="20240131", + sample_size=5, + )