From 53225b94437883e8e16f989bd5a199401f41880e Mon Sep 17 00:00:00 2001 From: liaozhaorun <1300336796@qq.com> Date: Tue, 3 Mar 2026 17:09:39 +0800 Subject: [PATCH] =?UTF-8?q?feat(data):=20=E6=B7=BB=E5=8A=A0=E6=AF=8F?= =?UTF-8?q?=E6=97=A5=E6=8C=87=E6=A0=87=E6=8E=A5=E5=8F=A3=E5=B9=B6=E4=BC=98?= =?UTF-8?q?=E5=8C=96=E5=9B=A0=E5=AD=90=E5=BC=95=E6=93=8E=20-=20=E6=96=B0?= =?UTF-8?q?=E5=A2=9E=20api=5Fdaily=5Fbasic.py=20=E5=B0=81=E8=A3=85=20Tusha?= =?UTF-8?q?re=20=E6=AF=8F=E6=97=A5=E6=8C=87=E6=A0=87=E6=8E=A5=E5=8F=A3=20-?= =?UTF-8?q?=20=E5=9B=A0=E5=AD=90=E5=BC=95=E6=93=8E=E7=A7=BB=E9=99=A4=20loo?= =?UTF-8?q?kback=5Fdays=EF=BC=8C=E6=94=AF=E6=8C=81=20daily=5Fbasic=20?= =?UTF-8?q?=E8=A1=A8=E5=AD=97=E6=AE=B5=E8=B7=AF=E7=94=B1=20-=20=E5=B0=86?= =?UTF-8?q?=E6=AF=8F=E6=97=A5=E6=8C=87=E6=A0=87=E7=BA=B3=E5=85=A5=E8=87=AA?= =?UTF-8?q?=E5=8A=A8=E5=90=8C=E6=AD=A5=E6=B5=81=E7=A8=8B=20-=20=E5=88=A0?= =?UTF-8?q?=E9=99=A4=E5=BA=9F=E5=BC=83=E7=9A=84=20training/main.py?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- AGENTS.md | 131 +++++- docs/factor_calculation_flow.md | 557 +++++++++++++++++++++++ src/data/api_wrappers/__init__.py | 15 +- src/data/api_wrappers/api.md | 72 ++- src/data/api_wrappers/api_daily_basic.py | 252 ++++++++++ src/data/sync.py | 40 +- src/factors/engine/data_router.py | 40 +- src/factors/engine/data_spec.py | 2 - src/factors/engine/planner.py | 77 ++-- src/training/main.py | 302 ------------ tests/test_factor_engine.py | 4 +- tests/test_two_stocks_string_factors.py | 73 ++- 12 files changed, 1132 insertions(+), 433 deletions(-) create mode 100644 docs/factor_calculation_flow.md create mode 100644 src/data/api_wrappers/api_daily_basic.py delete mode 100644 src/training/main.py diff --git a/AGENTS.md b/AGENTS.md index 092e255..025e54f 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -82,20 +82,21 @@ ProStock/ │ │ │ ├── data/ # 数据获取与存储 │ │ ├── api_wrappers/ # Tushare API 封装 -│ │ │ ├── base_sync.py # 同步基础抽象类(BaseDataSync/StockBasedSync/DateBasedSync) -│ │ │ ├── api_daily.py # 日线数据接口(DailySync) -│ │ │ ├── api_pro_bar.py # Pro Bar 数据接口(ProBarSync) -│ │ │ ├── api_stock_basic.py # 股票基础信息接口 -│ │ │ ├── api_trade_cal.py # 交易日历接口 -│ │ │ ├── api_bak_basic.py # 历史股票列表接口(BakBasicSync) -│ │ │ ├── api_namechange.py # 股票名称变更接口 -│ │ │ ├── financial_data/ # 财务数据接口 -│ │ │ │ ├── api_income.py # 利润表接口 +│ │ │ ├── base_sync.py # 同步基础抽象类(BaseDataSync/StockBasedSync/DateBasedSync) +│ │ │ ├── api_daily.py # 日线数据接口(DailySync) +│ │ │ ├── api_pro_bar.py # Pro Bar 数据接口(ProBarSync) +│ │ │ ├── api_stock_basic.py # 股票基础信息接口 +│ │ │ ├── api_trade_cal.py # 交易日历接口 +│ │ │ ├── api_bak_basic.py # 历史股票列表接口(BakBasicSync) +│ │ │ ├── api_namechange.py # 股票名称变更接口 +│ │ │ ├── financial_data/ # 财务数据接口 +│ │ │ │ ├── api_income.py # 利润表接口 │ │ │ │ └── api_financial_sync.py # 财务数据同步 │ │ │ └── __init__.py │ │ ├── __init__.py │ │ ├── client.py # Tushare API 客户端(带速率限制) │ │ ├── config.py # 数据模块配置 +│ │ ├── data_router.py # 数据路由器(factors/engine 专用) │ │ ├── db_inspector.py # 数据库信息查看工具 │ │ ├── db_manager.py # DuckDB 表管理和同步 │ │ ├── rate_limiter.py # 令牌桶速率限制器 @@ -104,20 +105,29 @@ ProStock/ │ │ └── utils.py # 数据模块工具函数 │ │ │ ├── factors/ # 因子计算框架(DSL 表达式驱动) +│ │ ├── engine/ # 执行引擎子模块 +│ │ │ ├── __init__.py # 导出引擎组件 +│ │ │ ├── data_spec.py # 数据规格定义(DataSpec, ExecutionPlan) +│ │ │ ├── data_router.py # 数据路由器 +│ │ │ ├── planner.py # 执行计划生成器(ExecutionPlanner) +│ │ │ ├── compute_engine.py # 计算引擎(ComputeEngine) +│ │ │ └── factor_engine.py # 因子引擎统一入口(FactorEngine) │ │ ├── __init__.py # 导出所有公开 API │ │ ├── dsl.py # DSL 表达式层 - 节点定义和运算符重载 │ │ ├── api.py # API 层 - 常用符号(close/open等)和函数(ts_mean/cs_rank等) │ │ ├── compiler.py # AST 编译器 - 依赖提取 │ │ ├── translator.py # Polars 表达式翻译器 -│ │ └── engine.py # 因子执行引擎 - 统一入口 +│ │ ├── parser.py # 字符串公式解析器(FormulaParser) +│ │ ├── registry.py # 函数注册表(FunctionRegistry) +│ │ └── exceptions.py # 异常定义(FormulaParseError等) │ │ │ ├── pipeline/ # 模型训练管道 │ │ ├── __init__.py -│ │ ├── pipeline.py # 处理流水线 -│ │ ├── registry.py # 插件注册中心 +│ │ ├── pipeline.py # 处理流水线(ProcessingPipeline) +│ │ ├── registry.py # 插件注册中心(PluginRegistry) │ │ ├── core/ # 核心抽象 │ │ │ ├── __init__.py -│ │ │ ├── base.py # 基类定义 +│ │ │ ├── base.py # 基类定义(BaseProcessor/BaseModel/BaseSplitter等) │ │ │ └── splitter.py # 时间序列划分策略 │ │ ├── models/ # 模型实现 │ │ │ ├── __init__.py @@ -128,14 +138,23 @@ ProStock/ │ │ │ └── training/ # 训练入口 │ ├── __init__.py -│ ├── main.py # 训练主程序 │ ├── pipeline.py # 训练流程配置 │ └── output/ # 训练输出 │ └── top_stocks.tsv # 推荐股票结果 │ ├── tests/ # 测试文件 │ ├── test_sync.py -│ └── test_daily.py +│ ├── test_daily.py +│ ├── test_factor_engine.py +│ ├── test_factor_integration.py +│ ├── test_pro_bar.py +│ ├── test_601117_factors.py +│ ├── test_two_stocks_string_factors.py +│ ├── test_db_manager.py +│ ├── test_daily_storage.py +│ ├── test_tushare_api.py +│ └── pipeline/ +│ └── test_core.py ├── config/ # 配置文件 │ └── .env.local # 环境变量(不在 git 中) ├── data/ # 数据存储(DuckDB) @@ -266,10 +285,13 @@ except Exception as e: ### 依赖项 关键包: - `pandas>=2.0.0` - 数据处理 +- `polars>=0.20.0` - 高性能数据处理(因子计算) - `numpy>=1.24.0` - 数值计算 - `tushare>=2.0.0` - A股数据 API - `pydantic>=2.0.0`、`pydantic-settings>=2.0.0` - 配置 - `tqdm>=4.65.0` - 进度条 +- `lightgbm>=4.0.0` - 机器学习模型 +- `catboost>=1.2.0` - 机器学习模型 - `pytest` - 测试(开发) ### 环境变量 @@ -292,6 +314,9 @@ uv run python -c "from src.data.sync import sync_all; sync_all(force_full=True)" # 自定义线程数 uv run python -c "from src.data.sync import sync_all; sync_all(max_workers=20)" + +# 运行因子计算测试 +uv run pytest tests/test_factor_engine.py -v ``` ## 架构变更历史 @@ -309,8 +334,16 @@ uv run python -c "from src.data.sync import sync_all; sync_all(max_workers=20)" - 新增 `api.py`: 常用符号(close/open/volume等)和函数(ts_mean/cs_rank等) - 新增 `compiler.py`: AST 编译器,提取表达式依赖 - 新增 `translator.py`: 将 DSL 表达式翻译为 Polars 表达式 - - 重构 `engine.py`: 统一执行引擎入口,整合 DataRouter、ExecutionPlanner、ComputeEngine - - 移除: `base.py`、`composite.py`、`data_loader.py`、`data_spec.py` + - 新增 `parser.py`: 字符串公式解析器(FormulaParser),支持从字符串解析 DSL 表达式 + - 新增 `registry.py`: 函数注册表(FunctionRegistry),管理字符串函数名到 Python 函数的映射 + - 新增 `exceptions.py`: 公式解析异常定义(FormulaParseError、UnknownFunctionError等) + - 重构 `engine/` 子模块: + - `factor_engine.py`: 因子引擎统一入口(FactorEngine) + - `data_spec.py`: 数据规格定义(DataSpec, ExecutionPlan) + - `data_router.py`: 数据路由器 + - `planner.py`: 执行计划生成器(ExecutionPlanner) + - `compute_engine.py`: 计算引擎(ComputeEngine) + - 移除: `base.py`、`composite.py`、`data_loader.py`、根目录的 `data_spec.py` - 移除: `factors/momentum/` 和 `factors/financial/` 子目录 **使用方式对比**: ```python @@ -328,6 +361,11 @@ uv run python -c "from src.data.sync import sync_all; sync_all(max_workers=20)" engine = FactorEngine() engine.register("ma20", ma20) result = engine.compute(["ma20"], "20240101", "20240131") + + # 字符串公式解析(Phase 1 新增) + from src.factors import FormulaParser, FunctionRegistry + parser = FormulaParser(FunctionRegistry()) + expr = parser.parse("ts_mean(close, 20) / close") # 从字符串解析 ``` #### data 模块补充完善 @@ -411,10 +449,20 @@ DSL 层 (dsl.py) <- 因子表达式 (Node) Compiler (compiler.py) <- AST 依赖提取 | v +Parser (parser.py) <- 字符串公式解析器 + | + v +Registry (registry.py) <- 函数注册表 + | + v Translator (translator.py) <- 翻译为 Polars 表达式 | v -Engine (engine.py) <- 执行引擎 (DataRouter/ExecutionPlanner/ComputeEngine) +Engine (engine/) <- 执行引擎 + | - FactorEngine: 统一入口 + | - DataRouter: 数据路由 + | - ExecutionPlanner: 执行计划 + | - ComputeEngine: 计算引擎 | v 数据层 (data_router.py + DuckDB) <- 数据获取和存储 @@ -422,7 +470,7 @@ Engine (engine.py) <- 执行引擎 (DataRouter/ExecutionPlanner/ComputeEngi ### 使用方式 -#### 1. 基础表达式 +#### 1. 基础表达式(DSL 方式) ```python from src.factors.api import close, open, ts_mean, cs_rank @@ -435,15 +483,37 @@ price_rank = cs_rank(close) # 收盘价截面排名 alpha = ma20 * 0.6 + price_rank * 0.4 ``` -#### 2. 注册和执行 +#### 2. 字符串公式解析(Phase 1 新增) + +```python +from src.factors import FormulaParser, FunctionRegistry + +# 创建解析器(自动加载 api.py 中的所有函数) +parser = FormulaParser(FunctionRegistry()) + +# 从字符串解析公式 +expr = parser.parse("ts_mean(close, 20) / close") +complex_expr = parser.parse("cs_rank(ts_mean(close, 5) - ts_mean(close, 20))") + +# 支持完整运算符和函数调用 +expr2 = parser.parse("(close - open) / open * 100") # 涨跌幅 +expr3 = parser.parse("ts_corr(close, volume, 20)") # 量价相关性 +``` + +#### 3. 注册和执行 ```python from src.factors import FactorEngine engine = FactorEngine() + +# 注册 DSL 表达式 engine.register("ma20", ma20) engine.register("price_rank", price_rank) +# 或注册字符串解析的表达式 +engine.register("alpha", parser.parse("ma20 * 0.6 + price_rank * 0.4")) + # 执行计算 result = engine.compute( factor_names=["ma20", "price_rank"], @@ -504,6 +574,27 @@ expr3 = -change # 涨跌额取反 expr4 = ts_mean(cs_rank(close), 20) # 排名后的20日平滑 ``` +### 异常处理 + +框架提供清晰的异常类型帮助定位问题: + +- `FormulaParseError` - 公式解析错误基类 +- `UnknownFunctionError` - 未知函数错误(提供模糊匹配建议) +- `InvalidSyntaxError` - 语法错误 +- `EmptyExpressionError` - 空表达式错误 +- `DuplicateFunctionError` - 函数重复注册错误 + +示例: +```python +from src.factors import FormulaParser, FunctionRegistry, UnknownFunctionError + +parser = FormulaParser(FunctionRegistry()) +try: + expr = parser.parse("unknown_func(close)") +except UnknownFunctionError as e: + print(e) # 显示错误位置和可用函数建议 +``` + ## AI 行为准则 ## AI 行为准则 diff --git a/docs/factor_calculation_flow.md b/docs/factor_calculation_flow.md new file mode 100644 index 0000000..8b80b94 --- /dev/null +++ b/docs/factor_calculation_flow.md @@ -0,0 +1,557 @@ +# 因子计算流程详解 + +本文档详细描述 ProStock 项目中因子计算引擎的完整数据流,以因子表达式 `(close / ts_delay(close, 5)) - 1` 为例,说明从字符串解析到最终计算结果的完整流程。 + +## 目录 + +1. [整体架构概览](#1-整体架构概览) +2. [阶段一:字符串解析](#2-阶段一字符串解析) +3. [阶段二:AST 编译与依赖提取](#3-阶段二ast-编译与依赖提取) +4. [阶段三:执行计划生成](#4-阶段三执行计划生成) +5. [阶段四:数据获取](#5-阶段四数据获取) +6. [阶段五:Polars 翻译与计算](#6-阶段五polars-翻译与计算) +7. [完整调用链示例](#7-完整调用链示例) +8. [数据流时序图](#8-数据流时序图) + +--- + +## 1. 整体架构概览 + +### 1.1 架构层次 + +因子框架采用分层设计,从上到下依次为: + +``` +API 层 (api.py) + | + v +DSL 层 (dsl.py) ← 因子表达式 (Node) + | + v +Parser (parser.py) ← 字符串公式解析 + | + v +Registry (registry.py) ← 函数注册表 + | + v +Compiler (compiler.py) ← AST 依赖提取 + | + v +Planner (planner.py) ← 执行计划生成 + | + v +Translator (translator.py) ← 翻译为 Polars 表达式 + | + v +Engine (engine/) ← 执行引擎 + | - FactorEngine: 统一入口 + | - DataRouter: 数据路由 + | - ExecutionPlanner: 执行计划 + | - ComputeEngine: 计算引擎 + | + v +数据层 (data/storage.py) ← DuckDB 数据获取和存储 +``` + +### 1.2 核心组件职责 + +| 组件 | 文件路径 | 主要职责 | +|------|----------|----------| +| **FormulaParser** | `factors/parser.py` | 将字符串表达式解析为 DSL AST 节点树 | +| **FunctionRegistry** | `factors/registry.py` | 管理函数名到 Python 实现的映射 | +| **DSL Nodes** | `factors/dsl.py` | 定义表达式节点(Symbol, FunctionNode, BinaryOpNode 等)| +| **DependencyExtractor** | `factors/compiler.py` | 从 AST 提取依赖的数据字段 | +| **ExecutionPlanner** | `factors/engine/planner.py` | 整合编译器和翻译器生成执行计划 | +| **DataRouter** | `factors/engine/data_router.py` | 按需取数、组装核心宽表 | +| **PolarsTranslator** | `factors/translator.py` | 将 DSL AST 翻译为 Polars 表达式 | +| **ComputeEngine** | `factors/engine/compute_engine.py` | 执行 Polars 表达式计算 | +| **FactorEngine** | `factors/engine/factor_engine.py` | 系统统一入口,协调各组件 | +| **Storage** | `data/storage.py` | DuckDB 数据存储和查询接口 | + +--- + +## 2. 阶段一:字符串解析 + +### 2.1 解析流程 + +当用户调用 `parser.parse("(close / ts_delay(close, 5)) - 1")` 时,解析流程如下: + +```python +# 1. FormulaParser.parse() 方法 +formula = "(close / ts_delay(close, 5)) - 1" +ast_tree = ast.parse(formula, mode='eval') # Python AST +dsl_node = self._visit(ast_tree.body) # 递归转换为 DSL Node +``` + +**解析步骤:** + +1. **Python AST 解析**:使用 Python 标准库 `ast.parse()` 将字符串解析为 Python AST +2. **AST 遍历转换**:通过 `_visit()` 方法递归遍历 Python AST,映射为 DSL 节点 +3. **节点类型映射**: + - `ast.Name` → `Symbol` + - `ast.Constant` → `Constant` + - `ast.BinOp` → `BinaryOpNode` + - `ast.UnaryOp` → `UnaryOpNode` + - `ast.Call` → `FunctionNode` + +### 2.2 解析结果 AST 结构 + +``` +BinaryOpNode(op='-', left, right=Constant(1)) + ├── left: BinaryOpNode(op='/', left, right) + │ ├── left: Symbol('close') + │ └── right: FunctionNode('ts_delay', [Symbol('close'), Constant(5)]) + └── right: Constant(1) +``` + +### 2.3 函数解析机制 + +对于函数调用 `ts_delay(close, 5)`: + +```python +# _visit_Call 方法处理逻辑 +def _visit_Call(self, node: ast.Call) -> FunctionNode: + func_name = node.func.id # "ts_delay" + + # 从注册表获取函数实现 + if self.registry.has(func_name): + func_impl = self.registry.get(func_name) + + # 递归解析参数 + args = [self._visit(arg) for arg in node.args] + + # 调用函数实现,返回 FunctionNode + return func_impl(*args) +``` + +**关键点**:`ts_delay` 函数在 `api.py` 中定义: + +```python +def ts_delay(x: NodeOrStr, periods: int) -> FunctionNode: + return FunctionNode('ts_delay', to_node(x), Constant(periods)) +``` + +--- + +## 3. 阶段二:AST 编译与依赖提取 + +### 3.1 依赖提取流程 + +解析后的 AST 需要提取依赖的原始数据字段: + +```python +# DependencyExtractor.extract_dependencies(node) +extractor = DependencyExtractor() +dependencies = extractor.extract_dependencies(dsl_node) +# 结果: {'close'} +``` + +**提取逻辑**: + +```python +def _visit(self, node): + if isinstance(node, Symbol): + return {node.name} # 收集字段名 + elif isinstance(node, (BinaryOpNode, FunctionNode)): + # 递归收集子节点依赖 + deps = set() + for child in node.args: + deps.update(self._visit(child)) + return deps + # ... 其他节点类型 +``` + +### 3.2 依赖的作用 + +提取的依赖 `{close}` 用于: +1. **数据规格推导**:确定需要从数据库读取哪些字段 +2. **执行计划生成**:明确数据需求,避免读取不必要的字段 + +--- + +## 4. 阶段三:执行计划生成 + +### 4.1 ExecutionPlanner 的作用 + +`ExecutionPlanner.create_plan()` 将 AST 转换为可执行的 `ExecutionPlan`: + +```python +planner = ExecutionPlanner() +plan = planner.create_plan( + node=dsl_node, # 解析后的 DSL 节点 + output_name="returns_5d" # 输出列名 +) +``` + +### 4.2 计划生成流程 + +```python +def create_plan(self, node, output_name): + # 1. 提取依赖 + dependencies = self.dependency_extractor.extract_dependencies(node) + + # 2. 推导数据规格 + data_specs = self._infer_data_specs(node, dependencies) + # 结果: [DataSpec(table='pro_bar', columns=['close'])] + + # 3. 翻译为 Polars 表达式 + polars_expr = self.polars_translator.translate(node) + # 结果: (pl.col('close') / pl.col('close').shift(5)) - 1 + + # 4. 构建执行计划 + return ExecutionPlan( + data_specs=data_specs, + polars_expr=polars_expr, + dependencies=dependencies, + output_name=output_name + ) +``` + +### 4.3 数据规格推导 + +根据依赖字段,`_infer_data_specs` 推导出需要的数据规格: + +```python +def _infer_data_specs(self, node, dependencies): + return [ + DataSpec( + table='pro_bar', # 默认使用 pro_bar 表 + columns=list(dependencies), # ['close'] + ) + ] +``` + +**DataSpec 说明**: +- `table`: 数据表名(pro_bar 或 daily) +- `columns`: 需要的字段列表 + +**注意**:数据获取使用用户传入的日期范围,不做自动扩展。时序因子(如 `ts_delay`、`ts_mean`)在数据不足时会返回 null,这是符合预期的行为。 + +--- + +## 5. 阶段四:数据获取 + +### 5.1 DataRouter 的核心职责 + +`DataRouter.fetch_data()` 按需取数、组装核心宽表: + +```python +data_router = DataRouter(storage, start_date, end_date) +wide_data = data_router.fetch_data([plan.data_specs]) +``` + +### 5.2 数据获取流程 + +```python +def fetch_data(self, data_specs, start_date, end_date): + # 1. 合并数据规格,收集所需表和字段 + required_tables = self._collect_required_tables(data_specs) + + # 2. 加载各表数据 + table_data = {} + for table, columns in required_tables.items(): + table_data[table] = self._load_table(table, columns, start_date, end_date) + + # 3. 组装宽表(left join 合并) + wide_table = self._assemble_wide_table(table_data, data_specs) + + return wide_table +``` + +### 5.3 数据库查询 + +`_load_table` 方法从 DuckDB 读取数据: + +```python +def _load_table(self, table, columns, start_date, end_date): + # 通过 Storage 查询数据库 + df = self.storage.load_polars( + table_name=table, + columns=columns + ['ts_code', 'trade_date'], # 必须包含主键 + start_date=start_date, + end_date=end_date + ) + return df +``` + +**Storage.load_polars 内部实现**: + +```python +# data/storage.py +SELECT {columns} FROM {table_name} +WHERE trade_date BETWEEN '{start}' AND '{end}' +ORDER BY ts_code, trade_date +``` + +### 5.4 宽表组装 + +对于 `(close / ts_delay(close, 5)) - 1`,DataRouter 返回的宽表结构: + +``` +┌──────────┬────────────┬───────┐ +│ ts_code │ trade_date │ close │ +├──────────┼────────────┼───────┤ +│ 000001.SZ│ 20240101 │ 10.5 │ +│ 000001.SZ│ 20240102 │ 10.6 │ +│ ... │ ... │ ... │ +│ 000002.SZ│ 20240101 │ 20.1 │ +└──────────┴────────────┴───────┘ +``` + +**注意**:数据获取使用用户传入的日期范围,不做自动扩展。对于时序因子(如 `ts_delay(close, 5)`),如果数据不足会返回 null,这是符合预期的行为。用户如需完整计算,应显式扩展日期范围。 + +--- + +## 6. 阶段五:Polars 翻译与计算 + +### 6.1 PolarsTranslator 的作用 + +将 DSL AST 翻译为 Polars 表达式(惰性计算图): + +```python +translator = PolarsTranslator() +polars_expr = translator.translate(dsl_node) +``` + +### 6.2 翻译规则 + +| DSL 节点类型 | Polars 表达式 | 说明 | +|-------------|--------------|------| +| `Symbol('close')` | `pl.col('close')` | 列引用 | +| `Constant(5)` | `pl.lit(5)` | 字面量 | +| `BinaryOpNode('/', a, b)` | `a / b` | 算术运算 | +| `FunctionNode('ts_delay', x, n)` | `x.shift(n).over('ts_code')` | 时间序列滞后 | +| `FunctionNode('ts_mean', x, n)` | `x.rolling_mean(n).over('ts_code')` | 时间序列均值 | +| `FunctionNode('cs_rank', x)` | `x.rank().over('trade_date')` | 截面排名 | + +### 6.3 时间序列函数翻译 + +`ts_delay(close, 5)` 翻译为: + +```python +pl.col('close').shift(5).over('ts_code') +``` + +**关键点**: +- `.shift(5)`:向后偏移 5 个位置 +- `.over('ts_code')`:按股票代码分组计算(每只股票独立计算) + +### 6.4 计算执行 + +`ComputeEngine.execute()` 执行计算: + +```python +compute_engine = ComputeEngine() +result = compute_engine.execute(plan, wide_data) +``` + +**执行逻辑**: + +```python +def execute(self, plan, data): + # 使用 Polars with_columns 添加因子列 + result = data.with_columns([ + plan.polars_expr.alias(plan.output_name) + ]) + return result +``` + +### 6.5 计算结果 + +最终结果包含原始列和计算出的因子列: + +``` +┌──────────┬────────────┬───────┬─────────────┐ +│ ts_code │ trade_date │ close │ returns_5d │ +├──────────┼────────────┼───────┼─────────────┤ +│ 000001.SZ│ 20240101 │ 10.5 │ null │ # 前5天无数据 +│ 000001.SZ│ 20240106 │ 10.8 │ 0.0286 │ # (10.8/10.5)-1 +│ ... │ ... │ ... │ ... │ +└──────────┴────────────┴───────┴─────────────┘ +``` + +--- + +## 7. 完整调用链示例 + +### 7.1 用户代码 + +```python +from src.factors import FactorEngine, FormulaParser, FunctionRegistry + +# 1. 创建引擎 +engine = FactorEngine() + +# 2. 解析字符串表达式 +parser = FormulaParser(FunctionRegistry()) +expr = parser.parse("(close / ts_delay(close, 5)) - 1") + +# 3. 注册因子 +engine.register("returns_5d", expr) + +# 4. 执行计算 +result = engine.compute( + factor_names=["returns_5d"], + start_date="20240101", + end_date="20240131" +) +``` + +### 7.2 内部调用链 + +``` +FactorEngine.compute() + │ + ├── 1. 创建 ExecutionPlan + │ └── ExecutionPlanner.create_plan() + │ ├── DependencyExtractor.extract_dependencies() → {'close'} + │ ├── _infer_data_specs() → [DataSpec('pro_bar', ['close'], 5)] + │ └── PolarsTranslator.translate() → pl.col('close').shift(5).over('ts_code')... + │ + ├── 2. 获取数据 + │ └── DataRouter.fetch_data([plan.data_specs]) + │ ├── _load_table('pro_bar', ['close'], start_date-5d, end_date) + │ │ └── Storage.load_polars() → 查询 DuckDB + │ └── _assemble_wide_table() → Polars DataFrame + │ + └── 3. 执行计算 + └── ComputeEngine.execute(plan, data) + └── data.with_columns([polars_expr.alias('returns_5d')]) + └── Polars 执行表达式计算 +``` + +--- + +## 8. 数据流时序图 + +```mermaid +sequenceDiagram + participant User as 用户代码 + participant FE as FactorEngine + participant PL as ExecutionPlanner + participant DR as DataRouter + participant ST as Storage + participant CE as ComputeEngine + participant DB as DuckDB + + User->>FE: compute("(close/ts_delay(close,5))-1", start, end) + + Note over FE: 阶段1:创建执行计划 + FE->>PL: create_plan(node, output_name) + PL->>PL: extract_dependencies(node) + PL->>PL: infer_data_specs(node) + PL->>PL: translate_to_polars(node) + PL-->>FE: ExecutionPlan + + Note over FE: 阶段2:数据获取 + FE->>DR: fetch_data(data_specs) + + loop 每个需要的表 + DR->>ST: load_polars(table, columns, adjusted_start, end) + ST->>DB: SELECT ... WHERE trade_date BETWEEN ... + DB-->>ST: 原始数据 + ST-->>DR: Polars DataFrame + end + + DR->>DR: assemble_wide_table() + DR->>DR: filter_by_date_range() + DR-->>FE: 核心宽表 + + Note over FE: 阶段3:执行计算 + FE->>CE: execute(plan, data) + CE->>CE: data.with_columns([polars_expr]) + CE-->>FE: 计算结果 + + FE-->>User: Polars DataFrame (含因子列) +``` + +--- + +## 附录:关键代码片段 + +### A.1 FactorEngine.compute 方法 + +```python +def compute( + self, + factor_names: list[str], + start_date: str, + end_date: str +) -> pl.DataFrame: + # 1. 收集所有执行计划 + all_plans = [] + for name in factor_names: + plan = self.factors[name] + all_plans.append(plan) + + # 2. 合并数据规格 + merged_specs = self._merge_data_specs([p.data_specs for p in all_plans]) + + # 3. 获取数据 + data_router = DataRouter(self.storage, start_date, end_date) + data = data_router.fetch_data(merged_specs) + + # 4. 执行计算 + compute_engine = ComputeEngine() + result = compute_engine.execute_plans(all_plans, data) + + return result +``` + +### A.2 PolarsTranslator 的函数处理器注册 + +```python +class PolarsTranslator: + def __init__(self): + self.handlers: dict[str, Callable] = {} + self._register_default_handlers() + + def _register_default_handlers(self): + # 时间序列函数 + self.handlers['ts_delay'] = lambda x, n: x.shift(n) + self.handlers['ts_mean'] = lambda x, n: x.rolling_mean(n) + self.handlers['ts_std'] = lambda x, n: x.rolling_std(n) + + # 截面函数 + self.handlers['cs_rank'] = lambda x: x.rank() + self.handlers['cs_zscore'] = lambda x: (x - x.mean()) / x.std() +``` + +### A.3 时序因子数据获取说明 + +由于移除了自动日期扩展机制,用户需要显式管理时序因子的日期范围: + +```python +# 示例:计算 2024-01-15 到 2024-01-20 的 5 日收益率 +# 需要显式提供足够的历史数据 +result = engine.compute( + factor_names=["returns_5d"], # (close / ts_delay(close, 5)) - 1 + start_date="20240108", # 向前扩展至少 5 个交易日 + end_date="20240120" +) + +# 在结果中,2024-01-15 之前的日期会因数据不足而返回 null +# 用户可以自行过滤到目标日期范围 +result = result.filter( + (pl.col("trade_date") >= "20240115") & (pl.col("trade_date") <= "20240120") +) +``` + +**设计原则**: +- 显式优于隐式:数据来源透明,用户可以完全控制数据范围 +- 符合 Polars 行为:rolling/shift 操作在窗口不足时返回 null +- 可验证性:用户可以明确知道用了哪些数据计算因子 + +--- + +## 总结 + +ProStock 的因子计算引擎采用 **DSL(领域特定语言)+ 延迟计算** 的架构设计,具有以下特点: + +1. **声明式**:用户通过数学表达式描述因子逻辑,无需关心实现细节 +2. **惰性求值**:表达式构建时不立即计算,生成执行计划后统一执行 +3. **智能数据获取**:自动分析依赖、推导数据规格、按需取数 +4. **向量化计算**:基于 Polars 的高性能向量化运算,支持时序和截面计算 +5. **可扩展性**:通过 FunctionRegistry 可以轻松添加新的因子函数 + +整个流程从字符串到计算结果,经历了解析 → 编译 → 计划 → 取数 → 计算五个阶段,各组件职责清晰,便于维护和扩展。 diff --git a/src/data/api_wrappers/__init__.py b/src/data/api_wrappers/__init__.py index 1060376..035ed4e 100644 --- a/src/data/api_wrappers/__init__.py +++ b/src/data/api_wrappers/__init__.py @@ -5,6 +5,7 @@ All wrapper files follow the naming convention: api_{data_type}.py Available APIs: - api_daily: Daily market data (日线行情) + - api_daily_basic: Daily basic indicators (每日指标,换手率、PE、PB、市值等) - api_pro_bar: Pro Bar universal market data (通用行情,后复权) - api_stock_basic: Stock basic information (股票基本信息) - api_trade_cal: Trading calendar (交易日历) @@ -13,9 +14,10 @@ Available APIs: Example: >>> from src.data.api_wrappers import get_daily, get_stock_basic, get_trade_cal, get_bak_basic - >>> from src.data.api_wrappers import get_pro_bar, sync_pro_bar + >>> from src.data.api_wrappers import get_pro_bar, sync_pro_bar, get_daily_basic, sync_daily_basic >>> data = get_daily('000001.SZ', start_date='20240101', end_date='20240131') >>> pro_data = get_pro_bar('000001.SZ', start_date='20240101', end_date='20240131') + >>> daily_basic = get_daily_basic(trade_date='20240101') >>> stocks = get_stock_basic() >>> calendar = get_trade_cal('20240101', '20240131') >>> bak_basic = get_bak_basic(trade_date='20240101') @@ -27,6 +29,12 @@ from src.data.api_wrappers.api_daily import ( preview_daily_sync, DailySync, ) +from src.data.api_wrappers.api_daily_basic import ( + get_daily_basic, + sync_daily_basic, + preview_daily_basic_sync, + DailyBasicSync, +) from src.data.api_wrappers.api_pro_bar import ( get_pro_bar, sync_pro_bar, @@ -55,6 +63,11 @@ __all__ = [ "sync_daily", "preview_daily_sync", "DailySync", + # Daily basic indicators + "get_daily_basic", + "sync_daily_basic", + "preview_daily_basic_sync", + "DailyBasicSync", # Pro Bar (universal market data) "get_pro_bar", "sync_pro_bar", diff --git a/src/data/api_wrappers/api.md b/src/data/api_wrappers/api.md index 26f398e..71383b5 100644 --- a/src/data/api_wrappers/api.md +++ b/src/data/api_wrappers/api.md @@ -495,4 +495,74 @@ df = ts.pro_bar(ts_code='000001.SZ', start_date='20180101', end_date='20181011', 例如: -df = ts.pro_bar(ts_code='000001.SH', asset='I', start_date='20180101', end_date='20181011') \ No newline at end of file +df = ts.pro_bar(ts_code='000001.SH', asset='I', start_date='20180101', end_date='20181011') + +每日指标 +接口:daily_basic,可以通过数据工具调试和查看数据。 +更新时间:交易日每日15点~17点之间 +描述:获取全部股票每日重要的基本面指标,可用于选股分析、报表展示等。单次请求最大返回6000条数据,可按日线循环提取全部历史。 +积分:至少2000积分才可以调取,5000积分无总量限制,具体请参阅积分获取办法 + +输入参数 + +名称 类型 必选 描述 +ts_code str Y 股票代码(二选一) +trade_date str N 交易日期 (二选一) +start_date str N 开始日期(YYYYMMDD) +end_date str N 结束日期(YYYYMMDD) +注:日期都填YYYYMMDD格式,比如20181010 + +输出参数 + +名称 类型 描述 +ts_code str TS股票代码 +trade_date str 交易日期 +close float 当日收盘价 +turnover_rate float 换手率(%) +turnover_rate_f float 换手率(自由流通股) +volume_ratio float 量比 +pe float 市盈率(总市值/净利润, 亏损的PE为空) +pe_ttm float 市盈率(TTM,亏损的PE为空) +pb float 市净率(总市值/净资产) +ps float 市销率 +ps_ttm float 市销率(TTM) +dv_ratio float 股息率 (%) +dv_ttm float 股息率(TTM)(%) +total_share float 总股本 (万股) +float_share float 流通股本 (万股) +free_share float 自由流通股本 (万) +total_mv float 总市值 (万元) +circ_mv float 流通市值(万元) +接口用法 + + +pro = ts.pro_api() + +df = pro.daily_basic(ts_code='', trade_date='20180726', fields='ts_code,trade_date,turnover_rate,volume_ratio,pe,pb') +或者 + + +df = pro.query('daily_basic', ts_code='', trade_date='20180726',fields='ts_code,trade_date,turnover_rate,volume_ratio,pe,pb') +数据样例 + + ts_code trade_date turnover_rate volume_ratio pe pb +0 600230.SH 20180726 2.4584 0.72 8.6928 3.7203 +1 600237.SH 20180726 1.4737 0.88 166.4001 1.8868 +2 002465.SZ 20180726 0.7489 0.72 71.8943 2.6391 +3 300732.SZ 20180726 6.7083 0.77 21.8101 3.2513 +4 600007.SH 20180726 0.0381 0.61 23.7696 2.3774 +5 300068.SZ 20180726 1.4583 0.52 27.8166 1.7549 +6 300552.SZ 20180726 2.0728 0.95 56.8004 2.9279 +7 601369.SH 20180726 0.2088 0.95 44.1163 1.8001 +8 002518.SZ 20180726 0.5814 0.76 15.1004 2.5626 +9 002913.SZ 20180726 12.1096 1.03 33.1279 2.9217 +10 601818.SH 20180726 0.1893 0.86 6.3064 0.7209 +11 600926.SH 20180726 0.6065 0.46 9.1772 0.9808 +12 002166.SZ 20180726 0.7582 0.82 16.9868 3.3452 +13 600841.SH 20180726 0.3754 1.02 66.2647 2.2302 +14 300634.SZ 20180726 23.1127 1.26 120.3053 14.3168 +15 300126.SZ 20180726 1.2304 1.11 348.4306 1.5171 +16 300718.SZ 20180726 17.6612 0.92 32.0239 3.8661 +17 000708.SZ 20180726 0.5575 0.70 10.3674 1.0276 +18 002626.SZ 20180726 0.6187 0.83 22.7580 4.2446 +19 600816.SH 20180726 0.6745 0.65 11.0778 3.2214 \ No newline at end of file diff --git a/src/data/api_wrappers/api_daily_basic.py b/src/data/api_wrappers/api_daily_basic.py new file mode 100644 index 0000000..110161e --- /dev/null +++ b/src/data/api_wrappers/api_daily_basic.py @@ -0,0 +1,252 @@ +"""每日指标数据接口。 + +获取全部股票每日重要的基本面指标,包括换手率、市盈率、市净率、 +总市值、流通市值等,可用于选股分析、报表展示等。 +""" + +from typing import Optional, Dict, Any + +import pandas as pd + +from src.data.client import TushareClient +from src.data.api_wrappers.base_sync import DateBasedSync + + +def get_daily_basic( + trade_date: Optional[str] = None, + ts_code: Optional[str] = None, + start_date: Optional[str] = None, + end_date: Optional[str] = None, + client: Optional[TushareClient] = None, +) -> pd.DataFrame: + """Fetch daily basic indicators from Tushare. + + This interface retrieves important daily fundamental indicators for all stocks, + including turnover rate, PE, PB, market value, etc. It can be used for stock + selection analysis and report display. + + Note: At least one of trade_date or ts_code must be provided. The recommended + approach is to use trade_date to fetch data for all stocks on a specific date, + which is more efficient than fetching by individual stock codes. + + Args: + trade_date: Specific trade date (YYYYMMDD format). Use this to get all + stocks' data for a single date. More efficient than ts_code. + ts_code: Stock code (e.g., '000001.SZ', '600000.SH'). Optional if + trade_date is provided. + start_date: Start date (YYYYMMDD format). Use with end_date for date range. + end_date: End date (YYYYMMDD format). Use with start_date for date range. + client: Optional TushareClient instance for shared rate limiting. + If None, creates a new client. For concurrent sync operations, + pass a shared client to ensure proper rate limiting. + + Returns: + pd.DataFrame with columns: + - ts_code: TS stock code + - trade_date: Trade date (YYYYMMDD) + - close: Closing price + - turnover_rate: Turnover rate (%) + - turnover_rate_f: Turnover rate (free float shares) + - volume_ratio: Volume ratio + - pe: Price-to-earnings ratio (total market cap / net profit) + - pe_ttm: PE ratio (TTM) + - pb: Price-to-book ratio (total market cap / net assets) + - ps: Price-to-sales ratio + - ps_ttm: PS ratio (TTM) + - dv_ratio: Dividend yield (%) + - dv_ttm: Dividend yield (TTM) (%) + - total_share: Total shares (10k shares) + - float_share: Float shares (10k shares) + - free_share: Free float shares (10k shares) + - total_mv: Total market value (10k CNY) + - circ_mv: Circulating market value (10k CNY) + + Example: + >>> # Get all stocks for a single date (recommended) + >>> data = get_daily_basic(trade_date='20240101') + >>> + >>> # Get specific stock data + >>> data = get_daily_basic(ts_code='000001.SZ', trade_date='20240101') + >>> + >>> # Get date range data for a specific stock + >>> data = get_daily_basic( + ... ts_code='000001.SZ', + ... start_date='20240101', + ... end_date='20240131' + ... ) + """ + client = client or TushareClient() + + # Build parameters + params = {} + if trade_date: + params["trade_date"] = trade_date + if ts_code: + params["ts_code"] = ts_code + if start_date: + params["start_date"] = start_date + if end_date: + params["end_date"] = end_date + + # Fetch data using daily_basic API + data = client.query("daily_basic", **params) + + # Rename date column if needed + if "date" in data.columns: + data = data.rename(columns={"date": "trade_date"}) + + return data + + +class DailyBasicSync(DateBasedSync): + """每日指标数据批量同步管理器,支持全量/增量同步。 + + 继承自 DateBasedSync,按日期顺序获取数据。 + 每日指标数据适合按日期获取,一次 API 调用即可获取全市场数据。 + + Example: + >>> sync = DailyBasicSync() + >>> results = sync.sync_all() # 增量同步 + >>> results = sync.sync_all(force_full=True) # 全量同步 + >>> preview = sync.preview_sync() # 预览 + """ + + table_name = "daily_basic" + default_start_date = "20180101" + + # 表结构定义 + TABLE_SCHEMA = { + "ts_code": "VARCHAR(16) NOT NULL", + "trade_date": "DATE NOT NULL", + "close": "DOUBLE", + "turnover_rate": "DOUBLE", + "turnover_rate_f": "DOUBLE", + "volume_ratio": "DOUBLE", + "pe": "DOUBLE", + "pe_ttm": "DOUBLE", + "pb": "DOUBLE", + "ps": "DOUBLE", + "ps_ttm": "DOUBLE", + "dv_ratio": "DOUBLE", + "dv_ttm": "DOUBLE", + "total_share": "DOUBLE", + "float_share": "DOUBLE", + "free_share": "DOUBLE", + "total_mv": "DOUBLE", + "circ_mv": "DOUBLE", + } + + # 索引定义 + TABLE_INDEXES = [ + ("idx_daily_basic_date_code", ["trade_date", "ts_code"]), + ] + + # 主键定义 + PRIMARY_KEY = ("ts_code", "trade_date") + + def fetch_single_date(self, trade_date: str) -> pd.DataFrame: + """获取单日的每日指标数据。 + + Args: + trade_date: 交易日期(YYYYMMDD) + + Returns: + 包含当日所有股票指标的 DataFrame + """ + # 使用 get_daily_basic 获取数据(传递共享 client) + data = get_daily_basic( + trade_date=trade_date, + client=self.client, # 传递共享客户端以确保限流 + ) + return data + + +def sync_daily_basic( + force_full: bool = False, + start_date: Optional[str] = None, + end_date: Optional[str] = None, + dry_run: bool = False, +) -> pd.DataFrame: + """同步所有股票的每日指标数据。 + + 这是每日指标数据同步的主要入口点。 + + Args: + force_full: 若为 True,强制从 20180101 完整重载 + start_date: 手动指定起始日期(YYYYMMDD) + end_date: 手动指定结束日期(默认为今天) + dry_run: 若为 True,仅预览将要同步的内容,不写入数据 + + Returns: + 同步的数据 DataFrame + + Example: + >>> # 首次同步(从 20180101 全量加载) + >>> result = sync_daily_basic() + >>> + >>> # 后续同步(增量 - 仅新数据) + >>> result = sync_daily_basic() + >>> + >>> # 强制完整重载 + >>> result = sync_daily_basic(force_full=True) + >>> + >>> # 手动指定日期范围 + >>> result = sync_daily_basic(start_date='20240101', end_date='20240131') + >>> + >>> # Dry run(仅预览) + >>> result = sync_daily_basic(dry_run=True) + """ + sync_manager = DailyBasicSync() + return sync_manager.sync_all( + force_full=force_full, + start_date=start_date, + end_date=end_date, + dry_run=dry_run, + ) + + +def preview_daily_basic_sync( + force_full: bool = False, + start_date: Optional[str] = None, + end_date: Optional[str] = None, + sample_size: int = 3, +) -> Dict[str, Any]: + """预览每日指标同步数据量和样本(不实际同步)。 + + 这是推荐的方式,可在实际同步前检查将要同步的内容。 + + Args: + force_full: 若为 True,预览全量同步(从 20180101) + start_date: 手动指定起始日期(覆盖自动检测) + end_date: 手动指定结束日期(默认为今天) + sample_size: 预览用样本天数(默认: 3) + + Returns: + 包含预览信息的字典: + { + 'sync_needed': bool, + 'date_count': int, + 'start_date': str, + 'end_date': str, + 'estimated_records': int, + 'sample_data': pd.DataFrame, + 'mode': str, # 'full', 'incremental', 或 'none' + } + + Example: + >>> # 预览将要同步的内容 + >>> preview = preview_daily_basic_sync() + >>> + >>> # 预览全量同步 + >>> preview = preview_daily_basic_sync(force_full=True) + >>> + >>> # 预览更多样本 + >>> preview = preview_daily_basic_sync(sample_size=5) + """ + sync_manager = DailyBasicSync() + return sync_manager.preview_sync( + force_full=force_full, + start_date=start_date, + end_date=end_date, + sample_size=sample_size, + ) diff --git a/src/data/sync.py b/src/data/sync.py index bf83b28..822a1cf 100644 --- a/src/data/sync.py +++ b/src/data/sync.py @@ -7,6 +7,7 @@ ✅ 本模块包含的同步逻辑(每日更新): - api_daily.py: 日线数据同步 (DailySync 类) + - api_daily_basic.py: 每日指标数据同步 (DailyBasicSync 类) - api_bak_basic.py: 历史股票列表同步 (BakBasicSync 类) - api_pro_bar.py: Pro Bar 数据同步 (ProBarSync 类) - api_stock_basic.py: 股票基本信息同步 @@ -44,6 +45,7 @@ from src.data.api_wrappers import sync_all_stocks from src.data.api_wrappers.api_daily import sync_daily, preview_daily_sync from src.data.api_wrappers.api_pro_bar import sync_pro_bar from src.data.api_wrappers.api_bak_basic import sync_bak_basic +from src.data.api_wrappers.api_daily_basic import sync_daily_basic def preview_sync( @@ -157,7 +159,8 @@ def sync_all_data( 2. 股票基本信息 (sync_all_stocks) 3. 日线数据 (sync_daily) 4. Pro Bar 数据 (sync_pro_bar) - 5. 历史股票列表 (sync_bak_basic) + 5. 每日指标数据 (sync_daily_basic) + 6. 历史股票列表 (sync_bak_basic) 【不包含的同步(需单独调用)】 - 财务数据: 利润表、资产负债表、现金流量表(季度更新) @@ -238,7 +241,7 @@ def sync_all_data( results["daily"] = pd.DataFrame() # 4. Sync Pro Bar data - print("\n[4/5] Syncing Pro Bar data (with adj, tor, vr)...") + print("\n[4/6] Syncing Pro Bar data (with adj, tor, vr)...") try: # 确保表存在 from src.data.api_wrappers.api_pro_bar import ProBarSync @@ -255,14 +258,31 @@ def sync_all_data( sum(len(df) for df in pro_bar_result.values()) if pro_bar_result else 0 ) print( - f"[4/5] Pro Bar data: OK ({total_pro_bar_records} records from {len(pro_bar_result)} stocks)" + f"[4/6] Pro Bar data: OK ({total_pro_bar_records} records from {len(pro_bar_result)} stocks)" ) except Exception as e: - print(f"[4/5] Pro Bar data: FAILED - {e}") + print(f"[4/6] Pro Bar data: FAILED - {e}") results["pro_bar"] = pd.DataFrame() - # 5. Sync stock historical list (bak_basic) - print("\n[5/5] Syncing stock historical list (bak_basic)...") + # 5. Sync daily basic indicators + print( + "\n[5/6] Syncing daily basic indicators (PE, PB, turnover rate, market value)..." + ) + try: + # 确保表存在 + from src.data.api_wrappers.api_daily_basic import DailyBasicSync + + DailyBasicSync().ensure_table_exists() + + daily_basic_result = sync_daily_basic(force_full=force_full, dry_run=dry_run) + results["daily_basic"] = daily_basic_result + print(f"[5/6] Daily basic: OK ({len(daily_basic_result)} records)") + except Exception as e: + print(f"[5/6] Daily basic: FAILED - {e}") + results["daily_basic"] = pd.DataFrame() + + # 6. Sync stock historical list (bak_basic) + print("\n[6/6] Syncing stock historical list (bak_basic)...") try: # 确保表存在 from src.data.api_wrappers.api_bak_basic import BakBasicSync @@ -271,9 +291,9 @@ def sync_all_data( bak_basic_result = sync_bak_basic(force_full=force_full) results["bak_basic"] = bak_basic_result - print(f"[5/5] Bak basic: OK ({len(bak_basic_result)} records)") + print(f"[6/6] Bak basic: OK ({len(bak_basic_result)} records)") except Exception as e: - print(f"[5/5] Bak basic: FAILED - {e}") + print(f"[6/6] Bak basic: FAILED - {e}") results["bak_basic"] = pd.DataFrame() # Summary @@ -286,7 +306,7 @@ def sync_all_data( total_records = sum(len(df) for df in data.values()) print(f" {data_type}: {len(data)} stocks, {total_records} total records") else: - # bak_basic 返回的是 DataFrame + # daily_basic 和 bak_basic 返回的是 DataFrame print(f" {data_type}: {len(data)} records") print("=" * 60) print("\nNote: namechange is NOT in auto-sync. To sync manually:") @@ -308,7 +328,7 @@ if __name__ == "__main__": print("") print(" # Or sync individual data types:") print(" from src.data.sync import sync_all, preview_sync") - print(" from src.data.sync import sync_bak_basic") + print(" from src.data.api_wrappers import sync_daily_basic, sync_bak_basic") print("") print(" # Preview before sync (recommended)") print(" preview = preview_sync()") diff --git a/src/factors/engine/data_router.py b/src/factors/engine/data_router.py index 6cccc29..e784449 100644 --- a/src/factors/engine/data_router.py +++ b/src/factors/engine/data_router.py @@ -69,16 +69,11 @@ class DataRouter: # 收集所有需要的表和字段 required_tables: Dict[str, Set[str]] = {} - max_lookback = 0 for spec in data_specs: if spec.table not in required_tables: required_tables[spec.table] = set() required_tables[spec.table].update(spec.columns) - max_lookback = max(max_lookback, spec.lookback_days) - - # 调整日期范围以包含回看期 - adjusted_start = self._adjust_start_date(start_date, max_lookback) # 从数据源获取各表数据 table_data = {} @@ -86,7 +81,7 @@ class DataRouter: df = self._load_table( table_name=table_name, columns=list(columns), - start_date=adjusted_start, + start_date=start_date, end_date=end_date, stock_codes=stock_codes, ) @@ -95,11 +90,6 @@ class DataRouter: # 组装核心宽表 core_table = self._assemble_wide_table(table_data, required_tables) - # 过滤到实际请求日期范围 - core_table = core_table.filter( - (pl.col("trade_date") >= start_date) & (pl.col("trade_date") <= end_date) - ) - return core_table def _load_table( @@ -265,34 +255,6 @@ class DataRouter: return result - def _adjust_start_date(self, start_date: str, lookback_days: int) -> str: - """根据回看天数调整开始日期。 - - Args: - start_date: 原始开始日期 (YYYYMMDD) - lookback_days: 需要回看的交易日数 - - Returns: - 调整后的开始日期 - """ - # 简化的日期调整:假设每月30天,向前推移 - # 实际应用中应该使用交易日历 - year = int(start_date[:4]) - month = int(start_date[4:6]) - day = int(start_date[6:8]) - - total_days = lookback_days + 30 # 额外缓冲 - - day -= total_days - while day <= 0: - month -= 1 - if month <= 0: - month = 12 - year -= 1 - day += 30 - - return f"{year:04d}{month:02d}{day:02d}" - def clear_cache(self) -> None: """清除数据缓存。""" with self._lock: diff --git a/src/factors/engine/data_spec.py b/src/factors/engine/data_spec.py index 8a7e81e..a41ad24 100644 --- a/src/factors/engine/data_spec.py +++ b/src/factors/engine/data_spec.py @@ -18,12 +18,10 @@ class DataSpec: Attributes: table: 数据表名称 columns: 需要的字段列表 - lookback_days: 回看天数(用于时序计算) """ table: str columns: List[str] - lookback_days: int = 1 @dataclass diff --git a/src/factors/engine/planner.py b/src/factors/engine/planner.py index aece8bc..75a8bb7 100644 --- a/src/factors/engine/planner.py +++ b/src/factors/engine/planner.py @@ -73,9 +73,9 @@ class ExecutionPlanner: ) -> List[DataSpec]: """从依赖推导数据规格。 - 根据表达式中的函数类型推断回看天数需求。 基础行情字段(open, high, low, close, vol, amount, pre_close, change, pct_chg) 默认从 pro_bar 表获取。 + 每日指标字段(total_mv, circ_mv, pe, pb 等)从 daily_basic 表获取。 Args: dependencies: 依赖的字段集合 @@ -84,10 +84,6 @@ class ExecutionPlanner: Returns: 数据规格列表 """ - # 计算最大回看窗口 - max_window = self._extract_max_window(expression) - lookback_days = max(1, max_window) - # 基础行情字段集合(这些字段从 pro_bar 表获取) pro_bar_fields = { "open", @@ -103,9 +99,27 @@ class ExecutionPlanner: "volume_ratio", } - # 将依赖分为 pro_bar 字段和其他字段 + # 每日指标字段集合(这些字段从 daily_basic 表获取) + daily_basic_fields = { + "turnover_rate_f", + "pe", + "pe_ttm", + "pb", + "ps", + "ps_ttm", + "dv_ratio", + "dv_ttm", + "total_share", + "float_share", + "free_share", + "total_mv", + "circ_mv", + } + + # 将依赖分为不同表的字段 pro_bar_deps = dependencies & pro_bar_fields - other_deps = dependencies - pro_bar_fields + daily_basic_deps = dependencies & daily_basic_fields + other_deps = dependencies - pro_bar_fields - daily_basic_fields data_specs = [] @@ -115,7 +129,15 @@ class ExecutionPlanner: DataSpec( table="pro_bar", columns=sorted(pro_bar_deps), - lookback_days=lookback_days, + ) + ) + + # daily_basic 表的数据规格 + if daily_basic_deps: + data_specs.append( + DataSpec( + table="daily_basic", + columns=sorted(daily_basic_deps), ) ) @@ -125,46 +147,7 @@ class ExecutionPlanner: DataSpec( table="daily", columns=sorted(other_deps), - lookback_days=lookback_days, ) ) return data_specs - - def _extract_max_window(self, node: Node) -> int: - """从表达式中提取最大窗口大小。 - - Args: - node: AST 节点 - - Returns: - 最大窗口大小,无时序函数返回 1 - """ - if isinstance(node, FunctionNode): - window = 1 - # 检查函数参数中的窗口大小 - for arg in node.args: - if ( - isinstance(arg, Constant) - and isinstance(arg.value, int) - and arg.value > window - ): - window = arg.value - - # 递归检查子表达式 - for arg in node.args: - if isinstance(arg, Node) and not isinstance(arg, Constant): - window = max(window, self._extract_max_window(arg)) - - return window - - elif isinstance(node, BinaryOpNode): - return max( - self._extract_max_window(node.left), - self._extract_max_window(node.right), - ) - - elif isinstance(node, UnaryOpNode): - return self._extract_max_window(node.operand) - - return 1 diff --git a/src/training/main.py b/src/training/main.py deleted file mode 100644 index 204811c..0000000 --- a/src/training/main.py +++ /dev/null @@ -1,302 +0,0 @@ -"""训练流程入口脚本 - -运行方式: - uv run python -m src.training.main - -或: - uv run python src/training/main.py - -本脚本提供两种运行方式: -1. run_full_pipeline(): 完整训练流程(数据准备 -> 训练 -> 预测) -2. prepare_data_and_train() + train_and_predict(): 分步执行,便于调试和调整 - -因子配置示例: - from src.factors import MovingAverageFactor, ReturnRankFactor - - # 直接传入因子实例列表 - 最简单的方式 - factors = [ - MovingAverageFactor(period=5), - MovingAverageFactor(period=10), - MovingAverageFactor(period=20), - ReturnRankFactor(period=5), - ReturnRankFactor(period=10), - ] - - # 运行完整流程 - result = run_full_pipeline(factors=factors) -""" - -from pathlib import Path -from typing import Optional, List - -import polars as pl - -from src.factors import BaseFactor -from src.training.pipeline import ( - FactorConfig, - predict_top_stocks, - prepare_data, - save_top_stocks, - train_model, -) - - -def prepare_data_and_train( - factors: Optional[List[BaseFactor]] = None, - data_dir: str = "data", - train_start: str = "20190101", - train_end: str = "20231231", - val_start: str = "20240102", - val_end: str = "20240531", - test_start: str = "20240602", - test_end: str = "20241231", -) -> tuple[pl.DataFrame, pl.DataFrame, pl.DataFrame, FactorConfig, str]: - """第一步:数据处理 - - 加载原始数据,计算因子和标签,拆分训练/验证/测试集。 - - Args: - factors: 因子实例列表,默认为 None(使用 MA5, MA10, ReturnRank5) - data_dir: 数据目录 - train_start: 训练集开始日期 - train_end: 训练集结束日期 - val_start: 验证集开始日期 - val_end: 验证集结束日期 - test_start: 测试集开始日期 - test_end: 测试集结束日期 - - Returns: - tuple: (train_data, val_data, test_data, factor_config, label_col) - """ - print("=" * 50) - print("[Step 1] 数据处理") - print("=" * 50) - print(f"训练集: {train_start} -> {train_end}") - print(f"验证集: {val_start} -> {val_end}") - print(f"测试集: {test_start} -> {test_end}") - print() - - # 1. 准备数据 - train_data, val_data, test_data, factor_config = prepare_data( - factors=factors, - data_dir=data_dir, - train_start=train_start, - train_end=train_end, - val_start=val_start, - val_end=val_end, - test_start=test_start, - test_end=test_end, - ) - - print(f"训练集样本数: {len(train_data)}") - print(f"验证集样本数: {len(val_data)}") - print(f"测试集样本数: {len(test_data)}") - print() - - # 打印少量数据样本展示 - print("=" * 50) - print("[数据预览] 训练集前3行:") - print(train_data.head(3)) - print() - print("[数据预览] 验证集前3行:") - print(val_data.head(3)) - print() - print("[数据预览] 测试集前3行:") - print(test_data.head(3)) - print() - - # 2. 获取特征列名 - feature_cols = factor_config.get_feature_names() - label_col = "label" - - print(f"特征列: {feature_cols}") - print(f"标签列: {label_col}") - print() - - return train_data, val_data, test_data, factor_config, label_col - - -def train_and_predict( - train_data: pl.DataFrame, - val_data: pl.DataFrame, - test_data: pl.DataFrame, - factor_config: FactorConfig, - label_col: str = "label", - top_n: int = 5, - output_path: str = "output/top_stocks.tsv", -) -> pl.DataFrame: - """第二步:训练和预测 - - 使用处理好的数据训练模型,进行测试集预测并保存结果。 - - Args: - train_data: 训练数据 - val_data: 验证数据 - test_data: 测试数据 - factor_config: 因子配置对象 - label_col: 标签列名 - top_n: 每日选股数量 - output_path: 输出文件路径 - - Returns: - 选股结果DataFrame - """ - print("=" * 50) - print("[Step 2] 模型训练与预测") - print("=" * 50) - print() - - # 获取特征列名 - feature_cols = factor_config.get_feature_names() - print(f"使用特征: {feature_cols}") - print() - - # 3. 训练模型 - print("[Training] Training model...") - model, pipeline = train_model( - train_data=train_data, - val_data=val_data, - feature_cols=feature_cols, - label_col=label_col, - ) - print() - - # 4. 测试集预测 - print("[Predict] Predicting on test set...") - top_stocks = predict_top_stocks( - model=model, - pipeline=pipeline, - test_data=test_data, - feature_cols=feature_cols, - top_n=top_n, - ) - print() - - # 5. 保存结果 - print(f"[Saving] Saving results to {output_path}...") - Path(output_path).parent.mkdir(parents=True, exist_ok=True) - save_top_stocks(top_stocks, output_path) - print() - - return top_stocks - - -def run_full_pipeline( - factors: Optional[List[BaseFactor]] = None, - train_start: str = "20190101", - train_end: str = "20231231", - val_start: str = "20240102", - val_end: str = "20240531", - test_start: str = "20240602", - test_end: str = "20241231", - top_n: int = 5, - output_path: str = "output/top_stocks.tsv", -) -> pl.DataFrame: - """运行完整训练流程 - - 相当于依次调用 prepare_data_and_train 和 train_and_predict。 - - Args: - factors: 因子实例列表,默认为 None(使用 MA5, MA10, ReturnRank5) - train_start: 训练集开始日期 - train_end: 训练集结束日期 - val_start: 验证集开始日期 - val_end: 验证集结束日期 - test_start: 测试集开始日期 - test_end: 测试集结束日期 - top_n: 每日选股数量 - output_path: 输出文件路径 - - Returns: - 选股结果DataFrame - """ - # 第一步:数据处理 - train_data, val_data, test_data, factor_config, label_col = prepare_data_and_train( - factors=factors, - train_start=train_start, - train_end=train_end, - val_start=val_start, - val_end=val_end, - test_start=test_start, - test_end=test_end, - ) - - # 第二步:训练和预测 - result = train_and_predict( - train_data=train_data, - val_data=val_data, - test_data=test_data, - factor_config=factor_config, - label_col=label_col, - top_n=top_n, - output_path=output_path, - ) - - print("=" * 50) - print("[Done] 训练流程完成!") - print("=" * 50) - - return result - - -if __name__ == "__main__": - from src.factors import MovingAverageFactor, ReturnRankFactor - - # ========== 因子配置 ========== - # 直接传入因子实例列表 - 简单直观 - factors = [ - MovingAverageFactor(period=5), # 5日移动平均线 - MovingAverageFactor(period=10), # 10日移动平均线 - MovingAverageFactor(period=20), # 20日移动平均线 - ReturnRankFactor(period=5), # 5日收益率排名 - ReturnRankFactor(period=10), # 10日收益率排名 - ] - - # ========== 运行方式 ========== - - # 方式一:完整流程(一次性执行) - # result = run_full_pipeline( - # factors=factors, - # train_start="20190101", - # train_end="20231231", - # val_start="20240102", - # val_end="20240531", - # test_start="20240602", - # test_end="20241231", - # top_n=5, - # output_path="output/top_stocks.tsv", - # ) - - # 方式二:分步执行(便于调试) - # 第一步:数据处理 - train_data, val_data, test_data, factor_config, label_col = prepare_data_and_train( - factors=factors, - train_start="20190101", - train_end="20231231", - val_start="20240102", - val_end="20240531", - test_start="20240602", - test_end="20241231", - ) - - # 可在此处添加自定义逻辑,例如: - # - 查看数据分布 - # - 调整特征 - # - 保存中间结果 - print("\n[Info] 因子配置详情:") - print(f" 因子列表: {factor_config.get_feature_names()}") - print(f" 最大回溯天数: {factor_config.get_max_lookback()}") - - # 第二步:训练和预测 - # result = train_and_predict( - # train_data=train_data, - # val_data=val_data, - # test_data=test_data, - # factor_config=factor_config, - # label_col=label_col, - # top_n=5, - # output_path="output/top_stocks.tsv", - # ) - # - # print("\n[Result] Top stocks selection:") - # print(result) diff --git a/tests/test_factor_engine.py b/tests/test_factor_engine.py index ad6768b..1286b9e 100644 --- a/tests/test_factor_engine.py +++ b/tests/test_factor_engine.py @@ -71,7 +71,7 @@ class TestFactorEngineEndToEnd: @pytest.fixture def engine(self, mock_data): """提供配置好的 FactorEngine fixture。""" - data_source = {"daily": mock_data} + data_source = {"pro_bar": mock_data} return FactorEngine(data_source=data_source, max_workers=2) def test_simple_symbol_expression(self, engine): @@ -116,7 +116,7 @@ class TestFullWorkflow: # 2. 初始化引擎 print("\nStep 2: Initialize FactorEngine...") - engine = FactorEngine(data_source={"daily": mock_data}) + engine = FactorEngine(data_source={"pro_bar": mock_data}) print(" Engine initialized") # 3. 注册因子 - 使用简单因子避免回看窗口问题 diff --git a/tests/test_two_stocks_string_factors.py b/tests/test_two_stocks_string_factors.py index 20f45f1..4400799 100644 --- a/tests/test_two_stocks_string_factors.py +++ b/tests/test_two_stocks_string_factors.py @@ -5,6 +5,7 @@ 2. return_5_rank: 5日收益率在截面上的排名 3. ma5: 5日均线 (ts_mean(close, 5)) 4. ma10: 10日均线 (ts_mean(close, 10)) +5. market_cap_rank: 市值百分比排名 (cs_rank(total_mv)) 特点:使用因子字符串架构(add_factor + 字符串表达式) @@ -48,6 +49,11 @@ def test_two_stocks_string_factors(): print("\n[1.4] ma10 = ts_mean(close, 10)") print(f" 字符串表达式: {ma10_str}") + # market_cap_rank: 市值百分比排名 (截面排名) + market_cap_rank_str = "cs_rank(total_mv)" + print("\n[1.5] market_cap_rank = cs_rank(total_mv)") + print(f" 字符串表达式: {market_cap_rank_str}") + # ======================================================================== # 1.5 打印数据来源信息 # ======================================================================== @@ -66,6 +72,7 @@ def test_two_stocks_string_factors(): "return_5_rank": return_5_rank_str, "ma5": ma5_str, "ma10": ma10_str, + "market_cap_rank": market_cap_rank_str, } for name, expr_str in expressions_str.items(): @@ -102,9 +109,12 @@ def test_two_stocks_string_factors(): engine.add_factor("ma10", ma10_str) print("[2.4] 注册 ma10 (字符串方式)") + engine.add_factor("market_cap_rank", market_cap_rank_str) + print("[2.5] 注册 market_cap_rank (市值百分比排名,字符串方式)") + # 也注册原始 close 价格用于验证 engine.add_factor("close_price", "close") - print("[2.5] 注册 close_price (原始收盘价,字符串方式)") + print("[2.6] 注册 close_price (原始收盘价,字符串方式)") print(f"\n已注册因子列表: {engine.list_registered()}") @@ -125,7 +135,6 @@ def test_two_stocks_string_factors(): for i, spec in enumerate(plan.data_specs, 1): print(f" [{i}] 表名: {spec.table}") print(f" 字段: {spec.columns}") - print(f" 回看天数: {spec.lookback_days}") # ======================================================================== # 3. 执行计算(两支股票) @@ -143,7 +152,14 @@ def test_two_stocks_string_factors(): try: result = engine.compute( - factor_names=["return_5", "return_5_rank", "ma5", "ma10", "close_price"], + factor_names=[ + "return_5", + "return_5_rank", + "ma5", + "ma10", + "close_price", + "market_cap_rank", + ], start_date=start_date, end_date=end_date, stock_codes=stock_codes, @@ -345,6 +361,10 @@ def test_two_stocks_string_factors(): print(" - 每天两支股票的排名之和应接近 1") print("-" * 60) + # 5.5.1 return_5_rank 截面排名验证 + print("\n[5.5.1] return_5_rank 截面排名验证:") + print("-" * 60) + # 获取有效数据 result_valid = result.drop_nulls(subset=["return_5_rank"]) @@ -373,6 +393,39 @@ def test_two_stocks_string_factors(): else: print(" [警告] 截面排名之和不接近 1") + # 5.5.2 market_cap_rank 市值百分比排名验证 + print("\n[5.5.2] market_cap_rank 市值百分比排名验证:") + print("-" * 60) + + result_valid_mv = result.drop_nulls(subset=["market_cap_rank"]) + + if len(result_valid_mv) > 0: + min_rank_mv = result_valid_mv["market_cap_rank"].min() + max_rank_mv = result_valid_mv["market_cap_rank"].max() + print(f"\n市值排名范围: [{min_rank_mv:.4f}, {max_rank_mv:.4f}]") + + if 0 <= min_rank_mv <= 1 and 0 <= max_rank_mv <= 1: + print(" [成功] 市值排名值在 [0, 1] 区间内!") + else: + print(" [警告] 市值排名值超出 [0, 1] 区间") + + # 检查某天两支股票的市值排名之和 + sample_date_mv = result_valid_mv["trade_date"][0] + day_data_mv = result_valid_mv.filter( + result_valid_mv["trade_date"] == sample_date_mv + ) + if len(day_data_mv) == 2: + rank_sum_mv = day_data_mv["market_cap_rank"].sum() + print(f"\n示例日期 {sample_date_mv} 的市值排名验证:") + for row in day_data_mv.iter_rows(named=True): + print(f" {row['ts_code']}: {row['market_cap_rank']:.4f}") + print(f" 排名之和: {rank_sum_mv:.4f} (两支股票应接近 1)") + + if abs(rank_sum_mv - 1.0) < 0.01: + print(" [成功] 市值排名之和验证通过!") + else: + print(" [警告] 市值排名之和不接近 1") + # ======================================================================== # 6. 统计摘要 # ======================================================================== @@ -394,7 +447,7 @@ def test_two_stocks_string_factors(): print(f"总记录数: {len(stock_data)}") print(f"有效记录数 (去空值后): {len(stock_valid)}") - factor_cols = ["return_5", "return_5_rank", "ma5", "ma10"] + factor_cols = ["return_5", "return_5_rank", "ma5", "ma10", "market_cap_rank"] for col in factor_cols: if col in stock_data.columns: @@ -418,7 +471,6 @@ def test_two_stocks_string_factors(): # ======================================================================== print("\n" + "=" * 80) - # ======================================================================== # 8. 测试总结 # ======================================================================== @@ -434,10 +486,13 @@ def test_two_stocks_string_factors(): print() print("因子定义方式: 字符串表达式 (add_factor 方法)") print("计算因子:") - print(" 1. return_5 - 5日收益率 (字符串: '(close / ts_delay(close, 5)) - 1')") - print(" 2. return_5_rank - 5日收益率截面排名 (字符串: 'cs_rank(...)')") - print(" 3. ma5 - 5日均线 (字符串: 'ts_mean(close, 5)')") - print(" 4. ma10 - 10日均线 (字符串: 'ts_mean(close, 10)')") + print( + " 1. return_5 - 5日收益率 (字符串: '(close / ts_delay(close, 5)) - 1')" + ) + print(" 2. return_5_rank - 5日收益率截面排名 (字符串: 'cs_rank(...)')") + print(" 3. ma5 - 5日均线 (字符串: 'ts_mean(close, 5)')") + print(" 4. ma10 - 10日均线 (字符串: 'ts_mean(close, 10)')") + print(" 5. market_cap_rank - 市值百分比排名 (字符串: 'cs_rank(total_mv)')") print() print("验证结果:") print(" - 字符串表达式解析: 正常")