diff --git a/.kilocode/rules/python-development-guidelines.md b/.kilocode/rules/python-development-guidelines.md index 76a0a71..b7d1a2d 100644 --- a/.kilocode/rules/python-development-guidelines.md +++ b/.kilocode/rules/python-development-guidelines.md @@ -291,6 +291,36 @@ tests/ └── fixtures/ # 测试数据 ``` +### 5.3 Mock 使用规范 +- **默认不使用 Mock**:单元测试应直接调用真实 API 或服务 +- **仅在必要时使用 Mock**: + - 测试错误处理场景(如网络超时、服务不可用) + - 测试空结果或边界情况 + - 隔离外部依赖的不确定性 +- **禁止全面 Mock**:不应为避免配置或环境问题而使用 Mock +- **真实环境验证**:确保至少有一个测试套件直接调用真实 API + +```python +# 正确示例:直接调用真实 API +class TestUserService: + def test_get_user(self): + result = get_user('user_id') + assert result.name == "John" + + def test_empty_result_with_mock(self): # 特殊场景使用 mock + with patch('module.api_call', return_value=pd.DataFrame()): + result = get_user('invalid_id') + assert result.empty + +# 错误示例:过度使用 mock +class TestUserService: + def test_get_user(self): + mock_data = {"name": "John"} + with patch('module.api_call', return_value=mock_data): + result = get_user('user_id') + assert result == mock_data +``` + ## 6 Git提交规范 ### 6.1 提交信息格式 @@ -323,3 +353,4 @@ tests/ - [ ] 无循环依赖 - [ ] 命名符合规范 - [ ] 日志不包含敏感信息 +- [ ] 测试未过度使用 Mock diff --git a/data/stock_basic.h5 b/data/stock_basic.h5 new file mode 100644 index 0000000..5e4b36d Binary files /dev/null and b/data/stock_basic.h5 differ diff --git a/src/data/__init__.py b/src/data/__init__.py index ced302f..45c5221 100644 --- a/src/data/__init__.py +++ b/src/data/__init__.py @@ -6,10 +6,13 @@ Provides simplified interfaces for fetching and storing Tushare data. from src.data.config import Config, get_config from src.data.client import TushareClient from src.data.storage import Storage +from src.data.stock_basic import get_stock_basic, sync_all_stocks __all__ = [ "Config", "get_config", "TushareClient", "Storage", + "get_stock_basic", + "sync_all_stocks", ] diff --git a/src/data/api.md b/src/data/api.md index 30dead1..30f6332 100644 --- a/src/data/api.md +++ b/src/data/api.md @@ -44,4 +44,83 @@ pre_close float 昨收价【除权价】 change float 涨跌额 pct_chg float 涨跌幅 【基于除权后的昨收计算的涨跌幅:(今收-除权昨收)/除权昨收 】 vol float 成交量 (手) -amount float 成交额 (千元) \ No newline at end of file +amount float 成交额 (千元) + +接口用例 + +#取000001的前复权行情 +df = ts.pro_bar(ts_code='000001.SZ', adj='qfq', start_date='20180101', end_date='20181011') + + ts_code trade_date open high low close \ +trade_date +20181011 000001.SZ 20181011 1085.71 1097.59 1047.90 1065.19 +20181010 000001.SZ 20181010 1138.65 1151.61 1121.36 1128.92 +20181009 000001.SZ 20181009 1130.00 1155.93 1122.44 1140.81 +20181008 000001.SZ 20181008 1155.93 1165.65 1128.92 1128.92 +20180928 000001.SZ 20180928 1164.57 1217.51 1164.57 1193.74 + +#取上证指数行情数据 + +df = ts.pro_bar(ts_code='000001.SH', asset='I', start_date='20180101', end_date='20181011') + +In [10]: df.head() +Out[10]: + ts_code trade_date close open high low \ +0 000001.SH 20181011 2583.4575 2643.0740 2661.2859 2560.3164 +1 000001.SH 20181010 2725.8367 2723.7242 2743.5480 2703.0626 +2 000001.SH 20181009 2721.0130 2713.7319 2734.3142 2711.1971 +3 000001.SH 20181008 2716.5104 2768.2075 2771.9384 2710.1781 +4 000001.SH 20180928 2821.3501 2794.2644 2821.7553 2791.8363 + + pre_close change pct_chg vol amount +0 2725.8367 -142.3792 -5.2233 197150702.0 170057762.5 +1 2721.0130 4.8237 0.1773 113485736.0 111312455.3 +2 2716.5104 4.5026 0.1657 116771899.0 110292457.8 +3 2821.3501 -104.8397 -3.7159 149501388.0 141531551.8 +4 2791.7748 29.5753 1.0594 134290456.0 125369989.4 + +#均线 + +df = ts.pro_bar(ts_code='000001.SZ', start_date='20180101', end_date='20181011', ma=[5, 20, 50]) +注:Tushare pro_bar接口的均价和均量数据是动态计算,想要获取某个时间段的均线,必须要设置start_date日期大于最大均线的日期数,然后自行截取想要日期段。例如,想要获取20190801开始的3日均线,必须设置start_date='20190729',然后剔除20190801之前的日期记录。 + +#换手率tor,量比vr + +df = ts.pro_bar(ts_code='000001.SZ', start_date='20180101', end_date='20181011', factors=['tor', 'vr']) + + +- 基础信息:https://tushare.pro/document/2?doc_id=25 +接口:stock_basic,可以通过数据工具调试和查看数据 +描述:获取基础信息数据,包括股票代码、名称、上市日期、退市日期等 +权限:2000积分起。此接口是基础信息,调取一次就可以拉取完,建议保存倒本地存储后使用 + +输入参数 + +名称 类型 必选 描述 +ts_code str N TS股票代码 +name str N 名称 +market str N 市场类别 (主板/创业板/科创板/CDR/北交所) +list_status str N 上市状态 L上市 D退市 P暂停上市 G过会未交易,默认是L +exchange str N 交易所 SSE上交所 SZSE深交所 BSE北交所 +is_hs str N 是否沪深港通标的,N否 H沪股通 S深股通 +输出参数 + +名称 类型 默认显示 描述 +ts_code str Y TS代码 +symbol str Y 股票代码 +name str Y 股票名称 +area str Y 地域 +industry str Y 所属行业 +fullname str N 股票全称 +enname str N 英文全称 +cnspell str Y 拼音缩写 +market str Y 市场类型(主板/创业板/科创板/CDR) +exchange str N 交易所代码 +curr_type str N 交易货币 +list_status str N 上市状态 L上市 D退市 G过会未交易 P暂停上市 +list_date str Y 上市日期 +delist_date str N 退市日期 +is_hs str N 是否沪深港通标的,N否 H沪股通 S深股通 +act_name str Y 实控人名称 +act_ent_type str Y 实控人企业性质 +说明:旧版上的PE/PB/股本等字段,请在行情接口“每日指标”中获取。 \ No newline at end of file diff --git a/src/data/client.py b/src/data/client.py index 3994a90..04ec97e 100644 --- a/src/data/client.py +++ b/src/data/client.py @@ -44,7 +44,7 @@ class TushareClient: """Execute API query with rate limiting and retry. Args: - api_name: API name (e.g., 'daily') + api_name: API name ('daily', 'pro_bar', etc.) timeout: Timeout for rate limiting **params: API parameters @@ -66,8 +66,22 @@ class TushareClient: for attempt in range(max_retries): try: - api = self._get_api() - data = api.query(api_name, **params) + import tushare as ts + + # pro_bar uses ts.pro_bar() instead of api.query() + if api_name == "pro_bar": + # pro_bar parameters: ts_code, start_date, end_date, adj, freq, factors, ma, adjfactor + data = ts.pro_bar(ts_code=params.get("ts_code"), + start_date=params.get("start_date"), + end_date=params.get("end_date"), + adj=params.get("adj"), + freq=params.get("freq", "D"), + factors=params.get("factors"), # factors should be a list like ['tor', 'vr'] + ma=params.get("ma"), + adjfactor=params.get("adjfactor")) + else: + api = self._get_api() + data = api.query(api_name, **params) available = self.rate_limiter.get_available_tokens() print(f"[Tushare] {api_name} | tokens: {available:.0f}/{self.rate_limiter.capacity}") diff --git a/src/data/config.py b/src/data/config.py index 96207e5..9303ed1 100644 --- a/src/data/config.py +++ b/src/data/config.py @@ -1,16 +1,37 @@ """Configuration management for data collection module.""" +import os from pathlib import Path from pydantic_settings import BaseSettings +# Config directory path - used for loading .env.local +# Static detection for pydantic-settings to find .env.local +CONFIG_DIR = Path(__file__).parent.parent.parent / "config" + + +def _get_project_root() -> Path: + """Get project root path from ROOT_PATH env var or auto-detect.""" + # Try to read from environment variable first + root_path = os.environ.get("ROOT_PATH") or os.environ.get("DATA_ROOT") + if root_path: + return Path(root_path) + # Fallback to auto-detection + return Path(__file__).parent.parent.parent + + class Config(BaseSettings): """Application configuration loaded from environment variables.""" # Tushare API token tushare_token: str = "" - # Data storage path - data_path: Path = Path("./data") + # Root path - loaded from environment variable ROOT_PATH + # If not set, uses auto-detected path + root_path: str = "" + + # Data storage path - can be set via DATA_PATH environment variable + # If relative path, it will be resolved relative to root_path + data_path: str = "data" # Rate limit: requests per minute rate_limit: int = 100 @@ -18,10 +39,31 @@ class Config(BaseSettings): # Thread pool size threads: int = 2 + @property + def project_root(self) -> Path: + """Get project root path.""" + if self.root_path: + return Path(self.root_path) + return _get_project_root() + + @property + def data_path_resolved(self) -> Path: + """Get resolved data path (absolute).""" + path = Path(self.data_path) + if path.is_absolute(): + return path + return self.project_root / path + class Config: - env_file = ".env.local" + # 从 config/ 目录读取 .env.local 文件 + env_file = str(CONFIG_DIR / ".env.local") env_file_encoding = "utf-8" case_sensitive = False + extra = "ignore" # 忽略 .env.local 中的额外变量 + # pydantic-settings 默认会将字段名转换为大写作为环境变量名 + # 所以 tushare_token 会映射到 TUSHARE_TOKEN + # root_path 会映射到 ROOT_PATH + # data_path 会映射到 DATA_PATH # Global config instance @@ -31,3 +73,8 @@ config = Config() def get_config() -> Config: """Get configuration instance.""" return config + + +def get_project_root() -> Path: + """Get project root path (convenience function).""" + return get_config().project_root diff --git a/src/data/daily.py b/src/data/daily.py index 9b2bfe9..cd1b048 100644 --- a/src/data/daily.py +++ b/src/data/daily.py @@ -57,14 +57,24 @@ def get_daily( if adj: params["adj"] = adj if factors: - params["factors"] = factors + # Tushare expects factors as comma-separated string, not list + if isinstance(factors, list): + factors_str = ",".join(factors) + else: + factors_str = factors + params["factors"] = factors_str + print(f"[get_daily] factors param: '{factors_str}'") if adjfactor: params["adjfactor"] = "True" - # Fetch data - data = client.query("daily", **params) + # Fetch data using pro_bar (supports factors like tor, vr) + print(f"[get_daily] Query params: {params}") + data = client.query("pro_bar", **params) - if data.empty: + if not data.empty: + print(f"[get_daily] Returned columns: {data.columns.tolist()}") + print(f"[get_daily] Sample row: {data.iloc[0].to_dict()}") + else: print(f"[get_daily] No data for ts_code={ts_code}") return data diff --git a/src/data/stock_basic.py b/src/data/stock_basic.py new file mode 100644 index 0000000..ee28852 --- /dev/null +++ b/src/data/stock_basic.py @@ -0,0 +1,120 @@ +"""Simplified stock basic information interface. + +Fetch basic stock information including code, name, listing date, etc. +This is a special interface - call once to get all stocks (listed and delisted). +""" +import pandas as pd +from typing import Optional, Literal, List +from src.data.client import TushareClient +from src.data.storage import Storage + + +def get_stock_basic( + ts_code: Optional[str] = None, + name: Optional[str] = None, + market: Optional[Literal["主板", "创业板", "科创板", "CDR", "北交所"]] = None, + list_status: Optional[Literal["L", "D", "P", "G"]] = None, + exchange: Optional[Literal["SSE", "SZSE", "BSE"]] = None, + is_hs: Optional[Literal["N", "H", "S"]] = None, + fields: Optional[List[str]] = None, +) -> pd.DataFrame: + """Fetch basic stock information. + + This interface retrieves basic information data including stock code, + name, listing date, delisting date, etc. + + Args: + ts_code: TS stock code + name: Stock name + market: Market type (主板/创业板/科创板/CDR/北交所) + list_status: Listing status - L(listed), D(delisted), P(paused), G(approved-not-traded) + exchange: Exchange - SSE, SZSE, BSE + is_hs: HS indicator - N(no), H(Shanghai), S(Shenzhen) + fields: Specific fields to return, None returns default fields + + Returns: + pd.DataFrame with stock basic information containing: + - ts_code, symbol, name, area, industry, fullname, enname, cnspell, + market, exchange, curr_type, list_status, list_date, delist_date, + is_hs, act_name, act_ent_type + """ + client = TushareClient() + + # Build parameters + params = {} + if ts_code: + params["ts_code"] = ts_code + if name: + params["name"] = name + if market: + params["market"] = market + if list_status: + params["list_status"] = list_status + if exchange: + params["exchange"] = exchange + if is_hs: + params["is_hs"] = is_hs + if fields: + params["fields"] = ",".join(fields) + + # Fetch data + data = client.query("stock_basic", **params) + + if data.empty: + print("[get_stock_basic] No data returned") + + return data + + +def sync_all_stocks() -> pd.DataFrame: + """Fetch and save all stocks (listed and delisted) to local storage. + + This is a special interface that should only be called once to initialize + the local database with all stock information. + + Returns: + pd.DataFrame with all stock information + """ + # Initialize storage + storage = Storage() + + # Check if already exists + if storage.exists("stock_basic"): + print("[sync_all_stocks] stock_basic data already exists, skipping...") + return storage.load("stock_basic") + + print("[sync_all_stocks] Fetching all stocks (listed and delisted)...") + + # Fetch all stocks - explicitly get all list_status values + # API default is L (listed), so we need to fetch all statuses + client = TushareClient() + + all_data = [] + for status in ["L", "D", "P", "G"]: + print(f"[sync_all_stocks] Fetching stocks with status: {status}") + data = client.query("stock_basic", list_status=status) + print(f"[sync_all_stocks] Fetched {len(data)} stocks with status {status}") + if not data.empty: + all_data.append(data) + + if not all_data: + print("[sync_all_stocks] No stock data fetched") + return pd.DataFrame() + + # Combine all data + data = pd.concat(all_data, ignore_index=True) + # Remove duplicates if any + data = data.drop_duplicates(subset=["ts_code"], keep="first") + print(f"[sync_all_stocks] Total unique stocks: {len(data)}") + + # Save to storage + storage.save("stock_basic", data, mode="replace") + + print(f"[sync_all_stocks] Saved {len(data)} stocks to local storage") + return data + + +if __name__ == "__main__": + # Sync all stocks (listed and delisted) to data folder + result = sync_all_stocks() + print(f"Total stocks synced: {len(result)}") diff --git a/src/data/storage.py b/src/data/storage.py index fba9caa..c9904f6 100644 --- a/src/data/storage.py +++ b/src/data/storage.py @@ -16,7 +16,7 @@ class Storage: path: Base path for data storage (auto-loaded from config if not provided) """ cfg = get_config() - self.base_path = path or cfg.data_path + self.base_path = path or cfg.data_path_resolved self.base_path.mkdir(parents=True, exist_ok=True) def _get_file_path(self, name: str) -> Path: diff --git a/tests/test_daily.py b/tests/test_daily.py index 7abbf2f..9f775fd 100644 --- a/tests/test_daily.py +++ b/tests/test_daily.py @@ -7,9 +7,7 @@ Tests the daily interface implementation against api.md requirements: """ import pytest import pandas as pd -from unittest.mock import Mock, patch from src.data.daily import get_daily -from src.data.client import TushareClient # Expected output fields according to api.md @@ -28,276 +26,30 @@ EXPECTED_BASE_FIELDS = [ ] EXPECTED_FACTOR_FIELDS = [ - 'tor', # 换手率 - 'vr', # 量比 + 'turnover_rate', # 换手率 (tor) + 'volume_ratio', # 量比 (vr) ] -def run_tests_with_print(): - """Run all tests and print results.""" - print("=" * 60) - print("Daily API 测试开始") - print("=" * 60) - - test_results = [] - - # Test 1: Basic daily data fetch - print("\n【测试1】基本日线数据获取") - print("-" * 40) - mock_data = pd.DataFrame({ - 'ts_code': ['000001.SZ'], - 'trade_date': ['20240102'], - 'open': [10.5], - 'high': [11.0], - 'low': [10.2], - 'close': [10.8], - 'pre_close': [10.3], - 'change': [0.5], - 'pct_chg': [4.85], - 'vol': [1000000], - 'amount': [10800000], - }) - - with patch.object(TushareClient, '__init__', lambda self, token=None: None): - with patch.object(TushareClient, 'query', return_value=mock_data): - result = get_daily('000001.SZ', start_date='20240101', end_date='20240131') - - print(f"获取数据形状: {result.shape}") - print(f"获取数据列: {result.columns.tolist()}") - print(f"数据内容:\n{result}") - - # Verify - tests_passed = isinstance(result, pd.DataFrame) - tests_passed &= len(result) == 1 - tests_passed &= result['ts_code'].iloc[0] == '000001.SZ' - - print(f"\n测试结果: {'通过 ✓' if tests_passed else '失败 ✗'}") - test_results.append(("基本日线数据获取", tests_passed)) - - # Test 2: Fetch with factors - print("\n【测试2】获取含换手率和量比的数据") - print("-" * 40) - mock_data_factors = pd.DataFrame({ - 'ts_code': ['000001.SZ'], - 'trade_date': ['20240102'], - 'open': [10.5], - 'high': [11.0], - 'low': [10.2], - 'close': [10.8], - 'pre_close': [10.3], - 'change': [0.5], - 'pct_chg': [4.85], - 'vol': [1000000], - 'amount': [10800000], - 'tor': [2.5], - 'vr': [1.2], - }) - - with patch.object(TushareClient, '__init__', lambda self, token=None: None): - with patch.object(TushareClient, 'query', return_value=mock_data_factors): - result = get_daily( - '000001.SZ', - start_date='20240101', - end_date='20240131', - factors=['tor', 'vr'], - ) - - print(f"获取数据形状: {result.shape}") - print(f"获取数据列: {result.columns.tolist()}") - print(f"数据内容:\n{result}") - - # Verify columns - tests_passed = isinstance(result, pd.DataFrame) - missing_base = [f for f in EXPECTED_BASE_FIELDS if f not in result.columns] - missing_factors = [f for f in EXPECTED_FACTOR_FIELDS if f not in result.columns] - - print(f"\n基础列检查: {'全部存在 ✓' if not missing_base else f'缺失: {missing_base} ✗'}") - print(f"因子列检查: {'全部存在 ✓' if not missing_factors else f'缺失: {missing_factors} ✗'}") - - tests_passed &= len(missing_base) == 0 - tests_passed &= len(missing_factors) == 0 - - print(f"\n测试结果: {'通过 ✓' if tests_passed else '失败 ✗'}") - test_results.append(("含因子数据获取", tests_passed)) - - # Test 3: Output fields completeness - print("\n【测试3】输出字段完整性检查") - print("-" * 40) - mock_data = pd.DataFrame({ - 'ts_code': ['600000.SH'], - 'trade_date': ['20240102'], - 'open': [10.5], - 'high': [11.0], - 'low': [10.2], - 'close': [10.8], - 'pre_close': [10.3], - 'change': [0.5], - 'pct_chg': [4.85], - 'vol': [1000000], - 'amount': [10800000], - }) - - with patch.object(TushareClient, '__init__', lambda self, token=None: None): - with patch.object(TushareClient, 'query', return_value=mock_data): - result = get_daily('600000.SH') - - print(f"获取数据形状: {result.shape}") - print(f"获取数据列: {result.columns.tolist()}") - print(f"期望基础列: {EXPECTED_BASE_FIELDS}") - - missing = set(EXPECTED_BASE_FIELDS) - set(result.columns) - print(f"缺失列: {missing if missing else '无'}") - - tests_passed = set(EXPECTED_BASE_FIELDS).issubset(result.columns.tolist()) - print(f"\n测试结果: {'通过 ✓' if tests_passed else '失败 ✗'}") - test_results.append(("输出字段完整性", tests_passed)) - - # Test 4: Empty result - print("\n【测试4】空结果处理") - print("-" * 40) - mock_data = pd.DataFrame() - - with patch.object(TushareClient, '__init__', lambda self, token=None: None): - with patch.object(TushareClient, 'query', return_value=mock_data): - result = get_daily('INVALID.SZ') - - print(f"获取数据是否为空: {result.empty}") - tests_passed = result.empty - print(f"\n测试结果: {'通过 ✓' if tests_passed else '失败 ✗'}") - test_results.append(("空结果处理", tests_passed)) - - # Test 5: Date range query - print("\n【测试5】日期范围查询") - print("-" * 40) - mock_data = pd.DataFrame({ - 'ts_code': ['000001.SZ', '000001.SZ'], - 'trade_date': ['20240102', '20240103'], - 'open': [10.5, 10.6], - 'high': [11.0, 11.1], - 'low': [10.2, 10.3], - 'close': [10.8, 10.9], - 'pre_close': [10.3, 10.8], - 'change': [0.5, 0.1], - 'pct_chg': [4.85, 0.93], - 'vol': [1000000, 1100000], - 'amount': [10800000, 11900000], - }) - - with patch.object(TushareClient, '__init__', lambda self, token=None: None): - with patch.object(TushareClient, 'query', return_value=mock_data): - result = get_daily( - '000001.SZ', - start_date='20240101', - end_date='20240131', - ) - - print(f"获取数据数量: {len(result)}") - print(f"期望数量: 2") - print(f"数据内容:\n{result}") - - tests_passed = len(result) == 2 - print(f"\n测试结果: {'通过 ✓' if tests_passed else '失败 ✗'}") - test_results.append(("日期范围查询", tests_passed)) - - # Test 6: With adjustment - print("\n【测试6】带复权参数查询") - print("-" * 40) - mock_data = pd.DataFrame({ - 'ts_code': ['000001.SZ'], - 'trade_date': ['20240102'], - 'open': [10.5], - 'high': [11.0], - 'low': [10.2], - 'close': [10.8], - 'pre_close': [10.3], - 'change': [0.5], - 'pct_chg': [4.85], - 'vol': [1000000], - 'amount': [10800000], - }) - - with patch.object(TushareClient, '__init__', lambda self, token=None: None): - with patch.object(TushareClient, 'query', return_value=mock_data): - result = get_daily('000001.SZ', adj='qfq') - - print(f"获取数据形状: {result.shape}") - print(f"数据内容:\n{result}") - - tests_passed = isinstance(result, pd.DataFrame) - print(f"\n测试结果: {'通过 ✓' if tests_passed else '失败 ✗'}") - test_results.append(("复权参数查询", tests_passed)) - - # Summary - print("\n" + "=" * 60) - print("测试汇总") - print("=" * 60) - passed = sum(1 for _, r in test_results if r) - total = len(test_results) - print(f"总测试数: {total}") - print(f"通过: {passed}") - print(f"失败: {total - passed}") - print(f"通过率: {passed/total*100:.1f}%") - print("\n详细结果:") - for name, passed in test_results: - status = "通过 ✓" if passed else "失败 ✗" - print(f" - {name}: {status}") - - return all(r for _, r in test_results) - - class TestGetDaily: - """Test cases for simplified get_daily function.""" + """Test cases for get_daily function with real API calls.""" def test_fetch_basic(self): - """Test basic daily data fetch.""" - mock_data = pd.DataFrame({ - 'ts_code': ['000001.SZ'], - 'trade_date': ['20240102'], - 'open': [10.5], - 'high': [11.0], - 'low': [10.2], - 'close': [10.8], - 'pre_close': [10.3], - 'change': [0.5], - 'pct_chg': [4.85], - 'vol': [1000000], - 'amount': [10800000], - }) - - with patch.object(TushareClient, '__init__', lambda self, token=None: None): - with patch.object(TushareClient, 'query', return_value=mock_data): - result = get_daily('000001.SZ', start_date='20240101', end_date='20240131') + """Test basic daily data fetch with real API.""" + result = get_daily('000001.SZ', start_date='20240101', end_date='20240131') assert isinstance(result, pd.DataFrame) - assert len(result) == 1 + assert len(result) >= 1 assert result['ts_code'].iloc[0] == '000001.SZ' def test_fetch_with_factors(self): """Test fetch with tor and vr factors.""" - mock_data = pd.DataFrame({ - 'ts_code': ['000001.SZ'], - 'trade_date': ['20240102'], - 'open': [10.5], - 'high': [11.0], - 'low': [10.2], - 'close': [10.8], - 'pre_close': [10.3], - 'change': [0.5], - 'pct_chg': [4.85], - 'vol': [1000000], - 'amount': [10800000], - 'tor': [2.5], # 换手率 - 'vr': [1.2], # 量比 - }) - - with patch.object(TushareClient, '__init__', lambda self, token=None: None): - with patch.object(TushareClient, 'query', return_value=mock_data): - result = get_daily( - '000001.SZ', - start_date='20240101', - end_date='20240131', - factors=['tor', 'vr'], - ) + result = get_daily( + '000001.SZ', + start_date='20240101', + end_date='20240131', + factors=['tor', 'vr'], + ) assert isinstance(result, pd.DataFrame) # Check all base fields are present @@ -309,23 +61,7 @@ class TestGetDaily: def test_output_fields_completeness(self): """Verify all required output fields are returned.""" - mock_data = pd.DataFrame({ - 'ts_code': ['600000.SH'], - 'trade_date': ['20240102'], - 'open': [10.5], - 'high': [11.0], - 'low': [10.2], - 'close': [10.8], - 'pre_close': [10.3], - 'change': [0.5], - 'pct_chg': [4.85], - 'vol': [1000000], - 'amount': [10800000], - }) - - with patch.object(TushareClient, '__init__', lambda self, token=None: None): - with patch.object(TushareClient, 'query', return_value=mock_data): - result = get_daily('600000.SH') + result = get_daily('600000.SH') # Verify all base fields are present assert set(EXPECTED_BASE_FIELDS).issubset(result.columns.tolist()), \ @@ -333,59 +69,25 @@ class TestGetDaily: def test_empty_result(self): """Test handling of empty results.""" - mock_data = pd.DataFrame() - - with patch.object(TushareClient, '__init__', lambda self, token=None: None): - with patch.object(TushareClient, 'query', return_value=mock_data): - result = get_daily('INVALID.SZ') - + # 使用真实 API 测试无效股票代码的空结果 + result = get_daily('INVALID.SZ') + assert isinstance(result, pd.DataFrame) assert result.empty def test_date_range_query(self): """Test query with date range.""" - mock_data = pd.DataFrame({ - 'ts_code': ['000001.SZ', '000001.SZ'], - 'trade_date': ['20240102', '20240103'], - 'open': [10.5, 10.6], - 'high': [11.0, 11.1], - 'low': [10.2, 10.3], - 'close': [10.8, 10.9], - 'pre_close': [10.3, 10.8], - 'change': [0.5, 0.1], - 'pct_chg': [4.85, 0.93], - 'vol': [1000000, 1100000], - 'amount': [10800000, 11900000], - }) + result = get_daily( + '000001.SZ', + start_date='20240101', + end_date='20240131', + ) - with patch.object(TushareClient, '__init__', lambda self, token=None: None): - with patch.object(TushareClient, 'query', return_value=mock_data): - result = get_daily( - '000001.SZ', - start_date='20240101', - end_date='20240131', - ) - - assert len(result) == 2 + assert isinstance(result, pd.DataFrame) + assert len(result) >= 1 def test_with_adj(self): """Test fetch with adjustment type.""" - mock_data = pd.DataFrame({ - 'ts_code': ['000001.SZ'], - 'trade_date': ['20240102'], - 'open': [10.5], - 'high': [11.0], - 'low': [10.2], - 'close': [10.8], - 'pre_close': [10.3], - 'change': [0.5], - 'pct_chg': [4.85], - 'vol': [1000000], - 'amount': [10800000], - }) - - with patch.object(TushareClient, '__init__', lambda self, token=None: None): - with patch.object(TushareClient, 'query', return_value=mock_data): - result = get_daily('000001.SZ', adj='qfq') + result = get_daily('000001.SZ', adj='qfq') assert isinstance(result, pd.DataFrame) @@ -411,9 +113,5 @@ def test_integration(): if __name__ == '__main__': - # 运行详细的打印测试 - run_tests_with_print() - print("\n" + "=" * 60) - print("运行 pytest 单元测试") - print("=" * 60 + "\n") + # 运行 pytest 单元测试(真实API调用) pytest.main([__file__, '-v'])