Compare commits
2 Commits
181994f063
...
ca27cb297a
| Author | SHA1 | Date | |
|---|---|---|---|
| ca27cb297a | |||
| a22bc2d282 |
@@ -85,8 +85,7 @@ ProStock/
|
|||||||
│ ├── data/ # 数据获取与存储
|
│ ├── data/ # 数据获取与存储
|
||||||
│ │ ├── api_wrappers/ # Tushare API 封装
|
│ │ ├── api_wrappers/ # Tushare API 封装
|
||||||
│ │ │ ├── base_sync.py # 同步基础抽象类
|
│ │ │ ├── base_sync.py # 同步基础抽象类
|
||||||
│ │ │ ├── api_daily.py # 日线数据接口
|
│ │ │ ├── api_pro_bar.py # Pro Bar 行情数据接口(主用)
|
||||||
│ │ │ ├── api_pro_bar.py # Pro Bar 数据接口
|
|
||||||
│ │ │ ├── api_stock_basic.py # 股票基础信息接口
|
│ │ │ ├── api_stock_basic.py # 股票基础信息接口
|
||||||
│ │ │ ├── api_trade_cal.py # 交易日历接口
|
│ │ │ ├── api_trade_cal.py # 交易日历接口
|
||||||
│ │ │ ├── api_bak_basic.py # 历史股票列表接口
|
│ │ │ ├── api_bak_basic.py # 历史股票列表接口
|
||||||
|
|||||||
433
README.md
433
README.md
@@ -1,300 +1,211 @@
|
|||||||
# ProStock
|
# ProStock
|
||||||
|
|
||||||
A股量化投资框架 - 从数据获取到模型训练的完整解决方案
|
A股量化投资框架,用于量化股票投资分析。
|
||||||
|
|
||||||
## 功能特性
|
## 特性
|
||||||
|
|
||||||
### 1. 数据层 (src/data/)
|
- **数据管理**:Tushare API 行情数据获取,DuckDB 本地数据存储
|
||||||
- **多源数据接入**: Tushare API 集成,支持日线、股票基础信息、交易日历
|
- **因子引擎**:DSL 表达式驱动的高性能因子计算框架(基于 Polars)
|
||||||
- **DuckDB 存储**: 高性能嵌入式数据库,支持 SQL 查询下推
|
- **机器学习**:支持 LightGBM 回归和 LambdaRank 排序学习
|
||||||
- **智能同步**: 增量/全量同步策略,自动检测数据更新需求
|
- **组件化设计**:灵活的数据处理器、股票池管理、过滤器组合
|
||||||
- **速率控制**: 令牌桶算法实现 API 限流
|
|
||||||
- **并发优化**: ThreadPoolExecutor 多线程数据获取
|
|
||||||
|
|
||||||
### 2. 因子层 (src/factors/)
|
## 环境要求
|
||||||
- **类型安全**: 严格的截面因子 vs 时序因子区分
|
|
||||||
- **防泄露机制**: 框架层面防止未来数据和跨股票数据泄露
|
|
||||||
- **因子组合**: 支持因子加减乘除和标量运算
|
|
||||||
- **高性能计算**: Polars 向量化操作,零拷贝数据导出
|
|
||||||
- **灵活扩展**: 基类抽象便于自定义因子
|
|
||||||
|
|
||||||
### 3. 模型层 (src/models/)
|
- Python 3.10+
|
||||||
- **插件架构**: 装饰器注册机制,新模型即插即用
|
- uv 包管理器
|
||||||
- **阶段感知**: 训练/测试阶段区分,防止数据泄露
|
|
||||||
- **多模型支持**: LightGBM、CatBoost 等模型统一接口
|
## 安装
|
||||||
- **数据处理**: 缺失值处理、缩尾、标准化、中性化等
|
|
||||||
- **时序划分**: WalkForward、ExpandingWindow 等时间序列划分策略
|
```bash
|
||||||
|
# 克隆项目
|
||||||
|
cd ProStock
|
||||||
|
|
||||||
|
# 使用 uv 安装依赖
|
||||||
|
uv pip install -e .
|
||||||
|
```
|
||||||
|
|
||||||
|
## 配置
|
||||||
|
|
||||||
|
创建 `config/.env.local` 文件:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Tushare Token(必需)
|
||||||
|
TUSHARE_TOKEN=your_token_here
|
||||||
|
|
||||||
|
# 数据存储路径(可选,默认 data/)
|
||||||
|
DATA_PATH=data
|
||||||
|
|
||||||
|
# API 速率限制(可选,默认 100)
|
||||||
|
RATE_LIMIT=100
|
||||||
|
|
||||||
|
# 并发线程数(可选,默认 10)
|
||||||
|
THREADS=10
|
||||||
|
```
|
||||||
|
|
||||||
|
## 快速开始
|
||||||
|
|
||||||
|
### 1. 同步股票数据
|
||||||
|
|
||||||
|
```python
|
||||||
|
from src.data.sync import sync_all
|
||||||
|
|
||||||
|
# 增量同步(默认)
|
||||||
|
sync_all()
|
||||||
|
|
||||||
|
# 强制全量同步
|
||||||
|
sync_all(force_full=True)
|
||||||
|
|
||||||
|
# 自定义线程数
|
||||||
|
sync_all(max_workers=20)
|
||||||
|
```
|
||||||
|
|
||||||
|
### 2. 计算因子
|
||||||
|
|
||||||
|
```python
|
||||||
|
from src.factors import FactorEngine
|
||||||
|
|
||||||
|
# 初始化引擎
|
||||||
|
engine = FactorEngine()
|
||||||
|
|
||||||
|
# 添加因子(推荐使用字符串表达式)
|
||||||
|
engine.add_factor("ma20", "ts_mean(close, 20)")
|
||||||
|
engine.add_factor("alpha", "cs_rank(ts_mean(close, 5) - ts_mean(close, 20))")
|
||||||
|
|
||||||
|
# 计算因子值
|
||||||
|
result = engine.compute(["ma20", "alpha"], "20240101", "20240131")
|
||||||
|
```
|
||||||
|
|
||||||
|
### 3. 训练模型
|
||||||
|
|
||||||
|
```python
|
||||||
|
from src.training import Trainer, DateSplitter, StockPoolManager
|
||||||
|
from src.training.components.models import LightGBMModel
|
||||||
|
|
||||||
|
# 创建模型
|
||||||
|
model = LightGBMModel(params={
|
||||||
|
"objective": "regression",
|
||||||
|
"num_leaves": 20,
|
||||||
|
"learning_rate": 0.01,
|
||||||
|
"n_estimators": 1000,
|
||||||
|
})
|
||||||
|
|
||||||
|
# 创建数据划分器
|
||||||
|
splitter = DateSplitter(
|
||||||
|
train_start="20200101",
|
||||||
|
train_end="20231231",
|
||||||
|
val_start="20240101",
|
||||||
|
val_end="20241231",
|
||||||
|
test_start="20250101",
|
||||||
|
test_end="20251231",
|
||||||
|
)
|
||||||
|
|
||||||
|
# 创建训练器并训练
|
||||||
|
trainer = Trainer(
|
||||||
|
model=model,
|
||||||
|
splitter=splitter,
|
||||||
|
target_col="future_return_5",
|
||||||
|
feature_cols=["ma_5", "ma_20", "volume_ratio"],
|
||||||
|
)
|
||||||
|
|
||||||
|
trainer.train(data)
|
||||||
|
results = trainer.get_results()
|
||||||
|
```
|
||||||
|
|
||||||
## 项目结构
|
## 项目结构
|
||||||
|
|
||||||
```
|
```
|
||||||
ProStock/
|
ProStock/
|
||||||
├── src/
|
├── src/
|
||||||
│ ├── config/ # 配置管理
|
│ ├── config/ # 配置管理
|
||||||
│ │ ├── settings.py # pydantic-settings 配置
|
│ ├── data/ # 数据获取与存储
|
||||||
│ │ └── __init__.py
|
│ │ ├── api_wrappers/ # Tushare API 封装
|
||||||
│ │
|
│ │ ├── storage.py # DuckDB 存储
|
||||||
│ ├── data/ # 数据获取与存储
|
│ │ └── sync.py # 数据同步调度
|
||||||
│ │ ├── api_wrappers/ # Tushare API 封装
|
│ ├── factors/ # 因子计算框架
|
||||||
│ │ │ ├── api_daily.py # 日线数据接口
|
│ │ ├── engine/ # 执行引擎
|
||||||
│ │ │ ├── api_stock_basic.py # 股票基础信息
|
│ │ ├── metadata/ # 因子元数据管理
|
||||||
│ │ │ └── api_trade_cal.py # 交易日历
|
│ │ ├── dsl.py # DSL 表达式层
|
||||||
│ │ ├── client.py # Tushare 客户端(含限流)
|
│ │ └── translator.py # Polars 翻译器
|
||||||
│ │ ├── config.py # 数据模块配置
|
│ └── training/ # 训练模块
|
||||||
│ │ ├── db_manager.py # DuckDB 表管理和同步
|
│ ├── core/ # 训练核心
|
||||||
│ │ ├── db_inspector.py # 数据库信息查看工具
|
│ └── components/ # 组件(模型、处理器、过滤器)
|
||||||
│ │ ├── rate_limiter.py # 令牌桶限流器
|
├── tests/ # 测试文件
|
||||||
│ │ ├── storage.py # DuckDB 存储核心
|
├── data/ # 数据存储
|
||||||
│ │ ├── sync.py # 数据同步主逻辑
|
└── docs/ # 文档
|
||||||
│ │ └── __init__.py
|
|
||||||
│ │
|
|
||||||
│ ├── factors/ # 因子计算框架
|
|
||||||
│ │ ├── base.py # 因子基类(截面/时序)
|
|
||||||
│ │ ├── composite.py # 组合因子和标量运算
|
|
||||||
│ │ ├── data_loader.py # DuckDB 数据加载器
|
|
||||||
│ │ ├── data_spec.py # 数据规格定义
|
|
||||||
│ │ ├── engine.py # 因子执行引擎
|
|
||||||
│ │ └── __init__.py
|
|
||||||
│ │
|
|
||||||
│ ├── models/ # 模型训练框架
|
|
||||||
│ │ ├── core/ # 核心抽象
|
|
||||||
│ │ │ ├── base.py # 处理器/模型/划分基类
|
|
||||||
│ │ │ └── splitter.py # 时间序列划分策略
|
|
||||||
│ │ ├── models/ # 模型实现
|
|
||||||
│ │ │ └── models.py # LightGBM、CatBoost
|
|
||||||
│ │ ├── processors/ # 数据处理器
|
|
||||||
│ │ │ └── processors.py # 标准化、缩尾、中性化等
|
|
||||||
│ │ ├── pipeline.py # 处理流水线
|
|
||||||
│ │ ├── registry.py # 插件注册中心
|
|
||||||
│ │ └── __init__.py
|
|
||||||
│ │
|
|
||||||
│ └── __init__.py
|
|
||||||
│
|
|
||||||
├── docs/ # 文档
|
|
||||||
│ ├── factor_framework_design.md # 因子框架设计
|
|
||||||
│ ├── ml_framework_design.md # 模型框架设计
|
|
||||||
│ ├── db_sync_guide.md # 数据同步指南
|
|
||||||
│ └── ...
|
|
||||||
│
|
|
||||||
├── data/ # 数据存储(DuckDB)
|
|
||||||
│ ├── prostock.db # 主数据库文件
|
|
||||||
│ └── stock_basic.csv # 股票基础信息缓存
|
|
||||||
│
|
|
||||||
├── config/ # 配置文件
|
|
||||||
│ └── .env.local # 环境变量(API Token等)
|
|
||||||
│
|
|
||||||
└── tests/ # 测试文件
|
|
||||||
├── test_sync.py
|
|
||||||
└── factors/
|
|
||||||
```
|
```
|
||||||
|
|
||||||
## 快速开始
|
## 因子框架
|
||||||
|
|
||||||
### 1. 安装依赖
|
### 支持的函数
|
||||||
|
|
||||||
**⚠️ 本项目强制使用 uv 作为 Python 包管理器**
|
**时间序列函数 (ts_*)**:
|
||||||
|
- `ts_mean`, `ts_std`, `ts_max`, `ts_min`, `ts_sum`
|
||||||
|
- `ts_delay`, `ts_delta`
|
||||||
|
- `ts_corr`, `ts_cov`, `ts_rank`
|
||||||
|
|
||||||
```bash
|
**截面函数 (cs_*)**:
|
||||||
# 安装 uv (如果尚未安装)
|
- `cs_rank` - 截面排名
|
||||||
pip install uv
|
- `cs_zscore` - Z-Score 标准化
|
||||||
|
- `cs_neutralize` - 行业/市值中性化
|
||||||
|
- `cs_winsorize` - 缩尾处理
|
||||||
|
|
||||||
# 安装项目依赖
|
**数学函数**:
|
||||||
uv pip install -e .
|
- `log`, `exp`, `sqrt`, `sign`, `abs`
|
||||||
```
|
- `max_`, `min_`, `clip`
|
||||||
|
- `if_`, `where`
|
||||||
|
|
||||||
### 2. 配置环境变量
|
### 因子元数据管理
|
||||||
|
|
||||||
创建 `config/.env.local` 文件:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
TUSHARE_TOKEN=your_tushare_token_here
|
|
||||||
DATA_PATH=data
|
|
||||||
RATE_LIMIT=100
|
|
||||||
THREADS=10
|
|
||||||
```
|
|
||||||
|
|
||||||
### 3. 数据同步
|
|
||||||
|
|
||||||
```bash
|
|
||||||
# 首次同步 - 全量同步(从20180101开始)
|
|
||||||
uv run python -c "from src.data.sync import sync_all; sync_all(force_full=True)"
|
|
||||||
|
|
||||||
# 日常同步 - 增量同步(自动从最新日期开始)
|
|
||||||
uv run python -c "from src.data.sync import sync_all; sync_all()"
|
|
||||||
|
|
||||||
# 预览同步(检查需要同步的数据量)
|
|
||||||
uv run python -c "from src.data.sync import preview_sync; preview_sync()"
|
|
||||||
|
|
||||||
# 自定义线程数
|
|
||||||
uv run python -c "from src.data.sync import sync_all; sync_all(max_workers=20)"
|
|
||||||
```
|
|
||||||
|
|
||||||
### 4. 查看数据库状态
|
|
||||||
|
|
||||||
```bash
|
|
||||||
uv run python -c "from src.data.db_inspector import get_db_info; get_db_info()"
|
|
||||||
```
|
|
||||||
|
|
||||||
## 使用示例
|
|
||||||
|
|
||||||
### 因子计算
|
|
||||||
|
|
||||||
```python
|
```python
|
||||||
from src.factors import FactorEngine, DataLoader, DataSpec
|
from src.factors.metadata import FactorManager
|
||||||
from src.factors.base import CrossSectionalFactor, TimeSeriesFactor
|
|
||||||
import polars as pl
|
|
||||||
|
|
||||||
# 自定义截面因子:PE排名
|
# 初始化管理器
|
||||||
class PERankFactor(CrossSectionalFactor):
|
manager = FactorManager()
|
||||||
name = "pe_rank"
|
|
||||||
data_specs = [DataSpec("daily", ["ts_code", "trade_date", "pe"], lookback_days=1)]
|
|
||||||
|
|
||||||
def compute(self, data) -> pl.Series:
|
# 添加因子
|
||||||
cs = data.get_cross_section()
|
manager.add_factor({
|
||||||
return cs["pe"].rank()
|
"factor_id": "F_001",
|
||||||
|
"name": "mom_5d",
|
||||||
|
"desc": "5日价格动量",
|
||||||
|
"dsl": "cs_rank(close / ts_delay(close, 5) - 1)",
|
||||||
|
"category": "momentum",
|
||||||
|
})
|
||||||
|
|
||||||
# 自定义时序因子:20日移动平均
|
# 查询因子
|
||||||
class MA20Factor(TimeSeriesFactor):
|
df = manager.get_factors_by_name("mom_5d")
|
||||||
name = "ma20"
|
|
||||||
data_specs = [DataSpec("daily", ["ts_code", "trade_date", "close"], lookback_days=20)]
|
|
||||||
|
|
||||||
def compute(self, data) -> pl.Series:
|
|
||||||
return data.get_column("close").rolling_mean(window_size=20)
|
|
||||||
|
|
||||||
# 执行计算
|
|
||||||
loader = DataLoader(data_dir="data")
|
|
||||||
engine = FactorEngine(loader)
|
|
||||||
|
|
||||||
# 计算截面因子
|
|
||||||
pe_rank = PERankFactor()
|
|
||||||
result1 = engine.compute(pe_rank, start_date="20240101", end_date="20240131")
|
|
||||||
|
|
||||||
# 计算时序因子
|
|
||||||
ma20 = MA20Factor()
|
|
||||||
result2 = engine.compute(ma20, stock_codes=["000001.SZ"],
|
|
||||||
start_date="20240101", end_date="20240131")
|
|
||||||
|
|
||||||
# 因子组合
|
|
||||||
combined = 0.5 * pe_rank + 0.3 * ma20
|
|
||||||
```
|
```
|
||||||
|
|
||||||
### 模型训练
|
## 常见任务
|
||||||
|
|
||||||
```python
|
```bash
|
||||||
from src.models import PluginRegistry, ProcessingPipeline
|
# 运行所有测试
|
||||||
from src.models.core import PipelineStage
|
uv run pytest
|
||||||
import polars as pl
|
|
||||||
|
|
||||||
# 创建处理流水线
|
# 同步财务数据
|
||||||
pipeline = ProcessingPipeline([
|
uv run python -c "from src.data.api_wrappers.financial_data import sync_financial; sync_financial()"
|
||||||
PluginRegistry.get_processor("dropna")(),
|
|
||||||
PluginRegistry.get_processor("winsorizer")(lower=0.01, upper=0.99),
|
|
||||||
PluginRegistry.get_processor("standard_scaler")(),
|
|
||||||
])
|
|
||||||
|
|
||||||
# 准备数据
|
# 批量注册因子
|
||||||
data = pl.read_csv("features.csv") # 包含特征和标签
|
uv run python src/scripts/register_factors.py
|
||||||
|
|
||||||
# 划分训练/测试集
|
|
||||||
from src.models.core import WalkForwardSplit
|
|
||||||
splitter = WalkForwardSplit(train_window=252, test_window=21)
|
|
||||||
|
|
||||||
# 获取 LightGBM 模型
|
|
||||||
ModelClass = PluginRegistry.get_model("lightgbm")
|
|
||||||
model = ModelClass(task_type="regression", params={"n_estimators": 100})
|
|
||||||
|
|
||||||
# 训练循环
|
|
||||||
for train_idx, test_idx in splitter.split(data):
|
|
||||||
train_data = data[train_idx]
|
|
||||||
test_data = data[test_idx]
|
|
||||||
|
|
||||||
# 数据处理
|
|
||||||
X_train = pipeline.fit_transform(train_data.drop("target"))
|
|
||||||
X_test = pipeline.transform(test_data.drop("target"))
|
|
||||||
y_train = train_data["target"]
|
|
||||||
y_test = test_data["target"]
|
|
||||||
|
|
||||||
# 训练模型
|
|
||||||
model.fit(X_train, y_train)
|
|
||||||
predictions = model.predict(X_test)
|
|
||||||
```
|
```
|
||||||
|
|
||||||
## 核心设计
|
## 依赖项
|
||||||
|
|
||||||
### 1. 数据防泄露机制
|
- pandas >= 2.0.0
|
||||||
|
- polars >= 0.20.0
|
||||||
**截面因子 (CrossSectionalFactor)**:
|
- numpy >= 1.24.0
|
||||||
- 防止日期泄露:每天只传入 `[T-lookback+1, T]` 数据
|
- tushare >= 2.0.0
|
||||||
- 允许股票间比较:传入当天所有股票数据
|
- pydantic >= 2.0.0
|
||||||
- 典型应用:PE排名、市值分位数、当日收益率排名
|
- lightgbm >= 4.0.0
|
||||||
|
- pytest
|
||||||
**时序因子 (TimeSeriesFactor)**:
|
|
||||||
- 防止股票泄露:每只股票单独计算
|
|
||||||
- 允许历史数据访问:传入完整时间序列
|
|
||||||
- 典型应用:移动平均线、RSI、历史波动率
|
|
||||||
|
|
||||||
### 2. 插件注册机制
|
|
||||||
|
|
||||||
```python
|
|
||||||
from src.models.registry import PluginRegistry
|
|
||||||
|
|
||||||
# 注册自定义处理器
|
|
||||||
@PluginRegistry.register_processor("my_processor")
|
|
||||||
class MyProcessor(BaseProcessor):
|
|
||||||
stage = PipelineStage.TRAIN
|
|
||||||
|
|
||||||
def fit(self, data):
|
|
||||||
# 学习参数
|
|
||||||
return self
|
|
||||||
|
|
||||||
def transform(self, data):
|
|
||||||
# 转换数据
|
|
||||||
return data
|
|
||||||
|
|
||||||
# 使用
|
|
||||||
processor_class = PluginRegistry.get_processor("my_processor")
|
|
||||||
processor = processor_class()
|
|
||||||
```
|
|
||||||
|
|
||||||
### 3. 数据同步策略
|
|
||||||
|
|
||||||
**智能增量同步**:
|
|
||||||
```python
|
|
||||||
from src.data.db_manager import SyncManager
|
|
||||||
|
|
||||||
manager = SyncManager()
|
|
||||||
result = manager.sync(
|
|
||||||
table_name="daily",
|
|
||||||
fetch_func=get_daily,
|
|
||||||
start_date="20240101",
|
|
||||||
end_date="20240131"
|
|
||||||
)
|
|
||||||
# 自动检测:表不存在→全量,表存在→增量
|
|
||||||
```
|
|
||||||
|
|
||||||
## 文档
|
## 文档
|
||||||
|
|
||||||
- [因子框架设计](docs/factor_framework_design.md) - 因子计算架构详解
|
更多详细信息请参阅 `docs/` 目录:
|
||||||
- [模型框架设计](docs/ml_framework_design.md) - 模型训练架构详解
|
|
||||||
- [数据同步指南](docs/db_sync_guide.md) - DuckDB 数据同步 API 说明
|
|
||||||
- [代码审查报告](docs/code_review_factors_20260222.md) - 因子框架代码审查
|
|
||||||
|
|
||||||
## 开发规范
|
- [因子表达式文档](docs/factor_expressions_document.md)
|
||||||
|
- [API 接口规范](docs/api/API_INTERFACE_SPEC.md)
|
||||||
- **Python 版本**: 3.10+
|
- [财务数据接口](docs/api/FINANCIAL_API_SPEC.md)
|
||||||
- **代码风格**: Google 风格文档字符串
|
|
||||||
- **类型提示**: 强制类型注解
|
|
||||||
- **测试**: pytest 框架
|
|
||||||
- **包管理**: uv (禁止直接使用 pip/python)
|
|
||||||
|
|
||||||
## 技术栈
|
|
||||||
|
|
||||||
- **数据处理**: Polars, Pandas, NumPy
|
|
||||||
- **数据存储**: DuckDB (嵌入式 OLAP 数据库)
|
|
||||||
- **API 接口**: Tushare Pro
|
|
||||||
- **机器学习**: LightGBM, CatBoost, scikit-learn
|
|
||||||
- **配置管理**: pydantic-settings
|
|
||||||
|
|
||||||
## 许可证
|
## 许可证
|
||||||
|
|
||||||
MIT License
|
MIT
|
||||||
|
|||||||
@@ -776,9 +776,9 @@ Skill 会自动:
|
|||||||
- [ ] 测试覆盖正常和异常情况
|
- [ ] 测试覆盖正常和异常情况
|
||||||
## 11. 示例参考
|
## 11. 示例参考
|
||||||
|
|
||||||
### 11.1 完整示例:api_daily.py
|
### 11.1 完整示例:api_pro_bar.py
|
||||||
|
|
||||||
参见 `src/data/api_wrappers/api_daily.py` - 按股票获取日线数据的完整实现。
|
参见 `src/data/api_wrappers/api_pro_bar.py` - 按股票获取 Pro Bar 行情数据的完整实现(主力行情表)。
|
||||||
|
|
||||||
### 11.2 完整示例:api_trade_cal.py
|
### 11.2 完整示例:api_trade_cal.py
|
||||||
|
|
||||||
|
|||||||
@@ -222,7 +222,7 @@ def _infer_data_specs(self, node, dependencies):
|
|||||||
```
|
```
|
||||||
|
|
||||||
**DataSpec 说明**:
|
**DataSpec 说明**:
|
||||||
- `table`: 数据表名(pro_bar 或 daily)
|
- `table`: 数据表名(pro_bar 为主力行情表)
|
||||||
- `columns`: 需要的字段列表
|
- `columns`: 需要的字段列表
|
||||||
|
|
||||||
**注意**:数据获取使用用户传入的日期范围,不做自动扩展。时序因子(如 `ts_delay`、`ts_mean`)在数据不足时会返回 null,这是符合预期的行为。
|
**注意**:数据获取使用用户传入的日期范围,不做自动扩展。时序因子(如 `ts_delay`、`ts_mean`)在数据不足时会返回 null,这是符合预期的行为。
|
||||||
@@ -377,19 +377,19 @@ def execute(self, plan, data):
|
|||||||
### 7.1 用户代码
|
### 7.1 用户代码
|
||||||
|
|
||||||
```python
|
```python
|
||||||
from src.factors import FactorEngine, FormulaParser, FunctionRegistry
|
from src.factors import FactorEngine
|
||||||
|
|
||||||
# 1. 创建引擎
|
# 1. 创建引擎
|
||||||
engine = FactorEngine()
|
engine = FactorEngine()
|
||||||
|
|
||||||
# 2. 解析字符串表达式
|
# 2. 使用字符串表达式注册因子(推荐)
|
||||||
parser = FormulaParser(FunctionRegistry())
|
engine.add_factor("returns_5d", "(close / ts_delay(close, 5)) - 1")
|
||||||
expr = parser.parse("(close / ts_delay(close, 5)) - 1")
|
|
||||||
|
|
||||||
# 3. 注册因子
|
# 或者使用 DSL 表达式
|
||||||
engine.register("returns_5d", expr)
|
from src.factors.api import close, ts_delay
|
||||||
|
engine.register("returns_5d", (close / ts_delay(close, 5)) - 1)
|
||||||
|
|
||||||
# 4. 执行计算
|
# 3. 执行计算
|
||||||
result = engine.compute(
|
result = engine.compute(
|
||||||
factor_names=["returns_5d"],
|
factor_names=["returns_5d"],
|
||||||
start_date="20240101",
|
start_date="20240101",
|
||||||
@@ -400,23 +400,27 @@ result = engine.compute(
|
|||||||
### 7.2 内部调用链
|
### 7.2 内部调用链
|
||||||
|
|
||||||
```
|
```
|
||||||
|
FactorEngine.add_factor() / register()
|
||||||
|
│
|
||||||
|
└── 创建并缓存 ExecutionPlan
|
||||||
|
└── ExecutionPlanner.create_plan()
|
||||||
|
├── DependencyExtractor.extract_dependencies() → {'close'}
|
||||||
|
├── _infer_data_specs() → [DataSpec('pro_bar', ['close'], 5)]
|
||||||
|
└── PolarsTranslator.translate() → pl.col('close').shift(5).over('ts_code')...
|
||||||
|
|
||||||
FactorEngine.compute()
|
FactorEngine.compute()
|
||||||
│
|
│
|
||||||
├── 1. 创建 ExecutionPlan
|
├── 1. 获取所有缓存的执行计划
|
||||||
│ └── ExecutionPlanner.create_plan()
|
├── 2. 合并数据规格
|
||||||
│ ├── DependencyExtractor.extract_dependencies() → {'close'}
|
│ └── _merge_data_specs()
|
||||||
│ ├── _infer_data_specs() → [DataSpec('pro_bar', ['close'], 5)]
|
├── 3. 获取数据
|
||||||
│ └── PolarsTranslator.translate() → pl.col('close').shift(5).over('ts_code')...
|
│ └── DataRouter.fetch_data(merged_specs)
|
||||||
│
|
│ ├── _load_table('pro_bar', ['close'], start_date, end_date)
|
||||||
├── 2. 获取数据
|
|
||||||
│ └── DataRouter.fetch_data([plan.data_specs])
|
|
||||||
│ ├── _load_table('pro_bar', ['close'], start_date-5d, end_date)
|
|
||||||
│ │ └── Storage.load_polars() → 查询 DuckDB
|
│ │ └── Storage.load_polars() → 查询 DuckDB
|
||||||
│ └── _assemble_wide_table() → Polars DataFrame
|
│ └── _assemble_wide_table() → Polars DataFrame
|
||||||
│
|
└── 4. 执行计算
|
||||||
└── 3. 执行计算
|
└── ComputeEngine.execute_plans(plans, data)
|
||||||
└── ComputeEngine.execute(plan, data)
|
└── data.with_columns([polars_exprs...])
|
||||||
└── data.with_columns([polars_expr.alias('returns_5d')])
|
|
||||||
└── Polars 执行表达式计算
|
└── Polars 执行表达式计算
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|||||||
@@ -92,17 +92,17 @@
|
|||||||
|
|
||||||
| 字段名 | 状态 | 数据来源 | 所属类别 |
|
| 字段名 | 状态 | 数据来源 | 所属类别 |
|
||||||
|--------|------|----------|----------|
|
|--------|------|----------|----------|
|
||||||
| `close` | 可用 | daily/pro_bar 表 | 价格 |
|
| `close` | 可用 | pro_bar 表 | 价格 |
|
||||||
| `open` | 可用 | daily/pro_bar 表 | 价格 |
|
| `open` | 可用 | pro_bar 表 | 价格 |
|
||||||
| `high` | 可用 | daily/pro_bar 表 | 价格 |
|
| `high` | 可用 | pro_bar 表 | 价格 |
|
||||||
| `low` | 可用 | daily/pro_bar 表 | 价格 |
|
| `low` | 可用 | pro_bar 表 | 价格 |
|
||||||
| `vol` | 可用 | daily/pro_bar 表 | 成交量 |
|
| `vol` | 可用 | pro_bar 表 | 成交量 |
|
||||||
| `amount` | 可用 | daily/pro_bar 表 | 成交额 |
|
| `amount` | 可用 | pro_bar 表 | 成交额 |
|
||||||
| `pre_close` | 可用 | daily/pro_bar 表 | 价格 |
|
| `pre_close` | 可用 | pro_bar 表 | 价格 |
|
||||||
| `change` | 可用 | daily/pro_bar 表 | 价格变化 |
|
| `change` | 可用 | pro_bar 表 | 价格变化 |
|
||||||
| `pct_chg` | 可用 | daily/pro_bar 表 | 涨跌幅 |
|
| `pct_chg` | 可用 | pro_bar 表 | 涨跌幅 |
|
||||||
| `turnover_rate` | 可用 | daily/pro_bar 表 | 换手率 |
|
| `turnover_rate` | 可用 | pro_bar 表 | 换手率 |
|
||||||
| `volume_ratio` | 可用 | daily/pro_bar 表 | 量比 |
|
| `volume_ratio` | 可用 | pro_bar 表 | 量比 |
|
||||||
|
|
||||||
### 1.8 支持的运算符
|
### 1.8 支持的运算符
|
||||||
|
|
||||||
@@ -482,7 +482,7 @@ spec = DataSpec(
|
|||||||
|
|
||||||
| 数据源 | 依赖因子数 | 实现难度 | 优先级 |
|
| 数据源 | 依赖因子数 | 实现难度 | 优先级 |
|
||||||
|--------|------------|----------|--------|
|
|--------|------------|----------|--------|
|
||||||
| daily/pro_bar (已有) | ~40 | 低 | 高 |
|
| pro_bar (主力行情表) | ~40 | 低 | 高 |
|
||||||
| 纯技术指标 (ts_*) | ~30 | 中 | 高 |
|
| 纯技术指标 (ts_*) | ~30 | 中 | 高 |
|
||||||
| 筹码分布 (cyq) | ~50 | 中 | 中 |
|
| 筹码分布 (cyq) | ~50 | 中 | 中 |
|
||||||
| 资金流向 (moneyflow) | ~30 | 中 | 中 |
|
| 资金流向 (moneyflow) | ~30 | 中 | 中 |
|
||||||
|
|||||||
@@ -524,7 +524,7 @@ def prepare_data(...) -> pl.DataFrame:
|
|||||||
```python
|
```python
|
||||||
# 系统自动识别
|
# 系统自动识别
|
||||||
n_income → financial_income 表 (PIT)
|
n_income → financial_income 表 (PIT)
|
||||||
close → daily 表 (DAILY)
|
close → pro_bar 表 (主力行情表)
|
||||||
```
|
```
|
||||||
|
|
||||||
### 3. 财务数据清洗
|
### 3. 财务数据清洗
|
||||||
@@ -584,10 +584,10 @@ CREATE TABLE financial_income (
|
|||||||
);
|
);
|
||||||
```
|
```
|
||||||
|
|
||||||
### daily(日线行情)
|
### pro_bar(主力行情表)
|
||||||
|
|
||||||
```sql
|
```sql
|
||||||
CREATE TABLE daily (
|
CREATE TABLE pro_bar (
|
||||||
ts_code VARCHAR, -- 股票代码
|
ts_code VARCHAR, -- 股票代码
|
||||||
trade_date DATE, -- 交易日期
|
trade_date DATE, -- 交易日期
|
||||||
open DOUBLE, -- 开盘价
|
open DOUBLE, -- 开盘价
|
||||||
@@ -595,6 +595,10 @@ CREATE TABLE daily (
|
|||||||
low DOUBLE, -- 最低价
|
low DOUBLE, -- 最低价
|
||||||
close DOUBLE, -- 收盘价
|
close DOUBLE, -- 收盘价
|
||||||
vol BIGINT, -- 成交量
|
vol BIGINT, -- 成交量
|
||||||
|
turnover_rate DOUBLE, -- 换手率
|
||||||
|
volume_ratio DOUBLE, -- 量比
|
||||||
... -- 其他行情字段
|
... -- 其他行情字段
|
||||||
);
|
);
|
||||||
```
|
```
|
||||||
|
|
||||||
|
**说明**: pro_bar 表通过 Tushare Pro Bar 接口获取,包含后复权数据和换手率、量比等指标,是主力行情数据表。
|
||||||
|
|||||||
@@ -1,240 +0,0 @@
|
|||||||
"""Simplified daily market data interface.
|
|
||||||
|
|
||||||
A single function to fetch A股日线行情 data from Tushare.
|
|
||||||
Supports all output fields including tor (换手率) and vr (量比).
|
|
||||||
|
|
||||||
This module provides both single-stock fetching (get_daily) and
|
|
||||||
batch synchronization (DailySync class) for daily market data.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import pandas as pd
|
|
||||||
from typing import Optional, List, Literal, Dict
|
|
||||||
|
|
||||||
from src.data.client import TushareClient
|
|
||||||
from src.data.api_wrappers.base_sync import StockBasedSync
|
|
||||||
|
|
||||||
|
|
||||||
def get_daily(
|
|
||||||
ts_code: str,
|
|
||||||
start_date: Optional[str] = None,
|
|
||||||
end_date: Optional[str] = None,
|
|
||||||
trade_date: Optional[str] = None,
|
|
||||||
adj: Literal[None, "qfq", "hfq"] = None,
|
|
||||||
factors: Optional[List[Literal["tor", "vr"]]] = None,
|
|
||||||
adjfactor: bool = False,
|
|
||||||
) -> pd.DataFrame:
|
|
||||||
"""Fetch daily market data for A-share stocks.
|
|
||||||
|
|
||||||
This is a simplified interface that combines rate limiting, API calls,
|
|
||||||
and error handling into a single function.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
ts_code: Stock code (e.g., '000001.SZ', '600000.SH')
|
|
||||||
start_date: Start date in YYYYMMDD format
|
|
||||||
end_date: End date in YYYYMMDD format
|
|
||||||
trade_date: Specific trade date in YYYYMMDD format
|
|
||||||
adj: Adjustment type - None, 'qfq' (forward), 'hfq' (backward)
|
|
||||||
factors: List of factors to include - 'tor' (turnover rate), 'vr' (volume ratio)
|
|
||||||
adjfactor: Whether to include adjustment factor
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
pd.DataFrame with daily market data containing:
|
|
||||||
- Base fields: ts_code, trade_date, open, high, low, close, pre_close,
|
|
||||||
change, pct_chg, vol, amount
|
|
||||||
- Factor fields (if requested): tor, vr
|
|
||||||
- Adjustment factor (if adjfactor=True): adjfactor
|
|
||||||
|
|
||||||
Example:
|
|
||||||
>>> data = get_daily('000001.SZ', start_date='20240101', end_date='20240131')
|
|
||||||
>>> data = get_daily('600000.SH', factors=['tor', 'vr'])
|
|
||||||
"""
|
|
||||||
# Initialize client
|
|
||||||
client = TushareClient()
|
|
||||||
|
|
||||||
# Build parameters
|
|
||||||
params = {"ts_code": ts_code}
|
|
||||||
|
|
||||||
if start_date:
|
|
||||||
params["start_date"] = start_date
|
|
||||||
if end_date:
|
|
||||||
params["end_date"] = end_date
|
|
||||||
if trade_date:
|
|
||||||
params["trade_date"] = trade_date
|
|
||||||
if adj:
|
|
||||||
params["adj"] = adj
|
|
||||||
if factors:
|
|
||||||
# Tushare expects factors as comma-separated string, not list
|
|
||||||
if isinstance(factors, list):
|
|
||||||
factors_str = ",".join(factors)
|
|
||||||
else:
|
|
||||||
factors_str = factors
|
|
||||||
params["factors"] = factors_str
|
|
||||||
if adjfactor:
|
|
||||||
params["adjfactor"] = "True"
|
|
||||||
|
|
||||||
# Fetch data using pro_bar (supports factors like tor, vr)
|
|
||||||
data = client.query("pro_bar", **params)
|
|
||||||
|
|
||||||
return data
|
|
||||||
|
|
||||||
|
|
||||||
class DailySync(StockBasedSync):
|
|
||||||
"""日线数据批量同步管理器,支持全量/增量同步。
|
|
||||||
|
|
||||||
继承自 StockBasedSync,使用多线程按股票并发获取数据。
|
|
||||||
|
|
||||||
Example:
|
|
||||||
>>> sync = DailySync()
|
|
||||||
>>> results = sync.sync_all() # 增量同步
|
|
||||||
>>> results = sync.sync_all(force_full=True) # 全量同步
|
|
||||||
>>> preview = sync.preview_sync() # 预览
|
|
||||||
"""
|
|
||||||
|
|
||||||
table_name = "daily"
|
|
||||||
|
|
||||||
# 表结构定义
|
|
||||||
TABLE_SCHEMA = {
|
|
||||||
"ts_code": "VARCHAR(16) NOT NULL",
|
|
||||||
"trade_date": "DATE NOT NULL",
|
|
||||||
"open": "DOUBLE",
|
|
||||||
"high": "DOUBLE",
|
|
||||||
"low": "DOUBLE",
|
|
||||||
"close": "DOUBLE",
|
|
||||||
"pre_close": "DOUBLE",
|
|
||||||
"change": "DOUBLE",
|
|
||||||
"pct_chg": "DOUBLE",
|
|
||||||
"vol": "DOUBLE",
|
|
||||||
"amount": "DOUBLE",
|
|
||||||
"turnover_rate": "DOUBLE",
|
|
||||||
"volume_ratio": "DOUBLE",
|
|
||||||
}
|
|
||||||
|
|
||||||
# 索引定义
|
|
||||||
TABLE_INDEXES = [
|
|
||||||
("idx_daily_date_code", ["trade_date", "ts_code"]),
|
|
||||||
]
|
|
||||||
|
|
||||||
# 主键定义
|
|
||||||
PRIMARY_KEY = ("ts_code", "trade_date")
|
|
||||||
|
|
||||||
def fetch_single_stock(
|
|
||||||
self,
|
|
||||||
ts_code: str,
|
|
||||||
start_date: str,
|
|
||||||
end_date: str,
|
|
||||||
) -> pd.DataFrame:
|
|
||||||
"""获取单只股票的日线数据。
|
|
||||||
|
|
||||||
Args:
|
|
||||||
ts_code: 股票代码
|
|
||||||
start_date: 起始日期(YYYYMMDD)
|
|
||||||
end_date: 结束日期(YYYYMMDD)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
包含日线数据的 DataFrame
|
|
||||||
"""
|
|
||||||
# 使用共享客户端进行跨线程速率限制
|
|
||||||
data = self.client.query(
|
|
||||||
"pro_bar",
|
|
||||||
ts_code=ts_code,
|
|
||||||
start_date=start_date,
|
|
||||||
end_date=end_date,
|
|
||||||
factors="tor,vr",
|
|
||||||
)
|
|
||||||
return data
|
|
||||||
|
|
||||||
|
|
||||||
def sync_daily(
|
|
||||||
force_full: bool = False,
|
|
||||||
start_date: Optional[str] = None,
|
|
||||||
end_date: Optional[str] = None,
|
|
||||||
max_workers: Optional[int] = None,
|
|
||||||
dry_run: bool = False,
|
|
||||||
) -> Dict[str, pd.DataFrame]:
|
|
||||||
"""同步所有股票的日线数据。
|
|
||||||
|
|
||||||
这是日线数据同步的主要入口点。
|
|
||||||
|
|
||||||
Args:
|
|
||||||
force_full: 若为 True,强制从 20180101 完整重载
|
|
||||||
start_date: 手动指定起始日期(YYYYMMDD)
|
|
||||||
end_date: 手动指定结束日期(默认为今天)
|
|
||||||
max_workers: 工作线程数(默认: 10)
|
|
||||||
dry_run: 若为 True,仅预览将要同步的内容,不写入数据
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
映射 ts_code 到 DataFrame 的字典
|
|
||||||
|
|
||||||
Example:
|
|
||||||
>>> # 首次同步(从 20180101 全量加载)
|
|
||||||
>>> result = sync_daily()
|
|
||||||
>>>
|
|
||||||
>>> # 后续同步(增量 - 仅新数据)
|
|
||||||
>>> result = sync_daily()
|
|
||||||
>>>
|
|
||||||
>>> # 强制完整重载
|
|
||||||
>>> result = sync_daily(force_full=True)
|
|
||||||
>>>
|
|
||||||
>>> # 手动指定日期范围
|
|
||||||
>>> result = sync_daily(start_date='20240101', end_date='20240131')
|
|
||||||
>>>
|
|
||||||
>>> # 自定义线程数
|
|
||||||
>>> result = sync_daily(max_workers=20)
|
|
||||||
>>>
|
|
||||||
>>> # Dry run(仅预览)
|
|
||||||
>>> result = sync_daily(dry_run=True)
|
|
||||||
"""
|
|
||||||
sync_manager = DailySync(max_workers=max_workers)
|
|
||||||
return sync_manager.sync_all(
|
|
||||||
force_full=force_full,
|
|
||||||
start_date=start_date,
|
|
||||||
end_date=end_date,
|
|
||||||
dry_run=dry_run,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def preview_daily_sync(
|
|
||||||
force_full: bool = False,
|
|
||||||
start_date: Optional[str] = None,
|
|
||||||
end_date: Optional[str] = None,
|
|
||||||
sample_size: int = 3,
|
|
||||||
) -> dict:
|
|
||||||
"""预览日线同步数据量和样本(不实际同步)。
|
|
||||||
|
|
||||||
这是推荐的方式,可在实际同步前检查将要同步的内容。
|
|
||||||
|
|
||||||
Args:
|
|
||||||
force_full: 若为 True,预览全量同步(从 20180101)
|
|
||||||
start_date: 手动指定起始日期(覆盖自动检测)
|
|
||||||
end_date: 手动指定结束日期(默认为今天)
|
|
||||||
sample_size: 预览用样本股票数量(默认: 3)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
包含预览信息的字典:
|
|
||||||
{
|
|
||||||
'sync_needed': bool,
|
|
||||||
'stock_count': int,
|
|
||||||
'start_date': str,
|
|
||||||
'end_date': str,
|
|
||||||
'estimated_records': int,
|
|
||||||
'sample_data': pd.DataFrame,
|
|
||||||
'mode': str, # 'full', 'incremental', 'partial', 或 'none'
|
|
||||||
}
|
|
||||||
|
|
||||||
Example:
|
|
||||||
>>> # 预览将要同步的内容
|
|
||||||
>>> preview = preview_daily_sync()
|
|
||||||
>>>
|
|
||||||
>>> # 预览全量同步
|
|
||||||
>>> preview = preview_daily_sync(force_full=True)
|
|
||||||
>>>
|
|
||||||
>>> # 预览更多样本
|
|
||||||
>>> preview = preview_daily_sync(sample_size=5)
|
|
||||||
"""
|
|
||||||
sync_manager = DailySync()
|
|
||||||
return sync_manager.preview_sync(
|
|
||||||
force_full=force_full,
|
|
||||||
start_date=start_date,
|
|
||||||
end_date=end_date,
|
|
||||||
sample_size=sample_size,
|
|
||||||
)
|
|
||||||
Reference in New Issue
Block a user