refactor: 存储层迁移DuckDB + 模块重构

- 存储层重构: HDF5 → DuckDB(UPSERT模式、线程安全存储)
- Sync类迁移: DataSync从sync.py迁移到api_daily.py(职责分离)
- 模型模块重构: src/models → src/pipeline(更清晰的命名)
- 新增因子模块: factors/momentum (MA、收益率排名)、factors/financial
- 新增API接口: api_namechange、api_bak_basic
- 新增训练入口: training模块(main.py、pipeline配置)
- 工具函数统一: get_today_date等移至utils.py
- 文档更新: AGENTS.md添加架构变更历史
This commit is contained in:
2026-02-23 16:23:53 +08:00
parent 9f95be56a0
commit 593ec99466
32 changed files with 4181 additions and 1395 deletions

3
.gitignore vendored
View File

@@ -75,3 +75,6 @@ temp/
# 数据目录(允许跟踪,但忽略内容)
data/*
# AI Agent 工作目录
/.sisyphus/

117
AGENTS.md
View File

@@ -2,6 +2,15 @@
A股量化投资框架 - Python 项目,用于量化股票投资分析。
## 交流语言要求
**⚠️ 强制要求:所有沟通和思考过程必须使用中文。**
所有与 AI Agent 的交流必须使用中文
代码中的注释和文档字符串使用中文
禁止使用英文进行思考或沟通
## 构建/检查/测试命令
**⚠️ 重要:本项目强制使用 uv 作为 Python 包管理器和运行工具。禁止直接使用 `python``pip` 命令。**
@@ -67,25 +76,69 @@ uv run pytest tests/test_sync.py # ✅ 正确
```
ProStock/
├── src/ # 源代码
│ ├── data/ # 数据采集模块
│ ├── config/ # 配置管理
│ │ ├── __init__.py
│ │ ── client.py # Tushare API 客户端,带速率限制
│ │ ├── config.py # 配置pydantic-settings
│ ├── daily.py # 日线市场数据
│ │ ── settings.py # pydantic-settings 配置
│ │
│ ├── data/ # 数据获取与存储
│ │ ├── api_wrappers/ # Tushare API 封装
│ │ │ ├── API_INTERFACE_SPEC.md # 接口规范文档
│ │ │ ├── api.md # API 接口定义
│ │ │ ├── api_daily.py # 日线数据接口
│ │ │ ├── api_stock_basic.py # 股票基础信息接口
│ │ │ ├── api_trade_cal.py # 交易日历接口
│ │ │ └── __init__.py
│ │ ├── __init__.py
│ │ ├── client.py # Tushare API 客户端(带速率限制)
│ │ ├── config.py # 数据模块配置
│ │ ├── db_inspector.py # 数据库信息查看工具
│ │ ├── db_manager.py # DuckDB 表管理和同步
│ │ ├── rate_limiter.py # 令牌桶速率限制器
│ │ ├── stock_basic.py # 股票基本信息
│ │ ── storage.py # HDF5 存储管理器
│ │ └── sync.py # 数据同步
│ ├── config/ # 全局配置
│ │ ├── storage.py # 数据存储核心
│ │ ── sync.py # 数据同步主逻辑
│ │
│ ├── factors/ # 因子计算框架
│ │ ├── __init__.py
│ │ ── settings.py # 应用设置pydantic-settings
└── __init__.py
│ │ ── base.py # 因子基类(截面/时序
│ ├── composite.py # 组合因子和标量运算
│ │ ├── data_loader.py # 数据加载器
│ │ ├── data_spec.py # 数据规格定义
│ │ ├── engine.py # 因子执行引擎
│ │ ├── momentum/ # 动量因子
│ │ │ ├── __init__.py
│ │ │ ├── ma.py # 移动平均线
│ │ │ └── return_rank.py # 收益排名
│ │ └── financial/ # 财务因子
│ │ └── __init__.py
│ │
│ ├── pipeline/ # 模型训练管道
│ │ ├── __init__.py
│ │ ├── pipeline.py # 处理流水线
│ │ ├── registry.py # 插件注册中心
│ │ ├── core/ # 核心抽象
│ │ │ ├── __init__.py
│ │ │ ├── base.py # 基类定义
│ │ │ └── splitter.py # 时间序列划分策略
│ │ ├── models/ # 模型实现
│ │ │ ├── __init__.py
│ │ │ └── models.py # LightGBM、CatBoost 等
│ │ └── processors/ # 数据处理器
│ │ ├── __init__.py
│ │ └── processors.py # 标准化、缩尾、中性化等
│ │
│ └── training/ # 训练入口
│ ├── __init__.py
│ ├── main.py # 训练主程序
│ ├── pipeline.py # 训练流程配置
│ └── output/ # 训练输出
│ └── top_stocks.tsv # 推荐股票结果
├── tests/ # 测试文件
│ ├── test_sync.py
│ └── test_daily.py
├── config/ # 配置文件
│ └── .env.local # 环境变量(不在 git 中)
├── data/ # 数据存储(HDF5 文件
├── data/ # 数据存储(DuckDB
├── docs/ # 文档
├── pyproject.toml # 项目配置
└── README.md
@@ -182,10 +235,10 @@ except Exception as e:
- 对配置单例使用 `@lru_cache()`
### 数据存储
- 通过 `pandas.HDFStore` 使用 **HDF5 格式** 进行持久化
- 使用 **DuckDB** 嵌入式 OLAP 数据库进行持久化
- 存储在 `data/` 目录中(通过 `DATA_PATH` 环境变量配置)
- 对可追加数据集使用 `format="table"`
- 追加时处理重复项:`drop_duplicates(subset=[...])`
- 使用 UPSERT 模式(`INSERT OR REPLACE`)处理重复数据
- 多线程场景使用 `ThreadSafeStorage.queue_save()` + `flush()` 模式
### 线程与并发
- 对 I/O 密集型任务API 调用)使用 `ThreadPoolExecutor`
@@ -240,3 +293,39 @@ uv run python -c "from src.data.sync import sync_all; sync_all(force_full=True)"
# 自定义线程数
uv run python -c "from src.data.sync import sync_all; sync_all(max_workers=20)"
```
## 架构变更历史
### v2.0 (2026-02-23) - 重要更新
#### 存储层重构
**变更**: 从 HDF5 迁移到 DuckDB
**原因**: DuckDB 提供更好的查询性能、SQL 下推能力、并发支持
**影响**: 所有数据表现在使用 DuckDB 存储,旧 HDF5 文件可手动迁移
#### Sync 类迁移
**变更**: `DataSync` 类从 `sync.py` 迁移到 `api_daily.py`
**原因**: 实现代码职责分离,每个 API 文件包含自己的同步逻辑
**影响**:
- `sync.py` 保留为调度中心
- `api_daily.py` 包含 `DailySync` 类和 `sync_daily` 函数
#### 新增模块
**pipeline 模块**: 机器学习流水线组件(处理器、模型、划分策略)
**training 模块**: 训练入口程序
**factors/momentum**: 动量因子MA、收益率排名
**factors/financial**: 财务因子框架
**data/utils.py**: 日期工具函数集中管理
#### 新增 API 接口
`api_namechange.py`: 股票曾用名接口(手动同步)
`api_bak_basic.py`: 历史股票列表接口
#### 工具函数统一
`get_today_date()``get_next_date()``DEFAULT_START_DATE` 等函数统一在 `src/data/utils.py` 中管理
其他模块应从 `utils.py` 导入这些函数,避免重复定义
### v1.x (历史版本)
初始版本,使用 HDF5 存储
数据同步逻辑集中在 `sync.py`

View File

@@ -1,12 +1,26 @@
# ProStock 数据接口封装规范
## 1. 概述
本文档定义了新增 Tushare API 接口封装的标准规范。所有非特殊接口必须遵循此规范,确保:
- 代码风格统一
- 自动 sync 支持
- 增量更新逻辑一致
- 减少存储写入压力
- 类型安全(强制类型提示)
### 1.1 技术栈
- **存储层**: DuckDB高性能嵌入式 OLAP 数据库)
- **数据格式**: Pandas DataFrame / Polars DataFrame
- **速率限制**: 令牌桶算法TokenBucketRateLimiter
- **并发**: ThreadPoolExecutor 多线程
- **类型系统**: Python 3.10+ 类型提示
### 1.2 自动化支持
项目提供 `prostock-api-interface` Skill 来自动化接口封装流程。在 `api.md` 中定义接口后,调用该 Skill 可自动生成:
- 数据模块文件(`src/data/api_wrappers/api_{data_type}.py`
- 数据库表管理配置
- 测试文件(`tests/test_{data_type}.py`
## 2. 接口分类
@@ -14,37 +28,41 @@
以下接口有独立的同步逻辑,不参与自动 sync 机制:
| 接口类型 | 示例 | 说明 |
|---------|------|------|
| 交易日历 | `trade_cal` | 全局数据,按日期范围获取 |
| 股票基础信息 | `stock_basic` | 一次性全量获取CSV 存储 |
| 辅助数据 | 行业分类、概念分类 | 低频更新,独立管理 |
| 接口类型 | 文件名 | 说明 |
|---------|--------|------|
| 交易日历 | `api_trade_cal.py` | 全局数据,按日期范围获取,使用 HDF5 缓存 |
| 股票基础信息 | `api_stock_basic.py` | 一次性全量获取CSV 存储 |
| 辅助数据 | `api_industry`, `api_concept` | 低频更新,独立管理 |
### 2.2 标准接口(必须遵循本规范)
所有按股票或按日期获取的因子数据、行情数据、财务数据等,必须遵循本规范
所有按股票或按日期获取的因子数据、行情数据、财务数据等,必须遵循本规范
- 按日期获取:**优先选择**,支持全市场批量获取
- 按股票获取:仅当 API 不支持按日期获取时使用
## 3. 文件结构要求
### 3.1 文件命名
```
{data_type}.py
api_{data_type}.py
```
示例:`daily.py``moneyflow.py``limit_list.py`
- 示例:`api_daily.py``api_moneyflow.py``api_limit_list.py`
- **必须**以 `api_` 前缀开头
- 使用小写字母和下划线
### 3.2 文件位置
所有接口文件必须位于 `src/data/` 目录下。
所有接口文件必须位于 `src/data/api_wrappers/` 目录下。
### 3.3 导出要求
新接口必须在 `src/data/__init__.py` 中导出:
新接口必须在 `src/data/api_wrappers/__init__.py` 中导出:
```python
from src.data.{module_name} import get_{data_type}
from src.data.api_wrappers.api_{data_type} import get_{data_type}
__all__ = [
# ... 其他导出 ...
"get_{data_type}",
@@ -59,7 +77,7 @@ __all__ = [
#### 4.1.1 按日期获取的接口(优先)
适用于:涨跌停、龙虎榜、筹码分布等。
适用于:涨跌停、龙虎榜、筹码分布、每日指标等。
**函数签名要求**
@@ -77,6 +95,7 @@ def get_{data_type}(
- 优先使用 `trade_date` 获取单日全市场数据
- 支持 `start_date + end_date` 获取区间数据
- `ts_code` 作为可选过滤参数
- **性能优势**: 单日全市场数据一次 API 调用即可完成
#### 4.1.2 按股票获取的接口
@@ -93,152 +112,504 @@ def get_{data_type}(
) -> pd.DataFrame:
```
**要求**
- `ts_code` 为必选参数
- 需要遍历所有股票获取全市场数据
### 4.2 文档字符串要求
函数必须包含 Google 风格的完整文档字符串,包含:
- 函数功能描述
- `Args` 部分:所有参数说明
- `Returns` 部分:返回的 DataFrame 包含的字段说明
- `Example` 部分:使用示例
函数必须包含 **Google 风格**的完整文档字符串,包含:
```python
def get_{data_type}(...) -> pd.DataFrame:
"""Fetch {数据描述} from Tushare.
This interface retrieves {详细描述}.
Args:
ts_code: Stock code (e.g., '000001.SZ', '600000.SH')
trade_date: Specific trade date (YYYYMMDD format)
start_date: Start date (YYYYMMDD format)
end_date: End date (YYYYMMDD format)
# 其他参数...
Returns:
pd.DataFrame with columns:
- ts_code: Stock code
- trade_date: Trade date (YYYYMMDD)
- {其他字段}: {字段描述}
Example:
>>> # Get single date data for all stocks
>>> data = get_{data_type}(trade_date='20240101')
>>>
>>> # Get date range data
>>> data = get_{data_type}(start_date='20240101', end_date='20240131')
>>>
>>> # Get specific stock data
>>> data = get_{data_type}(ts_code='000001.SZ', trade_date='20240101')
"""
```
### 4.3 日期格式要求
- 所有日期参数和返回值使用 `YYYYMMDD` 字符串格式
- 所有日期参数使用 **YYYYMMDD** 字符串格式
- 统一使用 `trade_date` 作为日期字段名
- 如果 API 返回其他日期字段名(如 `date``end_date`),必须在返回前重命名为 `trade_date`
- 如果 API 返回其他日期字段名(如 `date``end_date`),必须在返回前重命名为 `trade_date`
```python
if "date" in data.columns:
data = data.rename(columns={"date": "trade_date"})
```
### 4.4 股票代码要求
- 统一使用 `ts_code` 作为股票代码字段名
- 格式:`{code}.{exchange}`,如 `000001.SZ``600000.SH`
- 确保返回的 DataFrame 包含 `ts_code`
### 4.5 令牌桶限速要求
所有 API 调用必须通过 `TushareClient`,自动满足令牌桶限速要求
## 5. Sync 集成规范
### 5.1 DATASET_CONFIG 注册要求
新接口必须在 `DataSync.DATASET_CONFIG` 中注册,配置项:
所有 API 调用必须通过 `TushareClient`,自动满足令牌桶限速要求
```python
"{new_data_type}": {
"api_name": "{tushare_api_name}", # Tushare API 名称
"fetch_by": "date", # "date" 或 "stock"
"date_field": "trade_date",
"key_fields": ["ts_code", "trade_date"], # 用于去重的主键
from src.data.client import TushareClient
def get_{data_type}(...) -> pd.DataFrame:
client = TushareClient()
# Build parameters
params = {}
if trade_date:
params["trade_date"] = trade_date
if ts_code:
params["ts_code"] = ts_code
# ...
# Fetch data (rate limiting handled automatically)
data = client.query("{api_name}", **params)
return data
```
**注意**: `TushareClient` 自动处理:
- 令牌桶速率限制
- API 重试逻辑(指数退避)
- 配置加载
## 5. DuckDB 存储规范
### 5.1 存储架构
项目使用 **DuckDB** 作为持久化存储:
- **单例模式**: `Storage` 类确保单一数据库连接
- **线程安全**: `ThreadSafeStorage` 提供并发写入支持
- **UPSERT 支持**: `INSERT OR REPLACE` 自动处理重复数据
- **查询下推**: WHERE 条件在数据库层过滤
### 5.2 表结构设计
每个数据类型对应一个 DuckDB 表:
```sql
CREATE TABLE {data_type} (
ts_code VARCHAR(16) NOT NULL,
trade_date DATE NOT NULL,
# ...
PRIMARY KEY (ts_code, trade_date)
);
CREATE INDEX idx_{data_type}_date_code ON {data_type}(trade_date, ts_code);
```
**主键要求**:
- 必须包含 `ts_code``trade_date`
- 使用 UPSERT 确保幂等性
### 5.3 存储写入策略
**批量写入模式**(推荐用于多线程场景):
```python
from src.data.storage import ThreadSafeStorage
def sync_{data_type}(self, ...):
storage = ThreadSafeStorage()
# 收集数据到队列(不立即写入)
for data_chunk in data_generator:
storage.queue_save("{data_type}", data_chunk)
# 批量写入所有数据
storage.flush()
```
**直接写入模式**(适用于简单场景):
```python
from src.data.storage import Storage
storage = Storage()
storage.save("{data_type}", data, mode="append")
```
### 5.4 数据类型映射
标准字段类型映射(`DEFAULT_TYPE_MAPPING`
```python
DEFAULT_TYPE_MAPPING = {
"ts_code": "VARCHAR(16)",
"trade_date": "DATE",
"open": "DOUBLE",
"high": "DOUBLE",
"low": "DOUBLE",
"close": "DOUBLE",
"vol": "DOUBLE",
"amount": "DOUBLE",
# ... 其他字段
}
```
### 5.2 fetch_by 取值规则
## 6. Sync 集成规范
- **优先使用 `"date"`**:如果 API 支持按日期获取全市场数据
- 仅当 API 不支持按日期获取时才使用 `"stock"`
### 6.1 使用 db_manager 进行同步
### 5.3 sync 方法要求
必须实现对应的 sync 方法或复用通用方法:
项目使用 `db_manager` 模块提供高级同步功能:
```python
def sync_{data_type}(self, force_full: bool = False) -> pd.DataFrame:
"""Sync {数据描述}"""
return self.sync_dataset("{data_type}", force_full)
from src.data.db_manager import SyncManager, ensure_table
def sync_{data_type}(force_full: bool = False) -> pd.DataFrame:
"""Sync {数据描述} to DuckDB."""
manager = SyncManager()
# 确保表存在
ensure_table("{data_type}", schema={
"ts_code": "VARCHAR(16)",
"trade_date": "DATE",
# ... 其他字段
})
# 执行同步
result = manager.sync(
table_name="{data_type}",
fetch_func=get_{data_type},
start_date=start_date,
end_date=end_date,
force_full=force_full,
)
return result
```
同时提供便捷函数:
### 6.2 增量更新逻辑
`SyncManager` 自动处理增量更新:
1. **检查本地最新日期**: 从 DuckDB 获取 `MAX(trade_date)`
2. **获取交易日历**: 从 `api_trade_cal` 获取交易日范围
3. **计算需要同步的日期**: 本地最新日期 + 1 到最新交易日
4. **批量获取数据**: 按日期或按股票获取
5. **批量写入**: 使用 `ThreadSafeStorage` 队列写入
### 6.3 便捷函数
每个接口必须提供顶层便捷函数:
```python
def sync_{data_type}(force_full: bool = False) -> pd.DataFrame:
"""Sync {数据描述}"""
sync_manager = DataSync()
return sync_manager.sync_{data_type}(force_full)
"""Sync {数据描述} to local storage.
Args:
force_full: If True, force full reload from 20180101
Returns:
DataFrame with synced data
"""
# Implementation...
```
### 5.4 增量更新要求
## 7. 代码模板
- 必须实现增量更新逻辑(自动检查本地最新日期)
- 使用 `force_full` 参数支持强制全量同步
### 7.1 按日期获取接口模板
## 6. 存储规范
```python
"""{数据描述} interface.
### 6.1 存储方式
Fetch {数据描述} data from Tushare.
"""
所有数据通过 `Storage` 类进行 HDF5 存储。
import pandas as pd
from typing import Optional
from src.data.client import TushareClient
### 6.2 写入策略
**要求**:所有数据在请求完成后**一次性写入**,而非逐条写入。
def get_{data_type}(
trade_date: Optional[str] = None,
start_date: Optional[str] = None,
end_date: Optional[str] = None,
ts_code: Optional[str] = None,
) -> pd.DataFrame:
"""Fetch {数据描述} from Tushare.
### 6.3 去重要求
This interface retrieves {详细描述}.
使用 `key_fields` 配置的字段进行去重,默认使用 `["ts_code", "trade_date"]`
Args:
trade_date: Specific trade date (YYYYMMDD format)
start_date: Start date (YYYYMMDD format)
end_date: End date (YYYYMMDD format)
ts_code: Stock code filter (optional)
## 7. 测试规范
Returns:
pd.DataFrame with columns:
- ts_code: Stock code
- trade_date: Trade date (YYYYMMDD)
- {字段1}: {描述}
- {字段2}: {描述}
### 7.1 测试文件要求
Example:
>>> # Get all stocks for a single date
>>> data = get_{data_type}(trade_date='20240101')
>>>
>>> # Get date range data
>>> data = get_{data_type}(start_date='20240101', end_date='20240131')
"""
client = TushareClient()
# Build parameters
params = {}
if trade_date:
params["trade_date"] = trade_date
if start_date:
params["start_date"] = start_date
if end_date:
params["end_date"] = end_date
if ts_code:
params["ts_code"] = ts_code
# Fetch data
data = client.query("{tushare_api_name}", **params)
# Rename date column if needed
if "date" in data.columns:
data = data.rename(columns={"date": "trade_date"})
return data
```
### 7.2 按股票获取接口模板
```python
"""{数据描述} interface.
Fetch {数据描述} data from Tushare (per stock).
"""
import pandas as pd
from typing import Optional
from src.data.client import TushareClient
def get_{data_type}(
ts_code: str,
start_date: Optional[str] = None,
end_date: Optional[str] = None,
) -> pd.DataFrame:
"""Fetch {数据描述} for a specific stock.
Args:
ts_code: Stock code (e.g., '000001.SZ')
start_date: Start date (YYYYMMDD format)
end_date: End date (YYYYMMDD format)
Returns:
pd.DataFrame with {数据描述} data
"""
client = TushareClient()
params = {"ts_code": ts_code}
if start_date:
params["start_date"] = start_date
if end_date:
params["end_date"] = end_date
data = client.query("{tushare_api_name}", **params)
return data
```
### 7.3 Sync 函数模板
```python
from src.data.db_manager import SyncManager, ensure_table
from src.data.api_wrappers import get_{data_type}
def sync_{data_type}(force_full: bool = False) -> pd.DataFrame:
"""Sync {数据描述} to local DuckDB storage.
Args:
force_full: If True, force full reload from 20180101
Returns:
DataFrame with synced data
"""
manager = SyncManager()
# Ensure table exists with proper schema
ensure_table("{data_type}", schema={
"ts_code": "VARCHAR(16)",
"trade_date": "DATE",
# Add other fields...
})
# Perform sync
result = manager.sync(
table_name="{data_type}",
fetch_func=get_{data_type},
force_full=force_full,
)
return result
```
## 8. 测试规范
### 8.1 测试文件要求
必须创建对应的测试文件:`tests/test_{data_type}.py`
### 7.2 测试覆盖要求
### 8.2 测试覆盖要求
- 测试按日期获取
- 测试按股票获取(如果支持)
- 必须 mock `TushareClient`
- 测试覆盖正常和异常情况
```python
import pytest
import pandas as pd
from unittest.mock import patch, MagicMock
from src.data.api_wrappers.api_{data_type} import get_{data_type}
## 8. 新增接口完整流程
### 8.1 创建接口文件
class Test{DataType}:
"""Test suite for {data_type} API wrapper."""
1.`src/data/` 下创建 `{data_type}.py`
2. 实现数据获取函数,遵循第 4 节规范
@patch("src.data.api_wrappers.api_{data_type}.TushareClient")
def test_get_by_date(self, mock_client_class):
"""Test fetching data by date."""
# Setup mock
mock_client = MagicMock()
mock_client_class.return_value = mock_client
mock_client.query.return_value = pd.DataFrame({
"ts_code": ["000001.SZ"],
"trade_date": ["20240101"],
# ... other columns
})
### 8.2 注册 sync 支持
# Test
result = get_{data_type}(trade_date="20240101")
1.`sync.py``DataSync.DATASET_CONFIG` 中注册
2. 实现对应的 sync 方法
3. 提供便捷函数
# Assert
assert not result.empty
assert "ts_code" in result.columns
assert "trade_date" in result.columns
mock_client.query.assert_called_once()
### 8.3 更新导出
@patch("src.data.api_wrappers.api_{data_type}.TushareClient")
def test_get_by_stock(self, mock_client_class):
"""Test fetching data by stock code."""
# Similar setup...
pass
`src/data/__init__.py` 中导出接口函数。
@patch("src.data.api_wrappers.api_{data_type}.TushareClient")
def test_empty_response(self, mock_client_class):
"""Test handling empty response."""
mock_client = MagicMock()
mock_client_class.return_value = mock_client
mock_client.query.return_value = pd.DataFrame()
### 8.4 创建测试
result = get_{data_type}(trade_date="20240101")
assert result.empty
```
创建 `tests/test_{data_type}.py`,覆盖关键场景。
### 8.3 Mock 规范
## 9. 检查清单
- 在导入位置打补丁:`patch('src.data.api_wrappers.api_{data_type}.TushareClient')`
- 测试正常和异常情况
- 验证参数传递正确
### 9.1 文件结构
- [ ] 文件位于 `src/data/{data_type}.py`
- [ ] 已更新 `src/data/__init__.py` 导出公共接口
## 9. 使用 Skill 自动生成
### 9.1 准备工作
1.`api.md` 中定义接口信息,包含:
- 接口名称和描述
- 输入参数(名称、类型、必选、描述)
- 输出参数(名称、类型、描述)
### 9.2 调用 Skill
告知 Claude 要封装的接口名称:
> "帮我封装 {data_type} 接口"
> "为 {data_type} 接口生成代码"
### 9.3 自动生成内容
Skill 会自动:
1. 解析 `api.md` 中的接口定义
2. 生成 `src/data/api_wrappers/api_{data_type}.py`
3. 更新 `src/data/api_wrappers/__init__.py` 导出
4. 生成 `tests/test_{data_type}.py` 测试文件
5. 提供 `sync_{data_type}()` 函数模板
## 10. 检查清单
### 10.1 文件结构
- [ ] 文件位于 `src/data/api_wrappers/api_{data_type}.py`
- [ ] 已更新 `src/data/api_wrappers/__init__.py` 导出公共接口
- [ ] 已创建 `tests/test_{data_type}.py` 测试文件
### 9.2 接口实现
### 10.2 接口实现
- [ ] 数据获取函数使用 `TushareClient`
- [ ] 函数包含完整的 Google 风格文档字符串
- [ ] 日期参数使用 `YYYYMMDD` 格式
- [ ] 返回的 DataFrame 包含 `ts_code``trade_date` 字段
- [ ] 优先实现按日期获取的接口(如果 API 支持)
- [ ] 参数传递前检查是否为 None
### 9.3 Sync 集成
- [ ] 已在 `DataSync.DATASET_CONFIG` 中注册
- [ ] 正确设置 `fetch_by`"date" 或 "stock"
- [ ] 正确设置 `date_field``key_fields`
- [ ] 已实现对应的 sync 方法或复用通用方法
- [ ] 增量更新逻辑正确(检查本地最新日期)
### 10.3 存储集成
- [ ] 使用 `Storage``ThreadSafeStorage` 进行数据存储
- [ ] 表结构包含 `ts_code``trade_date` 作为主键
- [ ] 使用 UPSERT 模式(`INSERT OR REPLACE`
- [ ] 多线程场景使用 `queue_save()` + `flush()` 模式
### 9.4 存储优化
- [ ] 所有数据一次性写入(非逐条)
- [ ] 使用 `storage.save(mode="append")` 进行增量保存
- [ ] 去重字段配置正确
### 10.4 Sync 集成
- [ ] 使用 `db_manager` 模块进行同步管理
- [ ] 实现 `sync_{data_type}()` 便捷函数
- [ ] 支持 `force_full` 参数
- [ ] 增量更新逻辑正确
### 9.5 测试
### 10.5 测试
- [ ] 已编写单元测试
- [ ] 已 mock TushareClient
- [ ] 已 mock `TushareClient`
- [ ] 测试覆盖按日期和按股票获取
- [ ] 测试覆盖正常和异常情况
## 11. 示例参考
### 11.1 完整示例api_daily.py
参见 `src/data/api_wrappers/api_daily.py` - 按股票获取日线数据的完整实现。
### 11.2 完整示例api_trade_cal.py
参见 `src/data/api_wrappers/api_trade_cal.py` - 特殊接口(交易日历)的实现,包含 HDF5 缓存逻辑。
### 11.3 完整示例api_stock_basic.py
参见 `src/data/api_wrappers/api_stock_basic.py` - 特殊接口(股票基础信息)的实现,包含 CSV 存储逻辑。
---
**最后更新**: 2026-02-01
**最后更新**: 2026-02-23
**版本**: v2.0 - 更新 DuckDB 存储规范,添加 Skill 自动化说明

View File

@@ -7,15 +7,21 @@ Available APIs:
- api_daily: Daily market data (日线行情)
- api_stock_basic: Stock basic information (股票基本信息)
- api_trade_cal: Trading calendar (交易日历)
- api_namechange: Stock name change history (股票曾用名)
- api_bak_basic: Stock historical list (股票历史列表)
Example:
>>> from src.data.api_wrappers import get_daily, get_stock_basic, get_trade_cal
>>> from src.data.api_wrappers import get_daily, get_stock_basic, get_trade_cal, get_bak_basic
>>> from src.data.api_wrappers import get_bak_basic, sync_bak_basic
>>> data = get_daily('000001.SZ', start_date='20240101', end_date='20240131')
>>> stocks = get_stock_basic()
>>> calendar = get_trade_cal('20240101', '20240131')
>>> bak_basic = get_bak_basic(trade_date='20240101')
"""
from src.data.api_wrappers.api_daily import get_daily
from src.data.api_wrappers.api_daily import get_daily, sync_daily, preview_daily_sync, DailySync
from src.data.api_wrappers.api_bak_basic import get_bak_basic, sync_bak_basic
from src.data.api_wrappers.api_namechange import get_namechange, sync_namechange
from src.data.api_wrappers.api_stock_basic import get_stock_basic, sync_all_stocks
from src.data.api_wrappers.api_trade_cal import (
get_trade_cal,
@@ -28,6 +34,15 @@ from src.data.api_wrappers.api_trade_cal import (
__all__ = [
# Daily market data
"get_daily",
"sync_daily",
"preview_daily_sync",
"DailySync",
# Historical stock list
"get_bak_basic",
"sync_bak_basic",
# Namechange
"get_namechange",
"sync_namechange",
# Stock basic information
"get_stock_basic",
"sync_all_stocks",

View File

@@ -251,3 +251,98 @@ df = pro.query('daily_basic', ts_code='', trade_date='20180726',fields='ts_code,
17 000708.SZ 20180726 0.5575 0.70 10.3674 1.0276
18 002626.SZ 20180726 0.6187 0.83 22.7580 4.2446
19 600816.SH 20180726 0.6745 0.65 11.0778 3.2214
股票曾用名
接口namechange
描述:历史名称变更记录
输入参数
名称 类型 必选 描述
ts_code str N TS代码
start_date str N 公告开始日期
end_date str N 公告结束日期
输出参数
名称 类型 默认输出 描述
ts_code str Y TS代码
name str Y 证券名称
start_date str Y 开始日期
end_date str Y 结束日期
ann_date str Y 公告日期
change_reason str Y 变更原因
接口示例
pro = ts.pro_api()
df = pro.namechange(ts_code='600848.SH', fields='ts_code,name,start_date,end_date,change_reason')
数据样例
ts_code name start_date end_date change_reason
0 600848.SH 上海临港 20151118 None 改名
1 600848.SH 自仪股份 20070514 20151117 撤销ST
2 600848.SH ST自仪 20061026 20070513 完成股改
3 600848.SH SST自仪 20061009 20061025 未股改加S
4 600848.SH ST自仪 20010508 20061008 ST
5 600848.SH 自仪股份 19940324 20010507 其他
股票历史列表(历史每天股票列表)
接口bak_basic
描述获取备用基础列表数据从2016年开始
限量单次最大7000条可以根据日期参数循环获取历史正式权限需要5000积分。
输入参数
名称 类型 必选 描述
trade_date str N 交易日期
ts_code str N 股票代码
输出参数
名称 类型 默认显示 描述
trade_date str Y 交易日期
ts_code str Y TS股票代码
name str Y 股票名称
industry str Y 行业
area str Y 地域
pe float Y 市盈率(动)
float_share float Y 流通股本(亿)
total_share float Y 总股本(亿)
total_assets float Y 总资产(亿)
liquid_assets float Y 流动资产(亿)
fixed_assets float Y 固定资产(亿)
reserved float Y 公积金
reserved_pershare float Y 每股公积金
eps float Y 每股收益
bvps float Y 每股净资产
pb float Y 市净率
list_date str Y 上市日期
undp float Y 未分配利润
per_undp float Y 每股未分配利润
rev_yoy float Y 收入同比(%
profit_yoy float Y 利润同比(%
gpr float Y 毛利率(%
npr float Y 净利润率(%
holder_num int Y 股东人数
接口示例
pro = ts.pro_api()
df = pro.bak_basic(trade_date='20211012', fields='trade_date,ts_code,name,industry,pe')
数据样例
trade_date ts_code name industry pe
0 20211012 300605.SZ 恒锋信息 软件服务 56.4400
1 20211012 301017.SZ 漱玉平民 医药商业 58.7600
2 20211012 300755.SZ 华致酒行 其他商业 23.0000
3 20211012 300255.SZ 常山药业 生物制药 24.9900
4 20211012 688378.SH 奥来德 专用机械 24.9600
... ... ... ... ... ...
4529 20211012 688257.SH 新锐股份 机械基件 0.0000
4530 20211012 688255.SH 凯尔达 机械基件 0.0000
4531 20211012 688211.SH 中科微至 专用机械 0.0000
4532 20211012 605567.SH 春雪食品 食品 0.0000
4533 20211012 605566.SH 福莱蒽特 染料涂料 0.0000

View File

@@ -0,0 +1,243 @@
"""Stock historical list interface.
Fetch daily stock list from Tushare bak_basic API.
Data available from 2016 onwards.
"""
import pandas as pd
from typing import Optional, List
from datetime import datetime, timedelta
from tqdm import tqdm
from src.data.client import TushareClient
from src.data.storage import ThreadSafeStorage, Storage
from src.data.db_manager import ensure_table
def get_bak_basic(
trade_date: Optional[str] = None,
ts_code: Optional[str] = None,
) -> pd.DataFrame:
"""Fetch historical stock list from Tushare.
This interface retrieves the daily stock list including basic information
for all stocks on a specific trade date. Data is available from 2016 onwards.
Args:
trade_date: Specific trade date in YYYYMMDD format
ts_code: Stock code filter (optional, e.g., '000001.SZ')
Returns:
pd.DataFrame with columns:
- trade_date: Trade date (YYYYMMDD)
- ts_code: TS stock code
- name: Stock name
- industry: Industry
- area: Region
- pe: P/E ratio (dynamic)
- float_share: Float shares (100 million)
- total_share: Total shares (100 million)
- total_assets: Total assets (100 million)
- liquid_assets: Liquid assets (100 million)
- fixed_assets: Fixed assets (100 million)
- reserved: Reserve fund
- reserved_pershare: Reserve per share
- eps: Earnings per share
- bvps: Book value per share
- pb: P/B ratio
- list_date: Listing date
- undp: Undistributed profit
- per_undp: Undistributed profit per share
- rev_yoy: Revenue YoY (%)
- profit_yoy: Profit YoY (%)
- gpr: Gross profit ratio (%)
- npr: Net profit ratio (%)
- holder_num: Number of shareholders
Example:
>>> # Get all stocks for a single date
>>> data = get_bak_basic(trade_date='20240101')
>>>
>>> # Get specific stock data
>>> data = get_bak_basic(ts_code='000001.SZ', trade_date='20240101')
"""
client = TushareClient()
# Build parameters
params = {}
if trade_date:
params["trade_date"] = trade_date
if ts_code:
params["ts_code"] = ts_code
# Fetch data
data = client.query("bak_basic", **params)
return data
def sync_bak_basic(
start_date: Optional[str] = None,
end_date: Optional[str] = None,
force_full: bool = False,
) -> pd.DataFrame:
"""Sync historical stock list to DuckDB with intelligent incremental sync.
Logic:
- If table doesn't exist: create table + composite index (trade_date, ts_code) + full sync
- If table exists: incremental sync from last_date + 1
Args:
start_date: Start date for sync (YYYYMMDD format, default: 20160101 for full, last_date+1 for incremental)
end_date: End date for sync (YYYYMMDD format, default: today)
force_full: If True, force full reload from 20160101
Returns:
pd.DataFrame with synced data
"""
from src.data.db_manager import ensure_table
TABLE_NAME = "bak_basic"
storage = Storage()
thread_storage = ThreadSafeStorage()
# Default end date
if end_date is None:
end_date = datetime.now().strftime("%Y%m%d")
# Check if table exists
table_exists = storage.exists(TABLE_NAME)
if not table_exists or force_full:
# ===== FULL SYNC =====
# 1. Create table with schema
# 2. Create composite index (trade_date, ts_code)
# 3. Full sync from start_date
if not table_exists:
print(f"[sync_bak_basic] Table '{TABLE_NAME}' doesn't exist, creating...")
# Fetch sample to get schema
sample = get_bak_basic(trade_date=end_date)
if sample.empty:
sample = get_bak_basic(trade_date="20240102")
if sample.empty:
print("[sync_bak_basic] Cannot create table: no sample data available")
return pd.DataFrame()
# Create table with schema
columns = []
for col in sample.columns:
dtype = str(sample[col].dtype)
if "int" in dtype:
col_type = "INTEGER"
elif "float" in dtype:
col_type = "DOUBLE"
else:
col_type = "VARCHAR"
columns.append(f'"{col}" {col_type}')
columns_sql = ", ".join(columns)
create_sql = f'CREATE TABLE IF NOT EXISTS "{TABLE_NAME}" ({columns_sql}, PRIMARY KEY ("trade_date", "ts_code"))'
try:
storage._connection.execute(create_sql)
print(f"[sync_bak_basic] Created table '{TABLE_NAME}'")
except Exception as e:
print(f"[sync_bak_basic] Error creating table: {e}")
# Create composite index
try:
storage._connection.execute(f"""
CREATE INDEX IF NOT EXISTS "idx_bak_basic_date_code"
ON "{TABLE_NAME}"("trade_date", "ts_code")
""")
print(f"[sync_bak_basic] Created composite index on (trade_date, ts_code)")
except Exception as e:
print(f"[sync_bak_basic] Error creating index: {e}")
# Determine sync dates
sync_start = start_date or "20160101"
mode = "FULL"
print(f"[sync_bak_basic] Mode: {mode} SYNC from {sync_start} to {end_date}")
else:
# ===== INCREMENTAL SYNC =====
# Check last date in table, sync from last_date + 1
try:
result = storage._connection.execute(
f'SELECT MAX("trade_date") FROM "{TABLE_NAME}"'
).fetchone()
last_date = result[0] if result and result[0] else None
except Exception as e:
print(f"[sync_bak_basic] Error getting last date: {e}")
last_date = None
if last_date is None:
# Table exists but empty, do full sync
sync_start = start_date or "20160101"
mode = "FULL (empty table)"
else:
# Incremental from last_date + 1
# Handle both YYYYMMDD and YYYY-MM-DD formats
last_date_str = str(last_date).replace("-", "")
last_dt = datetime.strptime(last_date_str, "%Y%m%d")
next_dt = last_dt + timedelta(days=1)
sync_start = next_dt.strftime("%Y%m%d")
mode = "INCREMENTAL"
# Skip if already up to date
if sync_start > end_date:
print(f"[sync_bak_basic] Data is up-to-date (last: {last_date}), skipping sync")
return pd.DataFrame()
print(f"[sync_bak_basic] Mode: {mode} from {sync_start} to {end_date} (last: {last_date})")
# ===== FETCH AND SAVE DATA =====
all_data: List[pd.DataFrame] = []
current = datetime.strptime(sync_start, "%Y%m%d")
end_dt = datetime.strptime(end_date, "%Y%m%d")
# Calculate total days for progress bar
total_days = (end_dt - current).days + 1
print(f"[sync_bak_basic] Fetching data for {total_days} days...")
with tqdm(total=total_days, desc="Syncing dates") as pbar:
while current <= end_dt:
date_str = current.strftime("%Y%m%d")
try:
data = get_bak_basic(trade_date=date_str)
if not data.empty:
all_data.append(data)
pbar.set_postfix({"date": date_str, "records": len(data)})
except Exception as e:
print(f" {date_str}: ERROR - {e}")
current += timedelta(days=1)
pbar.update(1)
if not all_data:
print("[sync_bak_basic] No data fetched")
return pd.DataFrame()
# Combine and save
combined = pd.concat(all_data, ignore_index=True)
print(f"[sync_bak_basic] Total records: {len(combined)}")
# Delete existing data for the date range and append new data
storage._connection.execute(f'DELETE FROM "{TABLE_NAME}" WHERE "trade_date" >= ?', [sync_start])
thread_storage.queue_save(TABLE_NAME, combined)
thread_storage.flush()
print(f"[sync_bak_basic] Saved {len(combined)} records to DuckDB")
return combined
if __name__ == "__main__":
# Test sync
result = sync_bak_basic(end_date="20240102")
print(f"Synced {len(result)} records")
if not result.empty:
print("\nSample data:")
print(result.head())

View File

@@ -2,11 +2,27 @@
A single function to fetch A股日线行情 data from Tushare.
Supports all output fields including tor (换手率) and vr (量比).
This module provides both single-stock fetching (get_daily) and
batch synchronization (DailySync class) for daily market data.
"""
import pandas as pd
from typing import Optional, List, Literal
from typing import Optional, List, Literal, Dict
from datetime import datetime, timedelta
from tqdm import tqdm
from concurrent.futures import ThreadPoolExecutor, as_completed
import threading
from src.data.client import TushareClient
from src.data.storage import ThreadSafeStorage, Storage
from src.data.utils import get_today_date, get_next_date, DEFAULT_START_DATE
from src.data.api_wrappers.api_trade_cal import (
get_first_trading_day,
get_last_trading_day,
sync_trade_cal_cache,
)
from src.data.api_wrappers.api_stock_basic import _get_csv_path, sync_all_stocks
def get_daily(
@@ -71,3 +87,744 @@ def get_daily(
data = client.query("pro_bar", **params)
return data
# =============================================================================
# DailySync - 日线数据批量同步类
# =============================================================================
class DailySync:
"""日线数据批量同步管理器,支持全量/增量同步。
功能特性:
- 多线程并发获取ThreadPoolExecutor
- 增量同步(自动检测上次同步位置)
- 内存缓存(避免重复磁盘读取)
- 异常立即停止(确保数据一致性)
- 预览模式(预览同步数据量,不实际写入)
"""
# 默认工作线程数
DEFAULT_MAX_WORKERS = 10
def __init__(self, max_workers: Optional[int] = None):
"""初始化同步管理器。
Args:
max_workers: 工作线程数(默认: 10
"""
self.storage = ThreadSafeStorage()
self.client = TushareClient()
self.max_workers = max_workers or self.DEFAULT_MAX_WORKERS
self._stop_flag = threading.Event()
self._stop_flag.set() # 初始为未停止状态
self._cached_daily_data: Optional[pd.DataFrame] = None # 日线数据缓存
def _load_daily_data(self) -> pd.DataFrame:
"""从存储加载日线数据(带缓存)。
该方法会将数据缓存在内存中以避免重复磁盘读取。
调用 clear_cache() 可强制重新加载。
Returns:
缓存或从存储加载的日线数据 DataFrame
"""
if self._cached_daily_data is None:
self._cached_daily_data = self.storage.load("daily")
return self._cached_daily_data
def clear_cache(self) -> None:
"""清除缓存的日线数据,强制下次访问时重新加载。"""
self._cached_daily_data = None
def get_all_stock_codes(self, only_listed: bool = True) -> list:
"""从本地存储获取所有股票代码。
优先使用 stock_basic.csv 以确保包含所有股票,
避免回测中的前视偏差。
Args:
only_listed: 若为 True仅返回当前上市股票L 状态)。
设为 False 可包含退市股票(用于完整回测)。
Returns:
股票代码列表
"""
# 确保 stock_basic.csv 是最新的
print("[DailySync] Ensuring stock_basic.csv is up-to-date...")
sync_all_stocks()
# 从 stock_basic.csv 文件获取
stock_csv_path = _get_csv_path()
if stock_csv_path.exists():
print(f"[DailySync] Reading stock_basic from CSV: {stock_csv_path}")
try:
stock_df = pd.read_csv(stock_csv_path, encoding="utf-8-sig")
if not stock_df.empty and "ts_code" in stock_df.columns:
# 根据 list_status 过滤
if only_listed and "list_status" in stock_df.columns:
listed_stocks = stock_df[stock_df["list_status"] == "L"]
codes = listed_stocks["ts_code"].unique().tolist()
total = len(stock_df["ts_code"].unique())
print(
f"[DailySync] Found {len(codes)} listed stocks (filtered from {total} total)"
)
else:
codes = stock_df["ts_code"].unique().tolist()
print(
f"[DailySync] Found {len(codes)} stock codes from stock_basic.csv"
)
return codes
else:
print(
f"[DailySync] stock_basic.csv exists but no ts_code column or empty"
)
except Exception as e:
print(f"[DailySync] Error reading stock_basic.csv: {e}")
# 回退:从日线存储获取
print(
"[DailySync] stock_basic.csv not available, falling back to daily data..."
)
daily_data = self._load_daily_data()
if not daily_data.empty and "ts_code" in daily_data.columns:
codes = daily_data["ts_code"].unique().tolist()
print(f"[DailySync] Found {len(codes)} stock codes from daily data")
return codes
print("[DailySync] No stock codes found in local storage")
return []
def get_global_last_date(self) -> Optional[str]:
"""获取全局最后交易日期。
Returns:
最后交易日期字符串,若无数据则返回 None
"""
daily_data = self._load_daily_data()
if daily_data.empty or "trade_date" not in daily_data.columns:
return None
return str(daily_data["trade_date"].max())
def get_global_first_date(self) -> Optional[str]:
"""获取全局最早交易日期。
Returns:
最早交易日期字符串,若无数据则返回 None
"""
daily_data = self._load_daily_data()
if daily_data.empty or "trade_date" not in daily_data.columns:
return None
return str(daily_data["trade_date"].min())
def get_trade_calendar_bounds(
self, start_date: str, end_date: str
) -> tuple[Optional[str], Optional[str]]:
"""从交易日历获取首尾交易日。
Args:
start_date: 开始日期YYYYMMDD 格式)
end_date: 结束日期YYYYMMDD 格式)
Returns:
(首交易日, 尾交易日) 元组,若出错则返回 (None, None)
"""
try:
first_day = get_first_trading_day(start_date, end_date)
last_day = get_last_trading_day(start_date, end_date)
return (first_day, last_day)
except Exception as e:
print(f"[ERROR] Failed to get trade calendar bounds: {e}")
return (None, None)
def check_sync_needed(
self,
force_full: bool = False,
table_name: str = "daily",
) -> tuple[bool, Optional[str], Optional[str], Optional[str]]:
"""基于交易日历检查是否需要同步。
该方法比较本地数据日期范围与交易日历,
以确定是否需要获取新数据。
逻辑:
- 若 force_full需要同步返回 (True, 20180101, today)
- 若无本地数据:需要同步,返回 (True, 20180101, today)
- 若存在本地数据:
- 从交易日历获取最后交易日
- 若本地最后日期 >= 日历最后日期:无需同步
- 否则:从本地最后日期+1 到最新交易日同步
Args:
force_full: 若为 True始终返回需要同步
table_name: 要检查的表名(默认: "daily"
Returns:
(需要同步, 起始日期, 结束日期, 本地最后日期)
- 需要同步: True 表示应继续同步
- 起始日期: 同步起始日期(无需同步时为 None
- 结束日期: 同步结束日期(无需同步时为 None
- 本地最后日期: 本地数据最后日期(用于增量同步)
"""
# 若 force_full始终同步
if force_full:
print("[DailySync] Force full sync requested")
return (True, DEFAULT_START_DATE, get_today_date(), None)
# 检查特定表的本地数据是否存在
storage = Storage()
table_data = (
storage.load(table_name) if storage.exists(table_name) else pd.DataFrame()
)
if table_data.empty or "trade_date" not in table_data.columns:
print(
f"[DailySync] No local data found for table '{table_name}', full sync needed"
)
return (True, DEFAULT_START_DATE, get_today_date(), None)
# 获取本地数据最后日期
local_last_date = str(table_data["trade_date"].max())
print(f"[DailySync] Local data last date: {local_last_date}")
# 从交易日历获取最新交易日
today = get_today_date()
_, cal_last = self.get_trade_calendar_bounds(DEFAULT_START_DATE, today)
if cal_last is None:
print("[DailySync] Failed to get trade calendar, proceeding with sync")
return (True, DEFAULT_START_DATE, today, local_last_date)
print(f"[DailySync] Calendar last trading day: {cal_last}")
# 比较本地最后日期与日历最后日期
print(
f"[DailySync] Comparing: local={local_last_date} (type={type(local_last_date).__name__}), "
f"cal={cal_last} (type={type(cal_last).__name__})"
)
try:
local_last_int = int(local_last_date)
cal_last_int = int(cal_last)
print(
f"[DailySync] Comparing integers: local={local_last_int} >= cal={cal_last_int} = "
f"{local_last_int >= cal_last_int}"
)
if local_last_int >= cal_last_int:
print(
"[DailySync] Local data is up-to-date, SKIPPING sync (no tokens consumed)"
)
return (False, None, None, None)
except (ValueError, TypeError) as e:
print(f"[ERROR] Date comparison failed: {e}")
# 需要从本地最后日期+1 同步到最新交易日
sync_start = get_next_date(local_last_date)
print(f"[DailySync] Incremental sync needed from {sync_start} to {cal_last}")
return (True, sync_start, cal_last, local_last_date)
def preview_sync(
self,
force_full: bool = False,
start_date: Optional[str] = None,
end_date: Optional[str] = None,
sample_size: int = 3,
) -> dict:
"""预览同步数据量和样本(不实际同步)。
该方法提供即将同步的数据的预览,包括:
- 将同步的股票数量
- 同步日期范围
- 预估总记录数
- 前几只股票的样本数据
Args:
force_full: 若为 True预览全量同步从 20180101
start_date: 手动指定起始日期(覆盖自动检测)
end_date: 手动指定结束日期(默认为今天)
sample_size: 预览用样本股票数量(默认: 3
Returns:
包含预览信息的字典:
{
'sync_needed': bool,
'stock_count': int,
'start_date': str,
'end_date': str,
'estimated_records': int,
'sample_data': pd.DataFrame,
'mode': str, # 'full''incremental'
}
"""
print("\n" + "=" * 60)
print("[DailySync] Preview Mode - Analyzing sync requirements...")
print("=" * 60)
# 首先确保交易日历缓存是最新的
print("[DailySync] Syncing trade calendar cache...")
sync_trade_cal_cache()
# 确定日期范围
if end_date is None:
end_date = get_today_date()
# 检查是否需要同步
sync_needed, cal_start, cal_end, local_last = self.check_sync_needed(force_full)
if not sync_needed:
print("\n" + "=" * 60)
print("[DailySync] Preview Result")
print("=" * 60)
print(" Sync Status: NOT NEEDED")
print(" Reason: Local data is up-to-date with trade calendar")
print("=" * 60)
return {
"sync_needed": False,
"stock_count": 0,
"start_date": None,
"end_date": None,
"estimated_records": 0,
"sample_data": pd.DataFrame(),
"mode": "none",
}
# 使用 check_sync_needed 返回的日期
if cal_start and cal_end:
sync_start_date = cal_start
end_date = cal_end
else:
sync_start_date = start_date or DEFAULT_START_DATE
if end_date is None:
end_date = get_today_date()
# 确定同步模式
if force_full:
mode = "full"
print(f"[DailySync] Mode: FULL SYNC from {sync_start_date} to {end_date}")
elif local_last and cal_start and sync_start_date == get_next_date(local_last):
mode = "incremental"
print(f"[DailySync] Mode: INCREmental SYNC (bandwidth optimized)")
print(f"[DailySync] Sync from: {sync_start_date} to {end_date}")
else:
mode = "partial"
print(f"[DailySync] Mode: SYNC from {sync_start_date} to {end_date}")
# 获取所有股票代码
stock_codes = self.get_all_stock_codes()
if not stock_codes:
print("[DailySync] No stocks found to sync")
return {
"sync_needed": False,
"stock_count": 0,
"start_date": None,
"end_date": None,
"estimated_records": 0,
"sample_data": pd.DataFrame(),
"mode": "none",
}
stock_count = len(stock_codes)
print(f"[DailySync] Total stocks to sync: {stock_count}")
# 从前几只股票获取样本数据
print(f"[DailySync] Fetching sample data from {sample_size} stocks...")
sample_data_list = []
sample_codes = stock_codes[:sample_size]
for ts_code in sample_codes:
try:
data = self.client.query(
"pro_bar",
ts_code=ts_code,
start_date=sync_start_date,
end_date=end_date,
factors="tor,vr",
)
if not data.empty:
sample_data_list.append(data)
print(f" - {ts_code}: {len(data)} records")
except Exception as e:
print(f" - {ts_code}: Error fetching - {e}")
# 合并样本数据
sample_df = (
pd.concat(sample_data_list, ignore_index=True)
if sample_data_list
else pd.DataFrame()
)
# 基于样本估算总记录数
if not sample_df.empty:
avg_records_per_stock = len(sample_df) / len(sample_data_list)
estimated_records = int(avg_records_per_stock * stock_count)
else:
estimated_records = 0
# 显示预览结果
print("\n" + "=" * 60)
print("[DailySync] Preview Result")
print("=" * 60)
print(f" Sync Mode: {mode.upper()}")
print(f" Date Range: {sync_start_date} to {end_date}")
print(f" Stocks to Sync: {stock_count}")
print(f" Sample Stocks Checked: {len(sample_data_list)}/{sample_size}")
print(f" Estimated Total Records: ~{estimated_records:,}")
if not sample_df.empty:
print(f"\n Sample Data Preview (first {len(sample_df)} rows):")
print(" " + "-" * 56)
# 以紧凑格式显示样本数据
preview_cols = [
"ts_code",
"trade_date",
"open",
"high",
"low",
"close",
"vol",
]
available_cols = [c for c in preview_cols if c in sample_df.columns]
sample_display = sample_df[available_cols].head(10)
for idx, row in sample_display.iterrows():
print(f" {row.to_dict()}")
print(" " + "-" * 56)
print("=" * 60)
return {
"sync_needed": True,
"stock_count": stock_count,
"start_date": sync_start_date,
"end_date": end_date,
"estimated_records": estimated_records,
"sample_data": sample_df,
"mode": mode,
}
def sync_single_stock(
self,
ts_code: str,
start_date: str,
end_date: str,
) -> pd.DataFrame:
"""同步单只股票的日线数据。
Args:
ts_code: 股票代码
start_date: 起始日期YYYYMMDD
end_date: 结束日期YYYYMMDD
Returns:
包含日线市场数据的 DataFrame
"""
# 检查是否应该停止同步(用于异常处理)
if not self._stop_flag.is_set():
return pd.DataFrame()
try:
# 使用共享客户端进行跨线程速率限制
data = self.client.query(
"pro_bar",
ts_code=ts_code,
start_date=start_date,
end_date=end_date,
factors="tor,vr",
)
return data
except Exception as e:
# 设置停止标志以通知其他线程停止
self._stop_flag.clear()
print(f"[ERROR] Exception syncing {ts_code}: {e}")
raise
def sync_all(
self,
force_full: bool = False,
start_date: Optional[str] = None,
end_date: Optional[str] = None,
max_workers: Optional[int] = None,
dry_run: bool = False,
) -> Dict[str, pd.DataFrame]:
"""同步本地存储中所有股票的日线数据。
该函数:
1. 从本地存储读取股票代码daily 或 stock_basic
2. 检查交易日历确定是否需要同步:
- 若本地数据匹配交易日历边界,则跳过同步(节省 token
- 否则,从本地最后日期+1 同步到最新交易日(带宽优化)
3. 使用多线程并发获取(带速率限制)
4. 跳过返回空数据的股票(退市/不可用)
5. 遇异常立即停止
Args:
force_full: 若为 True强制从 20180101 完整重载
start_date: 手动指定起始日期(覆盖自动检测)
end_date: 手动指定结束日期(默认为今天)
max_workers: 工作线程数(默认: 10
dry_run: 若为 True仅预览将要同步的内容不写入数据
Returns:
映射 ts_code 到 DataFrame 的字典(若跳过或 dry_run 则为空字典)
"""
print("\n" + "=" * 60)
print("[DailySync] Starting daily data sync...")
print("=" * 60)
# 首先确保交易日历缓存是最新的(使用增量同步)
print("[DailySync] Syncing trade calendar cache...")
sync_trade_cal_cache()
# 确定日期范围
if end_date is None:
end_date = get_today_date()
# 基于交易日历检查是否需要同步
sync_needed, cal_start, cal_end, local_last = self.check_sync_needed(force_full)
if not sync_needed:
# 跳过同步 - 不消耗 token
print("\n" + "=" * 60)
print("[DailySync] Sync Summary")
print("=" * 60)
print(" Sync: SKIPPED (local data up-to-date with trade calendar)")
print(" Tokens saved: 0 consumed")
print("=" * 60)
return {}
# 使用 check_sync_needed 返回的日期(会计算增量起始日期)
if cal_start and cal_end:
sync_start_date = cal_start
end_date = cal_end
else:
# 回退到默认逻辑
sync_start_date = start_date or DEFAULT_START_DATE
if end_date is None:
end_date = get_today_date()
# 确定同步模式
if force_full:
mode = "full"
print(f"[DailySync] Mode: FULL SYNC from {sync_start_date} to {end_date}")
elif local_last and cal_start and sync_start_date == get_next_date(local_last):
mode = "incremental"
print(f"[DailySync] Mode: INCREMENTAL SYNC (bandwidth optimized)")
print(f"[DailySync] Sync from: {sync_start_date} to {end_date}")
else:
mode = "partial"
print(f"[DailySync] Mode: SYNC from {sync_start_date} to {end_date}")
# 获取所有股票代码
stock_codes = self.get_all_stock_codes()
if not stock_codes:
print("[DailySync] No stocks found to sync")
return {}
print(f"[DailySync] Total stocks to sync: {len(stock_codes)}")
print(f"[DailySync] Using {max_workers or self.max_workers} worker threads")
# 处理 dry run 模式
if dry_run:
print("\n" + "=" * 60)
print("[DailySync] DRY RUN MODE - No data will be written")
print("=" * 60)
print(f" Would sync {len(stock_codes)} stocks")
print(f" Date range: {sync_start_date} to {end_date}")
print(f" Mode: {mode}")
print("=" * 60)
return {}
# 为新同步重置停止标志
self._stop_flag.set()
# 多线程并发获取
results: Dict[str, pd.DataFrame] = {}
error_occurred = False
exception_to_raise = None
def sync_task(ts_code: str) -> tuple[str, pd.DataFrame]:
"""每只股票的任务函数。"""
try:
data = self.sync_single_stock(
ts_code=ts_code,
start_date=sync_start_date,
end_date=end_date,
)
return (ts_code, data)
except Exception as e:
# 重新抛出以被 Future 捕获
raise
# 使用 ThreadPoolExecutor 进行并发获取
workers = max_workers or self.max_workers
with ThreadPoolExecutor(max_workers=workers) as executor:
# 提交所有任务并跟踪 futures 与股票代码的映射
future_to_code = {
executor.submit(sync_task, ts_code): ts_code for ts_code in stock_codes
}
# 使用 as_completed 处理结果
error_count = 0
empty_count = 0
success_count = 0
# 创建进度条
pbar = tqdm(total=len(stock_codes), desc="Syncing stocks")
try:
# 处理完成的 futures
for future in as_completed(future_to_code):
ts_code = future_to_code[future]
try:
_, data = future.result()
if data is not None and not data.empty:
results[ts_code] = data
success_count += 1
else:
# 空数据 - 股票可能已退市或不可用
empty_count += 1
print(
f"[DailySync] Stock {ts_code}: empty data (skipped, may be delisted)"
)
except Exception as e:
# 发生异常 - 停止全部并中止
error_occurred = True
exception_to_raise = e
print(f"\n[ERROR] Sync aborted due to exception: {e}")
# 关闭 executor 以停止所有待处理任务
executor.shutdown(wait=False, cancel_futures=True)
raise exception_to_raise
# 更新进度条
pbar.update(1)
except Exception:
error_count = 1
print("[DailySync] Sync stopped due to exception")
finally:
pbar.close()
# 批量写入所有数据(仅在无错误时)
if results and not error_occurred:
for ts_code, data in results.items():
if not data.empty:
self.storage.queue_save("daily", data)
# 一次性刷新所有排队写入
self.storage.flush()
total_rows = sum(len(df) for df in results.values())
print(f"\n[DailySync] Saved {total_rows} rows to storage")
# 摘要
print("\n" + "=" * 60)
print("[DailySync] Sync Summary")
print("=" * 60)
print(f" Total stocks: {len(stock_codes)}")
print(f" Updated: {success_count}")
print(f" Skipped (empty/delisted): {empty_count}")
print(
f" Errors: {error_count} (aborted on first error)"
if error_count
else " Errors: 0"
)
print(f" Date range: {sync_start_date} to {end_date}")
print("=" * 60)
return results
def sync_daily(
force_full: bool = False,
start_date: Optional[str] = None,
end_date: Optional[str] = None,
max_workers: Optional[int] = None,
dry_run: bool = False,
) -> Dict[str, pd.DataFrame]:
"""同步所有股票的日线数据。
这是日线数据同步的主要入口点。
Args:
force_full: 若为 True强制从 20180101 完整重载
start_date: 手动指定起始日期YYYYMMDD
end_date: 手动指定结束日期(默认为今天)
max_workers: 工作线程数(默认: 10
dry_run: 若为 True仅预览将要同步的内容不写入数据
Returns:
映射 ts_code 到 DataFrame 的字典
Example:
>>> # 首次同步(从 20180101 全量加载)
>>> result = sync_daily()
>>>
>>> # 后续同步(增量 - 仅新数据)
>>> result = sync_daily()
>>>
>>> # 强制完整重载
>>> result = sync_daily(force_full=True)
>>>
>>> # 手动指定日期范围
>>> result = sync_daily(start_date='20240101', end_date='20240131')
>>>
>>> # 自定义线程数
>>> result = sync_daily(max_workers=20)
>>>
>>> # Dry run仅预览
>>> result = sync_daily(dry_run=True)
"""
sync_manager = DailySync(max_workers=max_workers)
return sync_manager.sync_all(
force_full=force_full,
start_date=start_date,
end_date=end_date,
dry_run=dry_run,
)
def preview_daily_sync(
force_full: bool = False,
start_date: Optional[str] = None,
end_date: Optional[str] = None,
sample_size: int = 3,
) -> dict:
"""预览日线同步数据量和样本(不实际同步)。
这是推荐的方式,可在实际同步前检查将要同步的内容。
Args:
force_full: 若为 True预览全量同步从 20180101
start_date: 手动指定起始日期(覆盖自动检测)
end_date: 手动指定结束日期(默认为今天)
sample_size: 预览用样本股票数量(默认: 3
Returns:
包含预览信息的字典:
{
'sync_needed': bool,
'stock_count': int,
'start_date': str,
'end_date': str,
'estimated_records': int,
'sample_data': pd.DataFrame,
'mode': str, # 'full', 'incremental', 'partial', 或 'none'
}
Example:
>>> # 预览将要同步的内容
>>> preview = preview_daily_sync()
>>>
>>> # 预览全量同步
>>> preview = preview_daily_sync(force_full=True)
>>>
>>> # 预览更多样本
>>> preview = preview_daily_sync(sample_size=5)
"""
sync_manager = DailySync()
return sync_manager.preview_sync(
force_full=force_full,
start_date=start_date,
end_date=end_date,
sample_size=sample_size,
)

View File

@@ -0,0 +1,113 @@
"""Stock name change history interface.
Fetch historical name change records for stocks.
This interface retrieves all historical name changes including name, dates, and change reasons.
"""
import pandas as pd
from pathlib import Path
from typing import Optional, List
from src.data.client import TushareClient
from src.data.config import get_config
# CSV file path for namechange data
def _get_csv_path() -> Path:
"""Get the CSV file path for namechange data."""
cfg = get_config()
return cfg.data_path_resolved / "namechange.csv"
def get_namechange(
ts_code: Optional[str] = None,
start_date: Optional[str] = None,
end_date: Optional[str] = None,
fields: Optional[List[str]] = None,
) -> pd.DataFrame:
"""Fetch stock name change history.
This interface retrieves historical name change records for stocks,
including name, start/end dates, announcement date, and change reason.
Args:
ts_code: TS stock code (optional, if not provided, returns all stocks)
start_date: Start date for announcement date range (YYYYMMDD)
end_date: End date for announcement date range (YYYYMMDD)
fields: Specific fields to return, None returns all fields
Returns:
pd.DataFrame with namechange information containing:
- ts_code: TS stock code
- name: Security name
- start_date: Start date of the name
- end_date: End date of the name
- ann_date: Announcement date
- change_reason: Reason for name change
"""
client = TushareClient()
# Build parameters
params = {}
if ts_code:
params["ts_code"] = ts_code
if start_date:
params["start_date"] = start_date
if end_date:
params["end_date"] = end_date
if fields:
params["fields"] = ",".join(fields)
# Fetch data
data = client.query("namechange", **params)
if data.empty:
print("[get_namechange] No data returned")
return data
def sync_namechange(force: bool = False) -> pd.DataFrame:
"""Fetch and save all stock name change records to local CSV.
This is a full load interface - fetches all historical name change records.
Each call fetches all data (no incremental sync).
Args:
force: If True, force re-fetch even if CSV exists (default: False)
Returns:
pd.DataFrame with all namechange records
"""
csv_path = _get_csv_path()
# Check if CSV file already exists
if csv_path.exists() and not force:
print(f"[sync_namechange] namechange.csv already exists at {csv_path}")
print("[sync_namechange] Use force=True to re-fetch")
return pd.read_csv(csv_path, encoding="utf-8-sig")
print("[sync_namechange] Fetching all stock name changes...")
# Fetch all namechange data (no parameters = all stocks, all history)
client = TushareClient()
data = client.query("namechange")
if data.empty:
print("[sync_namechange] No namechange data fetched")
return pd.DataFrame()
print(f"[sync_namechange] Fetched {len(data)} name change records")
# Save to CSV
data.to_csv(csv_path, index=False, encoding="utf-8-sig")
print(f"[sync_namechange] Saved {len(data)} records to {csv_path}")
return data
if __name__ == "__main__":
# Sync all namechange records to data folder
result = sync_namechange()
print(f"Total records synced: {len(result)}")
if not result.empty:
print("\nSample data:")
print(result.head(10))

View File

@@ -1,701 +1,34 @@
"""Data synchronization module.
"""数据同步调度中心模块。
This module provides data fetching functions with intelligent sync logic:
- If local file doesn't exist: fetch all data (full load from 20180101)
- If local file exists: incremental update (fetch from latest date + 1 day)
- Multi-threaded concurrent fetching for improved performance
- Stop immediately on any exception
- Preview mode: check data volume and samples before actual sync
该模块作为数据同步的调度中心,统一管理各类型数据的同步流程。
具体的同步逻辑已迁移到对应的 api_xxx.py 文件中:
- api_daily.py: 日线数据同步 (DailySync 类)
- api_bak_basic.py: 历史股票列表同步
- api_stock_basic.py: 股票基本信息同步
- api_trade_cal.py: 交易日历同步
Currently supported data types:
- daily: Daily market data (with turnover rate and volume ratio)
注意:名称变更 (namechange) 已从自动同步中移除,
因为股票名称变更不频繁,建议手动定期同步。
Usage:
# Preview sync (check data volume and samples without writing)
preview_sync()
使用方式:
# 预览同步(检查数据量,不写入)
from src.data.sync import preview_sync
preview = preview_sync()
# Sync all stocks (full load)
sync_all()
# 同步所有数据(不包括 namechange
from src.data.sync import sync_all_data
result = sync_all_data()
# Sync all stocks (incremental)
sync_all()
# Force full reload
sync_all(force_full=True)
# Dry run (preview only, no write)
sync_all(dry_run=True)
# 强制全量重载
result = sync_all_data(force_full=True)
"""
from typing import Optional, Dict
import pandas as pd
from typing import Optional, Dict, Callable
from datetime import datetime, timedelta
from tqdm import tqdm
from concurrent.futures import ThreadPoolExecutor, as_completed
import threading
import sys
from src.data.client import TushareClient
from src.data.storage import ThreadSafeStorage
from src.data.api_wrappers import get_daily
from src.data.api_wrappers import (
get_first_trading_day,
get_last_trading_day,
sync_trade_cal_cache,
)
# Default full sync start date
DEFAULT_START_DATE = "20180101"
# Today's date in YYYYMMDD format
TODAY = datetime.now().strftime("%Y%m%d")
def get_today_date() -> str:
"""Get today's date in YYYYMMDD format."""
return TODAY
def get_next_date(date_str: str) -> str:
"""Get the next day after the given date.
Args:
date_str: Date in YYYYMMDD format
Returns:
Next date in YYYYMMDD format
"""
dt = datetime.strptime(date_str, "%Y%m%d")
next_dt = dt + timedelta(days=1)
return next_dt.strftime("%Y%m%d")
class DataSync:
"""Data synchronization manager with full/incremental sync support."""
# Default number of worker threads
DEFAULT_MAX_WORKERS = 10
def __init__(self, max_workers: Optional[int] = None):
"""Initialize sync manager.
Args:
max_workers: Number of worker threads (default: 10)
"""
self.storage = ThreadSafeStorage()
self.client = TushareClient()
self.max_workers = max_workers or self.DEFAULT_MAX_WORKERS
self._stop_flag = threading.Event()
self._stop_flag.set() # Initially not stopped
self._cached_daily_data: Optional[pd.DataFrame] = None # Cache for daily data
def _load_daily_data(self) -> pd.DataFrame:
"""Load daily data from storage with caching.
This method caches the daily data in memory to avoid repeated disk reads.
Call clear_cache() to force reload.
Returns:
DataFrame with daily data (cached or loaded from storage)
"""
if self._cached_daily_data is None:
self._cached_daily_data = self.storage.load("daily")
return self._cached_daily_data
def clear_cache(self) -> None:
"""Clear the cached daily data to force reload on next access."""
self._cached_daily_data = None
def get_all_stock_codes(self, only_listed: bool = True) -> list:
"""Get all stock codes from local storage.
This function prioritizes stock_basic.csv to ensure all stocks
are included for backtesting to avoid look-ahead bias.
Args:
only_listed: If True, only return currently listed stocks (L status).
Set to False to include delisted stocks (for full backtest).
Returns:
List of stock codes
"""
# Import sync_all_stocks here to avoid circular imports
from src.data.api_wrappers import sync_all_stocks
from src.data.api_wrappers.api_stock_basic import _get_csv_path
# First, ensure stock_basic.csv is up-to-date with all stocks
print("[DataSync] Ensuring stock_basic.csv is up-to-date...")
sync_all_stocks()
# Get from stock_basic.csv file
stock_csv_path = _get_csv_path()
if stock_csv_path.exists():
print(f"[DataSync] Reading stock_basic from CSV: {stock_csv_path}")
try:
stock_df = pd.read_csv(stock_csv_path, encoding="utf-8-sig")
if not stock_df.empty and "ts_code" in stock_df.columns:
# Filter by list_status if only_listed is True
if only_listed and "list_status" in stock_df.columns:
listed_stocks = stock_df[stock_df["list_status"] == "L"]
codes = listed_stocks["ts_code"].unique().tolist()
total = len(stock_df["ts_code"].unique())
print(
f"[DataSync] Found {len(codes)} listed stocks (filtered from {total} total)"
)
else:
codes = stock_df["ts_code"].unique().tolist()
print(
f"[DataSync] Found {len(codes)} stock codes from stock_basic.csv"
)
return codes
else:
print(
f"[DataSync] stock_basic.csv exists but no ts_code column or empty"
)
except Exception as e:
print(f"[DataSync] Error reading stock_basic.csv: {e}")
# Fallback: try daily storage if stock_basic not available (using cached data)
print("[DataSync] stock_basic.csv not available, falling back to daily data...")
daily_data = self._load_daily_data()
if not daily_data.empty and "ts_code" in daily_data.columns:
codes = daily_data["ts_code"].unique().tolist()
print(f"[DataSync] Found {len(codes)} stock codes from daily data")
return codes
print("[DataSync] No stock codes found in local storage")
return []
def get_global_last_date(self) -> Optional[str]:
"""Get the global last trade date across all stocks.
Returns:
Last trade date string or None
"""
daily_data = self._load_daily_data()
if daily_data.empty or "trade_date" not in daily_data.columns:
return None
return str(daily_data["trade_date"].max())
def get_global_first_date(self) -> Optional[str]:
"""Get the global first trade date across all stocks.
Returns:
First trade date string or None
"""
daily_data = self._load_daily_data()
if daily_data.empty or "trade_date" not in daily_data.columns:
return None
return str(daily_data["trade_date"].min())
def get_trade_calendar_bounds(
self, start_date: str, end_date: str
) -> tuple[Optional[str], Optional[str]]:
"""Get the first and last trading day from trade calendar.
Args:
start_date: Start date in YYYYMMDD format
end_date: End date in YYYYMMDD format
Returns:
Tuple of (first_trading_day, last_trading_day) or (None, None) if error
"""
try:
first_day = get_first_trading_day(start_date, end_date)
last_day = get_last_trading_day(start_date, end_date)
return (first_day, last_day)
except Exception as e:
print(f"[ERROR] Failed to get trade calendar bounds: {e}")
return (None, None)
def check_sync_needed(
self, force_full: bool = False
) -> tuple[bool, Optional[str], Optional[str], Optional[str]]:
"""Check if sync is needed based on trade calendar.
This method compares local data date range with trade calendar
to determine if new data needs to be fetched.
Logic:
- If force_full: sync needed, return (True, 20180101, today)
- If no local data: sync needed, return (True, 20180101, today)
- If local data exists:
- Get the last trading day from trade calendar
- If local last date >= calendar last date: NO sync needed
- Otherwise: sync needed from local_last_date + 1 to latest trade day
Args:
force_full: If True, always return sync needed
Returns:
Tuple of (sync_needed, start_date, end_date, local_last_date)
- sync_needed: True if sync should proceed, False to skip
- start_date: Sync start date (None if sync not needed)
- end_date: Sync end date (None if sync not needed)
- local_last_date: Local data last date (for incremental sync)
"""
# If force_full, always sync
if force_full:
print("[DataSync] Force full sync requested")
return (True, DEFAULT_START_DATE, get_today_date(), None)
# Check if local data exists (using cached data)
daily_data = self._load_daily_data()
if daily_data.empty or "trade_date" not in daily_data.columns:
print("[DataSync] No local data found, full sync needed")
return (True, DEFAULT_START_DATE, get_today_date(), None)
# Get local data last date (we only care about the latest date, not the first)
local_last_date = str(daily_data["trade_date"].max())
print(f"[DataSync] Local data last date: {local_last_date}")
# Get the latest trading day from trade calendar
today = get_today_date()
_, cal_last = self.get_trade_calendar_bounds(DEFAULT_START_DATE, today)
if cal_last is None:
print("[DataSync] Failed to get trade calendar, proceeding with sync")
return (True, DEFAULT_START_DATE, today, local_last_date)
print(f"[DataSync] Calendar last trading day: {cal_last}")
# Compare local last date with calendar last date
# If local data is already up-to-date or newer, no sync needed
print(
f"[DataSync] Comparing: local={local_last_date} (type={type(local_last_date).__name__}), cal={cal_last} (type={type(cal_last).__name__})"
)
try:
local_last_int = int(local_last_date)
cal_last_int = int(cal_last)
print(
f"[DataSync] Comparing integers: local={local_last_int} >= cal={cal_last_int} = {local_last_int >= cal_last_int}"
)
if local_last_int >= cal_last_int:
print(
"[DataSync] Local data is up-to-date, SKIPPING sync (no tokens consumed)"
)
return (False, None, None, None)
except (ValueError, TypeError) as e:
print(f"[ERROR] Date comparison failed: {e}")
# Need to sync from local_last_date + 1 to latest trade day
sync_start = get_next_date(local_last_date)
print(f"[DataSync] Incremental sync needed from {sync_start} to {cal_last}")
return (True, sync_start, cal_last, local_last_date)
def preview_sync(
self,
force_full: bool = False,
start_date: Optional[str] = None,
end_date: Optional[str] = None,
sample_size: int = 3,
) -> dict:
"""Preview sync data volume and samples without actually syncing.
This method provides a preview of what would be synced, including:
- Number of stocks to be synced
- Date range for sync
- Estimated total records
- Sample data from first few stocks
Args:
force_full: If True, preview full sync from 20180101
start_date: Manual start date (overrides auto-detection)
end_date: Manual end date (defaults to today)
sample_size: Number of sample stocks to fetch for preview (default: 3)
Returns:
Dictionary with preview information:
{
'sync_needed': bool,
'stock_count': int,
'start_date': str,
'end_date': str,
'estimated_records': int,
'sample_data': pd.DataFrame,
'mode': str, # 'full' or 'incremental'
}
"""
print("\n" + "=" * 60)
print("[DataSync] Preview Mode - Analyzing sync requirements...")
print("=" * 60)
# First, ensure trade calendar cache is up-to-date
print("[DataSync] Syncing trade calendar cache...")
sync_trade_cal_cache()
# Determine date range
if end_date is None:
end_date = get_today_date()
# Check if sync is needed
sync_needed, cal_start, cal_end, local_last = self.check_sync_needed(force_full)
if not sync_needed:
print("\n" + "=" * 60)
print("[DataSync] Preview Result")
print("=" * 60)
print(" Sync Status: NOT NEEDED")
print(" Reason: Local data is up-to-date with trade calendar")
print("=" * 60)
return {
"sync_needed": False,
"stock_count": 0,
"start_date": None,
"end_date": None,
"estimated_records": 0,
"sample_data": pd.DataFrame(),
"mode": "none",
}
# Use dates from check_sync_needed
if cal_start and cal_end:
sync_start_date = cal_start
end_date = cal_end
else:
sync_start_date = start_date or DEFAULT_START_DATE
if end_date is None:
end_date = get_today_date()
# Determine sync mode
if force_full:
mode = "full"
print(f"[DataSync] Mode: FULL SYNC from {sync_start_date} to {end_date}")
elif local_last and cal_start and sync_start_date == get_next_date(local_last):
mode = "incremental"
print(f"[DataSync] Mode: INCREMENTAL SYNC (bandwidth optimized)")
print(f"[DataSync] Sync from: {sync_start_date} to {end_date}")
else:
mode = "partial"
print(f"[DataSync] Mode: SYNC from {sync_start_date} to {end_date}")
# Get all stock codes
stock_codes = self.get_all_stock_codes()
if not stock_codes:
print("[DataSync] No stocks found to sync")
return {
"sync_needed": False,
"stock_count": 0,
"start_date": None,
"end_date": None,
"estimated_records": 0,
"sample_data": pd.DataFrame(),
"mode": "none",
}
stock_count = len(stock_codes)
print(f"[DataSync] Total stocks to sync: {stock_count}")
# Fetch sample data from first few stocks
print(f"[DataSync] Fetching sample data from {sample_size} stocks...")
sample_data_list = []
sample_codes = stock_codes[:sample_size]
for ts_code in sample_codes:
try:
data = self.client.query(
"pro_bar",
ts_code=ts_code,
start_date=sync_start_date,
end_date=end_date,
factors="tor,vr",
)
if not data.empty:
sample_data_list.append(data)
print(f" - {ts_code}: {len(data)} records")
except Exception as e:
print(f" - {ts_code}: Error fetching - {e}")
# Combine sample data
sample_df = (
pd.concat(sample_data_list, ignore_index=True)
if sample_data_list
else pd.DataFrame()
)
# Estimate total records based on sample
if not sample_df.empty:
avg_records_per_stock = len(sample_df) / len(sample_data_list)
estimated_records = int(avg_records_per_stock * stock_count)
else:
estimated_records = 0
# Display preview results
print("\n" + "=" * 60)
print("[DataSync] Preview Result")
print("=" * 60)
print(f" Sync Mode: {mode.upper()}")
print(f" Date Range: {sync_start_date} to {end_date}")
print(f" Stocks to Sync: {stock_count}")
print(f" Sample Stocks Checked: {len(sample_data_list)}/{sample_size}")
print(f" Estimated Total Records: ~{estimated_records:,}")
if not sample_df.empty:
print(f"\n Sample Data Preview (first {len(sample_df)} rows):")
print(" " + "-" * 56)
# Display sample data in a compact format
preview_cols = [
"ts_code",
"trade_date",
"open",
"high",
"low",
"close",
"vol",
]
available_cols = [c for c in preview_cols if c in sample_df.columns]
sample_display = sample_df[available_cols].head(10)
for idx, row in sample_display.iterrows():
print(f" {row.to_dict()}")
print(" " + "-" * 56)
print("=" * 60)
return {
"sync_needed": True,
"stock_count": stock_count,
"start_date": sync_start_date,
"end_date": end_date,
"estimated_records": estimated_records,
"sample_data": sample_df,
"mode": mode,
}
def sync_single_stock(
self,
ts_code: str,
start_date: str,
end_date: str,
) -> pd.DataFrame:
"""Sync daily data for a single stock.
Args:
ts_code: Stock code
start_date: Start date (YYYYMMDD)
end_date: End date (YYYYMMDD)
Returns:
DataFrame with daily market data
"""
# Check if sync should stop (for exception handling)
if not self._stop_flag.is_set():
return pd.DataFrame()
try:
# Use shared client for rate limiting across threads
data = self.client.query(
"pro_bar",
ts_code=ts_code,
start_date=start_date,
end_date=end_date,
factors="tor,vr",
)
return data
except Exception as e:
# Set stop flag to signal other threads to stop
self._stop_flag.clear()
print(f"[ERROR] Exception syncing {ts_code}: {e}")
raise
def sync_all(
self,
force_full: bool = False,
start_date: Optional[str] = None,
end_date: Optional[str] = None,
max_workers: Optional[int] = None,
dry_run: bool = False,
) -> Dict[str, pd.DataFrame]:
"""Sync daily data for all stocks in local storage.
This function:
1. Reads stock codes from local storage (daily or stock_basic)
2. Checks trade calendar to determine if sync is needed:
- If local data matches trade calendar bounds, SKIP sync (save tokens)
- Otherwise, sync from local_last_date + 1 to latest trade day (bandwidth optimized)
3. Uses multi-threaded concurrent fetching with rate limiting
4. Skips updating stocks that return empty data (delisted/unavailable)
5. Stops immediately on any exception
Args:
force_full: If True, force full reload from 20180101
start_date: Manual start date (overrides auto-detection)
end_date: Manual end date (defaults to today)
max_workers: Number of worker threads (default: 10)
dry_run: If True, only preview what would be synced without writing data
Returns:
Dict mapping ts_code to DataFrame (empty if sync skipped or dry_run)
"""
print("\n" + "=" * 60)
print("[DataSync] Starting daily data sync...")
print("=" * 60)
# First, ensure trade calendar cache is up-to-date (uses incremental sync)
print("[DataSync] Syncing trade calendar cache...")
sync_trade_cal_cache()
# Determine date range
if end_date is None:
end_date = get_today_date()
# Check if sync is needed based on trade calendar
sync_needed, cal_start, cal_end, local_last = self.check_sync_needed(force_full)
if not sync_needed:
# Sync skipped - no tokens consumed
print("\n" + "=" * 60)
print("[DataSync] Sync Summary")
print("=" * 60)
print(" Sync: SKIPPED (local data up-to-date with trade calendar)")
print(" Tokens saved: 0 consumed")
print("=" * 60)
return {}
# Use dates from check_sync_needed (which calculates incremental start if needed)
if cal_start and cal_end:
sync_start_date = cal_start
end_date = cal_end
else:
# Fallback to default logic
sync_start_date = start_date or DEFAULT_START_DATE
if end_date is None:
end_date = get_today_date()
# Determine sync mode
if force_full:
mode = "full"
print(f"[DataSync] Mode: FULL SYNC from {sync_start_date} to {end_date}")
elif local_last and cal_start and sync_start_date == get_next_date(local_last):
mode = "incremental"
print(f"[DataSync] Mode: INCREMENTAL SYNC (bandwidth optimized)")
print(f"[DataSync] Sync from: {sync_start_date} to {end_date}")
else:
mode = "partial"
print(f"[DataSync] Mode: SYNC from {sync_start_date} to {end_date}")
# Get all stock codes
stock_codes = self.get_all_stock_codes()
if not stock_codes:
print("[DataSync] No stocks found to sync")
return {}
print(f"[DataSync] Total stocks to sync: {len(stock_codes)}")
print(f"[DataSync] Using {max_workers or self.max_workers} worker threads")
# Handle dry run mode
if dry_run:
print("\n" + "=" * 60)
print("[DataSync] DRY RUN MODE - No data will be written")
print("=" * 60)
print(f" Would sync {len(stock_codes)} stocks")
print(f" Date range: {sync_start_date} to {end_date}")
print(f" Mode: {mode}")
print("=" * 60)
return {}
# Reset stop flag for new sync
self._stop_flag.set()
# Multi-threaded concurrent fetching
results: Dict[str, pd.DataFrame] = {}
error_occurred = False
exception_to_raise = None
def sync_task(ts_code: str) -> tuple[str, pd.DataFrame]:
"""Task function for each stock."""
try:
data = self.sync_single_stock(
ts_code=ts_code,
start_date=sync_start_date,
end_date=end_date,
)
return (ts_code, data)
except Exception as e:
# Re-raise to be caught by Future
raise
# Use ThreadPoolExecutor for concurrent fetching
workers = max_workers or self.max_workers
with ThreadPoolExecutor(max_workers=workers) as executor:
# Submit all tasks and track futures with their stock codes
future_to_code = {
executor.submit(sync_task, ts_code): ts_code for ts_code in stock_codes
}
# Process results using as_completed
error_count = 0
empty_count = 0
success_count = 0
# Create progress bar
pbar = tqdm(total=len(stock_codes), desc="Syncing stocks")
try:
# Process futures as they complete
for future in as_completed(future_to_code):
ts_code = future_to_code[future]
try:
_, data = future.result()
if data is not None and not data.empty:
results[ts_code] = data
success_count += 1
else:
# Empty data - stock may be delisted or unavailable
empty_count += 1
print(
f"[DataSync] Stock {ts_code}: empty data (skipped, may be delisted)"
)
except Exception as e:
# Exception occurred - stop all and abort
error_occurred = True
exception_to_raise = e
print(f"\n[ERROR] Sync aborted due to exception: {e}")
# Shutdown executor to stop all pending tasks
executor.shutdown(wait=False, cancel_futures=True)
raise exception_to_raise
# Update progress bar
pbar.update(1)
except Exception:
error_count = 1
print("[DataSync] Sync stopped due to exception")
finally:
pbar.close()
# Queue all data for batch write (only if no error)
if results and not error_occurred:
for ts_code, data in results.items():
if not data.empty:
self.storage.queue_save("daily", data)
# Flush all queued writes at once
self.storage.flush()
total_rows = sum(len(df) for df in results.values())
print(f"\n[DataSync] Saved {total_rows} rows to storage")
# Summary
print("\n" + "=" * 60)
print("[DataSync] Sync Summary")
print("=" * 60)
print(f" Total stocks: {len(stock_codes)}")
print(f" Updated: {success_count}")
print(f" Skipped (empty/delisted): {empty_count}")
print(
f" Errors: {error_count} (aborted on first error)"
if error_count
else " Errors: 0"
)
print(f" Date range: {sync_start_date} to {end_date}")
print("=" * 60)
return results
# Convenience functions
from src.data.api_wrappers import sync_all_stocks
from src.data.api_wrappers.api_daily import sync_daily, preview_daily_sync
def preview_sync(
@@ -705,20 +38,19 @@ def preview_sync(
sample_size: int = 3,
max_workers: Optional[int] = None,
) -> dict:
"""Preview sync data volume and samples without actually syncing.
"""预览日线同步数据量和样本(不实际同步)。
This is the recommended way to check what would be synced before
running the actual synchronization.
这是推荐的方式,可在实际同步前检查将要同步的内容。
Args:
force_full: If True, preview full sync from 20180101
start_date: Manual start date (overrides auto-detection)
end_date: Manual end date (defaults to today)
sample_size: Number of sample stocks to fetch for preview (default: 3)
max_workers: Number of worker threads (not used in preview, for API compatibility)
force_full: 若为 True,预览全量同步(从 20180101
start_date: 手动指定起始日期(覆盖自动检测)
end_date: 手动指定结束日期(默认为今天)
sample_size: 预览用样本股票数量(默认: 3
max_workers: 工作线程数(默认: 10
Returns:
Dictionary with preview information:
包含预览信息的字典:
{
'sync_needed': bool,
'stock_count': int,
@@ -726,21 +58,20 @@ def preview_sync(
'end_date': str,
'estimated_records': int,
'sample_data': pd.DataFrame,
'mode': str, # 'full', 'incremental', 'partial', or 'none'
'mode': str, # 'full', 'incremental', 'partial', 'none'
}
Example:
>>> # Preview what would be synced
>>> # 预览将要同步的内容
>>> preview = preview_sync()
>>>
>>> # Preview full sync
>>> # 预览全量同步
>>> preview = preview_sync(force_full=True)
>>>
>>> # Preview with more samples
>>> # 预览更多样本
>>> preview = preview_sync(sample_size=5)
"""
sync_manager = DataSync(max_workers=max_workers)
return sync_manager.preview_sync(
return preview_daily_sync(
force_full=force_full,
start_date=start_date,
end_date=end_date,
@@ -755,54 +86,168 @@ def sync_all(
max_workers: Optional[int] = None,
dry_run: bool = False,
) -> Dict[str, pd.DataFrame]:
"""Sync daily data for all stocks.
"""同步所有股票的日线数据。
This is the main entry point for data synchronization.
这是日线数据同步的主要入口点。
Args:
force_full: If True, force full reload from 20180101
start_date: Manual start date (YYYYMMDD)
end_date: Manual end date (defaults to today)
max_workers: Number of worker threads (default: 10)
dry_run: If True, only preview what would be synced without writing data
force_full: 若为 True,强制从 20180101 完整重载
start_date: 手动指定起始日期(YYYYMMDD
end_date: 手动指定结束日期(默认为今天)
max_workers: 工作线程数(默认: 10
dry_run: 若为 True,仅预览将要同步的内容,不写入数据
Returns:
Dict mapping ts_code to DataFrame
映射 ts_code DataFrame 的字典
Example:
>>> # First time sync (full load from 20180101)
>>> # 首次同步(从 20180101 全量加载)
>>> result = sync_all()
>>>
>>> # Subsequent sync (incremental - only new data)
>>> # 后续同步(增量 - 仅新数据)
>>> result = sync_all()
>>>
>>> # Force full reload
>>> # 强制完整重载
>>> result = sync_all(force_full=True)
>>>
>>> # Manual date range
>>> # 手动指定日期范围
>>> result = sync_all(start_date='20240101', end_date='20240131')
>>>
>>> # Custom thread count
>>> # 自定义线程数
>>> result = sync_all(max_workers=20)
>>>
>>> # Dry run (preview only)
>>> # Dry run(仅预览)
>>> result = sync_all(dry_run=True)
"""
sync_manager = DataSync(max_workers=max_workers)
return sync_manager.sync_all(
return sync_daily(
force_full=force_full,
start_date=start_date,
end_date=end_date,
max_workers=max_workers,
dry_run=dry_run,
)
def sync_all_data(
force_full: bool = False,
max_workers: Optional[int] = None,
dry_run: bool = False,
) -> Dict[str, pd.DataFrame]:
"""同步所有数据类型(每日同步)。
该函数按顺序同步所有可用的数据类型:
1. 交易日历 (sync_trade_cal_cache)
2. 股票基本信息 (sync_all_stocks)
3. 日线市场数据 (sync_all)
4. 历史股票列表 (sync_bak_basic)
注意:名称变更 (namechange) 不在自动同步中,如需同步请手动调用。
Args:
force_full: 若为 True强制所有数据类型完整重载
max_workers: 日线数据同步的工作线程数(默认: 10
dry_run: 若为 True仅显示将要同步的内容 Returns:
映射数据类型,不写入数据
到同步结果的字典
Example:
>>> # 同步所有数据(增量)
>>> result = sync_all_data()
>>>
>>> # 强制完整重载
>>> result = sync_all_data(force_full=True)
>>>
>>> # Dry run
>>> result = sync_all_data(dry_run=True)
"""
results: Dict[str, pd.DataFrame] = {}
print("\n" + "=" * 60)
print("[sync_all_data] Starting full data synchronization...")
print("=" * 60)
# 1. Sync trade calendar (always needed first)
print("\n[1/5] Syncing trade calendar cache...")
try:
from src.data.api_wrappers import sync_trade_cal_cache
sync_trade_cal_cache()
results["trade_cal"] = pd.DataFrame()
print("[1/5] Trade calendar: OK")
except Exception as e:
print(f"[1/5] Trade calendar: FAILED - {e}")
results["trade_cal"] = pd.DataFrame()
# 2. Sync stock basic info
print("\n[2/5] Syncing stock basic info...")
try:
sync_all_stocks()
results["stock_basic"] = pd.DataFrame()
print("[2/5] Stock basic: OK")
except Exception as e:
print(f"[2/5] Stock basic: FAILED - {e}")
results["stock_basic"] = pd.DataFrame()
# 3. Sync daily market data
print("\n[3/5] Syncing daily market data...")
try:
daily_result = sync_daily(
force_full=force_full,
max_workers=max_workers,
dry_run=dry_run,
)
results["daily"] = (
pd.concat(daily_result.values(), ignore_index=True)
if daily_result
else pd.DataFrame()
)
print("[3/5] Daily data: OK")
except Exception as e:
print(f"[3/5] Daily data: FAILED - {e}")
results["daily"] = pd.DataFrame()
# 4. Sync stock historical list (bak_basic)
print("\n[4/5] Syncing stock historical list (bak_basic)...")
try:
bak_basic_result = sync_bak_basic(force_full=force_full)
results["bak_basic"] = bak_basic_result
print(f"[4/5] Bak basic: OK ({len(bak_basic_result)} records)")
except Exception as e:
print(f"[4/5] Bak basic: FAILED - {e}")
results["bak_basic"] = pd.DataFrame()
# Summary
print("\n" + "=" * 60)
print("[sync_all_data] Sync Summary")
print("=" * 60)
for data_type, df in results.items():
print(f" {data_type}: {len(df)} records")
print("=" * 60)
print("\nNote: namechange is NOT in auto-sync. To sync manually:")
print(" from src.data.api_wrappers import sync_namechange")
print(" sync_namechange(force=True)")
return results
# 保留向后兼容的导入
from src.data.api_wrappers import sync_bak_basic
if __name__ == "__main__":
print("=" * 60)
print("Data Sync Module")
print("=" * 60)
print("\nUsage:")
print(" # Sync all data types at once (RECOMMENDED)")
print(" from src.data.sync import sync_all_data")
print(" result = sync_all_data() # Incremental sync all")
print(" result = sync_all_data(force_full=True) # Full reload")
print("")
print(" # Or sync individual data types:")
print(" from src.data.sync import sync_all, preview_sync")
print(" from src.data.sync import sync_bak_basic")
print("")
print(" # Preview before sync (recommended)")
print(" preview = preview_sync()")
@@ -813,21 +258,14 @@ if __name__ == "__main__":
print(" # Actual sync")
print(" result = sync_all() # Incremental sync")
print(" result = sync_all(force_full=True) # Full reload")
print("")
print(" # bak_basic sync")
print(" result = sync_bak_basic() # Incremental sync")
print(" result = sync_bak_basic(force_full=True) # Full reload")
print("\n" + "=" * 60)
# Run preview first
print("\n[Main] Running preview first...")
preview = preview_sync()
if preview["sync_needed"]:
# Ask for confirmation
print("\n" + "=" * 60)
response = input("Proceed with sync? (y/n): ").strip().lower()
if response in ("y", "yes"):
print("\n[Main] Starting actual sync...")
result = sync_all()
print(f"\nSynced {len(result)} stocks")
else:
print("\n[Main] Sync cancelled by user")
else:
print("\n[Main] No sync needed - data is up to date")
# Run sync_all_data by default
print("\n[Main] Running sync_all_data()...")
result = sync_all_data()
print("\n[Main] Sync completed!")
print(f"Total data types synced: {len(result)}")

75
src/data/utils.py Normal file
View File

@@ -0,0 +1,75 @@
"""Data module utility functions.
集中管理数据模块中常用的工具函数,避免重复定义。
"""
from datetime import datetime, timedelta
from typing import Optional
# 默认全量同步开始日期
DEFAULT_START_DATE = "20180101"
# 今日日期 (YYYYMMDD 格式)
TODAY: str = datetime.now().strftime("%Y%m%d")
def get_today_date() -> str:
"""获取今日日期YYYYMMDD 格式)。
Returns:
今日日期字符串,格式为 YYYYMMDD
"""
return TODAY
def get_next_date(date_str: str) -> str:
"""获取给定日期的下一天。
Args:
date_str: YYYYMMDD 格式的日期
Returns:
YYYYMMDD 格式的下一天日期
"""
dt = datetime.strptime(date_str, "%Y%m%d")
next_dt = dt + timedelta(days=1)
return next_dt.strftime("%Y%m%d")
def get_prev_date(date_str: str) -> str:
"""获取给定日期的前一天。
Args:
date_str: YYYYMMDD 格式的日期
Returns:
YYYYMMDD 格式的前一天日期
"""
dt = datetime.strptime(date_str, "%Y%m%d")
prev_dt = dt - timedelta(days=1)
return prev_dt.strftime("%Y%m%d")
def parse_date(date_str: str) -> datetime:
"""解析 YYYYMMDD 格式的日期字符串。
Args:
date_str: YYYYMMDD 格式的日期
Returns:
datetime 对象
"""
return datetime.strptime(date_str, "%Y%m%d")
def format_date(dt: datetime) -> str:
"""将 datetime 对象格式化为 YYYYMMDD 字符串。
Args:
dt: datetime 对象
Returns:
YYYYMMDD 格式的日期字符串
"""
return dt.strftime("%Y%m%d")

View File

@@ -18,28 +18,29 @@
- CompositeFactor: 组合因子
- ScalarFactor: 标量运算因子
动量因子momentum/
- MovingAverageFactor: 移动平均线(时序因子)
- ReturnRankFactor: 收益率排名(截面因子)
财务因子financial/
- (待添加)
数据加载和执行Phase 3-4
- DataLoader: 数据加载器
- FactorEngine: 因子执行引擎
使用示例:
from src.factors import DataSpec, FactorContext, FactorData
from src.factors import CrossSectionalFactor, TimeSeriesFactor
# 使用通用因子(参数化)
from src.factors import MovingAverageFactor, ReturnRankFactor
from src.factors import DataLoader, FactorEngine
# 定义数据需求
spec = DataSpec(
source="daily",
columns=["ts_code", "trade_date", "close"],
lookback_days=20
)
ma5 = MovingAverageFactor(period=5) # 5日MA
ma10 = MovingAverageFactor(period=10) # 10日MA
ret5 = ReturnRankFactor(period=5) # 5日收益率排名
# 初始化引擎
loader = DataLoader(data_dir="data")
engine = FactorEngine(loader)
# 计算因子
result = engine.compute(factor, start_date="20240101", end_date="20240131")
result = engine.compute(ma5, stock_codes=["000001.SZ"], start_date="20240101", end_date="20240131")
"""
from src.factors.data_spec import DataSpec, FactorContext, FactorData
@@ -48,6 +49,9 @@ from src.factors.composite import CompositeFactor, ScalarFactor
from src.factors.data_loader import DataLoader
from src.factors.engine import FactorEngine
# 动量因子
from src.factors.momentum import MovingAverageFactor, ReturnRankFactor
__all__ = [
# Phase 1: 数据类型定义
"DataSpec",
@@ -62,4 +66,7 @@ __all__ = [
# Phase 3-4: 数据加载和执行引擎
"DataLoader",
"FactorEngine",
# 动量因子
"MovingAverageFactor",
"ReturnRankFactor",
]

View File

@@ -0,0 +1,15 @@
"""财务因子模块
本模块提供财务类型的因子:
因子分类:
- financial: 财务因子
- (待添加)
待添加因子:
- PERankFactor: 市盈率排名
- PBFactor: 市净率因子
- DividendFactor: 股息率因子
"""
__all__ = []

View File

@@ -0,0 +1,19 @@
"""动量因子模块
本模块提供动量类型的因子:
- MovingAverageFactor: 移动平均线(时序因子)
- ReturnRankFactor: 收益率排名(截面因子)
因子分类:
- momentum: 动量因子
- ma: 移动平均线
- return_rank: 收益率排名
"""
from src.factors.momentum.ma import MovingAverageFactor
from src.factors.momentum.return_rank import ReturnRankFactor
__all__ = [
"MovingAverageFactor",
"ReturnRankFactor",
]

View File

@@ -0,0 +1,78 @@
"""动量因子 - 移动平均线
本模块提供通用移动平均线因子,支持参数化配置:
- MovingAverageFactor: 移动平均线(时序因子)
使用示例:
>>> from src.factors.momentum import MovingAverageFactor
>>> ma5 = MovingAverageFactor(period=5) # 5日MA
>>> ma10 = MovingAverageFactor(period=10) # 10日MA
>>> ma20 = MovingAverageFactor(period=20) # 20日MA
"""
from typing import List
import polars as pl
from src.factors.base import TimeSeriesFactor
from src.factors.data_spec import DataSpec, FactorData
class MovingAverageFactor(TimeSeriesFactor):
"""移动平均线因子
计算逻辑对每只股票计算其过去n日收盘价的移动平均值。
特点:
- 参数化因子:训练时通过 period 参数指定计算窗口
- 时序因子:每只股票单独计算,防止股票间数据泄露
Attributes:
period: MA计算期天数默认5
Example:
>>> ma5 = MovingAverageFactor(period=5)
>>> # 计算过去5日的收盘价均值
"""
name: str = "ma"
factor_type: str = "time_series"
category: str = "momentum"
description: str = "移动平均线因子计算过去n日收盘价的均值"
data_specs: List[DataSpec] = [
DataSpec("daily", ["ts_code", "trade_date", "close"], lookback_days=5)
]
def __init__(self, period: int = 5):
"""初始化因子
Args:
period: MA计算期天数默认5日
"""
super().__init__(period=period)
# 重新创建 DataSpec 以设置正确的 lookback_daysDataSpec 是 frozen 的)
self.data_specs = [
DataSpec(
"daily",
["ts_code", "trade_date", "close"],
lookback_days=period,
)
]
self.name = f"ma_{period}"
def compute(self, data: FactorData) -> pl.Series:
"""计算移动平均线
Args:
data: FactorData包含单只股票的完整时间序列
Returns:
移动平均值序列
"""
# 获取收盘价序列
close_prices = data.get_column("close")
# 计算移动平均
ma = close_prices.rolling_mean(window_size=self.params["period"])
return ma

View File

@@ -0,0 +1,100 @@
"""动量因子 - 收益率排名
本模块提供收益率排名因子:
- ReturnRankFactor: 过去n日收益率的rank因子截面因子
使用示例:
>>> from src.factors.momentum import ReturnRankFactor
>>> ret5 = ReturnRankFactor(period=5) # 5日收益率排名
>>> ret10 = ReturnRankFactor(period=10) # 10日收益率排名
"""
from typing import List
import polars as pl
from src.factors.base import CrossSectionalFactor
from src.factors.data_spec import DataSpec, FactorData
class ReturnRankFactor(CrossSectionalFactor):
"""过去n日收益率排名因子
计算逻辑每个交易日计算所有股票过去n日的收益率然后进行截面排名。
特点:
- 参数化因子:训练时通过 period 参数指定计算窗口
- 截面因子:每天对所有股票进行横向排名,防止日期泄露
Attributes:
period: 收益率计算期默认5日
Example:
>>> ret5 = ReturnRankFactor(period=5)
>>> # 每个交易日返回所有股票过去5日收益率的排名
"""
name: str = "return_rank"
factor_type: str = "cross_sectional"
category: str = "momentum"
description: str = "过去n日收益率的截面排名因子"
data_specs: List[DataSpec] = [
DataSpec("daily", ["ts_code", "trade_date", "close"], lookback_days=5)
]
def __init__(self, period: int = 5):
"""初始化因子
Args:
period: 收益率计算期(天数)
"""
super().__init__(period=period)
# 重新创建 DataSpec 以设置正确的 lookback_daysDataSpec 是 frozen 的)
self.data_specs = [
DataSpec(
"daily",
["ts_code", "trade_date", "close"],
lookback_days=period + 1,
)
]
self.name = f"return_{period}_rank"
def compute(self, data: FactorData) -> pl.Series:
"""计算过去n日收益率排名
Args:
data: FactorData包含过去n+1天的截面数据
Returns:
过去n日收益率的截面排名0-1之间
"""
# 获取当前日期的截面数据
cs = data.to_polars()
# 获取所有交易日期(已按日期排序)
trade_dates = cs["trade_date"].unique().sort()
if len(trade_dates) < 2:
# 数据不足,返回空排名
return pl.Series(name=self.name, values=[])
# 获取最新日期的数据
latest_date = trade_dates[-1]
current_data = cs.filter(pl.col("trade_date") == latest_date)
# 获取n天前的日期
n_days_ago = trade_dates[-(self.params["period"] + 1)]
past_data = cs.filter(pl.col("trade_date") == n_days_ago)
# 通过 ts_code join 计算收益率
merged = current_data.select(["ts_code", "close"]).join(
past_data.select(["ts_code", "close"]).rename({"close": "close_past"}),
on="ts_code",
how="inner",
)
# 计算收益率
returns = (merged["close"] - merged["close_past"]) / merged["close_past"]
# 返回排名0-1之间
return returns.rank(method="average") / len(returns)

View File

@@ -1,9 +1,10 @@
"""ProStock 模型训练框架
"""ProStock ML Pipeline 组件库
组件化低耦合插件式的机器学习训练框架
提供组件化低耦合插件式的机器学习流水线组件
包括处理器模型划分策略等可复用组件
示例:
>>> from src.models import (
>>> from src.pipeline import (
... PluginRegistry, ProcessingPipeline,
... PipelineStage, BaseProcessor
... )
@@ -21,7 +22,7 @@
"""
# 导入核心抽象类和划分策略
from src.models.core import (
from src.pipeline.core import (
PipelineStage,
TaskType,
BaseProcessor,
@@ -34,13 +35,13 @@ from src.models.core import (
)
# 导入注册中心
from src.models.registry import PluginRegistry
from src.pipeline.registry import PluginRegistry
# 导入处理流水线
from src.models.pipeline import ProcessingPipeline
from src.pipeline.pipeline import ProcessingPipeline
# 导入并注册内置处理器
from src.models.processors.processors import (
from src.pipeline.processors.processors import (
DropNAProcessor,
FillNAProcessor,
Winsorizer,
@@ -51,7 +52,7 @@ from src.models.processors.processors import (
)
# 导入并注册内置模型
from src.models.models.models import (
from src.pipeline.models.models import (
LightGBMModel,
CatBoostModel,
)

View File

@@ -1,6 +1,6 @@
"""核心模块导出"""
from src.models.core.base import (
from src.pipeline.core.base import (
PipelineStage,
TaskType,
BaseProcessor,
@@ -9,7 +9,7 @@ from src.models.core.base import (
BaseMetric,
)
from src.models.core.splitter import (
from src.pipeline.core.splitter import (
TimeSeriesSplit,
WalkForwardSplit,
ExpandingWindowSplit,

View File

@@ -6,7 +6,7 @@
from typing import Iterator, List, Tuple
import polars as pl
from src.models.core.base import BaseSplitter
from src.pipeline.core.base import BaseSplitter
class TimeSeriesSplit(BaseSplitter):

View File

@@ -1,6 +1,6 @@
"""模型模块"""
from src.models.models.models import (
from src.pipeline.models.models import (
LightGBMModel,
CatBoostModel,
)

View File

@@ -7,8 +7,8 @@ from typing import Optional, Dict, Any
import polars as pl
import numpy as np
from src.models.core import BaseModel, TaskType
from src.models.registry import PluginRegistry
from src.pipeline.core import BaseModel, TaskType
from src.pipeline.registry import PluginRegistry
@PluginRegistry.register_model("lightgbm")

View File

@@ -6,7 +6,7 @@
from typing import List, Dict
import polars as pl
from src.models.core import BaseProcessor, PipelineStage
from src.pipeline.core import BaseProcessor, PipelineStage
class ProcessingPipeline:

View File

@@ -1,6 +1,6 @@
"""处理器模块"""
from src.models.processors.processors import (
from src.pipeline.processors.processors import (
DropNAProcessor,
FillNAProcessor,
Winsorizer,

View File

@@ -7,8 +7,8 @@ from typing import List, Optional, Dict, Any
import polars as pl
import numpy as np
from src.models.core import BaseProcessor, PipelineStage
from src.models.registry import PluginRegistry
from src.pipeline.core import BaseProcessor, PipelineStage
from src.pipeline.registry import PluginRegistry
# 数值类型列表
FLOAT_TYPES = [pl.Float32, pl.Float64, pl.Int8, pl.Int16, pl.Int32, pl.Int64]

View File

@@ -17,7 +17,7 @@ from functools import wraps
from weakref import WeakValueDictionary
import contextlib
from src.models.core import BaseProcessor, BaseModel, BaseSplitter, BaseMetric
from src.pipeline.core import BaseProcessor, BaseModel, BaseSplitter, BaseMetric
T = TypeVar("T")

46
src/training/__init__.py Normal file
View File

@@ -0,0 +1,46 @@
"""ProStock 训练流程模块
本模块提供完整的模型训练流程:
1. 数据处理Fillna(0) -> Dropna
2. 模型训练LightGBM分类模型
3. 预测选股每日top5股票池
使用示例:
from src.training import run_training
# 运行完整训练流程
result = run_training(
train_start="20180101",
train_end="20230101",
test_start="20230101",
test_end="20240101",
top_n=5,
output_path="output/top_stocks.tsv"
)
因子使用:
from src.factors import MovingAverageFactor, ReturnRankFactor
ma5 = MovingAverageFactor(period=5) # 5日移动平均
ma10 = MovingAverageFactor(period=10) # 10日移动平均
ret5 = ReturnRankFactor(period=5) # 5日收益率排名
"""
from src.training.pipeline import (
create_pipeline,
predict_top_stocks,
prepare_data,
run_training,
save_top_stocks,
train_model,
)
__all__ = [
# 管道函数
"prepare_data",
"create_pipeline",
"train_model",
"predict_top_stocks",
"save_top_stocks",
"run_training",
]

27
src/training/main.py Normal file
View File

@@ -0,0 +1,27 @@
"""训练流程入口脚本
运行方式:
uv run python -m src.training.main
或:
uv run python src/training/main.py
"""
from src.training.pipeline import run_training
if __name__ == "__main__":
# 运行完整训练流程
# 训练集20180101 - 20230101
# 测试集20230101 - 20240101
result = run_training(
train_start="20190101",
train_end="20250101",
test_start="20250101",
test_end="20260101",
top_n=5,
output_path="output/top_stocks.tsv",
)
print("\n[Result] Top stocks selection:")
print(result)

File diff suppressed because it is too large Load Diff

448
src/training/pipeline.py Normal file
View File

@@ -0,0 +1,448 @@
"""训练管道 - 包含数据处理、模型训练和预测功能
本模块提供:
1. 数据准备:从因子计算结果中准备训练/测试数据
2. 数据处理Fillna(0) -> Dropna
3. 模型训练使用LightGBM训练分类模型
4. 预测和选股输出每日top5股票池
"""
from datetime import datetime
from pathlib import Path
from typing import List, Optional, Tuple
import numpy as np
import polars as pl
from src.factors import DataLoader, FactorEngine
from src.factors.data_spec import DataSpec
from src.pipeline import (
DropNAProcessor,
FillNAProcessor,
LightGBMModel,
PipelineStage,
ProcessingPipeline,
TaskType,
)
def prepare_data(
data_dir: str = "data",
train_start: str = "20180101",
train_end: str = "20230101",
test_start: str = "20230101",
test_end: str = "20240101",
) -> Tuple[pl.DataFrame, pl.DataFrame]:
"""准备训练和测试数据
从DuckDB加载原始日线数据计算所需因子并生成标签。
Args:
data_dir: 数据目录
train_start: 训练集开始日期
train_end: 训练集结束日期
test_start: 测试集开始日期
test_end: 测试集结束日期
Returns:
(train_data, test_data): 训练集和测试集的DataFrame
"""
from src.data.storage import Storage
storage = Storage()
# 加载日线数据(需要更多历史数据用于计算因子)
# 训练集需要更多历史数据用于计算因子lookback
lookback_days = 20 # 足够计算MA10和5日收益率
start_with_lookback = str(int(train_start) - 10000) # 往前取一年
# 查询训练集数据
# 注意DuckDB 中 trade_date 是 DATE 类型,需要转换
start_dt = f"{start_with_lookback[:4]}-{start_with_lookback[4:6]}-{start_with_lookback[6:8]}"
end_dt = f"{train_end[:4]}-{train_end[4:6]}-{train_end[6:8]}"
train_query = f"""
SELECT ts_code, trade_date, close, pre_close
FROM daily
WHERE trade_date >= '{start_dt}' AND trade_date <= '{end_dt}'
ORDER BY ts_code, trade_date
"""
train_raw = storage._connection.sql(train_query).pl()
# 转换 trade_date 为字符串格式
train_raw = train_raw.with_columns(
pl.col("trade_date").dt.strftime("%Y-%m-%d").alias("trade_date")
)
# 查询测试集数据(也需要历史数据计算因子)
test_start_dt = f"{test_start[:4]}-{test_start[4:6]}-{test_start[6:8]}"
test_end_dt = f"{test_end[:4]}-{test_end[4:6]}-{test_end[6:8]}"
test_query = f"""
SELECT ts_code, trade_date, close, pre_close
FROM daily
WHERE trade_date >= '{test_start_dt}' AND trade_date <= '{test_end_dt}'
ORDER BY ts_code, trade_date
"""
test_raw = storage._connection.sql(test_query).pl()
# 转换 trade_date 为字符串格式
test_raw = test_raw.with_columns(
pl.col("trade_date").dt.strftime("%Y-%m-%d").alias("trade_date")
)
# 过滤不符合条件的股票
train_raw = _filter_invalid_stocks(train_raw)
test_raw = _filter_invalid_stocks(test_raw)
print(f"[PrepareData] After filtering: train={len(train_raw)}, test={len(test_raw)}")
# 计算因子和标签
train_data = _compute_features_and_label(train_raw, train_start, train_end)
test_data = _compute_features_and_label(test_raw, test_start, test_end)
return train_data, test_data
def _filter_invalid_stocks(df: pl.DataFrame) -> pl.DataFrame:
"""过滤不符合条件的股票
过滤规则:
1. 过滤北交所股票ts_code 以 BJ 结尾)
2. 过滤创业板股票ts_code 以 30 开头)
3. 过滤科创板股票ts_code 以 68 开头)
4. 过滤退市/风险股票ts_code 以 8 开头)
Args:
df: 原始数据
Returns:
过滤后的数据
"""
ts_code_col = pl.col("ts_code")
return df.filter(
~ts_code_col.str.ends_with("BJ")
& ~ts_code_col.str.starts_with("30")
& ~ts_code_col.str.starts_with("68")
& ~ts_code_col.str.starts_with("8")
)
def _compute_features_and_label(
raw_data: pl.DataFrame,
start_date: str,
end_date: str,
) -> pl.DataFrame:
"""计算因子和标签
因子:
1. return_5_rank: 5日收益率截面排名
2. ma_5: 5日移动平均
3. ma_10: 10日移动平均
标签未来5日收益率大于0为1否则为0
Args:
raw_data: 原始日线数据
start_date: 开始日期
end_date: 结束日期
Returns:
包含因子和标签的DataFrame
"""
# 确保按日期排序
raw_data = raw_data.sort(["ts_code", "trade_date"])
# 计算收益率未来5日
raw_data = raw_data.with_columns(
[
# 当日收益率
((pl.col("close") - pl.col("pre_close")) / pl.col("pre_close")).alias(
"daily_return"
),
]
)
# 按股票分组计算
result_list = []
for ts_code in raw_data["ts_code"].unique():
stock_data = raw_data.filter(pl.col("ts_code") == ts_code).sort("trade_date")
if len(stock_data) < 20:
continue
# 计算MA5和MA10
stock_data = stock_data.with_columns(
[
pl.col("close").rolling_mean(5).alias("ma_5"),
pl.col("close").rolling_mean(10).alias("ma_10"),
]
)
# 计算未来5日收益率用于标签
future_return = stock_data["close"].shift(-5) - stock_data["close"]
future_return_pct = future_return / stock_data["close"]
stock_data = stock_data.with_columns(
[
future_return_pct.alias("future_return_5"),
]
)
# 生成标签:收益率>0为1否则为0
stock_data = stock_data.with_columns(
[
(pl.col("future_return_5") > 0).cast(pl.Int8).alias("label"),
]
)
result_list.append(stock_data)
if not result_list:
return pl.DataFrame()
result = pl.concat(result_list)
# 转换日期格式YYYYMMDD -> YYYY-MM-DD
start_date_formatted = f"{start_date[:4]}-{start_date[4:6]}-{start_date[6:8]}"
end_date_formatted = f"{end_date[:4]}-{end_date[4:6]}-{end_date[6:8]}"
# 过滤有效日期范围
result = result.filter(
(pl.col("trade_date") >= start_date_formatted) & (pl.col("trade_date") <= end_date_formatted)
)
# 计算5日收益率排名截面
result = result.with_columns(
[
pl.col("daily_return")
.rank(method="average")
.over("trade_date")
.alias("return_5_rank")
]
)
# 归一化排名到0-1
result = result.with_columns(
[
(
pl.col("return_5_rank")
/ pl.col("return_5_rank").max().over("trade_date")
).alias("return_5_rank")
]
)
# 选择需要的列
feature_cols = ["trade_date", "ts_code", "return_5_rank", "ma_5", "ma_10", "label"]
result = result.select(feature_cols)
return result
def create_pipeline() -> ProcessingPipeline:
"""创建数据处理流水线
处理流程:
1. FillNA(0): 将缺失值填充为0
注意:不使用 Dropna因为会导致训练和预测时的行数不匹配
Returns:
配置好的ProcessingPipeline
"""
processors = [
FillNAProcessor(method="zero"), # 缺失值填充为0
]
return ProcessingPipeline(processors)
def train_model(
train_data: pl.DataFrame,
feature_cols: List[str],
label_col: str = "label",
model_params: Optional[dict] = None,
) -> Tuple[LightGBMModel, ProcessingPipeline]:
"""训练LightGBM分类模型
Args:
train_data: 训练数据
feature_cols: 特征列名列表
label_col: 标签列名
model_params: 模型参数字典
Returns:
(训练好的模型, 处理流水线)
"""
# 创建处理流水线
pipeline = create_pipeline()
print("[TrainModel] Pipeline created: FillNA(0)")
# 准备特征和标签
X_train = train_data.select(feature_cols)
y_train = train_data[label_col]
print(f"[TrainModel] Raw samples: {len(X_train)}, features: {feature_cols}")
# 处理数据
X_train_processed = pipeline.fit_transform(X_train, stage=PipelineStage.TRAIN)
print(f"[TrainModel] After processing: {len(X_train_processed)} samples")
# 过滤有效标签(排除-1等无效值
valid_mask = y_train.is_in([0, 1])
X_train_processed = X_train_processed.filter(valid_mask)
y_train = y_train.filter(valid_mask)
print(f"[TrainModel] After filtering valid labels: {len(X_train_processed)} samples")
print(f"[TrainModel] Label distribution: {dict(y_train.value_counts().sort('label').iter_rows())}")
# 创建模型
params = model_params or {
"n_estimators": 100,
"learning_rate": 0.05,
"max_depth": 5,
"num_leaves": 31,
}
print(f"[TrainModel] Model params: {params}")
model = LightGBMModel(
task_type="classification",
params=params,
)
# 训练模型
print("[TrainModel] Training LightGBM...")
model.fit(X_train_processed, y_train)
print("[TrainModel] Training completed!")
return model, pipeline
def predict_top_stocks(
model: LightGBMModel,
pipeline: ProcessingPipeline,
test_data: pl.DataFrame,
feature_cols: List[str],
top_n: int = 5,
) -> pl.DataFrame:
"""预测并选出每日top N股票
Args:
model: 训练好的模型
pipeline: 数据处理流水线
test_data: 测试数据
feature_cols: 特征列名
top_n: 每日选出的股票数量
Returns:
包含日期和股票代码的DataFrame
"""
# 准备特征和必要列
X_test = test_data.select(feature_cols)
key_cols = ["trade_date", "ts_code"]
key_data = test_data.select(key_cols)
print(f"[Predict] Test samples: {len(X_test)}, top_n: {top_n}")
# 处理数据(使用训练阶段的参数)
X_test_processed = pipeline.transform(X_test, stage=PipelineStage.TEST)
print(f"[Predict] Data processed, shape: {X_test_processed.shape}")
# 预测概率
probs = model.predict_proba(X_test_processed)
print(f"[Predict] Predictions generated, probability shape: {probs.shape}")
# 使用 key_data 添加预测结果,保持行数一致
result = key_data.with_columns(
pl.Series(
name="pred_prob", values=probs[:, 1] if len(probs.shape) > 1 and probs.shape[1] > 1 else probs.flatten()
),
)
# 每日选出top N
top_stocks = []
for date in result["trade_date"].unique().sort():
day_data = result.filter(pl.col("trade_date") == date)
# 按概率降序排序选出top N
day_top = day_data.sort("pred_prob", descending=True).head(top_n)
top_stocks.append(day_top.select(["trade_date", "pred_prob", "ts_code"]).rename({"pred_prob": "score"}))
return pl.concat(top_stocks)
def save_top_stocks(top_stocks: pl.DataFrame, output_path: str) -> None:
"""保存选股结果到TSV文件
Args:
top_stocks: 选股结果
output_path: 输出文件路径
"""
# 转换为pandas并保存为TSV
df = top_stocks.to_pandas()
df.to_csv(output_path, sep="\t", index=False)
print(f"[Training] Top stocks saved to: {output_path}")
def run_training(
data_dir: str = "data",
output_path: str = "output/top_stocks.tsv",
train_start: str = "20180101",
train_end: str = "20230101",
test_start: str = "20230101",
test_end: str = "20240101",
top_n: int = 5,
) -> pl.DataFrame:
"""运行完整训练流程
Args:
data_dir: 数据目录
output_path: 输出文件路径
train_start: 训练集开始日期
train_end: 训练集结束日期
test_start: 测试集开始日期
test_end: 测试集结束日期
top_n: 每日选股数量
Returns:
选股结果DataFrame
"""
print(f"[Training] Starting training pipeline...")
print(f"[Training] Train period: {train_start} -> {train_end}")
print(f"[Training] Test period: {test_start} -> {test_end}")
# 1. 准备数据
print("[Training] Preparing data...")
train_data, test_data = prepare_data(
data_dir=data_dir,
train_start=train_start,
train_end=train_end,
test_start=test_start,
test_end=test_end,
)
print(f"[Training] Train samples: {len(train_data)}")
print(f"[Training] Test samples: {len(test_data)}")
# 2. 定义特征列
feature_cols = ["return_5_rank", "ma_5", "ma_10"]
label_col = "label"
# 3. 训练模型
print("[Training] Training model...")
model, pipeline = train_model(
train_data=train_data,
feature_cols=feature_cols,
label_col=label_col,
)
# 4. 测试集预测
print("[Training] Predicting on test set...")
top_stocks = predict_top_stocks(
model=model,
pipeline=pipeline,
test_data=test_data,
feature_cols=feature_cols,
top_n=top_n,
)
# 5. 保存结果
print(f"[Training] Saving results to {output_path}...")
Path(output_path).parent.mkdir(parents=True, exist_ok=True)
save_top_stocks(top_stocks, output_path)
print("[Training] Training completed!")
return top_stocks

View File

@@ -1,4 +1,4 @@
"""模型框架核心测试
"""Pipeline 组件库核心测试
测试核心抽象类插件注册中心处理器模型和划分策略
"""
@@ -9,7 +9,7 @@ import numpy as np
from typing import List, Optional
# 确保导入时注册所有组件
from src.models import (
from src.pipeline import (
PluginRegistry,
PipelineStage,
BaseProcessor,
@@ -17,7 +17,7 @@ from src.models import (
BaseSplitter,
ProcessingPipeline,
)
from src.models.core import TaskType
from src.pipeline.core import TaskType
# ========== 测试核心抽象类 ==========
@@ -232,7 +232,7 @@ class TestBuiltInProcessors:
def test_dropna_processor(self):
"""测试缺失值删除处理器"""
from src.models.processors import DropNAProcessor
from src.pipeline.processors import DropNAProcessor
processor = DropNAProcessor(columns=["a", "b"])
df = pl.DataFrame({"a": [1, None, 3], "b": [4, 5, None], "c": [7, 8, 9]})
@@ -246,7 +246,7 @@ class TestBuiltInProcessors:
def test_fillna_processor(self):
"""测试缺失值填充处理器"""
from src.models.processors import FillNAProcessor
from src.pipeline.processors import FillNAProcessor
processor = FillNAProcessor(columns=["a"], method="mean")
df = pl.DataFrame({"a": [1.0, 2.0, None, 4.0]})
@@ -258,7 +258,7 @@ class TestBuiltInProcessors:
def test_standard_scaler(self):
"""测试标准化处理器"""
from src.models.processors import StandardScaler
from src.pipeline.processors import StandardScaler
processor = StandardScaler(columns=["value"])
df = pl.DataFrame({"value": [1.0, 2.0, 3.0, 4.0, 5.0]})
@@ -271,7 +271,7 @@ class TestBuiltInProcessors:
def test_winsorizer(self):
"""测试缩尾处理器"""
from src.models.processors import Winsorizer
from src.pipeline.processors import Winsorizer
processor = Winsorizer(columns=["value"], lower=0.1, upper=0.9)
df = pl.DataFrame(
@@ -288,7 +288,7 @@ class TestBuiltInProcessors:
def test_rank_transformer(self):
"""测试排名转换处理器"""
from src.models.processors import RankTransformer
from src.pipeline.processors import RankTransformer
processor = RankTransformer(columns=["value"])
df = pl.DataFrame(
@@ -302,7 +302,7 @@ class TestBuiltInProcessors:
def test_neutralizer(self):
"""测试中性化处理器"""
from src.models.processors import Neutralizer
from src.pipeline.processors import Neutralizer
processor = Neutralizer(columns=["value"], group_col="industry")
df = pl.DataFrame(
@@ -331,7 +331,7 @@ class TestProcessingPipeline:
def test_pipeline_fit_transform(self):
"""测试流水线的 fit_transform"""
from src.models.processors import StandardScaler
from src.pipeline.processors import StandardScaler
scaler1 = StandardScaler(columns=["a"])
scaler2 = StandardScaler(columns=["b"])
@@ -348,7 +348,7 @@ class TestProcessingPipeline:
def test_pipeline_transform_uses_fitted_params(self):
"""测试 transform 使用已 fit 的参数"""
from src.models.processors import StandardScaler
from src.pipeline.processors import StandardScaler
scaler = StandardScaler(columns=["value"])
pipeline = ProcessingPipeline([scaler])
@@ -383,7 +383,7 @@ class TestSplitters:
def test_time_series_split(self):
"""测试时间序列划分"""
from src.models.core import TimeSeriesSplit
from src.pipeline.core import TimeSeriesSplit
splitter = TimeSeriesSplit(n_splits=2, gap=1, min_train_size=3)
@@ -406,7 +406,7 @@ class TestSplitters:
def test_walk_forward_split(self):
"""测试滚动前向划分"""
from src.models.core import WalkForwardSplit
from src.pipeline.core import WalkForwardSplit
splitter = WalkForwardSplit(train_window=5, test_window=2, gap=1)
@@ -426,7 +426,7 @@ class TestSplitters:
def test_expanding_window_split(self):
"""测试扩展窗口划分"""
from src.models.core import ExpandingWindowSplit
from src.pipeline.core import ExpandingWindowSplit
splitter = ExpandingWindowSplit(initial_train_size=3, test_window=2, gap=1)
@@ -455,7 +455,7 @@ class TestModels:
@pytest.mark.skip(reason="需要安装 lightgbm")
def test_lightgbm_model(self):
"""测试 LightGBM 模型"""
from src.models.models import LightGBMModel
from src.pipeline.models import LightGBMModel
model = LightGBMModel(task_type="regression", params={"n_estimators": 10})

View File

@@ -1,277 +1,163 @@
"""Tests for data synchronization module.
"""Sync 接口测试规范与实现。
Tests the sync module's full/incremental sync logic for daily data:
- Full sync when local data doesn't exist (from 20180101)
- Incremental sync when local data exists (from last_date + 1)
- Data integrity validation
【测试规范】
1. 所有 sync 测试只使用 2018-01-01 到 2018-04-01 的数据
2. 只测试接口是否能正常返回数据,不测试落库逻辑
3. 对于按股票查询的接口,只测试 000001.SZ、000002.SZ 两支股票
4. 使用真实 API 调用,确保接口可用性
【测试范围】
- get_daily: 日线数据接口(按股票)
- sync_all_stocks: 股票基础信息接口
- sync_trade_cal_cache: 交易日历接口
- sync_namechange: 名称变更接口
- sync_bak_basic: 备用股票基础信息接口
"""
import pytest
import pandas as pd
from unittest.mock import Mock, patch, MagicMock
from datetime import datetime, timedelta
from datetime import datetime
from src.data.sync import (
DataSync,
sync_all,
get_today_date,
get_next_date,
DEFAULT_START_DATE,
)
from src.data.storage import ThreadSafeStorage
from src.data.client import TushareClient
# 测试用常量
TEST_START_DATE = "20180101"
TEST_END_DATE = "20180401"
TEST_STOCK_CODES = ["000001.SZ", "000002.SZ"]
@pytest.fixture
def mock_storage():
"""Create a mock storage instance."""
storage = Mock(spec=ThreadSafeStorage)
storage.exists = Mock(return_value=False)
storage.load = Mock(return_value=pd.DataFrame())
storage.save = Mock(return_value={"status": "success", "rows": 0})
return storage
class TestGetDaily:
"""测试日线数据 get 接口(按股票查询)."""
def test_get_daily_single_stock(self):
"""测试 get_daily 获取单只股票数据."""
from src.data.api_wrappers.api_daily import get_daily
@pytest.fixture
def mock_client():
"""Create a mock client instance."""
return Mock(spec=TushareClient)
class TestDateUtilities:
"""Test date utility functions."""
def test_get_today_date_format(self):
"""Test today date is in YYYYMMDD format."""
result = get_today_date()
assert len(result) == 8
assert result.isdigit()
def test_get_next_date(self):
"""Test getting next date."""
result = get_next_date("20240101")
assert result == "20240102"
def test_get_next_date_year_end(self):
"""Test getting next date across year boundary."""
result = get_next_date("20241231")
assert result == "20250101"
def test_get_next_date_month_end(self):
"""Test getting next date across month boundary."""
result = get_next_date("20240131")
assert result == "20240201"
class TestDataSync:
"""Test DataSync class functionality."""
def test_get_all_stock_codes_from_daily(self, mock_storage):
"""Test getting stock codes from daily data."""
with patch("src.data.sync.ThreadSafeStorage", return_value=mock_storage):
sync = DataSync()
sync.storage = mock_storage
mock_storage.load.return_value = pd.DataFrame(
{
"ts_code": ["000001.SZ", "000001.SZ", "600000.SH"],
}
result = get_daily(
ts_code=TEST_STOCK_CODES[0],
start_date=TEST_START_DATE,
end_date=TEST_END_DATE,
)
codes = sync.get_all_stock_codes()
# 验证返回了数据
assert isinstance(result, pd.DataFrame), "get_daily 应返回 DataFrame"
assert not result.empty, "get_daily 应返回非空数据"
assert len(codes) == 2
assert "000001.SZ" in codes
assert "600000.SH" in codes
def test_get_daily_has_required_columns(self):
"""测试 get_daily 返回的数据包含必要字段."""
from src.data.api_wrappers.api_daily import get_daily
def test_get_all_stock_codes_fallback(self, mock_storage):
"""Test fallback to stock_basic when daily is empty."""
with patch("src.data.sync.ThreadSafeStorage", return_value=mock_storage):
sync = DataSync()
sync.storage = mock_storage
# First call (daily) returns empty, second call (stock_basic) returns data
mock_storage.load.side_effect = [
pd.DataFrame(), # daily empty
pd.DataFrame({"ts_code": ["000001.SZ", "600000.SH"]}), # stock_basic
]
codes = sync.get_all_stock_codes()
assert len(codes) == 2
def test_get_global_last_date(self, mock_storage):
"""Test getting global last date."""
with patch("src.data.sync.ThreadSafeStorage", return_value=mock_storage):
sync = DataSync()
sync.storage = mock_storage
mock_storage.load.return_value = pd.DataFrame(
{
"ts_code": ["000001.SZ", "600000.SH"],
"trade_date": ["20240102", "20240103"],
}
result = get_daily(
ts_code=TEST_STOCK_CODES[0],
start_date=TEST_START_DATE,
end_date=TEST_END_DATE,
)
last_date = sync.get_global_last_date()
assert last_date == "20240103"
# 验证必要的列存在
required_columns = ["ts_code", "trade_date", "open", "high", "low", "close"]
for col in required_columns:
assert col in result.columns, f"get_daily 返回应包含 {col}"
def test_get_global_last_date_empty(self, mock_storage):
"""Test getting last date from empty storage."""
with patch("src.data.sync.ThreadSafeStorage", return_value=mock_storage):
sync = DataSync()
sync.storage = mock_storage
def test_get_daily_multiple_stocks(self):
"""测试 get_daily 获取多只股票数据."""
from src.data.api_wrappers.api_daily import get_daily
mock_storage.load.return_value = pd.DataFrame()
results = {}
for code in TEST_STOCK_CODES:
result = get_daily(
ts_code=code,
start_date=TEST_START_DATE,
end_date=TEST_END_DATE,
)
results[code] = result
assert isinstance(result, pd.DataFrame), (
f"get_daily({code}) 应返回 DataFrame"
)
assert not result.empty, f"get_daily({code}) 应返回非空数据"
last_date = sync.get_global_last_date()
assert last_date is None
def test_sync_single_stock(self, mock_storage):
"""Test syncing a single stock."""
with patch("src.data.sync.ThreadSafeStorage", return_value=mock_storage):
with patch(
"src.data.sync.get_daily",
return_value=pd.DataFrame(
{
"ts_code": ["000001.SZ"],
"trade_date": ["20240102"],
}
),
):
sync = DataSync()
sync.storage = mock_storage
class TestSyncStockBasic:
"""测试股票基础信息 sync 接口."""
result = sync.sync_single_stock(
ts_code="000001.SZ",
start_date="20240101",
end_date="20240102",
def test_sync_all_stocks_returns_data(self):
"""测试 sync_all_stocks 是否能正常返回数据."""
from src.data.api_wrappers.api_stock_basic import sync_all_stocks
result = sync_all_stocks()
# 验证返回了数据
assert isinstance(result, pd.DataFrame), "sync_all_stocks 应返回 DataFrame"
assert not result.empty, "sync_all_stocks 应返回非空数据"
def test_sync_all_stocks_has_required_columns(self):
"""测试 sync_all_stocks 返回的数据包含必要字段."""
from src.data.api_wrappers.api_stock_basic import sync_all_stocks
result = sync_all_stocks()
# 验证必要的列存在
required_columns = ["ts_code"]
for col in required_columns:
assert col in result.columns, f"sync_all_stocks 返回应包含 {col}"
class TestSyncTradeCal:
"""测试交易日历 sync 接口."""
def test_sync_trade_cal_cache_returns_data(self):
"""测试 sync_trade_cal_cache 是否能正常返回数据."""
from src.data.api_wrappers.api_trade_cal import sync_trade_cal_cache
result = sync_trade_cal_cache(
start_date=TEST_START_DATE,
end_date=TEST_END_DATE,
)
assert isinstance(result, pd.DataFrame)
assert len(result) == 1
# 验证返回了数据
assert isinstance(result, pd.DataFrame), "sync_trade_cal_cache 应返回 DataFrame"
assert not result.empty, "sync_trade_cal_cache 应返回非空数据"
def test_sync_single_stock_empty(self, mock_storage):
"""Test syncing a stock with no data."""
with patch("src.data.sync.ThreadSafeStorage", return_value=mock_storage):
with patch("src.data.sync.get_daily", return_value=pd.DataFrame()):
sync = DataSync()
sync.storage = mock_storage
def test_sync_trade_cal_cache_has_required_columns(self):
"""测试 sync_trade_cal_cache 返回的数据包含必要字段."""
from src.data.api_wrappers.api_trade_cal import sync_trade_cal_cache
result = sync.sync_single_stock(
ts_code="INVALID.SZ",
start_date="20240101",
end_date="20240102",
result = sync_trade_cal_cache(
start_date=TEST_START_DATE,
end_date=TEST_END_DATE,
)
assert result.empty
# 验证必要的列存在
required_columns = ["cal_date", "is_open"]
for col in required_columns:
assert col in result.columns, f"sync_trade_cal_cache 返回应包含 {col}"
class TestSyncAll:
"""Test sync_all function."""
class TestSyncNamechange:
"""测试名称变更 sync 接口."""
def test_full_sync_mode(self, mock_storage):
"""Test full sync mode when force_full=True."""
with patch("src.data.sync.ThreadSafeStorage", return_value=mock_storage):
with patch("src.data.sync.get_daily", return_value=pd.DataFrame()):
sync = DataSync()
sync.storage = mock_storage
sync.sync_single_stock = Mock(return_value=pd.DataFrame())
def test_sync_namechange_returns_data(self):
"""测试 sync_namechange 是否能正常返回数据."""
from src.data.api_wrappers.api_namechange import sync_namechange
mock_storage.load.return_value = pd.DataFrame(
{
"ts_code": ["000001.SZ"],
}
result = sync_namechange()
# 验证返回了数据(可能是空 DataFrame因为是历史变更
assert isinstance(result, pd.DataFrame), "sync_namechange 应返回 DataFrame"
class TestSyncBakBasic:
"""测试备用股票基础信息 sync 接口."""
def test_sync_bak_basic_returns_data(self):
"""测试 sync_bak_basic 是否能正常返回数据."""
from src.data.api_wrappers.api_bak_basic import sync_bak_basic
result = sync_bak_basic(
start_date=TEST_START_DATE,
end_date=TEST_END_DATE,
)
result = sync.sync_all(force_full=True)
# Verify sync_single_stock was called with default start date
sync.sync_single_stock.assert_called_once()
call_args = sync.sync_single_stock.call_args
assert call_args[1]["start_date"] == DEFAULT_START_DATE
def test_incremental_sync_mode(self, mock_storage):
"""Test incremental sync mode when data exists."""
with patch("src.data.sync.ThreadSafeStorage", return_value=mock_storage):
sync = DataSync()
sync.storage = mock_storage
sync.sync_single_stock = Mock(return_value=pd.DataFrame())
# Mock existing data with last date
mock_storage.load.side_effect = [
pd.DataFrame(
{
"ts_code": ["000001.SZ"],
"trade_date": ["20240102"],
}
), # get_all_stock_codes
pd.DataFrame(
{
"ts_code": ["000001.SZ"],
"trade_date": ["20240102"],
}
), # get_global_last_date
]
result = sync.sync_all(force_full=False)
# Verify sync_single_stock was called with next date
sync.sync_single_stock.assert_called_once()
call_args = sync.sync_single_stock.call_args
assert call_args[1]["start_date"] == "20240103"
def test_manual_start_date(self, mock_storage):
"""Test sync with manual start date."""
with patch("src.data.sync.ThreadSafeStorage", return_value=mock_storage):
sync = DataSync()
sync.storage = mock_storage
sync.sync_single_stock = Mock(return_value=pd.DataFrame())
mock_storage.load.return_value = pd.DataFrame(
{
"ts_code": ["000001.SZ"],
}
)
result = sync.sync_all(force_full=False, start_date="20230601")
sync.sync_single_stock.assert_called_once()
call_args = sync.sync_single_stock.call_args
assert call_args[1]["start_date"] == "20230601"
def test_no_stocks_found(self, mock_storage):
"""Test sync when no stocks are found."""
with patch("src.data.sync.ThreadSafeStorage", return_value=mock_storage):
sync = DataSync()
sync.storage = mock_storage
mock_storage.load.return_value = pd.DataFrame()
result = sync.sync_all()
assert result == {}
class TestSyncAllConvenienceFunction:
"""Test sync_all convenience function."""
def test_sync_all_function(self):
"""Test sync_all convenience function."""
with patch("src.data.sync.DataSync") as MockSync:
mock_instance = Mock()
mock_instance.sync_all.return_value = {}
MockSync.return_value = mock_instance
result = sync_all(force_full=True)
MockSync.assert_called_once()
mock_instance.sync_all.assert_called_once_with(
force_full=True,
start_date=None,
end_date=None,
dry_run=False,
)
# 验证返回了数据
assert isinstance(result, pd.DataFrame), "sync_bak_basic 应返回 DataFrame"
# 注意bak_basic 可能返回空数据,这是正常的
if __name__ == "__main__":

View File

@@ -1,256 +0,0 @@
"""Tests for data sync with REAL data (read-only).
Tests verify:
1. get_global_last_date() correctly reads local data's max date
2. Incremental sync date calculation (local_last_date + 1)
3. Full sync date calculation (20180101)
4. Multi-stock scenario with real data
⚠️ IMPORTANT: These tests ONLY read data, no write operations.
- NO sync_all() calls (writes daily.h5)
- NO check_sync_needed() calls (writes trade_cal.h5)
"""
import pytest
import pandas as pd
from pathlib import Path
from src.data.sync import (
DataSync,
get_next_date,
DEFAULT_START_DATE,
)
from src.data.storage import Storage
class TestDataSyncReadOnly:
"""Read-only tests for data sync - verify date calculation logic."""
@pytest.fixture
def storage(self):
"""Create storage instance."""
return Storage()
@pytest.fixture
def data_sync(self):
"""Create DataSync instance."""
return DataSync()
@pytest.fixture
def daily_exists(self, storage):
"""Check if daily.h5 exists."""
return storage.exists("daily")
def test_daily_h5_exists(self, storage):
"""Verify daily.h5 data file exists before running tests."""
assert storage.exists("daily"), (
"daily.h5 not found. Please run full sync first: "
"uv run python -c 'from src.data.sync import sync_all; sync_all(force_full=True)'"
)
def test_get_global_last_date(self, data_sync, daily_exists):
"""Test get_global_last_date returns correct max date from local data."""
if not daily_exists:
pytest.skip("daily.h5 not found")
last_date = data_sync.get_global_last_date()
# Verify it's a valid date string
assert last_date is not None, "get_global_last_date returned None"
assert isinstance(last_date, str), f"Expected str, got {type(last_date)}"
assert len(last_date) == 8, f"Expected 8-digit date, got {last_date}"
assert last_date.isdigit(), f"Expected numeric date, got {last_date}"
# Verify by reading storage directly
daily_data = data_sync.storage.load("daily")
expected_max = str(daily_data["trade_date"].max())
assert last_date == expected_max, (
f"get_global_last_date returned {last_date}, "
f"but actual max date is {expected_max}"
)
print(f"[TEST] Local data last date: {last_date}")
def test_incremental_sync_date_calculation(self, data_sync, daily_exists):
"""Test incremental sync: start_date = local_last_date + 1.
This verifies that when local data exists, incremental sync should
fetch data from (local_last_date + 1), not from 20180101.
"""
if not daily_exists:
pytest.skip("daily.h5 not found")
# Get local last date
local_last_date = data_sync.get_global_last_date()
assert local_last_date is not None, "No local data found"
# Calculate expected incremental start date
expected_start_date = get_next_date(local_last_date)
# Verify the calculation is correct
local_last_int = int(local_last_date)
expected_int = local_last_int + 1
actual_int = int(expected_start_date)
assert actual_int == expected_int, (
f"Incremental start date calculation error: "
f"expected {expected_int}, got {actual_int}"
)
print(
f"[TEST] Incremental sync: local_last={local_last_date}, "
f"start_date should be {expected_start_date}"
)
# Verify this is NOT 20180101 (would be full sync)
assert expected_start_date != DEFAULT_START_DATE, (
f"Incremental sync should NOT start from {DEFAULT_START_DATE}"
)
def test_full_sync_date_calculation(self):
"""Test full sync: start_date = 20180101 when force_full=True.
This verifies that force_full=True always starts from 20180101.
"""
# Full sync should always use DEFAULT_START_DATE
full_sync_start = DEFAULT_START_DATE
assert full_sync_start == "20180101", (
f"Full sync should start from 20180101, got {full_sync_start}"
)
print(f"[TEST] Full sync start date: {full_sync_start}")
def test_date_comparison_logic(self, data_sync, daily_exists):
"""Test date comparison: incremental vs full sync selection logic.
Verify that:
- If local_last_date < today: incremental sync needed
- If local_last_date >= today: no sync needed
"""
if not daily_exists:
pytest.skip("daily.h5 not found")
from datetime import datetime
local_last_date = data_sync.get_global_last_date()
today = datetime.now().strftime("%Y%m%d")
local_last_int = int(local_last_date)
today_int = int(today)
# Log the comparison
print(
f"[TEST] Date comparison: local_last={local_last_date} ({local_last_int}), "
f"today={today} ({today_int})"
)
# This test just verifies the comparison logic works
if local_last_int < today_int:
print("[TEST] Local data is older than today - sync needed")
# Incremental sync should fetch from local_last_date + 1
sync_start = get_next_date(local_last_date)
assert int(sync_start) > local_last_int, (
"Sync start should be after local last"
)
else:
print("[TEST] Local data is up-to-date - no sync needed")
def test_get_all_stock_codes_real_data(self, data_sync, daily_exists):
"""Test get_all_stock_codes returns multiple real stock codes."""
if not daily_exists:
pytest.skip("daily.h5 not found")
codes = data_sync.get_all_stock_codes()
# Verify it's a list
assert isinstance(codes, list), f"Expected list, got {type(codes)}"
assert len(codes) > 0, "No stock codes found"
# Verify multiple stocks
assert len(codes) >= 10, (
f"Expected at least 10 stocks for multi-stock test, got {len(codes)}"
)
# Verify format (should be like 000001.SZ, 600000.SH)
sample_codes = codes[:5]
for code in sample_codes:
assert "." in code, f"Invalid stock code format: {code}"
suffix = code.split(".")[-1]
assert suffix in ["SZ", "SH"], f"Invalid exchange suffix: {suffix}"
print(f"[TEST] Found {len(codes)} stock codes (sample: {sample_codes})")
def test_multi_stock_date_range(self, data_sync, daily_exists):
"""Test that multiple stocks share the same date range in local data.
This verifies that local data has consistent date coverage across stocks.
"""
if not daily_exists:
pytest.skip("daily.h5 not found")
daily_data = data_sync.storage.load("daily")
# Get date range for each stock
stock_dates = daily_data.groupby("ts_code")["trade_date"].agg(["min", "max"])
# Get global min and max
global_min = str(daily_data["trade_date"].min())
global_max = str(daily_data["trade_date"].max())
print(f"[TEST] Global date range: {global_min} to {global_max}")
print(f"[TEST] Total stocks: {len(stock_dates)}")
# Verify we have data for multiple stocks
assert len(stock_dates) >= 10, (
f"Expected at least 10 stocks, got {len(stock_dates)}"
)
# Verify date range is reasonable (at least 1 year of data)
global_min_int = int(global_min)
global_max_int = int(global_max)
days_span = global_max_int - global_min_int
assert days_span > 100, (
f"Date range too small: {days_span} days. "
f"Expected at least 100 days of data."
)
print(f"[TEST] Date span: {days_span} days")
class TestDateUtilities:
"""Test date utility functions."""
def test_get_next_date(self):
"""Test get_next_date correctly calculates next day."""
# Test normal cases
assert get_next_date("20240101") == "20240102"
assert get_next_date("20240131") == "20240201" # Month boundary
assert get_next_date("20241231") == "20250101" # Year boundary
def test_incremental_vs_full_sync_logic(self):
"""Test the logic difference between incremental and full sync.
Incremental: start_date = local_last_date + 1
Full: start_date = 20180101
"""
# Scenario 1: Local data exists
local_last_date = "20240115"
incremental_start = get_next_date(local_last_date)
assert incremental_start == "20240116"
assert incremental_start != DEFAULT_START_DATE
# Scenario 2: Force full sync
full_sync_start = DEFAULT_START_DATE # "20180101"
assert full_sync_start == "20180101"
assert incremental_start != full_sync_start
print("[TEST] Incremental vs Full sync logic verified")
if __name__ == "__main__":
pytest.main([__file__, "-v", "-s"])