feat(data): 添加每日指标接口并优化因子引擎
- 新增 api_daily_basic.py 封装 Tushare 每日指标接口 - 因子引擎移除 lookback_days,支持 daily_basic 表字段路由 - 将每日指标纳入自动同步流程 - 删除废弃的 training/main.py
This commit is contained in:
131
AGENTS.md
131
AGENTS.md
@@ -82,20 +82,21 @@ ProStock/
|
|||||||
│ │
|
│ │
|
||||||
│ ├── data/ # 数据获取与存储
|
│ ├── data/ # 数据获取与存储
|
||||||
│ │ ├── api_wrappers/ # Tushare API 封装
|
│ │ ├── api_wrappers/ # Tushare API 封装
|
||||||
│ │ │ ├── base_sync.py # 同步基础抽象类(BaseDataSync/StockBasedSync/DateBasedSync)
|
│ │ │ ├── base_sync.py # 同步基础抽象类(BaseDataSync/StockBasedSync/DateBasedSync)
|
||||||
│ │ │ ├── api_daily.py # 日线数据接口(DailySync)
|
│ │ │ ├── api_daily.py # 日线数据接口(DailySync)
|
||||||
│ │ │ ├── api_pro_bar.py # Pro Bar 数据接口(ProBarSync)
|
│ │ │ ├── api_pro_bar.py # Pro Bar 数据接口(ProBarSync)
|
||||||
│ │ │ ├── api_stock_basic.py # 股票基础信息接口
|
│ │ │ ├── api_stock_basic.py # 股票基础信息接口
|
||||||
│ │ │ ├── api_trade_cal.py # 交易日历接口
|
│ │ │ ├── api_trade_cal.py # 交易日历接口
|
||||||
│ │ │ ├── api_bak_basic.py # 历史股票列表接口(BakBasicSync)
|
│ │ │ ├── api_bak_basic.py # 历史股票列表接口(BakBasicSync)
|
||||||
│ │ │ ├── api_namechange.py # 股票名称变更接口
|
│ │ │ ├── api_namechange.py # 股票名称变更接口
|
||||||
│ │ │ ├── financial_data/ # 财务数据接口
|
│ │ │ ├── financial_data/ # 财务数据接口
|
||||||
│ │ │ │ ├── api_income.py # 利润表接口
|
│ │ │ │ ├── api_income.py # 利润表接口
|
||||||
│ │ │ │ └── api_financial_sync.py # 财务数据同步
|
│ │ │ │ └── api_financial_sync.py # 财务数据同步
|
||||||
│ │ │ └── __init__.py
|
│ │ │ └── __init__.py
|
||||||
│ │ ├── __init__.py
|
│ │ ├── __init__.py
|
||||||
│ │ ├── client.py # Tushare API 客户端(带速率限制)
|
│ │ ├── client.py # Tushare API 客户端(带速率限制)
|
||||||
│ │ ├── config.py # 数据模块配置
|
│ │ ├── config.py # 数据模块配置
|
||||||
|
│ │ ├── data_router.py # 数据路由器(factors/engine 专用)
|
||||||
│ │ ├── db_inspector.py # 数据库信息查看工具
|
│ │ ├── db_inspector.py # 数据库信息查看工具
|
||||||
│ │ ├── db_manager.py # DuckDB 表管理和同步
|
│ │ ├── db_manager.py # DuckDB 表管理和同步
|
||||||
│ │ ├── rate_limiter.py # 令牌桶速率限制器
|
│ │ ├── rate_limiter.py # 令牌桶速率限制器
|
||||||
@@ -104,20 +105,29 @@ ProStock/
|
|||||||
│ │ └── utils.py # 数据模块工具函数
|
│ │ └── utils.py # 数据模块工具函数
|
||||||
│ │
|
│ │
|
||||||
│ ├── factors/ # 因子计算框架(DSL 表达式驱动)
|
│ ├── factors/ # 因子计算框架(DSL 表达式驱动)
|
||||||
|
│ │ ├── engine/ # 执行引擎子模块
|
||||||
|
│ │ │ ├── __init__.py # 导出引擎组件
|
||||||
|
│ │ │ ├── data_spec.py # 数据规格定义(DataSpec, ExecutionPlan)
|
||||||
|
│ │ │ ├── data_router.py # 数据路由器
|
||||||
|
│ │ │ ├── planner.py # 执行计划生成器(ExecutionPlanner)
|
||||||
|
│ │ │ ├── compute_engine.py # 计算引擎(ComputeEngine)
|
||||||
|
│ │ │ └── factor_engine.py # 因子引擎统一入口(FactorEngine)
|
||||||
│ │ ├── __init__.py # 导出所有公开 API
|
│ │ ├── __init__.py # 导出所有公开 API
|
||||||
│ │ ├── dsl.py # DSL 表达式层 - 节点定义和运算符重载
|
│ │ ├── dsl.py # DSL 表达式层 - 节点定义和运算符重载
|
||||||
│ │ ├── api.py # API 层 - 常用符号(close/open等)和函数(ts_mean/cs_rank等)
|
│ │ ├── api.py # API 层 - 常用符号(close/open等)和函数(ts_mean/cs_rank等)
|
||||||
│ │ ├── compiler.py # AST 编译器 - 依赖提取
|
│ │ ├── compiler.py # AST 编译器 - 依赖提取
|
||||||
│ │ ├── translator.py # Polars 表达式翻译器
|
│ │ ├── translator.py # Polars 表达式翻译器
|
||||||
│ │ └── engine.py # 因子执行引擎 - 统一入口
|
│ │ ├── parser.py # 字符串公式解析器(FormulaParser)
|
||||||
|
│ │ ├── registry.py # 函数注册表(FunctionRegistry)
|
||||||
|
│ │ └── exceptions.py # 异常定义(FormulaParseError等)
|
||||||
│ │
|
│ │
|
||||||
│ ├── pipeline/ # 模型训练管道
|
│ ├── pipeline/ # 模型训练管道
|
||||||
│ │ ├── __init__.py
|
│ │ ├── __init__.py
|
||||||
│ │ ├── pipeline.py # 处理流水线
|
│ │ ├── pipeline.py # 处理流水线(ProcessingPipeline)
|
||||||
│ │ ├── registry.py # 插件注册中心
|
│ │ ├── registry.py # 插件注册中心(PluginRegistry)
|
||||||
│ │ ├── core/ # 核心抽象
|
│ │ ├── core/ # 核心抽象
|
||||||
│ │ │ ├── __init__.py
|
│ │ │ ├── __init__.py
|
||||||
│ │ │ ├── base.py # 基类定义
|
│ │ │ ├── base.py # 基类定义(BaseProcessor/BaseModel/BaseSplitter等)
|
||||||
│ │ │ └── splitter.py # 时间序列划分策略
|
│ │ │ └── splitter.py # 时间序列划分策略
|
||||||
│ │ ├── models/ # 模型实现
|
│ │ ├── models/ # 模型实现
|
||||||
│ │ │ ├── __init__.py
|
│ │ │ ├── __init__.py
|
||||||
@@ -128,14 +138,23 @@ ProStock/
|
|||||||
│ │
|
│ │
|
||||||
│ └── training/ # 训练入口
|
│ └── training/ # 训练入口
|
||||||
│ ├── __init__.py
|
│ ├── __init__.py
|
||||||
│ ├── main.py # 训练主程序
|
|
||||||
│ ├── pipeline.py # 训练流程配置
|
│ ├── pipeline.py # 训练流程配置
|
||||||
│ └── output/ # 训练输出
|
│ └── output/ # 训练输出
|
||||||
│ └── top_stocks.tsv # 推荐股票结果
|
│ └── top_stocks.tsv # 推荐股票结果
|
||||||
│
|
│
|
||||||
├── tests/ # 测试文件
|
├── tests/ # 测试文件
|
||||||
│ ├── test_sync.py
|
│ ├── test_sync.py
|
||||||
│ └── test_daily.py
|
│ ├── test_daily.py
|
||||||
|
│ ├── test_factor_engine.py
|
||||||
|
│ ├── test_factor_integration.py
|
||||||
|
│ ├── test_pro_bar.py
|
||||||
|
│ ├── test_601117_factors.py
|
||||||
|
│ ├── test_two_stocks_string_factors.py
|
||||||
|
│ ├── test_db_manager.py
|
||||||
|
│ ├── test_daily_storage.py
|
||||||
|
│ ├── test_tushare_api.py
|
||||||
|
│ └── pipeline/
|
||||||
|
│ └── test_core.py
|
||||||
├── config/ # 配置文件
|
├── config/ # 配置文件
|
||||||
│ └── .env.local # 环境变量(不在 git 中)
|
│ └── .env.local # 环境变量(不在 git 中)
|
||||||
├── data/ # 数据存储(DuckDB)
|
├── data/ # 数据存储(DuckDB)
|
||||||
@@ -266,10 +285,13 @@ except Exception as e:
|
|||||||
### 依赖项
|
### 依赖项
|
||||||
关键包:
|
关键包:
|
||||||
- `pandas>=2.0.0` - 数据处理
|
- `pandas>=2.0.0` - 数据处理
|
||||||
|
- `polars>=0.20.0` - 高性能数据处理(因子计算)
|
||||||
- `numpy>=1.24.0` - 数值计算
|
- `numpy>=1.24.0` - 数值计算
|
||||||
- `tushare>=2.0.0` - A股数据 API
|
- `tushare>=2.0.0` - A股数据 API
|
||||||
- `pydantic>=2.0.0`、`pydantic-settings>=2.0.0` - 配置
|
- `pydantic>=2.0.0`、`pydantic-settings>=2.0.0` - 配置
|
||||||
- `tqdm>=4.65.0` - 进度条
|
- `tqdm>=4.65.0` - 进度条
|
||||||
|
- `lightgbm>=4.0.0` - 机器学习模型
|
||||||
|
- `catboost>=1.2.0` - 机器学习模型
|
||||||
- `pytest` - 测试(开发)
|
- `pytest` - 测试(开发)
|
||||||
|
|
||||||
### 环境变量
|
### 环境变量
|
||||||
@@ -292,6 +314,9 @@ uv run python -c "from src.data.sync import sync_all; sync_all(force_full=True)"
|
|||||||
|
|
||||||
# 自定义线程数
|
# 自定义线程数
|
||||||
uv run python -c "from src.data.sync import sync_all; sync_all(max_workers=20)"
|
uv run python -c "from src.data.sync import sync_all; sync_all(max_workers=20)"
|
||||||
|
|
||||||
|
# 运行因子计算测试
|
||||||
|
uv run pytest tests/test_factor_engine.py -v
|
||||||
```
|
```
|
||||||
|
|
||||||
## 架构变更历史
|
## 架构变更历史
|
||||||
@@ -309,8 +334,16 @@ uv run python -c "from src.data.sync import sync_all; sync_all(max_workers=20)"
|
|||||||
- 新增 `api.py`: 常用符号(close/open/volume等)和函数(ts_mean/cs_rank等)
|
- 新增 `api.py`: 常用符号(close/open/volume等)和函数(ts_mean/cs_rank等)
|
||||||
- 新增 `compiler.py`: AST 编译器,提取表达式依赖
|
- 新增 `compiler.py`: AST 编译器,提取表达式依赖
|
||||||
- 新增 `translator.py`: 将 DSL 表达式翻译为 Polars 表达式
|
- 新增 `translator.py`: 将 DSL 表达式翻译为 Polars 表达式
|
||||||
- 重构 `engine.py`: 统一执行引擎入口,整合 DataRouter、ExecutionPlanner、ComputeEngine
|
- 新增 `parser.py`: 字符串公式解析器(FormulaParser),支持从字符串解析 DSL 表达式
|
||||||
- 移除: `base.py`、`composite.py`、`data_loader.py`、`data_spec.py`
|
- 新增 `registry.py`: 函数注册表(FunctionRegistry),管理字符串函数名到 Python 函数的映射
|
||||||
|
- 新增 `exceptions.py`: 公式解析异常定义(FormulaParseError、UnknownFunctionError等)
|
||||||
|
- 重构 `engine/` 子模块:
|
||||||
|
- `factor_engine.py`: 因子引擎统一入口(FactorEngine)
|
||||||
|
- `data_spec.py`: 数据规格定义(DataSpec, ExecutionPlan)
|
||||||
|
- `data_router.py`: 数据路由器
|
||||||
|
- `planner.py`: 执行计划生成器(ExecutionPlanner)
|
||||||
|
- `compute_engine.py`: 计算引擎(ComputeEngine)
|
||||||
|
- 移除: `base.py`、`composite.py`、`data_loader.py`、根目录的 `data_spec.py`
|
||||||
- 移除: `factors/momentum/` 和 `factors/financial/` 子目录
|
- 移除: `factors/momentum/` 和 `factors/financial/` 子目录
|
||||||
**使用方式对比**:
|
**使用方式对比**:
|
||||||
```python
|
```python
|
||||||
@@ -328,6 +361,11 @@ uv run python -c "from src.data.sync import sync_all; sync_all(max_workers=20)"
|
|||||||
engine = FactorEngine()
|
engine = FactorEngine()
|
||||||
engine.register("ma20", ma20)
|
engine.register("ma20", ma20)
|
||||||
result = engine.compute(["ma20"], "20240101", "20240131")
|
result = engine.compute(["ma20"], "20240101", "20240131")
|
||||||
|
|
||||||
|
# 字符串公式解析(Phase 1 新增)
|
||||||
|
from src.factors import FormulaParser, FunctionRegistry
|
||||||
|
parser = FormulaParser(FunctionRegistry())
|
||||||
|
expr = parser.parse("ts_mean(close, 20) / close") # 从字符串解析
|
||||||
```
|
```
|
||||||
|
|
||||||
#### data 模块补充完善
|
#### data 模块补充完善
|
||||||
@@ -411,10 +449,20 @@ DSL 层 (dsl.py) <- 因子表达式 (Node)
|
|||||||
Compiler (compiler.py) <- AST 依赖提取
|
Compiler (compiler.py) <- AST 依赖提取
|
||||||
|
|
|
|
||||||
v
|
v
|
||||||
|
Parser (parser.py) <- 字符串公式解析器
|
||||||
|
|
|
||||||
|
v
|
||||||
|
Registry (registry.py) <- 函数注册表
|
||||||
|
|
|
||||||
|
v
|
||||||
Translator (translator.py) <- 翻译为 Polars 表达式
|
Translator (translator.py) <- 翻译为 Polars 表达式
|
||||||
|
|
|
|
||||||
v
|
v
|
||||||
Engine (engine.py) <- 执行引擎 (DataRouter/ExecutionPlanner/ComputeEngine)
|
Engine (engine/) <- 执行引擎
|
||||||
|
| - FactorEngine: 统一入口
|
||||||
|
| - DataRouter: 数据路由
|
||||||
|
| - ExecutionPlanner: 执行计划
|
||||||
|
| - ComputeEngine: 计算引擎
|
||||||
|
|
|
|
||||||
v
|
v
|
||||||
数据层 (data_router.py + DuckDB) <- 数据获取和存储
|
数据层 (data_router.py + DuckDB) <- 数据获取和存储
|
||||||
@@ -422,7 +470,7 @@ Engine (engine.py) <- 执行引擎 (DataRouter/ExecutionPlanner/ComputeEngi
|
|||||||
|
|
||||||
### 使用方式
|
### 使用方式
|
||||||
|
|
||||||
#### 1. 基础表达式
|
#### 1. 基础表达式(DSL 方式)
|
||||||
|
|
||||||
```python
|
```python
|
||||||
from src.factors.api import close, open, ts_mean, cs_rank
|
from src.factors.api import close, open, ts_mean, cs_rank
|
||||||
@@ -435,15 +483,37 @@ price_rank = cs_rank(close) # 收盘价截面排名
|
|||||||
alpha = ma20 * 0.6 + price_rank * 0.4
|
alpha = ma20 * 0.6 + price_rank * 0.4
|
||||||
```
|
```
|
||||||
|
|
||||||
#### 2. 注册和执行
|
#### 2. 字符串公式解析(Phase 1 新增)
|
||||||
|
|
||||||
|
```python
|
||||||
|
from src.factors import FormulaParser, FunctionRegistry
|
||||||
|
|
||||||
|
# 创建解析器(自动加载 api.py 中的所有函数)
|
||||||
|
parser = FormulaParser(FunctionRegistry())
|
||||||
|
|
||||||
|
# 从字符串解析公式
|
||||||
|
expr = parser.parse("ts_mean(close, 20) / close")
|
||||||
|
complex_expr = parser.parse("cs_rank(ts_mean(close, 5) - ts_mean(close, 20))")
|
||||||
|
|
||||||
|
# 支持完整运算符和函数调用
|
||||||
|
expr2 = parser.parse("(close - open) / open * 100") # 涨跌幅
|
||||||
|
expr3 = parser.parse("ts_corr(close, volume, 20)") # 量价相关性
|
||||||
|
```
|
||||||
|
|
||||||
|
#### 3. 注册和执行
|
||||||
|
|
||||||
```python
|
```python
|
||||||
from src.factors import FactorEngine
|
from src.factors import FactorEngine
|
||||||
|
|
||||||
engine = FactorEngine()
|
engine = FactorEngine()
|
||||||
|
|
||||||
|
# 注册 DSL 表达式
|
||||||
engine.register("ma20", ma20)
|
engine.register("ma20", ma20)
|
||||||
engine.register("price_rank", price_rank)
|
engine.register("price_rank", price_rank)
|
||||||
|
|
||||||
|
# 或注册字符串解析的表达式
|
||||||
|
engine.register("alpha", parser.parse("ma20 * 0.6 + price_rank * 0.4"))
|
||||||
|
|
||||||
# 执行计算
|
# 执行计算
|
||||||
result = engine.compute(
|
result = engine.compute(
|
||||||
factor_names=["ma20", "price_rank"],
|
factor_names=["ma20", "price_rank"],
|
||||||
@@ -504,6 +574,27 @@ expr3 = -change # 涨跌额取反
|
|||||||
expr4 = ts_mean(cs_rank(close), 20) # 排名后的20日平滑
|
expr4 = ts_mean(cs_rank(close), 20) # 排名后的20日平滑
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### 异常处理
|
||||||
|
|
||||||
|
框架提供清晰的异常类型帮助定位问题:
|
||||||
|
|
||||||
|
- `FormulaParseError` - 公式解析错误基类
|
||||||
|
- `UnknownFunctionError` - 未知函数错误(提供模糊匹配建议)
|
||||||
|
- `InvalidSyntaxError` - 语法错误
|
||||||
|
- `EmptyExpressionError` - 空表达式错误
|
||||||
|
- `DuplicateFunctionError` - 函数重复注册错误
|
||||||
|
|
||||||
|
示例:
|
||||||
|
```python
|
||||||
|
from src.factors import FormulaParser, FunctionRegistry, UnknownFunctionError
|
||||||
|
|
||||||
|
parser = FormulaParser(FunctionRegistry())
|
||||||
|
try:
|
||||||
|
expr = parser.parse("unknown_func(close)")
|
||||||
|
except UnknownFunctionError as e:
|
||||||
|
print(e) # 显示错误位置和可用函数建议
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
## AI 行为准则
|
## AI 行为准则
|
||||||
## AI 行为准则
|
## AI 行为准则
|
||||||
|
|||||||
557
docs/factor_calculation_flow.md
Normal file
557
docs/factor_calculation_flow.md
Normal file
@@ -0,0 +1,557 @@
|
|||||||
|
# 因子计算流程详解
|
||||||
|
|
||||||
|
本文档详细描述 ProStock 项目中因子计算引擎的完整数据流,以因子表达式 `(close / ts_delay(close, 5)) - 1` 为例,说明从字符串解析到最终计算结果的完整流程。
|
||||||
|
|
||||||
|
## 目录
|
||||||
|
|
||||||
|
1. [整体架构概览](#1-整体架构概览)
|
||||||
|
2. [阶段一:字符串解析](#2-阶段一字符串解析)
|
||||||
|
3. [阶段二:AST 编译与依赖提取](#3-阶段二ast-编译与依赖提取)
|
||||||
|
4. [阶段三:执行计划生成](#4-阶段三执行计划生成)
|
||||||
|
5. [阶段四:数据获取](#5-阶段四数据获取)
|
||||||
|
6. [阶段五:Polars 翻译与计算](#6-阶段五polars-翻译与计算)
|
||||||
|
7. [完整调用链示例](#7-完整调用链示例)
|
||||||
|
8. [数据流时序图](#8-数据流时序图)
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 1. 整体架构概览
|
||||||
|
|
||||||
|
### 1.1 架构层次
|
||||||
|
|
||||||
|
因子框架采用分层设计,从上到下依次为:
|
||||||
|
|
||||||
|
```
|
||||||
|
API 层 (api.py)
|
||||||
|
|
|
||||||
|
v
|
||||||
|
DSL 层 (dsl.py) ← 因子表达式 (Node)
|
||||||
|
|
|
||||||
|
v
|
||||||
|
Parser (parser.py) ← 字符串公式解析
|
||||||
|
|
|
||||||
|
v
|
||||||
|
Registry (registry.py) ← 函数注册表
|
||||||
|
|
|
||||||
|
v
|
||||||
|
Compiler (compiler.py) ← AST 依赖提取
|
||||||
|
|
|
||||||
|
v
|
||||||
|
Planner (planner.py) ← 执行计划生成
|
||||||
|
|
|
||||||
|
v
|
||||||
|
Translator (translator.py) ← 翻译为 Polars 表达式
|
||||||
|
|
|
||||||
|
v
|
||||||
|
Engine (engine/) ← 执行引擎
|
||||||
|
| - FactorEngine: 统一入口
|
||||||
|
| - DataRouter: 数据路由
|
||||||
|
| - ExecutionPlanner: 执行计划
|
||||||
|
| - ComputeEngine: 计算引擎
|
||||||
|
|
|
||||||
|
v
|
||||||
|
数据层 (data/storage.py) ← DuckDB 数据获取和存储
|
||||||
|
```
|
||||||
|
|
||||||
|
### 1.2 核心组件职责
|
||||||
|
|
||||||
|
| 组件 | 文件路径 | 主要职责 |
|
||||||
|
|------|----------|----------|
|
||||||
|
| **FormulaParser** | `factors/parser.py` | 将字符串表达式解析为 DSL AST 节点树 |
|
||||||
|
| **FunctionRegistry** | `factors/registry.py` | 管理函数名到 Python 实现的映射 |
|
||||||
|
| **DSL Nodes** | `factors/dsl.py` | 定义表达式节点(Symbol, FunctionNode, BinaryOpNode 等)|
|
||||||
|
| **DependencyExtractor** | `factors/compiler.py` | 从 AST 提取依赖的数据字段 |
|
||||||
|
| **ExecutionPlanner** | `factors/engine/planner.py` | 整合编译器和翻译器生成执行计划 |
|
||||||
|
| **DataRouter** | `factors/engine/data_router.py` | 按需取数、组装核心宽表 |
|
||||||
|
| **PolarsTranslator** | `factors/translator.py` | 将 DSL AST 翻译为 Polars 表达式 |
|
||||||
|
| **ComputeEngine** | `factors/engine/compute_engine.py` | 执行 Polars 表达式计算 |
|
||||||
|
| **FactorEngine** | `factors/engine/factor_engine.py` | 系统统一入口,协调各组件 |
|
||||||
|
| **Storage** | `data/storage.py` | DuckDB 数据存储和查询接口 |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 2. 阶段一:字符串解析
|
||||||
|
|
||||||
|
### 2.1 解析流程
|
||||||
|
|
||||||
|
当用户调用 `parser.parse("(close / ts_delay(close, 5)) - 1")` 时,解析流程如下:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# 1. FormulaParser.parse() 方法
|
||||||
|
formula = "(close / ts_delay(close, 5)) - 1"
|
||||||
|
ast_tree = ast.parse(formula, mode='eval') # Python AST
|
||||||
|
dsl_node = self._visit(ast_tree.body) # 递归转换为 DSL Node
|
||||||
|
```
|
||||||
|
|
||||||
|
**解析步骤:**
|
||||||
|
|
||||||
|
1. **Python AST 解析**:使用 Python 标准库 `ast.parse()` 将字符串解析为 Python AST
|
||||||
|
2. **AST 遍历转换**:通过 `_visit()` 方法递归遍历 Python AST,映射为 DSL 节点
|
||||||
|
3. **节点类型映射**:
|
||||||
|
- `ast.Name` → `Symbol`
|
||||||
|
- `ast.Constant` → `Constant`
|
||||||
|
- `ast.BinOp` → `BinaryOpNode`
|
||||||
|
- `ast.UnaryOp` → `UnaryOpNode`
|
||||||
|
- `ast.Call` → `FunctionNode`
|
||||||
|
|
||||||
|
### 2.2 解析结果 AST 结构
|
||||||
|
|
||||||
|
```
|
||||||
|
BinaryOpNode(op='-', left, right=Constant(1))
|
||||||
|
├── left: BinaryOpNode(op='/', left, right)
|
||||||
|
│ ├── left: Symbol('close')
|
||||||
|
│ └── right: FunctionNode('ts_delay', [Symbol('close'), Constant(5)])
|
||||||
|
└── right: Constant(1)
|
||||||
|
```
|
||||||
|
|
||||||
|
### 2.3 函数解析机制
|
||||||
|
|
||||||
|
对于函数调用 `ts_delay(close, 5)`:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# _visit_Call 方法处理逻辑
|
||||||
|
def _visit_Call(self, node: ast.Call) -> FunctionNode:
|
||||||
|
func_name = node.func.id # "ts_delay"
|
||||||
|
|
||||||
|
# 从注册表获取函数实现
|
||||||
|
if self.registry.has(func_name):
|
||||||
|
func_impl = self.registry.get(func_name)
|
||||||
|
|
||||||
|
# 递归解析参数
|
||||||
|
args = [self._visit(arg) for arg in node.args]
|
||||||
|
|
||||||
|
# 调用函数实现,返回 FunctionNode
|
||||||
|
return func_impl(*args)
|
||||||
|
```
|
||||||
|
|
||||||
|
**关键点**:`ts_delay` 函数在 `api.py` 中定义:
|
||||||
|
|
||||||
|
```python
|
||||||
|
def ts_delay(x: NodeOrStr, periods: int) -> FunctionNode:
|
||||||
|
return FunctionNode('ts_delay', to_node(x), Constant(periods))
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 3. 阶段二:AST 编译与依赖提取
|
||||||
|
|
||||||
|
### 3.1 依赖提取流程
|
||||||
|
|
||||||
|
解析后的 AST 需要提取依赖的原始数据字段:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# DependencyExtractor.extract_dependencies(node)
|
||||||
|
extractor = DependencyExtractor()
|
||||||
|
dependencies = extractor.extract_dependencies(dsl_node)
|
||||||
|
# 结果: {'close'}
|
||||||
|
```
|
||||||
|
|
||||||
|
**提取逻辑**:
|
||||||
|
|
||||||
|
```python
|
||||||
|
def _visit(self, node):
|
||||||
|
if isinstance(node, Symbol):
|
||||||
|
return {node.name} # 收集字段名
|
||||||
|
elif isinstance(node, (BinaryOpNode, FunctionNode)):
|
||||||
|
# 递归收集子节点依赖
|
||||||
|
deps = set()
|
||||||
|
for child in node.args:
|
||||||
|
deps.update(self._visit(child))
|
||||||
|
return deps
|
||||||
|
# ... 其他节点类型
|
||||||
|
```
|
||||||
|
|
||||||
|
### 3.2 依赖的作用
|
||||||
|
|
||||||
|
提取的依赖 `{close}` 用于:
|
||||||
|
1. **数据规格推导**:确定需要从数据库读取哪些字段
|
||||||
|
2. **执行计划生成**:明确数据需求,避免读取不必要的字段
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 4. 阶段三:执行计划生成
|
||||||
|
|
||||||
|
### 4.1 ExecutionPlanner 的作用
|
||||||
|
|
||||||
|
`ExecutionPlanner.create_plan()` 将 AST 转换为可执行的 `ExecutionPlan`:
|
||||||
|
|
||||||
|
```python
|
||||||
|
planner = ExecutionPlanner()
|
||||||
|
plan = planner.create_plan(
|
||||||
|
node=dsl_node, # 解析后的 DSL 节点
|
||||||
|
output_name="returns_5d" # 输出列名
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
### 4.2 计划生成流程
|
||||||
|
|
||||||
|
```python
|
||||||
|
def create_plan(self, node, output_name):
|
||||||
|
# 1. 提取依赖
|
||||||
|
dependencies = self.dependency_extractor.extract_dependencies(node)
|
||||||
|
|
||||||
|
# 2. 推导数据规格
|
||||||
|
data_specs = self._infer_data_specs(node, dependencies)
|
||||||
|
# 结果: [DataSpec(table='pro_bar', columns=['close'])]
|
||||||
|
|
||||||
|
# 3. 翻译为 Polars 表达式
|
||||||
|
polars_expr = self.polars_translator.translate(node)
|
||||||
|
# 结果: (pl.col('close') / pl.col('close').shift(5)) - 1
|
||||||
|
|
||||||
|
# 4. 构建执行计划
|
||||||
|
return ExecutionPlan(
|
||||||
|
data_specs=data_specs,
|
||||||
|
polars_expr=polars_expr,
|
||||||
|
dependencies=dependencies,
|
||||||
|
output_name=output_name
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
### 4.3 数据规格推导
|
||||||
|
|
||||||
|
根据依赖字段,`_infer_data_specs` 推导出需要的数据规格:
|
||||||
|
|
||||||
|
```python
|
||||||
|
def _infer_data_specs(self, node, dependencies):
|
||||||
|
return [
|
||||||
|
DataSpec(
|
||||||
|
table='pro_bar', # 默认使用 pro_bar 表
|
||||||
|
columns=list(dependencies), # ['close']
|
||||||
|
)
|
||||||
|
]
|
||||||
|
```
|
||||||
|
|
||||||
|
**DataSpec 说明**:
|
||||||
|
- `table`: 数据表名(pro_bar 或 daily)
|
||||||
|
- `columns`: 需要的字段列表
|
||||||
|
|
||||||
|
**注意**:数据获取使用用户传入的日期范围,不做自动扩展。时序因子(如 `ts_delay`、`ts_mean`)在数据不足时会返回 null,这是符合预期的行为。
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 5. 阶段四:数据获取
|
||||||
|
|
||||||
|
### 5.1 DataRouter 的核心职责
|
||||||
|
|
||||||
|
`DataRouter.fetch_data()` 按需取数、组装核心宽表:
|
||||||
|
|
||||||
|
```python
|
||||||
|
data_router = DataRouter(storage, start_date, end_date)
|
||||||
|
wide_data = data_router.fetch_data([plan.data_specs])
|
||||||
|
```
|
||||||
|
|
||||||
|
### 5.2 数据获取流程
|
||||||
|
|
||||||
|
```python
|
||||||
|
def fetch_data(self, data_specs, start_date, end_date):
|
||||||
|
# 1. 合并数据规格,收集所需表和字段
|
||||||
|
required_tables = self._collect_required_tables(data_specs)
|
||||||
|
|
||||||
|
# 2. 加载各表数据
|
||||||
|
table_data = {}
|
||||||
|
for table, columns in required_tables.items():
|
||||||
|
table_data[table] = self._load_table(table, columns, start_date, end_date)
|
||||||
|
|
||||||
|
# 3. 组装宽表(left join 合并)
|
||||||
|
wide_table = self._assemble_wide_table(table_data, data_specs)
|
||||||
|
|
||||||
|
return wide_table
|
||||||
|
```
|
||||||
|
|
||||||
|
### 5.3 数据库查询
|
||||||
|
|
||||||
|
`_load_table` 方法从 DuckDB 读取数据:
|
||||||
|
|
||||||
|
```python
|
||||||
|
def _load_table(self, table, columns, start_date, end_date):
|
||||||
|
# 通过 Storage 查询数据库
|
||||||
|
df = self.storage.load_polars(
|
||||||
|
table_name=table,
|
||||||
|
columns=columns + ['ts_code', 'trade_date'], # 必须包含主键
|
||||||
|
start_date=start_date,
|
||||||
|
end_date=end_date
|
||||||
|
)
|
||||||
|
return df
|
||||||
|
```
|
||||||
|
|
||||||
|
**Storage.load_polars 内部实现**:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# data/storage.py
|
||||||
|
SELECT {columns} FROM {table_name}
|
||||||
|
WHERE trade_date BETWEEN '{start}' AND '{end}'
|
||||||
|
ORDER BY ts_code, trade_date
|
||||||
|
```
|
||||||
|
|
||||||
|
### 5.4 宽表组装
|
||||||
|
|
||||||
|
对于 `(close / ts_delay(close, 5)) - 1`,DataRouter 返回的宽表结构:
|
||||||
|
|
||||||
|
```
|
||||||
|
┌──────────┬────────────┬───────┐
|
||||||
|
│ ts_code │ trade_date │ close │
|
||||||
|
├──────────┼────────────┼───────┤
|
||||||
|
│ 000001.SZ│ 20240101 │ 10.5 │
|
||||||
|
│ 000001.SZ│ 20240102 │ 10.6 │
|
||||||
|
│ ... │ ... │ ... │
|
||||||
|
│ 000002.SZ│ 20240101 │ 20.1 │
|
||||||
|
└──────────┴────────────┴───────┘
|
||||||
|
```
|
||||||
|
|
||||||
|
**注意**:数据获取使用用户传入的日期范围,不做自动扩展。对于时序因子(如 `ts_delay(close, 5)`),如果数据不足会返回 null,这是符合预期的行为。用户如需完整计算,应显式扩展日期范围。
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 6. 阶段五:Polars 翻译与计算
|
||||||
|
|
||||||
|
### 6.1 PolarsTranslator 的作用
|
||||||
|
|
||||||
|
将 DSL AST 翻译为 Polars 表达式(惰性计算图):
|
||||||
|
|
||||||
|
```python
|
||||||
|
translator = PolarsTranslator()
|
||||||
|
polars_expr = translator.translate(dsl_node)
|
||||||
|
```
|
||||||
|
|
||||||
|
### 6.2 翻译规则
|
||||||
|
|
||||||
|
| DSL 节点类型 | Polars 表达式 | 说明 |
|
||||||
|
|-------------|--------------|------|
|
||||||
|
| `Symbol('close')` | `pl.col('close')` | 列引用 |
|
||||||
|
| `Constant(5)` | `pl.lit(5)` | 字面量 |
|
||||||
|
| `BinaryOpNode('/', a, b)` | `a / b` | 算术运算 |
|
||||||
|
| `FunctionNode('ts_delay', x, n)` | `x.shift(n).over('ts_code')` | 时间序列滞后 |
|
||||||
|
| `FunctionNode('ts_mean', x, n)` | `x.rolling_mean(n).over('ts_code')` | 时间序列均值 |
|
||||||
|
| `FunctionNode('cs_rank', x)` | `x.rank().over('trade_date')` | 截面排名 |
|
||||||
|
|
||||||
|
### 6.3 时间序列函数翻译
|
||||||
|
|
||||||
|
`ts_delay(close, 5)` 翻译为:
|
||||||
|
|
||||||
|
```python
|
||||||
|
pl.col('close').shift(5).over('ts_code')
|
||||||
|
```
|
||||||
|
|
||||||
|
**关键点**:
|
||||||
|
- `.shift(5)`:向后偏移 5 个位置
|
||||||
|
- `.over('ts_code')`:按股票代码分组计算(每只股票独立计算)
|
||||||
|
|
||||||
|
### 6.4 计算执行
|
||||||
|
|
||||||
|
`ComputeEngine.execute()` 执行计算:
|
||||||
|
|
||||||
|
```python
|
||||||
|
compute_engine = ComputeEngine()
|
||||||
|
result = compute_engine.execute(plan, wide_data)
|
||||||
|
```
|
||||||
|
|
||||||
|
**执行逻辑**:
|
||||||
|
|
||||||
|
```python
|
||||||
|
def execute(self, plan, data):
|
||||||
|
# 使用 Polars with_columns 添加因子列
|
||||||
|
result = data.with_columns([
|
||||||
|
plan.polars_expr.alias(plan.output_name)
|
||||||
|
])
|
||||||
|
return result
|
||||||
|
```
|
||||||
|
|
||||||
|
### 6.5 计算结果
|
||||||
|
|
||||||
|
最终结果包含原始列和计算出的因子列:
|
||||||
|
|
||||||
|
```
|
||||||
|
┌──────────┬────────────┬───────┬─────────────┐
|
||||||
|
│ ts_code │ trade_date │ close │ returns_5d │
|
||||||
|
├──────────┼────────────┼───────┼─────────────┤
|
||||||
|
│ 000001.SZ│ 20240101 │ 10.5 │ null │ # 前5天无数据
|
||||||
|
│ 000001.SZ│ 20240106 │ 10.8 │ 0.0286 │ # (10.8/10.5)-1
|
||||||
|
│ ... │ ... │ ... │ ... │
|
||||||
|
└──────────┴────────────┴───────┴─────────────┘
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 7. 完整调用链示例
|
||||||
|
|
||||||
|
### 7.1 用户代码
|
||||||
|
|
||||||
|
```python
|
||||||
|
from src.factors import FactorEngine, FormulaParser, FunctionRegistry
|
||||||
|
|
||||||
|
# 1. 创建引擎
|
||||||
|
engine = FactorEngine()
|
||||||
|
|
||||||
|
# 2. 解析字符串表达式
|
||||||
|
parser = FormulaParser(FunctionRegistry())
|
||||||
|
expr = parser.parse("(close / ts_delay(close, 5)) - 1")
|
||||||
|
|
||||||
|
# 3. 注册因子
|
||||||
|
engine.register("returns_5d", expr)
|
||||||
|
|
||||||
|
# 4. 执行计算
|
||||||
|
result = engine.compute(
|
||||||
|
factor_names=["returns_5d"],
|
||||||
|
start_date="20240101",
|
||||||
|
end_date="20240131"
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
### 7.2 内部调用链
|
||||||
|
|
||||||
|
```
|
||||||
|
FactorEngine.compute()
|
||||||
|
│
|
||||||
|
├── 1. 创建 ExecutionPlan
|
||||||
|
│ └── ExecutionPlanner.create_plan()
|
||||||
|
│ ├── DependencyExtractor.extract_dependencies() → {'close'}
|
||||||
|
│ ├── _infer_data_specs() → [DataSpec('pro_bar', ['close'], 5)]
|
||||||
|
│ └── PolarsTranslator.translate() → pl.col('close').shift(5).over('ts_code')...
|
||||||
|
│
|
||||||
|
├── 2. 获取数据
|
||||||
|
│ └── DataRouter.fetch_data([plan.data_specs])
|
||||||
|
│ ├── _load_table('pro_bar', ['close'], start_date-5d, end_date)
|
||||||
|
│ │ └── Storage.load_polars() → 查询 DuckDB
|
||||||
|
│ └── _assemble_wide_table() → Polars DataFrame
|
||||||
|
│
|
||||||
|
└── 3. 执行计算
|
||||||
|
└── ComputeEngine.execute(plan, data)
|
||||||
|
└── data.with_columns([polars_expr.alias('returns_5d')])
|
||||||
|
└── Polars 执行表达式计算
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 8. 数据流时序图
|
||||||
|
|
||||||
|
```mermaid
|
||||||
|
sequenceDiagram
|
||||||
|
participant User as 用户代码
|
||||||
|
participant FE as FactorEngine
|
||||||
|
participant PL as ExecutionPlanner
|
||||||
|
participant DR as DataRouter
|
||||||
|
participant ST as Storage
|
||||||
|
participant CE as ComputeEngine
|
||||||
|
participant DB as DuckDB
|
||||||
|
|
||||||
|
User->>FE: compute("(close/ts_delay(close,5))-1", start, end)
|
||||||
|
|
||||||
|
Note over FE: 阶段1:创建执行计划
|
||||||
|
FE->>PL: create_plan(node, output_name)
|
||||||
|
PL->>PL: extract_dependencies(node)
|
||||||
|
PL->>PL: infer_data_specs(node)
|
||||||
|
PL->>PL: translate_to_polars(node)
|
||||||
|
PL-->>FE: ExecutionPlan
|
||||||
|
|
||||||
|
Note over FE: 阶段2:数据获取
|
||||||
|
FE->>DR: fetch_data(data_specs)
|
||||||
|
|
||||||
|
loop 每个需要的表
|
||||||
|
DR->>ST: load_polars(table, columns, adjusted_start, end)
|
||||||
|
ST->>DB: SELECT ... WHERE trade_date BETWEEN ...
|
||||||
|
DB-->>ST: 原始数据
|
||||||
|
ST-->>DR: Polars DataFrame
|
||||||
|
end
|
||||||
|
|
||||||
|
DR->>DR: assemble_wide_table()
|
||||||
|
DR->>DR: filter_by_date_range()
|
||||||
|
DR-->>FE: 核心宽表
|
||||||
|
|
||||||
|
Note over FE: 阶段3:执行计算
|
||||||
|
FE->>CE: execute(plan, data)
|
||||||
|
CE->>CE: data.with_columns([polars_expr])
|
||||||
|
CE-->>FE: 计算结果
|
||||||
|
|
||||||
|
FE-->>User: Polars DataFrame (含因子列)
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 附录:关键代码片段
|
||||||
|
|
||||||
|
### A.1 FactorEngine.compute 方法
|
||||||
|
|
||||||
|
```python
|
||||||
|
def compute(
|
||||||
|
self,
|
||||||
|
factor_names: list[str],
|
||||||
|
start_date: str,
|
||||||
|
end_date: str
|
||||||
|
) -> pl.DataFrame:
|
||||||
|
# 1. 收集所有执行计划
|
||||||
|
all_plans = []
|
||||||
|
for name in factor_names:
|
||||||
|
plan = self.factors[name]
|
||||||
|
all_plans.append(plan)
|
||||||
|
|
||||||
|
# 2. 合并数据规格
|
||||||
|
merged_specs = self._merge_data_specs([p.data_specs for p in all_plans])
|
||||||
|
|
||||||
|
# 3. 获取数据
|
||||||
|
data_router = DataRouter(self.storage, start_date, end_date)
|
||||||
|
data = data_router.fetch_data(merged_specs)
|
||||||
|
|
||||||
|
# 4. 执行计算
|
||||||
|
compute_engine = ComputeEngine()
|
||||||
|
result = compute_engine.execute_plans(all_plans, data)
|
||||||
|
|
||||||
|
return result
|
||||||
|
```
|
||||||
|
|
||||||
|
### A.2 PolarsTranslator 的函数处理器注册
|
||||||
|
|
||||||
|
```python
|
||||||
|
class PolarsTranslator:
|
||||||
|
def __init__(self):
|
||||||
|
self.handlers: dict[str, Callable] = {}
|
||||||
|
self._register_default_handlers()
|
||||||
|
|
||||||
|
def _register_default_handlers(self):
|
||||||
|
# 时间序列函数
|
||||||
|
self.handlers['ts_delay'] = lambda x, n: x.shift(n)
|
||||||
|
self.handlers['ts_mean'] = lambda x, n: x.rolling_mean(n)
|
||||||
|
self.handlers['ts_std'] = lambda x, n: x.rolling_std(n)
|
||||||
|
|
||||||
|
# 截面函数
|
||||||
|
self.handlers['cs_rank'] = lambda x: x.rank()
|
||||||
|
self.handlers['cs_zscore'] = lambda x: (x - x.mean()) / x.std()
|
||||||
|
```
|
||||||
|
|
||||||
|
### A.3 时序因子数据获取说明
|
||||||
|
|
||||||
|
由于移除了自动日期扩展机制,用户需要显式管理时序因子的日期范围:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# 示例:计算 2024-01-15 到 2024-01-20 的 5 日收益率
|
||||||
|
# 需要显式提供足够的历史数据
|
||||||
|
result = engine.compute(
|
||||||
|
factor_names=["returns_5d"], # (close / ts_delay(close, 5)) - 1
|
||||||
|
start_date="20240108", # 向前扩展至少 5 个交易日
|
||||||
|
end_date="20240120"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 在结果中,2024-01-15 之前的日期会因数据不足而返回 null
|
||||||
|
# 用户可以自行过滤到目标日期范围
|
||||||
|
result = result.filter(
|
||||||
|
(pl.col("trade_date") >= "20240115") & (pl.col("trade_date") <= "20240120")
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
**设计原则**:
|
||||||
|
- 显式优于隐式:数据来源透明,用户可以完全控制数据范围
|
||||||
|
- 符合 Polars 行为:rolling/shift 操作在窗口不足时返回 null
|
||||||
|
- 可验证性:用户可以明确知道用了哪些数据计算因子
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 总结
|
||||||
|
|
||||||
|
ProStock 的因子计算引擎采用 **DSL(领域特定语言)+ 延迟计算** 的架构设计,具有以下特点:
|
||||||
|
|
||||||
|
1. **声明式**:用户通过数学表达式描述因子逻辑,无需关心实现细节
|
||||||
|
2. **惰性求值**:表达式构建时不立即计算,生成执行计划后统一执行
|
||||||
|
3. **智能数据获取**:自动分析依赖、推导数据规格、按需取数
|
||||||
|
4. **向量化计算**:基于 Polars 的高性能向量化运算,支持时序和截面计算
|
||||||
|
5. **可扩展性**:通过 FunctionRegistry 可以轻松添加新的因子函数
|
||||||
|
|
||||||
|
整个流程从字符串到计算结果,经历了解析 → 编译 → 计划 → 取数 → 计算五个阶段,各组件职责清晰,便于维护和扩展。
|
||||||
@@ -5,6 +5,7 @@ All wrapper files follow the naming convention: api_{data_type}.py
|
|||||||
|
|
||||||
Available APIs:
|
Available APIs:
|
||||||
- api_daily: Daily market data (日线行情)
|
- api_daily: Daily market data (日线行情)
|
||||||
|
- api_daily_basic: Daily basic indicators (每日指标,换手率、PE、PB、市值等)
|
||||||
- api_pro_bar: Pro Bar universal market data (通用行情,后复权)
|
- api_pro_bar: Pro Bar universal market data (通用行情,后复权)
|
||||||
- api_stock_basic: Stock basic information (股票基本信息)
|
- api_stock_basic: Stock basic information (股票基本信息)
|
||||||
- api_trade_cal: Trading calendar (交易日历)
|
- api_trade_cal: Trading calendar (交易日历)
|
||||||
@@ -13,9 +14,10 @@ Available APIs:
|
|||||||
|
|
||||||
Example:
|
Example:
|
||||||
>>> from src.data.api_wrappers import get_daily, get_stock_basic, get_trade_cal, get_bak_basic
|
>>> from src.data.api_wrappers import get_daily, get_stock_basic, get_trade_cal, get_bak_basic
|
||||||
>>> from src.data.api_wrappers import get_pro_bar, sync_pro_bar
|
>>> from src.data.api_wrappers import get_pro_bar, sync_pro_bar, get_daily_basic, sync_daily_basic
|
||||||
>>> data = get_daily('000001.SZ', start_date='20240101', end_date='20240131')
|
>>> data = get_daily('000001.SZ', start_date='20240101', end_date='20240131')
|
||||||
>>> pro_data = get_pro_bar('000001.SZ', start_date='20240101', end_date='20240131')
|
>>> pro_data = get_pro_bar('000001.SZ', start_date='20240101', end_date='20240131')
|
||||||
|
>>> daily_basic = get_daily_basic(trade_date='20240101')
|
||||||
>>> stocks = get_stock_basic()
|
>>> stocks = get_stock_basic()
|
||||||
>>> calendar = get_trade_cal('20240101', '20240131')
|
>>> calendar = get_trade_cal('20240101', '20240131')
|
||||||
>>> bak_basic = get_bak_basic(trade_date='20240101')
|
>>> bak_basic = get_bak_basic(trade_date='20240101')
|
||||||
@@ -27,6 +29,12 @@ from src.data.api_wrappers.api_daily import (
|
|||||||
preview_daily_sync,
|
preview_daily_sync,
|
||||||
DailySync,
|
DailySync,
|
||||||
)
|
)
|
||||||
|
from src.data.api_wrappers.api_daily_basic import (
|
||||||
|
get_daily_basic,
|
||||||
|
sync_daily_basic,
|
||||||
|
preview_daily_basic_sync,
|
||||||
|
DailyBasicSync,
|
||||||
|
)
|
||||||
from src.data.api_wrappers.api_pro_bar import (
|
from src.data.api_wrappers.api_pro_bar import (
|
||||||
get_pro_bar,
|
get_pro_bar,
|
||||||
sync_pro_bar,
|
sync_pro_bar,
|
||||||
@@ -55,6 +63,11 @@ __all__ = [
|
|||||||
"sync_daily",
|
"sync_daily",
|
||||||
"preview_daily_sync",
|
"preview_daily_sync",
|
||||||
"DailySync",
|
"DailySync",
|
||||||
|
# Daily basic indicators
|
||||||
|
"get_daily_basic",
|
||||||
|
"sync_daily_basic",
|
||||||
|
"preview_daily_basic_sync",
|
||||||
|
"DailyBasicSync",
|
||||||
# Pro Bar (universal market data)
|
# Pro Bar (universal market data)
|
||||||
"get_pro_bar",
|
"get_pro_bar",
|
||||||
"sync_pro_bar",
|
"sync_pro_bar",
|
||||||
|
|||||||
@@ -495,4 +495,74 @@ df = ts.pro_bar(ts_code='000001.SZ', start_date='20180101', end_date='20181011',
|
|||||||
例如:
|
例如:
|
||||||
|
|
||||||
|
|
||||||
df = ts.pro_bar(ts_code='000001.SH', asset='I', start_date='20180101', end_date='20181011')
|
df = ts.pro_bar(ts_code='000001.SH', asset='I', start_date='20180101', end_date='20181011')
|
||||||
|
|
||||||
|
每日指标
|
||||||
|
接口:daily_basic,可以通过数据工具调试和查看数据。
|
||||||
|
更新时间:交易日每日15点~17点之间
|
||||||
|
描述:获取全部股票每日重要的基本面指标,可用于选股分析、报表展示等。单次请求最大返回6000条数据,可按日线循环提取全部历史。
|
||||||
|
积分:至少2000积分才可以调取,5000积分无总量限制,具体请参阅积分获取办法
|
||||||
|
|
||||||
|
输入参数
|
||||||
|
|
||||||
|
名称 类型 必选 描述
|
||||||
|
ts_code str Y 股票代码(二选一)
|
||||||
|
trade_date str N 交易日期 (二选一)
|
||||||
|
start_date str N 开始日期(YYYYMMDD)
|
||||||
|
end_date str N 结束日期(YYYYMMDD)
|
||||||
|
注:日期都填YYYYMMDD格式,比如20181010
|
||||||
|
|
||||||
|
输出参数
|
||||||
|
|
||||||
|
名称 类型 描述
|
||||||
|
ts_code str TS股票代码
|
||||||
|
trade_date str 交易日期
|
||||||
|
close float 当日收盘价
|
||||||
|
turnover_rate float 换手率(%)
|
||||||
|
turnover_rate_f float 换手率(自由流通股)
|
||||||
|
volume_ratio float 量比
|
||||||
|
pe float 市盈率(总市值/净利润, 亏损的PE为空)
|
||||||
|
pe_ttm float 市盈率(TTM,亏损的PE为空)
|
||||||
|
pb float 市净率(总市值/净资产)
|
||||||
|
ps float 市销率
|
||||||
|
ps_ttm float 市销率(TTM)
|
||||||
|
dv_ratio float 股息率 (%)
|
||||||
|
dv_ttm float 股息率(TTM)(%)
|
||||||
|
total_share float 总股本 (万股)
|
||||||
|
float_share float 流通股本 (万股)
|
||||||
|
free_share float 自由流通股本 (万)
|
||||||
|
total_mv float 总市值 (万元)
|
||||||
|
circ_mv float 流通市值(万元)
|
||||||
|
接口用法
|
||||||
|
|
||||||
|
|
||||||
|
pro = ts.pro_api()
|
||||||
|
|
||||||
|
df = pro.daily_basic(ts_code='', trade_date='20180726', fields='ts_code,trade_date,turnover_rate,volume_ratio,pe,pb')
|
||||||
|
或者
|
||||||
|
|
||||||
|
|
||||||
|
df = pro.query('daily_basic', ts_code='', trade_date='20180726',fields='ts_code,trade_date,turnover_rate,volume_ratio,pe,pb')
|
||||||
|
数据样例
|
||||||
|
|
||||||
|
ts_code trade_date turnover_rate volume_ratio pe pb
|
||||||
|
0 600230.SH 20180726 2.4584 0.72 8.6928 3.7203
|
||||||
|
1 600237.SH 20180726 1.4737 0.88 166.4001 1.8868
|
||||||
|
2 002465.SZ 20180726 0.7489 0.72 71.8943 2.6391
|
||||||
|
3 300732.SZ 20180726 6.7083 0.77 21.8101 3.2513
|
||||||
|
4 600007.SH 20180726 0.0381 0.61 23.7696 2.3774
|
||||||
|
5 300068.SZ 20180726 1.4583 0.52 27.8166 1.7549
|
||||||
|
6 300552.SZ 20180726 2.0728 0.95 56.8004 2.9279
|
||||||
|
7 601369.SH 20180726 0.2088 0.95 44.1163 1.8001
|
||||||
|
8 002518.SZ 20180726 0.5814 0.76 15.1004 2.5626
|
||||||
|
9 002913.SZ 20180726 12.1096 1.03 33.1279 2.9217
|
||||||
|
10 601818.SH 20180726 0.1893 0.86 6.3064 0.7209
|
||||||
|
11 600926.SH 20180726 0.6065 0.46 9.1772 0.9808
|
||||||
|
12 002166.SZ 20180726 0.7582 0.82 16.9868 3.3452
|
||||||
|
13 600841.SH 20180726 0.3754 1.02 66.2647 2.2302
|
||||||
|
14 300634.SZ 20180726 23.1127 1.26 120.3053 14.3168
|
||||||
|
15 300126.SZ 20180726 1.2304 1.11 348.4306 1.5171
|
||||||
|
16 300718.SZ 20180726 17.6612 0.92 32.0239 3.8661
|
||||||
|
17 000708.SZ 20180726 0.5575 0.70 10.3674 1.0276
|
||||||
|
18 002626.SZ 20180726 0.6187 0.83 22.7580 4.2446
|
||||||
|
19 600816.SH 20180726 0.6745 0.65 11.0778 3.2214
|
||||||
252
src/data/api_wrappers/api_daily_basic.py
Normal file
252
src/data/api_wrappers/api_daily_basic.py
Normal file
@@ -0,0 +1,252 @@
|
|||||||
|
"""每日指标数据接口。
|
||||||
|
|
||||||
|
获取全部股票每日重要的基本面指标,包括换手率、市盈率、市净率、
|
||||||
|
总市值、流通市值等,可用于选股分析、报表展示等。
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Optional, Dict, Any
|
||||||
|
|
||||||
|
import pandas as pd
|
||||||
|
|
||||||
|
from src.data.client import TushareClient
|
||||||
|
from src.data.api_wrappers.base_sync import DateBasedSync
|
||||||
|
|
||||||
|
|
||||||
|
def get_daily_basic(
|
||||||
|
trade_date: Optional[str] = None,
|
||||||
|
ts_code: Optional[str] = None,
|
||||||
|
start_date: Optional[str] = None,
|
||||||
|
end_date: Optional[str] = None,
|
||||||
|
client: Optional[TushareClient] = None,
|
||||||
|
) -> pd.DataFrame:
|
||||||
|
"""Fetch daily basic indicators from Tushare.
|
||||||
|
|
||||||
|
This interface retrieves important daily fundamental indicators for all stocks,
|
||||||
|
including turnover rate, PE, PB, market value, etc. It can be used for stock
|
||||||
|
selection analysis and report display.
|
||||||
|
|
||||||
|
Note: At least one of trade_date or ts_code must be provided. The recommended
|
||||||
|
approach is to use trade_date to fetch data for all stocks on a specific date,
|
||||||
|
which is more efficient than fetching by individual stock codes.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
trade_date: Specific trade date (YYYYMMDD format). Use this to get all
|
||||||
|
stocks' data for a single date. More efficient than ts_code.
|
||||||
|
ts_code: Stock code (e.g., '000001.SZ', '600000.SH'). Optional if
|
||||||
|
trade_date is provided.
|
||||||
|
start_date: Start date (YYYYMMDD format). Use with end_date for date range.
|
||||||
|
end_date: End date (YYYYMMDD format). Use with start_date for date range.
|
||||||
|
client: Optional TushareClient instance for shared rate limiting.
|
||||||
|
If None, creates a new client. For concurrent sync operations,
|
||||||
|
pass a shared client to ensure proper rate limiting.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
pd.DataFrame with columns:
|
||||||
|
- ts_code: TS stock code
|
||||||
|
- trade_date: Trade date (YYYYMMDD)
|
||||||
|
- close: Closing price
|
||||||
|
- turnover_rate: Turnover rate (%)
|
||||||
|
- turnover_rate_f: Turnover rate (free float shares)
|
||||||
|
- volume_ratio: Volume ratio
|
||||||
|
- pe: Price-to-earnings ratio (total market cap / net profit)
|
||||||
|
- pe_ttm: PE ratio (TTM)
|
||||||
|
- pb: Price-to-book ratio (total market cap / net assets)
|
||||||
|
- ps: Price-to-sales ratio
|
||||||
|
- ps_ttm: PS ratio (TTM)
|
||||||
|
- dv_ratio: Dividend yield (%)
|
||||||
|
- dv_ttm: Dividend yield (TTM) (%)
|
||||||
|
- total_share: Total shares (10k shares)
|
||||||
|
- float_share: Float shares (10k shares)
|
||||||
|
- free_share: Free float shares (10k shares)
|
||||||
|
- total_mv: Total market value (10k CNY)
|
||||||
|
- circ_mv: Circulating market value (10k CNY)
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> # Get all stocks for a single date (recommended)
|
||||||
|
>>> data = get_daily_basic(trade_date='20240101')
|
||||||
|
>>>
|
||||||
|
>>> # Get specific stock data
|
||||||
|
>>> data = get_daily_basic(ts_code='000001.SZ', trade_date='20240101')
|
||||||
|
>>>
|
||||||
|
>>> # Get date range data for a specific stock
|
||||||
|
>>> data = get_daily_basic(
|
||||||
|
... ts_code='000001.SZ',
|
||||||
|
... start_date='20240101',
|
||||||
|
... end_date='20240131'
|
||||||
|
... )
|
||||||
|
"""
|
||||||
|
client = client or TushareClient()
|
||||||
|
|
||||||
|
# Build parameters
|
||||||
|
params = {}
|
||||||
|
if trade_date:
|
||||||
|
params["trade_date"] = trade_date
|
||||||
|
if ts_code:
|
||||||
|
params["ts_code"] = ts_code
|
||||||
|
if start_date:
|
||||||
|
params["start_date"] = start_date
|
||||||
|
if end_date:
|
||||||
|
params["end_date"] = end_date
|
||||||
|
|
||||||
|
# Fetch data using daily_basic API
|
||||||
|
data = client.query("daily_basic", **params)
|
||||||
|
|
||||||
|
# Rename date column if needed
|
||||||
|
if "date" in data.columns:
|
||||||
|
data = data.rename(columns={"date": "trade_date"})
|
||||||
|
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
class DailyBasicSync(DateBasedSync):
|
||||||
|
"""每日指标数据批量同步管理器,支持全量/增量同步。
|
||||||
|
|
||||||
|
继承自 DateBasedSync,按日期顺序获取数据。
|
||||||
|
每日指标数据适合按日期获取,一次 API 调用即可获取全市场数据。
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> sync = DailyBasicSync()
|
||||||
|
>>> results = sync.sync_all() # 增量同步
|
||||||
|
>>> results = sync.sync_all(force_full=True) # 全量同步
|
||||||
|
>>> preview = sync.preview_sync() # 预览
|
||||||
|
"""
|
||||||
|
|
||||||
|
table_name = "daily_basic"
|
||||||
|
default_start_date = "20180101"
|
||||||
|
|
||||||
|
# 表结构定义
|
||||||
|
TABLE_SCHEMA = {
|
||||||
|
"ts_code": "VARCHAR(16) NOT NULL",
|
||||||
|
"trade_date": "DATE NOT NULL",
|
||||||
|
"close": "DOUBLE",
|
||||||
|
"turnover_rate": "DOUBLE",
|
||||||
|
"turnover_rate_f": "DOUBLE",
|
||||||
|
"volume_ratio": "DOUBLE",
|
||||||
|
"pe": "DOUBLE",
|
||||||
|
"pe_ttm": "DOUBLE",
|
||||||
|
"pb": "DOUBLE",
|
||||||
|
"ps": "DOUBLE",
|
||||||
|
"ps_ttm": "DOUBLE",
|
||||||
|
"dv_ratio": "DOUBLE",
|
||||||
|
"dv_ttm": "DOUBLE",
|
||||||
|
"total_share": "DOUBLE",
|
||||||
|
"float_share": "DOUBLE",
|
||||||
|
"free_share": "DOUBLE",
|
||||||
|
"total_mv": "DOUBLE",
|
||||||
|
"circ_mv": "DOUBLE",
|
||||||
|
}
|
||||||
|
|
||||||
|
# 索引定义
|
||||||
|
TABLE_INDEXES = [
|
||||||
|
("idx_daily_basic_date_code", ["trade_date", "ts_code"]),
|
||||||
|
]
|
||||||
|
|
||||||
|
# 主键定义
|
||||||
|
PRIMARY_KEY = ("ts_code", "trade_date")
|
||||||
|
|
||||||
|
def fetch_single_date(self, trade_date: str) -> pd.DataFrame:
|
||||||
|
"""获取单日的每日指标数据。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
trade_date: 交易日期(YYYYMMDD)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
包含当日所有股票指标的 DataFrame
|
||||||
|
"""
|
||||||
|
# 使用 get_daily_basic 获取数据(传递共享 client)
|
||||||
|
data = get_daily_basic(
|
||||||
|
trade_date=trade_date,
|
||||||
|
client=self.client, # 传递共享客户端以确保限流
|
||||||
|
)
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
def sync_daily_basic(
|
||||||
|
force_full: bool = False,
|
||||||
|
start_date: Optional[str] = None,
|
||||||
|
end_date: Optional[str] = None,
|
||||||
|
dry_run: bool = False,
|
||||||
|
) -> pd.DataFrame:
|
||||||
|
"""同步所有股票的每日指标数据。
|
||||||
|
|
||||||
|
这是每日指标数据同步的主要入口点。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
force_full: 若为 True,强制从 20180101 完整重载
|
||||||
|
start_date: 手动指定起始日期(YYYYMMDD)
|
||||||
|
end_date: 手动指定结束日期(默认为今天)
|
||||||
|
dry_run: 若为 True,仅预览将要同步的内容,不写入数据
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
同步的数据 DataFrame
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> # 首次同步(从 20180101 全量加载)
|
||||||
|
>>> result = sync_daily_basic()
|
||||||
|
>>>
|
||||||
|
>>> # 后续同步(增量 - 仅新数据)
|
||||||
|
>>> result = sync_daily_basic()
|
||||||
|
>>>
|
||||||
|
>>> # 强制完整重载
|
||||||
|
>>> result = sync_daily_basic(force_full=True)
|
||||||
|
>>>
|
||||||
|
>>> # 手动指定日期范围
|
||||||
|
>>> result = sync_daily_basic(start_date='20240101', end_date='20240131')
|
||||||
|
>>>
|
||||||
|
>>> # Dry run(仅预览)
|
||||||
|
>>> result = sync_daily_basic(dry_run=True)
|
||||||
|
"""
|
||||||
|
sync_manager = DailyBasicSync()
|
||||||
|
return sync_manager.sync_all(
|
||||||
|
force_full=force_full,
|
||||||
|
start_date=start_date,
|
||||||
|
end_date=end_date,
|
||||||
|
dry_run=dry_run,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def preview_daily_basic_sync(
|
||||||
|
force_full: bool = False,
|
||||||
|
start_date: Optional[str] = None,
|
||||||
|
end_date: Optional[str] = None,
|
||||||
|
sample_size: int = 3,
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""预览每日指标同步数据量和样本(不实际同步)。
|
||||||
|
|
||||||
|
这是推荐的方式,可在实际同步前检查将要同步的内容。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
force_full: 若为 True,预览全量同步(从 20180101)
|
||||||
|
start_date: 手动指定起始日期(覆盖自动检测)
|
||||||
|
end_date: 手动指定结束日期(默认为今天)
|
||||||
|
sample_size: 预览用样本天数(默认: 3)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
包含预览信息的字典:
|
||||||
|
{
|
||||||
|
'sync_needed': bool,
|
||||||
|
'date_count': int,
|
||||||
|
'start_date': str,
|
||||||
|
'end_date': str,
|
||||||
|
'estimated_records': int,
|
||||||
|
'sample_data': pd.DataFrame,
|
||||||
|
'mode': str, # 'full', 'incremental', 或 'none'
|
||||||
|
}
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> # 预览将要同步的内容
|
||||||
|
>>> preview = preview_daily_basic_sync()
|
||||||
|
>>>
|
||||||
|
>>> # 预览全量同步
|
||||||
|
>>> preview = preview_daily_basic_sync(force_full=True)
|
||||||
|
>>>
|
||||||
|
>>> # 预览更多样本
|
||||||
|
>>> preview = preview_daily_basic_sync(sample_size=5)
|
||||||
|
"""
|
||||||
|
sync_manager = DailyBasicSync()
|
||||||
|
return sync_manager.preview_sync(
|
||||||
|
force_full=force_full,
|
||||||
|
start_date=start_date,
|
||||||
|
end_date=end_date,
|
||||||
|
sample_size=sample_size,
|
||||||
|
)
|
||||||
@@ -7,6 +7,7 @@
|
|||||||
|
|
||||||
✅ 本模块包含的同步逻辑(每日更新):
|
✅ 本模块包含的同步逻辑(每日更新):
|
||||||
- api_daily.py: 日线数据同步 (DailySync 类)
|
- api_daily.py: 日线数据同步 (DailySync 类)
|
||||||
|
- api_daily_basic.py: 每日指标数据同步 (DailyBasicSync 类)
|
||||||
- api_bak_basic.py: 历史股票列表同步 (BakBasicSync 类)
|
- api_bak_basic.py: 历史股票列表同步 (BakBasicSync 类)
|
||||||
- api_pro_bar.py: Pro Bar 数据同步 (ProBarSync 类)
|
- api_pro_bar.py: Pro Bar 数据同步 (ProBarSync 类)
|
||||||
- api_stock_basic.py: 股票基本信息同步
|
- api_stock_basic.py: 股票基本信息同步
|
||||||
@@ -44,6 +45,7 @@ from src.data.api_wrappers import sync_all_stocks
|
|||||||
from src.data.api_wrappers.api_daily import sync_daily, preview_daily_sync
|
from src.data.api_wrappers.api_daily import sync_daily, preview_daily_sync
|
||||||
from src.data.api_wrappers.api_pro_bar import sync_pro_bar
|
from src.data.api_wrappers.api_pro_bar import sync_pro_bar
|
||||||
from src.data.api_wrappers.api_bak_basic import sync_bak_basic
|
from src.data.api_wrappers.api_bak_basic import sync_bak_basic
|
||||||
|
from src.data.api_wrappers.api_daily_basic import sync_daily_basic
|
||||||
|
|
||||||
|
|
||||||
def preview_sync(
|
def preview_sync(
|
||||||
@@ -157,7 +159,8 @@ def sync_all_data(
|
|||||||
2. 股票基本信息 (sync_all_stocks)
|
2. 股票基本信息 (sync_all_stocks)
|
||||||
3. 日线数据 (sync_daily)
|
3. 日线数据 (sync_daily)
|
||||||
4. Pro Bar 数据 (sync_pro_bar)
|
4. Pro Bar 数据 (sync_pro_bar)
|
||||||
5. 历史股票列表 (sync_bak_basic)
|
5. 每日指标数据 (sync_daily_basic)
|
||||||
|
6. 历史股票列表 (sync_bak_basic)
|
||||||
|
|
||||||
【不包含的同步(需单独调用)】
|
【不包含的同步(需单独调用)】
|
||||||
- 财务数据: 利润表、资产负债表、现金流量表(季度更新)
|
- 财务数据: 利润表、资产负债表、现金流量表(季度更新)
|
||||||
@@ -238,7 +241,7 @@ def sync_all_data(
|
|||||||
results["daily"] = pd.DataFrame()
|
results["daily"] = pd.DataFrame()
|
||||||
|
|
||||||
# 4. Sync Pro Bar data
|
# 4. Sync Pro Bar data
|
||||||
print("\n[4/5] Syncing Pro Bar data (with adj, tor, vr)...")
|
print("\n[4/6] Syncing Pro Bar data (with adj, tor, vr)...")
|
||||||
try:
|
try:
|
||||||
# 确保表存在
|
# 确保表存在
|
||||||
from src.data.api_wrappers.api_pro_bar import ProBarSync
|
from src.data.api_wrappers.api_pro_bar import ProBarSync
|
||||||
@@ -255,14 +258,31 @@ def sync_all_data(
|
|||||||
sum(len(df) for df in pro_bar_result.values()) if pro_bar_result else 0
|
sum(len(df) for df in pro_bar_result.values()) if pro_bar_result else 0
|
||||||
)
|
)
|
||||||
print(
|
print(
|
||||||
f"[4/5] Pro Bar data: OK ({total_pro_bar_records} records from {len(pro_bar_result)} stocks)"
|
f"[4/6] Pro Bar data: OK ({total_pro_bar_records} records from {len(pro_bar_result)} stocks)"
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"[4/5] Pro Bar data: FAILED - {e}")
|
print(f"[4/6] Pro Bar data: FAILED - {e}")
|
||||||
results["pro_bar"] = pd.DataFrame()
|
results["pro_bar"] = pd.DataFrame()
|
||||||
|
|
||||||
# 5. Sync stock historical list (bak_basic)
|
# 5. Sync daily basic indicators
|
||||||
print("\n[5/5] Syncing stock historical list (bak_basic)...")
|
print(
|
||||||
|
"\n[5/6] Syncing daily basic indicators (PE, PB, turnover rate, market value)..."
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
# 确保表存在
|
||||||
|
from src.data.api_wrappers.api_daily_basic import DailyBasicSync
|
||||||
|
|
||||||
|
DailyBasicSync().ensure_table_exists()
|
||||||
|
|
||||||
|
daily_basic_result = sync_daily_basic(force_full=force_full, dry_run=dry_run)
|
||||||
|
results["daily_basic"] = daily_basic_result
|
||||||
|
print(f"[5/6] Daily basic: OK ({len(daily_basic_result)} records)")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"[5/6] Daily basic: FAILED - {e}")
|
||||||
|
results["daily_basic"] = pd.DataFrame()
|
||||||
|
|
||||||
|
# 6. Sync stock historical list (bak_basic)
|
||||||
|
print("\n[6/6] Syncing stock historical list (bak_basic)...")
|
||||||
try:
|
try:
|
||||||
# 确保表存在
|
# 确保表存在
|
||||||
from src.data.api_wrappers.api_bak_basic import BakBasicSync
|
from src.data.api_wrappers.api_bak_basic import BakBasicSync
|
||||||
@@ -271,9 +291,9 @@ def sync_all_data(
|
|||||||
|
|
||||||
bak_basic_result = sync_bak_basic(force_full=force_full)
|
bak_basic_result = sync_bak_basic(force_full=force_full)
|
||||||
results["bak_basic"] = bak_basic_result
|
results["bak_basic"] = bak_basic_result
|
||||||
print(f"[5/5] Bak basic: OK ({len(bak_basic_result)} records)")
|
print(f"[6/6] Bak basic: OK ({len(bak_basic_result)} records)")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"[5/5] Bak basic: FAILED - {e}")
|
print(f"[6/6] Bak basic: FAILED - {e}")
|
||||||
results["bak_basic"] = pd.DataFrame()
|
results["bak_basic"] = pd.DataFrame()
|
||||||
|
|
||||||
# Summary
|
# Summary
|
||||||
@@ -286,7 +306,7 @@ def sync_all_data(
|
|||||||
total_records = sum(len(df) for df in data.values())
|
total_records = sum(len(df) for df in data.values())
|
||||||
print(f" {data_type}: {len(data)} stocks, {total_records} total records")
|
print(f" {data_type}: {len(data)} stocks, {total_records} total records")
|
||||||
else:
|
else:
|
||||||
# bak_basic 返回的是 DataFrame
|
# daily_basic 和 bak_basic 返回的是 DataFrame
|
||||||
print(f" {data_type}: {len(data)} records")
|
print(f" {data_type}: {len(data)} records")
|
||||||
print("=" * 60)
|
print("=" * 60)
|
||||||
print("\nNote: namechange is NOT in auto-sync. To sync manually:")
|
print("\nNote: namechange is NOT in auto-sync. To sync manually:")
|
||||||
@@ -308,7 +328,7 @@ if __name__ == "__main__":
|
|||||||
print("")
|
print("")
|
||||||
print(" # Or sync individual data types:")
|
print(" # Or sync individual data types:")
|
||||||
print(" from src.data.sync import sync_all, preview_sync")
|
print(" from src.data.sync import sync_all, preview_sync")
|
||||||
print(" from src.data.sync import sync_bak_basic")
|
print(" from src.data.api_wrappers import sync_daily_basic, sync_bak_basic")
|
||||||
print("")
|
print("")
|
||||||
print(" # Preview before sync (recommended)")
|
print(" # Preview before sync (recommended)")
|
||||||
print(" preview = preview_sync()")
|
print(" preview = preview_sync()")
|
||||||
|
|||||||
@@ -69,16 +69,11 @@ class DataRouter:
|
|||||||
|
|
||||||
# 收集所有需要的表和字段
|
# 收集所有需要的表和字段
|
||||||
required_tables: Dict[str, Set[str]] = {}
|
required_tables: Dict[str, Set[str]] = {}
|
||||||
max_lookback = 0
|
|
||||||
|
|
||||||
for spec in data_specs:
|
for spec in data_specs:
|
||||||
if spec.table not in required_tables:
|
if spec.table not in required_tables:
|
||||||
required_tables[spec.table] = set()
|
required_tables[spec.table] = set()
|
||||||
required_tables[spec.table].update(spec.columns)
|
required_tables[spec.table].update(spec.columns)
|
||||||
max_lookback = max(max_lookback, spec.lookback_days)
|
|
||||||
|
|
||||||
# 调整日期范围以包含回看期
|
|
||||||
adjusted_start = self._adjust_start_date(start_date, max_lookback)
|
|
||||||
|
|
||||||
# 从数据源获取各表数据
|
# 从数据源获取各表数据
|
||||||
table_data = {}
|
table_data = {}
|
||||||
@@ -86,7 +81,7 @@ class DataRouter:
|
|||||||
df = self._load_table(
|
df = self._load_table(
|
||||||
table_name=table_name,
|
table_name=table_name,
|
||||||
columns=list(columns),
|
columns=list(columns),
|
||||||
start_date=adjusted_start,
|
start_date=start_date,
|
||||||
end_date=end_date,
|
end_date=end_date,
|
||||||
stock_codes=stock_codes,
|
stock_codes=stock_codes,
|
||||||
)
|
)
|
||||||
@@ -95,11 +90,6 @@ class DataRouter:
|
|||||||
# 组装核心宽表
|
# 组装核心宽表
|
||||||
core_table = self._assemble_wide_table(table_data, required_tables)
|
core_table = self._assemble_wide_table(table_data, required_tables)
|
||||||
|
|
||||||
# 过滤到实际请求日期范围
|
|
||||||
core_table = core_table.filter(
|
|
||||||
(pl.col("trade_date") >= start_date) & (pl.col("trade_date") <= end_date)
|
|
||||||
)
|
|
||||||
|
|
||||||
return core_table
|
return core_table
|
||||||
|
|
||||||
def _load_table(
|
def _load_table(
|
||||||
@@ -265,34 +255,6 @@ class DataRouter:
|
|||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def _adjust_start_date(self, start_date: str, lookback_days: int) -> str:
|
|
||||||
"""根据回看天数调整开始日期。
|
|
||||||
|
|
||||||
Args:
|
|
||||||
start_date: 原始开始日期 (YYYYMMDD)
|
|
||||||
lookback_days: 需要回看的交易日数
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
调整后的开始日期
|
|
||||||
"""
|
|
||||||
# 简化的日期调整:假设每月30天,向前推移
|
|
||||||
# 实际应用中应该使用交易日历
|
|
||||||
year = int(start_date[:4])
|
|
||||||
month = int(start_date[4:6])
|
|
||||||
day = int(start_date[6:8])
|
|
||||||
|
|
||||||
total_days = lookback_days + 30 # 额外缓冲
|
|
||||||
|
|
||||||
day -= total_days
|
|
||||||
while day <= 0:
|
|
||||||
month -= 1
|
|
||||||
if month <= 0:
|
|
||||||
month = 12
|
|
||||||
year -= 1
|
|
||||||
day += 30
|
|
||||||
|
|
||||||
return f"{year:04d}{month:02d}{day:02d}"
|
|
||||||
|
|
||||||
def clear_cache(self) -> None:
|
def clear_cache(self) -> None:
|
||||||
"""清除数据缓存。"""
|
"""清除数据缓存。"""
|
||||||
with self._lock:
|
with self._lock:
|
||||||
|
|||||||
@@ -18,12 +18,10 @@ class DataSpec:
|
|||||||
Attributes:
|
Attributes:
|
||||||
table: 数据表名称
|
table: 数据表名称
|
||||||
columns: 需要的字段列表
|
columns: 需要的字段列表
|
||||||
lookback_days: 回看天数(用于时序计算)
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
table: str
|
table: str
|
||||||
columns: List[str]
|
columns: List[str]
|
||||||
lookback_days: int = 1
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|||||||
@@ -73,9 +73,9 @@ class ExecutionPlanner:
|
|||||||
) -> List[DataSpec]:
|
) -> List[DataSpec]:
|
||||||
"""从依赖推导数据规格。
|
"""从依赖推导数据规格。
|
||||||
|
|
||||||
根据表达式中的函数类型推断回看天数需求。
|
|
||||||
基础行情字段(open, high, low, close, vol, amount, pre_close, change, pct_chg)
|
基础行情字段(open, high, low, close, vol, amount, pre_close, change, pct_chg)
|
||||||
默认从 pro_bar 表获取。
|
默认从 pro_bar 表获取。
|
||||||
|
每日指标字段(total_mv, circ_mv, pe, pb 等)从 daily_basic 表获取。
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
dependencies: 依赖的字段集合
|
dependencies: 依赖的字段集合
|
||||||
@@ -84,10 +84,6 @@ class ExecutionPlanner:
|
|||||||
Returns:
|
Returns:
|
||||||
数据规格列表
|
数据规格列表
|
||||||
"""
|
"""
|
||||||
# 计算最大回看窗口
|
|
||||||
max_window = self._extract_max_window(expression)
|
|
||||||
lookback_days = max(1, max_window)
|
|
||||||
|
|
||||||
# 基础行情字段集合(这些字段从 pro_bar 表获取)
|
# 基础行情字段集合(这些字段从 pro_bar 表获取)
|
||||||
pro_bar_fields = {
|
pro_bar_fields = {
|
||||||
"open",
|
"open",
|
||||||
@@ -103,9 +99,27 @@ class ExecutionPlanner:
|
|||||||
"volume_ratio",
|
"volume_ratio",
|
||||||
}
|
}
|
||||||
|
|
||||||
# 将依赖分为 pro_bar 字段和其他字段
|
# 每日指标字段集合(这些字段从 daily_basic 表获取)
|
||||||
|
daily_basic_fields = {
|
||||||
|
"turnover_rate_f",
|
||||||
|
"pe",
|
||||||
|
"pe_ttm",
|
||||||
|
"pb",
|
||||||
|
"ps",
|
||||||
|
"ps_ttm",
|
||||||
|
"dv_ratio",
|
||||||
|
"dv_ttm",
|
||||||
|
"total_share",
|
||||||
|
"float_share",
|
||||||
|
"free_share",
|
||||||
|
"total_mv",
|
||||||
|
"circ_mv",
|
||||||
|
}
|
||||||
|
|
||||||
|
# 将依赖分为不同表的字段
|
||||||
pro_bar_deps = dependencies & pro_bar_fields
|
pro_bar_deps = dependencies & pro_bar_fields
|
||||||
other_deps = dependencies - pro_bar_fields
|
daily_basic_deps = dependencies & daily_basic_fields
|
||||||
|
other_deps = dependencies - pro_bar_fields - daily_basic_fields
|
||||||
|
|
||||||
data_specs = []
|
data_specs = []
|
||||||
|
|
||||||
@@ -115,7 +129,15 @@ class ExecutionPlanner:
|
|||||||
DataSpec(
|
DataSpec(
|
||||||
table="pro_bar",
|
table="pro_bar",
|
||||||
columns=sorted(pro_bar_deps),
|
columns=sorted(pro_bar_deps),
|
||||||
lookback_days=lookback_days,
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# daily_basic 表的数据规格
|
||||||
|
if daily_basic_deps:
|
||||||
|
data_specs.append(
|
||||||
|
DataSpec(
|
||||||
|
table="daily_basic",
|
||||||
|
columns=sorted(daily_basic_deps),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -125,46 +147,7 @@ class ExecutionPlanner:
|
|||||||
DataSpec(
|
DataSpec(
|
||||||
table="daily",
|
table="daily",
|
||||||
columns=sorted(other_deps),
|
columns=sorted(other_deps),
|
||||||
lookback_days=lookback_days,
|
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
return data_specs
|
return data_specs
|
||||||
|
|
||||||
def _extract_max_window(self, node: Node) -> int:
|
|
||||||
"""从表达式中提取最大窗口大小。
|
|
||||||
|
|
||||||
Args:
|
|
||||||
node: AST 节点
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
最大窗口大小,无时序函数返回 1
|
|
||||||
"""
|
|
||||||
if isinstance(node, FunctionNode):
|
|
||||||
window = 1
|
|
||||||
# 检查函数参数中的窗口大小
|
|
||||||
for arg in node.args:
|
|
||||||
if (
|
|
||||||
isinstance(arg, Constant)
|
|
||||||
and isinstance(arg.value, int)
|
|
||||||
and arg.value > window
|
|
||||||
):
|
|
||||||
window = arg.value
|
|
||||||
|
|
||||||
# 递归检查子表达式
|
|
||||||
for arg in node.args:
|
|
||||||
if isinstance(arg, Node) and not isinstance(arg, Constant):
|
|
||||||
window = max(window, self._extract_max_window(arg))
|
|
||||||
|
|
||||||
return window
|
|
||||||
|
|
||||||
elif isinstance(node, BinaryOpNode):
|
|
||||||
return max(
|
|
||||||
self._extract_max_window(node.left),
|
|
||||||
self._extract_max_window(node.right),
|
|
||||||
)
|
|
||||||
|
|
||||||
elif isinstance(node, UnaryOpNode):
|
|
||||||
return self._extract_max_window(node.operand)
|
|
||||||
|
|
||||||
return 1
|
|
||||||
|
|||||||
@@ -1,302 +0,0 @@
|
|||||||
"""训练流程入口脚本
|
|
||||||
|
|
||||||
运行方式:
|
|
||||||
uv run python -m src.training.main
|
|
||||||
|
|
||||||
或:
|
|
||||||
uv run python src/training/main.py
|
|
||||||
|
|
||||||
本脚本提供两种运行方式:
|
|
||||||
1. run_full_pipeline(): 完整训练流程(数据准备 -> 训练 -> 预测)
|
|
||||||
2. prepare_data_and_train() + train_and_predict(): 分步执行,便于调试和调整
|
|
||||||
|
|
||||||
因子配置示例:
|
|
||||||
from src.factors import MovingAverageFactor, ReturnRankFactor
|
|
||||||
|
|
||||||
# 直接传入因子实例列表 - 最简单的方式
|
|
||||||
factors = [
|
|
||||||
MovingAverageFactor(period=5),
|
|
||||||
MovingAverageFactor(period=10),
|
|
||||||
MovingAverageFactor(period=20),
|
|
||||||
ReturnRankFactor(period=5),
|
|
||||||
ReturnRankFactor(period=10),
|
|
||||||
]
|
|
||||||
|
|
||||||
# 运行完整流程
|
|
||||||
result = run_full_pipeline(factors=factors)
|
|
||||||
"""
|
|
||||||
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Optional, List
|
|
||||||
|
|
||||||
import polars as pl
|
|
||||||
|
|
||||||
from src.factors import BaseFactor
|
|
||||||
from src.training.pipeline import (
|
|
||||||
FactorConfig,
|
|
||||||
predict_top_stocks,
|
|
||||||
prepare_data,
|
|
||||||
save_top_stocks,
|
|
||||||
train_model,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def prepare_data_and_train(
|
|
||||||
factors: Optional[List[BaseFactor]] = None,
|
|
||||||
data_dir: str = "data",
|
|
||||||
train_start: str = "20190101",
|
|
||||||
train_end: str = "20231231",
|
|
||||||
val_start: str = "20240102",
|
|
||||||
val_end: str = "20240531",
|
|
||||||
test_start: str = "20240602",
|
|
||||||
test_end: str = "20241231",
|
|
||||||
) -> tuple[pl.DataFrame, pl.DataFrame, pl.DataFrame, FactorConfig, str]:
|
|
||||||
"""第一步:数据处理
|
|
||||||
|
|
||||||
加载原始数据,计算因子和标签,拆分训练/验证/测试集。
|
|
||||||
|
|
||||||
Args:
|
|
||||||
factors: 因子实例列表,默认为 None(使用 MA5, MA10, ReturnRank5)
|
|
||||||
data_dir: 数据目录
|
|
||||||
train_start: 训练集开始日期
|
|
||||||
train_end: 训练集结束日期
|
|
||||||
val_start: 验证集开始日期
|
|
||||||
val_end: 验证集结束日期
|
|
||||||
test_start: 测试集开始日期
|
|
||||||
test_end: 测试集结束日期
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
tuple: (train_data, val_data, test_data, factor_config, label_col)
|
|
||||||
"""
|
|
||||||
print("=" * 50)
|
|
||||||
print("[Step 1] 数据处理")
|
|
||||||
print("=" * 50)
|
|
||||||
print(f"训练集: {train_start} -> {train_end}")
|
|
||||||
print(f"验证集: {val_start} -> {val_end}")
|
|
||||||
print(f"测试集: {test_start} -> {test_end}")
|
|
||||||
print()
|
|
||||||
|
|
||||||
# 1. 准备数据
|
|
||||||
train_data, val_data, test_data, factor_config = prepare_data(
|
|
||||||
factors=factors,
|
|
||||||
data_dir=data_dir,
|
|
||||||
train_start=train_start,
|
|
||||||
train_end=train_end,
|
|
||||||
val_start=val_start,
|
|
||||||
val_end=val_end,
|
|
||||||
test_start=test_start,
|
|
||||||
test_end=test_end,
|
|
||||||
)
|
|
||||||
|
|
||||||
print(f"训练集样本数: {len(train_data)}")
|
|
||||||
print(f"验证集样本数: {len(val_data)}")
|
|
||||||
print(f"测试集样本数: {len(test_data)}")
|
|
||||||
print()
|
|
||||||
|
|
||||||
# 打印少量数据样本展示
|
|
||||||
print("=" * 50)
|
|
||||||
print("[数据预览] 训练集前3行:")
|
|
||||||
print(train_data.head(3))
|
|
||||||
print()
|
|
||||||
print("[数据预览] 验证集前3行:")
|
|
||||||
print(val_data.head(3))
|
|
||||||
print()
|
|
||||||
print("[数据预览] 测试集前3行:")
|
|
||||||
print(test_data.head(3))
|
|
||||||
print()
|
|
||||||
|
|
||||||
# 2. 获取特征列名
|
|
||||||
feature_cols = factor_config.get_feature_names()
|
|
||||||
label_col = "label"
|
|
||||||
|
|
||||||
print(f"特征列: {feature_cols}")
|
|
||||||
print(f"标签列: {label_col}")
|
|
||||||
print()
|
|
||||||
|
|
||||||
return train_data, val_data, test_data, factor_config, label_col
|
|
||||||
|
|
||||||
|
|
||||||
def train_and_predict(
|
|
||||||
train_data: pl.DataFrame,
|
|
||||||
val_data: pl.DataFrame,
|
|
||||||
test_data: pl.DataFrame,
|
|
||||||
factor_config: FactorConfig,
|
|
||||||
label_col: str = "label",
|
|
||||||
top_n: int = 5,
|
|
||||||
output_path: str = "output/top_stocks.tsv",
|
|
||||||
) -> pl.DataFrame:
|
|
||||||
"""第二步:训练和预测
|
|
||||||
|
|
||||||
使用处理好的数据训练模型,进行测试集预测并保存结果。
|
|
||||||
|
|
||||||
Args:
|
|
||||||
train_data: 训练数据
|
|
||||||
val_data: 验证数据
|
|
||||||
test_data: 测试数据
|
|
||||||
factor_config: 因子配置对象
|
|
||||||
label_col: 标签列名
|
|
||||||
top_n: 每日选股数量
|
|
||||||
output_path: 输出文件路径
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
选股结果DataFrame
|
|
||||||
"""
|
|
||||||
print("=" * 50)
|
|
||||||
print("[Step 2] 模型训练与预测")
|
|
||||||
print("=" * 50)
|
|
||||||
print()
|
|
||||||
|
|
||||||
# 获取特征列名
|
|
||||||
feature_cols = factor_config.get_feature_names()
|
|
||||||
print(f"使用特征: {feature_cols}")
|
|
||||||
print()
|
|
||||||
|
|
||||||
# 3. 训练模型
|
|
||||||
print("[Training] Training model...")
|
|
||||||
model, pipeline = train_model(
|
|
||||||
train_data=train_data,
|
|
||||||
val_data=val_data,
|
|
||||||
feature_cols=feature_cols,
|
|
||||||
label_col=label_col,
|
|
||||||
)
|
|
||||||
print()
|
|
||||||
|
|
||||||
# 4. 测试集预测
|
|
||||||
print("[Predict] Predicting on test set...")
|
|
||||||
top_stocks = predict_top_stocks(
|
|
||||||
model=model,
|
|
||||||
pipeline=pipeline,
|
|
||||||
test_data=test_data,
|
|
||||||
feature_cols=feature_cols,
|
|
||||||
top_n=top_n,
|
|
||||||
)
|
|
||||||
print()
|
|
||||||
|
|
||||||
# 5. 保存结果
|
|
||||||
print(f"[Saving] Saving results to {output_path}...")
|
|
||||||
Path(output_path).parent.mkdir(parents=True, exist_ok=True)
|
|
||||||
save_top_stocks(top_stocks, output_path)
|
|
||||||
print()
|
|
||||||
|
|
||||||
return top_stocks
|
|
||||||
|
|
||||||
|
|
||||||
def run_full_pipeline(
|
|
||||||
factors: Optional[List[BaseFactor]] = None,
|
|
||||||
train_start: str = "20190101",
|
|
||||||
train_end: str = "20231231",
|
|
||||||
val_start: str = "20240102",
|
|
||||||
val_end: str = "20240531",
|
|
||||||
test_start: str = "20240602",
|
|
||||||
test_end: str = "20241231",
|
|
||||||
top_n: int = 5,
|
|
||||||
output_path: str = "output/top_stocks.tsv",
|
|
||||||
) -> pl.DataFrame:
|
|
||||||
"""运行完整训练流程
|
|
||||||
|
|
||||||
相当于依次调用 prepare_data_and_train 和 train_and_predict。
|
|
||||||
|
|
||||||
Args:
|
|
||||||
factors: 因子实例列表,默认为 None(使用 MA5, MA10, ReturnRank5)
|
|
||||||
train_start: 训练集开始日期
|
|
||||||
train_end: 训练集结束日期
|
|
||||||
val_start: 验证集开始日期
|
|
||||||
val_end: 验证集结束日期
|
|
||||||
test_start: 测试集开始日期
|
|
||||||
test_end: 测试集结束日期
|
|
||||||
top_n: 每日选股数量
|
|
||||||
output_path: 输出文件路径
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
选股结果DataFrame
|
|
||||||
"""
|
|
||||||
# 第一步:数据处理
|
|
||||||
train_data, val_data, test_data, factor_config, label_col = prepare_data_and_train(
|
|
||||||
factors=factors,
|
|
||||||
train_start=train_start,
|
|
||||||
train_end=train_end,
|
|
||||||
val_start=val_start,
|
|
||||||
val_end=val_end,
|
|
||||||
test_start=test_start,
|
|
||||||
test_end=test_end,
|
|
||||||
)
|
|
||||||
|
|
||||||
# 第二步:训练和预测
|
|
||||||
result = train_and_predict(
|
|
||||||
train_data=train_data,
|
|
||||||
val_data=val_data,
|
|
||||||
test_data=test_data,
|
|
||||||
factor_config=factor_config,
|
|
||||||
label_col=label_col,
|
|
||||||
top_n=top_n,
|
|
||||||
output_path=output_path,
|
|
||||||
)
|
|
||||||
|
|
||||||
print("=" * 50)
|
|
||||||
print("[Done] 训练流程完成!")
|
|
||||||
print("=" * 50)
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
from src.factors import MovingAverageFactor, ReturnRankFactor
|
|
||||||
|
|
||||||
# ========== 因子配置 ==========
|
|
||||||
# 直接传入因子实例列表 - 简单直观
|
|
||||||
factors = [
|
|
||||||
MovingAverageFactor(period=5), # 5日移动平均线
|
|
||||||
MovingAverageFactor(period=10), # 10日移动平均线
|
|
||||||
MovingAverageFactor(period=20), # 20日移动平均线
|
|
||||||
ReturnRankFactor(period=5), # 5日收益率排名
|
|
||||||
ReturnRankFactor(period=10), # 10日收益率排名
|
|
||||||
]
|
|
||||||
|
|
||||||
# ========== 运行方式 ==========
|
|
||||||
|
|
||||||
# 方式一:完整流程(一次性执行)
|
|
||||||
# result = run_full_pipeline(
|
|
||||||
# factors=factors,
|
|
||||||
# train_start="20190101",
|
|
||||||
# train_end="20231231",
|
|
||||||
# val_start="20240102",
|
|
||||||
# val_end="20240531",
|
|
||||||
# test_start="20240602",
|
|
||||||
# test_end="20241231",
|
|
||||||
# top_n=5,
|
|
||||||
# output_path="output/top_stocks.tsv",
|
|
||||||
# )
|
|
||||||
|
|
||||||
# 方式二:分步执行(便于调试)
|
|
||||||
# 第一步:数据处理
|
|
||||||
train_data, val_data, test_data, factor_config, label_col = prepare_data_and_train(
|
|
||||||
factors=factors,
|
|
||||||
train_start="20190101",
|
|
||||||
train_end="20231231",
|
|
||||||
val_start="20240102",
|
|
||||||
val_end="20240531",
|
|
||||||
test_start="20240602",
|
|
||||||
test_end="20241231",
|
|
||||||
)
|
|
||||||
|
|
||||||
# 可在此处添加自定义逻辑,例如:
|
|
||||||
# - 查看数据分布
|
|
||||||
# - 调整特征
|
|
||||||
# - 保存中间结果
|
|
||||||
print("\n[Info] 因子配置详情:")
|
|
||||||
print(f" 因子列表: {factor_config.get_feature_names()}")
|
|
||||||
print(f" 最大回溯天数: {factor_config.get_max_lookback()}")
|
|
||||||
|
|
||||||
# 第二步:训练和预测
|
|
||||||
# result = train_and_predict(
|
|
||||||
# train_data=train_data,
|
|
||||||
# val_data=val_data,
|
|
||||||
# test_data=test_data,
|
|
||||||
# factor_config=factor_config,
|
|
||||||
# label_col=label_col,
|
|
||||||
# top_n=5,
|
|
||||||
# output_path="output/top_stocks.tsv",
|
|
||||||
# )
|
|
||||||
#
|
|
||||||
# print("\n[Result] Top stocks selection:")
|
|
||||||
# print(result)
|
|
||||||
@@ -71,7 +71,7 @@ class TestFactorEngineEndToEnd:
|
|||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def engine(self, mock_data):
|
def engine(self, mock_data):
|
||||||
"""提供配置好的 FactorEngine fixture。"""
|
"""提供配置好的 FactorEngine fixture。"""
|
||||||
data_source = {"daily": mock_data}
|
data_source = {"pro_bar": mock_data}
|
||||||
return FactorEngine(data_source=data_source, max_workers=2)
|
return FactorEngine(data_source=data_source, max_workers=2)
|
||||||
|
|
||||||
def test_simple_symbol_expression(self, engine):
|
def test_simple_symbol_expression(self, engine):
|
||||||
@@ -116,7 +116,7 @@ class TestFullWorkflow:
|
|||||||
|
|
||||||
# 2. 初始化引擎
|
# 2. 初始化引擎
|
||||||
print("\nStep 2: Initialize FactorEngine...")
|
print("\nStep 2: Initialize FactorEngine...")
|
||||||
engine = FactorEngine(data_source={"daily": mock_data})
|
engine = FactorEngine(data_source={"pro_bar": mock_data})
|
||||||
print(" Engine initialized")
|
print(" Engine initialized")
|
||||||
|
|
||||||
# 3. 注册因子 - 使用简单因子避免回看窗口问题
|
# 3. 注册因子 - 使用简单因子避免回看窗口问题
|
||||||
|
|||||||
@@ -5,6 +5,7 @@
|
|||||||
2. return_5_rank: 5日收益率在截面上的排名
|
2. return_5_rank: 5日收益率在截面上的排名
|
||||||
3. ma5: 5日均线 (ts_mean(close, 5))
|
3. ma5: 5日均线 (ts_mean(close, 5))
|
||||||
4. ma10: 10日均线 (ts_mean(close, 10))
|
4. ma10: 10日均线 (ts_mean(close, 10))
|
||||||
|
5. market_cap_rank: 市值百分比排名 (cs_rank(total_mv))
|
||||||
|
|
||||||
特点:使用因子字符串架构(add_factor + 字符串表达式)
|
特点:使用因子字符串架构(add_factor + 字符串表达式)
|
||||||
|
|
||||||
@@ -48,6 +49,11 @@ def test_two_stocks_string_factors():
|
|||||||
print("\n[1.4] ma10 = ts_mean(close, 10)")
|
print("\n[1.4] ma10 = ts_mean(close, 10)")
|
||||||
print(f" 字符串表达式: {ma10_str}")
|
print(f" 字符串表达式: {ma10_str}")
|
||||||
|
|
||||||
|
# market_cap_rank: 市值百分比排名 (截面排名)
|
||||||
|
market_cap_rank_str = "cs_rank(total_mv)"
|
||||||
|
print("\n[1.5] market_cap_rank = cs_rank(total_mv)")
|
||||||
|
print(f" 字符串表达式: {market_cap_rank_str}")
|
||||||
|
|
||||||
# ========================================================================
|
# ========================================================================
|
||||||
# 1.5 打印数据来源信息
|
# 1.5 打印数据来源信息
|
||||||
# ========================================================================
|
# ========================================================================
|
||||||
@@ -66,6 +72,7 @@ def test_two_stocks_string_factors():
|
|||||||
"return_5_rank": return_5_rank_str,
|
"return_5_rank": return_5_rank_str,
|
||||||
"ma5": ma5_str,
|
"ma5": ma5_str,
|
||||||
"ma10": ma10_str,
|
"ma10": ma10_str,
|
||||||
|
"market_cap_rank": market_cap_rank_str,
|
||||||
}
|
}
|
||||||
|
|
||||||
for name, expr_str in expressions_str.items():
|
for name, expr_str in expressions_str.items():
|
||||||
@@ -102,9 +109,12 @@ def test_two_stocks_string_factors():
|
|||||||
engine.add_factor("ma10", ma10_str)
|
engine.add_factor("ma10", ma10_str)
|
||||||
print("[2.4] 注册 ma10 (字符串方式)")
|
print("[2.4] 注册 ma10 (字符串方式)")
|
||||||
|
|
||||||
|
engine.add_factor("market_cap_rank", market_cap_rank_str)
|
||||||
|
print("[2.5] 注册 market_cap_rank (市值百分比排名,字符串方式)")
|
||||||
|
|
||||||
# 也注册原始 close 价格用于验证
|
# 也注册原始 close 价格用于验证
|
||||||
engine.add_factor("close_price", "close")
|
engine.add_factor("close_price", "close")
|
||||||
print("[2.5] 注册 close_price (原始收盘价,字符串方式)")
|
print("[2.6] 注册 close_price (原始收盘价,字符串方式)")
|
||||||
|
|
||||||
print(f"\n已注册因子列表: {engine.list_registered()}")
|
print(f"\n已注册因子列表: {engine.list_registered()}")
|
||||||
|
|
||||||
@@ -125,7 +135,6 @@ def test_two_stocks_string_factors():
|
|||||||
for i, spec in enumerate(plan.data_specs, 1):
|
for i, spec in enumerate(plan.data_specs, 1):
|
||||||
print(f" [{i}] 表名: {spec.table}")
|
print(f" [{i}] 表名: {spec.table}")
|
||||||
print(f" 字段: {spec.columns}")
|
print(f" 字段: {spec.columns}")
|
||||||
print(f" 回看天数: {spec.lookback_days}")
|
|
||||||
|
|
||||||
# ========================================================================
|
# ========================================================================
|
||||||
# 3. 执行计算(两支股票)
|
# 3. 执行计算(两支股票)
|
||||||
@@ -143,7 +152,14 @@ def test_two_stocks_string_factors():
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
result = engine.compute(
|
result = engine.compute(
|
||||||
factor_names=["return_5", "return_5_rank", "ma5", "ma10", "close_price"],
|
factor_names=[
|
||||||
|
"return_5",
|
||||||
|
"return_5_rank",
|
||||||
|
"ma5",
|
||||||
|
"ma10",
|
||||||
|
"close_price",
|
||||||
|
"market_cap_rank",
|
||||||
|
],
|
||||||
start_date=start_date,
|
start_date=start_date,
|
||||||
end_date=end_date,
|
end_date=end_date,
|
||||||
stock_codes=stock_codes,
|
stock_codes=stock_codes,
|
||||||
@@ -345,6 +361,10 @@ def test_two_stocks_string_factors():
|
|||||||
print(" - 每天两支股票的排名之和应接近 1")
|
print(" - 每天两支股票的排名之和应接近 1")
|
||||||
print("-" * 60)
|
print("-" * 60)
|
||||||
|
|
||||||
|
# 5.5.1 return_5_rank 截面排名验证
|
||||||
|
print("\n[5.5.1] return_5_rank 截面排名验证:")
|
||||||
|
print("-" * 60)
|
||||||
|
|
||||||
# 获取有效数据
|
# 获取有效数据
|
||||||
result_valid = result.drop_nulls(subset=["return_5_rank"])
|
result_valid = result.drop_nulls(subset=["return_5_rank"])
|
||||||
|
|
||||||
@@ -373,6 +393,39 @@ def test_two_stocks_string_factors():
|
|||||||
else:
|
else:
|
||||||
print(" [警告] 截面排名之和不接近 1")
|
print(" [警告] 截面排名之和不接近 1")
|
||||||
|
|
||||||
|
# 5.5.2 market_cap_rank 市值百分比排名验证
|
||||||
|
print("\n[5.5.2] market_cap_rank 市值百分比排名验证:")
|
||||||
|
print("-" * 60)
|
||||||
|
|
||||||
|
result_valid_mv = result.drop_nulls(subset=["market_cap_rank"])
|
||||||
|
|
||||||
|
if len(result_valid_mv) > 0:
|
||||||
|
min_rank_mv = result_valid_mv["market_cap_rank"].min()
|
||||||
|
max_rank_mv = result_valid_mv["market_cap_rank"].max()
|
||||||
|
print(f"\n市值排名范围: [{min_rank_mv:.4f}, {max_rank_mv:.4f}]")
|
||||||
|
|
||||||
|
if 0 <= min_rank_mv <= 1 and 0 <= max_rank_mv <= 1:
|
||||||
|
print(" [成功] 市值排名值在 [0, 1] 区间内!")
|
||||||
|
else:
|
||||||
|
print(" [警告] 市值排名值超出 [0, 1] 区间")
|
||||||
|
|
||||||
|
# 检查某天两支股票的市值排名之和
|
||||||
|
sample_date_mv = result_valid_mv["trade_date"][0]
|
||||||
|
day_data_mv = result_valid_mv.filter(
|
||||||
|
result_valid_mv["trade_date"] == sample_date_mv
|
||||||
|
)
|
||||||
|
if len(day_data_mv) == 2:
|
||||||
|
rank_sum_mv = day_data_mv["market_cap_rank"].sum()
|
||||||
|
print(f"\n示例日期 {sample_date_mv} 的市值排名验证:")
|
||||||
|
for row in day_data_mv.iter_rows(named=True):
|
||||||
|
print(f" {row['ts_code']}: {row['market_cap_rank']:.4f}")
|
||||||
|
print(f" 排名之和: {rank_sum_mv:.4f} (两支股票应接近 1)")
|
||||||
|
|
||||||
|
if abs(rank_sum_mv - 1.0) < 0.01:
|
||||||
|
print(" [成功] 市值排名之和验证通过!")
|
||||||
|
else:
|
||||||
|
print(" [警告] 市值排名之和不接近 1")
|
||||||
|
|
||||||
# ========================================================================
|
# ========================================================================
|
||||||
# 6. 统计摘要
|
# 6. 统计摘要
|
||||||
# ========================================================================
|
# ========================================================================
|
||||||
@@ -394,7 +447,7 @@ def test_two_stocks_string_factors():
|
|||||||
print(f"总记录数: {len(stock_data)}")
|
print(f"总记录数: {len(stock_data)}")
|
||||||
print(f"有效记录数 (去空值后): {len(stock_valid)}")
|
print(f"有效记录数 (去空值后): {len(stock_valid)}")
|
||||||
|
|
||||||
factor_cols = ["return_5", "return_5_rank", "ma5", "ma10"]
|
factor_cols = ["return_5", "return_5_rank", "ma5", "ma10", "market_cap_rank"]
|
||||||
|
|
||||||
for col in factor_cols:
|
for col in factor_cols:
|
||||||
if col in stock_data.columns:
|
if col in stock_data.columns:
|
||||||
@@ -418,7 +471,6 @@ def test_two_stocks_string_factors():
|
|||||||
# ========================================================================
|
# ========================================================================
|
||||||
print("\n" + "=" * 80)
|
print("\n" + "=" * 80)
|
||||||
|
|
||||||
|
|
||||||
# ========================================================================
|
# ========================================================================
|
||||||
# 8. 测试总结
|
# 8. 测试总结
|
||||||
# ========================================================================
|
# ========================================================================
|
||||||
@@ -434,10 +486,13 @@ def test_two_stocks_string_factors():
|
|||||||
print()
|
print()
|
||||||
print("因子定义方式: 字符串表达式 (add_factor 方法)")
|
print("因子定义方式: 字符串表达式 (add_factor 方法)")
|
||||||
print("计算因子:")
|
print("计算因子:")
|
||||||
print(" 1. return_5 - 5日收益率 (字符串: '(close / ts_delay(close, 5)) - 1')")
|
print(
|
||||||
print(" 2. return_5_rank - 5日收益率截面排名 (字符串: 'cs_rank(...)')")
|
" 1. return_5 - 5日收益率 (字符串: '(close / ts_delay(close, 5)) - 1')"
|
||||||
print(" 3. ma5 - 5日均线 (字符串: 'ts_mean(close, 5)')")
|
)
|
||||||
print(" 4. ma10 - 10日均线 (字符串: 'ts_mean(close, 10)')")
|
print(" 2. return_5_rank - 5日收益率截面排名 (字符串: 'cs_rank(...)')")
|
||||||
|
print(" 3. ma5 - 5日均线 (字符串: 'ts_mean(close, 5)')")
|
||||||
|
print(" 4. ma10 - 10日均线 (字符串: 'ts_mean(close, 10)')")
|
||||||
|
print(" 5. market_cap_rank - 市值百分比排名 (字符串: 'cs_rank(total_mv)')")
|
||||||
print()
|
print()
|
||||||
print("验证结果:")
|
print("验证结果:")
|
||||||
print(" - 字符串表达式解析: 正常")
|
print(" - 字符串表达式解析: 正常")
|
||||||
|
|||||||
Reference in New Issue
Block a user