refactor: 存储层迁移DuckDB + 模块重构
- 存储层重构: HDF5 → DuckDB(UPSERT模式、线程安全存储) - Sync类迁移: DataSync从sync.py迁移到api_daily.py(职责分离) - 模型模块重构: src/models → src/pipeline(更清晰的命名) - 新增因子模块: factors/momentum (MA、收益率排名)、factors/financial - 新增API接口: api_namechange、api_bak_basic - 新增训练入口: training模块(main.py、pipeline配置) - 工具函数统一: get_today_date等移至utils.py - 文档更新: AGENTS.md添加架构变更历史
This commit is contained in:
3
.gitignore
vendored
3
.gitignore
vendored
@@ -75,3 +75,6 @@ temp/
|
||||
|
||||
# 数据目录(允许跟踪,但忽略内容)
|
||||
data/*
|
||||
|
||||
# AI Agent 工作目录
|
||||
/.sisyphus/
|
||||
|
||||
117
AGENTS.md
117
AGENTS.md
@@ -2,6 +2,15 @@
|
||||
|
||||
A股量化投资框架 - Python 项目,用于量化股票投资分析。
|
||||
|
||||
## 交流语言要求
|
||||
|
||||
**⚠️ 强制要求:所有沟通和思考过程必须使用中文。**
|
||||
|
||||
所有与 AI Agent 的交流必须使用中文
|
||||
代码中的注释和文档字符串使用中文
|
||||
禁止使用英文进行思考或沟通
|
||||
|
||||
|
||||
## 构建/检查/测试命令
|
||||
|
||||
**⚠️ 重要:本项目强制使用 uv 作为 Python 包管理器和运行工具。禁止直接使用 `python` 或 `pip` 命令。**
|
||||
@@ -67,25 +76,69 @@ uv run pytest tests/test_sync.py # ✅ 正确
|
||||
```
|
||||
ProStock/
|
||||
├── src/ # 源代码
|
||||
│ ├── data/ # 数据采集模块
|
||||
│ ├── config/ # 配置管理
|
||||
│ │ ├── __init__.py
|
||||
│ │ ├── client.py # Tushare API 客户端,带速率限制
|
||||
│ │ ├── config.py # 配置(pydantic-settings)
|
||||
│ │ ├── daily.py # 日线市场数据
|
||||
│ │ └── settings.py # pydantic-settings 配置
|
||||
│ │
|
||||
│ ├── data/ # 数据获取与存储
|
||||
│ │ ├── api_wrappers/ # Tushare API 封装
|
||||
│ │ │ ├── API_INTERFACE_SPEC.md # 接口规范文档
|
||||
│ │ │ ├── api.md # API 接口定义
|
||||
│ │ │ ├── api_daily.py # 日线数据接口
|
||||
│ │ │ ├── api_stock_basic.py # 股票基础信息接口
|
||||
│ │ │ ├── api_trade_cal.py # 交易日历接口
|
||||
│ │ │ └── __init__.py
|
||||
│ │ ├── __init__.py
|
||||
│ │ ├── client.py # Tushare API 客户端(带速率限制)
|
||||
│ │ ├── config.py # 数据模块配置
|
||||
│ │ ├── db_inspector.py # 数据库信息查看工具
|
||||
│ │ ├── db_manager.py # DuckDB 表管理和同步
|
||||
│ │ ├── rate_limiter.py # 令牌桶速率限制器
|
||||
│ │ ├── stock_basic.py # 股票基本信息
|
||||
│ │ ├── storage.py # HDF5 存储管理器
|
||||
│ │ └── sync.py # 数据同步
|
||||
│ ├── config/ # 全局配置
|
||||
│ │ ├── storage.py # 数据存储核心
|
||||
│ │ └── sync.py # 数据同步主逻辑
|
||||
│ │
|
||||
│ ├── factors/ # 因子计算框架
|
||||
│ │ ├── __init__.py
|
||||
│ │ └── settings.py # 应用设置(pydantic-settings)
|
||||
│ └── __init__.py
|
||||
│ │ ├── base.py # 因子基类(截面/时序)
|
||||
│ │ ├── composite.py # 组合因子和标量运算
|
||||
│ │ ├── data_loader.py # 数据加载器
|
||||
│ │ ├── data_spec.py # 数据规格定义
|
||||
│ │ ├── engine.py # 因子执行引擎
|
||||
│ │ ├── momentum/ # 动量因子
|
||||
│ │ │ ├── __init__.py
|
||||
│ │ │ ├── ma.py # 移动平均线
|
||||
│ │ │ └── return_rank.py # 收益排名
|
||||
│ │ └── financial/ # 财务因子
|
||||
│ │ └── __init__.py
|
||||
│ │
|
||||
│ ├── pipeline/ # 模型训练管道
|
||||
│ │ ├── __init__.py
|
||||
│ │ ├── pipeline.py # 处理流水线
|
||||
│ │ ├── registry.py # 插件注册中心
|
||||
│ │ ├── core/ # 核心抽象
|
||||
│ │ │ ├── __init__.py
|
||||
│ │ │ ├── base.py # 基类定义
|
||||
│ │ │ └── splitter.py # 时间序列划分策略
|
||||
│ │ ├── models/ # 模型实现
|
||||
│ │ │ ├── __init__.py
|
||||
│ │ │ └── models.py # LightGBM、CatBoost 等
|
||||
│ │ └── processors/ # 数据处理器
|
||||
│ │ ├── __init__.py
|
||||
│ │ └── processors.py # 标准化、缩尾、中性化等
|
||||
│ │
|
||||
│ └── training/ # 训练入口
|
||||
│ ├── __init__.py
|
||||
│ ├── main.py # 训练主程序
|
||||
│ ├── pipeline.py # 训练流程配置
|
||||
│ └── output/ # 训练输出
|
||||
│ └── top_stocks.tsv # 推荐股票结果
|
||||
│
|
||||
├── tests/ # 测试文件
|
||||
│ ├── test_sync.py
|
||||
│ └── test_daily.py
|
||||
├── config/ # 配置文件
|
||||
│ └── .env.local # 环境变量(不在 git 中)
|
||||
├── data/ # 数据存储(HDF5 文件)
|
||||
├── data/ # 数据存储(DuckDB)
|
||||
├── docs/ # 文档
|
||||
├── pyproject.toml # 项目配置
|
||||
└── README.md
|
||||
@@ -182,10 +235,10 @@ except Exception as e:
|
||||
- 对配置单例使用 `@lru_cache()`
|
||||
|
||||
### 数据存储
|
||||
- 通过 `pandas.HDFStore` 使用 **HDF5 格式** 进行持久化
|
||||
- 使用 **DuckDB** 嵌入式 OLAP 数据库进行持久化
|
||||
- 存储在 `data/` 目录中(通过 `DATA_PATH` 环境变量配置)
|
||||
- 对可追加数据集使用 `format="table"`
|
||||
- 追加时处理重复项:`drop_duplicates(subset=[...])`
|
||||
- 使用 UPSERT 模式(`INSERT OR REPLACE`)处理重复数据
|
||||
- 多线程场景使用 `ThreadSafeStorage.queue_save()` + `flush()` 模式
|
||||
|
||||
### 线程与并发
|
||||
- 对 I/O 密集型任务(API 调用)使用 `ThreadPoolExecutor`
|
||||
@@ -240,3 +293,39 @@ uv run python -c "from src.data.sync import sync_all; sync_all(force_full=True)"
|
||||
# 自定义线程数
|
||||
uv run python -c "from src.data.sync import sync_all; sync_all(max_workers=20)"
|
||||
```
|
||||
|
||||
## 架构变更历史
|
||||
|
||||
### v2.0 (2026-02-23) - 重要更新
|
||||
|
||||
#### 存储层重构
|
||||
**变更**: 从 HDF5 迁移到 DuckDB
|
||||
**原因**: DuckDB 提供更好的查询性能、SQL 下推能力、并发支持
|
||||
**影响**: 所有数据表现在使用 DuckDB 存储,旧 HDF5 文件可手动迁移
|
||||
|
||||
#### Sync 类迁移
|
||||
**变更**: `DataSync` 类从 `sync.py` 迁移到 `api_daily.py`
|
||||
**原因**: 实现代码职责分离,每个 API 文件包含自己的同步逻辑
|
||||
**影响**:
|
||||
- `sync.py` 保留为调度中心
|
||||
- `api_daily.py` 包含 `DailySync` 类和 `sync_daily` 函数
|
||||
|
||||
#### 新增模块
|
||||
**pipeline 模块**: 机器学习流水线组件(处理器、模型、划分策略)
|
||||
**training 模块**: 训练入口程序
|
||||
**factors/momentum**: 动量因子(MA、收益率排名)
|
||||
**factors/financial**: 财务因子框架
|
||||
**data/utils.py**: 日期工具函数集中管理
|
||||
|
||||
#### 新增 API 接口
|
||||
`api_namechange.py`: 股票曾用名接口(手动同步)
|
||||
`api_bak_basic.py`: 历史股票列表接口
|
||||
|
||||
#### 工具函数统一
|
||||
`get_today_date()`、`get_next_date()`、`DEFAULT_START_DATE` 等函数统一在 `src/data/utils.py` 中管理
|
||||
其他模块应从 `utils.py` 导入这些函数,避免重复定义
|
||||
|
||||
### v1.x (历史版本)
|
||||
|
||||
初始版本,使用 HDF5 存储
|
||||
数据同步逻辑集中在 `sync.py`
|
||||
|
||||
@@ -1,12 +1,26 @@
|
||||
# ProStock 数据接口封装规范
|
||||
|
||||
## 1. 概述
|
||||
|
||||
本文档定义了新增 Tushare API 接口封装的标准规范。所有非特殊接口必须遵循此规范,确保:
|
||||
- 代码风格统一
|
||||
- 自动 sync 支持
|
||||
- 增量更新逻辑一致
|
||||
- 减少存储写入压力
|
||||
- 类型安全(强制类型提示)
|
||||
|
||||
### 1.1 技术栈
|
||||
|
||||
- **存储层**: DuckDB(高性能嵌入式 OLAP 数据库)
|
||||
- **数据格式**: Pandas DataFrame / Polars DataFrame
|
||||
- **速率限制**: 令牌桶算法(TokenBucketRateLimiter)
|
||||
- **并发**: ThreadPoolExecutor 多线程
|
||||
- **类型系统**: Python 3.10+ 类型提示
|
||||
|
||||
### 1.2 自动化支持
|
||||
|
||||
项目提供 `prostock-api-interface` Skill 来自动化接口封装流程。在 `api.md` 中定义接口后,调用该 Skill 可自动生成:
|
||||
- 数据模块文件(`src/data/api_wrappers/api_{data_type}.py`)
|
||||
- 数据库表管理配置
|
||||
- 测试文件(`tests/test_{data_type}.py`)
|
||||
|
||||
## 2. 接口分类
|
||||
|
||||
@@ -14,37 +28,41 @@
|
||||
|
||||
以下接口有独立的同步逻辑,不参与自动 sync 机制:
|
||||
|
||||
| 接口类型 | 示例 | 说明 |
|
||||
|---------|------|------|
|
||||
| 交易日历 | `trade_cal` | 全局数据,按日期范围获取 |
|
||||
| 股票基础信息 | `stock_basic` | 一次性全量获取,CSV 存储 |
|
||||
| 辅助数据 | 行业分类、概念分类 | 低频更新,独立管理 |
|
||||
| 接口类型 | 文件名 | 说明 |
|
||||
|---------|--------|------|
|
||||
| 交易日历 | `api_trade_cal.py` | 全局数据,按日期范围获取,使用 HDF5 缓存 |
|
||||
| 股票基础信息 | `api_stock_basic.py` | 一次性全量获取,CSV 存储 |
|
||||
| 辅助数据 | `api_industry`, `api_concept` | 低频更新,独立管理 |
|
||||
|
||||
### 2.2 标准接口(必须遵循本规范)
|
||||
|
||||
所有按股票或按日期获取的因子数据、行情数据、财务数据等,必须遵循本规范。
|
||||
所有按股票或按日期获取的因子数据、行情数据、财务数据等,必须遵循本规范:
|
||||
|
||||
- 按日期获取:**优先选择**,支持全市场批量获取
|
||||
- 按股票获取:仅当 API 不支持按日期获取时使用
|
||||
|
||||
## 3. 文件结构要求
|
||||
|
||||
### 3.1 文件命名
|
||||
|
||||
```
|
||||
{data_type}.py
|
||||
api_{data_type}.py
|
||||
```
|
||||
|
||||
示例:`daily.py`、`moneyflow.py`、`limit_list.py`
|
||||
- 示例:`api_daily.py`、`api_moneyflow.py`、`api_limit_list.py`
|
||||
- **必须**以 `api_` 前缀开头
|
||||
- 使用小写字母和下划线
|
||||
|
||||
### 3.2 文件位置
|
||||
|
||||
所有接口文件必须位于 `src/data/` 目录下。
|
||||
所有接口文件必须位于 `src/data/api_wrappers/` 目录下。
|
||||
|
||||
### 3.3 导出要求
|
||||
|
||||
新接口必须在 `src/data/__init__.py` 中导出:
|
||||
新接口必须在 `src/data/api_wrappers/__init__.py` 中导出:
|
||||
|
||||
```python
|
||||
from src.data.{module_name} import get_{data_type}
|
||||
|
||||
from src.data.api_wrappers.api_{data_type} import get_{data_type}
|
||||
__all__ = [
|
||||
# ... 其他导出 ...
|
||||
"get_{data_type}",
|
||||
@@ -59,7 +77,7 @@ __all__ = [
|
||||
|
||||
#### 4.1.1 按日期获取的接口(优先)
|
||||
|
||||
适用于:涨跌停、龙虎榜、筹码分布等。
|
||||
适用于:涨跌停、龙虎榜、筹码分布、每日指标等。
|
||||
|
||||
**函数签名要求**:
|
||||
|
||||
@@ -77,6 +95,7 @@ def get_{data_type}(
|
||||
- 优先使用 `trade_date` 获取单日全市场数据
|
||||
- 支持 `start_date + end_date` 获取区间数据
|
||||
- `ts_code` 作为可选过滤参数
|
||||
- **性能优势**: 单日全市场数据一次 API 调用即可完成
|
||||
|
||||
#### 4.1.2 按股票获取的接口
|
||||
|
||||
@@ -93,152 +112,504 @@ def get_{data_type}(
|
||||
) -> pd.DataFrame:
|
||||
```
|
||||
|
||||
**要求**:
|
||||
- `ts_code` 为必选参数
|
||||
- 需要遍历所有股票获取全市场数据
|
||||
|
||||
### 4.2 文档字符串要求
|
||||
|
||||
函数必须包含 Google 风格的完整文档字符串,包含:
|
||||
- 函数功能描述
|
||||
- `Args` 部分:所有参数说明
|
||||
- `Returns` 部分:返回的 DataFrame 包含的字段说明
|
||||
- `Example` 部分:使用示例
|
||||
函数必须包含 **Google 风格**的完整文档字符串,包含:
|
||||
|
||||
```python
|
||||
def get_{data_type}(...) -> pd.DataFrame:
|
||||
"""Fetch {数据描述} from Tushare.
|
||||
|
||||
This interface retrieves {详细描述}.
|
||||
|
||||
Args:
|
||||
ts_code: Stock code (e.g., '000001.SZ', '600000.SH')
|
||||
trade_date: Specific trade date (YYYYMMDD format)
|
||||
start_date: Start date (YYYYMMDD format)
|
||||
end_date: End date (YYYYMMDD format)
|
||||
# 其他参数...
|
||||
|
||||
Returns:
|
||||
pd.DataFrame with columns:
|
||||
- ts_code: Stock code
|
||||
- trade_date: Trade date (YYYYMMDD)
|
||||
- {其他字段}: {字段描述}
|
||||
|
||||
Example:
|
||||
>>> # Get single date data for all stocks
|
||||
>>> data = get_{data_type}(trade_date='20240101')
|
||||
>>>
|
||||
>>> # Get date range data
|
||||
>>> data = get_{data_type}(start_date='20240101', end_date='20240131')
|
||||
>>>
|
||||
>>> # Get specific stock data
|
||||
>>> data = get_{data_type}(ts_code='000001.SZ', trade_date='20240101')
|
||||
"""
|
||||
```
|
||||
|
||||
### 4.3 日期格式要求
|
||||
|
||||
- 所有日期参数和返回值使用 `YYYYMMDD` 字符串格式
|
||||
- 所有日期参数使用 **YYYYMMDD** 字符串格式
|
||||
- 统一使用 `trade_date` 作为日期字段名
|
||||
- 如果 API 返回其他日期字段名(如 `date`、`end_date`),必须在返回前重命名为 `trade_date`
|
||||
- 如果 API 返回其他日期字段名(如 `date`、`end_date`),必须在返回前重命名为 `trade_date`:
|
||||
|
||||
```python
|
||||
if "date" in data.columns:
|
||||
data = data.rename(columns={"date": "trade_date"})
|
||||
```
|
||||
|
||||
### 4.4 股票代码要求
|
||||
|
||||
- 统一使用 `ts_code` 作为股票代码字段名
|
||||
- 格式:`{code}.{exchange}`,如 `000001.SZ`、`600000.SH`
|
||||
- 确保返回的 DataFrame 包含 `ts_code` 列
|
||||
|
||||
### 4.5 令牌桶限速要求
|
||||
|
||||
所有 API 调用必须通过 `TushareClient`,自动满足令牌桶限速要求。
|
||||
|
||||
## 5. Sync 集成规范
|
||||
|
||||
### 5.1 DATASET_CONFIG 注册要求
|
||||
|
||||
新接口必须在 `DataSync.DATASET_CONFIG` 中注册,配置项:
|
||||
所有 API 调用必须通过 `TushareClient`,自动满足令牌桶限速要求:
|
||||
|
||||
```python
|
||||
"{new_data_type}": {
|
||||
"api_name": "{tushare_api_name}", # Tushare API 名称
|
||||
"fetch_by": "date", # "date" 或 "stock"
|
||||
"date_field": "trade_date",
|
||||
"key_fields": ["ts_code", "trade_date"], # 用于去重的主键
|
||||
from src.data.client import TushareClient
|
||||
|
||||
def get_{data_type}(...) -> pd.DataFrame:
|
||||
client = TushareClient()
|
||||
|
||||
# Build parameters
|
||||
params = {}
|
||||
if trade_date:
|
||||
params["trade_date"] = trade_date
|
||||
if ts_code:
|
||||
params["ts_code"] = ts_code
|
||||
# ...
|
||||
|
||||
# Fetch data (rate limiting handled automatically)
|
||||
data = client.query("{api_name}", **params)
|
||||
|
||||
return data
|
||||
```
|
||||
|
||||
**注意**: `TushareClient` 自动处理:
|
||||
- 令牌桶速率限制
|
||||
- API 重试逻辑(指数退避)
|
||||
- 配置加载
|
||||
|
||||
## 5. DuckDB 存储规范
|
||||
|
||||
### 5.1 存储架构
|
||||
|
||||
项目使用 **DuckDB** 作为持久化存储:
|
||||
|
||||
- **单例模式**: `Storage` 类确保单一数据库连接
|
||||
- **线程安全**: `ThreadSafeStorage` 提供并发写入支持
|
||||
- **UPSERT 支持**: `INSERT OR REPLACE` 自动处理重复数据
|
||||
- **查询下推**: WHERE 条件在数据库层过滤
|
||||
|
||||
### 5.2 表结构设计
|
||||
|
||||
每个数据类型对应一个 DuckDB 表:
|
||||
|
||||
```sql
|
||||
CREATE TABLE {data_type} (
|
||||
ts_code VARCHAR(16) NOT NULL,
|
||||
trade_date DATE NOT NULL,
|
||||
# 其他字段...
|
||||
PRIMARY KEY (ts_code, trade_date)
|
||||
);
|
||||
|
||||
CREATE INDEX idx_{data_type}_date_code ON {data_type}(trade_date, ts_code);
|
||||
```
|
||||
|
||||
**主键要求**:
|
||||
- 必须包含 `ts_code` 和 `trade_date`
|
||||
- 使用 UPSERT 确保幂等性
|
||||
|
||||
### 5.3 存储写入策略
|
||||
|
||||
**批量写入模式**(推荐用于多线程场景):
|
||||
|
||||
```python
|
||||
from src.data.storage import ThreadSafeStorage
|
||||
|
||||
def sync_{data_type}(self, ...):
|
||||
storage = ThreadSafeStorage()
|
||||
|
||||
# 收集数据到队列(不立即写入)
|
||||
for data_chunk in data_generator:
|
||||
storage.queue_save("{data_type}", data_chunk)
|
||||
|
||||
# 批量写入所有数据
|
||||
storage.flush()
|
||||
```
|
||||
|
||||
**直接写入模式**(适用于简单场景):
|
||||
|
||||
```python
|
||||
from src.data.storage import Storage
|
||||
|
||||
storage = Storage()
|
||||
storage.save("{data_type}", data, mode="append")
|
||||
```
|
||||
|
||||
### 5.4 数据类型映射
|
||||
|
||||
标准字段类型映射(`DEFAULT_TYPE_MAPPING`):
|
||||
|
||||
```python
|
||||
DEFAULT_TYPE_MAPPING = {
|
||||
"ts_code": "VARCHAR(16)",
|
||||
"trade_date": "DATE",
|
||||
"open": "DOUBLE",
|
||||
"high": "DOUBLE",
|
||||
"low": "DOUBLE",
|
||||
"close": "DOUBLE",
|
||||
"vol": "DOUBLE",
|
||||
"amount": "DOUBLE",
|
||||
# ... 其他字段
|
||||
}
|
||||
```
|
||||
|
||||
### 5.2 fetch_by 取值规则
|
||||
## 6. Sync 集成规范
|
||||
|
||||
- **优先使用 `"date"`**:如果 API 支持按日期获取全市场数据
|
||||
- 仅当 API 不支持按日期获取时才使用 `"stock"`
|
||||
### 6.1 使用 db_manager 进行同步
|
||||
|
||||
### 5.3 sync 方法要求
|
||||
|
||||
必须实现对应的 sync 方法或复用通用方法:
|
||||
项目使用 `db_manager` 模块提供高级同步功能:
|
||||
|
||||
```python
|
||||
def sync_{data_type}(self, force_full: bool = False) -> pd.DataFrame:
|
||||
"""Sync {数据描述}。"""
|
||||
return self.sync_dataset("{data_type}", force_full)
|
||||
from src.data.db_manager import SyncManager, ensure_table
|
||||
|
||||
def sync_{data_type}(force_full: bool = False) -> pd.DataFrame:
|
||||
"""Sync {数据描述} to DuckDB."""
|
||||
|
||||
manager = SyncManager()
|
||||
|
||||
# 确保表存在
|
||||
ensure_table("{data_type}", schema={
|
||||
"ts_code": "VARCHAR(16)",
|
||||
"trade_date": "DATE",
|
||||
# ... 其他字段
|
||||
})
|
||||
|
||||
# 执行同步
|
||||
result = manager.sync(
|
||||
table_name="{data_type}",
|
||||
fetch_func=get_{data_type},
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
force_full=force_full,
|
||||
)
|
||||
|
||||
return result
|
||||
```
|
||||
|
||||
同时提供便捷函数:
|
||||
### 6.2 增量更新逻辑
|
||||
|
||||
`SyncManager` 自动处理增量更新:
|
||||
|
||||
1. **检查本地最新日期**: 从 DuckDB 获取 `MAX(trade_date)`
|
||||
2. **获取交易日历**: 从 `api_trade_cal` 获取交易日范围
|
||||
3. **计算需要同步的日期**: 本地最新日期 + 1 到最新交易日
|
||||
4. **批量获取数据**: 按日期或按股票获取
|
||||
5. **批量写入**: 使用 `ThreadSafeStorage` 队列写入
|
||||
|
||||
### 6.3 便捷函数
|
||||
|
||||
每个接口必须提供顶层便捷函数:
|
||||
|
||||
```python
|
||||
def sync_{data_type}(force_full: bool = False) -> pd.DataFrame:
|
||||
"""Sync {数据描述}。"""
|
||||
sync_manager = DataSync()
|
||||
return sync_manager.sync_{data_type}(force_full)
|
||||
"""Sync {数据描述} to local storage.
|
||||
|
||||
Args:
|
||||
force_full: If True, force full reload from 20180101
|
||||
|
||||
Returns:
|
||||
DataFrame with synced data
|
||||
"""
|
||||
# Implementation...
|
||||
```
|
||||
|
||||
### 5.4 增量更新要求
|
||||
## 7. 代码模板
|
||||
|
||||
- 必须实现增量更新逻辑(自动检查本地最新日期)
|
||||
- 使用 `force_full` 参数支持强制全量同步
|
||||
### 7.1 按日期获取接口模板
|
||||
|
||||
## 6. 存储规范
|
||||
```python
|
||||
"""{数据描述} interface.
|
||||
|
||||
### 6.1 存储方式
|
||||
Fetch {数据描述} data from Tushare.
|
||||
"""
|
||||
|
||||
所有数据通过 `Storage` 类进行 HDF5 存储。
|
||||
import pandas as pd
|
||||
from typing import Optional
|
||||
from src.data.client import TushareClient
|
||||
|
||||
### 6.2 写入策略
|
||||
|
||||
**要求**:所有数据在请求完成后**一次性写入**,而非逐条写入。
|
||||
def get_{data_type}(
|
||||
trade_date: Optional[str] = None,
|
||||
start_date: Optional[str] = None,
|
||||
end_date: Optional[str] = None,
|
||||
ts_code: Optional[str] = None,
|
||||
) -> pd.DataFrame:
|
||||
"""Fetch {数据描述} from Tushare.
|
||||
|
||||
### 6.3 去重要求
|
||||
This interface retrieves {详细描述}.
|
||||
|
||||
使用 `key_fields` 配置的字段进行去重,默认使用 `["ts_code", "trade_date"]`。
|
||||
Args:
|
||||
trade_date: Specific trade date (YYYYMMDD format)
|
||||
start_date: Start date (YYYYMMDD format)
|
||||
end_date: End date (YYYYMMDD format)
|
||||
ts_code: Stock code filter (optional)
|
||||
|
||||
## 7. 测试规范
|
||||
Returns:
|
||||
pd.DataFrame with columns:
|
||||
- ts_code: Stock code
|
||||
- trade_date: Trade date (YYYYMMDD)
|
||||
- {字段1}: {描述}
|
||||
- {字段2}: {描述}
|
||||
|
||||
### 7.1 测试文件要求
|
||||
Example:
|
||||
>>> # Get all stocks for a single date
|
||||
>>> data = get_{data_type}(trade_date='20240101')
|
||||
>>>
|
||||
>>> # Get date range data
|
||||
>>> data = get_{data_type}(start_date='20240101', end_date='20240131')
|
||||
"""
|
||||
client = TushareClient()
|
||||
|
||||
# Build parameters
|
||||
params = {}
|
||||
if trade_date:
|
||||
params["trade_date"] = trade_date
|
||||
if start_date:
|
||||
params["start_date"] = start_date
|
||||
if end_date:
|
||||
params["end_date"] = end_date
|
||||
if ts_code:
|
||||
params["ts_code"] = ts_code
|
||||
|
||||
# Fetch data
|
||||
data = client.query("{tushare_api_name}", **params)
|
||||
|
||||
# Rename date column if needed
|
||||
if "date" in data.columns:
|
||||
data = data.rename(columns={"date": "trade_date"})
|
||||
|
||||
return data
|
||||
```
|
||||
|
||||
### 7.2 按股票获取接口模板
|
||||
|
||||
```python
|
||||
"""{数据描述} interface.
|
||||
|
||||
Fetch {数据描述} data from Tushare (per stock).
|
||||
"""
|
||||
|
||||
import pandas as pd
|
||||
from typing import Optional
|
||||
from src.data.client import TushareClient
|
||||
|
||||
|
||||
def get_{data_type}(
|
||||
ts_code: str,
|
||||
start_date: Optional[str] = None,
|
||||
end_date: Optional[str] = None,
|
||||
) -> pd.DataFrame:
|
||||
"""Fetch {数据描述} for a specific stock.
|
||||
|
||||
Args:
|
||||
ts_code: Stock code (e.g., '000001.SZ')
|
||||
start_date: Start date (YYYYMMDD format)
|
||||
end_date: End date (YYYYMMDD format)
|
||||
|
||||
Returns:
|
||||
pd.DataFrame with {数据描述} data
|
||||
"""
|
||||
client = TushareClient()
|
||||
|
||||
params = {"ts_code": ts_code}
|
||||
if start_date:
|
||||
params["start_date"] = start_date
|
||||
if end_date:
|
||||
params["end_date"] = end_date
|
||||
|
||||
data = client.query("{tushare_api_name}", **params)
|
||||
|
||||
return data
|
||||
```
|
||||
|
||||
### 7.3 Sync 函数模板
|
||||
|
||||
```python
|
||||
from src.data.db_manager import SyncManager, ensure_table
|
||||
from src.data.api_wrappers import get_{data_type}
|
||||
|
||||
|
||||
def sync_{data_type}(force_full: bool = False) -> pd.DataFrame:
|
||||
"""Sync {数据描述} to local DuckDB storage.
|
||||
|
||||
Args:
|
||||
force_full: If True, force full reload from 20180101
|
||||
|
||||
Returns:
|
||||
DataFrame with synced data
|
||||
"""
|
||||
manager = SyncManager()
|
||||
|
||||
# Ensure table exists with proper schema
|
||||
ensure_table("{data_type}", schema={
|
||||
"ts_code": "VARCHAR(16)",
|
||||
"trade_date": "DATE",
|
||||
# Add other fields...
|
||||
})
|
||||
|
||||
# Perform sync
|
||||
result = manager.sync(
|
||||
table_name="{data_type}",
|
||||
fetch_func=get_{data_type},
|
||||
force_full=force_full,
|
||||
)
|
||||
|
||||
return result
|
||||
```
|
||||
|
||||
## 8. 测试规范
|
||||
|
||||
### 8.1 测试文件要求
|
||||
|
||||
必须创建对应的测试文件:`tests/test_{data_type}.py`
|
||||
|
||||
### 7.2 测试覆盖要求
|
||||
### 8.2 测试覆盖要求
|
||||
|
||||
- 测试按日期获取
|
||||
- 测试按股票获取(如果支持)
|
||||
- 必须 mock `TushareClient`
|
||||
- 测试覆盖正常和异常情况
|
||||
```python
|
||||
import pytest
|
||||
import pandas as pd
|
||||
from unittest.mock import patch, MagicMock
|
||||
from src.data.api_wrappers.api_{data_type} import get_{data_type}
|
||||
|
||||
## 8. 新增接口完整流程
|
||||
|
||||
### 8.1 创建接口文件
|
||||
class Test{DataType}:
|
||||
"""Test suite for {data_type} API wrapper."""
|
||||
|
||||
1. 在 `src/data/` 下创建 `{data_type}.py`
|
||||
2. 实现数据获取函数,遵循第 4 节规范
|
||||
@patch("src.data.api_wrappers.api_{data_type}.TushareClient")
|
||||
def test_get_by_date(self, mock_client_class):
|
||||
"""Test fetching data by date."""
|
||||
# Setup mock
|
||||
mock_client = MagicMock()
|
||||
mock_client_class.return_value = mock_client
|
||||
mock_client.query.return_value = pd.DataFrame({
|
||||
"ts_code": ["000001.SZ"],
|
||||
"trade_date": ["20240101"],
|
||||
# ... other columns
|
||||
})
|
||||
|
||||
### 8.2 注册 sync 支持
|
||||
# Test
|
||||
result = get_{data_type}(trade_date="20240101")
|
||||
|
||||
1. 在 `sync.py` 的 `DataSync.DATASET_CONFIG` 中注册
|
||||
2. 实现对应的 sync 方法
|
||||
3. 提供便捷函数
|
||||
# Assert
|
||||
assert not result.empty
|
||||
assert "ts_code" in result.columns
|
||||
assert "trade_date" in result.columns
|
||||
mock_client.query.assert_called_once()
|
||||
|
||||
### 8.3 更新导出
|
||||
@patch("src.data.api_wrappers.api_{data_type}.TushareClient")
|
||||
def test_get_by_stock(self, mock_client_class):
|
||||
"""Test fetching data by stock code."""
|
||||
# Similar setup...
|
||||
pass
|
||||
|
||||
在 `src/data/__init__.py` 中导出接口函数。
|
||||
@patch("src.data.api_wrappers.api_{data_type}.TushareClient")
|
||||
def test_empty_response(self, mock_client_class):
|
||||
"""Test handling empty response."""
|
||||
mock_client = MagicMock()
|
||||
mock_client_class.return_value = mock_client
|
||||
mock_client.query.return_value = pd.DataFrame()
|
||||
|
||||
### 8.4 创建测试
|
||||
result = get_{data_type}(trade_date="20240101")
|
||||
assert result.empty
|
||||
```
|
||||
|
||||
创建 `tests/test_{data_type}.py`,覆盖关键场景。
|
||||
### 8.3 Mock 规范
|
||||
|
||||
## 9. 检查清单
|
||||
- 在导入位置打补丁:`patch('src.data.api_wrappers.api_{data_type}.TushareClient')`
|
||||
- 测试正常和异常情况
|
||||
- 验证参数传递正确
|
||||
|
||||
### 9.1 文件结构
|
||||
- [ ] 文件位于 `src/data/{data_type}.py`
|
||||
- [ ] 已更新 `src/data/__init__.py` 导出公共接口
|
||||
## 9. 使用 Skill 自动生成
|
||||
|
||||
### 9.1 准备工作
|
||||
|
||||
1. 在 `api.md` 中定义接口信息,包含:
|
||||
- 接口名称和描述
|
||||
- 输入参数(名称、类型、必选、描述)
|
||||
- 输出参数(名称、类型、描述)
|
||||
|
||||
### 9.2 调用 Skill
|
||||
|
||||
告知 Claude 要封装的接口名称:
|
||||
|
||||
> "帮我封装 {data_type} 接口"
|
||||
|
||||
> "为 {data_type} 接口生成代码"
|
||||
|
||||
### 9.3 自动生成内容
|
||||
|
||||
Skill 会自动:
|
||||
1. 解析 `api.md` 中的接口定义
|
||||
2. 生成 `src/data/api_wrappers/api_{data_type}.py`
|
||||
3. 更新 `src/data/api_wrappers/__init__.py` 导出
|
||||
4. 生成 `tests/test_{data_type}.py` 测试文件
|
||||
5. 提供 `sync_{data_type}()` 函数模板
|
||||
|
||||
## 10. 检查清单
|
||||
|
||||
### 10.1 文件结构
|
||||
- [ ] 文件位于 `src/data/api_wrappers/api_{data_type}.py`
|
||||
- [ ] 已更新 `src/data/api_wrappers/__init__.py` 导出公共接口
|
||||
- [ ] 已创建 `tests/test_{data_type}.py` 测试文件
|
||||
|
||||
### 9.2 接口实现
|
||||
### 10.2 接口实现
|
||||
- [ ] 数据获取函数使用 `TushareClient`
|
||||
- [ ] 函数包含完整的 Google 风格文档字符串
|
||||
- [ ] 日期参数使用 `YYYYMMDD` 格式
|
||||
- [ ] 返回的 DataFrame 包含 `ts_code` 和 `trade_date` 字段
|
||||
- [ ] 优先实现按日期获取的接口(如果 API 支持)
|
||||
- [ ] 参数传递前检查是否为 None
|
||||
|
||||
### 9.3 Sync 集成
|
||||
- [ ] 已在 `DataSync.DATASET_CONFIG` 中注册
|
||||
- [ ] 正确设置 `fetch_by`("date" 或 "stock")
|
||||
- [ ] 正确设置 `date_field` 和 `key_fields`
|
||||
- [ ] 已实现对应的 sync 方法或复用通用方法
|
||||
- [ ] 增量更新逻辑正确(检查本地最新日期)
|
||||
### 10.3 存储集成
|
||||
- [ ] 使用 `Storage` 或 `ThreadSafeStorage` 进行数据存储
|
||||
- [ ] 表结构包含 `ts_code` 和 `trade_date` 作为主键
|
||||
- [ ] 使用 UPSERT 模式(`INSERT OR REPLACE`)
|
||||
- [ ] 多线程场景使用 `queue_save()` + `flush()` 模式
|
||||
|
||||
### 9.4 存储优化
|
||||
- [ ] 所有数据一次性写入(非逐条)
|
||||
- [ ] 使用 `storage.save(mode="append")` 进行增量保存
|
||||
- [ ] 去重字段配置正确
|
||||
### 10.4 Sync 集成
|
||||
- [ ] 使用 `db_manager` 模块进行同步管理
|
||||
- [ ] 实现 `sync_{data_type}()` 便捷函数
|
||||
- [ ] 支持 `force_full` 参数
|
||||
- [ ] 增量更新逻辑正确
|
||||
|
||||
### 9.5 测试
|
||||
### 10.5 测试
|
||||
- [ ] 已编写单元测试
|
||||
- [ ] 已 mock TushareClient
|
||||
- [ ] 已 mock `TushareClient`
|
||||
- [ ] 测试覆盖按日期和按股票获取
|
||||
- [ ] 测试覆盖正常和异常情况
|
||||
## 11. 示例参考
|
||||
|
||||
### 11.1 完整示例:api_daily.py
|
||||
|
||||
参见 `src/data/api_wrappers/api_daily.py` - 按股票获取日线数据的完整实现。
|
||||
|
||||
### 11.2 完整示例:api_trade_cal.py
|
||||
|
||||
参见 `src/data/api_wrappers/api_trade_cal.py` - 特殊接口(交易日历)的实现,包含 HDF5 缓存逻辑。
|
||||
|
||||
### 11.3 完整示例:api_stock_basic.py
|
||||
|
||||
参见 `src/data/api_wrappers/api_stock_basic.py` - 特殊接口(股票基础信息)的实现,包含 CSV 存储逻辑。
|
||||
|
||||
---
|
||||
|
||||
**最后更新**: 2026-02-01
|
||||
**最后更新**: 2026-02-23
|
||||
|
||||
**版本**: v2.0 - 更新 DuckDB 存储规范,添加 Skill 自动化说明
|
||||
@@ -7,15 +7,21 @@ Available APIs:
|
||||
- api_daily: Daily market data (日线行情)
|
||||
- api_stock_basic: Stock basic information (股票基本信息)
|
||||
- api_trade_cal: Trading calendar (交易日历)
|
||||
- api_namechange: Stock name change history (股票曾用名)
|
||||
- api_bak_basic: Stock historical list (股票历史列表)
|
||||
|
||||
Example:
|
||||
>>> from src.data.api_wrappers import get_daily, get_stock_basic, get_trade_cal
|
||||
>>> from src.data.api_wrappers import get_daily, get_stock_basic, get_trade_cal, get_bak_basic
|
||||
>>> from src.data.api_wrappers import get_bak_basic, sync_bak_basic
|
||||
>>> data = get_daily('000001.SZ', start_date='20240101', end_date='20240131')
|
||||
>>> stocks = get_stock_basic()
|
||||
>>> calendar = get_trade_cal('20240101', '20240131')
|
||||
>>> bak_basic = get_bak_basic(trade_date='20240101')
|
||||
"""
|
||||
|
||||
from src.data.api_wrappers.api_daily import get_daily
|
||||
from src.data.api_wrappers.api_daily import get_daily, sync_daily, preview_daily_sync, DailySync
|
||||
from src.data.api_wrappers.api_bak_basic import get_bak_basic, sync_bak_basic
|
||||
from src.data.api_wrappers.api_namechange import get_namechange, sync_namechange
|
||||
from src.data.api_wrappers.api_stock_basic import get_stock_basic, sync_all_stocks
|
||||
from src.data.api_wrappers.api_trade_cal import (
|
||||
get_trade_cal,
|
||||
@@ -28,6 +34,15 @@ from src.data.api_wrappers.api_trade_cal import (
|
||||
__all__ = [
|
||||
# Daily market data
|
||||
"get_daily",
|
||||
"sync_daily",
|
||||
"preview_daily_sync",
|
||||
"DailySync",
|
||||
# Historical stock list
|
||||
"get_bak_basic",
|
||||
"sync_bak_basic",
|
||||
# Namechange
|
||||
"get_namechange",
|
||||
"sync_namechange",
|
||||
# Stock basic information
|
||||
"get_stock_basic",
|
||||
"sync_all_stocks",
|
||||
|
||||
@@ -251,3 +251,98 @@ df = pro.query('daily_basic', ts_code='', trade_date='20180726',fields='ts_code,
|
||||
17 000708.SZ 20180726 0.5575 0.70 10.3674 1.0276
|
||||
18 002626.SZ 20180726 0.6187 0.83 22.7580 4.2446
|
||||
19 600816.SH 20180726 0.6745 0.65 11.0778 3.2214
|
||||
|
||||
|
||||
股票曾用名
|
||||
接口:namechange
|
||||
描述:历史名称变更记录
|
||||
|
||||
输入参数
|
||||
|
||||
名称 类型 必选 描述
|
||||
ts_code str N TS代码
|
||||
start_date str N 公告开始日期
|
||||
end_date str N 公告结束日期
|
||||
输出参数
|
||||
|
||||
名称 类型 默认输出 描述
|
||||
ts_code str Y TS代码
|
||||
name str Y 证券名称
|
||||
start_date str Y 开始日期
|
||||
end_date str Y 结束日期
|
||||
ann_date str Y 公告日期
|
||||
change_reason str Y 变更原因
|
||||
接口示例
|
||||
|
||||
|
||||
pro = ts.pro_api()
|
||||
|
||||
df = pro.namechange(ts_code='600848.SH', fields='ts_code,name,start_date,end_date,change_reason')
|
||||
数据样例
|
||||
|
||||
ts_code name start_date end_date change_reason
|
||||
0 600848.SH 上海临港 20151118 None 改名
|
||||
1 600848.SH 自仪股份 20070514 20151117 撤销ST
|
||||
2 600848.SH ST自仪 20061026 20070513 完成股改
|
||||
3 600848.SH SST自仪 20061009 20061025 未股改加S
|
||||
4 600848.SH ST自仪 20010508 20061008 ST
|
||||
5 600848.SH 自仪股份 19940324 20010507 其他
|
||||
|
||||
|
||||
股票历史列表(历史每天股票列表)
|
||||
接口:bak_basic
|
||||
描述:获取备用基础列表,数据从2016年开始
|
||||
限量:单次最大7000条,可以根据日期参数循环获取历史,正式权限需要5000积分。
|
||||
|
||||
输入参数
|
||||
|
||||
名称 类型 必选 描述
|
||||
trade_date str N 交易日期
|
||||
ts_code str N 股票代码
|
||||
输出参数
|
||||
|
||||
名称 类型 默认显示 描述
|
||||
trade_date str Y 交易日期
|
||||
ts_code str Y TS股票代码
|
||||
name str Y 股票名称
|
||||
industry str Y 行业
|
||||
area str Y 地域
|
||||
pe float Y 市盈率(动)
|
||||
float_share float Y 流通股本(亿)
|
||||
total_share float Y 总股本(亿)
|
||||
total_assets float Y 总资产(亿)
|
||||
liquid_assets float Y 流动资产(亿)
|
||||
fixed_assets float Y 固定资产(亿)
|
||||
reserved float Y 公积金
|
||||
reserved_pershare float Y 每股公积金
|
||||
eps float Y 每股收益
|
||||
bvps float Y 每股净资产
|
||||
pb float Y 市净率
|
||||
list_date str Y 上市日期
|
||||
undp float Y 未分配利润
|
||||
per_undp float Y 每股未分配利润
|
||||
rev_yoy float Y 收入同比(%)
|
||||
profit_yoy float Y 利润同比(%)
|
||||
gpr float Y 毛利率(%)
|
||||
npr float Y 净利润率(%)
|
||||
holder_num int Y 股东人数
|
||||
接口示例
|
||||
|
||||
|
||||
pro = ts.pro_api()
|
||||
|
||||
df = pro.bak_basic(trade_date='20211012', fields='trade_date,ts_code,name,industry,pe')
|
||||
数据样例
|
||||
|
||||
trade_date ts_code name industry pe
|
||||
0 20211012 300605.SZ 恒锋信息 软件服务 56.4400
|
||||
1 20211012 301017.SZ 漱玉平民 医药商业 58.7600
|
||||
2 20211012 300755.SZ 华致酒行 其他商业 23.0000
|
||||
3 20211012 300255.SZ 常山药业 生物制药 24.9900
|
||||
4 20211012 688378.SH 奥来德 专用机械 24.9600
|
||||
... ... ... ... ... ...
|
||||
4529 20211012 688257.SH 新锐股份 机械基件 0.0000
|
||||
4530 20211012 688255.SH 凯尔达 机械基件 0.0000
|
||||
4531 20211012 688211.SH 中科微至 专用机械 0.0000
|
||||
4532 20211012 605567.SH 春雪食品 食品 0.0000
|
||||
4533 20211012 605566.SH 福莱蒽特 染料涂料 0.0000
|
||||
243
src/data/api_wrappers/api_bak_basic.py
Normal file
243
src/data/api_wrappers/api_bak_basic.py
Normal file
@@ -0,0 +1,243 @@
|
||||
"""Stock historical list interface.
|
||||
|
||||
Fetch daily stock list from Tushare bak_basic API.
|
||||
Data available from 2016 onwards.
|
||||
"""
|
||||
|
||||
import pandas as pd
|
||||
from typing import Optional, List
|
||||
from datetime import datetime, timedelta
|
||||
from tqdm import tqdm
|
||||
from src.data.client import TushareClient
|
||||
from src.data.storage import ThreadSafeStorage, Storage
|
||||
from src.data.db_manager import ensure_table
|
||||
|
||||
|
||||
def get_bak_basic(
|
||||
trade_date: Optional[str] = None,
|
||||
ts_code: Optional[str] = None,
|
||||
) -> pd.DataFrame:
|
||||
"""Fetch historical stock list from Tushare.
|
||||
|
||||
This interface retrieves the daily stock list including basic information
|
||||
for all stocks on a specific trade date. Data is available from 2016 onwards.
|
||||
|
||||
Args:
|
||||
trade_date: Specific trade date in YYYYMMDD format
|
||||
ts_code: Stock code filter (optional, e.g., '000001.SZ')
|
||||
|
||||
Returns:
|
||||
pd.DataFrame with columns:
|
||||
- trade_date: Trade date (YYYYMMDD)
|
||||
- ts_code: TS stock code
|
||||
- name: Stock name
|
||||
- industry: Industry
|
||||
- area: Region
|
||||
- pe: P/E ratio (dynamic)
|
||||
- float_share: Float shares (100 million)
|
||||
- total_share: Total shares (100 million)
|
||||
- total_assets: Total assets (100 million)
|
||||
- liquid_assets: Liquid assets (100 million)
|
||||
- fixed_assets: Fixed assets (100 million)
|
||||
- reserved: Reserve fund
|
||||
- reserved_pershare: Reserve per share
|
||||
- eps: Earnings per share
|
||||
- bvps: Book value per share
|
||||
- pb: P/B ratio
|
||||
- list_date: Listing date
|
||||
- undp: Undistributed profit
|
||||
- per_undp: Undistributed profit per share
|
||||
- rev_yoy: Revenue YoY (%)
|
||||
- profit_yoy: Profit YoY (%)
|
||||
- gpr: Gross profit ratio (%)
|
||||
- npr: Net profit ratio (%)
|
||||
- holder_num: Number of shareholders
|
||||
|
||||
Example:
|
||||
>>> # Get all stocks for a single date
|
||||
>>> data = get_bak_basic(trade_date='20240101')
|
||||
>>>
|
||||
>>> # Get specific stock data
|
||||
>>> data = get_bak_basic(ts_code='000001.SZ', trade_date='20240101')
|
||||
"""
|
||||
client = TushareClient()
|
||||
|
||||
# Build parameters
|
||||
params = {}
|
||||
if trade_date:
|
||||
params["trade_date"] = trade_date
|
||||
if ts_code:
|
||||
params["ts_code"] = ts_code
|
||||
|
||||
# Fetch data
|
||||
data = client.query("bak_basic", **params)
|
||||
|
||||
return data
|
||||
|
||||
|
||||
def sync_bak_basic(
|
||||
start_date: Optional[str] = None,
|
||||
end_date: Optional[str] = None,
|
||||
force_full: bool = False,
|
||||
) -> pd.DataFrame:
|
||||
"""Sync historical stock list to DuckDB with intelligent incremental sync.
|
||||
|
||||
Logic:
|
||||
- If table doesn't exist: create table + composite index (trade_date, ts_code) + full sync
|
||||
- If table exists: incremental sync from last_date + 1
|
||||
|
||||
Args:
|
||||
start_date: Start date for sync (YYYYMMDD format, default: 20160101 for full, last_date+1 for incremental)
|
||||
end_date: End date for sync (YYYYMMDD format, default: today)
|
||||
force_full: If True, force full reload from 20160101
|
||||
|
||||
Returns:
|
||||
pd.DataFrame with synced data
|
||||
"""
|
||||
from src.data.db_manager import ensure_table
|
||||
|
||||
TABLE_NAME = "bak_basic"
|
||||
storage = Storage()
|
||||
thread_storage = ThreadSafeStorage()
|
||||
|
||||
# Default end date
|
||||
if end_date is None:
|
||||
end_date = datetime.now().strftime("%Y%m%d")
|
||||
|
||||
# Check if table exists
|
||||
table_exists = storage.exists(TABLE_NAME)
|
||||
|
||||
if not table_exists or force_full:
|
||||
# ===== FULL SYNC =====
|
||||
# 1. Create table with schema
|
||||
# 2. Create composite index (trade_date, ts_code)
|
||||
# 3. Full sync from start_date
|
||||
|
||||
if not table_exists:
|
||||
print(f"[sync_bak_basic] Table '{TABLE_NAME}' doesn't exist, creating...")
|
||||
|
||||
# Fetch sample to get schema
|
||||
sample = get_bak_basic(trade_date=end_date)
|
||||
if sample.empty:
|
||||
sample = get_bak_basic(trade_date="20240102")
|
||||
|
||||
if sample.empty:
|
||||
print("[sync_bak_basic] Cannot create table: no sample data available")
|
||||
return pd.DataFrame()
|
||||
|
||||
# Create table with schema
|
||||
columns = []
|
||||
for col in sample.columns:
|
||||
dtype = str(sample[col].dtype)
|
||||
if "int" in dtype:
|
||||
col_type = "INTEGER"
|
||||
elif "float" in dtype:
|
||||
col_type = "DOUBLE"
|
||||
else:
|
||||
col_type = "VARCHAR"
|
||||
columns.append(f'"{col}" {col_type}')
|
||||
|
||||
columns_sql = ", ".join(columns)
|
||||
create_sql = f'CREATE TABLE IF NOT EXISTS "{TABLE_NAME}" ({columns_sql}, PRIMARY KEY ("trade_date", "ts_code"))'
|
||||
|
||||
try:
|
||||
storage._connection.execute(create_sql)
|
||||
print(f"[sync_bak_basic] Created table '{TABLE_NAME}'")
|
||||
except Exception as e:
|
||||
print(f"[sync_bak_basic] Error creating table: {e}")
|
||||
|
||||
# Create composite index
|
||||
try:
|
||||
storage._connection.execute(f"""
|
||||
CREATE INDEX IF NOT EXISTS "idx_bak_basic_date_code"
|
||||
ON "{TABLE_NAME}"("trade_date", "ts_code")
|
||||
""")
|
||||
print(f"[sync_bak_basic] Created composite index on (trade_date, ts_code)")
|
||||
except Exception as e:
|
||||
print(f"[sync_bak_basic] Error creating index: {e}")
|
||||
|
||||
# Determine sync dates
|
||||
sync_start = start_date or "20160101"
|
||||
mode = "FULL"
|
||||
print(f"[sync_bak_basic] Mode: {mode} SYNC from {sync_start} to {end_date}")
|
||||
|
||||
else:
|
||||
# ===== INCREMENTAL SYNC =====
|
||||
# Check last date in table, sync from last_date + 1
|
||||
|
||||
try:
|
||||
result = storage._connection.execute(
|
||||
f'SELECT MAX("trade_date") FROM "{TABLE_NAME}"'
|
||||
).fetchone()
|
||||
last_date = result[0] if result and result[0] else None
|
||||
except Exception as e:
|
||||
print(f"[sync_bak_basic] Error getting last date: {e}")
|
||||
last_date = None
|
||||
|
||||
if last_date is None:
|
||||
# Table exists but empty, do full sync
|
||||
sync_start = start_date or "20160101"
|
||||
mode = "FULL (empty table)"
|
||||
else:
|
||||
# Incremental from last_date + 1
|
||||
# Handle both YYYYMMDD and YYYY-MM-DD formats
|
||||
last_date_str = str(last_date).replace("-", "")
|
||||
last_dt = datetime.strptime(last_date_str, "%Y%m%d")
|
||||
next_dt = last_dt + timedelta(days=1)
|
||||
sync_start = next_dt.strftime("%Y%m%d")
|
||||
mode = "INCREMENTAL"
|
||||
|
||||
# Skip if already up to date
|
||||
if sync_start > end_date:
|
||||
print(f"[sync_bak_basic] Data is up-to-date (last: {last_date}), skipping sync")
|
||||
return pd.DataFrame()
|
||||
|
||||
print(f"[sync_bak_basic] Mode: {mode} from {sync_start} to {end_date} (last: {last_date})")
|
||||
|
||||
# ===== FETCH AND SAVE DATA =====
|
||||
all_data: List[pd.DataFrame] = []
|
||||
current = datetime.strptime(sync_start, "%Y%m%d")
|
||||
end_dt = datetime.strptime(end_date, "%Y%m%d")
|
||||
|
||||
# Calculate total days for progress bar
|
||||
total_days = (end_dt - current).days + 1
|
||||
print(f"[sync_bak_basic] Fetching data for {total_days} days...")
|
||||
|
||||
with tqdm(total=total_days, desc="Syncing dates") as pbar:
|
||||
while current <= end_dt:
|
||||
date_str = current.strftime("%Y%m%d")
|
||||
try:
|
||||
data = get_bak_basic(trade_date=date_str)
|
||||
if not data.empty:
|
||||
all_data.append(data)
|
||||
pbar.set_postfix({"date": date_str, "records": len(data)})
|
||||
except Exception as e:
|
||||
print(f" {date_str}: ERROR - {e}")
|
||||
|
||||
current += timedelta(days=1)
|
||||
pbar.update(1)
|
||||
|
||||
if not all_data:
|
||||
print("[sync_bak_basic] No data fetched")
|
||||
return pd.DataFrame()
|
||||
|
||||
# Combine and save
|
||||
combined = pd.concat(all_data, ignore_index=True)
|
||||
print(f"[sync_bak_basic] Total records: {len(combined)}")
|
||||
|
||||
# Delete existing data for the date range and append new data
|
||||
storage._connection.execute(f'DELETE FROM "{TABLE_NAME}" WHERE "trade_date" >= ?', [sync_start])
|
||||
thread_storage.queue_save(TABLE_NAME, combined)
|
||||
thread_storage.flush()
|
||||
|
||||
print(f"[sync_bak_basic] Saved {len(combined)} records to DuckDB")
|
||||
return combined
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Test sync
|
||||
result = sync_bak_basic(end_date="20240102")
|
||||
print(f"Synced {len(result)} records")
|
||||
if not result.empty:
|
||||
print("\nSample data:")
|
||||
print(result.head())
|
||||
@@ -2,11 +2,27 @@
|
||||
|
||||
A single function to fetch A股日线行情 data from Tushare.
|
||||
Supports all output fields including tor (换手率) and vr (量比).
|
||||
|
||||
This module provides both single-stock fetching (get_daily) and
|
||||
batch synchronization (DailySync class) for daily market data.
|
||||
"""
|
||||
|
||||
import pandas as pd
|
||||
from typing import Optional, List, Literal
|
||||
from typing import Optional, List, Literal, Dict
|
||||
from datetime import datetime, timedelta
|
||||
from tqdm import tqdm
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
import threading
|
||||
|
||||
from src.data.client import TushareClient
|
||||
from src.data.storage import ThreadSafeStorage, Storage
|
||||
from src.data.utils import get_today_date, get_next_date, DEFAULT_START_DATE
|
||||
from src.data.api_wrappers.api_trade_cal import (
|
||||
get_first_trading_day,
|
||||
get_last_trading_day,
|
||||
sync_trade_cal_cache,
|
||||
)
|
||||
from src.data.api_wrappers.api_stock_basic import _get_csv_path, sync_all_stocks
|
||||
|
||||
|
||||
def get_daily(
|
||||
@@ -71,3 +87,744 @@ def get_daily(
|
||||
data = client.query("pro_bar", **params)
|
||||
|
||||
return data
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# DailySync - 日线数据批量同步类
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class DailySync:
|
||||
"""日线数据批量同步管理器,支持全量/增量同步。
|
||||
|
||||
功能特性:
|
||||
- 多线程并发获取(ThreadPoolExecutor)
|
||||
- 增量同步(自动检测上次同步位置)
|
||||
- 内存缓存(避免重复磁盘读取)
|
||||
- 异常立即停止(确保数据一致性)
|
||||
- 预览模式(预览同步数据量,不实际写入)
|
||||
"""
|
||||
|
||||
# 默认工作线程数
|
||||
DEFAULT_MAX_WORKERS = 10
|
||||
|
||||
def __init__(self, max_workers: Optional[int] = None):
|
||||
"""初始化同步管理器。
|
||||
|
||||
Args:
|
||||
max_workers: 工作线程数(默认: 10)
|
||||
"""
|
||||
self.storage = ThreadSafeStorage()
|
||||
self.client = TushareClient()
|
||||
self.max_workers = max_workers or self.DEFAULT_MAX_WORKERS
|
||||
self._stop_flag = threading.Event()
|
||||
self._stop_flag.set() # 初始为未停止状态
|
||||
self._cached_daily_data: Optional[pd.DataFrame] = None # 日线数据缓存
|
||||
|
||||
def _load_daily_data(self) -> pd.DataFrame:
|
||||
"""从存储加载日线数据(带缓存)。
|
||||
|
||||
该方法会将数据缓存在内存中以避免重复磁盘读取。
|
||||
调用 clear_cache() 可强制重新加载。
|
||||
|
||||
Returns:
|
||||
缓存或从存储加载的日线数据 DataFrame
|
||||
"""
|
||||
if self._cached_daily_data is None:
|
||||
self._cached_daily_data = self.storage.load("daily")
|
||||
return self._cached_daily_data
|
||||
|
||||
def clear_cache(self) -> None:
|
||||
"""清除缓存的日线数据,强制下次访问时重新加载。"""
|
||||
self._cached_daily_data = None
|
||||
|
||||
def get_all_stock_codes(self, only_listed: bool = True) -> list:
|
||||
"""从本地存储获取所有股票代码。
|
||||
|
||||
优先使用 stock_basic.csv 以确保包含所有股票,
|
||||
避免回测中的前视偏差。
|
||||
|
||||
Args:
|
||||
only_listed: 若为 True,仅返回当前上市股票(L 状态)。
|
||||
设为 False 可包含退市股票(用于完整回测)。
|
||||
|
||||
Returns:
|
||||
股票代码列表
|
||||
"""
|
||||
# 确保 stock_basic.csv 是最新的
|
||||
print("[DailySync] Ensuring stock_basic.csv is up-to-date...")
|
||||
sync_all_stocks()
|
||||
|
||||
# 从 stock_basic.csv 文件获取
|
||||
stock_csv_path = _get_csv_path()
|
||||
|
||||
if stock_csv_path.exists():
|
||||
print(f"[DailySync] Reading stock_basic from CSV: {stock_csv_path}")
|
||||
try:
|
||||
stock_df = pd.read_csv(stock_csv_path, encoding="utf-8-sig")
|
||||
if not stock_df.empty and "ts_code" in stock_df.columns:
|
||||
# 根据 list_status 过滤
|
||||
if only_listed and "list_status" in stock_df.columns:
|
||||
listed_stocks = stock_df[stock_df["list_status"] == "L"]
|
||||
codes = listed_stocks["ts_code"].unique().tolist()
|
||||
total = len(stock_df["ts_code"].unique())
|
||||
print(
|
||||
f"[DailySync] Found {len(codes)} listed stocks (filtered from {total} total)"
|
||||
)
|
||||
else:
|
||||
codes = stock_df["ts_code"].unique().tolist()
|
||||
print(
|
||||
f"[DailySync] Found {len(codes)} stock codes from stock_basic.csv"
|
||||
)
|
||||
return codes
|
||||
else:
|
||||
print(
|
||||
f"[DailySync] stock_basic.csv exists but no ts_code column or empty"
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"[DailySync] Error reading stock_basic.csv: {e}")
|
||||
|
||||
# 回退:从日线存储获取
|
||||
print(
|
||||
"[DailySync] stock_basic.csv not available, falling back to daily data..."
|
||||
)
|
||||
daily_data = self._load_daily_data()
|
||||
if not daily_data.empty and "ts_code" in daily_data.columns:
|
||||
codes = daily_data["ts_code"].unique().tolist()
|
||||
print(f"[DailySync] Found {len(codes)} stock codes from daily data")
|
||||
return codes
|
||||
|
||||
print("[DailySync] No stock codes found in local storage")
|
||||
return []
|
||||
|
||||
def get_global_last_date(self) -> Optional[str]:
|
||||
"""获取全局最后交易日期。
|
||||
|
||||
Returns:
|
||||
最后交易日期字符串,若无数据则返回 None
|
||||
"""
|
||||
daily_data = self._load_daily_data()
|
||||
if daily_data.empty or "trade_date" not in daily_data.columns:
|
||||
return None
|
||||
return str(daily_data["trade_date"].max())
|
||||
|
||||
def get_global_first_date(self) -> Optional[str]:
|
||||
"""获取全局最早交易日期。
|
||||
|
||||
Returns:
|
||||
最早交易日期字符串,若无数据则返回 None
|
||||
"""
|
||||
daily_data = self._load_daily_data()
|
||||
if daily_data.empty or "trade_date" not in daily_data.columns:
|
||||
return None
|
||||
return str(daily_data["trade_date"].min())
|
||||
|
||||
def get_trade_calendar_bounds(
|
||||
self, start_date: str, end_date: str
|
||||
) -> tuple[Optional[str], Optional[str]]:
|
||||
"""从交易日历获取首尾交易日。
|
||||
|
||||
Args:
|
||||
start_date: 开始日期(YYYYMMDD 格式)
|
||||
end_date: 结束日期(YYYYMMDD 格式)
|
||||
|
||||
Returns:
|
||||
(首交易日, 尾交易日) 元组,若出错则返回 (None, None)
|
||||
"""
|
||||
try:
|
||||
first_day = get_first_trading_day(start_date, end_date)
|
||||
last_day = get_last_trading_day(start_date, end_date)
|
||||
return (first_day, last_day)
|
||||
except Exception as e:
|
||||
print(f"[ERROR] Failed to get trade calendar bounds: {e}")
|
||||
return (None, None)
|
||||
|
||||
def check_sync_needed(
|
||||
self,
|
||||
force_full: bool = False,
|
||||
table_name: str = "daily",
|
||||
) -> tuple[bool, Optional[str], Optional[str], Optional[str]]:
|
||||
"""基于交易日历检查是否需要同步。
|
||||
|
||||
该方法比较本地数据日期范围与交易日历,
|
||||
以确定是否需要获取新数据。
|
||||
|
||||
逻辑:
|
||||
- 若 force_full:需要同步,返回 (True, 20180101, today)
|
||||
- 若无本地数据:需要同步,返回 (True, 20180101, today)
|
||||
- 若存在本地数据:
|
||||
- 从交易日历获取最后交易日
|
||||
- 若本地最后日期 >= 日历最后日期:无需同步
|
||||
- 否则:从本地最后日期+1 到最新交易日同步
|
||||
|
||||
Args:
|
||||
force_full: 若为 True,始终返回需要同步
|
||||
table_name: 要检查的表名(默认: "daily")
|
||||
|
||||
Returns:
|
||||
(需要同步, 起始日期, 结束日期, 本地最后日期)
|
||||
- 需要同步: True 表示应继续同步
|
||||
- 起始日期: 同步起始日期(无需同步时为 None)
|
||||
- 结束日期: 同步结束日期(无需同步时为 None)
|
||||
- 本地最后日期: 本地数据最后日期(用于增量同步)
|
||||
"""
|
||||
# 若 force_full,始终同步
|
||||
if force_full:
|
||||
print("[DailySync] Force full sync requested")
|
||||
return (True, DEFAULT_START_DATE, get_today_date(), None)
|
||||
|
||||
# 检查特定表的本地数据是否存在
|
||||
storage = Storage()
|
||||
table_data = (
|
||||
storage.load(table_name) if storage.exists(table_name) else pd.DataFrame()
|
||||
)
|
||||
|
||||
if table_data.empty or "trade_date" not in table_data.columns:
|
||||
print(
|
||||
f"[DailySync] No local data found for table '{table_name}', full sync needed"
|
||||
)
|
||||
return (True, DEFAULT_START_DATE, get_today_date(), None)
|
||||
|
||||
# 获取本地数据最后日期
|
||||
local_last_date = str(table_data["trade_date"].max())
|
||||
|
||||
print(f"[DailySync] Local data last date: {local_last_date}")
|
||||
|
||||
# 从交易日历获取最新交易日
|
||||
today = get_today_date()
|
||||
_, cal_last = self.get_trade_calendar_bounds(DEFAULT_START_DATE, today)
|
||||
|
||||
if cal_last is None:
|
||||
print("[DailySync] Failed to get trade calendar, proceeding with sync")
|
||||
return (True, DEFAULT_START_DATE, today, local_last_date)
|
||||
|
||||
print(f"[DailySync] Calendar last trading day: {cal_last}")
|
||||
|
||||
# 比较本地最后日期与日历最后日期
|
||||
print(
|
||||
f"[DailySync] Comparing: local={local_last_date} (type={type(local_last_date).__name__}), "
|
||||
f"cal={cal_last} (type={type(cal_last).__name__})"
|
||||
)
|
||||
try:
|
||||
local_last_int = int(local_last_date)
|
||||
cal_last_int = int(cal_last)
|
||||
print(
|
||||
f"[DailySync] Comparing integers: local={local_last_int} >= cal={cal_last_int} = "
|
||||
f"{local_last_int >= cal_last_int}"
|
||||
)
|
||||
if local_last_int >= cal_last_int:
|
||||
print(
|
||||
"[DailySync] Local data is up-to-date, SKIPPING sync (no tokens consumed)"
|
||||
)
|
||||
return (False, None, None, None)
|
||||
except (ValueError, TypeError) as e:
|
||||
print(f"[ERROR] Date comparison failed: {e}")
|
||||
|
||||
# 需要从本地最后日期+1 同步到最新交易日
|
||||
sync_start = get_next_date(local_last_date)
|
||||
print(f"[DailySync] Incremental sync needed from {sync_start} to {cal_last}")
|
||||
return (True, sync_start, cal_last, local_last_date)
|
||||
|
||||
def preview_sync(
|
||||
self,
|
||||
force_full: bool = False,
|
||||
start_date: Optional[str] = None,
|
||||
end_date: Optional[str] = None,
|
||||
sample_size: int = 3,
|
||||
) -> dict:
|
||||
"""预览同步数据量和样本(不实际同步)。
|
||||
|
||||
该方法提供即将同步的数据的预览,包括:
|
||||
- 将同步的股票数量
|
||||
- 同步日期范围
|
||||
- 预估总记录数
|
||||
- 前几只股票的样本数据
|
||||
|
||||
Args:
|
||||
force_full: 若为 True,预览全量同步(从 20180101)
|
||||
start_date: 手动指定起始日期(覆盖自动检测)
|
||||
end_date: 手动指定结束日期(默认为今天)
|
||||
sample_size: 预览用样本股票数量(默认: 3)
|
||||
|
||||
Returns:
|
||||
包含预览信息的字典:
|
||||
{
|
||||
'sync_needed': bool,
|
||||
'stock_count': int,
|
||||
'start_date': str,
|
||||
'end_date': str,
|
||||
'estimated_records': int,
|
||||
'sample_data': pd.DataFrame,
|
||||
'mode': str, # 'full' 或 'incremental'
|
||||
}
|
||||
"""
|
||||
print("\n" + "=" * 60)
|
||||
print("[DailySync] Preview Mode - Analyzing sync requirements...")
|
||||
print("=" * 60)
|
||||
|
||||
# 首先确保交易日历缓存是最新的
|
||||
print("[DailySync] Syncing trade calendar cache...")
|
||||
sync_trade_cal_cache()
|
||||
|
||||
# 确定日期范围
|
||||
if end_date is None:
|
||||
end_date = get_today_date()
|
||||
|
||||
# 检查是否需要同步
|
||||
sync_needed, cal_start, cal_end, local_last = self.check_sync_needed(force_full)
|
||||
|
||||
if not sync_needed:
|
||||
print("\n" + "=" * 60)
|
||||
print("[DailySync] Preview Result")
|
||||
print("=" * 60)
|
||||
print(" Sync Status: NOT NEEDED")
|
||||
print(" Reason: Local data is up-to-date with trade calendar")
|
||||
print("=" * 60)
|
||||
return {
|
||||
"sync_needed": False,
|
||||
"stock_count": 0,
|
||||
"start_date": None,
|
||||
"end_date": None,
|
||||
"estimated_records": 0,
|
||||
"sample_data": pd.DataFrame(),
|
||||
"mode": "none",
|
||||
}
|
||||
|
||||
# 使用 check_sync_needed 返回的日期
|
||||
if cal_start and cal_end:
|
||||
sync_start_date = cal_start
|
||||
end_date = cal_end
|
||||
else:
|
||||
sync_start_date = start_date or DEFAULT_START_DATE
|
||||
if end_date is None:
|
||||
end_date = get_today_date()
|
||||
|
||||
# 确定同步模式
|
||||
if force_full:
|
||||
mode = "full"
|
||||
print(f"[DailySync] Mode: FULL SYNC from {sync_start_date} to {end_date}")
|
||||
elif local_last and cal_start and sync_start_date == get_next_date(local_last):
|
||||
mode = "incremental"
|
||||
print(f"[DailySync] Mode: INCREmental SYNC (bandwidth optimized)")
|
||||
print(f"[DailySync] Sync from: {sync_start_date} to {end_date}")
|
||||
else:
|
||||
mode = "partial"
|
||||
print(f"[DailySync] Mode: SYNC from {sync_start_date} to {end_date}")
|
||||
|
||||
# 获取所有股票代码
|
||||
stock_codes = self.get_all_stock_codes()
|
||||
if not stock_codes:
|
||||
print("[DailySync] No stocks found to sync")
|
||||
return {
|
||||
"sync_needed": False,
|
||||
"stock_count": 0,
|
||||
"start_date": None,
|
||||
"end_date": None,
|
||||
"estimated_records": 0,
|
||||
"sample_data": pd.DataFrame(),
|
||||
"mode": "none",
|
||||
}
|
||||
|
||||
stock_count = len(stock_codes)
|
||||
print(f"[DailySync] Total stocks to sync: {stock_count}")
|
||||
|
||||
# 从前几只股票获取样本数据
|
||||
print(f"[DailySync] Fetching sample data from {sample_size} stocks...")
|
||||
sample_data_list = []
|
||||
sample_codes = stock_codes[:sample_size]
|
||||
|
||||
for ts_code in sample_codes:
|
||||
try:
|
||||
data = self.client.query(
|
||||
"pro_bar",
|
||||
ts_code=ts_code,
|
||||
start_date=sync_start_date,
|
||||
end_date=end_date,
|
||||
factors="tor,vr",
|
||||
)
|
||||
if not data.empty:
|
||||
sample_data_list.append(data)
|
||||
print(f" - {ts_code}: {len(data)} records")
|
||||
except Exception as e:
|
||||
print(f" - {ts_code}: Error fetching - {e}")
|
||||
|
||||
# 合并样本数据
|
||||
sample_df = (
|
||||
pd.concat(sample_data_list, ignore_index=True)
|
||||
if sample_data_list
|
||||
else pd.DataFrame()
|
||||
)
|
||||
|
||||
# 基于样本估算总记录数
|
||||
if not sample_df.empty:
|
||||
avg_records_per_stock = len(sample_df) / len(sample_data_list)
|
||||
estimated_records = int(avg_records_per_stock * stock_count)
|
||||
else:
|
||||
estimated_records = 0
|
||||
|
||||
# 显示预览结果
|
||||
print("\n" + "=" * 60)
|
||||
print("[DailySync] Preview Result")
|
||||
print("=" * 60)
|
||||
print(f" Sync Mode: {mode.upper()}")
|
||||
print(f" Date Range: {sync_start_date} to {end_date}")
|
||||
print(f" Stocks to Sync: {stock_count}")
|
||||
print(f" Sample Stocks Checked: {len(sample_data_list)}/{sample_size}")
|
||||
print(f" Estimated Total Records: ~{estimated_records:,}")
|
||||
|
||||
if not sample_df.empty:
|
||||
print(f"\n Sample Data Preview (first {len(sample_df)} rows):")
|
||||
print(" " + "-" * 56)
|
||||
# 以紧凑格式显示样本数据
|
||||
preview_cols = [
|
||||
"ts_code",
|
||||
"trade_date",
|
||||
"open",
|
||||
"high",
|
||||
"low",
|
||||
"close",
|
||||
"vol",
|
||||
]
|
||||
available_cols = [c for c in preview_cols if c in sample_df.columns]
|
||||
sample_display = sample_df[available_cols].head(10)
|
||||
for idx, row in sample_display.iterrows():
|
||||
print(f" {row.to_dict()}")
|
||||
print(" " + "-" * 56)
|
||||
|
||||
print("=" * 60)
|
||||
|
||||
return {
|
||||
"sync_needed": True,
|
||||
"stock_count": stock_count,
|
||||
"start_date": sync_start_date,
|
||||
"end_date": end_date,
|
||||
"estimated_records": estimated_records,
|
||||
"sample_data": sample_df,
|
||||
"mode": mode,
|
||||
}
|
||||
|
||||
def sync_single_stock(
|
||||
self,
|
||||
ts_code: str,
|
||||
start_date: str,
|
||||
end_date: str,
|
||||
) -> pd.DataFrame:
|
||||
"""同步单只股票的日线数据。
|
||||
|
||||
Args:
|
||||
ts_code: 股票代码
|
||||
start_date: 起始日期(YYYYMMDD)
|
||||
end_date: 结束日期(YYYYMMDD)
|
||||
|
||||
Returns:
|
||||
包含日线市场数据的 DataFrame
|
||||
"""
|
||||
# 检查是否应该停止同步(用于异常处理)
|
||||
if not self._stop_flag.is_set():
|
||||
return pd.DataFrame()
|
||||
|
||||
try:
|
||||
# 使用共享客户端进行跨线程速率限制
|
||||
data = self.client.query(
|
||||
"pro_bar",
|
||||
ts_code=ts_code,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
factors="tor,vr",
|
||||
)
|
||||
return data
|
||||
except Exception as e:
|
||||
# 设置停止标志以通知其他线程停止
|
||||
self._stop_flag.clear()
|
||||
print(f"[ERROR] Exception syncing {ts_code}: {e}")
|
||||
raise
|
||||
|
||||
def sync_all(
|
||||
self,
|
||||
force_full: bool = False,
|
||||
start_date: Optional[str] = None,
|
||||
end_date: Optional[str] = None,
|
||||
max_workers: Optional[int] = None,
|
||||
dry_run: bool = False,
|
||||
) -> Dict[str, pd.DataFrame]:
|
||||
"""同步本地存储中所有股票的日线数据。
|
||||
|
||||
该函数:
|
||||
1. 从本地存储读取股票代码(daily 或 stock_basic)
|
||||
2. 检查交易日历确定是否需要同步:
|
||||
- 若本地数据匹配交易日历边界,则跳过同步(节省 token)
|
||||
- 否则,从本地最后日期+1 同步到最新交易日(带宽优化)
|
||||
3. 使用多线程并发获取(带速率限制)
|
||||
4. 跳过返回空数据的股票(退市/不可用)
|
||||
5. 遇异常立即停止
|
||||
|
||||
Args:
|
||||
force_full: 若为 True,强制从 20180101 完整重载
|
||||
start_date: 手动指定起始日期(覆盖自动检测)
|
||||
end_date: 手动指定结束日期(默认为今天)
|
||||
max_workers: 工作线程数(默认: 10)
|
||||
dry_run: 若为 True,仅预览将要同步的内容,不写入数据
|
||||
|
||||
Returns:
|
||||
映射 ts_code 到 DataFrame 的字典(若跳过或 dry_run 则为空字典)
|
||||
"""
|
||||
print("\n" + "=" * 60)
|
||||
print("[DailySync] Starting daily data sync...")
|
||||
print("=" * 60)
|
||||
|
||||
# 首先确保交易日历缓存是最新的(使用增量同步)
|
||||
print("[DailySync] Syncing trade calendar cache...")
|
||||
sync_trade_cal_cache()
|
||||
|
||||
# 确定日期范围
|
||||
if end_date is None:
|
||||
end_date = get_today_date()
|
||||
|
||||
# 基于交易日历检查是否需要同步
|
||||
sync_needed, cal_start, cal_end, local_last = self.check_sync_needed(force_full)
|
||||
|
||||
if not sync_needed:
|
||||
# 跳过同步 - 不消耗 token
|
||||
print("\n" + "=" * 60)
|
||||
print("[DailySync] Sync Summary")
|
||||
print("=" * 60)
|
||||
print(" Sync: SKIPPED (local data up-to-date with trade calendar)")
|
||||
print(" Tokens saved: 0 consumed")
|
||||
print("=" * 60)
|
||||
return {}
|
||||
|
||||
# 使用 check_sync_needed 返回的日期(会计算增量起始日期)
|
||||
if cal_start and cal_end:
|
||||
sync_start_date = cal_start
|
||||
end_date = cal_end
|
||||
else:
|
||||
# 回退到默认逻辑
|
||||
sync_start_date = start_date or DEFAULT_START_DATE
|
||||
if end_date is None:
|
||||
end_date = get_today_date()
|
||||
|
||||
# 确定同步模式
|
||||
if force_full:
|
||||
mode = "full"
|
||||
print(f"[DailySync] Mode: FULL SYNC from {sync_start_date} to {end_date}")
|
||||
elif local_last and cal_start and sync_start_date == get_next_date(local_last):
|
||||
mode = "incremental"
|
||||
print(f"[DailySync] Mode: INCREMENTAL SYNC (bandwidth optimized)")
|
||||
print(f"[DailySync] Sync from: {sync_start_date} to {end_date}")
|
||||
else:
|
||||
mode = "partial"
|
||||
print(f"[DailySync] Mode: SYNC from {sync_start_date} to {end_date}")
|
||||
|
||||
# 获取所有股票代码
|
||||
stock_codes = self.get_all_stock_codes()
|
||||
if not stock_codes:
|
||||
print("[DailySync] No stocks found to sync")
|
||||
return {}
|
||||
|
||||
print(f"[DailySync] Total stocks to sync: {len(stock_codes)}")
|
||||
print(f"[DailySync] Using {max_workers or self.max_workers} worker threads")
|
||||
|
||||
# 处理 dry run 模式
|
||||
if dry_run:
|
||||
print("\n" + "=" * 60)
|
||||
print("[DailySync] DRY RUN MODE - No data will be written")
|
||||
print("=" * 60)
|
||||
print(f" Would sync {len(stock_codes)} stocks")
|
||||
print(f" Date range: {sync_start_date} to {end_date}")
|
||||
print(f" Mode: {mode}")
|
||||
print("=" * 60)
|
||||
return {}
|
||||
|
||||
# 为新同步重置停止标志
|
||||
self._stop_flag.set()
|
||||
|
||||
# 多线程并发获取
|
||||
results: Dict[str, pd.DataFrame] = {}
|
||||
error_occurred = False
|
||||
exception_to_raise = None
|
||||
|
||||
def sync_task(ts_code: str) -> tuple[str, pd.DataFrame]:
|
||||
"""每只股票的任务函数。"""
|
||||
try:
|
||||
data = self.sync_single_stock(
|
||||
ts_code=ts_code,
|
||||
start_date=sync_start_date,
|
||||
end_date=end_date,
|
||||
)
|
||||
return (ts_code, data)
|
||||
except Exception as e:
|
||||
# 重新抛出以被 Future 捕获
|
||||
raise
|
||||
|
||||
# 使用 ThreadPoolExecutor 进行并发获取
|
||||
workers = max_workers or self.max_workers
|
||||
with ThreadPoolExecutor(max_workers=workers) as executor:
|
||||
# 提交所有任务并跟踪 futures 与股票代码的映射
|
||||
future_to_code = {
|
||||
executor.submit(sync_task, ts_code): ts_code for ts_code in stock_codes
|
||||
}
|
||||
|
||||
# 使用 as_completed 处理结果
|
||||
error_count = 0
|
||||
empty_count = 0
|
||||
success_count = 0
|
||||
|
||||
# 创建进度条
|
||||
pbar = tqdm(total=len(stock_codes), desc="Syncing stocks")
|
||||
|
||||
try:
|
||||
# 处理完成的 futures
|
||||
for future in as_completed(future_to_code):
|
||||
ts_code = future_to_code[future]
|
||||
|
||||
try:
|
||||
_, data = future.result()
|
||||
if data is not None and not data.empty:
|
||||
results[ts_code] = data
|
||||
success_count += 1
|
||||
else:
|
||||
# 空数据 - 股票可能已退市或不可用
|
||||
empty_count += 1
|
||||
print(
|
||||
f"[DailySync] Stock {ts_code}: empty data (skipped, may be delisted)"
|
||||
)
|
||||
except Exception as e:
|
||||
# 发生异常 - 停止全部并中止
|
||||
error_occurred = True
|
||||
exception_to_raise = e
|
||||
print(f"\n[ERROR] Sync aborted due to exception: {e}")
|
||||
# 关闭 executor 以停止所有待处理任务
|
||||
executor.shutdown(wait=False, cancel_futures=True)
|
||||
raise exception_to_raise
|
||||
|
||||
# 更新进度条
|
||||
pbar.update(1)
|
||||
|
||||
except Exception:
|
||||
error_count = 1
|
||||
print("[DailySync] Sync stopped due to exception")
|
||||
finally:
|
||||
pbar.close()
|
||||
|
||||
# 批量写入所有数据(仅在无错误时)
|
||||
if results and not error_occurred:
|
||||
for ts_code, data in results.items():
|
||||
if not data.empty:
|
||||
self.storage.queue_save("daily", data)
|
||||
# 一次性刷新所有排队写入
|
||||
self.storage.flush()
|
||||
total_rows = sum(len(df) for df in results.values())
|
||||
print(f"\n[DailySync] Saved {total_rows} rows to storage")
|
||||
|
||||
# 摘要
|
||||
print("\n" + "=" * 60)
|
||||
print("[DailySync] Sync Summary")
|
||||
print("=" * 60)
|
||||
print(f" Total stocks: {len(stock_codes)}")
|
||||
print(f" Updated: {success_count}")
|
||||
print(f" Skipped (empty/delisted): {empty_count}")
|
||||
print(
|
||||
f" Errors: {error_count} (aborted on first error)"
|
||||
if error_count
|
||||
else " Errors: 0"
|
||||
)
|
||||
print(f" Date range: {sync_start_date} to {end_date}")
|
||||
print("=" * 60)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def sync_daily(
|
||||
force_full: bool = False,
|
||||
start_date: Optional[str] = None,
|
||||
end_date: Optional[str] = None,
|
||||
max_workers: Optional[int] = None,
|
||||
dry_run: bool = False,
|
||||
) -> Dict[str, pd.DataFrame]:
|
||||
"""同步所有股票的日线数据。
|
||||
|
||||
这是日线数据同步的主要入口点。
|
||||
|
||||
Args:
|
||||
force_full: 若为 True,强制从 20180101 完整重载
|
||||
start_date: 手动指定起始日期(YYYYMMDD)
|
||||
end_date: 手动指定结束日期(默认为今天)
|
||||
max_workers: 工作线程数(默认: 10)
|
||||
dry_run: 若为 True,仅预览将要同步的内容,不写入数据
|
||||
|
||||
Returns:
|
||||
映射 ts_code 到 DataFrame 的字典
|
||||
|
||||
Example:
|
||||
>>> # 首次同步(从 20180101 全量加载)
|
||||
>>> result = sync_daily()
|
||||
>>>
|
||||
>>> # 后续同步(增量 - 仅新数据)
|
||||
>>> result = sync_daily()
|
||||
>>>
|
||||
>>> # 强制完整重载
|
||||
>>> result = sync_daily(force_full=True)
|
||||
>>>
|
||||
>>> # 手动指定日期范围
|
||||
>>> result = sync_daily(start_date='20240101', end_date='20240131')
|
||||
>>>
|
||||
>>> # 自定义线程数
|
||||
>>> result = sync_daily(max_workers=20)
|
||||
>>>
|
||||
>>> # Dry run(仅预览)
|
||||
>>> result = sync_daily(dry_run=True)
|
||||
"""
|
||||
sync_manager = DailySync(max_workers=max_workers)
|
||||
return sync_manager.sync_all(
|
||||
force_full=force_full,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
dry_run=dry_run,
|
||||
)
|
||||
|
||||
|
||||
def preview_daily_sync(
|
||||
force_full: bool = False,
|
||||
start_date: Optional[str] = None,
|
||||
end_date: Optional[str] = None,
|
||||
sample_size: int = 3,
|
||||
) -> dict:
|
||||
"""预览日线同步数据量和样本(不实际同步)。
|
||||
|
||||
这是推荐的方式,可在实际同步前检查将要同步的内容。
|
||||
|
||||
Args:
|
||||
force_full: 若为 True,预览全量同步(从 20180101)
|
||||
start_date: 手动指定起始日期(覆盖自动检测)
|
||||
end_date: 手动指定结束日期(默认为今天)
|
||||
sample_size: 预览用样本股票数量(默认: 3)
|
||||
|
||||
Returns:
|
||||
包含预览信息的字典:
|
||||
{
|
||||
'sync_needed': bool,
|
||||
'stock_count': int,
|
||||
'start_date': str,
|
||||
'end_date': str,
|
||||
'estimated_records': int,
|
||||
'sample_data': pd.DataFrame,
|
||||
'mode': str, # 'full', 'incremental', 'partial', 或 'none'
|
||||
}
|
||||
|
||||
Example:
|
||||
>>> # 预览将要同步的内容
|
||||
>>> preview = preview_daily_sync()
|
||||
>>>
|
||||
>>> # 预览全量同步
|
||||
>>> preview = preview_daily_sync(force_full=True)
|
||||
>>>
|
||||
>>> # 预览更多样本
|
||||
>>> preview = preview_daily_sync(sample_size=5)
|
||||
"""
|
||||
sync_manager = DailySync()
|
||||
return sync_manager.preview_sync(
|
||||
force_full=force_full,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
sample_size=sample_size,
|
||||
)
|
||||
|
||||
113
src/data/api_wrappers/api_namechange.py
Normal file
113
src/data/api_wrappers/api_namechange.py
Normal file
@@ -0,0 +1,113 @@
|
||||
"""Stock name change history interface.
|
||||
|
||||
Fetch historical name change records for stocks.
|
||||
This interface retrieves all historical name changes including name, dates, and change reasons.
|
||||
"""
|
||||
|
||||
import pandas as pd
|
||||
from pathlib import Path
|
||||
from typing import Optional, List
|
||||
from src.data.client import TushareClient
|
||||
from src.data.config import get_config
|
||||
|
||||
|
||||
# CSV file path for namechange data
|
||||
def _get_csv_path() -> Path:
|
||||
"""Get the CSV file path for namechange data."""
|
||||
cfg = get_config()
|
||||
return cfg.data_path_resolved / "namechange.csv"
|
||||
|
||||
|
||||
def get_namechange(
|
||||
ts_code: Optional[str] = None,
|
||||
start_date: Optional[str] = None,
|
||||
end_date: Optional[str] = None,
|
||||
fields: Optional[List[str]] = None,
|
||||
) -> pd.DataFrame:
|
||||
"""Fetch stock name change history.
|
||||
|
||||
This interface retrieves historical name change records for stocks,
|
||||
including name, start/end dates, announcement date, and change reason.
|
||||
|
||||
Args:
|
||||
ts_code: TS stock code (optional, if not provided, returns all stocks)
|
||||
start_date: Start date for announcement date range (YYYYMMDD)
|
||||
end_date: End date for announcement date range (YYYYMMDD)
|
||||
fields: Specific fields to return, None returns all fields
|
||||
|
||||
Returns:
|
||||
pd.DataFrame with namechange information containing:
|
||||
- ts_code: TS stock code
|
||||
- name: Security name
|
||||
- start_date: Start date of the name
|
||||
- end_date: End date of the name
|
||||
- ann_date: Announcement date
|
||||
- change_reason: Reason for name change
|
||||
"""
|
||||
client = TushareClient()
|
||||
|
||||
# Build parameters
|
||||
params = {}
|
||||
if ts_code:
|
||||
params["ts_code"] = ts_code
|
||||
if start_date:
|
||||
params["start_date"] = start_date
|
||||
if end_date:
|
||||
params["end_date"] = end_date
|
||||
if fields:
|
||||
params["fields"] = ",".join(fields)
|
||||
|
||||
# Fetch data
|
||||
data = client.query("namechange", **params)
|
||||
|
||||
if data.empty:
|
||||
print("[get_namechange] No data returned")
|
||||
|
||||
return data
|
||||
|
||||
|
||||
def sync_namechange(force: bool = False) -> pd.DataFrame:
|
||||
"""Fetch and save all stock name change records to local CSV.
|
||||
|
||||
This is a full load interface - fetches all historical name change records.
|
||||
Each call fetches all data (no incremental sync).
|
||||
|
||||
Args:
|
||||
force: If True, force re-fetch even if CSV exists (default: False)
|
||||
|
||||
Returns:
|
||||
pd.DataFrame with all namechange records
|
||||
"""
|
||||
csv_path = _get_csv_path()
|
||||
|
||||
# Check if CSV file already exists
|
||||
if csv_path.exists() and not force:
|
||||
print(f"[sync_namechange] namechange.csv already exists at {csv_path}")
|
||||
print("[sync_namechange] Use force=True to re-fetch")
|
||||
return pd.read_csv(csv_path, encoding="utf-8-sig")
|
||||
|
||||
print("[sync_namechange] Fetching all stock name changes...")
|
||||
|
||||
# Fetch all namechange data (no parameters = all stocks, all history)
|
||||
client = TushareClient()
|
||||
data = client.query("namechange")
|
||||
|
||||
if data.empty:
|
||||
print("[sync_namechange] No namechange data fetched")
|
||||
return pd.DataFrame()
|
||||
|
||||
print(f"[sync_namechange] Fetched {len(data)} name change records")
|
||||
|
||||
# Save to CSV
|
||||
data.to_csv(csv_path, index=False, encoding="utf-8-sig")
|
||||
print(f"[sync_namechange] Saved {len(data)} records to {csv_path}")
|
||||
return data
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Sync all namechange records to data folder
|
||||
result = sync_namechange()
|
||||
print(f"Total records synced: {len(result)}")
|
||||
if not result.empty:
|
||||
print("\nSample data:")
|
||||
print(result.head(10))
|
||||
910
src/data/sync.py
910
src/data/sync.py
@@ -1,701 +1,34 @@
|
||||
"""Data synchronization module.
|
||||
"""数据同步调度中心模块。
|
||||
|
||||
This module provides data fetching functions with intelligent sync logic:
|
||||
- If local file doesn't exist: fetch all data (full load from 20180101)
|
||||
- If local file exists: incremental update (fetch from latest date + 1 day)
|
||||
- Multi-threaded concurrent fetching for improved performance
|
||||
- Stop immediately on any exception
|
||||
- Preview mode: check data volume and samples before actual sync
|
||||
该模块作为数据同步的调度中心,统一管理各类型数据的同步流程。
|
||||
具体的同步逻辑已迁移到对应的 api_xxx.py 文件中:
|
||||
- api_daily.py: 日线数据同步 (DailySync 类)
|
||||
- api_bak_basic.py: 历史股票列表同步
|
||||
- api_stock_basic.py: 股票基本信息同步
|
||||
- api_trade_cal.py: 交易日历同步
|
||||
|
||||
Currently supported data types:
|
||||
- daily: Daily market data (with turnover rate and volume ratio)
|
||||
注意:名称变更 (namechange) 已从自动同步中移除,
|
||||
因为股票名称变更不频繁,建议手动定期同步。
|
||||
|
||||
Usage:
|
||||
# Preview sync (check data volume and samples without writing)
|
||||
preview_sync()
|
||||
使用方式:
|
||||
# 预览同步(检查数据量,不写入)
|
||||
from src.data.sync import preview_sync
|
||||
preview = preview_sync()
|
||||
|
||||
# Sync all stocks (full load)
|
||||
sync_all()
|
||||
# 同步所有数据(不包括 namechange)
|
||||
from src.data.sync import sync_all_data
|
||||
result = sync_all_data()
|
||||
|
||||
# Sync all stocks (incremental)
|
||||
sync_all()
|
||||
|
||||
# Force full reload
|
||||
sync_all(force_full=True)
|
||||
|
||||
# Dry run (preview only, no write)
|
||||
sync_all(dry_run=True)
|
||||
# 强制全量重载
|
||||
result = sync_all_data(force_full=True)
|
||||
"""
|
||||
|
||||
from typing import Optional, Dict
|
||||
|
||||
import pandas as pd
|
||||
from typing import Optional, Dict, Callable
|
||||
from datetime import datetime, timedelta
|
||||
from tqdm import tqdm
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
import threading
|
||||
import sys
|
||||
|
||||
from src.data.client import TushareClient
|
||||
from src.data.storage import ThreadSafeStorage
|
||||
from src.data.api_wrappers import get_daily
|
||||
from src.data.api_wrappers import (
|
||||
get_first_trading_day,
|
||||
get_last_trading_day,
|
||||
sync_trade_cal_cache,
|
||||
)
|
||||
|
||||
|
||||
# Default full sync start date
|
||||
DEFAULT_START_DATE = "20180101"
|
||||
|
||||
# Today's date in YYYYMMDD format
|
||||
TODAY = datetime.now().strftime("%Y%m%d")
|
||||
|
||||
|
||||
def get_today_date() -> str:
|
||||
"""Get today's date in YYYYMMDD format."""
|
||||
return TODAY
|
||||
|
||||
|
||||
def get_next_date(date_str: str) -> str:
|
||||
"""Get the next day after the given date.
|
||||
|
||||
Args:
|
||||
date_str: Date in YYYYMMDD format
|
||||
|
||||
Returns:
|
||||
Next date in YYYYMMDD format
|
||||
"""
|
||||
dt = datetime.strptime(date_str, "%Y%m%d")
|
||||
next_dt = dt + timedelta(days=1)
|
||||
return next_dt.strftime("%Y%m%d")
|
||||
|
||||
|
||||
class DataSync:
|
||||
"""Data synchronization manager with full/incremental sync support."""
|
||||
|
||||
# Default number of worker threads
|
||||
DEFAULT_MAX_WORKERS = 10
|
||||
|
||||
def __init__(self, max_workers: Optional[int] = None):
|
||||
"""Initialize sync manager.
|
||||
|
||||
Args:
|
||||
max_workers: Number of worker threads (default: 10)
|
||||
"""
|
||||
self.storage = ThreadSafeStorage()
|
||||
self.client = TushareClient()
|
||||
self.max_workers = max_workers or self.DEFAULT_MAX_WORKERS
|
||||
self._stop_flag = threading.Event()
|
||||
self._stop_flag.set() # Initially not stopped
|
||||
self._cached_daily_data: Optional[pd.DataFrame] = None # Cache for daily data
|
||||
|
||||
def _load_daily_data(self) -> pd.DataFrame:
|
||||
"""Load daily data from storage with caching.
|
||||
|
||||
This method caches the daily data in memory to avoid repeated disk reads.
|
||||
Call clear_cache() to force reload.
|
||||
|
||||
Returns:
|
||||
DataFrame with daily data (cached or loaded from storage)
|
||||
"""
|
||||
if self._cached_daily_data is None:
|
||||
self._cached_daily_data = self.storage.load("daily")
|
||||
return self._cached_daily_data
|
||||
|
||||
def clear_cache(self) -> None:
|
||||
"""Clear the cached daily data to force reload on next access."""
|
||||
self._cached_daily_data = None
|
||||
|
||||
def get_all_stock_codes(self, only_listed: bool = True) -> list:
|
||||
"""Get all stock codes from local storage.
|
||||
|
||||
This function prioritizes stock_basic.csv to ensure all stocks
|
||||
are included for backtesting to avoid look-ahead bias.
|
||||
|
||||
Args:
|
||||
only_listed: If True, only return currently listed stocks (L status).
|
||||
Set to False to include delisted stocks (for full backtest).
|
||||
|
||||
Returns:
|
||||
List of stock codes
|
||||
"""
|
||||
# Import sync_all_stocks here to avoid circular imports
|
||||
from src.data.api_wrappers import sync_all_stocks
|
||||
from src.data.api_wrappers.api_stock_basic import _get_csv_path
|
||||
|
||||
# First, ensure stock_basic.csv is up-to-date with all stocks
|
||||
print("[DataSync] Ensuring stock_basic.csv is up-to-date...")
|
||||
sync_all_stocks()
|
||||
|
||||
# Get from stock_basic.csv file
|
||||
stock_csv_path = _get_csv_path()
|
||||
|
||||
if stock_csv_path.exists():
|
||||
print(f"[DataSync] Reading stock_basic from CSV: {stock_csv_path}")
|
||||
try:
|
||||
stock_df = pd.read_csv(stock_csv_path, encoding="utf-8-sig")
|
||||
if not stock_df.empty and "ts_code" in stock_df.columns:
|
||||
# Filter by list_status if only_listed is True
|
||||
if only_listed and "list_status" in stock_df.columns:
|
||||
listed_stocks = stock_df[stock_df["list_status"] == "L"]
|
||||
codes = listed_stocks["ts_code"].unique().tolist()
|
||||
total = len(stock_df["ts_code"].unique())
|
||||
print(
|
||||
f"[DataSync] Found {len(codes)} listed stocks (filtered from {total} total)"
|
||||
)
|
||||
else:
|
||||
codes = stock_df["ts_code"].unique().tolist()
|
||||
print(
|
||||
f"[DataSync] Found {len(codes)} stock codes from stock_basic.csv"
|
||||
)
|
||||
return codes
|
||||
else:
|
||||
print(
|
||||
f"[DataSync] stock_basic.csv exists but no ts_code column or empty"
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"[DataSync] Error reading stock_basic.csv: {e}")
|
||||
|
||||
# Fallback: try daily storage if stock_basic not available (using cached data)
|
||||
print("[DataSync] stock_basic.csv not available, falling back to daily data...")
|
||||
daily_data = self._load_daily_data()
|
||||
if not daily_data.empty and "ts_code" in daily_data.columns:
|
||||
codes = daily_data["ts_code"].unique().tolist()
|
||||
print(f"[DataSync] Found {len(codes)} stock codes from daily data")
|
||||
return codes
|
||||
|
||||
print("[DataSync] No stock codes found in local storage")
|
||||
return []
|
||||
|
||||
def get_global_last_date(self) -> Optional[str]:
|
||||
"""Get the global last trade date across all stocks.
|
||||
|
||||
Returns:
|
||||
Last trade date string or None
|
||||
"""
|
||||
daily_data = self._load_daily_data()
|
||||
if daily_data.empty or "trade_date" not in daily_data.columns:
|
||||
return None
|
||||
return str(daily_data["trade_date"].max())
|
||||
|
||||
def get_global_first_date(self) -> Optional[str]:
|
||||
"""Get the global first trade date across all stocks.
|
||||
|
||||
Returns:
|
||||
First trade date string or None
|
||||
"""
|
||||
daily_data = self._load_daily_data()
|
||||
if daily_data.empty or "trade_date" not in daily_data.columns:
|
||||
return None
|
||||
return str(daily_data["trade_date"].min())
|
||||
|
||||
def get_trade_calendar_bounds(
|
||||
self, start_date: str, end_date: str
|
||||
) -> tuple[Optional[str], Optional[str]]:
|
||||
"""Get the first and last trading day from trade calendar.
|
||||
|
||||
Args:
|
||||
start_date: Start date in YYYYMMDD format
|
||||
end_date: End date in YYYYMMDD format
|
||||
|
||||
Returns:
|
||||
Tuple of (first_trading_day, last_trading_day) or (None, None) if error
|
||||
"""
|
||||
try:
|
||||
first_day = get_first_trading_day(start_date, end_date)
|
||||
last_day = get_last_trading_day(start_date, end_date)
|
||||
return (first_day, last_day)
|
||||
except Exception as e:
|
||||
print(f"[ERROR] Failed to get trade calendar bounds: {e}")
|
||||
return (None, None)
|
||||
|
||||
def check_sync_needed(
|
||||
self, force_full: bool = False
|
||||
) -> tuple[bool, Optional[str], Optional[str], Optional[str]]:
|
||||
"""Check if sync is needed based on trade calendar.
|
||||
|
||||
This method compares local data date range with trade calendar
|
||||
to determine if new data needs to be fetched.
|
||||
|
||||
Logic:
|
||||
- If force_full: sync needed, return (True, 20180101, today)
|
||||
- If no local data: sync needed, return (True, 20180101, today)
|
||||
- If local data exists:
|
||||
- Get the last trading day from trade calendar
|
||||
- If local last date >= calendar last date: NO sync needed
|
||||
- Otherwise: sync needed from local_last_date + 1 to latest trade day
|
||||
|
||||
Args:
|
||||
force_full: If True, always return sync needed
|
||||
|
||||
Returns:
|
||||
Tuple of (sync_needed, start_date, end_date, local_last_date)
|
||||
- sync_needed: True if sync should proceed, False to skip
|
||||
- start_date: Sync start date (None if sync not needed)
|
||||
- end_date: Sync end date (None if sync not needed)
|
||||
- local_last_date: Local data last date (for incremental sync)
|
||||
"""
|
||||
# If force_full, always sync
|
||||
if force_full:
|
||||
print("[DataSync] Force full sync requested")
|
||||
return (True, DEFAULT_START_DATE, get_today_date(), None)
|
||||
|
||||
# Check if local data exists (using cached data)
|
||||
daily_data = self._load_daily_data()
|
||||
if daily_data.empty or "trade_date" not in daily_data.columns:
|
||||
print("[DataSync] No local data found, full sync needed")
|
||||
return (True, DEFAULT_START_DATE, get_today_date(), None)
|
||||
|
||||
# Get local data last date (we only care about the latest date, not the first)
|
||||
local_last_date = str(daily_data["trade_date"].max())
|
||||
|
||||
print(f"[DataSync] Local data last date: {local_last_date}")
|
||||
|
||||
# Get the latest trading day from trade calendar
|
||||
today = get_today_date()
|
||||
_, cal_last = self.get_trade_calendar_bounds(DEFAULT_START_DATE, today)
|
||||
|
||||
if cal_last is None:
|
||||
print("[DataSync] Failed to get trade calendar, proceeding with sync")
|
||||
return (True, DEFAULT_START_DATE, today, local_last_date)
|
||||
|
||||
print(f"[DataSync] Calendar last trading day: {cal_last}")
|
||||
|
||||
# Compare local last date with calendar last date
|
||||
# If local data is already up-to-date or newer, no sync needed
|
||||
print(
|
||||
f"[DataSync] Comparing: local={local_last_date} (type={type(local_last_date).__name__}), cal={cal_last} (type={type(cal_last).__name__})"
|
||||
)
|
||||
try:
|
||||
local_last_int = int(local_last_date)
|
||||
cal_last_int = int(cal_last)
|
||||
print(
|
||||
f"[DataSync] Comparing integers: local={local_last_int} >= cal={cal_last_int} = {local_last_int >= cal_last_int}"
|
||||
)
|
||||
if local_last_int >= cal_last_int:
|
||||
print(
|
||||
"[DataSync] Local data is up-to-date, SKIPPING sync (no tokens consumed)"
|
||||
)
|
||||
return (False, None, None, None)
|
||||
except (ValueError, TypeError) as e:
|
||||
print(f"[ERROR] Date comparison failed: {e}")
|
||||
|
||||
# Need to sync from local_last_date + 1 to latest trade day
|
||||
sync_start = get_next_date(local_last_date)
|
||||
print(f"[DataSync] Incremental sync needed from {sync_start} to {cal_last}")
|
||||
return (True, sync_start, cal_last, local_last_date)
|
||||
|
||||
def preview_sync(
|
||||
self,
|
||||
force_full: bool = False,
|
||||
start_date: Optional[str] = None,
|
||||
end_date: Optional[str] = None,
|
||||
sample_size: int = 3,
|
||||
) -> dict:
|
||||
"""Preview sync data volume and samples without actually syncing.
|
||||
|
||||
This method provides a preview of what would be synced, including:
|
||||
- Number of stocks to be synced
|
||||
- Date range for sync
|
||||
- Estimated total records
|
||||
- Sample data from first few stocks
|
||||
|
||||
Args:
|
||||
force_full: If True, preview full sync from 20180101
|
||||
start_date: Manual start date (overrides auto-detection)
|
||||
end_date: Manual end date (defaults to today)
|
||||
sample_size: Number of sample stocks to fetch for preview (default: 3)
|
||||
|
||||
Returns:
|
||||
Dictionary with preview information:
|
||||
{
|
||||
'sync_needed': bool,
|
||||
'stock_count': int,
|
||||
'start_date': str,
|
||||
'end_date': str,
|
||||
'estimated_records': int,
|
||||
'sample_data': pd.DataFrame,
|
||||
'mode': str, # 'full' or 'incremental'
|
||||
}
|
||||
"""
|
||||
print("\n" + "=" * 60)
|
||||
print("[DataSync] Preview Mode - Analyzing sync requirements...")
|
||||
print("=" * 60)
|
||||
|
||||
# First, ensure trade calendar cache is up-to-date
|
||||
print("[DataSync] Syncing trade calendar cache...")
|
||||
sync_trade_cal_cache()
|
||||
|
||||
# Determine date range
|
||||
if end_date is None:
|
||||
end_date = get_today_date()
|
||||
|
||||
# Check if sync is needed
|
||||
sync_needed, cal_start, cal_end, local_last = self.check_sync_needed(force_full)
|
||||
|
||||
if not sync_needed:
|
||||
print("\n" + "=" * 60)
|
||||
print("[DataSync] Preview Result")
|
||||
print("=" * 60)
|
||||
print(" Sync Status: NOT NEEDED")
|
||||
print(" Reason: Local data is up-to-date with trade calendar")
|
||||
print("=" * 60)
|
||||
return {
|
||||
"sync_needed": False,
|
||||
"stock_count": 0,
|
||||
"start_date": None,
|
||||
"end_date": None,
|
||||
"estimated_records": 0,
|
||||
"sample_data": pd.DataFrame(),
|
||||
"mode": "none",
|
||||
}
|
||||
|
||||
# Use dates from check_sync_needed
|
||||
if cal_start and cal_end:
|
||||
sync_start_date = cal_start
|
||||
end_date = cal_end
|
||||
else:
|
||||
sync_start_date = start_date or DEFAULT_START_DATE
|
||||
if end_date is None:
|
||||
end_date = get_today_date()
|
||||
|
||||
# Determine sync mode
|
||||
if force_full:
|
||||
mode = "full"
|
||||
print(f"[DataSync] Mode: FULL SYNC from {sync_start_date} to {end_date}")
|
||||
elif local_last and cal_start and sync_start_date == get_next_date(local_last):
|
||||
mode = "incremental"
|
||||
print(f"[DataSync] Mode: INCREMENTAL SYNC (bandwidth optimized)")
|
||||
print(f"[DataSync] Sync from: {sync_start_date} to {end_date}")
|
||||
else:
|
||||
mode = "partial"
|
||||
print(f"[DataSync] Mode: SYNC from {sync_start_date} to {end_date}")
|
||||
|
||||
# Get all stock codes
|
||||
stock_codes = self.get_all_stock_codes()
|
||||
if not stock_codes:
|
||||
print("[DataSync] No stocks found to sync")
|
||||
return {
|
||||
"sync_needed": False,
|
||||
"stock_count": 0,
|
||||
"start_date": None,
|
||||
"end_date": None,
|
||||
"estimated_records": 0,
|
||||
"sample_data": pd.DataFrame(),
|
||||
"mode": "none",
|
||||
}
|
||||
|
||||
stock_count = len(stock_codes)
|
||||
print(f"[DataSync] Total stocks to sync: {stock_count}")
|
||||
|
||||
# Fetch sample data from first few stocks
|
||||
print(f"[DataSync] Fetching sample data from {sample_size} stocks...")
|
||||
sample_data_list = []
|
||||
sample_codes = stock_codes[:sample_size]
|
||||
|
||||
for ts_code in sample_codes:
|
||||
try:
|
||||
data = self.client.query(
|
||||
"pro_bar",
|
||||
ts_code=ts_code,
|
||||
start_date=sync_start_date,
|
||||
end_date=end_date,
|
||||
factors="tor,vr",
|
||||
)
|
||||
if not data.empty:
|
||||
sample_data_list.append(data)
|
||||
print(f" - {ts_code}: {len(data)} records")
|
||||
except Exception as e:
|
||||
print(f" - {ts_code}: Error fetching - {e}")
|
||||
|
||||
# Combine sample data
|
||||
sample_df = (
|
||||
pd.concat(sample_data_list, ignore_index=True)
|
||||
if sample_data_list
|
||||
else pd.DataFrame()
|
||||
)
|
||||
|
||||
# Estimate total records based on sample
|
||||
if not sample_df.empty:
|
||||
avg_records_per_stock = len(sample_df) / len(sample_data_list)
|
||||
estimated_records = int(avg_records_per_stock * stock_count)
|
||||
else:
|
||||
estimated_records = 0
|
||||
|
||||
# Display preview results
|
||||
print("\n" + "=" * 60)
|
||||
print("[DataSync] Preview Result")
|
||||
print("=" * 60)
|
||||
print(f" Sync Mode: {mode.upper()}")
|
||||
print(f" Date Range: {sync_start_date} to {end_date}")
|
||||
print(f" Stocks to Sync: {stock_count}")
|
||||
print(f" Sample Stocks Checked: {len(sample_data_list)}/{sample_size}")
|
||||
print(f" Estimated Total Records: ~{estimated_records:,}")
|
||||
|
||||
if not sample_df.empty:
|
||||
print(f"\n Sample Data Preview (first {len(sample_df)} rows):")
|
||||
print(" " + "-" * 56)
|
||||
# Display sample data in a compact format
|
||||
preview_cols = [
|
||||
"ts_code",
|
||||
"trade_date",
|
||||
"open",
|
||||
"high",
|
||||
"low",
|
||||
"close",
|
||||
"vol",
|
||||
]
|
||||
available_cols = [c for c in preview_cols if c in sample_df.columns]
|
||||
sample_display = sample_df[available_cols].head(10)
|
||||
for idx, row in sample_display.iterrows():
|
||||
print(f" {row.to_dict()}")
|
||||
print(" " + "-" * 56)
|
||||
|
||||
print("=" * 60)
|
||||
|
||||
return {
|
||||
"sync_needed": True,
|
||||
"stock_count": stock_count,
|
||||
"start_date": sync_start_date,
|
||||
"end_date": end_date,
|
||||
"estimated_records": estimated_records,
|
||||
"sample_data": sample_df,
|
||||
"mode": mode,
|
||||
}
|
||||
|
||||
def sync_single_stock(
|
||||
self,
|
||||
ts_code: str,
|
||||
start_date: str,
|
||||
end_date: str,
|
||||
) -> pd.DataFrame:
|
||||
"""Sync daily data for a single stock.
|
||||
|
||||
Args:
|
||||
ts_code: Stock code
|
||||
start_date: Start date (YYYYMMDD)
|
||||
end_date: End date (YYYYMMDD)
|
||||
|
||||
Returns:
|
||||
DataFrame with daily market data
|
||||
"""
|
||||
# Check if sync should stop (for exception handling)
|
||||
if not self._stop_flag.is_set():
|
||||
return pd.DataFrame()
|
||||
|
||||
try:
|
||||
# Use shared client for rate limiting across threads
|
||||
data = self.client.query(
|
||||
"pro_bar",
|
||||
ts_code=ts_code,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
factors="tor,vr",
|
||||
)
|
||||
return data
|
||||
except Exception as e:
|
||||
# Set stop flag to signal other threads to stop
|
||||
self._stop_flag.clear()
|
||||
print(f"[ERROR] Exception syncing {ts_code}: {e}")
|
||||
raise
|
||||
|
||||
def sync_all(
|
||||
self,
|
||||
force_full: bool = False,
|
||||
start_date: Optional[str] = None,
|
||||
end_date: Optional[str] = None,
|
||||
max_workers: Optional[int] = None,
|
||||
dry_run: bool = False,
|
||||
) -> Dict[str, pd.DataFrame]:
|
||||
"""Sync daily data for all stocks in local storage.
|
||||
|
||||
This function:
|
||||
1. Reads stock codes from local storage (daily or stock_basic)
|
||||
2. Checks trade calendar to determine if sync is needed:
|
||||
- If local data matches trade calendar bounds, SKIP sync (save tokens)
|
||||
- Otherwise, sync from local_last_date + 1 to latest trade day (bandwidth optimized)
|
||||
3. Uses multi-threaded concurrent fetching with rate limiting
|
||||
4. Skips updating stocks that return empty data (delisted/unavailable)
|
||||
5. Stops immediately on any exception
|
||||
|
||||
Args:
|
||||
force_full: If True, force full reload from 20180101
|
||||
start_date: Manual start date (overrides auto-detection)
|
||||
end_date: Manual end date (defaults to today)
|
||||
max_workers: Number of worker threads (default: 10)
|
||||
dry_run: If True, only preview what would be synced without writing data
|
||||
|
||||
Returns:
|
||||
Dict mapping ts_code to DataFrame (empty if sync skipped or dry_run)
|
||||
"""
|
||||
print("\n" + "=" * 60)
|
||||
print("[DataSync] Starting daily data sync...")
|
||||
print("=" * 60)
|
||||
|
||||
# First, ensure trade calendar cache is up-to-date (uses incremental sync)
|
||||
print("[DataSync] Syncing trade calendar cache...")
|
||||
sync_trade_cal_cache()
|
||||
|
||||
# Determine date range
|
||||
if end_date is None:
|
||||
end_date = get_today_date()
|
||||
|
||||
# Check if sync is needed based on trade calendar
|
||||
sync_needed, cal_start, cal_end, local_last = self.check_sync_needed(force_full)
|
||||
|
||||
if not sync_needed:
|
||||
# Sync skipped - no tokens consumed
|
||||
print("\n" + "=" * 60)
|
||||
print("[DataSync] Sync Summary")
|
||||
print("=" * 60)
|
||||
print(" Sync: SKIPPED (local data up-to-date with trade calendar)")
|
||||
print(" Tokens saved: 0 consumed")
|
||||
print("=" * 60)
|
||||
return {}
|
||||
|
||||
# Use dates from check_sync_needed (which calculates incremental start if needed)
|
||||
if cal_start and cal_end:
|
||||
sync_start_date = cal_start
|
||||
end_date = cal_end
|
||||
else:
|
||||
# Fallback to default logic
|
||||
sync_start_date = start_date or DEFAULT_START_DATE
|
||||
if end_date is None:
|
||||
end_date = get_today_date()
|
||||
|
||||
# Determine sync mode
|
||||
if force_full:
|
||||
mode = "full"
|
||||
print(f"[DataSync] Mode: FULL SYNC from {sync_start_date} to {end_date}")
|
||||
elif local_last and cal_start and sync_start_date == get_next_date(local_last):
|
||||
mode = "incremental"
|
||||
print(f"[DataSync] Mode: INCREMENTAL SYNC (bandwidth optimized)")
|
||||
print(f"[DataSync] Sync from: {sync_start_date} to {end_date}")
|
||||
else:
|
||||
mode = "partial"
|
||||
print(f"[DataSync] Mode: SYNC from {sync_start_date} to {end_date}")
|
||||
|
||||
# Get all stock codes
|
||||
stock_codes = self.get_all_stock_codes()
|
||||
if not stock_codes:
|
||||
print("[DataSync] No stocks found to sync")
|
||||
return {}
|
||||
|
||||
print(f"[DataSync] Total stocks to sync: {len(stock_codes)}")
|
||||
print(f"[DataSync] Using {max_workers or self.max_workers} worker threads")
|
||||
|
||||
# Handle dry run mode
|
||||
if dry_run:
|
||||
print("\n" + "=" * 60)
|
||||
print("[DataSync] DRY RUN MODE - No data will be written")
|
||||
print("=" * 60)
|
||||
print(f" Would sync {len(stock_codes)} stocks")
|
||||
print(f" Date range: {sync_start_date} to {end_date}")
|
||||
print(f" Mode: {mode}")
|
||||
print("=" * 60)
|
||||
return {}
|
||||
|
||||
# Reset stop flag for new sync
|
||||
self._stop_flag.set()
|
||||
|
||||
# Multi-threaded concurrent fetching
|
||||
results: Dict[str, pd.DataFrame] = {}
|
||||
error_occurred = False
|
||||
exception_to_raise = None
|
||||
|
||||
def sync_task(ts_code: str) -> tuple[str, pd.DataFrame]:
|
||||
"""Task function for each stock."""
|
||||
try:
|
||||
data = self.sync_single_stock(
|
||||
ts_code=ts_code,
|
||||
start_date=sync_start_date,
|
||||
end_date=end_date,
|
||||
)
|
||||
return (ts_code, data)
|
||||
except Exception as e:
|
||||
# Re-raise to be caught by Future
|
||||
raise
|
||||
|
||||
# Use ThreadPoolExecutor for concurrent fetching
|
||||
workers = max_workers or self.max_workers
|
||||
with ThreadPoolExecutor(max_workers=workers) as executor:
|
||||
# Submit all tasks and track futures with their stock codes
|
||||
future_to_code = {
|
||||
executor.submit(sync_task, ts_code): ts_code for ts_code in stock_codes
|
||||
}
|
||||
|
||||
# Process results using as_completed
|
||||
error_count = 0
|
||||
empty_count = 0
|
||||
success_count = 0
|
||||
|
||||
# Create progress bar
|
||||
pbar = tqdm(total=len(stock_codes), desc="Syncing stocks")
|
||||
|
||||
try:
|
||||
# Process futures as they complete
|
||||
for future in as_completed(future_to_code):
|
||||
ts_code = future_to_code[future]
|
||||
|
||||
try:
|
||||
_, data = future.result()
|
||||
if data is not None and not data.empty:
|
||||
results[ts_code] = data
|
||||
success_count += 1
|
||||
else:
|
||||
# Empty data - stock may be delisted or unavailable
|
||||
empty_count += 1
|
||||
print(
|
||||
f"[DataSync] Stock {ts_code}: empty data (skipped, may be delisted)"
|
||||
)
|
||||
except Exception as e:
|
||||
# Exception occurred - stop all and abort
|
||||
error_occurred = True
|
||||
exception_to_raise = e
|
||||
print(f"\n[ERROR] Sync aborted due to exception: {e}")
|
||||
# Shutdown executor to stop all pending tasks
|
||||
executor.shutdown(wait=False, cancel_futures=True)
|
||||
raise exception_to_raise
|
||||
|
||||
# Update progress bar
|
||||
pbar.update(1)
|
||||
|
||||
except Exception:
|
||||
error_count = 1
|
||||
print("[DataSync] Sync stopped due to exception")
|
||||
finally:
|
||||
pbar.close()
|
||||
|
||||
# Queue all data for batch write (only if no error)
|
||||
if results and not error_occurred:
|
||||
for ts_code, data in results.items():
|
||||
if not data.empty:
|
||||
self.storage.queue_save("daily", data)
|
||||
# Flush all queued writes at once
|
||||
self.storage.flush()
|
||||
total_rows = sum(len(df) for df in results.values())
|
||||
print(f"\n[DataSync] Saved {total_rows} rows to storage")
|
||||
|
||||
# Summary
|
||||
print("\n" + "=" * 60)
|
||||
print("[DataSync] Sync Summary")
|
||||
print("=" * 60)
|
||||
print(f" Total stocks: {len(stock_codes)}")
|
||||
print(f" Updated: {success_count}")
|
||||
print(f" Skipped (empty/delisted): {empty_count}")
|
||||
print(
|
||||
f" Errors: {error_count} (aborted on first error)"
|
||||
if error_count
|
||||
else " Errors: 0"
|
||||
)
|
||||
print(f" Date range: {sync_start_date} to {end_date}")
|
||||
print("=" * 60)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
# Convenience functions
|
||||
from src.data.api_wrappers import sync_all_stocks
|
||||
from src.data.api_wrappers.api_daily import sync_daily, preview_daily_sync
|
||||
|
||||
|
||||
def preview_sync(
|
||||
@@ -705,20 +38,19 @@ def preview_sync(
|
||||
sample_size: int = 3,
|
||||
max_workers: Optional[int] = None,
|
||||
) -> dict:
|
||||
"""Preview sync data volume and samples without actually syncing.
|
||||
"""预览日线同步数据量和样本(不实际同步)。
|
||||
|
||||
This is the recommended way to check what would be synced before
|
||||
running the actual synchronization.
|
||||
这是推荐的方式,可在实际同步前检查将要同步的内容。
|
||||
|
||||
Args:
|
||||
force_full: If True, preview full sync from 20180101
|
||||
start_date: Manual start date (overrides auto-detection)
|
||||
end_date: Manual end date (defaults to today)
|
||||
sample_size: Number of sample stocks to fetch for preview (default: 3)
|
||||
max_workers: Number of worker threads (not used in preview, for API compatibility)
|
||||
force_full: 若为 True,预览全量同步(从 20180101)
|
||||
start_date: 手动指定起始日期(覆盖自动检测)
|
||||
end_date: 手动指定结束日期(默认为今天)
|
||||
sample_size: 预览用样本股票数量(默认: 3)
|
||||
max_workers: 工作线程数(默认: 10)
|
||||
|
||||
Returns:
|
||||
Dictionary with preview information:
|
||||
包含预览信息的字典:
|
||||
{
|
||||
'sync_needed': bool,
|
||||
'stock_count': int,
|
||||
@@ -726,21 +58,20 @@ def preview_sync(
|
||||
'end_date': str,
|
||||
'estimated_records': int,
|
||||
'sample_data': pd.DataFrame,
|
||||
'mode': str, # 'full', 'incremental', 'partial', or 'none'
|
||||
'mode': str, # 'full', 'incremental', 'partial', 或 'none'
|
||||
}
|
||||
|
||||
Example:
|
||||
>>> # Preview what would be synced
|
||||
>>> # 预览将要同步的内容
|
||||
>>> preview = preview_sync()
|
||||
>>>
|
||||
>>> # Preview full sync
|
||||
>>> # 预览全量同步
|
||||
>>> preview = preview_sync(force_full=True)
|
||||
>>>
|
||||
>>> # Preview with more samples
|
||||
>>> # 预览更多样本
|
||||
>>> preview = preview_sync(sample_size=5)
|
||||
"""
|
||||
sync_manager = DataSync(max_workers=max_workers)
|
||||
return sync_manager.preview_sync(
|
||||
return preview_daily_sync(
|
||||
force_full=force_full,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
@@ -755,54 +86,168 @@ def sync_all(
|
||||
max_workers: Optional[int] = None,
|
||||
dry_run: bool = False,
|
||||
) -> Dict[str, pd.DataFrame]:
|
||||
"""Sync daily data for all stocks.
|
||||
"""同步所有股票的日线数据。
|
||||
|
||||
This is the main entry point for data synchronization.
|
||||
这是日线数据同步的主要入口点。
|
||||
|
||||
Args:
|
||||
force_full: If True, force full reload from 20180101
|
||||
start_date: Manual start date (YYYYMMDD)
|
||||
end_date: Manual end date (defaults to today)
|
||||
max_workers: Number of worker threads (default: 10)
|
||||
dry_run: If True, only preview what would be synced without writing data
|
||||
force_full: 若为 True,强制从 20180101 完整重载
|
||||
start_date: 手动指定起始日期(YYYYMMDD)
|
||||
end_date: 手动指定结束日期(默认为今天)
|
||||
max_workers: 工作线程数(默认: 10)
|
||||
dry_run: 若为 True,仅预览将要同步的内容,不写入数据
|
||||
|
||||
Returns:
|
||||
Dict mapping ts_code to DataFrame
|
||||
映射 ts_code 到 DataFrame 的字典
|
||||
|
||||
Example:
|
||||
>>> # First time sync (full load from 20180101)
|
||||
>>> # 首次同步(从 20180101 全量加载)
|
||||
>>> result = sync_all()
|
||||
>>>
|
||||
>>> # Subsequent sync (incremental - only new data)
|
||||
>>> # 后续同步(增量 - 仅新数据)
|
||||
>>> result = sync_all()
|
||||
>>>
|
||||
>>> # Force full reload
|
||||
>>> # 强制完整重载
|
||||
>>> result = sync_all(force_full=True)
|
||||
>>>
|
||||
>>> # Manual date range
|
||||
>>> # 手动指定日期范围
|
||||
>>> result = sync_all(start_date='20240101', end_date='20240131')
|
||||
>>>
|
||||
>>> # Custom thread count
|
||||
>>> # 自定义线程数
|
||||
>>> result = sync_all(max_workers=20)
|
||||
>>>
|
||||
>>> # Dry run (preview only)
|
||||
>>> # Dry run(仅预览)
|
||||
>>> result = sync_all(dry_run=True)
|
||||
"""
|
||||
sync_manager = DataSync(max_workers=max_workers)
|
||||
return sync_manager.sync_all(
|
||||
return sync_daily(
|
||||
force_full=force_full,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
max_workers=max_workers,
|
||||
dry_run=dry_run,
|
||||
)
|
||||
|
||||
|
||||
def sync_all_data(
|
||||
force_full: bool = False,
|
||||
max_workers: Optional[int] = None,
|
||||
dry_run: bool = False,
|
||||
) -> Dict[str, pd.DataFrame]:
|
||||
"""同步所有数据类型(每日同步)。
|
||||
|
||||
该函数按顺序同步所有可用的数据类型:
|
||||
1. 交易日历 (sync_trade_cal_cache)
|
||||
2. 股票基本信息 (sync_all_stocks)
|
||||
3. 日线市场数据 (sync_all)
|
||||
4. 历史股票列表 (sync_bak_basic)
|
||||
|
||||
注意:名称变更 (namechange) 不在自动同步中,如需同步请手动调用。
|
||||
|
||||
Args:
|
||||
force_full: 若为 True,强制所有数据类型完整重载
|
||||
max_workers: 日线数据同步的工作线程数(默认: 10)
|
||||
dry_run: 若为 True,仅显示将要同步的内容 Returns:
|
||||
映射数据类型,不写入数据
|
||||
|
||||
到同步结果的字典
|
||||
|
||||
Example:
|
||||
>>> # 同步所有数据(增量)
|
||||
>>> result = sync_all_data()
|
||||
>>>
|
||||
>>> # 强制完整重载
|
||||
>>> result = sync_all_data(force_full=True)
|
||||
>>>
|
||||
>>> # Dry run
|
||||
>>> result = sync_all_data(dry_run=True)
|
||||
"""
|
||||
results: Dict[str, pd.DataFrame] = {}
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("[sync_all_data] Starting full data synchronization...")
|
||||
print("=" * 60)
|
||||
|
||||
# 1. Sync trade calendar (always needed first)
|
||||
print("\n[1/5] Syncing trade calendar cache...")
|
||||
try:
|
||||
from src.data.api_wrappers import sync_trade_cal_cache
|
||||
|
||||
sync_trade_cal_cache()
|
||||
results["trade_cal"] = pd.DataFrame()
|
||||
print("[1/5] Trade calendar: OK")
|
||||
except Exception as e:
|
||||
print(f"[1/5] Trade calendar: FAILED - {e}")
|
||||
results["trade_cal"] = pd.DataFrame()
|
||||
|
||||
# 2. Sync stock basic info
|
||||
print("\n[2/5] Syncing stock basic info...")
|
||||
try:
|
||||
sync_all_stocks()
|
||||
results["stock_basic"] = pd.DataFrame()
|
||||
print("[2/5] Stock basic: OK")
|
||||
except Exception as e:
|
||||
print(f"[2/5] Stock basic: FAILED - {e}")
|
||||
results["stock_basic"] = pd.DataFrame()
|
||||
|
||||
# 3. Sync daily market data
|
||||
print("\n[3/5] Syncing daily market data...")
|
||||
try:
|
||||
daily_result = sync_daily(
|
||||
force_full=force_full,
|
||||
max_workers=max_workers,
|
||||
dry_run=dry_run,
|
||||
)
|
||||
results["daily"] = (
|
||||
pd.concat(daily_result.values(), ignore_index=True)
|
||||
if daily_result
|
||||
else pd.DataFrame()
|
||||
)
|
||||
print("[3/5] Daily data: OK")
|
||||
except Exception as e:
|
||||
print(f"[3/5] Daily data: FAILED - {e}")
|
||||
results["daily"] = pd.DataFrame()
|
||||
|
||||
# 4. Sync stock historical list (bak_basic)
|
||||
print("\n[4/5] Syncing stock historical list (bak_basic)...")
|
||||
try:
|
||||
bak_basic_result = sync_bak_basic(force_full=force_full)
|
||||
results["bak_basic"] = bak_basic_result
|
||||
print(f"[4/5] Bak basic: OK ({len(bak_basic_result)} records)")
|
||||
except Exception as e:
|
||||
print(f"[4/5] Bak basic: FAILED - {e}")
|
||||
results["bak_basic"] = pd.DataFrame()
|
||||
|
||||
# Summary
|
||||
print("\n" + "=" * 60)
|
||||
print("[sync_all_data] Sync Summary")
|
||||
print("=" * 60)
|
||||
for data_type, df in results.items():
|
||||
print(f" {data_type}: {len(df)} records")
|
||||
print("=" * 60)
|
||||
print("\nNote: namechange is NOT in auto-sync. To sync manually:")
|
||||
print(" from src.data.api_wrappers import sync_namechange")
|
||||
print(" sync_namechange(force=True)")
|
||||
|
||||
return results
|
||||
|
||||
|
||||
# 保留向后兼容的导入
|
||||
from src.data.api_wrappers import sync_bak_basic
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("=" * 60)
|
||||
print("Data Sync Module")
|
||||
print("=" * 60)
|
||||
print("\nUsage:")
|
||||
print(" # Sync all data types at once (RECOMMENDED)")
|
||||
print(" from src.data.sync import sync_all_data")
|
||||
print(" result = sync_all_data() # Incremental sync all")
|
||||
print(" result = sync_all_data(force_full=True) # Full reload")
|
||||
print("")
|
||||
print(" # Or sync individual data types:")
|
||||
print(" from src.data.sync import sync_all, preview_sync")
|
||||
print(" from src.data.sync import sync_bak_basic")
|
||||
print("")
|
||||
print(" # Preview before sync (recommended)")
|
||||
print(" preview = preview_sync()")
|
||||
@@ -813,21 +258,14 @@ if __name__ == "__main__":
|
||||
print(" # Actual sync")
|
||||
print(" result = sync_all() # Incremental sync")
|
||||
print(" result = sync_all(force_full=True) # Full reload")
|
||||
print("")
|
||||
print(" # bak_basic sync")
|
||||
print(" result = sync_bak_basic() # Incremental sync")
|
||||
print(" result = sync_bak_basic(force_full=True) # Full reload")
|
||||
print("\n" + "=" * 60)
|
||||
|
||||
# Run preview first
|
||||
print("\n[Main] Running preview first...")
|
||||
preview = preview_sync()
|
||||
|
||||
if preview["sync_needed"]:
|
||||
# Ask for confirmation
|
||||
print("\n" + "=" * 60)
|
||||
response = input("Proceed with sync? (y/n): ").strip().lower()
|
||||
if response in ("y", "yes"):
|
||||
print("\n[Main] Starting actual sync...")
|
||||
result = sync_all()
|
||||
print(f"\nSynced {len(result)} stocks")
|
||||
else:
|
||||
print("\n[Main] Sync cancelled by user")
|
||||
else:
|
||||
print("\n[Main] No sync needed - data is up to date")
|
||||
# Run sync_all_data by default
|
||||
print("\n[Main] Running sync_all_data()...")
|
||||
result = sync_all_data()
|
||||
print("\n[Main] Sync completed!")
|
||||
print(f"Total data types synced: {len(result)}")
|
||||
|
||||
75
src/data/utils.py
Normal file
75
src/data/utils.py
Normal file
@@ -0,0 +1,75 @@
|
||||
"""Data module utility functions.
|
||||
|
||||
集中管理数据模块中常用的工具函数,避免重复定义。
|
||||
"""
|
||||
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Optional
|
||||
|
||||
|
||||
# 默认全量同步开始日期
|
||||
DEFAULT_START_DATE = "20180101"
|
||||
|
||||
# 今日日期 (YYYYMMDD 格式)
|
||||
TODAY: str = datetime.now().strftime("%Y%m%d")
|
||||
|
||||
|
||||
def get_today_date() -> str:
|
||||
"""获取今日日期(YYYYMMDD 格式)。
|
||||
|
||||
Returns:
|
||||
今日日期字符串,格式为 YYYYMMDD
|
||||
"""
|
||||
return TODAY
|
||||
|
||||
|
||||
def get_next_date(date_str: str) -> str:
|
||||
"""获取给定日期的下一天。
|
||||
|
||||
Args:
|
||||
date_str: YYYYMMDD 格式的日期
|
||||
|
||||
Returns:
|
||||
YYYYMMDD 格式的下一天日期
|
||||
"""
|
||||
dt = datetime.strptime(date_str, "%Y%m%d")
|
||||
next_dt = dt + timedelta(days=1)
|
||||
return next_dt.strftime("%Y%m%d")
|
||||
|
||||
|
||||
def get_prev_date(date_str: str) -> str:
|
||||
"""获取给定日期的前一天。
|
||||
|
||||
Args:
|
||||
date_str: YYYYMMDD 格式的日期
|
||||
|
||||
Returns:
|
||||
YYYYMMDD 格式的前一天日期
|
||||
"""
|
||||
dt = datetime.strptime(date_str, "%Y%m%d")
|
||||
prev_dt = dt - timedelta(days=1)
|
||||
return prev_dt.strftime("%Y%m%d")
|
||||
|
||||
|
||||
def parse_date(date_str: str) -> datetime:
|
||||
"""解析 YYYYMMDD 格式的日期字符串。
|
||||
|
||||
Args:
|
||||
date_str: YYYYMMDD 格式的日期
|
||||
|
||||
Returns:
|
||||
datetime 对象
|
||||
"""
|
||||
return datetime.strptime(date_str, "%Y%m%d")
|
||||
|
||||
|
||||
def format_date(dt: datetime) -> str:
|
||||
"""将 datetime 对象格式化为 YYYYMMDD 字符串。
|
||||
|
||||
Args:
|
||||
dt: datetime 对象
|
||||
|
||||
Returns:
|
||||
YYYYMMDD 格式的日期字符串
|
||||
"""
|
||||
return dt.strftime("%Y%m%d")
|
||||
@@ -18,28 +18,29 @@
|
||||
- CompositeFactor: 组合因子
|
||||
- ScalarFactor: 标量运算因子
|
||||
|
||||
动量因子(momentum/):
|
||||
- MovingAverageFactor: 移动平均线(时序因子)
|
||||
- ReturnRankFactor: 收益率排名(截面因子)
|
||||
|
||||
财务因子(financial/):
|
||||
- (待添加)
|
||||
|
||||
数据加载和执行(Phase 3-4):
|
||||
- DataLoader: 数据加载器
|
||||
- FactorEngine: 因子执行引擎
|
||||
|
||||
使用示例:
|
||||
from src.factors import DataSpec, FactorContext, FactorData
|
||||
from src.factors import CrossSectionalFactor, TimeSeriesFactor
|
||||
# 使用通用因子(参数化)
|
||||
from src.factors import MovingAverageFactor, ReturnRankFactor
|
||||
from src.factors import DataLoader, FactorEngine
|
||||
|
||||
# 定义数据需求
|
||||
spec = DataSpec(
|
||||
source="daily",
|
||||
columns=["ts_code", "trade_date", "close"],
|
||||
lookback_days=20
|
||||
)
|
||||
ma5 = MovingAverageFactor(period=5) # 5日MA
|
||||
ma10 = MovingAverageFactor(period=10) # 10日MA
|
||||
ret5 = ReturnRankFactor(period=5) # 5日收益率排名
|
||||
|
||||
# 初始化引擎
|
||||
loader = DataLoader(data_dir="data")
|
||||
engine = FactorEngine(loader)
|
||||
|
||||
# 计算因子
|
||||
result = engine.compute(factor, start_date="20240101", end_date="20240131")
|
||||
result = engine.compute(ma5, stock_codes=["000001.SZ"], start_date="20240101", end_date="20240131")
|
||||
"""
|
||||
|
||||
from src.factors.data_spec import DataSpec, FactorContext, FactorData
|
||||
@@ -48,6 +49,9 @@ from src.factors.composite import CompositeFactor, ScalarFactor
|
||||
from src.factors.data_loader import DataLoader
|
||||
from src.factors.engine import FactorEngine
|
||||
|
||||
# 动量因子
|
||||
from src.factors.momentum import MovingAverageFactor, ReturnRankFactor
|
||||
|
||||
__all__ = [
|
||||
# Phase 1: 数据类型定义
|
||||
"DataSpec",
|
||||
@@ -62,4 +66,7 @@ __all__ = [
|
||||
# Phase 3-4: 数据加载和执行引擎
|
||||
"DataLoader",
|
||||
"FactorEngine",
|
||||
# 动量因子
|
||||
"MovingAverageFactor",
|
||||
"ReturnRankFactor",
|
||||
]
|
||||
|
||||
15
src/factors/financial/__init__.py
Normal file
15
src/factors/financial/__init__.py
Normal file
@@ -0,0 +1,15 @@
|
||||
"""财务因子模块
|
||||
|
||||
本模块提供财务类型的因子:
|
||||
|
||||
因子分类:
|
||||
- financial: 财务因子
|
||||
- (待添加)
|
||||
|
||||
待添加因子:
|
||||
- PERankFactor: 市盈率排名
|
||||
- PBFactor: 市净率因子
|
||||
- DividendFactor: 股息率因子
|
||||
"""
|
||||
|
||||
__all__ = []
|
||||
19
src/factors/momentum/__init__.py
Normal file
19
src/factors/momentum/__init__.py
Normal file
@@ -0,0 +1,19 @@
|
||||
"""动量因子模块
|
||||
|
||||
本模块提供动量类型的因子:
|
||||
- MovingAverageFactor: 移动平均线(时序因子)
|
||||
- ReturnRankFactor: 收益率排名(截面因子)
|
||||
|
||||
因子分类:
|
||||
- momentum: 动量因子
|
||||
- ma: 移动平均线
|
||||
- return_rank: 收益率排名
|
||||
"""
|
||||
|
||||
from src.factors.momentum.ma import MovingAverageFactor
|
||||
from src.factors.momentum.return_rank import ReturnRankFactor
|
||||
|
||||
__all__ = [
|
||||
"MovingAverageFactor",
|
||||
"ReturnRankFactor",
|
||||
]
|
||||
78
src/factors/momentum/ma.py
Normal file
78
src/factors/momentum/ma.py
Normal file
@@ -0,0 +1,78 @@
|
||||
"""动量因子 - 移动平均线
|
||||
|
||||
本模块提供通用移动平均线因子,支持参数化配置:
|
||||
- MovingAverageFactor: 移动平均线(时序因子)
|
||||
|
||||
使用示例:
|
||||
>>> from src.factors.momentum import MovingAverageFactor
|
||||
>>> ma5 = MovingAverageFactor(period=5) # 5日MA
|
||||
>>> ma10 = MovingAverageFactor(period=10) # 10日MA
|
||||
>>> ma20 = MovingAverageFactor(period=20) # 20日MA
|
||||
"""
|
||||
|
||||
from typing import List
|
||||
|
||||
import polars as pl
|
||||
|
||||
from src.factors.base import TimeSeriesFactor
|
||||
from src.factors.data_spec import DataSpec, FactorData
|
||||
|
||||
|
||||
class MovingAverageFactor(TimeSeriesFactor):
|
||||
"""移动平均线因子
|
||||
|
||||
计算逻辑:对每只股票,计算其过去n日收盘价的移动平均值。
|
||||
|
||||
特点:
|
||||
- 参数化因子:训练时通过 period 参数指定计算窗口
|
||||
- 时序因子:每只股票单独计算,防止股票间数据泄露
|
||||
|
||||
Attributes:
|
||||
period: MA计算期(天数),默认5
|
||||
|
||||
Example:
|
||||
>>> ma5 = MovingAverageFactor(period=5)
|
||||
>>> # 计算过去5日的收盘价均值
|
||||
"""
|
||||
|
||||
name: str = "ma"
|
||||
factor_type: str = "time_series"
|
||||
category: str = "momentum"
|
||||
description: str = "移动平均线因子,计算过去n日收盘价的均值"
|
||||
data_specs: List[DataSpec] = [
|
||||
DataSpec("daily", ["ts_code", "trade_date", "close"], lookback_days=5)
|
||||
]
|
||||
|
||||
def __init__(self, period: int = 5):
|
||||
"""初始化因子
|
||||
|
||||
Args:
|
||||
period: MA计算期(天数),默认5日
|
||||
"""
|
||||
super().__init__(period=period)
|
||||
# 重新创建 DataSpec 以设置正确的 lookback_days(DataSpec 是 frozen 的)
|
||||
self.data_specs = [
|
||||
DataSpec(
|
||||
"daily",
|
||||
["ts_code", "trade_date", "close"],
|
||||
lookback_days=period,
|
||||
)
|
||||
]
|
||||
self.name = f"ma_{period}"
|
||||
|
||||
def compute(self, data: FactorData) -> pl.Series:
|
||||
"""计算移动平均线
|
||||
|
||||
Args:
|
||||
data: FactorData,包含单只股票的完整时间序列
|
||||
|
||||
Returns:
|
||||
移动平均值序列
|
||||
"""
|
||||
# 获取收盘价序列
|
||||
close_prices = data.get_column("close")
|
||||
|
||||
# 计算移动平均
|
||||
ma = close_prices.rolling_mean(window_size=self.params["period"])
|
||||
|
||||
return ma
|
||||
100
src/factors/momentum/return_rank.py
Normal file
100
src/factors/momentum/return_rank.py
Normal file
@@ -0,0 +1,100 @@
|
||||
"""动量因子 - 收益率排名
|
||||
|
||||
本模块提供收益率排名因子:
|
||||
- ReturnRankFactor: 过去n日收益率的rank因子(截面因子)
|
||||
|
||||
使用示例:
|
||||
>>> from src.factors.momentum import ReturnRankFactor
|
||||
>>> ret5 = ReturnRankFactor(period=5) # 5日收益率排名
|
||||
>>> ret10 = ReturnRankFactor(period=10) # 10日收益率排名
|
||||
"""
|
||||
|
||||
from typing import List
|
||||
|
||||
import polars as pl
|
||||
|
||||
from src.factors.base import CrossSectionalFactor
|
||||
from src.factors.data_spec import DataSpec, FactorData
|
||||
|
||||
|
||||
class ReturnRankFactor(CrossSectionalFactor):
|
||||
"""过去n日收益率排名因子
|
||||
|
||||
计算逻辑:每个交易日,计算所有股票过去n日的收益率,然后进行截面排名。
|
||||
|
||||
特点:
|
||||
- 参数化因子:训练时通过 period 参数指定计算窗口
|
||||
- 截面因子:每天对所有股票进行横向排名,防止日期泄露
|
||||
|
||||
Attributes:
|
||||
period: 收益率计算期(默认5日)
|
||||
|
||||
Example:
|
||||
>>> ret5 = ReturnRankFactor(period=5)
|
||||
>>> # 每个交易日,返回所有股票过去5日收益率的排名
|
||||
"""
|
||||
|
||||
name: str = "return_rank"
|
||||
factor_type: str = "cross_sectional"
|
||||
category: str = "momentum"
|
||||
description: str = "过去n日收益率的截面排名因子"
|
||||
data_specs: List[DataSpec] = [
|
||||
DataSpec("daily", ["ts_code", "trade_date", "close"], lookback_days=5)
|
||||
]
|
||||
|
||||
def __init__(self, period: int = 5):
|
||||
"""初始化因子
|
||||
|
||||
Args:
|
||||
period: 收益率计算期(天数)
|
||||
"""
|
||||
super().__init__(period=period)
|
||||
# 重新创建 DataSpec 以设置正确的 lookback_days(DataSpec 是 frozen 的)
|
||||
self.data_specs = [
|
||||
DataSpec(
|
||||
"daily",
|
||||
["ts_code", "trade_date", "close"],
|
||||
lookback_days=period + 1,
|
||||
)
|
||||
]
|
||||
self.name = f"return_{period}_rank"
|
||||
|
||||
def compute(self, data: FactorData) -> pl.Series:
|
||||
"""计算过去n日收益率排名
|
||||
|
||||
Args:
|
||||
data: FactorData,包含过去n+1天的截面数据
|
||||
|
||||
Returns:
|
||||
过去n日收益率的截面排名(0-1之间)
|
||||
"""
|
||||
# 获取当前日期的截面数据
|
||||
cs = data.to_polars()
|
||||
|
||||
# 获取所有交易日期(已按日期排序)
|
||||
trade_dates = cs["trade_date"].unique().sort()
|
||||
|
||||
if len(trade_dates) < 2:
|
||||
# 数据不足,返回空排名
|
||||
return pl.Series(name=self.name, values=[])
|
||||
|
||||
# 获取最新日期的数据
|
||||
latest_date = trade_dates[-1]
|
||||
current_data = cs.filter(pl.col("trade_date") == latest_date)
|
||||
|
||||
# 获取n天前的日期
|
||||
n_days_ago = trade_dates[-(self.params["period"] + 1)]
|
||||
past_data = cs.filter(pl.col("trade_date") == n_days_ago)
|
||||
|
||||
# 通过 ts_code join 计算收益率
|
||||
merged = current_data.select(["ts_code", "close"]).join(
|
||||
past_data.select(["ts_code", "close"]).rename({"close": "close_past"}),
|
||||
on="ts_code",
|
||||
how="inner",
|
||||
)
|
||||
|
||||
# 计算收益率
|
||||
returns = (merged["close"] - merged["close_past"]) / merged["close_past"]
|
||||
|
||||
# 返回排名(0-1之间)
|
||||
return returns.rank(method="average") / len(returns)
|
||||
@@ -1,9 +1,10 @@
|
||||
"""ProStock 模型训练框架
|
||||
"""ProStock ML Pipeline 组件库
|
||||
|
||||
组件化、低耦合、插件式的机器学习训练框架。
|
||||
提供组件化、低耦合、插件式的机器学习流水线组件。
|
||||
包括处理器、模型、划分策略等可复用组件。
|
||||
|
||||
示例:
|
||||
>>> from src.models import (
|
||||
>>> from src.pipeline import (
|
||||
... PluginRegistry, ProcessingPipeline,
|
||||
... PipelineStage, BaseProcessor
|
||||
... )
|
||||
@@ -21,7 +22,7 @@
|
||||
"""
|
||||
|
||||
# 导入核心抽象类和划分策略
|
||||
from src.models.core import (
|
||||
from src.pipeline.core import (
|
||||
PipelineStage,
|
||||
TaskType,
|
||||
BaseProcessor,
|
||||
@@ -34,13 +35,13 @@ from src.models.core import (
|
||||
)
|
||||
|
||||
# 导入注册中心
|
||||
from src.models.registry import PluginRegistry
|
||||
from src.pipeline.registry import PluginRegistry
|
||||
|
||||
# 导入处理流水线
|
||||
from src.models.pipeline import ProcessingPipeline
|
||||
from src.pipeline.pipeline import ProcessingPipeline
|
||||
|
||||
# 导入并注册内置处理器
|
||||
from src.models.processors.processors import (
|
||||
from src.pipeline.processors.processors import (
|
||||
DropNAProcessor,
|
||||
FillNAProcessor,
|
||||
Winsorizer,
|
||||
@@ -51,7 +52,7 @@ from src.models.processors.processors import (
|
||||
)
|
||||
|
||||
# 导入并注册内置模型
|
||||
from src.models.models.models import (
|
||||
from src.pipeline.models.models import (
|
||||
LightGBMModel,
|
||||
CatBoostModel,
|
||||
)
|
||||
@@ -1,6 +1,6 @@
|
||||
"""核心模块导出"""
|
||||
|
||||
from src.models.core.base import (
|
||||
from src.pipeline.core.base import (
|
||||
PipelineStage,
|
||||
TaskType,
|
||||
BaseProcessor,
|
||||
@@ -9,7 +9,7 @@ from src.models.core.base import (
|
||||
BaseMetric,
|
||||
)
|
||||
|
||||
from src.models.core.splitter import (
|
||||
from src.pipeline.core.splitter import (
|
||||
TimeSeriesSplit,
|
||||
WalkForwardSplit,
|
||||
ExpandingWindowSplit,
|
||||
@@ -6,7 +6,7 @@
|
||||
from typing import Iterator, List, Tuple
|
||||
import polars as pl
|
||||
|
||||
from src.models.core.base import BaseSplitter
|
||||
from src.pipeline.core.base import BaseSplitter
|
||||
|
||||
|
||||
class TimeSeriesSplit(BaseSplitter):
|
||||
@@ -1,6 +1,6 @@
|
||||
"""模型模块"""
|
||||
|
||||
from src.models.models.models import (
|
||||
from src.pipeline.models.models import (
|
||||
LightGBMModel,
|
||||
CatBoostModel,
|
||||
)
|
||||
@@ -7,8 +7,8 @@ from typing import Optional, Dict, Any
|
||||
import polars as pl
|
||||
import numpy as np
|
||||
|
||||
from src.models.core import BaseModel, TaskType
|
||||
from src.models.registry import PluginRegistry
|
||||
from src.pipeline.core import BaseModel, TaskType
|
||||
from src.pipeline.registry import PluginRegistry
|
||||
|
||||
|
||||
@PluginRegistry.register_model("lightgbm")
|
||||
@@ -6,7 +6,7 @@
|
||||
from typing import List, Dict
|
||||
import polars as pl
|
||||
|
||||
from src.models.core import BaseProcessor, PipelineStage
|
||||
from src.pipeline.core import BaseProcessor, PipelineStage
|
||||
|
||||
|
||||
class ProcessingPipeline:
|
||||
@@ -1,6 +1,6 @@
|
||||
"""处理器模块"""
|
||||
|
||||
from src.models.processors.processors import (
|
||||
from src.pipeline.processors.processors import (
|
||||
DropNAProcessor,
|
||||
FillNAProcessor,
|
||||
Winsorizer,
|
||||
@@ -7,8 +7,8 @@ from typing import List, Optional, Dict, Any
|
||||
import polars as pl
|
||||
import numpy as np
|
||||
|
||||
from src.models.core import BaseProcessor, PipelineStage
|
||||
from src.models.registry import PluginRegistry
|
||||
from src.pipeline.core import BaseProcessor, PipelineStage
|
||||
from src.pipeline.registry import PluginRegistry
|
||||
|
||||
# 数值类型列表
|
||||
FLOAT_TYPES = [pl.Float32, pl.Float64, pl.Int8, pl.Int16, pl.Int32, pl.Int64]
|
||||
@@ -17,7 +17,7 @@ from functools import wraps
|
||||
from weakref import WeakValueDictionary
|
||||
import contextlib
|
||||
|
||||
from src.models.core import BaseProcessor, BaseModel, BaseSplitter, BaseMetric
|
||||
from src.pipeline.core import BaseProcessor, BaseModel, BaseSplitter, BaseMetric
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
46
src/training/__init__.py
Normal file
46
src/training/__init__.py
Normal file
@@ -0,0 +1,46 @@
|
||||
"""ProStock 训练流程模块
|
||||
|
||||
本模块提供完整的模型训练流程:
|
||||
1. 数据处理:Fillna(0) -> Dropna
|
||||
2. 模型训练:LightGBM分类模型
|
||||
3. 预测选股:每日top5股票池
|
||||
|
||||
使用示例:
|
||||
from src.training import run_training
|
||||
|
||||
# 运行完整训练流程
|
||||
result = run_training(
|
||||
train_start="20180101",
|
||||
train_end="20230101",
|
||||
test_start="20230101",
|
||||
test_end="20240101",
|
||||
top_n=5,
|
||||
output_path="output/top_stocks.tsv"
|
||||
)
|
||||
|
||||
因子使用:
|
||||
from src.factors import MovingAverageFactor, ReturnRankFactor
|
||||
|
||||
ma5 = MovingAverageFactor(period=5) # 5日移动平均
|
||||
ma10 = MovingAverageFactor(period=10) # 10日移动平均
|
||||
ret5 = ReturnRankFactor(period=5) # 5日收益率排名
|
||||
"""
|
||||
|
||||
from src.training.pipeline import (
|
||||
create_pipeline,
|
||||
predict_top_stocks,
|
||||
prepare_data,
|
||||
run_training,
|
||||
save_top_stocks,
|
||||
train_model,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
# 管道函数
|
||||
"prepare_data",
|
||||
"create_pipeline",
|
||||
"train_model",
|
||||
"predict_top_stocks",
|
||||
"save_top_stocks",
|
||||
"run_training",
|
||||
]
|
||||
27
src/training/main.py
Normal file
27
src/training/main.py
Normal file
@@ -0,0 +1,27 @@
|
||||
"""训练流程入口脚本
|
||||
|
||||
运行方式:
|
||||
uv run python -m src.training.main
|
||||
|
||||
或:
|
||||
uv run python src/training/main.py
|
||||
"""
|
||||
|
||||
from src.training.pipeline import run_training
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 运行完整训练流程
|
||||
# 训练集:20180101 - 20230101
|
||||
# 测试集:20230101 - 20240101
|
||||
result = run_training(
|
||||
train_start="20190101",
|
||||
train_end="20250101",
|
||||
test_start="20250101",
|
||||
test_end="20260101",
|
||||
top_n=5,
|
||||
output_path="output/top_stocks.tsv",
|
||||
)
|
||||
|
||||
print("\n[Result] Top stocks selection:")
|
||||
print(result)
|
||||
1216
src/training/output/top_stocks.tsv
Normal file
1216
src/training/output/top_stocks.tsv
Normal file
File diff suppressed because it is too large
Load Diff
448
src/training/pipeline.py
Normal file
448
src/training/pipeline.py
Normal file
@@ -0,0 +1,448 @@
|
||||
"""训练管道 - 包含数据处理、模型训练和预测功能
|
||||
|
||||
本模块提供:
|
||||
1. 数据准备:从因子计算结果中准备训练/测试数据
|
||||
2. 数据处理:Fillna(0) -> Dropna
|
||||
3. 模型训练:使用LightGBM训练分类模型
|
||||
4. 预测和选股:输出每日top5股票池
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import polars as pl
|
||||
|
||||
from src.factors import DataLoader, FactorEngine
|
||||
from src.factors.data_spec import DataSpec
|
||||
from src.pipeline import (
|
||||
DropNAProcessor,
|
||||
FillNAProcessor,
|
||||
LightGBMModel,
|
||||
PipelineStage,
|
||||
ProcessingPipeline,
|
||||
TaskType,
|
||||
)
|
||||
|
||||
|
||||
def prepare_data(
|
||||
data_dir: str = "data",
|
||||
train_start: str = "20180101",
|
||||
train_end: str = "20230101",
|
||||
test_start: str = "20230101",
|
||||
test_end: str = "20240101",
|
||||
) -> Tuple[pl.DataFrame, pl.DataFrame]:
|
||||
"""准备训练和测试数据
|
||||
|
||||
从DuckDB加载原始日线数据,计算所需因子并生成标签。
|
||||
|
||||
Args:
|
||||
data_dir: 数据目录
|
||||
train_start: 训练集开始日期
|
||||
train_end: 训练集结束日期
|
||||
test_start: 测试集开始日期
|
||||
test_end: 测试集结束日期
|
||||
|
||||
Returns:
|
||||
(train_data, test_data): 训练集和测试集的DataFrame
|
||||
"""
|
||||
from src.data.storage import Storage
|
||||
|
||||
storage = Storage()
|
||||
|
||||
# 加载日线数据(需要更多历史数据用于计算因子)
|
||||
# 训练集需要更多历史数据(用于计算因子lookback)
|
||||
lookback_days = 20 # 足够计算MA10和5日收益率
|
||||
start_with_lookback = str(int(train_start) - 10000) # 往前取一年
|
||||
|
||||
# 查询训练集数据
|
||||
# 注意:DuckDB 中 trade_date 是 DATE 类型,需要转换
|
||||
start_dt = f"{start_with_lookback[:4]}-{start_with_lookback[4:6]}-{start_with_lookback[6:8]}"
|
||||
end_dt = f"{train_end[:4]}-{train_end[4:6]}-{train_end[6:8]}"
|
||||
train_query = f"""
|
||||
SELECT ts_code, trade_date, close, pre_close
|
||||
FROM daily
|
||||
WHERE trade_date >= '{start_dt}' AND trade_date <= '{end_dt}'
|
||||
ORDER BY ts_code, trade_date
|
||||
"""
|
||||
train_raw = storage._connection.sql(train_query).pl()
|
||||
# 转换 trade_date 为字符串格式
|
||||
train_raw = train_raw.with_columns(
|
||||
pl.col("trade_date").dt.strftime("%Y-%m-%d").alias("trade_date")
|
||||
)
|
||||
|
||||
# 查询测试集数据(也需要历史数据计算因子)
|
||||
test_start_dt = f"{test_start[:4]}-{test_start[4:6]}-{test_start[6:8]}"
|
||||
test_end_dt = f"{test_end[:4]}-{test_end[4:6]}-{test_end[6:8]}"
|
||||
test_query = f"""
|
||||
SELECT ts_code, trade_date, close, pre_close
|
||||
FROM daily
|
||||
WHERE trade_date >= '{test_start_dt}' AND trade_date <= '{test_end_dt}'
|
||||
ORDER BY ts_code, trade_date
|
||||
"""
|
||||
test_raw = storage._connection.sql(test_query).pl()
|
||||
# 转换 trade_date 为字符串格式
|
||||
test_raw = test_raw.with_columns(
|
||||
pl.col("trade_date").dt.strftime("%Y-%m-%d").alias("trade_date")
|
||||
)
|
||||
|
||||
# 过滤不符合条件的股票
|
||||
train_raw = _filter_invalid_stocks(train_raw)
|
||||
test_raw = _filter_invalid_stocks(test_raw)
|
||||
print(f"[PrepareData] After filtering: train={len(train_raw)}, test={len(test_raw)}")
|
||||
|
||||
# 计算因子和标签
|
||||
train_data = _compute_features_and_label(train_raw, train_start, train_end)
|
||||
test_data = _compute_features_and_label(test_raw, test_start, test_end)
|
||||
|
||||
return train_data, test_data
|
||||
|
||||
|
||||
def _filter_invalid_stocks(df: pl.DataFrame) -> pl.DataFrame:
|
||||
"""过滤不符合条件的股票
|
||||
|
||||
过滤规则:
|
||||
1. 过滤北交所股票(ts_code 以 BJ 结尾)
|
||||
2. 过滤创业板股票(ts_code 以 30 开头)
|
||||
3. 过滤科创板股票(ts_code 以 68 开头)
|
||||
4. 过滤退市/风险股票(ts_code 以 8 开头)
|
||||
|
||||
Args:
|
||||
df: 原始数据
|
||||
|
||||
Returns:
|
||||
过滤后的数据
|
||||
"""
|
||||
ts_code_col = pl.col("ts_code")
|
||||
|
||||
return df.filter(
|
||||
~ts_code_col.str.ends_with("BJ")
|
||||
& ~ts_code_col.str.starts_with("30")
|
||||
& ~ts_code_col.str.starts_with("68")
|
||||
& ~ts_code_col.str.starts_with("8")
|
||||
)
|
||||
|
||||
|
||||
def _compute_features_and_label(
|
||||
raw_data: pl.DataFrame,
|
||||
start_date: str,
|
||||
end_date: str,
|
||||
) -> pl.DataFrame:
|
||||
"""计算因子和标签
|
||||
|
||||
因子:
|
||||
1. return_5_rank: 5日收益率截面排名
|
||||
2. ma_5: 5日移动平均
|
||||
3. ma_10: 10日移动平均
|
||||
|
||||
标签:未来5日收益率大于0为1,否则为0
|
||||
|
||||
Args:
|
||||
raw_data: 原始日线数据
|
||||
start_date: 开始日期
|
||||
end_date: 结束日期
|
||||
|
||||
Returns:
|
||||
包含因子和标签的DataFrame
|
||||
"""
|
||||
# 确保按日期排序
|
||||
raw_data = raw_data.sort(["ts_code", "trade_date"])
|
||||
|
||||
# 计算收益率(未来5日)
|
||||
raw_data = raw_data.with_columns(
|
||||
[
|
||||
# 当日收益率
|
||||
((pl.col("close") - pl.col("pre_close")) / pl.col("pre_close")).alias(
|
||||
"daily_return"
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
# 按股票分组计算
|
||||
result_list = []
|
||||
|
||||
for ts_code in raw_data["ts_code"].unique():
|
||||
stock_data = raw_data.filter(pl.col("ts_code") == ts_code).sort("trade_date")
|
||||
|
||||
if len(stock_data) < 20:
|
||||
continue
|
||||
|
||||
# 计算MA5和MA10
|
||||
stock_data = stock_data.with_columns(
|
||||
[
|
||||
pl.col("close").rolling_mean(5).alias("ma_5"),
|
||||
pl.col("close").rolling_mean(10).alias("ma_10"),
|
||||
]
|
||||
)
|
||||
|
||||
# 计算未来5日收益率(用于标签)
|
||||
future_return = stock_data["close"].shift(-5) - stock_data["close"]
|
||||
future_return_pct = future_return / stock_data["close"]
|
||||
stock_data = stock_data.with_columns(
|
||||
[
|
||||
future_return_pct.alias("future_return_5"),
|
||||
]
|
||||
)
|
||||
|
||||
# 生成标签:收益率>0为1,否则为0
|
||||
stock_data = stock_data.with_columns(
|
||||
[
|
||||
(pl.col("future_return_5") > 0).cast(pl.Int8).alias("label"),
|
||||
]
|
||||
)
|
||||
|
||||
result_list.append(stock_data)
|
||||
|
||||
if not result_list:
|
||||
return pl.DataFrame()
|
||||
|
||||
result = pl.concat(result_list)
|
||||
|
||||
# 转换日期格式:YYYYMMDD -> YYYY-MM-DD
|
||||
start_date_formatted = f"{start_date[:4]}-{start_date[4:6]}-{start_date[6:8]}"
|
||||
end_date_formatted = f"{end_date[:4]}-{end_date[4:6]}-{end_date[6:8]}"
|
||||
|
||||
# 过滤有效日期范围
|
||||
result = result.filter(
|
||||
(pl.col("trade_date") >= start_date_formatted) & (pl.col("trade_date") <= end_date_formatted)
|
||||
)
|
||||
|
||||
# 计算5日收益率排名(截面)
|
||||
result = result.with_columns(
|
||||
[
|
||||
pl.col("daily_return")
|
||||
.rank(method="average")
|
||||
.over("trade_date")
|
||||
.alias("return_5_rank")
|
||||
]
|
||||
)
|
||||
|
||||
# 归一化排名到0-1
|
||||
result = result.with_columns(
|
||||
[
|
||||
(
|
||||
pl.col("return_5_rank")
|
||||
/ pl.col("return_5_rank").max().over("trade_date")
|
||||
).alias("return_5_rank")
|
||||
]
|
||||
)
|
||||
|
||||
# 选择需要的列
|
||||
feature_cols = ["trade_date", "ts_code", "return_5_rank", "ma_5", "ma_10", "label"]
|
||||
result = result.select(feature_cols)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def create_pipeline() -> ProcessingPipeline:
|
||||
"""创建数据处理流水线
|
||||
|
||||
处理流程:
|
||||
1. FillNA(0): 将缺失值填充为0
|
||||
|
||||
注意:不使用 Dropna,因为会导致训练和预测时的行数不匹配
|
||||
|
||||
Returns:
|
||||
配置好的ProcessingPipeline
|
||||
"""
|
||||
processors = [
|
||||
FillNAProcessor(method="zero"), # 缺失值填充为0
|
||||
]
|
||||
return ProcessingPipeline(processors)
|
||||
|
||||
|
||||
def train_model(
|
||||
train_data: pl.DataFrame,
|
||||
feature_cols: List[str],
|
||||
label_col: str = "label",
|
||||
model_params: Optional[dict] = None,
|
||||
) -> Tuple[LightGBMModel, ProcessingPipeline]:
|
||||
"""训练LightGBM分类模型
|
||||
|
||||
Args:
|
||||
train_data: 训练数据
|
||||
feature_cols: 特征列名列表
|
||||
label_col: 标签列名
|
||||
model_params: 模型参数字典
|
||||
|
||||
Returns:
|
||||
(训练好的模型, 处理流水线)
|
||||
"""
|
||||
# 创建处理流水线
|
||||
pipeline = create_pipeline()
|
||||
print("[TrainModel] Pipeline created: FillNA(0)")
|
||||
|
||||
# 准备特征和标签
|
||||
X_train = train_data.select(feature_cols)
|
||||
y_train = train_data[label_col]
|
||||
print(f"[TrainModel] Raw samples: {len(X_train)}, features: {feature_cols}")
|
||||
|
||||
# 处理数据
|
||||
X_train_processed = pipeline.fit_transform(X_train, stage=PipelineStage.TRAIN)
|
||||
print(f"[TrainModel] After processing: {len(X_train_processed)} samples")
|
||||
|
||||
# 过滤有效标签(排除-1等无效值)
|
||||
valid_mask = y_train.is_in([0, 1])
|
||||
X_train_processed = X_train_processed.filter(valid_mask)
|
||||
y_train = y_train.filter(valid_mask)
|
||||
print(f"[TrainModel] After filtering valid labels: {len(X_train_processed)} samples")
|
||||
print(f"[TrainModel] Label distribution: {dict(y_train.value_counts().sort('label').iter_rows())}")
|
||||
|
||||
# 创建模型
|
||||
params = model_params or {
|
||||
"n_estimators": 100,
|
||||
"learning_rate": 0.05,
|
||||
"max_depth": 5,
|
||||
"num_leaves": 31,
|
||||
}
|
||||
print(f"[TrainModel] Model params: {params}")
|
||||
model = LightGBMModel(
|
||||
task_type="classification",
|
||||
params=params,
|
||||
)
|
||||
|
||||
# 训练模型
|
||||
print("[TrainModel] Training LightGBM...")
|
||||
model.fit(X_train_processed, y_train)
|
||||
print("[TrainModel] Training completed!")
|
||||
|
||||
return model, pipeline
|
||||
|
||||
|
||||
def predict_top_stocks(
|
||||
model: LightGBMModel,
|
||||
pipeline: ProcessingPipeline,
|
||||
test_data: pl.DataFrame,
|
||||
feature_cols: List[str],
|
||||
top_n: int = 5,
|
||||
) -> pl.DataFrame:
|
||||
"""预测并选出每日top N股票
|
||||
|
||||
Args:
|
||||
model: 训练好的模型
|
||||
pipeline: 数据处理流水线
|
||||
test_data: 测试数据
|
||||
feature_cols: 特征列名
|
||||
top_n: 每日选出的股票数量
|
||||
|
||||
Returns:
|
||||
包含日期和股票代码的DataFrame
|
||||
"""
|
||||
# 准备特征和必要列
|
||||
X_test = test_data.select(feature_cols)
|
||||
key_cols = ["trade_date", "ts_code"]
|
||||
key_data = test_data.select(key_cols)
|
||||
|
||||
print(f"[Predict] Test samples: {len(X_test)}, top_n: {top_n}")
|
||||
|
||||
# 处理数据(使用训练阶段的参数)
|
||||
X_test_processed = pipeline.transform(X_test, stage=PipelineStage.TEST)
|
||||
print(f"[Predict] Data processed, shape: {X_test_processed.shape}")
|
||||
|
||||
# 预测概率
|
||||
probs = model.predict_proba(X_test_processed)
|
||||
print(f"[Predict] Predictions generated, probability shape: {probs.shape}")
|
||||
|
||||
# 使用 key_data 添加预测结果,保持行数一致
|
||||
result = key_data.with_columns(
|
||||
pl.Series(
|
||||
name="pred_prob", values=probs[:, 1] if len(probs.shape) > 1 and probs.shape[1] > 1 else probs.flatten()
|
||||
),
|
||||
)
|
||||
|
||||
# 每日选出top N
|
||||
top_stocks = []
|
||||
for date in result["trade_date"].unique().sort():
|
||||
day_data = result.filter(pl.col("trade_date") == date)
|
||||
|
||||
# 按概率降序排序,选出top N
|
||||
day_top = day_data.sort("pred_prob", descending=True).head(top_n)
|
||||
|
||||
top_stocks.append(day_top.select(["trade_date", "pred_prob", "ts_code"]).rename({"pred_prob": "score"}))
|
||||
|
||||
return pl.concat(top_stocks)
|
||||
|
||||
|
||||
def save_top_stocks(top_stocks: pl.DataFrame, output_path: str) -> None:
|
||||
"""保存选股结果到TSV文件
|
||||
|
||||
Args:
|
||||
top_stocks: 选股结果
|
||||
output_path: 输出文件路径
|
||||
"""
|
||||
# 转换为pandas并保存为TSV
|
||||
df = top_stocks.to_pandas()
|
||||
df.to_csv(output_path, sep="\t", index=False)
|
||||
print(f"[Training] Top stocks saved to: {output_path}")
|
||||
|
||||
|
||||
def run_training(
|
||||
data_dir: str = "data",
|
||||
output_path: str = "output/top_stocks.tsv",
|
||||
train_start: str = "20180101",
|
||||
train_end: str = "20230101",
|
||||
test_start: str = "20230101",
|
||||
test_end: str = "20240101",
|
||||
top_n: int = 5,
|
||||
) -> pl.DataFrame:
|
||||
"""运行完整训练流程
|
||||
|
||||
Args:
|
||||
data_dir: 数据目录
|
||||
output_path: 输出文件路径
|
||||
train_start: 训练集开始日期
|
||||
train_end: 训练集结束日期
|
||||
test_start: 测试集开始日期
|
||||
test_end: 测试集结束日期
|
||||
top_n: 每日选股数量
|
||||
|
||||
Returns:
|
||||
选股结果DataFrame
|
||||
"""
|
||||
print(f"[Training] Starting training pipeline...")
|
||||
print(f"[Training] Train period: {train_start} -> {train_end}")
|
||||
print(f"[Training] Test period: {test_start} -> {test_end}")
|
||||
|
||||
# 1. 准备数据
|
||||
print("[Training] Preparing data...")
|
||||
train_data, test_data = prepare_data(
|
||||
data_dir=data_dir,
|
||||
train_start=train_start,
|
||||
train_end=train_end,
|
||||
test_start=test_start,
|
||||
test_end=test_end,
|
||||
)
|
||||
print(f"[Training] Train samples: {len(train_data)}")
|
||||
print(f"[Training] Test samples: {len(test_data)}")
|
||||
|
||||
# 2. 定义特征列
|
||||
feature_cols = ["return_5_rank", "ma_5", "ma_10"]
|
||||
label_col = "label"
|
||||
|
||||
# 3. 训练模型
|
||||
print("[Training] Training model...")
|
||||
model, pipeline = train_model(
|
||||
train_data=train_data,
|
||||
feature_cols=feature_cols,
|
||||
label_col=label_col,
|
||||
)
|
||||
|
||||
# 4. 测试集预测
|
||||
print("[Training] Predicting on test set...")
|
||||
top_stocks = predict_top_stocks(
|
||||
model=model,
|
||||
pipeline=pipeline,
|
||||
test_data=test_data,
|
||||
feature_cols=feature_cols,
|
||||
top_n=top_n,
|
||||
)
|
||||
|
||||
# 5. 保存结果
|
||||
print(f"[Training] Saving results to {output_path}...")
|
||||
Path(output_path).parent.mkdir(parents=True, exist_ok=True)
|
||||
save_top_stocks(top_stocks, output_path)
|
||||
|
||||
print("[Training] Training completed!")
|
||||
|
||||
return top_stocks
|
||||
@@ -1,4 +1,4 @@
|
||||
"""模型框架核心测试
|
||||
"""Pipeline 组件库核心测试
|
||||
|
||||
测试核心抽象类、插件注册中心、处理器、模型和划分策略。
|
||||
"""
|
||||
@@ -9,7 +9,7 @@ import numpy as np
|
||||
from typing import List, Optional
|
||||
|
||||
# 确保导入时注册所有组件
|
||||
from src.models import (
|
||||
from src.pipeline import (
|
||||
PluginRegistry,
|
||||
PipelineStage,
|
||||
BaseProcessor,
|
||||
@@ -17,7 +17,7 @@ from src.models import (
|
||||
BaseSplitter,
|
||||
ProcessingPipeline,
|
||||
)
|
||||
from src.models.core import TaskType
|
||||
from src.pipeline.core import TaskType
|
||||
|
||||
|
||||
# ========== 测试核心抽象类 ==========
|
||||
@@ -232,7 +232,7 @@ class TestBuiltInProcessors:
|
||||
|
||||
def test_dropna_processor(self):
|
||||
"""测试缺失值删除处理器"""
|
||||
from src.models.processors import DropNAProcessor
|
||||
from src.pipeline.processors import DropNAProcessor
|
||||
|
||||
processor = DropNAProcessor(columns=["a", "b"])
|
||||
df = pl.DataFrame({"a": [1, None, 3], "b": [4, 5, None], "c": [7, 8, 9]})
|
||||
@@ -246,7 +246,7 @@ class TestBuiltInProcessors:
|
||||
|
||||
def test_fillna_processor(self):
|
||||
"""测试缺失值填充处理器"""
|
||||
from src.models.processors import FillNAProcessor
|
||||
from src.pipeline.processors import FillNAProcessor
|
||||
|
||||
processor = FillNAProcessor(columns=["a"], method="mean")
|
||||
df = pl.DataFrame({"a": [1.0, 2.0, None, 4.0]})
|
||||
@@ -258,7 +258,7 @@ class TestBuiltInProcessors:
|
||||
|
||||
def test_standard_scaler(self):
|
||||
"""测试标准化处理器"""
|
||||
from src.models.processors import StandardScaler
|
||||
from src.pipeline.processors import StandardScaler
|
||||
|
||||
processor = StandardScaler(columns=["value"])
|
||||
df = pl.DataFrame({"value": [1.0, 2.0, 3.0, 4.0, 5.0]})
|
||||
@@ -271,7 +271,7 @@ class TestBuiltInProcessors:
|
||||
|
||||
def test_winsorizer(self):
|
||||
"""测试缩尾处理器"""
|
||||
from src.models.processors import Winsorizer
|
||||
from src.pipeline.processors import Winsorizer
|
||||
|
||||
processor = Winsorizer(columns=["value"], lower=0.1, upper=0.9)
|
||||
df = pl.DataFrame(
|
||||
@@ -288,7 +288,7 @@ class TestBuiltInProcessors:
|
||||
|
||||
def test_rank_transformer(self):
|
||||
"""测试排名转换处理器"""
|
||||
from src.models.processors import RankTransformer
|
||||
from src.pipeline.processors import RankTransformer
|
||||
|
||||
processor = RankTransformer(columns=["value"])
|
||||
df = pl.DataFrame(
|
||||
@@ -302,7 +302,7 @@ class TestBuiltInProcessors:
|
||||
|
||||
def test_neutralizer(self):
|
||||
"""测试中性化处理器"""
|
||||
from src.models.processors import Neutralizer
|
||||
from src.pipeline.processors import Neutralizer
|
||||
|
||||
processor = Neutralizer(columns=["value"], group_col="industry")
|
||||
df = pl.DataFrame(
|
||||
@@ -331,7 +331,7 @@ class TestProcessingPipeline:
|
||||
|
||||
def test_pipeline_fit_transform(self):
|
||||
"""测试流水线的 fit_transform"""
|
||||
from src.models.processors import StandardScaler
|
||||
from src.pipeline.processors import StandardScaler
|
||||
|
||||
scaler1 = StandardScaler(columns=["a"])
|
||||
scaler2 = StandardScaler(columns=["b"])
|
||||
@@ -348,7 +348,7 @@ class TestProcessingPipeline:
|
||||
|
||||
def test_pipeline_transform_uses_fitted_params(self):
|
||||
"""测试 transform 使用已 fit 的参数"""
|
||||
from src.models.processors import StandardScaler
|
||||
from src.pipeline.processors import StandardScaler
|
||||
|
||||
scaler = StandardScaler(columns=["value"])
|
||||
pipeline = ProcessingPipeline([scaler])
|
||||
@@ -383,7 +383,7 @@ class TestSplitters:
|
||||
|
||||
def test_time_series_split(self):
|
||||
"""测试时间序列划分"""
|
||||
from src.models.core import TimeSeriesSplit
|
||||
from src.pipeline.core import TimeSeriesSplit
|
||||
|
||||
splitter = TimeSeriesSplit(n_splits=2, gap=1, min_train_size=3)
|
||||
|
||||
@@ -406,7 +406,7 @@ class TestSplitters:
|
||||
|
||||
def test_walk_forward_split(self):
|
||||
"""测试滚动前向划分"""
|
||||
from src.models.core import WalkForwardSplit
|
||||
from src.pipeline.core import WalkForwardSplit
|
||||
|
||||
splitter = WalkForwardSplit(train_window=5, test_window=2, gap=1)
|
||||
|
||||
@@ -426,7 +426,7 @@ class TestSplitters:
|
||||
|
||||
def test_expanding_window_split(self):
|
||||
"""测试扩展窗口划分"""
|
||||
from src.models.core import ExpandingWindowSplit
|
||||
from src.pipeline.core import ExpandingWindowSplit
|
||||
|
||||
splitter = ExpandingWindowSplit(initial_train_size=3, test_window=2, gap=1)
|
||||
|
||||
@@ -455,7 +455,7 @@ class TestModels:
|
||||
@pytest.mark.skip(reason="需要安装 lightgbm")
|
||||
def test_lightgbm_model(self):
|
||||
"""测试 LightGBM 模型"""
|
||||
from src.models.models import LightGBMModel
|
||||
from src.pipeline.models import LightGBMModel
|
||||
|
||||
model = LightGBMModel(task_type="regression", params={"n_estimators": 10})
|
||||
|
||||
@@ -1,277 +1,163 @@
|
||||
"""Tests for data synchronization module.
|
||||
"""Sync 接口测试规范与实现。
|
||||
|
||||
Tests the sync module's full/incremental sync logic for daily data:
|
||||
- Full sync when local data doesn't exist (from 20180101)
|
||||
- Incremental sync when local data exists (from last_date + 1)
|
||||
- Data integrity validation
|
||||
【测试规范】
|
||||
1. 所有 sync 测试只使用 2018-01-01 到 2018-04-01 的数据
|
||||
2. 只测试接口是否能正常返回数据,不测试落库逻辑
|
||||
3. 对于按股票查询的接口,只测试 000001.SZ、000002.SZ 两支股票
|
||||
4. 使用真实 API 调用,确保接口可用性
|
||||
|
||||
【测试范围】
|
||||
- get_daily: 日线数据接口(按股票)
|
||||
- sync_all_stocks: 股票基础信息接口
|
||||
- sync_trade_cal_cache: 交易日历接口
|
||||
- sync_namechange: 名称变更接口
|
||||
- sync_bak_basic: 备用股票基础信息接口
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import pandas as pd
|
||||
from unittest.mock import Mock, patch, MagicMock
|
||||
from datetime import datetime, timedelta
|
||||
from datetime import datetime
|
||||
|
||||
from src.data.sync import (
|
||||
DataSync,
|
||||
sync_all,
|
||||
get_today_date,
|
||||
get_next_date,
|
||||
DEFAULT_START_DATE,
|
||||
)
|
||||
from src.data.storage import ThreadSafeStorage
|
||||
from src.data.client import TushareClient
|
||||
# 测试用常量
|
||||
TEST_START_DATE = "20180101"
|
||||
TEST_END_DATE = "20180401"
|
||||
TEST_STOCK_CODES = ["000001.SZ", "000002.SZ"]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_storage():
|
||||
"""Create a mock storage instance."""
|
||||
storage = Mock(spec=ThreadSafeStorage)
|
||||
storage.exists = Mock(return_value=False)
|
||||
storage.load = Mock(return_value=pd.DataFrame())
|
||||
storage.save = Mock(return_value={"status": "success", "rows": 0})
|
||||
return storage
|
||||
class TestGetDaily:
|
||||
"""测试日线数据 get 接口(按股票查询)."""
|
||||
|
||||
def test_get_daily_single_stock(self):
|
||||
"""测试 get_daily 获取单只股票数据."""
|
||||
from src.data.api_wrappers.api_daily import get_daily
|
||||
|
||||
@pytest.fixture
|
||||
def mock_client():
|
||||
"""Create a mock client instance."""
|
||||
return Mock(spec=TushareClient)
|
||||
|
||||
|
||||
class TestDateUtilities:
|
||||
"""Test date utility functions."""
|
||||
|
||||
def test_get_today_date_format(self):
|
||||
"""Test today date is in YYYYMMDD format."""
|
||||
result = get_today_date()
|
||||
assert len(result) == 8
|
||||
assert result.isdigit()
|
||||
|
||||
def test_get_next_date(self):
|
||||
"""Test getting next date."""
|
||||
result = get_next_date("20240101")
|
||||
assert result == "20240102"
|
||||
|
||||
def test_get_next_date_year_end(self):
|
||||
"""Test getting next date across year boundary."""
|
||||
result = get_next_date("20241231")
|
||||
assert result == "20250101"
|
||||
|
||||
def test_get_next_date_month_end(self):
|
||||
"""Test getting next date across month boundary."""
|
||||
result = get_next_date("20240131")
|
||||
assert result == "20240201"
|
||||
|
||||
|
||||
class TestDataSync:
|
||||
"""Test DataSync class functionality."""
|
||||
|
||||
def test_get_all_stock_codes_from_daily(self, mock_storage):
|
||||
"""Test getting stock codes from daily data."""
|
||||
with patch("src.data.sync.ThreadSafeStorage", return_value=mock_storage):
|
||||
sync = DataSync()
|
||||
sync.storage = mock_storage
|
||||
|
||||
mock_storage.load.return_value = pd.DataFrame(
|
||||
{
|
||||
"ts_code": ["000001.SZ", "000001.SZ", "600000.SH"],
|
||||
}
|
||||
result = get_daily(
|
||||
ts_code=TEST_STOCK_CODES[0],
|
||||
start_date=TEST_START_DATE,
|
||||
end_date=TEST_END_DATE,
|
||||
)
|
||||
|
||||
codes = sync.get_all_stock_codes()
|
||||
# 验证返回了数据
|
||||
assert isinstance(result, pd.DataFrame), "get_daily 应返回 DataFrame"
|
||||
assert not result.empty, "get_daily 应返回非空数据"
|
||||
|
||||
assert len(codes) == 2
|
||||
assert "000001.SZ" in codes
|
||||
assert "600000.SH" in codes
|
||||
def test_get_daily_has_required_columns(self):
|
||||
"""测试 get_daily 返回的数据包含必要字段."""
|
||||
from src.data.api_wrappers.api_daily import get_daily
|
||||
|
||||
def test_get_all_stock_codes_fallback(self, mock_storage):
|
||||
"""Test fallback to stock_basic when daily is empty."""
|
||||
with patch("src.data.sync.ThreadSafeStorage", return_value=mock_storage):
|
||||
sync = DataSync()
|
||||
sync.storage = mock_storage
|
||||
|
||||
# First call (daily) returns empty, second call (stock_basic) returns data
|
||||
mock_storage.load.side_effect = [
|
||||
pd.DataFrame(), # daily empty
|
||||
pd.DataFrame({"ts_code": ["000001.SZ", "600000.SH"]}), # stock_basic
|
||||
]
|
||||
|
||||
codes = sync.get_all_stock_codes()
|
||||
|
||||
assert len(codes) == 2
|
||||
|
||||
def test_get_global_last_date(self, mock_storage):
|
||||
"""Test getting global last date."""
|
||||
with patch("src.data.sync.ThreadSafeStorage", return_value=mock_storage):
|
||||
sync = DataSync()
|
||||
sync.storage = mock_storage
|
||||
|
||||
mock_storage.load.return_value = pd.DataFrame(
|
||||
{
|
||||
"ts_code": ["000001.SZ", "600000.SH"],
|
||||
"trade_date": ["20240102", "20240103"],
|
||||
}
|
||||
result = get_daily(
|
||||
ts_code=TEST_STOCK_CODES[0],
|
||||
start_date=TEST_START_DATE,
|
||||
end_date=TEST_END_DATE,
|
||||
)
|
||||
|
||||
last_date = sync.get_global_last_date()
|
||||
assert last_date == "20240103"
|
||||
# 验证必要的列存在
|
||||
required_columns = ["ts_code", "trade_date", "open", "high", "low", "close"]
|
||||
for col in required_columns:
|
||||
assert col in result.columns, f"get_daily 返回应包含 {col} 列"
|
||||
|
||||
def test_get_global_last_date_empty(self, mock_storage):
|
||||
"""Test getting last date from empty storage."""
|
||||
with patch("src.data.sync.ThreadSafeStorage", return_value=mock_storage):
|
||||
sync = DataSync()
|
||||
sync.storage = mock_storage
|
||||
def test_get_daily_multiple_stocks(self):
|
||||
"""测试 get_daily 获取多只股票数据."""
|
||||
from src.data.api_wrappers.api_daily import get_daily
|
||||
|
||||
mock_storage.load.return_value = pd.DataFrame()
|
||||
results = {}
|
||||
for code in TEST_STOCK_CODES:
|
||||
result = get_daily(
|
||||
ts_code=code,
|
||||
start_date=TEST_START_DATE,
|
||||
end_date=TEST_END_DATE,
|
||||
)
|
||||
results[code] = result
|
||||
assert isinstance(result, pd.DataFrame), (
|
||||
f"get_daily({code}) 应返回 DataFrame"
|
||||
)
|
||||
assert not result.empty, f"get_daily({code}) 应返回非空数据"
|
||||
|
||||
last_date = sync.get_global_last_date()
|
||||
assert last_date is None
|
||||
|
||||
def test_sync_single_stock(self, mock_storage):
|
||||
"""Test syncing a single stock."""
|
||||
with patch("src.data.sync.ThreadSafeStorage", return_value=mock_storage):
|
||||
with patch(
|
||||
"src.data.sync.get_daily",
|
||||
return_value=pd.DataFrame(
|
||||
{
|
||||
"ts_code": ["000001.SZ"],
|
||||
"trade_date": ["20240102"],
|
||||
}
|
||||
),
|
||||
):
|
||||
sync = DataSync()
|
||||
sync.storage = mock_storage
|
||||
class TestSyncStockBasic:
|
||||
"""测试股票基础信息 sync 接口."""
|
||||
|
||||
result = sync.sync_single_stock(
|
||||
ts_code="000001.SZ",
|
||||
start_date="20240101",
|
||||
end_date="20240102",
|
||||
def test_sync_all_stocks_returns_data(self):
|
||||
"""测试 sync_all_stocks 是否能正常返回数据."""
|
||||
from src.data.api_wrappers.api_stock_basic import sync_all_stocks
|
||||
|
||||
result = sync_all_stocks()
|
||||
|
||||
# 验证返回了数据
|
||||
assert isinstance(result, pd.DataFrame), "sync_all_stocks 应返回 DataFrame"
|
||||
assert not result.empty, "sync_all_stocks 应返回非空数据"
|
||||
|
||||
def test_sync_all_stocks_has_required_columns(self):
|
||||
"""测试 sync_all_stocks 返回的数据包含必要字段."""
|
||||
from src.data.api_wrappers.api_stock_basic import sync_all_stocks
|
||||
|
||||
result = sync_all_stocks()
|
||||
|
||||
# 验证必要的列存在
|
||||
required_columns = ["ts_code"]
|
||||
for col in required_columns:
|
||||
assert col in result.columns, f"sync_all_stocks 返回应包含 {col} 列"
|
||||
|
||||
|
||||
class TestSyncTradeCal:
|
||||
"""测试交易日历 sync 接口."""
|
||||
|
||||
def test_sync_trade_cal_cache_returns_data(self):
|
||||
"""测试 sync_trade_cal_cache 是否能正常返回数据."""
|
||||
from src.data.api_wrappers.api_trade_cal import sync_trade_cal_cache
|
||||
|
||||
result = sync_trade_cal_cache(
|
||||
start_date=TEST_START_DATE,
|
||||
end_date=TEST_END_DATE,
|
||||
)
|
||||
|
||||
assert isinstance(result, pd.DataFrame)
|
||||
assert len(result) == 1
|
||||
# 验证返回了数据
|
||||
assert isinstance(result, pd.DataFrame), "sync_trade_cal_cache 应返回 DataFrame"
|
||||
assert not result.empty, "sync_trade_cal_cache 应返回非空数据"
|
||||
|
||||
def test_sync_single_stock_empty(self, mock_storage):
|
||||
"""Test syncing a stock with no data."""
|
||||
with patch("src.data.sync.ThreadSafeStorage", return_value=mock_storage):
|
||||
with patch("src.data.sync.get_daily", return_value=pd.DataFrame()):
|
||||
sync = DataSync()
|
||||
sync.storage = mock_storage
|
||||
def test_sync_trade_cal_cache_has_required_columns(self):
|
||||
"""测试 sync_trade_cal_cache 返回的数据包含必要字段."""
|
||||
from src.data.api_wrappers.api_trade_cal import sync_trade_cal_cache
|
||||
|
||||
result = sync.sync_single_stock(
|
||||
ts_code="INVALID.SZ",
|
||||
start_date="20240101",
|
||||
end_date="20240102",
|
||||
result = sync_trade_cal_cache(
|
||||
start_date=TEST_START_DATE,
|
||||
end_date=TEST_END_DATE,
|
||||
)
|
||||
|
||||
assert result.empty
|
||||
# 验证必要的列存在
|
||||
required_columns = ["cal_date", "is_open"]
|
||||
for col in required_columns:
|
||||
assert col in result.columns, f"sync_trade_cal_cache 返回应包含 {col} 列"
|
||||
|
||||
|
||||
class TestSyncAll:
|
||||
"""Test sync_all function."""
|
||||
class TestSyncNamechange:
|
||||
"""测试名称变更 sync 接口."""
|
||||
|
||||
def test_full_sync_mode(self, mock_storage):
|
||||
"""Test full sync mode when force_full=True."""
|
||||
with patch("src.data.sync.ThreadSafeStorage", return_value=mock_storage):
|
||||
with patch("src.data.sync.get_daily", return_value=pd.DataFrame()):
|
||||
sync = DataSync()
|
||||
sync.storage = mock_storage
|
||||
sync.sync_single_stock = Mock(return_value=pd.DataFrame())
|
||||
def test_sync_namechange_returns_data(self):
|
||||
"""测试 sync_namechange 是否能正常返回数据."""
|
||||
from src.data.api_wrappers.api_namechange import sync_namechange
|
||||
|
||||
mock_storage.load.return_value = pd.DataFrame(
|
||||
{
|
||||
"ts_code": ["000001.SZ"],
|
||||
}
|
||||
result = sync_namechange()
|
||||
|
||||
# 验证返回了数据(可能是空 DataFrame,因为是历史变更)
|
||||
assert isinstance(result, pd.DataFrame), "sync_namechange 应返回 DataFrame"
|
||||
|
||||
|
||||
class TestSyncBakBasic:
|
||||
"""测试备用股票基础信息 sync 接口."""
|
||||
|
||||
def test_sync_bak_basic_returns_data(self):
|
||||
"""测试 sync_bak_basic 是否能正常返回数据."""
|
||||
from src.data.api_wrappers.api_bak_basic import sync_bak_basic
|
||||
|
||||
result = sync_bak_basic(
|
||||
start_date=TEST_START_DATE,
|
||||
end_date=TEST_END_DATE,
|
||||
)
|
||||
|
||||
result = sync.sync_all(force_full=True)
|
||||
|
||||
# Verify sync_single_stock was called with default start date
|
||||
sync.sync_single_stock.assert_called_once()
|
||||
call_args = sync.sync_single_stock.call_args
|
||||
assert call_args[1]["start_date"] == DEFAULT_START_DATE
|
||||
|
||||
def test_incremental_sync_mode(self, mock_storage):
|
||||
"""Test incremental sync mode when data exists."""
|
||||
with patch("src.data.sync.ThreadSafeStorage", return_value=mock_storage):
|
||||
sync = DataSync()
|
||||
sync.storage = mock_storage
|
||||
sync.sync_single_stock = Mock(return_value=pd.DataFrame())
|
||||
|
||||
# Mock existing data with last date
|
||||
mock_storage.load.side_effect = [
|
||||
pd.DataFrame(
|
||||
{
|
||||
"ts_code": ["000001.SZ"],
|
||||
"trade_date": ["20240102"],
|
||||
}
|
||||
), # get_all_stock_codes
|
||||
pd.DataFrame(
|
||||
{
|
||||
"ts_code": ["000001.SZ"],
|
||||
"trade_date": ["20240102"],
|
||||
}
|
||||
), # get_global_last_date
|
||||
]
|
||||
|
||||
result = sync.sync_all(force_full=False)
|
||||
|
||||
# Verify sync_single_stock was called with next date
|
||||
sync.sync_single_stock.assert_called_once()
|
||||
call_args = sync.sync_single_stock.call_args
|
||||
assert call_args[1]["start_date"] == "20240103"
|
||||
|
||||
def test_manual_start_date(self, mock_storage):
|
||||
"""Test sync with manual start date."""
|
||||
with patch("src.data.sync.ThreadSafeStorage", return_value=mock_storage):
|
||||
sync = DataSync()
|
||||
sync.storage = mock_storage
|
||||
sync.sync_single_stock = Mock(return_value=pd.DataFrame())
|
||||
|
||||
mock_storage.load.return_value = pd.DataFrame(
|
||||
{
|
||||
"ts_code": ["000001.SZ"],
|
||||
}
|
||||
)
|
||||
|
||||
result = sync.sync_all(force_full=False, start_date="20230601")
|
||||
|
||||
sync.sync_single_stock.assert_called_once()
|
||||
call_args = sync.sync_single_stock.call_args
|
||||
assert call_args[1]["start_date"] == "20230601"
|
||||
|
||||
def test_no_stocks_found(self, mock_storage):
|
||||
"""Test sync when no stocks are found."""
|
||||
with patch("src.data.sync.ThreadSafeStorage", return_value=mock_storage):
|
||||
sync = DataSync()
|
||||
sync.storage = mock_storage
|
||||
|
||||
mock_storage.load.return_value = pd.DataFrame()
|
||||
|
||||
result = sync.sync_all()
|
||||
|
||||
assert result == {}
|
||||
|
||||
|
||||
class TestSyncAllConvenienceFunction:
|
||||
"""Test sync_all convenience function."""
|
||||
|
||||
def test_sync_all_function(self):
|
||||
"""Test sync_all convenience function."""
|
||||
with patch("src.data.sync.DataSync") as MockSync:
|
||||
mock_instance = Mock()
|
||||
mock_instance.sync_all.return_value = {}
|
||||
MockSync.return_value = mock_instance
|
||||
|
||||
result = sync_all(force_full=True)
|
||||
|
||||
MockSync.assert_called_once()
|
||||
mock_instance.sync_all.assert_called_once_with(
|
||||
force_full=True,
|
||||
start_date=None,
|
||||
end_date=None,
|
||||
dry_run=False,
|
||||
)
|
||||
# 验证返回了数据
|
||||
assert isinstance(result, pd.DataFrame), "sync_bak_basic 应返回 DataFrame"
|
||||
# 注意:bak_basic 可能返回空数据,这是正常的
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -1,256 +0,0 @@
|
||||
"""Tests for data sync with REAL data (read-only).
|
||||
|
||||
Tests verify:
|
||||
1. get_global_last_date() correctly reads local data's max date
|
||||
2. Incremental sync date calculation (local_last_date + 1)
|
||||
3. Full sync date calculation (20180101)
|
||||
4. Multi-stock scenario with real data
|
||||
|
||||
⚠️ IMPORTANT: These tests ONLY read data, no write operations.
|
||||
- NO sync_all() calls (writes daily.h5)
|
||||
- NO check_sync_needed() calls (writes trade_cal.h5)
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import pandas as pd
|
||||
from pathlib import Path
|
||||
|
||||
from src.data.sync import (
|
||||
DataSync,
|
||||
get_next_date,
|
||||
DEFAULT_START_DATE,
|
||||
)
|
||||
from src.data.storage import Storage
|
||||
|
||||
|
||||
class TestDataSyncReadOnly:
|
||||
"""Read-only tests for data sync - verify date calculation logic."""
|
||||
|
||||
@pytest.fixture
|
||||
def storage(self):
|
||||
"""Create storage instance."""
|
||||
return Storage()
|
||||
|
||||
@pytest.fixture
|
||||
def data_sync(self):
|
||||
"""Create DataSync instance."""
|
||||
return DataSync()
|
||||
|
||||
@pytest.fixture
|
||||
def daily_exists(self, storage):
|
||||
"""Check if daily.h5 exists."""
|
||||
return storage.exists("daily")
|
||||
|
||||
def test_daily_h5_exists(self, storage):
|
||||
"""Verify daily.h5 data file exists before running tests."""
|
||||
assert storage.exists("daily"), (
|
||||
"daily.h5 not found. Please run full sync first: "
|
||||
"uv run python -c 'from src.data.sync import sync_all; sync_all(force_full=True)'"
|
||||
)
|
||||
|
||||
def test_get_global_last_date(self, data_sync, daily_exists):
|
||||
"""Test get_global_last_date returns correct max date from local data."""
|
||||
if not daily_exists:
|
||||
pytest.skip("daily.h5 not found")
|
||||
|
||||
last_date = data_sync.get_global_last_date()
|
||||
|
||||
# Verify it's a valid date string
|
||||
assert last_date is not None, "get_global_last_date returned None"
|
||||
assert isinstance(last_date, str), f"Expected str, got {type(last_date)}"
|
||||
assert len(last_date) == 8, f"Expected 8-digit date, got {last_date}"
|
||||
assert last_date.isdigit(), f"Expected numeric date, got {last_date}"
|
||||
|
||||
# Verify by reading storage directly
|
||||
daily_data = data_sync.storage.load("daily")
|
||||
expected_max = str(daily_data["trade_date"].max())
|
||||
|
||||
assert last_date == expected_max, (
|
||||
f"get_global_last_date returned {last_date}, "
|
||||
f"but actual max date is {expected_max}"
|
||||
)
|
||||
|
||||
print(f"[TEST] Local data last date: {last_date}")
|
||||
|
||||
def test_incremental_sync_date_calculation(self, data_sync, daily_exists):
|
||||
"""Test incremental sync: start_date = local_last_date + 1.
|
||||
|
||||
This verifies that when local data exists, incremental sync should
|
||||
fetch data from (local_last_date + 1), not from 20180101.
|
||||
"""
|
||||
if not daily_exists:
|
||||
pytest.skip("daily.h5 not found")
|
||||
|
||||
# Get local last date
|
||||
local_last_date = data_sync.get_global_last_date()
|
||||
assert local_last_date is not None, "No local data found"
|
||||
|
||||
# Calculate expected incremental start date
|
||||
expected_start_date = get_next_date(local_last_date)
|
||||
|
||||
# Verify the calculation is correct
|
||||
local_last_int = int(local_last_date)
|
||||
expected_int = local_last_int + 1
|
||||
actual_int = int(expected_start_date)
|
||||
|
||||
assert actual_int == expected_int, (
|
||||
f"Incremental start date calculation error: "
|
||||
f"expected {expected_int}, got {actual_int}"
|
||||
)
|
||||
|
||||
print(
|
||||
f"[TEST] Incremental sync: local_last={local_last_date}, "
|
||||
f"start_date should be {expected_start_date}"
|
||||
)
|
||||
|
||||
# Verify this is NOT 20180101 (would be full sync)
|
||||
assert expected_start_date != DEFAULT_START_DATE, (
|
||||
f"Incremental sync should NOT start from {DEFAULT_START_DATE}"
|
||||
)
|
||||
|
||||
def test_full_sync_date_calculation(self):
|
||||
"""Test full sync: start_date = 20180101 when force_full=True.
|
||||
|
||||
This verifies that force_full=True always starts from 20180101.
|
||||
"""
|
||||
# Full sync should always use DEFAULT_START_DATE
|
||||
full_sync_start = DEFAULT_START_DATE
|
||||
|
||||
assert full_sync_start == "20180101", (
|
||||
f"Full sync should start from 20180101, got {full_sync_start}"
|
||||
)
|
||||
|
||||
print(f"[TEST] Full sync start date: {full_sync_start}")
|
||||
|
||||
def test_date_comparison_logic(self, data_sync, daily_exists):
|
||||
"""Test date comparison: incremental vs full sync selection logic.
|
||||
|
||||
Verify that:
|
||||
- If local_last_date < today: incremental sync needed
|
||||
- If local_last_date >= today: no sync needed
|
||||
"""
|
||||
if not daily_exists:
|
||||
pytest.skip("daily.h5 not found")
|
||||
|
||||
from datetime import datetime
|
||||
|
||||
local_last_date = data_sync.get_global_last_date()
|
||||
today = datetime.now().strftime("%Y%m%d")
|
||||
|
||||
local_last_int = int(local_last_date)
|
||||
today_int = int(today)
|
||||
|
||||
# Log the comparison
|
||||
print(
|
||||
f"[TEST] Date comparison: local_last={local_last_date} ({local_last_int}), "
|
||||
f"today={today} ({today_int})"
|
||||
)
|
||||
|
||||
# This test just verifies the comparison logic works
|
||||
if local_last_int < today_int:
|
||||
print("[TEST] Local data is older than today - sync needed")
|
||||
# Incremental sync should fetch from local_last_date + 1
|
||||
sync_start = get_next_date(local_last_date)
|
||||
assert int(sync_start) > local_last_int, (
|
||||
"Sync start should be after local last"
|
||||
)
|
||||
else:
|
||||
print("[TEST] Local data is up-to-date - no sync needed")
|
||||
|
||||
def test_get_all_stock_codes_real_data(self, data_sync, daily_exists):
|
||||
"""Test get_all_stock_codes returns multiple real stock codes."""
|
||||
if not daily_exists:
|
||||
pytest.skip("daily.h5 not found")
|
||||
|
||||
codes = data_sync.get_all_stock_codes()
|
||||
|
||||
# Verify it's a list
|
||||
assert isinstance(codes, list), f"Expected list, got {type(codes)}"
|
||||
assert len(codes) > 0, "No stock codes found"
|
||||
|
||||
# Verify multiple stocks
|
||||
assert len(codes) >= 10, (
|
||||
f"Expected at least 10 stocks for multi-stock test, got {len(codes)}"
|
||||
)
|
||||
|
||||
# Verify format (should be like 000001.SZ, 600000.SH)
|
||||
sample_codes = codes[:5]
|
||||
for code in sample_codes:
|
||||
assert "." in code, f"Invalid stock code format: {code}"
|
||||
suffix = code.split(".")[-1]
|
||||
assert suffix in ["SZ", "SH"], f"Invalid exchange suffix: {suffix}"
|
||||
|
||||
print(f"[TEST] Found {len(codes)} stock codes (sample: {sample_codes})")
|
||||
|
||||
def test_multi_stock_date_range(self, data_sync, daily_exists):
|
||||
"""Test that multiple stocks share the same date range in local data.
|
||||
|
||||
This verifies that local data has consistent date coverage across stocks.
|
||||
"""
|
||||
if not daily_exists:
|
||||
pytest.skip("daily.h5 not found")
|
||||
|
||||
daily_data = data_sync.storage.load("daily")
|
||||
|
||||
# Get date range for each stock
|
||||
stock_dates = daily_data.groupby("ts_code")["trade_date"].agg(["min", "max"])
|
||||
|
||||
# Get global min and max
|
||||
global_min = str(daily_data["trade_date"].min())
|
||||
global_max = str(daily_data["trade_date"].max())
|
||||
|
||||
print(f"[TEST] Global date range: {global_min} to {global_max}")
|
||||
print(f"[TEST] Total stocks: {len(stock_dates)}")
|
||||
|
||||
# Verify we have data for multiple stocks
|
||||
assert len(stock_dates) >= 10, (
|
||||
f"Expected at least 10 stocks, got {len(stock_dates)}"
|
||||
)
|
||||
|
||||
# Verify date range is reasonable (at least 1 year of data)
|
||||
global_min_int = int(global_min)
|
||||
global_max_int = int(global_max)
|
||||
days_span = global_max_int - global_min_int
|
||||
|
||||
assert days_span > 100, (
|
||||
f"Date range too small: {days_span} days. "
|
||||
f"Expected at least 100 days of data."
|
||||
)
|
||||
|
||||
print(f"[TEST] Date span: {days_span} days")
|
||||
|
||||
|
||||
class TestDateUtilities:
|
||||
"""Test date utility functions."""
|
||||
|
||||
def test_get_next_date(self):
|
||||
"""Test get_next_date correctly calculates next day."""
|
||||
# Test normal cases
|
||||
assert get_next_date("20240101") == "20240102"
|
||||
assert get_next_date("20240131") == "20240201" # Month boundary
|
||||
assert get_next_date("20241231") == "20250101" # Year boundary
|
||||
|
||||
def test_incremental_vs_full_sync_logic(self):
|
||||
"""Test the logic difference between incremental and full sync.
|
||||
|
||||
Incremental: start_date = local_last_date + 1
|
||||
Full: start_date = 20180101
|
||||
"""
|
||||
# Scenario 1: Local data exists
|
||||
local_last_date = "20240115"
|
||||
incremental_start = get_next_date(local_last_date)
|
||||
|
||||
assert incremental_start == "20240116"
|
||||
assert incremental_start != DEFAULT_START_DATE
|
||||
|
||||
# Scenario 2: Force full sync
|
||||
full_sync_start = DEFAULT_START_DATE # "20180101"
|
||||
|
||||
assert full_sync_start == "20180101"
|
||||
assert incremental_start != full_sync_start
|
||||
|
||||
print("[TEST] Incremental vs Full sync logic verified")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v", "-s"])
|
||||
Reference in New Issue
Block a user