feat(data): 添加每日指标接口并优化因子引擎
- 新增 api_daily_basic.py 封装 Tushare 每日指标接口 - 因子引擎移除 lookback_days,支持 daily_basic 表字段路由 - 将每日指标纳入自动同步流程 - 删除废弃的 training/main.py
This commit is contained in:
131
AGENTS.md
131
AGENTS.md
@@ -82,20 +82,21 @@ ProStock/
|
||||
│ │
|
||||
│ ├── data/ # 数据获取与存储
|
||||
│ │ ├── api_wrappers/ # Tushare API 封装
|
||||
│ │ │ ├── base_sync.py # 同步基础抽象类(BaseDataSync/StockBasedSync/DateBasedSync)
|
||||
│ │ │ ├── api_daily.py # 日线数据接口(DailySync)
|
||||
│ │ │ ├── api_pro_bar.py # Pro Bar 数据接口(ProBarSync)
|
||||
│ │ │ ├── api_stock_basic.py # 股票基础信息接口
|
||||
│ │ │ ├── api_trade_cal.py # 交易日历接口
|
||||
│ │ │ ├── api_bak_basic.py # 历史股票列表接口(BakBasicSync)
|
||||
│ │ │ ├── api_namechange.py # 股票名称变更接口
|
||||
│ │ │ ├── financial_data/ # 财务数据接口
|
||||
│ │ │ │ ├── api_income.py # 利润表接口
|
||||
│ │ │ ├── base_sync.py # 同步基础抽象类(BaseDataSync/StockBasedSync/DateBasedSync)
|
||||
│ │ │ ├── api_daily.py # 日线数据接口(DailySync)
|
||||
│ │ │ ├── api_pro_bar.py # Pro Bar 数据接口(ProBarSync)
|
||||
│ │ │ ├── api_stock_basic.py # 股票基础信息接口
|
||||
│ │ │ ├── api_trade_cal.py # 交易日历接口
|
||||
│ │ │ ├── api_bak_basic.py # 历史股票列表接口(BakBasicSync)
|
||||
│ │ │ ├── api_namechange.py # 股票名称变更接口
|
||||
│ │ │ ├── financial_data/ # 财务数据接口
|
||||
│ │ │ │ ├── api_income.py # 利润表接口
|
||||
│ │ │ │ └── api_financial_sync.py # 财务数据同步
|
||||
│ │ │ └── __init__.py
|
||||
│ │ ├── __init__.py
|
||||
│ │ ├── client.py # Tushare API 客户端(带速率限制)
|
||||
│ │ ├── config.py # 数据模块配置
|
||||
│ │ ├── data_router.py # 数据路由器(factors/engine 专用)
|
||||
│ │ ├── db_inspector.py # 数据库信息查看工具
|
||||
│ │ ├── db_manager.py # DuckDB 表管理和同步
|
||||
│ │ ├── rate_limiter.py # 令牌桶速率限制器
|
||||
@@ -104,20 +105,29 @@ ProStock/
|
||||
│ │ └── utils.py # 数据模块工具函数
|
||||
│ │
|
||||
│ ├── factors/ # 因子计算框架(DSL 表达式驱动)
|
||||
│ │ ├── engine/ # 执行引擎子模块
|
||||
│ │ │ ├── __init__.py # 导出引擎组件
|
||||
│ │ │ ├── data_spec.py # 数据规格定义(DataSpec, ExecutionPlan)
|
||||
│ │ │ ├── data_router.py # 数据路由器
|
||||
│ │ │ ├── planner.py # 执行计划生成器(ExecutionPlanner)
|
||||
│ │ │ ├── compute_engine.py # 计算引擎(ComputeEngine)
|
||||
│ │ │ └── factor_engine.py # 因子引擎统一入口(FactorEngine)
|
||||
│ │ ├── __init__.py # 导出所有公开 API
|
||||
│ │ ├── dsl.py # DSL 表达式层 - 节点定义和运算符重载
|
||||
│ │ ├── api.py # API 层 - 常用符号(close/open等)和函数(ts_mean/cs_rank等)
|
||||
│ │ ├── compiler.py # AST 编译器 - 依赖提取
|
||||
│ │ ├── translator.py # Polars 表达式翻译器
|
||||
│ │ └── engine.py # 因子执行引擎 - 统一入口
|
||||
│ │ ├── parser.py # 字符串公式解析器(FormulaParser)
|
||||
│ │ ├── registry.py # 函数注册表(FunctionRegistry)
|
||||
│ │ └── exceptions.py # 异常定义(FormulaParseError等)
|
||||
│ │
|
||||
│ ├── pipeline/ # 模型训练管道
|
||||
│ │ ├── __init__.py
|
||||
│ │ ├── pipeline.py # 处理流水线
|
||||
│ │ ├── registry.py # 插件注册中心
|
||||
│ │ ├── pipeline.py # 处理流水线(ProcessingPipeline)
|
||||
│ │ ├── registry.py # 插件注册中心(PluginRegistry)
|
||||
│ │ ├── core/ # 核心抽象
|
||||
│ │ │ ├── __init__.py
|
||||
│ │ │ ├── base.py # 基类定义
|
||||
│ │ │ ├── base.py # 基类定义(BaseProcessor/BaseModel/BaseSplitter等)
|
||||
│ │ │ └── splitter.py # 时间序列划分策略
|
||||
│ │ ├── models/ # 模型实现
|
||||
│ │ │ ├── __init__.py
|
||||
@@ -128,14 +138,23 @@ ProStock/
|
||||
│ │
|
||||
│ └── training/ # 训练入口
|
||||
│ ├── __init__.py
|
||||
│ ├── main.py # 训练主程序
|
||||
│ ├── pipeline.py # 训练流程配置
|
||||
│ └── output/ # 训练输出
|
||||
│ └── top_stocks.tsv # 推荐股票结果
|
||||
│
|
||||
├── tests/ # 测试文件
|
||||
│ ├── test_sync.py
|
||||
│ └── test_daily.py
|
||||
│ ├── test_daily.py
|
||||
│ ├── test_factor_engine.py
|
||||
│ ├── test_factor_integration.py
|
||||
│ ├── test_pro_bar.py
|
||||
│ ├── test_601117_factors.py
|
||||
│ ├── test_two_stocks_string_factors.py
|
||||
│ ├── test_db_manager.py
|
||||
│ ├── test_daily_storage.py
|
||||
│ ├── test_tushare_api.py
|
||||
│ └── pipeline/
|
||||
│ └── test_core.py
|
||||
├── config/ # 配置文件
|
||||
│ └── .env.local # 环境变量(不在 git 中)
|
||||
├── data/ # 数据存储(DuckDB)
|
||||
@@ -266,10 +285,13 @@ except Exception as e:
|
||||
### 依赖项
|
||||
关键包:
|
||||
- `pandas>=2.0.0` - 数据处理
|
||||
- `polars>=0.20.0` - 高性能数据处理(因子计算)
|
||||
- `numpy>=1.24.0` - 数值计算
|
||||
- `tushare>=2.0.0` - A股数据 API
|
||||
- `pydantic>=2.0.0`、`pydantic-settings>=2.0.0` - 配置
|
||||
- `tqdm>=4.65.0` - 进度条
|
||||
- `lightgbm>=4.0.0` - 机器学习模型
|
||||
- `catboost>=1.2.0` - 机器学习模型
|
||||
- `pytest` - 测试(开发)
|
||||
|
||||
### 环境变量
|
||||
@@ -292,6 +314,9 @@ uv run python -c "from src.data.sync import sync_all; sync_all(force_full=True)"
|
||||
|
||||
# 自定义线程数
|
||||
uv run python -c "from src.data.sync import sync_all; sync_all(max_workers=20)"
|
||||
|
||||
# 运行因子计算测试
|
||||
uv run pytest tests/test_factor_engine.py -v
|
||||
```
|
||||
|
||||
## 架构变更历史
|
||||
@@ -309,8 +334,16 @@ uv run python -c "from src.data.sync import sync_all; sync_all(max_workers=20)"
|
||||
- 新增 `api.py`: 常用符号(close/open/volume等)和函数(ts_mean/cs_rank等)
|
||||
- 新增 `compiler.py`: AST 编译器,提取表达式依赖
|
||||
- 新增 `translator.py`: 将 DSL 表达式翻译为 Polars 表达式
|
||||
- 重构 `engine.py`: 统一执行引擎入口,整合 DataRouter、ExecutionPlanner、ComputeEngine
|
||||
- 移除: `base.py`、`composite.py`、`data_loader.py`、`data_spec.py`
|
||||
- 新增 `parser.py`: 字符串公式解析器(FormulaParser),支持从字符串解析 DSL 表达式
|
||||
- 新增 `registry.py`: 函数注册表(FunctionRegistry),管理字符串函数名到 Python 函数的映射
|
||||
- 新增 `exceptions.py`: 公式解析异常定义(FormulaParseError、UnknownFunctionError等)
|
||||
- 重构 `engine/` 子模块:
|
||||
- `factor_engine.py`: 因子引擎统一入口(FactorEngine)
|
||||
- `data_spec.py`: 数据规格定义(DataSpec, ExecutionPlan)
|
||||
- `data_router.py`: 数据路由器
|
||||
- `planner.py`: 执行计划生成器(ExecutionPlanner)
|
||||
- `compute_engine.py`: 计算引擎(ComputeEngine)
|
||||
- 移除: `base.py`、`composite.py`、`data_loader.py`、根目录的 `data_spec.py`
|
||||
- 移除: `factors/momentum/` 和 `factors/financial/` 子目录
|
||||
**使用方式对比**:
|
||||
```python
|
||||
@@ -328,6 +361,11 @@ uv run python -c "from src.data.sync import sync_all; sync_all(max_workers=20)"
|
||||
engine = FactorEngine()
|
||||
engine.register("ma20", ma20)
|
||||
result = engine.compute(["ma20"], "20240101", "20240131")
|
||||
|
||||
# 字符串公式解析(Phase 1 新增)
|
||||
from src.factors import FormulaParser, FunctionRegistry
|
||||
parser = FormulaParser(FunctionRegistry())
|
||||
expr = parser.parse("ts_mean(close, 20) / close") # 从字符串解析
|
||||
```
|
||||
|
||||
#### data 模块补充完善
|
||||
@@ -411,10 +449,20 @@ DSL 层 (dsl.py) <- 因子表达式 (Node)
|
||||
Compiler (compiler.py) <- AST 依赖提取
|
||||
|
|
||||
v
|
||||
Parser (parser.py) <- 字符串公式解析器
|
||||
|
|
||||
v
|
||||
Registry (registry.py) <- 函数注册表
|
||||
|
|
||||
v
|
||||
Translator (translator.py) <- 翻译为 Polars 表达式
|
||||
|
|
||||
v
|
||||
Engine (engine.py) <- 执行引擎 (DataRouter/ExecutionPlanner/ComputeEngine)
|
||||
Engine (engine/) <- 执行引擎
|
||||
| - FactorEngine: 统一入口
|
||||
| - DataRouter: 数据路由
|
||||
| - ExecutionPlanner: 执行计划
|
||||
| - ComputeEngine: 计算引擎
|
||||
|
|
||||
v
|
||||
数据层 (data_router.py + DuckDB) <- 数据获取和存储
|
||||
@@ -422,7 +470,7 @@ Engine (engine.py) <- 执行引擎 (DataRouter/ExecutionPlanner/ComputeEngi
|
||||
|
||||
### 使用方式
|
||||
|
||||
#### 1. 基础表达式
|
||||
#### 1. 基础表达式(DSL 方式)
|
||||
|
||||
```python
|
||||
from src.factors.api import close, open, ts_mean, cs_rank
|
||||
@@ -435,15 +483,37 @@ price_rank = cs_rank(close) # 收盘价截面排名
|
||||
alpha = ma20 * 0.6 + price_rank * 0.4
|
||||
```
|
||||
|
||||
#### 2. 注册和执行
|
||||
#### 2. 字符串公式解析(Phase 1 新增)
|
||||
|
||||
```python
|
||||
from src.factors import FormulaParser, FunctionRegistry
|
||||
|
||||
# 创建解析器(自动加载 api.py 中的所有函数)
|
||||
parser = FormulaParser(FunctionRegistry())
|
||||
|
||||
# 从字符串解析公式
|
||||
expr = parser.parse("ts_mean(close, 20) / close")
|
||||
complex_expr = parser.parse("cs_rank(ts_mean(close, 5) - ts_mean(close, 20))")
|
||||
|
||||
# 支持完整运算符和函数调用
|
||||
expr2 = parser.parse("(close - open) / open * 100") # 涨跌幅
|
||||
expr3 = parser.parse("ts_corr(close, volume, 20)") # 量价相关性
|
||||
```
|
||||
|
||||
#### 3. 注册和执行
|
||||
|
||||
```python
|
||||
from src.factors import FactorEngine
|
||||
|
||||
engine = FactorEngine()
|
||||
|
||||
# 注册 DSL 表达式
|
||||
engine.register("ma20", ma20)
|
||||
engine.register("price_rank", price_rank)
|
||||
|
||||
# 或注册字符串解析的表达式
|
||||
engine.register("alpha", parser.parse("ma20 * 0.6 + price_rank * 0.4"))
|
||||
|
||||
# 执行计算
|
||||
result = engine.compute(
|
||||
factor_names=["ma20", "price_rank"],
|
||||
@@ -504,6 +574,27 @@ expr3 = -change # 涨跌额取反
|
||||
expr4 = ts_mean(cs_rank(close), 20) # 排名后的20日平滑
|
||||
```
|
||||
|
||||
### 异常处理
|
||||
|
||||
框架提供清晰的异常类型帮助定位问题:
|
||||
|
||||
- `FormulaParseError` - 公式解析错误基类
|
||||
- `UnknownFunctionError` - 未知函数错误(提供模糊匹配建议)
|
||||
- `InvalidSyntaxError` - 语法错误
|
||||
- `EmptyExpressionError` - 空表达式错误
|
||||
- `DuplicateFunctionError` - 函数重复注册错误
|
||||
|
||||
示例:
|
||||
```python
|
||||
from src.factors import FormulaParser, FunctionRegistry, UnknownFunctionError
|
||||
|
||||
parser = FormulaParser(FunctionRegistry())
|
||||
try:
|
||||
expr = parser.parse("unknown_func(close)")
|
||||
except UnknownFunctionError as e:
|
||||
print(e) # 显示错误位置和可用函数建议
|
||||
```
|
||||
|
||||
|
||||
## AI 行为准则
|
||||
## AI 行为准则
|
||||
|
||||
557
docs/factor_calculation_flow.md
Normal file
557
docs/factor_calculation_flow.md
Normal file
@@ -0,0 +1,557 @@
|
||||
# 因子计算流程详解
|
||||
|
||||
本文档详细描述 ProStock 项目中因子计算引擎的完整数据流,以因子表达式 `(close / ts_delay(close, 5)) - 1` 为例,说明从字符串解析到最终计算结果的完整流程。
|
||||
|
||||
## 目录
|
||||
|
||||
1. [整体架构概览](#1-整体架构概览)
|
||||
2. [阶段一:字符串解析](#2-阶段一字符串解析)
|
||||
3. [阶段二:AST 编译与依赖提取](#3-阶段二ast-编译与依赖提取)
|
||||
4. [阶段三:执行计划生成](#4-阶段三执行计划生成)
|
||||
5. [阶段四:数据获取](#5-阶段四数据获取)
|
||||
6. [阶段五:Polars 翻译与计算](#6-阶段五polars-翻译与计算)
|
||||
7. [完整调用链示例](#7-完整调用链示例)
|
||||
8. [数据流时序图](#8-数据流时序图)
|
||||
|
||||
---
|
||||
|
||||
## 1. 整体架构概览
|
||||
|
||||
### 1.1 架构层次
|
||||
|
||||
因子框架采用分层设计,从上到下依次为:
|
||||
|
||||
```
|
||||
API 层 (api.py)
|
||||
|
|
||||
v
|
||||
DSL 层 (dsl.py) ← 因子表达式 (Node)
|
||||
|
|
||||
v
|
||||
Parser (parser.py) ← 字符串公式解析
|
||||
|
|
||||
v
|
||||
Registry (registry.py) ← 函数注册表
|
||||
|
|
||||
v
|
||||
Compiler (compiler.py) ← AST 依赖提取
|
||||
|
|
||||
v
|
||||
Planner (planner.py) ← 执行计划生成
|
||||
|
|
||||
v
|
||||
Translator (translator.py) ← 翻译为 Polars 表达式
|
||||
|
|
||||
v
|
||||
Engine (engine/) ← 执行引擎
|
||||
| - FactorEngine: 统一入口
|
||||
| - DataRouter: 数据路由
|
||||
| - ExecutionPlanner: 执行计划
|
||||
| - ComputeEngine: 计算引擎
|
||||
|
|
||||
v
|
||||
数据层 (data/storage.py) ← DuckDB 数据获取和存储
|
||||
```
|
||||
|
||||
### 1.2 核心组件职责
|
||||
|
||||
| 组件 | 文件路径 | 主要职责 |
|
||||
|------|----------|----------|
|
||||
| **FormulaParser** | `factors/parser.py` | 将字符串表达式解析为 DSL AST 节点树 |
|
||||
| **FunctionRegistry** | `factors/registry.py` | 管理函数名到 Python 实现的映射 |
|
||||
| **DSL Nodes** | `factors/dsl.py` | 定义表达式节点(Symbol, FunctionNode, BinaryOpNode 等)|
|
||||
| **DependencyExtractor** | `factors/compiler.py` | 从 AST 提取依赖的数据字段 |
|
||||
| **ExecutionPlanner** | `factors/engine/planner.py` | 整合编译器和翻译器生成执行计划 |
|
||||
| **DataRouter** | `factors/engine/data_router.py` | 按需取数、组装核心宽表 |
|
||||
| **PolarsTranslator** | `factors/translator.py` | 将 DSL AST 翻译为 Polars 表达式 |
|
||||
| **ComputeEngine** | `factors/engine/compute_engine.py` | 执行 Polars 表达式计算 |
|
||||
| **FactorEngine** | `factors/engine/factor_engine.py` | 系统统一入口,协调各组件 |
|
||||
| **Storage** | `data/storage.py` | DuckDB 数据存储和查询接口 |
|
||||
|
||||
---
|
||||
|
||||
## 2. 阶段一:字符串解析
|
||||
|
||||
### 2.1 解析流程
|
||||
|
||||
当用户调用 `parser.parse("(close / ts_delay(close, 5)) - 1")` 时,解析流程如下:
|
||||
|
||||
```python
|
||||
# 1. FormulaParser.parse() 方法
|
||||
formula = "(close / ts_delay(close, 5)) - 1"
|
||||
ast_tree = ast.parse(formula, mode='eval') # Python AST
|
||||
dsl_node = self._visit(ast_tree.body) # 递归转换为 DSL Node
|
||||
```
|
||||
|
||||
**解析步骤:**
|
||||
|
||||
1. **Python AST 解析**:使用 Python 标准库 `ast.parse()` 将字符串解析为 Python AST
|
||||
2. **AST 遍历转换**:通过 `_visit()` 方法递归遍历 Python AST,映射为 DSL 节点
|
||||
3. **节点类型映射**:
|
||||
- `ast.Name` → `Symbol`
|
||||
- `ast.Constant` → `Constant`
|
||||
- `ast.BinOp` → `BinaryOpNode`
|
||||
- `ast.UnaryOp` → `UnaryOpNode`
|
||||
- `ast.Call` → `FunctionNode`
|
||||
|
||||
### 2.2 解析结果 AST 结构
|
||||
|
||||
```
|
||||
BinaryOpNode(op='-', left, right=Constant(1))
|
||||
├── left: BinaryOpNode(op='/', left, right)
|
||||
│ ├── left: Symbol('close')
|
||||
│ └── right: FunctionNode('ts_delay', [Symbol('close'), Constant(5)])
|
||||
└── right: Constant(1)
|
||||
```
|
||||
|
||||
### 2.3 函数解析机制
|
||||
|
||||
对于函数调用 `ts_delay(close, 5)`:
|
||||
|
||||
```python
|
||||
# _visit_Call 方法处理逻辑
|
||||
def _visit_Call(self, node: ast.Call) -> FunctionNode:
|
||||
func_name = node.func.id # "ts_delay"
|
||||
|
||||
# 从注册表获取函数实现
|
||||
if self.registry.has(func_name):
|
||||
func_impl = self.registry.get(func_name)
|
||||
|
||||
# 递归解析参数
|
||||
args = [self._visit(arg) for arg in node.args]
|
||||
|
||||
# 调用函数实现,返回 FunctionNode
|
||||
return func_impl(*args)
|
||||
```
|
||||
|
||||
**关键点**:`ts_delay` 函数在 `api.py` 中定义:
|
||||
|
||||
```python
|
||||
def ts_delay(x: NodeOrStr, periods: int) -> FunctionNode:
|
||||
return FunctionNode('ts_delay', to_node(x), Constant(periods))
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 3. 阶段二:AST 编译与依赖提取
|
||||
|
||||
### 3.1 依赖提取流程
|
||||
|
||||
解析后的 AST 需要提取依赖的原始数据字段:
|
||||
|
||||
```python
|
||||
# DependencyExtractor.extract_dependencies(node)
|
||||
extractor = DependencyExtractor()
|
||||
dependencies = extractor.extract_dependencies(dsl_node)
|
||||
# 结果: {'close'}
|
||||
```
|
||||
|
||||
**提取逻辑**:
|
||||
|
||||
```python
|
||||
def _visit(self, node):
|
||||
if isinstance(node, Symbol):
|
||||
return {node.name} # 收集字段名
|
||||
elif isinstance(node, (BinaryOpNode, FunctionNode)):
|
||||
# 递归收集子节点依赖
|
||||
deps = set()
|
||||
for child in node.args:
|
||||
deps.update(self._visit(child))
|
||||
return deps
|
||||
# ... 其他节点类型
|
||||
```
|
||||
|
||||
### 3.2 依赖的作用
|
||||
|
||||
提取的依赖 `{close}` 用于:
|
||||
1. **数据规格推导**:确定需要从数据库读取哪些字段
|
||||
2. **执行计划生成**:明确数据需求,避免读取不必要的字段
|
||||
|
||||
---
|
||||
|
||||
## 4. 阶段三:执行计划生成
|
||||
|
||||
### 4.1 ExecutionPlanner 的作用
|
||||
|
||||
`ExecutionPlanner.create_plan()` 将 AST 转换为可执行的 `ExecutionPlan`:
|
||||
|
||||
```python
|
||||
planner = ExecutionPlanner()
|
||||
plan = planner.create_plan(
|
||||
node=dsl_node, # 解析后的 DSL 节点
|
||||
output_name="returns_5d" # 输出列名
|
||||
)
|
||||
```
|
||||
|
||||
### 4.2 计划生成流程
|
||||
|
||||
```python
|
||||
def create_plan(self, node, output_name):
|
||||
# 1. 提取依赖
|
||||
dependencies = self.dependency_extractor.extract_dependencies(node)
|
||||
|
||||
# 2. 推导数据规格
|
||||
data_specs = self._infer_data_specs(node, dependencies)
|
||||
# 结果: [DataSpec(table='pro_bar', columns=['close'])]
|
||||
|
||||
# 3. 翻译为 Polars 表达式
|
||||
polars_expr = self.polars_translator.translate(node)
|
||||
# 结果: (pl.col('close') / pl.col('close').shift(5)) - 1
|
||||
|
||||
# 4. 构建执行计划
|
||||
return ExecutionPlan(
|
||||
data_specs=data_specs,
|
||||
polars_expr=polars_expr,
|
||||
dependencies=dependencies,
|
||||
output_name=output_name
|
||||
)
|
||||
```
|
||||
|
||||
### 4.3 数据规格推导
|
||||
|
||||
根据依赖字段,`_infer_data_specs` 推导出需要的数据规格:
|
||||
|
||||
```python
|
||||
def _infer_data_specs(self, node, dependencies):
|
||||
return [
|
||||
DataSpec(
|
||||
table='pro_bar', # 默认使用 pro_bar 表
|
||||
columns=list(dependencies), # ['close']
|
||||
)
|
||||
]
|
||||
```
|
||||
|
||||
**DataSpec 说明**:
|
||||
- `table`: 数据表名(pro_bar 或 daily)
|
||||
- `columns`: 需要的字段列表
|
||||
|
||||
**注意**:数据获取使用用户传入的日期范围,不做自动扩展。时序因子(如 `ts_delay`、`ts_mean`)在数据不足时会返回 null,这是符合预期的行为。
|
||||
|
||||
---
|
||||
|
||||
## 5. 阶段四:数据获取
|
||||
|
||||
### 5.1 DataRouter 的核心职责
|
||||
|
||||
`DataRouter.fetch_data()` 按需取数、组装核心宽表:
|
||||
|
||||
```python
|
||||
data_router = DataRouter(storage, start_date, end_date)
|
||||
wide_data = data_router.fetch_data([plan.data_specs])
|
||||
```
|
||||
|
||||
### 5.2 数据获取流程
|
||||
|
||||
```python
|
||||
def fetch_data(self, data_specs, start_date, end_date):
|
||||
# 1. 合并数据规格,收集所需表和字段
|
||||
required_tables = self._collect_required_tables(data_specs)
|
||||
|
||||
# 2. 加载各表数据
|
||||
table_data = {}
|
||||
for table, columns in required_tables.items():
|
||||
table_data[table] = self._load_table(table, columns, start_date, end_date)
|
||||
|
||||
# 3. 组装宽表(left join 合并)
|
||||
wide_table = self._assemble_wide_table(table_data, data_specs)
|
||||
|
||||
return wide_table
|
||||
```
|
||||
|
||||
### 5.3 数据库查询
|
||||
|
||||
`_load_table` 方法从 DuckDB 读取数据:
|
||||
|
||||
```python
|
||||
def _load_table(self, table, columns, start_date, end_date):
|
||||
# 通过 Storage 查询数据库
|
||||
df = self.storage.load_polars(
|
||||
table_name=table,
|
||||
columns=columns + ['ts_code', 'trade_date'], # 必须包含主键
|
||||
start_date=start_date,
|
||||
end_date=end_date
|
||||
)
|
||||
return df
|
||||
```
|
||||
|
||||
**Storage.load_polars 内部实现**:
|
||||
|
||||
```python
|
||||
# data/storage.py
|
||||
SELECT {columns} FROM {table_name}
|
||||
WHERE trade_date BETWEEN '{start}' AND '{end}'
|
||||
ORDER BY ts_code, trade_date
|
||||
```
|
||||
|
||||
### 5.4 宽表组装
|
||||
|
||||
对于 `(close / ts_delay(close, 5)) - 1`,DataRouter 返回的宽表结构:
|
||||
|
||||
```
|
||||
┌──────────┬────────────┬───────┐
|
||||
│ ts_code │ trade_date │ close │
|
||||
├──────────┼────────────┼───────┤
|
||||
│ 000001.SZ│ 20240101 │ 10.5 │
|
||||
│ 000001.SZ│ 20240102 │ 10.6 │
|
||||
│ ... │ ... │ ... │
|
||||
│ 000002.SZ│ 20240101 │ 20.1 │
|
||||
└──────────┴────────────┴───────┘
|
||||
```
|
||||
|
||||
**注意**:数据获取使用用户传入的日期范围,不做自动扩展。对于时序因子(如 `ts_delay(close, 5)`),如果数据不足会返回 null,这是符合预期的行为。用户如需完整计算,应显式扩展日期范围。
|
||||
|
||||
---
|
||||
|
||||
## 6. 阶段五:Polars 翻译与计算
|
||||
|
||||
### 6.1 PolarsTranslator 的作用
|
||||
|
||||
将 DSL AST 翻译为 Polars 表达式(惰性计算图):
|
||||
|
||||
```python
|
||||
translator = PolarsTranslator()
|
||||
polars_expr = translator.translate(dsl_node)
|
||||
```
|
||||
|
||||
### 6.2 翻译规则
|
||||
|
||||
| DSL 节点类型 | Polars 表达式 | 说明 |
|
||||
|-------------|--------------|------|
|
||||
| `Symbol('close')` | `pl.col('close')` | 列引用 |
|
||||
| `Constant(5)` | `pl.lit(5)` | 字面量 |
|
||||
| `BinaryOpNode('/', a, b)` | `a / b` | 算术运算 |
|
||||
| `FunctionNode('ts_delay', x, n)` | `x.shift(n).over('ts_code')` | 时间序列滞后 |
|
||||
| `FunctionNode('ts_mean', x, n)` | `x.rolling_mean(n).over('ts_code')` | 时间序列均值 |
|
||||
| `FunctionNode('cs_rank', x)` | `x.rank().over('trade_date')` | 截面排名 |
|
||||
|
||||
### 6.3 时间序列函数翻译
|
||||
|
||||
`ts_delay(close, 5)` 翻译为:
|
||||
|
||||
```python
|
||||
pl.col('close').shift(5).over('ts_code')
|
||||
```
|
||||
|
||||
**关键点**:
|
||||
- `.shift(5)`:向后偏移 5 个位置
|
||||
- `.over('ts_code')`:按股票代码分组计算(每只股票独立计算)
|
||||
|
||||
### 6.4 计算执行
|
||||
|
||||
`ComputeEngine.execute()` 执行计算:
|
||||
|
||||
```python
|
||||
compute_engine = ComputeEngine()
|
||||
result = compute_engine.execute(plan, wide_data)
|
||||
```
|
||||
|
||||
**执行逻辑**:
|
||||
|
||||
```python
|
||||
def execute(self, plan, data):
|
||||
# 使用 Polars with_columns 添加因子列
|
||||
result = data.with_columns([
|
||||
plan.polars_expr.alias(plan.output_name)
|
||||
])
|
||||
return result
|
||||
```
|
||||
|
||||
### 6.5 计算结果
|
||||
|
||||
最终结果包含原始列和计算出的因子列:
|
||||
|
||||
```
|
||||
┌──────────┬────────────┬───────┬─────────────┐
|
||||
│ ts_code │ trade_date │ close │ returns_5d │
|
||||
├──────────┼────────────┼───────┼─────────────┤
|
||||
│ 000001.SZ│ 20240101 │ 10.5 │ null │ # 前5天无数据
|
||||
│ 000001.SZ│ 20240106 │ 10.8 │ 0.0286 │ # (10.8/10.5)-1
|
||||
│ ... │ ... │ ... │ ... │
|
||||
└──────────┴────────────┴───────┴─────────────┘
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 7. 完整调用链示例
|
||||
|
||||
### 7.1 用户代码
|
||||
|
||||
```python
|
||||
from src.factors import FactorEngine, FormulaParser, FunctionRegistry
|
||||
|
||||
# 1. 创建引擎
|
||||
engine = FactorEngine()
|
||||
|
||||
# 2. 解析字符串表达式
|
||||
parser = FormulaParser(FunctionRegistry())
|
||||
expr = parser.parse("(close / ts_delay(close, 5)) - 1")
|
||||
|
||||
# 3. 注册因子
|
||||
engine.register("returns_5d", expr)
|
||||
|
||||
# 4. 执行计算
|
||||
result = engine.compute(
|
||||
factor_names=["returns_5d"],
|
||||
start_date="20240101",
|
||||
end_date="20240131"
|
||||
)
|
||||
```
|
||||
|
||||
### 7.2 内部调用链
|
||||
|
||||
```
|
||||
FactorEngine.compute()
|
||||
│
|
||||
├── 1. 创建 ExecutionPlan
|
||||
│ └── ExecutionPlanner.create_plan()
|
||||
│ ├── DependencyExtractor.extract_dependencies() → {'close'}
|
||||
│ ├── _infer_data_specs() → [DataSpec('pro_bar', ['close'], 5)]
|
||||
│ └── PolarsTranslator.translate() → pl.col('close').shift(5).over('ts_code')...
|
||||
│
|
||||
├── 2. 获取数据
|
||||
│ └── DataRouter.fetch_data([plan.data_specs])
|
||||
│ ├── _load_table('pro_bar', ['close'], start_date-5d, end_date)
|
||||
│ │ └── Storage.load_polars() → 查询 DuckDB
|
||||
│ └── _assemble_wide_table() → Polars DataFrame
|
||||
│
|
||||
└── 3. 执行计算
|
||||
└── ComputeEngine.execute(plan, data)
|
||||
└── data.with_columns([polars_expr.alias('returns_5d')])
|
||||
└── Polars 执行表达式计算
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 8. 数据流时序图
|
||||
|
||||
```mermaid
|
||||
sequenceDiagram
|
||||
participant User as 用户代码
|
||||
participant FE as FactorEngine
|
||||
participant PL as ExecutionPlanner
|
||||
participant DR as DataRouter
|
||||
participant ST as Storage
|
||||
participant CE as ComputeEngine
|
||||
participant DB as DuckDB
|
||||
|
||||
User->>FE: compute("(close/ts_delay(close,5))-1", start, end)
|
||||
|
||||
Note over FE: 阶段1:创建执行计划
|
||||
FE->>PL: create_plan(node, output_name)
|
||||
PL->>PL: extract_dependencies(node)
|
||||
PL->>PL: infer_data_specs(node)
|
||||
PL->>PL: translate_to_polars(node)
|
||||
PL-->>FE: ExecutionPlan
|
||||
|
||||
Note over FE: 阶段2:数据获取
|
||||
FE->>DR: fetch_data(data_specs)
|
||||
|
||||
loop 每个需要的表
|
||||
DR->>ST: load_polars(table, columns, adjusted_start, end)
|
||||
ST->>DB: SELECT ... WHERE trade_date BETWEEN ...
|
||||
DB-->>ST: 原始数据
|
||||
ST-->>DR: Polars DataFrame
|
||||
end
|
||||
|
||||
DR->>DR: assemble_wide_table()
|
||||
DR->>DR: filter_by_date_range()
|
||||
DR-->>FE: 核心宽表
|
||||
|
||||
Note over FE: 阶段3:执行计算
|
||||
FE->>CE: execute(plan, data)
|
||||
CE->>CE: data.with_columns([polars_expr])
|
||||
CE-->>FE: 计算结果
|
||||
|
||||
FE-->>User: Polars DataFrame (含因子列)
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 附录:关键代码片段
|
||||
|
||||
### A.1 FactorEngine.compute 方法
|
||||
|
||||
```python
|
||||
def compute(
|
||||
self,
|
||||
factor_names: list[str],
|
||||
start_date: str,
|
||||
end_date: str
|
||||
) -> pl.DataFrame:
|
||||
# 1. 收集所有执行计划
|
||||
all_plans = []
|
||||
for name in factor_names:
|
||||
plan = self.factors[name]
|
||||
all_plans.append(plan)
|
||||
|
||||
# 2. 合并数据规格
|
||||
merged_specs = self._merge_data_specs([p.data_specs for p in all_plans])
|
||||
|
||||
# 3. 获取数据
|
||||
data_router = DataRouter(self.storage, start_date, end_date)
|
||||
data = data_router.fetch_data(merged_specs)
|
||||
|
||||
# 4. 执行计算
|
||||
compute_engine = ComputeEngine()
|
||||
result = compute_engine.execute_plans(all_plans, data)
|
||||
|
||||
return result
|
||||
```
|
||||
|
||||
### A.2 PolarsTranslator 的函数处理器注册
|
||||
|
||||
```python
|
||||
class PolarsTranslator:
|
||||
def __init__(self):
|
||||
self.handlers: dict[str, Callable] = {}
|
||||
self._register_default_handlers()
|
||||
|
||||
def _register_default_handlers(self):
|
||||
# 时间序列函数
|
||||
self.handlers['ts_delay'] = lambda x, n: x.shift(n)
|
||||
self.handlers['ts_mean'] = lambda x, n: x.rolling_mean(n)
|
||||
self.handlers['ts_std'] = lambda x, n: x.rolling_std(n)
|
||||
|
||||
# 截面函数
|
||||
self.handlers['cs_rank'] = lambda x: x.rank()
|
||||
self.handlers['cs_zscore'] = lambda x: (x - x.mean()) / x.std()
|
||||
```
|
||||
|
||||
### A.3 时序因子数据获取说明
|
||||
|
||||
由于移除了自动日期扩展机制,用户需要显式管理时序因子的日期范围:
|
||||
|
||||
```python
|
||||
# 示例:计算 2024-01-15 到 2024-01-20 的 5 日收益率
|
||||
# 需要显式提供足够的历史数据
|
||||
result = engine.compute(
|
||||
factor_names=["returns_5d"], # (close / ts_delay(close, 5)) - 1
|
||||
start_date="20240108", # 向前扩展至少 5 个交易日
|
||||
end_date="20240120"
|
||||
)
|
||||
|
||||
# 在结果中,2024-01-15 之前的日期会因数据不足而返回 null
|
||||
# 用户可以自行过滤到目标日期范围
|
||||
result = result.filter(
|
||||
(pl.col("trade_date") >= "20240115") & (pl.col("trade_date") <= "20240120")
|
||||
)
|
||||
```
|
||||
|
||||
**设计原则**:
|
||||
- 显式优于隐式:数据来源透明,用户可以完全控制数据范围
|
||||
- 符合 Polars 行为:rolling/shift 操作在窗口不足时返回 null
|
||||
- 可验证性:用户可以明确知道用了哪些数据计算因子
|
||||
|
||||
---
|
||||
|
||||
## 总结
|
||||
|
||||
ProStock 的因子计算引擎采用 **DSL(领域特定语言)+ 延迟计算** 的架构设计,具有以下特点:
|
||||
|
||||
1. **声明式**:用户通过数学表达式描述因子逻辑,无需关心实现细节
|
||||
2. **惰性求值**:表达式构建时不立即计算,生成执行计划后统一执行
|
||||
3. **智能数据获取**:自动分析依赖、推导数据规格、按需取数
|
||||
4. **向量化计算**:基于 Polars 的高性能向量化运算,支持时序和截面计算
|
||||
5. **可扩展性**:通过 FunctionRegistry 可以轻松添加新的因子函数
|
||||
|
||||
整个流程从字符串到计算结果,经历了解析 → 编译 → 计划 → 取数 → 计算五个阶段,各组件职责清晰,便于维护和扩展。
|
||||
@@ -5,6 +5,7 @@ All wrapper files follow the naming convention: api_{data_type}.py
|
||||
|
||||
Available APIs:
|
||||
- api_daily: Daily market data (日线行情)
|
||||
- api_daily_basic: Daily basic indicators (每日指标,换手率、PE、PB、市值等)
|
||||
- api_pro_bar: Pro Bar universal market data (通用行情,后复权)
|
||||
- api_stock_basic: Stock basic information (股票基本信息)
|
||||
- api_trade_cal: Trading calendar (交易日历)
|
||||
@@ -13,9 +14,10 @@ Available APIs:
|
||||
|
||||
Example:
|
||||
>>> from src.data.api_wrappers import get_daily, get_stock_basic, get_trade_cal, get_bak_basic
|
||||
>>> from src.data.api_wrappers import get_pro_bar, sync_pro_bar
|
||||
>>> from src.data.api_wrappers import get_pro_bar, sync_pro_bar, get_daily_basic, sync_daily_basic
|
||||
>>> data = get_daily('000001.SZ', start_date='20240101', end_date='20240131')
|
||||
>>> pro_data = get_pro_bar('000001.SZ', start_date='20240101', end_date='20240131')
|
||||
>>> daily_basic = get_daily_basic(trade_date='20240101')
|
||||
>>> stocks = get_stock_basic()
|
||||
>>> calendar = get_trade_cal('20240101', '20240131')
|
||||
>>> bak_basic = get_bak_basic(trade_date='20240101')
|
||||
@@ -27,6 +29,12 @@ from src.data.api_wrappers.api_daily import (
|
||||
preview_daily_sync,
|
||||
DailySync,
|
||||
)
|
||||
from src.data.api_wrappers.api_daily_basic import (
|
||||
get_daily_basic,
|
||||
sync_daily_basic,
|
||||
preview_daily_basic_sync,
|
||||
DailyBasicSync,
|
||||
)
|
||||
from src.data.api_wrappers.api_pro_bar import (
|
||||
get_pro_bar,
|
||||
sync_pro_bar,
|
||||
@@ -55,6 +63,11 @@ __all__ = [
|
||||
"sync_daily",
|
||||
"preview_daily_sync",
|
||||
"DailySync",
|
||||
# Daily basic indicators
|
||||
"get_daily_basic",
|
||||
"sync_daily_basic",
|
||||
"preview_daily_basic_sync",
|
||||
"DailyBasicSync",
|
||||
# Pro Bar (universal market data)
|
||||
"get_pro_bar",
|
||||
"sync_pro_bar",
|
||||
|
||||
@@ -496,3 +496,73 @@ df = ts.pro_bar(ts_code='000001.SZ', start_date='20180101', end_date='20181011',
|
||||
|
||||
|
||||
df = ts.pro_bar(ts_code='000001.SH', asset='I', start_date='20180101', end_date='20181011')
|
||||
|
||||
每日指标
|
||||
接口:daily_basic,可以通过数据工具调试和查看数据。
|
||||
更新时间:交易日每日15点~17点之间
|
||||
描述:获取全部股票每日重要的基本面指标,可用于选股分析、报表展示等。单次请求最大返回6000条数据,可按日线循环提取全部历史。
|
||||
积分:至少2000积分才可以调取,5000积分无总量限制,具体请参阅积分获取办法
|
||||
|
||||
输入参数
|
||||
|
||||
名称 类型 必选 描述
|
||||
ts_code str Y 股票代码(二选一)
|
||||
trade_date str N 交易日期 (二选一)
|
||||
start_date str N 开始日期(YYYYMMDD)
|
||||
end_date str N 结束日期(YYYYMMDD)
|
||||
注:日期都填YYYYMMDD格式,比如20181010
|
||||
|
||||
输出参数
|
||||
|
||||
名称 类型 描述
|
||||
ts_code str TS股票代码
|
||||
trade_date str 交易日期
|
||||
close float 当日收盘价
|
||||
turnover_rate float 换手率(%)
|
||||
turnover_rate_f float 换手率(自由流通股)
|
||||
volume_ratio float 量比
|
||||
pe float 市盈率(总市值/净利润, 亏损的PE为空)
|
||||
pe_ttm float 市盈率(TTM,亏损的PE为空)
|
||||
pb float 市净率(总市值/净资产)
|
||||
ps float 市销率
|
||||
ps_ttm float 市销率(TTM)
|
||||
dv_ratio float 股息率 (%)
|
||||
dv_ttm float 股息率(TTM)(%)
|
||||
total_share float 总股本 (万股)
|
||||
float_share float 流通股本 (万股)
|
||||
free_share float 自由流通股本 (万)
|
||||
total_mv float 总市值 (万元)
|
||||
circ_mv float 流通市值(万元)
|
||||
接口用法
|
||||
|
||||
|
||||
pro = ts.pro_api()
|
||||
|
||||
df = pro.daily_basic(ts_code='', trade_date='20180726', fields='ts_code,trade_date,turnover_rate,volume_ratio,pe,pb')
|
||||
或者
|
||||
|
||||
|
||||
df = pro.query('daily_basic', ts_code='', trade_date='20180726',fields='ts_code,trade_date,turnover_rate,volume_ratio,pe,pb')
|
||||
数据样例
|
||||
|
||||
ts_code trade_date turnover_rate volume_ratio pe pb
|
||||
0 600230.SH 20180726 2.4584 0.72 8.6928 3.7203
|
||||
1 600237.SH 20180726 1.4737 0.88 166.4001 1.8868
|
||||
2 002465.SZ 20180726 0.7489 0.72 71.8943 2.6391
|
||||
3 300732.SZ 20180726 6.7083 0.77 21.8101 3.2513
|
||||
4 600007.SH 20180726 0.0381 0.61 23.7696 2.3774
|
||||
5 300068.SZ 20180726 1.4583 0.52 27.8166 1.7549
|
||||
6 300552.SZ 20180726 2.0728 0.95 56.8004 2.9279
|
||||
7 601369.SH 20180726 0.2088 0.95 44.1163 1.8001
|
||||
8 002518.SZ 20180726 0.5814 0.76 15.1004 2.5626
|
||||
9 002913.SZ 20180726 12.1096 1.03 33.1279 2.9217
|
||||
10 601818.SH 20180726 0.1893 0.86 6.3064 0.7209
|
||||
11 600926.SH 20180726 0.6065 0.46 9.1772 0.9808
|
||||
12 002166.SZ 20180726 0.7582 0.82 16.9868 3.3452
|
||||
13 600841.SH 20180726 0.3754 1.02 66.2647 2.2302
|
||||
14 300634.SZ 20180726 23.1127 1.26 120.3053 14.3168
|
||||
15 300126.SZ 20180726 1.2304 1.11 348.4306 1.5171
|
||||
16 300718.SZ 20180726 17.6612 0.92 32.0239 3.8661
|
||||
17 000708.SZ 20180726 0.5575 0.70 10.3674 1.0276
|
||||
18 002626.SZ 20180726 0.6187 0.83 22.7580 4.2446
|
||||
19 600816.SH 20180726 0.6745 0.65 11.0778 3.2214
|
||||
252
src/data/api_wrappers/api_daily_basic.py
Normal file
252
src/data/api_wrappers/api_daily_basic.py
Normal file
@@ -0,0 +1,252 @@
|
||||
"""每日指标数据接口。
|
||||
|
||||
获取全部股票每日重要的基本面指标,包括换手率、市盈率、市净率、
|
||||
总市值、流通市值等,可用于选股分析、报表展示等。
|
||||
"""
|
||||
|
||||
from typing import Optional, Dict, Any
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from src.data.client import TushareClient
|
||||
from src.data.api_wrappers.base_sync import DateBasedSync
|
||||
|
||||
|
||||
def get_daily_basic(
|
||||
trade_date: Optional[str] = None,
|
||||
ts_code: Optional[str] = None,
|
||||
start_date: Optional[str] = None,
|
||||
end_date: Optional[str] = None,
|
||||
client: Optional[TushareClient] = None,
|
||||
) -> pd.DataFrame:
|
||||
"""Fetch daily basic indicators from Tushare.
|
||||
|
||||
This interface retrieves important daily fundamental indicators for all stocks,
|
||||
including turnover rate, PE, PB, market value, etc. It can be used for stock
|
||||
selection analysis and report display.
|
||||
|
||||
Note: At least one of trade_date or ts_code must be provided. The recommended
|
||||
approach is to use trade_date to fetch data for all stocks on a specific date,
|
||||
which is more efficient than fetching by individual stock codes.
|
||||
|
||||
Args:
|
||||
trade_date: Specific trade date (YYYYMMDD format). Use this to get all
|
||||
stocks' data for a single date. More efficient than ts_code.
|
||||
ts_code: Stock code (e.g., '000001.SZ', '600000.SH'). Optional if
|
||||
trade_date is provided.
|
||||
start_date: Start date (YYYYMMDD format). Use with end_date for date range.
|
||||
end_date: End date (YYYYMMDD format). Use with start_date for date range.
|
||||
client: Optional TushareClient instance for shared rate limiting.
|
||||
If None, creates a new client. For concurrent sync operations,
|
||||
pass a shared client to ensure proper rate limiting.
|
||||
|
||||
Returns:
|
||||
pd.DataFrame with columns:
|
||||
- ts_code: TS stock code
|
||||
- trade_date: Trade date (YYYYMMDD)
|
||||
- close: Closing price
|
||||
- turnover_rate: Turnover rate (%)
|
||||
- turnover_rate_f: Turnover rate (free float shares)
|
||||
- volume_ratio: Volume ratio
|
||||
- pe: Price-to-earnings ratio (total market cap / net profit)
|
||||
- pe_ttm: PE ratio (TTM)
|
||||
- pb: Price-to-book ratio (total market cap / net assets)
|
||||
- ps: Price-to-sales ratio
|
||||
- ps_ttm: PS ratio (TTM)
|
||||
- dv_ratio: Dividend yield (%)
|
||||
- dv_ttm: Dividend yield (TTM) (%)
|
||||
- total_share: Total shares (10k shares)
|
||||
- float_share: Float shares (10k shares)
|
||||
- free_share: Free float shares (10k shares)
|
||||
- total_mv: Total market value (10k CNY)
|
||||
- circ_mv: Circulating market value (10k CNY)
|
||||
|
||||
Example:
|
||||
>>> # Get all stocks for a single date (recommended)
|
||||
>>> data = get_daily_basic(trade_date='20240101')
|
||||
>>>
|
||||
>>> # Get specific stock data
|
||||
>>> data = get_daily_basic(ts_code='000001.SZ', trade_date='20240101')
|
||||
>>>
|
||||
>>> # Get date range data for a specific stock
|
||||
>>> data = get_daily_basic(
|
||||
... ts_code='000001.SZ',
|
||||
... start_date='20240101',
|
||||
... end_date='20240131'
|
||||
... )
|
||||
"""
|
||||
client = client or TushareClient()
|
||||
|
||||
# Build parameters
|
||||
params = {}
|
||||
if trade_date:
|
||||
params["trade_date"] = trade_date
|
||||
if ts_code:
|
||||
params["ts_code"] = ts_code
|
||||
if start_date:
|
||||
params["start_date"] = start_date
|
||||
if end_date:
|
||||
params["end_date"] = end_date
|
||||
|
||||
# Fetch data using daily_basic API
|
||||
data = client.query("daily_basic", **params)
|
||||
|
||||
# Rename date column if needed
|
||||
if "date" in data.columns:
|
||||
data = data.rename(columns={"date": "trade_date"})
|
||||
|
||||
return data
|
||||
|
||||
|
||||
class DailyBasicSync(DateBasedSync):
|
||||
"""每日指标数据批量同步管理器,支持全量/增量同步。
|
||||
|
||||
继承自 DateBasedSync,按日期顺序获取数据。
|
||||
每日指标数据适合按日期获取,一次 API 调用即可获取全市场数据。
|
||||
|
||||
Example:
|
||||
>>> sync = DailyBasicSync()
|
||||
>>> results = sync.sync_all() # 增量同步
|
||||
>>> results = sync.sync_all(force_full=True) # 全量同步
|
||||
>>> preview = sync.preview_sync() # 预览
|
||||
"""
|
||||
|
||||
table_name = "daily_basic"
|
||||
default_start_date = "20180101"
|
||||
|
||||
# 表结构定义
|
||||
TABLE_SCHEMA = {
|
||||
"ts_code": "VARCHAR(16) NOT NULL",
|
||||
"trade_date": "DATE NOT NULL",
|
||||
"close": "DOUBLE",
|
||||
"turnover_rate": "DOUBLE",
|
||||
"turnover_rate_f": "DOUBLE",
|
||||
"volume_ratio": "DOUBLE",
|
||||
"pe": "DOUBLE",
|
||||
"pe_ttm": "DOUBLE",
|
||||
"pb": "DOUBLE",
|
||||
"ps": "DOUBLE",
|
||||
"ps_ttm": "DOUBLE",
|
||||
"dv_ratio": "DOUBLE",
|
||||
"dv_ttm": "DOUBLE",
|
||||
"total_share": "DOUBLE",
|
||||
"float_share": "DOUBLE",
|
||||
"free_share": "DOUBLE",
|
||||
"total_mv": "DOUBLE",
|
||||
"circ_mv": "DOUBLE",
|
||||
}
|
||||
|
||||
# 索引定义
|
||||
TABLE_INDEXES = [
|
||||
("idx_daily_basic_date_code", ["trade_date", "ts_code"]),
|
||||
]
|
||||
|
||||
# 主键定义
|
||||
PRIMARY_KEY = ("ts_code", "trade_date")
|
||||
|
||||
def fetch_single_date(self, trade_date: str) -> pd.DataFrame:
|
||||
"""获取单日的每日指标数据。
|
||||
|
||||
Args:
|
||||
trade_date: 交易日期(YYYYMMDD)
|
||||
|
||||
Returns:
|
||||
包含当日所有股票指标的 DataFrame
|
||||
"""
|
||||
# 使用 get_daily_basic 获取数据(传递共享 client)
|
||||
data = get_daily_basic(
|
||||
trade_date=trade_date,
|
||||
client=self.client, # 传递共享客户端以确保限流
|
||||
)
|
||||
return data
|
||||
|
||||
|
||||
def sync_daily_basic(
|
||||
force_full: bool = False,
|
||||
start_date: Optional[str] = None,
|
||||
end_date: Optional[str] = None,
|
||||
dry_run: bool = False,
|
||||
) -> pd.DataFrame:
|
||||
"""同步所有股票的每日指标数据。
|
||||
|
||||
这是每日指标数据同步的主要入口点。
|
||||
|
||||
Args:
|
||||
force_full: 若为 True,强制从 20180101 完整重载
|
||||
start_date: 手动指定起始日期(YYYYMMDD)
|
||||
end_date: 手动指定结束日期(默认为今天)
|
||||
dry_run: 若为 True,仅预览将要同步的内容,不写入数据
|
||||
|
||||
Returns:
|
||||
同步的数据 DataFrame
|
||||
|
||||
Example:
|
||||
>>> # 首次同步(从 20180101 全量加载)
|
||||
>>> result = sync_daily_basic()
|
||||
>>>
|
||||
>>> # 后续同步(增量 - 仅新数据)
|
||||
>>> result = sync_daily_basic()
|
||||
>>>
|
||||
>>> # 强制完整重载
|
||||
>>> result = sync_daily_basic(force_full=True)
|
||||
>>>
|
||||
>>> # 手动指定日期范围
|
||||
>>> result = sync_daily_basic(start_date='20240101', end_date='20240131')
|
||||
>>>
|
||||
>>> # Dry run(仅预览)
|
||||
>>> result = sync_daily_basic(dry_run=True)
|
||||
"""
|
||||
sync_manager = DailyBasicSync()
|
||||
return sync_manager.sync_all(
|
||||
force_full=force_full,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
dry_run=dry_run,
|
||||
)
|
||||
|
||||
|
||||
def preview_daily_basic_sync(
|
||||
force_full: bool = False,
|
||||
start_date: Optional[str] = None,
|
||||
end_date: Optional[str] = None,
|
||||
sample_size: int = 3,
|
||||
) -> Dict[str, Any]:
|
||||
"""预览每日指标同步数据量和样本(不实际同步)。
|
||||
|
||||
这是推荐的方式,可在实际同步前检查将要同步的内容。
|
||||
|
||||
Args:
|
||||
force_full: 若为 True,预览全量同步(从 20180101)
|
||||
start_date: 手动指定起始日期(覆盖自动检测)
|
||||
end_date: 手动指定结束日期(默认为今天)
|
||||
sample_size: 预览用样本天数(默认: 3)
|
||||
|
||||
Returns:
|
||||
包含预览信息的字典:
|
||||
{
|
||||
'sync_needed': bool,
|
||||
'date_count': int,
|
||||
'start_date': str,
|
||||
'end_date': str,
|
||||
'estimated_records': int,
|
||||
'sample_data': pd.DataFrame,
|
||||
'mode': str, # 'full', 'incremental', 或 'none'
|
||||
}
|
||||
|
||||
Example:
|
||||
>>> # 预览将要同步的内容
|
||||
>>> preview = preview_daily_basic_sync()
|
||||
>>>
|
||||
>>> # 预览全量同步
|
||||
>>> preview = preview_daily_basic_sync(force_full=True)
|
||||
>>>
|
||||
>>> # 预览更多样本
|
||||
>>> preview = preview_daily_basic_sync(sample_size=5)
|
||||
"""
|
||||
sync_manager = DailyBasicSync()
|
||||
return sync_manager.preview_sync(
|
||||
force_full=force_full,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
sample_size=sample_size,
|
||||
)
|
||||
@@ -7,6 +7,7 @@
|
||||
|
||||
✅ 本模块包含的同步逻辑(每日更新):
|
||||
- api_daily.py: 日线数据同步 (DailySync 类)
|
||||
- api_daily_basic.py: 每日指标数据同步 (DailyBasicSync 类)
|
||||
- api_bak_basic.py: 历史股票列表同步 (BakBasicSync 类)
|
||||
- api_pro_bar.py: Pro Bar 数据同步 (ProBarSync 类)
|
||||
- api_stock_basic.py: 股票基本信息同步
|
||||
@@ -44,6 +45,7 @@ from src.data.api_wrappers import sync_all_stocks
|
||||
from src.data.api_wrappers.api_daily import sync_daily, preview_daily_sync
|
||||
from src.data.api_wrappers.api_pro_bar import sync_pro_bar
|
||||
from src.data.api_wrappers.api_bak_basic import sync_bak_basic
|
||||
from src.data.api_wrappers.api_daily_basic import sync_daily_basic
|
||||
|
||||
|
||||
def preview_sync(
|
||||
@@ -157,7 +159,8 @@ def sync_all_data(
|
||||
2. 股票基本信息 (sync_all_stocks)
|
||||
3. 日线数据 (sync_daily)
|
||||
4. Pro Bar 数据 (sync_pro_bar)
|
||||
5. 历史股票列表 (sync_bak_basic)
|
||||
5. 每日指标数据 (sync_daily_basic)
|
||||
6. 历史股票列表 (sync_bak_basic)
|
||||
|
||||
【不包含的同步(需单独调用)】
|
||||
- 财务数据: 利润表、资产负债表、现金流量表(季度更新)
|
||||
@@ -238,7 +241,7 @@ def sync_all_data(
|
||||
results["daily"] = pd.DataFrame()
|
||||
|
||||
# 4. Sync Pro Bar data
|
||||
print("\n[4/5] Syncing Pro Bar data (with adj, tor, vr)...")
|
||||
print("\n[4/6] Syncing Pro Bar data (with adj, tor, vr)...")
|
||||
try:
|
||||
# 确保表存在
|
||||
from src.data.api_wrappers.api_pro_bar import ProBarSync
|
||||
@@ -255,14 +258,31 @@ def sync_all_data(
|
||||
sum(len(df) for df in pro_bar_result.values()) if pro_bar_result else 0
|
||||
)
|
||||
print(
|
||||
f"[4/5] Pro Bar data: OK ({total_pro_bar_records} records from {len(pro_bar_result)} stocks)"
|
||||
f"[4/6] Pro Bar data: OK ({total_pro_bar_records} records from {len(pro_bar_result)} stocks)"
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"[4/5] Pro Bar data: FAILED - {e}")
|
||||
print(f"[4/6] Pro Bar data: FAILED - {e}")
|
||||
results["pro_bar"] = pd.DataFrame()
|
||||
|
||||
# 5. Sync stock historical list (bak_basic)
|
||||
print("\n[5/5] Syncing stock historical list (bak_basic)...")
|
||||
# 5. Sync daily basic indicators
|
||||
print(
|
||||
"\n[5/6] Syncing daily basic indicators (PE, PB, turnover rate, market value)..."
|
||||
)
|
||||
try:
|
||||
# 确保表存在
|
||||
from src.data.api_wrappers.api_daily_basic import DailyBasicSync
|
||||
|
||||
DailyBasicSync().ensure_table_exists()
|
||||
|
||||
daily_basic_result = sync_daily_basic(force_full=force_full, dry_run=dry_run)
|
||||
results["daily_basic"] = daily_basic_result
|
||||
print(f"[5/6] Daily basic: OK ({len(daily_basic_result)} records)")
|
||||
except Exception as e:
|
||||
print(f"[5/6] Daily basic: FAILED - {e}")
|
||||
results["daily_basic"] = pd.DataFrame()
|
||||
|
||||
# 6. Sync stock historical list (bak_basic)
|
||||
print("\n[6/6] Syncing stock historical list (bak_basic)...")
|
||||
try:
|
||||
# 确保表存在
|
||||
from src.data.api_wrappers.api_bak_basic import BakBasicSync
|
||||
@@ -271,9 +291,9 @@ def sync_all_data(
|
||||
|
||||
bak_basic_result = sync_bak_basic(force_full=force_full)
|
||||
results["bak_basic"] = bak_basic_result
|
||||
print(f"[5/5] Bak basic: OK ({len(bak_basic_result)} records)")
|
||||
print(f"[6/6] Bak basic: OK ({len(bak_basic_result)} records)")
|
||||
except Exception as e:
|
||||
print(f"[5/5] Bak basic: FAILED - {e}")
|
||||
print(f"[6/6] Bak basic: FAILED - {e}")
|
||||
results["bak_basic"] = pd.DataFrame()
|
||||
|
||||
# Summary
|
||||
@@ -286,7 +306,7 @@ def sync_all_data(
|
||||
total_records = sum(len(df) for df in data.values())
|
||||
print(f" {data_type}: {len(data)} stocks, {total_records} total records")
|
||||
else:
|
||||
# bak_basic 返回的是 DataFrame
|
||||
# daily_basic 和 bak_basic 返回的是 DataFrame
|
||||
print(f" {data_type}: {len(data)} records")
|
||||
print("=" * 60)
|
||||
print("\nNote: namechange is NOT in auto-sync. To sync manually:")
|
||||
@@ -308,7 +328,7 @@ if __name__ == "__main__":
|
||||
print("")
|
||||
print(" # Or sync individual data types:")
|
||||
print(" from src.data.sync import sync_all, preview_sync")
|
||||
print(" from src.data.sync import sync_bak_basic")
|
||||
print(" from src.data.api_wrappers import sync_daily_basic, sync_bak_basic")
|
||||
print("")
|
||||
print(" # Preview before sync (recommended)")
|
||||
print(" preview = preview_sync()")
|
||||
|
||||
@@ -69,16 +69,11 @@ class DataRouter:
|
||||
|
||||
# 收集所有需要的表和字段
|
||||
required_tables: Dict[str, Set[str]] = {}
|
||||
max_lookback = 0
|
||||
|
||||
for spec in data_specs:
|
||||
if spec.table not in required_tables:
|
||||
required_tables[spec.table] = set()
|
||||
required_tables[spec.table].update(spec.columns)
|
||||
max_lookback = max(max_lookback, spec.lookback_days)
|
||||
|
||||
# 调整日期范围以包含回看期
|
||||
adjusted_start = self._adjust_start_date(start_date, max_lookback)
|
||||
|
||||
# 从数据源获取各表数据
|
||||
table_data = {}
|
||||
@@ -86,7 +81,7 @@ class DataRouter:
|
||||
df = self._load_table(
|
||||
table_name=table_name,
|
||||
columns=list(columns),
|
||||
start_date=adjusted_start,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
stock_codes=stock_codes,
|
||||
)
|
||||
@@ -95,11 +90,6 @@ class DataRouter:
|
||||
# 组装核心宽表
|
||||
core_table = self._assemble_wide_table(table_data, required_tables)
|
||||
|
||||
# 过滤到实际请求日期范围
|
||||
core_table = core_table.filter(
|
||||
(pl.col("trade_date") >= start_date) & (pl.col("trade_date") <= end_date)
|
||||
)
|
||||
|
||||
return core_table
|
||||
|
||||
def _load_table(
|
||||
@@ -265,34 +255,6 @@ class DataRouter:
|
||||
|
||||
return result
|
||||
|
||||
def _adjust_start_date(self, start_date: str, lookback_days: int) -> str:
|
||||
"""根据回看天数调整开始日期。
|
||||
|
||||
Args:
|
||||
start_date: 原始开始日期 (YYYYMMDD)
|
||||
lookback_days: 需要回看的交易日数
|
||||
|
||||
Returns:
|
||||
调整后的开始日期
|
||||
"""
|
||||
# 简化的日期调整:假设每月30天,向前推移
|
||||
# 实际应用中应该使用交易日历
|
||||
year = int(start_date[:4])
|
||||
month = int(start_date[4:6])
|
||||
day = int(start_date[6:8])
|
||||
|
||||
total_days = lookback_days + 30 # 额外缓冲
|
||||
|
||||
day -= total_days
|
||||
while day <= 0:
|
||||
month -= 1
|
||||
if month <= 0:
|
||||
month = 12
|
||||
year -= 1
|
||||
day += 30
|
||||
|
||||
return f"{year:04d}{month:02d}{day:02d}"
|
||||
|
||||
def clear_cache(self) -> None:
|
||||
"""清除数据缓存。"""
|
||||
with self._lock:
|
||||
|
||||
@@ -18,12 +18,10 @@ class DataSpec:
|
||||
Attributes:
|
||||
table: 数据表名称
|
||||
columns: 需要的字段列表
|
||||
lookback_days: 回看天数(用于时序计算)
|
||||
"""
|
||||
|
||||
table: str
|
||||
columns: List[str]
|
||||
lookback_days: int = 1
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@@ -73,9 +73,9 @@ class ExecutionPlanner:
|
||||
) -> List[DataSpec]:
|
||||
"""从依赖推导数据规格。
|
||||
|
||||
根据表达式中的函数类型推断回看天数需求。
|
||||
基础行情字段(open, high, low, close, vol, amount, pre_close, change, pct_chg)
|
||||
默认从 pro_bar 表获取。
|
||||
每日指标字段(total_mv, circ_mv, pe, pb 等)从 daily_basic 表获取。
|
||||
|
||||
Args:
|
||||
dependencies: 依赖的字段集合
|
||||
@@ -84,10 +84,6 @@ class ExecutionPlanner:
|
||||
Returns:
|
||||
数据规格列表
|
||||
"""
|
||||
# 计算最大回看窗口
|
||||
max_window = self._extract_max_window(expression)
|
||||
lookback_days = max(1, max_window)
|
||||
|
||||
# 基础行情字段集合(这些字段从 pro_bar 表获取)
|
||||
pro_bar_fields = {
|
||||
"open",
|
||||
@@ -103,9 +99,27 @@ class ExecutionPlanner:
|
||||
"volume_ratio",
|
||||
}
|
||||
|
||||
# 将依赖分为 pro_bar 字段和其他字段
|
||||
# 每日指标字段集合(这些字段从 daily_basic 表获取)
|
||||
daily_basic_fields = {
|
||||
"turnover_rate_f",
|
||||
"pe",
|
||||
"pe_ttm",
|
||||
"pb",
|
||||
"ps",
|
||||
"ps_ttm",
|
||||
"dv_ratio",
|
||||
"dv_ttm",
|
||||
"total_share",
|
||||
"float_share",
|
||||
"free_share",
|
||||
"total_mv",
|
||||
"circ_mv",
|
||||
}
|
||||
|
||||
# 将依赖分为不同表的字段
|
||||
pro_bar_deps = dependencies & pro_bar_fields
|
||||
other_deps = dependencies - pro_bar_fields
|
||||
daily_basic_deps = dependencies & daily_basic_fields
|
||||
other_deps = dependencies - pro_bar_fields - daily_basic_fields
|
||||
|
||||
data_specs = []
|
||||
|
||||
@@ -115,7 +129,15 @@ class ExecutionPlanner:
|
||||
DataSpec(
|
||||
table="pro_bar",
|
||||
columns=sorted(pro_bar_deps),
|
||||
lookback_days=lookback_days,
|
||||
)
|
||||
)
|
||||
|
||||
# daily_basic 表的数据规格
|
||||
if daily_basic_deps:
|
||||
data_specs.append(
|
||||
DataSpec(
|
||||
table="daily_basic",
|
||||
columns=sorted(daily_basic_deps),
|
||||
)
|
||||
)
|
||||
|
||||
@@ -125,46 +147,7 @@ class ExecutionPlanner:
|
||||
DataSpec(
|
||||
table="daily",
|
||||
columns=sorted(other_deps),
|
||||
lookback_days=lookback_days,
|
||||
)
|
||||
)
|
||||
|
||||
return data_specs
|
||||
|
||||
def _extract_max_window(self, node: Node) -> int:
|
||||
"""从表达式中提取最大窗口大小。
|
||||
|
||||
Args:
|
||||
node: AST 节点
|
||||
|
||||
Returns:
|
||||
最大窗口大小,无时序函数返回 1
|
||||
"""
|
||||
if isinstance(node, FunctionNode):
|
||||
window = 1
|
||||
# 检查函数参数中的窗口大小
|
||||
for arg in node.args:
|
||||
if (
|
||||
isinstance(arg, Constant)
|
||||
and isinstance(arg.value, int)
|
||||
and arg.value > window
|
||||
):
|
||||
window = arg.value
|
||||
|
||||
# 递归检查子表达式
|
||||
for arg in node.args:
|
||||
if isinstance(arg, Node) and not isinstance(arg, Constant):
|
||||
window = max(window, self._extract_max_window(arg))
|
||||
|
||||
return window
|
||||
|
||||
elif isinstance(node, BinaryOpNode):
|
||||
return max(
|
||||
self._extract_max_window(node.left),
|
||||
self._extract_max_window(node.right),
|
||||
)
|
||||
|
||||
elif isinstance(node, UnaryOpNode):
|
||||
return self._extract_max_window(node.operand)
|
||||
|
||||
return 1
|
||||
|
||||
@@ -1,302 +0,0 @@
|
||||
"""训练流程入口脚本
|
||||
|
||||
运行方式:
|
||||
uv run python -m src.training.main
|
||||
|
||||
或:
|
||||
uv run python src/training/main.py
|
||||
|
||||
本脚本提供两种运行方式:
|
||||
1. run_full_pipeline(): 完整训练流程(数据准备 -> 训练 -> 预测)
|
||||
2. prepare_data_and_train() + train_and_predict(): 分步执行,便于调试和调整
|
||||
|
||||
因子配置示例:
|
||||
from src.factors import MovingAverageFactor, ReturnRankFactor
|
||||
|
||||
# 直接传入因子实例列表 - 最简单的方式
|
||||
factors = [
|
||||
MovingAverageFactor(period=5),
|
||||
MovingAverageFactor(period=10),
|
||||
MovingAverageFactor(period=20),
|
||||
ReturnRankFactor(period=5),
|
||||
ReturnRankFactor(period=10),
|
||||
]
|
||||
|
||||
# 运行完整流程
|
||||
result = run_full_pipeline(factors=factors)
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Optional, List
|
||||
|
||||
import polars as pl
|
||||
|
||||
from src.factors import BaseFactor
|
||||
from src.training.pipeline import (
|
||||
FactorConfig,
|
||||
predict_top_stocks,
|
||||
prepare_data,
|
||||
save_top_stocks,
|
||||
train_model,
|
||||
)
|
||||
|
||||
|
||||
def prepare_data_and_train(
|
||||
factors: Optional[List[BaseFactor]] = None,
|
||||
data_dir: str = "data",
|
||||
train_start: str = "20190101",
|
||||
train_end: str = "20231231",
|
||||
val_start: str = "20240102",
|
||||
val_end: str = "20240531",
|
||||
test_start: str = "20240602",
|
||||
test_end: str = "20241231",
|
||||
) -> tuple[pl.DataFrame, pl.DataFrame, pl.DataFrame, FactorConfig, str]:
|
||||
"""第一步:数据处理
|
||||
|
||||
加载原始数据,计算因子和标签,拆分训练/验证/测试集。
|
||||
|
||||
Args:
|
||||
factors: 因子实例列表,默认为 None(使用 MA5, MA10, ReturnRank5)
|
||||
data_dir: 数据目录
|
||||
train_start: 训练集开始日期
|
||||
train_end: 训练集结束日期
|
||||
val_start: 验证集开始日期
|
||||
val_end: 验证集结束日期
|
||||
test_start: 测试集开始日期
|
||||
test_end: 测试集结束日期
|
||||
|
||||
Returns:
|
||||
tuple: (train_data, val_data, test_data, factor_config, label_col)
|
||||
"""
|
||||
print("=" * 50)
|
||||
print("[Step 1] 数据处理")
|
||||
print("=" * 50)
|
||||
print(f"训练集: {train_start} -> {train_end}")
|
||||
print(f"验证集: {val_start} -> {val_end}")
|
||||
print(f"测试集: {test_start} -> {test_end}")
|
||||
print()
|
||||
|
||||
# 1. 准备数据
|
||||
train_data, val_data, test_data, factor_config = prepare_data(
|
||||
factors=factors,
|
||||
data_dir=data_dir,
|
||||
train_start=train_start,
|
||||
train_end=train_end,
|
||||
val_start=val_start,
|
||||
val_end=val_end,
|
||||
test_start=test_start,
|
||||
test_end=test_end,
|
||||
)
|
||||
|
||||
print(f"训练集样本数: {len(train_data)}")
|
||||
print(f"验证集样本数: {len(val_data)}")
|
||||
print(f"测试集样本数: {len(test_data)}")
|
||||
print()
|
||||
|
||||
# 打印少量数据样本展示
|
||||
print("=" * 50)
|
||||
print("[数据预览] 训练集前3行:")
|
||||
print(train_data.head(3))
|
||||
print()
|
||||
print("[数据预览] 验证集前3行:")
|
||||
print(val_data.head(3))
|
||||
print()
|
||||
print("[数据预览] 测试集前3行:")
|
||||
print(test_data.head(3))
|
||||
print()
|
||||
|
||||
# 2. 获取特征列名
|
||||
feature_cols = factor_config.get_feature_names()
|
||||
label_col = "label"
|
||||
|
||||
print(f"特征列: {feature_cols}")
|
||||
print(f"标签列: {label_col}")
|
||||
print()
|
||||
|
||||
return train_data, val_data, test_data, factor_config, label_col
|
||||
|
||||
|
||||
def train_and_predict(
|
||||
train_data: pl.DataFrame,
|
||||
val_data: pl.DataFrame,
|
||||
test_data: pl.DataFrame,
|
||||
factor_config: FactorConfig,
|
||||
label_col: str = "label",
|
||||
top_n: int = 5,
|
||||
output_path: str = "output/top_stocks.tsv",
|
||||
) -> pl.DataFrame:
|
||||
"""第二步:训练和预测
|
||||
|
||||
使用处理好的数据训练模型,进行测试集预测并保存结果。
|
||||
|
||||
Args:
|
||||
train_data: 训练数据
|
||||
val_data: 验证数据
|
||||
test_data: 测试数据
|
||||
factor_config: 因子配置对象
|
||||
label_col: 标签列名
|
||||
top_n: 每日选股数量
|
||||
output_path: 输出文件路径
|
||||
|
||||
Returns:
|
||||
选股结果DataFrame
|
||||
"""
|
||||
print("=" * 50)
|
||||
print("[Step 2] 模型训练与预测")
|
||||
print("=" * 50)
|
||||
print()
|
||||
|
||||
# 获取特征列名
|
||||
feature_cols = factor_config.get_feature_names()
|
||||
print(f"使用特征: {feature_cols}")
|
||||
print()
|
||||
|
||||
# 3. 训练模型
|
||||
print("[Training] Training model...")
|
||||
model, pipeline = train_model(
|
||||
train_data=train_data,
|
||||
val_data=val_data,
|
||||
feature_cols=feature_cols,
|
||||
label_col=label_col,
|
||||
)
|
||||
print()
|
||||
|
||||
# 4. 测试集预测
|
||||
print("[Predict] Predicting on test set...")
|
||||
top_stocks = predict_top_stocks(
|
||||
model=model,
|
||||
pipeline=pipeline,
|
||||
test_data=test_data,
|
||||
feature_cols=feature_cols,
|
||||
top_n=top_n,
|
||||
)
|
||||
print()
|
||||
|
||||
# 5. 保存结果
|
||||
print(f"[Saving] Saving results to {output_path}...")
|
||||
Path(output_path).parent.mkdir(parents=True, exist_ok=True)
|
||||
save_top_stocks(top_stocks, output_path)
|
||||
print()
|
||||
|
||||
return top_stocks
|
||||
|
||||
|
||||
def run_full_pipeline(
|
||||
factors: Optional[List[BaseFactor]] = None,
|
||||
train_start: str = "20190101",
|
||||
train_end: str = "20231231",
|
||||
val_start: str = "20240102",
|
||||
val_end: str = "20240531",
|
||||
test_start: str = "20240602",
|
||||
test_end: str = "20241231",
|
||||
top_n: int = 5,
|
||||
output_path: str = "output/top_stocks.tsv",
|
||||
) -> pl.DataFrame:
|
||||
"""运行完整训练流程
|
||||
|
||||
相当于依次调用 prepare_data_and_train 和 train_and_predict。
|
||||
|
||||
Args:
|
||||
factors: 因子实例列表,默认为 None(使用 MA5, MA10, ReturnRank5)
|
||||
train_start: 训练集开始日期
|
||||
train_end: 训练集结束日期
|
||||
val_start: 验证集开始日期
|
||||
val_end: 验证集结束日期
|
||||
test_start: 测试集开始日期
|
||||
test_end: 测试集结束日期
|
||||
top_n: 每日选股数量
|
||||
output_path: 输出文件路径
|
||||
|
||||
Returns:
|
||||
选股结果DataFrame
|
||||
"""
|
||||
# 第一步:数据处理
|
||||
train_data, val_data, test_data, factor_config, label_col = prepare_data_and_train(
|
||||
factors=factors,
|
||||
train_start=train_start,
|
||||
train_end=train_end,
|
||||
val_start=val_start,
|
||||
val_end=val_end,
|
||||
test_start=test_start,
|
||||
test_end=test_end,
|
||||
)
|
||||
|
||||
# 第二步:训练和预测
|
||||
result = train_and_predict(
|
||||
train_data=train_data,
|
||||
val_data=val_data,
|
||||
test_data=test_data,
|
||||
factor_config=factor_config,
|
||||
label_col=label_col,
|
||||
top_n=top_n,
|
||||
output_path=output_path,
|
||||
)
|
||||
|
||||
print("=" * 50)
|
||||
print("[Done] 训练流程完成!")
|
||||
print("=" * 50)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from src.factors import MovingAverageFactor, ReturnRankFactor
|
||||
|
||||
# ========== 因子配置 ==========
|
||||
# 直接传入因子实例列表 - 简单直观
|
||||
factors = [
|
||||
MovingAverageFactor(period=5), # 5日移动平均线
|
||||
MovingAverageFactor(period=10), # 10日移动平均线
|
||||
MovingAverageFactor(period=20), # 20日移动平均线
|
||||
ReturnRankFactor(period=5), # 5日收益率排名
|
||||
ReturnRankFactor(period=10), # 10日收益率排名
|
||||
]
|
||||
|
||||
# ========== 运行方式 ==========
|
||||
|
||||
# 方式一:完整流程(一次性执行)
|
||||
# result = run_full_pipeline(
|
||||
# factors=factors,
|
||||
# train_start="20190101",
|
||||
# train_end="20231231",
|
||||
# val_start="20240102",
|
||||
# val_end="20240531",
|
||||
# test_start="20240602",
|
||||
# test_end="20241231",
|
||||
# top_n=5,
|
||||
# output_path="output/top_stocks.tsv",
|
||||
# )
|
||||
|
||||
# 方式二:分步执行(便于调试)
|
||||
# 第一步:数据处理
|
||||
train_data, val_data, test_data, factor_config, label_col = prepare_data_and_train(
|
||||
factors=factors,
|
||||
train_start="20190101",
|
||||
train_end="20231231",
|
||||
val_start="20240102",
|
||||
val_end="20240531",
|
||||
test_start="20240602",
|
||||
test_end="20241231",
|
||||
)
|
||||
|
||||
# 可在此处添加自定义逻辑,例如:
|
||||
# - 查看数据分布
|
||||
# - 调整特征
|
||||
# - 保存中间结果
|
||||
print("\n[Info] 因子配置详情:")
|
||||
print(f" 因子列表: {factor_config.get_feature_names()}")
|
||||
print(f" 最大回溯天数: {factor_config.get_max_lookback()}")
|
||||
|
||||
# 第二步:训练和预测
|
||||
# result = train_and_predict(
|
||||
# train_data=train_data,
|
||||
# val_data=val_data,
|
||||
# test_data=test_data,
|
||||
# factor_config=factor_config,
|
||||
# label_col=label_col,
|
||||
# top_n=5,
|
||||
# output_path="output/top_stocks.tsv",
|
||||
# )
|
||||
#
|
||||
# print("\n[Result] Top stocks selection:")
|
||||
# print(result)
|
||||
@@ -71,7 +71,7 @@ class TestFactorEngineEndToEnd:
|
||||
@pytest.fixture
|
||||
def engine(self, mock_data):
|
||||
"""提供配置好的 FactorEngine fixture。"""
|
||||
data_source = {"daily": mock_data}
|
||||
data_source = {"pro_bar": mock_data}
|
||||
return FactorEngine(data_source=data_source, max_workers=2)
|
||||
|
||||
def test_simple_symbol_expression(self, engine):
|
||||
@@ -116,7 +116,7 @@ class TestFullWorkflow:
|
||||
|
||||
# 2. 初始化引擎
|
||||
print("\nStep 2: Initialize FactorEngine...")
|
||||
engine = FactorEngine(data_source={"daily": mock_data})
|
||||
engine = FactorEngine(data_source={"pro_bar": mock_data})
|
||||
print(" Engine initialized")
|
||||
|
||||
# 3. 注册因子 - 使用简单因子避免回看窗口问题
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
2. return_5_rank: 5日收益率在截面上的排名
|
||||
3. ma5: 5日均线 (ts_mean(close, 5))
|
||||
4. ma10: 10日均线 (ts_mean(close, 10))
|
||||
5. market_cap_rank: 市值百分比排名 (cs_rank(total_mv))
|
||||
|
||||
特点:使用因子字符串架构(add_factor + 字符串表达式)
|
||||
|
||||
@@ -48,6 +49,11 @@ def test_two_stocks_string_factors():
|
||||
print("\n[1.4] ma10 = ts_mean(close, 10)")
|
||||
print(f" 字符串表达式: {ma10_str}")
|
||||
|
||||
# market_cap_rank: 市值百分比排名 (截面排名)
|
||||
market_cap_rank_str = "cs_rank(total_mv)"
|
||||
print("\n[1.5] market_cap_rank = cs_rank(total_mv)")
|
||||
print(f" 字符串表达式: {market_cap_rank_str}")
|
||||
|
||||
# ========================================================================
|
||||
# 1.5 打印数据来源信息
|
||||
# ========================================================================
|
||||
@@ -66,6 +72,7 @@ def test_two_stocks_string_factors():
|
||||
"return_5_rank": return_5_rank_str,
|
||||
"ma5": ma5_str,
|
||||
"ma10": ma10_str,
|
||||
"market_cap_rank": market_cap_rank_str,
|
||||
}
|
||||
|
||||
for name, expr_str in expressions_str.items():
|
||||
@@ -102,9 +109,12 @@ def test_two_stocks_string_factors():
|
||||
engine.add_factor("ma10", ma10_str)
|
||||
print("[2.4] 注册 ma10 (字符串方式)")
|
||||
|
||||
engine.add_factor("market_cap_rank", market_cap_rank_str)
|
||||
print("[2.5] 注册 market_cap_rank (市值百分比排名,字符串方式)")
|
||||
|
||||
# 也注册原始 close 价格用于验证
|
||||
engine.add_factor("close_price", "close")
|
||||
print("[2.5] 注册 close_price (原始收盘价,字符串方式)")
|
||||
print("[2.6] 注册 close_price (原始收盘价,字符串方式)")
|
||||
|
||||
print(f"\n已注册因子列表: {engine.list_registered()}")
|
||||
|
||||
@@ -125,7 +135,6 @@ def test_two_stocks_string_factors():
|
||||
for i, spec in enumerate(plan.data_specs, 1):
|
||||
print(f" [{i}] 表名: {spec.table}")
|
||||
print(f" 字段: {spec.columns}")
|
||||
print(f" 回看天数: {spec.lookback_days}")
|
||||
|
||||
# ========================================================================
|
||||
# 3. 执行计算(两支股票)
|
||||
@@ -143,7 +152,14 @@ def test_two_stocks_string_factors():
|
||||
|
||||
try:
|
||||
result = engine.compute(
|
||||
factor_names=["return_5", "return_5_rank", "ma5", "ma10", "close_price"],
|
||||
factor_names=[
|
||||
"return_5",
|
||||
"return_5_rank",
|
||||
"ma5",
|
||||
"ma10",
|
||||
"close_price",
|
||||
"market_cap_rank",
|
||||
],
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
stock_codes=stock_codes,
|
||||
@@ -345,6 +361,10 @@ def test_two_stocks_string_factors():
|
||||
print(" - 每天两支股票的排名之和应接近 1")
|
||||
print("-" * 60)
|
||||
|
||||
# 5.5.1 return_5_rank 截面排名验证
|
||||
print("\n[5.5.1] return_5_rank 截面排名验证:")
|
||||
print("-" * 60)
|
||||
|
||||
# 获取有效数据
|
||||
result_valid = result.drop_nulls(subset=["return_5_rank"])
|
||||
|
||||
@@ -373,6 +393,39 @@ def test_two_stocks_string_factors():
|
||||
else:
|
||||
print(" [警告] 截面排名之和不接近 1")
|
||||
|
||||
# 5.5.2 market_cap_rank 市值百分比排名验证
|
||||
print("\n[5.5.2] market_cap_rank 市值百分比排名验证:")
|
||||
print("-" * 60)
|
||||
|
||||
result_valid_mv = result.drop_nulls(subset=["market_cap_rank"])
|
||||
|
||||
if len(result_valid_mv) > 0:
|
||||
min_rank_mv = result_valid_mv["market_cap_rank"].min()
|
||||
max_rank_mv = result_valid_mv["market_cap_rank"].max()
|
||||
print(f"\n市值排名范围: [{min_rank_mv:.4f}, {max_rank_mv:.4f}]")
|
||||
|
||||
if 0 <= min_rank_mv <= 1 and 0 <= max_rank_mv <= 1:
|
||||
print(" [成功] 市值排名值在 [0, 1] 区间内!")
|
||||
else:
|
||||
print(" [警告] 市值排名值超出 [0, 1] 区间")
|
||||
|
||||
# 检查某天两支股票的市值排名之和
|
||||
sample_date_mv = result_valid_mv["trade_date"][0]
|
||||
day_data_mv = result_valid_mv.filter(
|
||||
result_valid_mv["trade_date"] == sample_date_mv
|
||||
)
|
||||
if len(day_data_mv) == 2:
|
||||
rank_sum_mv = day_data_mv["market_cap_rank"].sum()
|
||||
print(f"\n示例日期 {sample_date_mv} 的市值排名验证:")
|
||||
for row in day_data_mv.iter_rows(named=True):
|
||||
print(f" {row['ts_code']}: {row['market_cap_rank']:.4f}")
|
||||
print(f" 排名之和: {rank_sum_mv:.4f} (两支股票应接近 1)")
|
||||
|
||||
if abs(rank_sum_mv - 1.0) < 0.01:
|
||||
print(" [成功] 市值排名之和验证通过!")
|
||||
else:
|
||||
print(" [警告] 市值排名之和不接近 1")
|
||||
|
||||
# ========================================================================
|
||||
# 6. 统计摘要
|
||||
# ========================================================================
|
||||
@@ -394,7 +447,7 @@ def test_two_stocks_string_factors():
|
||||
print(f"总记录数: {len(stock_data)}")
|
||||
print(f"有效记录数 (去空值后): {len(stock_valid)}")
|
||||
|
||||
factor_cols = ["return_5", "return_5_rank", "ma5", "ma10"]
|
||||
factor_cols = ["return_5", "return_5_rank", "ma5", "ma10", "market_cap_rank"]
|
||||
|
||||
for col in factor_cols:
|
||||
if col in stock_data.columns:
|
||||
@@ -418,7 +471,6 @@ def test_two_stocks_string_factors():
|
||||
# ========================================================================
|
||||
print("\n" + "=" * 80)
|
||||
|
||||
|
||||
# ========================================================================
|
||||
# 8. 测试总结
|
||||
# ========================================================================
|
||||
@@ -434,10 +486,13 @@ def test_two_stocks_string_factors():
|
||||
print()
|
||||
print("因子定义方式: 字符串表达式 (add_factor 方法)")
|
||||
print("计算因子:")
|
||||
print(" 1. return_5 - 5日收益率 (字符串: '(close / ts_delay(close, 5)) - 1')")
|
||||
print(" 2. return_5_rank - 5日收益率截面排名 (字符串: 'cs_rank(...)')")
|
||||
print(" 3. ma5 - 5日均线 (字符串: 'ts_mean(close, 5)')")
|
||||
print(" 4. ma10 - 10日均线 (字符串: 'ts_mean(close, 10)')")
|
||||
print(
|
||||
" 1. return_5 - 5日收益率 (字符串: '(close / ts_delay(close, 5)) - 1')"
|
||||
)
|
||||
print(" 2. return_5_rank - 5日收益率截面排名 (字符串: 'cs_rank(...)')")
|
||||
print(" 3. ma5 - 5日均线 (字符串: 'ts_mean(close, 5)')")
|
||||
print(" 4. ma10 - 10日均线 (字符串: 'ts_mean(close, 10)')")
|
||||
print(" 5. market_cap_rank - 市值百分比排名 (字符串: 'cs_rank(total_mv)')")
|
||||
print()
|
||||
print("验证结果:")
|
||||
print(" - 字符串表达式解析: 正常")
|
||||
|
||||
Reference in New Issue
Block a user