refactor(factor): 完成因子框架 DSL 化重构
- 重构 FactorEngine 实现完整的 DSL 表达式执行链路 - 新增 DataRouter 数据路由器,支持内存模式和核心宽表组装 - 新增 ExecutionPlanner 执行计划生成器,整合编译器和翻译器 - 新增 ComputeEngine 计算引擎,支持并行运算 - 完善 factors/__init__.py 公开 API 导出 - 新增 test_factor_engine.py 引擎单元测试 - 移除旧引擎实现和废弃的 DSL promotion 测试 - 更新 AGENTS.md 添加 v2.2 架构变更历史和 Factors 框架设计说明
This commit is contained in:
214
AGENTS.md
214
AGENTS.md
@@ -82,34 +82,34 @@ ProStock/
|
|||||||
│ │
|
│ │
|
||||||
│ ├── data/ # 数据获取与存储
|
│ ├── data/ # 数据获取与存储
|
||||||
│ │ ├── api_wrappers/ # Tushare API 封装
|
│ │ ├── api_wrappers/ # Tushare API 封装
|
||||||
│ │ │ ├── API_INTERFACE_SPEC.md # 接口规范文档
|
│ │ │ ├── base_sync.py # 同步基础抽象类(BaseDataSync/StockBasedSync/DateBasedSync)
|
||||||
│ │ │ ├── api.md # API 接口定义
|
│ │ │ ├── api_daily.py # 日线数据接口(DailySync)
|
||||||
│ │ │ ├── api_daily.py # 日线数据接口
|
│ │ │ ├── api_pro_bar.py # Pro Bar 数据接口(ProBarSync)
|
||||||
│ │ │ ├── api_stock_basic.py # 股票基础信息接口
|
│ │ │ ├── api_stock_basic.py # 股票基础信息接口
|
||||||
│ │ │ ├── api_trade_cal.py # 交易日历接口
|
│ │ │ ├── api_trade_cal.py # 交易日历接口
|
||||||
|
│ │ │ ├── api_bak_basic.py # 历史股票列表接口(BakBasicSync)
|
||||||
|
│ │ │ ├── api_namechange.py # 股票名称变更接口
|
||||||
|
│ │ │ ├── financial_data/ # 财务数据接口
|
||||||
|
│ │ │ │ ├── api_income.py # 利润表接口
|
||||||
|
│ │ │ │ └── api_financial_sync.py # 财务数据同步
|
||||||
│ │ │ └── __init__.py
|
│ │ │ └── __init__.py
|
||||||
│ │ ├── __init__.py
|
│ │ ├── __init__.py
|
||||||
│ │ ├── client.py # Tushare API 客户端(带速率限制)
|
│ │ ├── client.py # Tushare API 客户端(带速率限制)
|
||||||
│ │ ├── config.py # 数据模块配置
|
│ │ ├── config.py # 数据模块配置
|
||||||
│ │ ├── db_inspector.py # 数据库信息查看工具
|
│ │ ├── db_inspector.py # 数据库信息查看工具
|
||||||
│ │ ├── db_manager.py # DuckDB 表管理和同步
|
│ │ ├── db_manager.py # DuckDB 表管理和同步
|
||||||
│ │ ├── rate_limiter.py # 令牌桶速率限制器
|
│ │ ├── rate_limiter.py # 令牌桶速率限制器
|
||||||
│ │ ├── storage.py # 数据存储核心
|
│ │ ├── storage.py # 数据存储核心
|
||||||
│ │ └── sync.py # 数据同步主逻辑
|
│ │ ├── sync.py # 数据同步调度中心
|
||||||
|
│ │ └── utils.py # 数据模块工具函数
|
||||||
│ │
|
│ │
|
||||||
│ ├── factors/ # 因子计算框架
|
│ ├── factors/ # 因子计算框架(DSL 表达式驱动)
|
||||||
│ │ ├── __init__.py
|
│ │ ├── __init__.py # 导出所有公开 API
|
||||||
│ │ ├── base.py # 因子基类(截面/时序)
|
│ │ ├── dsl.py # DSL 表达式层 - 节点定义和运算符重载
|
||||||
│ │ ├── composite.py # 组合因子和标量运算
|
│ │ ├── api.py # API 层 - 常用符号(close/open等)和函数(ts_mean/cs_rank等)
|
||||||
│ │ ├── data_loader.py # 数据加载器
|
│ │ ├── compiler.py # AST 编译器 - 依赖提取
|
||||||
│ │ ├── data_spec.py # 数据规格定义
|
│ │ ├── translator.py # Polars 表达式翻译器
|
||||||
│ │ ├── engine.py # 因子执行引擎
|
│ │ └── engine.py # 因子执行引擎 - 统一入口
|
||||||
│ │ ├── momentum/ # 动量因子
|
|
||||||
│ │ │ ├── __init__.py
|
|
||||||
│ │ │ ├── ma.py # 移动平均线
|
|
||||||
│ │ │ └── return_rank.py # 收益排名
|
|
||||||
│ │ └── financial/ # 财务因子
|
|
||||||
│ │ └── __init__.py
|
|
||||||
│ │
|
│ │
|
||||||
│ ├── pipeline/ # 模型训练管道
|
│ ├── pipeline/ # 模型训练管道
|
||||||
│ │ ├── __init__.py
|
│ │ ├── __init__.py
|
||||||
@@ -296,9 +296,48 @@ uv run python -c "from src.data.sync import sync_all; sync_all(max_workers=20)"
|
|||||||
|
|
||||||
## 架构变更历史
|
## 架构变更历史
|
||||||
|
|
||||||
### v2.1 (2026-02-28) - 同步模块规范更新
|
### v2.2 (2026-03-01) - 因子框架 DSL 化重构
|
||||||
|
|
||||||
#### sync.py 职责划分
|
#### 因子计算框架重构
|
||||||
|
**变更**: 从基类继承方式迁移到 DSL 表达式方式
|
||||||
|
**原因**:
|
||||||
|
- 提供更直观的数学公式表达方式
|
||||||
|
- 支持因子表达式的组合和嵌套
|
||||||
|
- 更好的类型安全和编译期检查
|
||||||
|
**架构变化**:
|
||||||
|
- 新增 `dsl.py`: 表达式节点基类和运算符重载(Symbol、FunctionNode等)
|
||||||
|
- 新增 `api.py`: 常用符号(close/open/volume等)和函数(ts_mean/cs_rank等)
|
||||||
|
- 新增 `compiler.py`: AST 编译器,提取表达式依赖
|
||||||
|
- 新增 `translator.py`: 将 DSL 表达式翻译为 Polars 表达式
|
||||||
|
- 重构 `engine.py`: 统一执行引擎入口,整合 DataRouter、ExecutionPlanner、ComputeEngine
|
||||||
|
- 移除: `base.py`、`composite.py`、`data_loader.py`、`data_spec.py`
|
||||||
|
- 移除: `factors/momentum/` 和 `factors/financial/` 子目录
|
||||||
|
**使用方式对比**:
|
||||||
|
```python
|
||||||
|
# 旧方式(基类继承)
|
||||||
|
class MA20Factor(TimeSeriesFactor):
|
||||||
|
name = "ma20"
|
||||||
|
data_specs = [DataSpec("daily", ["close"], 20)]
|
||||||
|
def compute(self, data):
|
||||||
|
return data.get_column("close").rolling_mean(20)
|
||||||
|
|
||||||
|
# 新方式(DSL 表达式)
|
||||||
|
from src.factors.api import close, ts_mean
|
||||||
|
ma20 = ts_mean(close, 20) # 直接编写数学表达式
|
||||||
|
|
||||||
|
engine = FactorEngine()
|
||||||
|
engine.register("ma20", ma20)
|
||||||
|
result = engine.compute(["ma20"], "20240101", "20240131")
|
||||||
|
```
|
||||||
|
|
||||||
|
#### data 模块补充完善
|
||||||
|
**新增文件**:
|
||||||
|
- `api_wrappers/base_sync.py`: 数据同步基础抽象类(BaseDataSync、StockBasedSync、DateBasedSync)
|
||||||
|
- `data_router.py`: 数据路由器(已集成到 factors/engine.py 中的 DataRouter)
|
||||||
|
- `utils.py`: 日期工具函数(get_today_date、get_next_date、is_quarter_end等)
|
||||||
|
**影响**: 数据同步逻辑更加规范化,支持按股票和按日期两种同步模式
|
||||||
|
|
||||||
|
### v2.1 (2026-02-28) - 同步模块规范更新
|
||||||
**变更**: 明确 `sync.py` 只包含每日更新的数据同步
|
**变更**: 明确 `sync.py` 只包含每日更新的数据同步
|
||||||
**原因**: 区分高频(每日)和低频(季度/年度)数据,避免不必要的 API 调用
|
**原因**: 区分高频(每日)和低频(季度/年度)数据,避免不必要的 API 调用
|
||||||
**规范**:
|
**规范**:
|
||||||
@@ -353,7 +392,120 @@ uv run python -c "from src.data.sync import sync_all; sync_all(max_workers=20)"
|
|||||||
`get_today_date()`、`get_next_date()`、`DEFAULT_START_DATE` 等函数统一在 `src/data/utils.py` 中管理
|
`get_today_date()`、`get_next_date()`、`DEFAULT_START_DATE` 等函数统一在 `src/data/utils.py` 中管理
|
||||||
其他模块应从 `utils.py` 导入这些函数,避免重复定义
|
其他模块应从 `utils.py` 导入这些函数,避免重复定义
|
||||||
|
|
||||||
|
其他模块应从 `utils.py` 导入这些函数,避免重复定义
|
||||||
|
|
||||||
|
|
||||||
|
## Factors 框架设计说明
|
||||||
|
|
||||||
|
### 架构层次
|
||||||
|
|
||||||
|
因子框架采用分层设计,从上到下依次是:
|
||||||
|
|
||||||
|
```
|
||||||
|
API 层 (api.py)
|
||||||
|
|
|
||||||
|
v
|
||||||
|
DSL 层 (dsl.py) <- 因子表达式 (Node)
|
||||||
|
|
|
||||||
|
v
|
||||||
|
Compiler (compiler.py) <- AST 依赖提取
|
||||||
|
|
|
||||||
|
v
|
||||||
|
Translator (translator.py) <- 翻译为 Polars 表达式
|
||||||
|
|
|
||||||
|
v
|
||||||
|
Engine (engine.py) <- 执行引擎 (DataRouter/ExecutionPlanner/ComputeEngine)
|
||||||
|
|
|
||||||
|
v
|
||||||
|
数据层 (data_router.py + DuckDB) <- 数据获取和存储
|
||||||
|
```
|
||||||
|
|
||||||
|
### 使用方式
|
||||||
|
|
||||||
|
#### 1. 基础表达式
|
||||||
|
|
||||||
|
```python
|
||||||
|
from src.factors.api import close, open, ts_mean, cs_rank
|
||||||
|
|
||||||
|
# 定义因子表达式(惰性计算)
|
||||||
|
ma20 = ts_mean(close, 20) # 20日移动平均
|
||||||
|
price_rank = cs_rank(close) # 收盘价截面排名
|
||||||
|
|
||||||
|
# 组合运算
|
||||||
|
alpha = ma20 * 0.6 + price_rank * 0.4
|
||||||
|
```
|
||||||
|
|
||||||
|
#### 2. 注册和执行
|
||||||
|
|
||||||
|
```python
|
||||||
|
from src.factors import FactorEngine
|
||||||
|
|
||||||
|
engine = FactorEngine()
|
||||||
|
engine.register("ma20", ma20)
|
||||||
|
engine.register("price_rank", price_rank)
|
||||||
|
|
||||||
|
# 执行计算
|
||||||
|
result = engine.compute(
|
||||||
|
factor_names=["ma20", "price_rank"],
|
||||||
|
start_date="20240101",
|
||||||
|
end_date="20240131",
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
### 支持的函数
|
||||||
|
|
||||||
|
**时间序列函数 (ts_*)**:
|
||||||
|
- `ts_mean(x, window)` - 滚动均值
|
||||||
|
- `ts_std(x, window)` - 滚动标准差
|
||||||
|
- `ts_max(x, window)` - 滚动最大值
|
||||||
|
- `ts_min(x, window)` - 滚动最小值
|
||||||
|
- `ts_sum(x, window)` - 滚动求和
|
||||||
|
- `ts_delay(x, periods)` - 滞后 N 期
|
||||||
|
- `ts_delta(x, periods)` - 差分 N 期
|
||||||
|
- `ts_corr(x, y, window)` - 滚动相关系数
|
||||||
|
- `ts_cov(x, y, window)` - 滚动协方差
|
||||||
|
- `ts_rank(x, window)` - 滚动排名
|
||||||
|
|
||||||
|
**截面函数 (cs_*)**:
|
||||||
|
- `cs_rank(x)` - 截面排名(分位数)
|
||||||
|
- `cs_zscore(x)` - Z-Score 标准化
|
||||||
|
- `cs_neutralize(x, group)` - 行业/市值中性化
|
||||||
|
- `cs_winsorize(x, lower, upper)` - 缩尾处理
|
||||||
|
- `cs_demean(x)` - 去均值
|
||||||
|
|
||||||
|
**数学函数**:
|
||||||
|
- `log(x)` - 自然对数
|
||||||
|
- `exp(x)` - 指数函数
|
||||||
|
- `sqrt(x)` - 平方根
|
||||||
|
- `sign(x)` - 符号函数
|
||||||
|
- `abs(x)` - 绝对值
|
||||||
|
- `max_(x, y)` / `min_(x, y)` - 逐元素最值
|
||||||
|
- `clip(x, lower, upper)` - 数值裁剪
|
||||||
|
|
||||||
|
**条件函数**:
|
||||||
|
- `if_(condition, true_val, false_val)` - 条件选择
|
||||||
|
- `where(condition, true_val, false_val)` - if_ 的别名
|
||||||
|
|
||||||
|
### 运算符支持
|
||||||
|
|
||||||
|
DSL 表达式支持完整的 Python 运算符:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# 算术运算: +, -, *, /, //, %, **
|
||||||
|
expr1 = (close - open) / open * 100 # 涨跌幅
|
||||||
|
|
||||||
|
# 比较运算: ==, !=, <, <=, >, >=
|
||||||
|
expr2 = close > open # 是否上涨
|
||||||
|
|
||||||
|
# 一元运算: -, +, abs()
|
||||||
|
expr3 = -change # 涨跌额取反
|
||||||
|
|
||||||
|
# 链式调用
|
||||||
|
expr4 = ts_mean(cs_rank(close), 20) # 排名后的20日平滑
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
## AI 行为准则
|
||||||
## AI 行为准则
|
## AI 行为准则
|
||||||
|
|
||||||
### LSP 检测报错处理
|
### LSP 检测报错处理
|
||||||
@@ -384,3 +536,25 @@ LSP 报错:Syntax error on line 45
|
|||||||
✅ 正确做法:读取文件第 45 行,发现少了一个右括号,添加后重新检测
|
✅ 正确做法:读取文件第 45 行,发现少了一个右括号,添加后重新检测
|
||||||
❌ 错误做法:删除文件重新写、或者忽略错误继续
|
❌ 错误做法:删除文件重新写、或者忽略错误继续
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### Emoji 表情禁用规则
|
||||||
|
|
||||||
|
**⚠️ 强制要求:代码和测试文件中禁止出现 emoji 表情。**
|
||||||
|
|
||||||
|
1. **禁止范围**
|
||||||
|
- 所有 `.py` 源代码文件
|
||||||
|
- 所有测试文件 (`tests/` 目录)
|
||||||
|
- 配置文件、脚本文件
|
||||||
|
|
||||||
|
2. **替代方案**
|
||||||
|
- ❌ 禁止使用:`print("✅ 成功")`、`print("❌ 失败")`、`# 📝 注释`
|
||||||
|
- ✅ 应使用:`print("[成功]")`、`print("[失败]")`、`# 注释`
|
||||||
|
- 使用方括号 `[成功]`、`[警告]`、`[错误]` 等文字标记代替 emoji
|
||||||
|
|
||||||
|
3. **唯一例外**
|
||||||
|
- AGENTS.md 文件本身可以使用 emoji 进行文档强调(如本文件中的 ⚠️)
|
||||||
|
- 项目文档、README 等对外展示文件可以酌情使用
|
||||||
|
|
||||||
|
4. **检查方法**
|
||||||
|
- 使用正则表达式搜索 emoji:`[\U0001F600-\U0001F64F\U0001F300-\U0001F5FF\U0001F680-\U0001F6FF\U0001F1E0-\U0001F1FF\u2600-\u26FF\u2700-\u27BF]`
|
||||||
|
- 提交前自查,确保无 emoji 混入代码
|
||||||
|
|||||||
76
src/factors/__init__.py
Normal file
76
src/factors/__init__.py
Normal file
@@ -0,0 +1,76 @@
|
|||||||
|
"""ProStock 因子计算框架。
|
||||||
|
|
||||||
|
提供完整的因子表达式 DSL、编译、翻译和执行能力。
|
||||||
|
|
||||||
|
主要组件:
|
||||||
|
- dsl: DSL 表达式层,定义节点类型和运算符重载
|
||||||
|
- api: 常用符号和函数的便捷接口
|
||||||
|
- compiler: AST 编译器,提取依赖关系
|
||||||
|
- translator: Polars 表达式翻译器
|
||||||
|
- engine: 因子计算引擎,系统统一入口
|
||||||
|
|
||||||
|
使用示例:
|
||||||
|
>>> from src.factors import FactorEngine
|
||||||
|
>>> from src.factors.api import close, ts_mean, cs_rank
|
||||||
|
|
||||||
|
>>> # 初始化引擎
|
||||||
|
>>> engine = FactorEngine()
|
||||||
|
|
||||||
|
>>> # 注册因子
|
||||||
|
>>> engine.register("ma20", ts_mean(close, 20))
|
||||||
|
>>> engine.register("price_rank", cs_rank(close))
|
||||||
|
|
||||||
|
>>> # 执行计算
|
||||||
|
>>> result = engine.compute(["ma20", "price_rank"], "20240101", "20240131")
|
||||||
|
"""
|
||||||
|
|
||||||
|
from src.factors.dsl import (
|
||||||
|
Node,
|
||||||
|
Symbol,
|
||||||
|
Constant,
|
||||||
|
BinaryOpNode,
|
||||||
|
UnaryOpNode,
|
||||||
|
FunctionNode,
|
||||||
|
)
|
||||||
|
|
||||||
|
from src.factors.compiler import (
|
||||||
|
DependencyExtractor,
|
||||||
|
extract_dependencies,
|
||||||
|
)
|
||||||
|
|
||||||
|
from src.factors.translator import (
|
||||||
|
PolarsTranslator,
|
||||||
|
translate_to_polars,
|
||||||
|
)
|
||||||
|
|
||||||
|
from src.factors.engine import (
|
||||||
|
FactorEngine,
|
||||||
|
DataSpec,
|
||||||
|
ExecutionPlan,
|
||||||
|
DataRouter,
|
||||||
|
ExecutionPlanner,
|
||||||
|
ComputeEngine,
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
# DSL 层
|
||||||
|
"Node",
|
||||||
|
"Symbol",
|
||||||
|
"Constant",
|
||||||
|
"BinaryOpNode",
|
||||||
|
"UnaryOpNode",
|
||||||
|
"FunctionNode",
|
||||||
|
# 编译器
|
||||||
|
"DependencyExtractor",
|
||||||
|
"extract_dependencies",
|
||||||
|
# 翻译器
|
||||||
|
"PolarsTranslator",
|
||||||
|
"translate_to_polars",
|
||||||
|
# 引擎
|
||||||
|
"FactorEngine",
|
||||||
|
"DataSpec",
|
||||||
|
"ExecutionPlan",
|
||||||
|
"DataRouter",
|
||||||
|
"ExecutionPlanner",
|
||||||
|
"ComputeEngine",
|
||||||
|
]
|
||||||
File diff suppressed because it is too large
Load Diff
@@ -1,325 +0,0 @@
|
|||||||
"""测试 DSL 字符串自动提升(Promotion)功能。
|
|
||||||
|
|
||||||
验证以下功能:
|
|
||||||
1. 字符串自动转换为 Symbol
|
|
||||||
2. 算子函数支持字符串参数
|
|
||||||
3. 右位运算支持
|
|
||||||
"""
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
from src.factors.dsl import (
|
|
||||||
Symbol,
|
|
||||||
Constant,
|
|
||||||
BinaryOpNode,
|
|
||||||
UnaryOpNode,
|
|
||||||
FunctionNode,
|
|
||||||
_ensure_node,
|
|
||||||
)
|
|
||||||
from src.factors.api import (
|
|
||||||
close,
|
|
||||||
open,
|
|
||||||
ts_mean,
|
|
||||||
ts_std,
|
|
||||||
ts_corr,
|
|
||||||
cs_rank,
|
|
||||||
cs_zscore,
|
|
||||||
log,
|
|
||||||
exp,
|
|
||||||
max_,
|
|
||||||
min_,
|
|
||||||
clip,
|
|
||||||
if_,
|
|
||||||
where,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class TestEnsureNode:
|
|
||||||
"""测试 _ensure_node 辅助函数。"""
|
|
||||||
|
|
||||||
def test_ensure_node_with_node(self):
|
|
||||||
"""Node 类型应该原样返回。"""
|
|
||||||
sym = Symbol("close")
|
|
||||||
result = _ensure_node(sym)
|
|
||||||
assert result is sym
|
|
||||||
|
|
||||||
def test_ensure_node_with_int(self):
|
|
||||||
"""整数应该转换为 Constant。"""
|
|
||||||
result = _ensure_node(100)
|
|
||||||
assert isinstance(result, Constant)
|
|
||||||
assert result.value == 100
|
|
||||||
|
|
||||||
def test_ensure_node_with_float(self):
|
|
||||||
"""浮点数应该转换为 Constant。"""
|
|
||||||
result = _ensure_node(3.14)
|
|
||||||
assert isinstance(result, Constant)
|
|
||||||
assert result.value == 3.14
|
|
||||||
|
|
||||||
def test_ensure_node_with_str(self):
|
|
||||||
"""字符串应该转换为 Symbol。"""
|
|
||||||
result = _ensure_node("close")
|
|
||||||
assert isinstance(result, Symbol)
|
|
||||||
assert result.name == "close"
|
|
||||||
|
|
||||||
def test_ensure_node_with_invalid_type(self):
|
|
||||||
"""无效类型应该抛出 TypeError。"""
|
|
||||||
with pytest.raises(TypeError):
|
|
||||||
_ensure_node([1, 2, 3])
|
|
||||||
|
|
||||||
|
|
||||||
class TestSymbolStringPromotion:
|
|
||||||
"""测试 Symbol 与字符串的运算。"""
|
|
||||||
|
|
||||||
def test_symbol_add_str(self):
|
|
||||||
"""Symbol + 字符串。"""
|
|
||||||
expr = close + "pe_ratio"
|
|
||||||
assert isinstance(expr, BinaryOpNode)
|
|
||||||
assert expr.op == "+"
|
|
||||||
assert isinstance(expr.left, Symbol)
|
|
||||||
assert expr.left.name == "close"
|
|
||||||
assert isinstance(expr.right, Symbol)
|
|
||||||
assert expr.right.name == "pe_ratio"
|
|
||||||
|
|
||||||
def test_symbol_sub_str(self):
|
|
||||||
"""Symbol - 字符串。"""
|
|
||||||
expr = close - "open"
|
|
||||||
assert isinstance(expr, BinaryOpNode)
|
|
||||||
assert expr.op == "-"
|
|
||||||
assert expr.right.name == "open"
|
|
||||||
|
|
||||||
def test_symbol_mul_str(self):
|
|
||||||
"""Symbol * 字符串。"""
|
|
||||||
expr = close * "volume"
|
|
||||||
assert isinstance(expr, BinaryOpNode)
|
|
||||||
assert expr.op == "*"
|
|
||||||
assert expr.right.name == "volume"
|
|
||||||
|
|
||||||
def test_symbol_div_str(self):
|
|
||||||
"""Symbol / 字符串。"""
|
|
||||||
expr = close / "pe_ratio"
|
|
||||||
assert isinstance(expr, BinaryOpNode)
|
|
||||||
assert expr.op == "/"
|
|
||||||
assert expr.right.name == "pe_ratio"
|
|
||||||
|
|
||||||
def test_symbol_pow_str(self):
|
|
||||||
"""Symbol ** 字符串。"""
|
|
||||||
expr = close ** "exponent"
|
|
||||||
assert isinstance(expr, BinaryOpNode)
|
|
||||||
assert expr.op == "**"
|
|
||||||
assert expr.right.name == "exponent"
|
|
||||||
|
|
||||||
|
|
||||||
class TestRightHandOperations:
|
|
||||||
"""测试右位运算。"""
|
|
||||||
|
|
||||||
def test_int_add_symbol(self):
|
|
||||||
"""整数 + Symbol。"""
|
|
||||||
expr = 100 + close
|
|
||||||
assert isinstance(expr, BinaryOpNode)
|
|
||||||
assert expr.op == "+"
|
|
||||||
assert isinstance(expr.left, Constant)
|
|
||||||
assert expr.left.value == 100
|
|
||||||
assert isinstance(expr.right, Symbol)
|
|
||||||
assert expr.right.name == "close"
|
|
||||||
|
|
||||||
def test_int_sub_symbol(self):
|
|
||||||
"""整数 - Symbol。"""
|
|
||||||
expr = 100 - close
|
|
||||||
assert isinstance(expr, BinaryOpNode)
|
|
||||||
assert expr.op == "-"
|
|
||||||
assert expr.left.value == 100
|
|
||||||
assert expr.right.name == "close"
|
|
||||||
|
|
||||||
def test_int_mul_symbol(self):
|
|
||||||
"""整数 * Symbol。"""
|
|
||||||
expr = 2 * close
|
|
||||||
assert isinstance(expr, BinaryOpNode)
|
|
||||||
assert expr.op == "*"
|
|
||||||
assert expr.left.value == 2
|
|
||||||
assert expr.right.name == "close"
|
|
||||||
|
|
||||||
def test_int_div_symbol(self):
|
|
||||||
"""整数 / Symbol。"""
|
|
||||||
expr = 100 / close
|
|
||||||
assert isinstance(expr, BinaryOpNode)
|
|
||||||
assert expr.op == "/"
|
|
||||||
assert expr.left.value == 100
|
|
||||||
assert expr.right.name == "close"
|
|
||||||
|
|
||||||
def test_int_div_str_not_supported(self):
|
|
||||||
"""Python 内置 int 不支持直接与 str 进行除法运算。
|
|
||||||
|
|
||||||
注意:Python 内置的 int 类型不支持直接与 str 进行除法运算,
|
|
||||||
所以 100 / "close" 会抛出 TypeError。正确的用法是 100 / Symbol("close") 或
|
|
||||||
使用已有的 Symbol 对象如 close。
|
|
||||||
"""
|
|
||||||
with pytest.raises(TypeError):
|
|
||||||
100 / "close"
|
|
||||||
def test_int_floordiv_symbol(self):
|
|
||||||
"""整数 // Symbol。"""
|
|
||||||
expr = 100 // close
|
|
||||||
assert isinstance(expr, BinaryOpNode)
|
|
||||||
assert expr.op == "//"
|
|
||||||
|
|
||||||
def test_int_mod_symbol(self):
|
|
||||||
"""整数 % Symbol。"""
|
|
||||||
expr = 100 % close
|
|
||||||
assert isinstance(expr, BinaryOpNode)
|
|
||||||
assert expr.op == "%"
|
|
||||||
|
|
||||||
def test_int_pow_symbol(self):
|
|
||||||
"""整数 ** Symbol。"""
|
|
||||||
expr = 2**close
|
|
||||||
assert isinstance(expr, BinaryOpNode)
|
|
||||||
assert expr.op == "**"
|
|
||||||
assert expr.left.value == 2
|
|
||||||
assert expr.right.name == "close"
|
|
||||||
|
|
||||||
|
|
||||||
class TestOperatorFunctionsWithStrings:
|
|
||||||
"""测试算子函数支持字符串参数。"""
|
|
||||||
|
|
||||||
def test_ts_mean_with_str(self):
|
|
||||||
"""ts_mean 支持字符串参数。"""
|
|
||||||
expr = ts_mean("close", 20)
|
|
||||||
assert isinstance(expr, FunctionNode)
|
|
||||||
assert expr.func_name == "ts_mean"
|
|
||||||
assert len(expr.args) == 2
|
|
||||||
assert isinstance(expr.args[0], Symbol)
|
|
||||||
assert expr.args[0].name == "close"
|
|
||||||
assert isinstance(expr.args[1], Constant)
|
|
||||||
assert expr.args[1].value == 20
|
|
||||||
|
|
||||||
def test_ts_std_with_str(self):
|
|
||||||
"""ts_std 支持字符串参数。"""
|
|
||||||
expr = ts_std("volume", 10)
|
|
||||||
assert isinstance(expr, FunctionNode)
|
|
||||||
assert expr.func_name == "ts_std"
|
|
||||||
assert expr.args[0].name == "volume"
|
|
||||||
|
|
||||||
def test_ts_corr_with_str(self):
|
|
||||||
"""ts_corr 支持字符串参数。"""
|
|
||||||
expr = ts_corr("close", "open", 20)
|
|
||||||
assert isinstance(expr, FunctionNode)
|
|
||||||
assert expr.func_name == "ts_corr"
|
|
||||||
assert expr.args[0].name == "close"
|
|
||||||
assert expr.args[1].name == "open"
|
|
||||||
|
|
||||||
def test_cs_rank_with_str(self):
|
|
||||||
"""cs_rank 支持字符串参数。"""
|
|
||||||
expr = cs_rank("pe_ratio")
|
|
||||||
assert isinstance(expr, FunctionNode)
|
|
||||||
assert expr.func_name == "cs_rank"
|
|
||||||
assert expr.args[0].name == "pe_ratio"
|
|
||||||
|
|
||||||
def test_cs_zscore_with_str(self):
|
|
||||||
"""cs_zscore 支持字符串参数。"""
|
|
||||||
expr = cs_zscore("market_cap")
|
|
||||||
assert isinstance(expr, FunctionNode)
|
|
||||||
assert expr.func_name == "cs_zscore"
|
|
||||||
assert expr.args[0].name == "market_cap"
|
|
||||||
|
|
||||||
def test_log_with_str(self):
|
|
||||||
"""log 支持字符串参数。"""
|
|
||||||
expr = log("close")
|
|
||||||
assert isinstance(expr, FunctionNode)
|
|
||||||
assert expr.func_name == "log"
|
|
||||||
assert expr.args[0].name == "close"
|
|
||||||
|
|
||||||
def test_max_with_str(self):
|
|
||||||
"""max_ 支持字符串参数。"""
|
|
||||||
expr = max_("close", "open")
|
|
||||||
assert isinstance(expr, FunctionNode)
|
|
||||||
assert expr.func_name == "max"
|
|
||||||
assert expr.args[0].name == "close"
|
|
||||||
assert expr.args[1].name == "open"
|
|
||||||
|
|
||||||
def test_max_with_str_and_number(self):
|
|
||||||
"""max_ 支持字符串和数值混合。"""
|
|
||||||
expr = max_("close", 100)
|
|
||||||
assert isinstance(expr, FunctionNode)
|
|
||||||
assert expr.args[0].name == "close"
|
|
||||||
assert expr.args[1].value == 100
|
|
||||||
|
|
||||||
def test_clip_with_str(self):
|
|
||||||
"""clip 支持字符串参数。"""
|
|
||||||
expr = clip("pe_ratio", "lower_bound", "upper_bound")
|
|
||||||
assert isinstance(expr, FunctionNode)
|
|
||||||
assert expr.func_name == "clip"
|
|
||||||
assert expr.args[0].name == "pe_ratio"
|
|
||||||
assert expr.args[1].name == "lower_bound"
|
|
||||||
assert expr.args[2].name == "upper_bound"
|
|
||||||
|
|
||||||
def test_if_with_str(self):
|
|
||||||
"""if_ 支持字符串参数。"""
|
|
||||||
expr = if_("condition", "true_val", "false_val")
|
|
||||||
assert isinstance(expr, FunctionNode)
|
|
||||||
assert expr.func_name == "if"
|
|
||||||
assert expr.args[0].name == "condition"
|
|
||||||
assert expr.args[1].name == "true_val"
|
|
||||||
assert expr.args[2].name == "false_val"
|
|
||||||
|
|
||||||
|
|
||||||
class TestComplexExpressions:
|
|
||||||
"""测试复杂表达式。"""
|
|
||||||
|
|
||||||
def test_complex_expression_1(self):
|
|
||||||
"""复杂表达式:ts_mean("close", 5) / "pe_ratio"。"""
|
|
||||||
expr = ts_mean("close", 5) / "pe_ratio"
|
|
||||||
assert isinstance(expr, BinaryOpNode)
|
|
||||||
assert expr.op == "/"
|
|
||||||
assert isinstance(expr.left, FunctionNode)
|
|
||||||
assert expr.left.func_name == "ts_mean"
|
|
||||||
assert isinstance(expr.right, Symbol)
|
|
||||||
assert expr.right.name == "pe_ratio"
|
|
||||||
|
|
||||||
def test_complex_expression_2(self):
|
|
||||||
"""复杂表达式:100 / close * cs_rank("volume") 。
|
|
||||||
|
|
||||||
注意:Python 内置的 int 类型不支持直接与 str 进行除法运算,
|
|
||||||
所以需要使用已有的 Symbol 对象或先创建 Symbol。
|
|
||||||
"""
|
|
||||||
expr = 100 / close * cs_rank("volume")
|
|
||||||
assert isinstance(expr, BinaryOpNode)
|
|
||||||
assert expr.op == "*"
|
|
||||||
assert isinstance(expr.left, BinaryOpNode)
|
|
||||||
assert expr.left.op == "/"
|
|
||||||
assert isinstance(expr.right, FunctionNode)
|
|
||||||
assert expr.right.func_name == "cs_rank"
|
|
||||||
def test_complex_expression_3(self):
|
|
||||||
"""复杂表达式:ts_mean(close - "open", 20) / close。"""
|
|
||||||
expr = ts_mean(close - "open", 20) / close
|
|
||||||
assert isinstance(expr, BinaryOpNode)
|
|
||||||
assert expr.op == "/"
|
|
||||||
assert isinstance(expr.left, FunctionNode)
|
|
||||||
assert expr.left.func_name == "ts_mean"
|
|
||||||
# 检查 ts_mean 的第一个参数是 close - open
|
|
||||||
assert isinstance(expr.left.args[0], BinaryOpNode)
|
|
||||||
assert expr.left.args[0].op == "-"
|
|
||||||
|
|
||||||
|
|
||||||
class TestExpressionRepr:
|
|
||||||
"""测试表达式字符串表示。"""
|
|
||||||
|
|
||||||
def test_symbol_str_repr(self):
|
|
||||||
"""Symbol 的字符串表示。"""
|
|
||||||
expr = Symbol("close")
|
|
||||||
assert repr(expr) == "close"
|
|
||||||
|
|
||||||
def test_binary_op_repr(self):
|
|
||||||
"""二元运算的字符串表示。"""
|
|
||||||
expr = close + "open"
|
|
||||||
assert repr(expr) == "(close + open)"
|
|
||||||
|
|
||||||
def test_function_node_repr(self):
|
|
||||||
"""函数节点的字符串表示。"""
|
|
||||||
expr = ts_mean("close", 20)
|
|
||||||
assert repr(expr) == "ts_mean(close, 20)"
|
|
||||||
|
|
||||||
def test_complex_expr_repr(self):
|
|
||||||
"""复杂表达式的字符串表示。"""
|
|
||||||
expr = ts_mean("close", 5) / "pe_ratio"
|
|
||||||
assert repr(expr) == "(ts_mean(close, 5) / pe_ratio)"
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
pytest.main([__file__, "-v"])
|
|
||||||
160
tests/test_factor_engine.py
Normal file
160
tests/test_factor_engine.py
Normal file
@@ -0,0 +1,160 @@
|
|||||||
|
"""FactorEngine 端到端测试。
|
||||||
|
|
||||||
|
模拟内存数据作为假数据库,完整跑通从表达式注册到结果输出的全流程链路。
|
||||||
|
"""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import polars as pl
|
||||||
|
import numpy as np
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
|
||||||
|
from src.factors.engine import FactorEngine, DataSpec
|
||||||
|
from src.factors.api import close, ts_mean, ts_std, cs_rank, cs_zscore, open as open_sym
|
||||||
|
from src.factors.dsl import Symbol, FunctionNode
|
||||||
|
|
||||||
|
|
||||||
|
def create_mock_data(
|
||||||
|
start_date: str = "20240101",
|
||||||
|
end_date: str = "20240131",
|
||||||
|
n_stocks: int = 5,
|
||||||
|
) -> pl.DataFrame:
|
||||||
|
"""创建模拟的日线数据。"""
|
||||||
|
start = datetime.strptime(start_date, "%Y%m%d")
|
||||||
|
end = datetime.strptime(end_date, "%Y%m%d")
|
||||||
|
|
||||||
|
dates = []
|
||||||
|
current = start
|
||||||
|
while current <= end:
|
||||||
|
if current.weekday() < 5: # 周一到周五
|
||||||
|
dates.append(current.strftime("%Y%m%d"))
|
||||||
|
current += timedelta(days=1)
|
||||||
|
|
||||||
|
stocks = [f"{600000 + i:06d}.SH" for i in range(n_stocks)]
|
||||||
|
np.random.seed(42)
|
||||||
|
|
||||||
|
rows = []
|
||||||
|
for date in dates:
|
||||||
|
for stock in stocks:
|
||||||
|
base_price = 10 + np.random.randn() * 5
|
||||||
|
close_val = base_price + np.random.randn() * 0.5
|
||||||
|
open_val = close_val + np.random.randn() * 0.2
|
||||||
|
high_val = max(open_val, close_val) + abs(np.random.randn()) * 0.3
|
||||||
|
low_val = min(open_val, close_val) - abs(np.random.randn()) * 0.3
|
||||||
|
vol = int(1000000 + np.random.exponential(500000))
|
||||||
|
amt = close_val * vol
|
||||||
|
|
||||||
|
rows.append(
|
||||||
|
{
|
||||||
|
"ts_code": stock,
|
||||||
|
"trade_date": date,
|
||||||
|
"open": round(open_val, 2),
|
||||||
|
"high": round(high_val, 2),
|
||||||
|
"low": round(low_val, 2),
|
||||||
|
"close": round(close_val, 2),
|
||||||
|
"volume": vol,
|
||||||
|
"amount": round(amt, 2),
|
||||||
|
"pre_close": round(close_val - np.random.randn() * 0.3, 2),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
return pl.DataFrame(rows)
|
||||||
|
|
||||||
|
|
||||||
|
class TestFactorEngineEndToEnd:
|
||||||
|
"""FactorEngine 端到端测试类。"""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_data(self):
|
||||||
|
"""提供模拟数据的 fixture。"""
|
||||||
|
return create_mock_data("20240101", "20240131", n_stocks=5)
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def engine(self, mock_data):
|
||||||
|
"""提供配置好的 FactorEngine fixture。"""
|
||||||
|
data_source = {"daily": mock_data}
|
||||||
|
return FactorEngine(data_source=data_source, max_workers=2)
|
||||||
|
|
||||||
|
def test_simple_symbol_expression(self, engine):
|
||||||
|
"""测试简单的符号表达式。"""
|
||||||
|
engine.register("close_price", close)
|
||||||
|
result = engine.compute("close_price", "20240115", "20240120")
|
||||||
|
assert "close_price" in result.columns
|
||||||
|
assert len(result) > 0
|
||||||
|
print("[PASS] 简单符号表达式测试")
|
||||||
|
|
||||||
|
def test_arithmetic_expression(self, engine):
|
||||||
|
"""测试算术表达式。"""
|
||||||
|
engine.register("returns", (close - open_sym) / open_sym)
|
||||||
|
result = engine.compute("returns", "20240115", "20240120")
|
||||||
|
assert "returns" in result.columns
|
||||||
|
print("[PASS] 算术表达式测试")
|
||||||
|
|
||||||
|
def test_cs_rank_factor(self, engine):
|
||||||
|
"""测试截面排名因子。"""
|
||||||
|
engine.register("price_rank", cs_rank(close))
|
||||||
|
result = engine.compute("price_rank", "20240115", "20240120")
|
||||||
|
assert "price_rank" in result.columns
|
||||||
|
assert result["price_rank"].min() >= 0
|
||||||
|
assert result["price_rank"].max() <= 1
|
||||||
|
print("[PASS] 截面排名因子测试")
|
||||||
|
|
||||||
|
|
||||||
|
class TestFullWorkflow:
|
||||||
|
"""完整工作流测试类。"""
|
||||||
|
|
||||||
|
def test_full_workflow_demo(self):
|
||||||
|
"""演示完整的因子计算工作流。"""
|
||||||
|
print("\n" + "=" * 60)
|
||||||
|
print("FactorEngine Full Workflow Demo")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
# 1. 准备数据
|
||||||
|
print("\nStep 1: Prepare mock data...")
|
||||||
|
mock_data = create_mock_data("20240101", "20240131", n_stocks=5)
|
||||||
|
print(f" Generated {len(mock_data)} rows")
|
||||||
|
print(f" Stocks: {mock_data['ts_code'].n_unique()}")
|
||||||
|
|
||||||
|
# 2. 初始化引擎
|
||||||
|
print("\nStep 2: Initialize FactorEngine...")
|
||||||
|
engine = FactorEngine(data_source={"daily": mock_data})
|
||||||
|
print(" Engine initialized")
|
||||||
|
|
||||||
|
# 3. 注册因子 - 使用简单因子避免回看窗口问题
|
||||||
|
print("\nStep 3: Register factors...")
|
||||||
|
engine.register("returns", (close - open_sym) / open_sym)
|
||||||
|
engine.register("price_rank", cs_rank(close))
|
||||||
|
print(" Registered: returns, price_rank")
|
||||||
|
|
||||||
|
# 4. 执行计算 - 使用完整日期范围
|
||||||
|
print("\nStep 4: Compute factors...")
|
||||||
|
result = engine.compute(
|
||||||
|
["returns", "price_rank"],
|
||||||
|
"20240115",
|
||||||
|
"20240120",
|
||||||
|
)
|
||||||
|
print(f" Computed {len(result)} rows")
|
||||||
|
|
||||||
|
# 5. 验证结果
|
||||||
|
print("\nStep 5: Verify results...")
|
||||||
|
assert "returns" in result.columns
|
||||||
|
assert "price_rank" in result.columns
|
||||||
|
assert result["price_rank"].min() >= 0
|
||||||
|
assert result["price_rank"].max() <= 1
|
||||||
|
print(" All assertions passed")
|
||||||
|
|
||||||
|
# 6. 展示样本
|
||||||
|
print("\nStep 6: Sample output...")
|
||||||
|
sample = result.select(
|
||||||
|
["ts_code", "trade_date", "close", "returns", "price_rank"]
|
||||||
|
).head(3)
|
||||||
|
print(sample.to_pandas().to_string(index=False))
|
||||||
|
|
||||||
|
print("\n" + "=" * 60)
|
||||||
|
print("Workflow completed successfully!")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test = TestFullWorkflow()
|
||||||
|
test.test_full_workflow_demo()
|
||||||
|
pytest.main([__file__, "-v", "--tb=short"])
|
||||||
Reference in New Issue
Block a user