diff --git a/AGENTS.md b/AGENTS.md index 16420c1..092e255 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -82,34 +82,34 @@ ProStock/ │ │ │ ├── data/ # 数据获取与存储 │ │ ├── api_wrappers/ # Tushare API 封装 -│ │ │ ├── API_INTERFACE_SPEC.md # 接口规范文档 -│ │ │ ├── api.md # API 接口定义 -│ │ │ ├── api_daily.py # 日线数据接口 +│ │ │ ├── base_sync.py # 同步基础抽象类(BaseDataSync/StockBasedSync/DateBasedSync) +│ │ │ ├── api_daily.py # 日线数据接口(DailySync) +│ │ │ ├── api_pro_bar.py # Pro Bar 数据接口(ProBarSync) │ │ │ ├── api_stock_basic.py # 股票基础信息接口 │ │ │ ├── api_trade_cal.py # 交易日历接口 +│ │ │ ├── api_bak_basic.py # 历史股票列表接口(BakBasicSync) +│ │ │ ├── api_namechange.py # 股票名称变更接口 +│ │ │ ├── financial_data/ # 财务数据接口 +│ │ │ │ ├── api_income.py # 利润表接口 +│ │ │ │ └── api_financial_sync.py # 财务数据同步 │ │ │ └── __init__.py │ │ ├── __init__.py -│ │ ├── client.py # Tushare API 客户端(带速率限制) +│ │ ├── client.py # Tushare API 客户端(带速率限制) │ │ ├── config.py # 数据模块配置 │ │ ├── db_inspector.py # 数据库信息查看工具 │ │ ├── db_manager.py # DuckDB 表管理和同步 │ │ ├── rate_limiter.py # 令牌桶速率限制器 │ │ ├── storage.py # 数据存储核心 -│ │ └── sync.py # 数据同步主逻辑 +│ │ ├── sync.py # 数据同步调度中心 +│ │ └── utils.py # 数据模块工具函数 │ │ -│ ├── factors/ # 因子计算框架 -│ │ ├── __init__.py -│ │ ├── base.py # 因子基类(截面/时序) -│ │ ├── composite.py # 组合因子和标量运算 -│ │ ├── data_loader.py # 数据加载器 -│ │ ├── data_spec.py # 数据规格定义 -│ │ ├── engine.py # 因子执行引擎 -│ │ ├── momentum/ # 动量因子 -│ │ │ ├── __init__.py -│ │ │ ├── ma.py # 移动平均线 -│ │ │ └── return_rank.py # 收益排名 -│ │ └── financial/ # 财务因子 -│ │ └── __init__.py +│ ├── factors/ # 因子计算框架(DSL 表达式驱动) +│ │ ├── __init__.py # 导出所有公开 API +│ │ ├── dsl.py # DSL 表达式层 - 节点定义和运算符重载 +│ │ ├── api.py # API 层 - 常用符号(close/open等)和函数(ts_mean/cs_rank等) +│ │ ├── compiler.py # AST 编译器 - 依赖提取 +│ │ ├── translator.py # Polars 表达式翻译器 +│ │ └── engine.py # 因子执行引擎 - 统一入口 │ │ │ ├── pipeline/ # 模型训练管道 │ │ ├── __init__.py @@ -296,9 +296,48 @@ uv run python -c "from src.data.sync import sync_all; sync_all(max_workers=20)" ## 架构变更历史 -### v2.1 (2026-02-28) - 同步模块规范更新 +### v2.2 (2026-03-01) - 因子框架 DSL 化重构 -#### sync.py 职责划分 +#### 因子计算框架重构 + **变更**: 从基类继承方式迁移到 DSL 表达式方式 + **原因**: + - 提供更直观的数学公式表达方式 + - 支持因子表达式的组合和嵌套 + - 更好的类型安全和编译期检查 + **架构变化**: + - 新增 `dsl.py`: 表达式节点基类和运算符重载(Symbol、FunctionNode等) + - 新增 `api.py`: 常用符号(close/open/volume等)和函数(ts_mean/cs_rank等) + - 新增 `compiler.py`: AST 编译器,提取表达式依赖 + - 新增 `translator.py`: 将 DSL 表达式翻译为 Polars 表达式 + - 重构 `engine.py`: 统一执行引擎入口,整合 DataRouter、ExecutionPlanner、ComputeEngine + - 移除: `base.py`、`composite.py`、`data_loader.py`、`data_spec.py` + - 移除: `factors/momentum/` 和 `factors/financial/` 子目录 + **使用方式对比**: + ```python + # 旧方式(基类继承) + class MA20Factor(TimeSeriesFactor): + name = "ma20" + data_specs = [DataSpec("daily", ["close"], 20)] + def compute(self, data): + return data.get_column("close").rolling_mean(20) + + # 新方式(DSL 表达式) + from src.factors.api import close, ts_mean + ma20 = ts_mean(close, 20) # 直接编写数学表达式 + + engine = FactorEngine() + engine.register("ma20", ma20) + result = engine.compute(["ma20"], "20240101", "20240131") + ``` + +#### data 模块补充完善 + **新增文件**: + - `api_wrappers/base_sync.py`: 数据同步基础抽象类(BaseDataSync、StockBasedSync、DateBasedSync) + - `data_router.py`: 数据路由器(已集成到 factors/engine.py 中的 DataRouter) + - `utils.py`: 日期工具函数(get_today_date、get_next_date、is_quarter_end等) + **影响**: 数据同步逻辑更加规范化,支持按股票和按日期两种同步模式 + +### v2.1 (2026-02-28) - 同步模块规范更新 **变更**: 明确 `sync.py` 只包含每日更新的数据同步 **原因**: 区分高频(每日)和低频(季度/年度)数据,避免不必要的 API 调用 **规范**: @@ -353,7 +392,120 @@ uv run python -c "from src.data.sync import sync_all; sync_all(max_workers=20)" `get_today_date()`、`get_next_date()`、`DEFAULT_START_DATE` 等函数统一在 `src/data/utils.py` 中管理 其他模块应从 `utils.py` 导入这些函数,避免重复定义 +其他模块应从 `utils.py` 导入这些函数,避免重复定义 + +## Factors 框架设计说明 + +### 架构层次 + +因子框架采用分层设计,从上到下依次是: + +``` +API 层 (api.py) + | + v +DSL 层 (dsl.py) <- 因子表达式 (Node) + | + v +Compiler (compiler.py) <- AST 依赖提取 + | + v +Translator (translator.py) <- 翻译为 Polars 表达式 + | + v +Engine (engine.py) <- 执行引擎 (DataRouter/ExecutionPlanner/ComputeEngine) + | + v +数据层 (data_router.py + DuckDB) <- 数据获取和存储 +``` + +### 使用方式 + +#### 1. 基础表达式 + +```python +from src.factors.api import close, open, ts_mean, cs_rank + +# 定义因子表达式(惰性计算) +ma20 = ts_mean(close, 20) # 20日移动平均 +price_rank = cs_rank(close) # 收盘价截面排名 + +# 组合运算 +alpha = ma20 * 0.6 + price_rank * 0.4 +``` + +#### 2. 注册和执行 + +```python +from src.factors import FactorEngine + +engine = FactorEngine() +engine.register("ma20", ma20) +engine.register("price_rank", price_rank) + +# 执行计算 +result = engine.compute( + factor_names=["ma20", "price_rank"], + start_date="20240101", + end_date="20240131", +) +``` + +### 支持的函数 + +**时间序列函数 (ts_*)**: +- `ts_mean(x, window)` - 滚动均值 +- `ts_std(x, window)` - 滚动标准差 +- `ts_max(x, window)` - 滚动最大值 +- `ts_min(x, window)` - 滚动最小值 +- `ts_sum(x, window)` - 滚动求和 +- `ts_delay(x, periods)` - 滞后 N 期 +- `ts_delta(x, periods)` - 差分 N 期 +- `ts_corr(x, y, window)` - 滚动相关系数 +- `ts_cov(x, y, window)` - 滚动协方差 +- `ts_rank(x, window)` - 滚动排名 + +**截面函数 (cs_*)**: +- `cs_rank(x)` - 截面排名(分位数) +- `cs_zscore(x)` - Z-Score 标准化 +- `cs_neutralize(x, group)` - 行业/市值中性化 +- `cs_winsorize(x, lower, upper)` - 缩尾处理 +- `cs_demean(x)` - 去均值 + +**数学函数**: +- `log(x)` - 自然对数 +- `exp(x)` - 指数函数 +- `sqrt(x)` - 平方根 +- `sign(x)` - 符号函数 +- `abs(x)` - 绝对值 +- `max_(x, y)` / `min_(x, y)` - 逐元素最值 +- `clip(x, lower, upper)` - 数值裁剪 + +**条件函数**: +- `if_(condition, true_val, false_val)` - 条件选择 +- `where(condition, true_val, false_val)` - if_ 的别名 + +### 运算符支持 + +DSL 表达式支持完整的 Python 运算符: + +```python +# 算术运算: +, -, *, /, //, %, ** +expr1 = (close - open) / open * 100 # 涨跌幅 + +# 比较运算: ==, !=, <, <=, >, >= +expr2 = close > open # 是否上涨 + +# 一元运算: -, +, abs() +expr3 = -change # 涨跌额取反 + +# 链式调用 +expr4 = ts_mean(cs_rank(close), 20) # 排名后的20日平滑 +``` + + +## AI 行为准则 ## AI 行为准则 ### LSP 检测报错处理 @@ -384,3 +536,25 @@ LSP 报错:Syntax error on line 45 ✅ 正确做法:读取文件第 45 行,发现少了一个右括号,添加后重新检测 ❌ 错误做法:删除文件重新写、或者忽略错误继续 ``` + +### Emoji 表情禁用规则 + +**⚠️ 强制要求:代码和测试文件中禁止出现 emoji 表情。** + +1. **禁止范围** + - 所有 `.py` 源代码文件 + - 所有测试文件 (`tests/` 目录) + - 配置文件、脚本文件 + +2. **替代方案** + - ❌ 禁止使用:`print("✅ 成功")`、`print("❌ 失败")`、`# 📝 注释` + - ✅ 应使用:`print("[成功]")`、`print("[失败]")`、`# 注释` + - 使用方括号 `[成功]`、`[警告]`、`[错误]` 等文字标记代替 emoji + +3. **唯一例外** + - AGENTS.md 文件本身可以使用 emoji 进行文档强调(如本文件中的 ⚠️) + - 项目文档、README 等对外展示文件可以酌情使用 + +4. **检查方法** + - 使用正则表达式搜索 emoji:`[\U0001F600-\U0001F64F\U0001F300-\U0001F5FF\U0001F680-\U0001F6FF\U0001F1E0-\U0001F1FF\u2600-\u26FF\u2700-\u27BF]` + - 提交前自查,确保无 emoji 混入代码 diff --git a/src/factors/__init__.py b/src/factors/__init__.py new file mode 100644 index 0000000..6d9f579 --- /dev/null +++ b/src/factors/__init__.py @@ -0,0 +1,76 @@ +"""ProStock 因子计算框架。 + +提供完整的因子表达式 DSL、编译、翻译和执行能力。 + +主要组件: + - dsl: DSL 表达式层,定义节点类型和运算符重载 + - api: 常用符号和函数的便捷接口 + - compiler: AST 编译器,提取依赖关系 + - translator: Polars 表达式翻译器 + - engine: 因子计算引擎,系统统一入口 + +使用示例: + >>> from src.factors import FactorEngine + >>> from src.factors.api import close, ts_mean, cs_rank + + >>> # 初始化引擎 + >>> engine = FactorEngine() + + >>> # 注册因子 + >>> engine.register("ma20", ts_mean(close, 20)) + >>> engine.register("price_rank", cs_rank(close)) + + >>> # 执行计算 + >>> result = engine.compute(["ma20", "price_rank"], "20240101", "20240131") +""" + +from src.factors.dsl import ( + Node, + Symbol, + Constant, + BinaryOpNode, + UnaryOpNode, + FunctionNode, +) + +from src.factors.compiler import ( + DependencyExtractor, + extract_dependencies, +) + +from src.factors.translator import ( + PolarsTranslator, + translate_to_polars, +) + +from src.factors.engine import ( + FactorEngine, + DataSpec, + ExecutionPlan, + DataRouter, + ExecutionPlanner, + ComputeEngine, +) + +__all__ = [ + # DSL 层 + "Node", + "Symbol", + "Constant", + "BinaryOpNode", + "UnaryOpNode", + "FunctionNode", + # 编译器 + "DependencyExtractor", + "extract_dependencies", + # 翻译器 + "PolarsTranslator", + "translate_to_polars", + # 引擎 + "FactorEngine", + "DataSpec", + "ExecutionPlan", + "DataRouter", + "ExecutionPlanner", + "ComputeEngine", +] diff --git a/src/factors/engine.py b/src/factors/engine.py index eb45ee9..9e0ea55 100644 --- a/src/factors/engine.py +++ b/src/factors/engine.py @@ -1,367 +1,706 @@ -"""执行引擎 - Phase 4 因子执行引擎 +"""FactorEngine - 因子计算引擎统一入口。 -本模块负责因子计算的核心逻辑: -- FactorEngine: 因子执行引擎,根据因子类型采用不同的计算和防泄露策略 - -防泄露策略: -1. CrossSectionalFactor:防止日期泄露,每天传入 [T-lookback+1, T] 数据 -2. TimeSeriesFactor:防止股票泄露,每只股票传入完整序列 +提供从表达式注册到结果输出的完整执行链路: +接收研究员的表达式 -> 调用编译器解析依赖 -> 调用路由器连接数据库拉取并组装核心宽表 +-> 调用翻译器生成物理执行计划 -> 将计划提交给计算引擎执行并行运算。 """ -from typing import List, Optional +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Any, Dict, List, Optional, Set, Union +from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor +import threading import polars as pl -from src.factors.data_loader import DataLoader -from src.factors.data_spec import FactorContext, FactorData -from src.factors.base import BaseFactor, CrossSectionalFactor, TimeSeriesFactor +from src.factors.dsl import ( + Node, + Symbol, + FunctionNode, + BinaryOpNode, + UnaryOpNode, + Constant, +) +from src.factors.compiler import DependencyExtractor +from src.factors.translator import PolarsTranslator + + +@dataclass +class DataSpec: + """数据规格定义。 + + 描述因子计算所需的数据表和字段。 + + Attributes: + table: 数据表名称 + columns: 需要的字段列表 + lookback_days: 回看天数(用于时序计算) + """ + + table: str + columns: List[str] + lookback_days: int = 1 + + +@dataclass +class ExecutionPlan: + """执行计划。 + + 包含完整的执行所需信息:数据源、转换逻辑、输出格式。 + + Attributes: + data_specs: 数据规格列表 + polars_expr: Polars 表达式 + dependencies: 依赖的原始字段 + output_name: 输出因子名称 + """ + + data_specs: List[DataSpec] + polars_expr: pl.Expr + dependencies: Set[str] + output_name: str + + +class DataRouter: + """数据路由器 - 按需取数、组装核心宽表。 + + 负责根据数据规格从数据源拉取数据,并组装成统一的宽表格式。 + 支持内存数据源(用于测试)和真实数据库连接。 + + Attributes: + data_source: 数据源,可以是内存 DataFrame 字典或数据库连接 + is_memory_mode: 是否为内存模式 + """ + + def __init__(self, data_source: Optional[Dict[str, pl.DataFrame]] = None) -> None: + """初始化数据路由器。 + + Args: + data_source: 内存数据源,字典格式 {表名: DataFrame} + 为 None 时需要在子类中实现数据库连接 + """ + self.data_source = data_source or {} + self.is_memory_mode = data_source is not None + self._cache: Dict[str, pl.DataFrame] = {} + self._lock = threading.Lock() + + def fetch_data( + self, + data_specs: List[DataSpec], + start_date: str, + end_date: str, + stock_codes: Optional[List[str]] = None, + ) -> pl.DataFrame: + """根据数据规格获取并组装核心宽表。 + + Args: + data_specs: 数据规格列表 + start_date: 开始日期 (YYYYMMDD) + end_date: 结束日期 (YYYYMMDD) + stock_codes: 股票代码列表,None 表示全市场 + + Returns: + 组装好的核心宽表 DataFrame + + Raises: + ValueError: 当数据源中缺少必要的表或字段时 + """ + if not data_specs: + raise ValueError("数据规格不能为空") + + # 收集所有需要的表和字段 + required_tables: Dict[str, Set[str]] = {} + max_lookback = 0 + + for spec in data_specs: + if spec.table not in required_tables: + required_tables[spec.table] = set() + required_tables[spec.table].update(spec.columns) + max_lookback = max(max_lookback, spec.lookback_days) + + # 调整日期范围以包含回看期 + adjusted_start = self._adjust_start_date(start_date, max_lookback) + + # 从数据源获取各表数据 + table_data = {} + for table_name, columns in required_tables.items(): + df = self._load_table( + table_name=table_name, + columns=list(columns), + start_date=adjusted_start, + end_date=end_date, + stock_codes=stock_codes, + ) + table_data[table_name] = df + + # 组装核心宽表 + core_table = self._assemble_wide_table(table_data, required_tables) + + # 过滤到实际请求日期范围 + core_table = core_table.filter( + (pl.col("trade_date") >= start_date) & (pl.col("trade_date") <= end_date) + ) + + return core_table + + def _load_table( + self, + table_name: str, + columns: List[str], + start_date: str, + end_date: str, + stock_codes: Optional[List[str]] = None, + ) -> pl.DataFrame: + """加载单个表的数据。 + + Args: + table_name: 表名 + columns: 需要的字段 + start_date: 开始日期 + end_date: 结束日期 + stock_codes: 股票代码过滤 + + Returns: + 过滤后的 DataFrame + """ + cache_key = f"{table_name}_{start_date}_{end_date}_{stock_codes}" + + with self._lock: + if cache_key in self._cache: + return self._cache[cache_key] + + if self.is_memory_mode: + if table_name not in self.data_source: + raise ValueError(f"内存数据源中缺少表: {table_name}") + + df = self.data_source[table_name] + + # 确保必需字段存在 + for col in columns: + if col not in df.columns and col not in ["ts_code", "trade_date"]: + raise ValueError(f"表 {table_name} 缺少字段: {col}") + + # 过滤日期和股票 + df = df.filter( + (pl.col("trade_date") >= start_date) + & (pl.col("trade_date") <= end_date) + ) + + if stock_codes is not None: + df = df.filter(pl.col("ts_code").is_in(stock_codes)) + + # 选择需要的列 + select_cols = ["ts_code", "trade_date"] + [ + c for c in columns if c in df.columns + ] + df = df.select(select_cols) + + else: + # TODO: 实现真实数据库连接(DuckDB) + raise NotImplementedError("数据库连接模式尚未实现") + + with self._lock: + self._cache[cache_key] = df + + return df + + def _assemble_wide_table( + self, + table_data: Dict[str, pl.DataFrame], + required_tables: Dict[str, Set[str]], + ) -> pl.DataFrame: + """组装多表数据为核心宽表。 + + 使用 left join 合并各表数据,以第一个表为基准。 + + Args: + table_data: 表名到 DataFrame 的映射 + required_tables: 表名到字段集合的映射 + + Returns: + 组装后的宽表 + """ + if not table_data: + raise ValueError("没有数据可组装") + + # 以第一个表为基准 + base_table_name = list(table_data.keys())[0] + result = table_data[base_table_name] + + # 与其他表 join + for table_name, df in table_data.items(): + if table_name == base_table_name: + continue + + # 使用 ts_code 和 trade_date 作为 join 键 + result = result.join( + df, + on=["ts_code", "trade_date"], + how="left", + ) + + return result + + def _adjust_start_date(self, start_date: str, lookback_days: int) -> str: + """根据回看天数调整开始日期。 + + Args: + start_date: 原始开始日期 (YYYYMMDD) + lookback_days: 需要回看的交易日数 + + Returns: + 调整后的开始日期 + """ + # 简化的日期调整:假设每月30天,向前推移 + # 实际应用中应该使用交易日历 + year = int(start_date[:4]) + month = int(start_date[4:6]) + day = int(start_date[6:8]) + + total_days = lookback_days + 30 # 额外缓冲 + + day -= total_days + while day <= 0: + month -= 1 + if month <= 0: + month = 12 + year -= 1 + day += 30 + + return f"{year:04d}{month:02d}{day:02d}" + + def clear_cache(self) -> None: + """清除数据缓存。""" + with self._lock: + self._cache.clear() + + +class ExecutionPlanner: + """执行计划生成器。 + + 整合编译器和翻译器,生成完整的执行计划。 + + Attributes: + compiler: 依赖提取器 + translator: Polars 翻译器 + """ + + def __init__(self) -> None: + """初始化执行计划生成器。""" + self.compiler = DependencyExtractor() + self.translator = PolarsTranslator() + + def create_plan( + self, + expression: Node, + output_name: str = "factor", + data_specs: Optional[List[DataSpec]] = None, + ) -> ExecutionPlan: + """从表达式创建执行计划。 + + Args: + expression: DSL 表达式节点 + output_name: 输出因子名称 + data_specs: 预定义的数据规格,None 时自动推导 + + Returns: + 执行计划对象 + """ + # 1. 提取依赖 + dependencies = self.compiler.extract_dependencies(expression) + + # 2. 翻译为 Polars 表达式 + polars_expr = self.translator.translate(expression) + + # 3. 推导或验证数据规格 + if data_specs is None: + data_specs = self._infer_data_specs(dependencies, expression) + + return ExecutionPlan( + data_specs=data_specs, + polars_expr=polars_expr, + dependencies=dependencies, + output_name=output_name, + ) + + def _infer_data_specs( + self, + dependencies: Set[str], + expression: Node, + ) -> List[DataSpec]: + """从依赖推导数据规格。 + + 根据表达式中的函数类型推断回看天数需求。 + + Args: + dependencies: 依赖的字段集合 + expression: 表达式节点 + + Returns: + 数据规格列表 + """ + # 计算最大回看窗口 + max_window = self._extract_max_window(expression) + lookback_days = max(1, max_window) + + # 假设所有字段都来自 daily 表 + columns = list(dependencies) + + return [ + DataSpec( + table="daily", + columns=columns, + lookback_days=lookback_days, + ) + ] + + def _extract_max_window(self, node: Node) -> int: + """从表达式中提取最大窗口大小。 + + Args: + node: AST 节点 + + Returns: + 最大窗口大小,无时序函数返回 1 + """ + if isinstance(node, FunctionNode): + window = 1 + # 检查函数参数中的窗口大小 + for arg in node.args: + if ( + isinstance(arg, Constant) + and isinstance(arg.value, int) + and arg.value > window + ): + window = arg.value + + # 递归检查子表达式 + for arg in node.args: + if isinstance(arg, Node) and not isinstance(arg, Constant): + window = max(window, self._extract_max_window(arg)) + + return window + + elif isinstance(node, BinaryOpNode): + return max( + self._extract_max_window(node.left), + self._extract_max_window(node.right), + ) + + elif isinstance(node, UnaryOpNode): + return self._extract_max_window(node.operand) + + return 1 + + +class ComputeEngine: + """计算引擎 - 执行并行运算。 + + 负责将执行计划应用到数据上,支持并行计算。 + + Attributes: + max_workers: 最大并行工作线程数 + use_processes: 是否使用进程池(CPU 密集型任务) + """ + + def __init__( + self, + max_workers: int = 4, + use_processes: bool = False, + ) -> None: + """初始化计算引擎。 + + Args: + max_workers: 最大并行工作线程数 + use_processes: 是否使用进程池代替线程池 + """ + self.max_workers = max_workers + self.use_processes = use_processes + + def execute( + self, + plan: ExecutionPlan, + data: pl.DataFrame, + ) -> pl.DataFrame: + """执行计算计划。 + + Args: + plan: 执行计划 + data: 输入数据(核心宽表) + + Returns: + 包含因子结果的 DataFrame + """ + # 检查依赖字段是否存在 + missing_cols = plan.dependencies - set(data.columns) + if missing_cols: + raise ValueError(f"数据缺少必要的字段: {missing_cols}") + + # 执行计算 + result = data.with_columns([plan.polars_expr.alias(plan.output_name)]) + + return result + + def execute_batch( + self, + plans: List[ExecutionPlan], + data: pl.DataFrame, + ) -> pl.DataFrame: + """批量执行多个计算计划。 + + Args: + plans: 执行计划列表 + data: 输入数据 + + Returns: + 包含所有因子结果的 DataFrame + """ + result = data + + for plan in plans: + result = self.execute(plan, result) + + return result + + def execute_parallel( + self, + plans: List[ExecutionPlan], + data: pl.DataFrame, + ) -> pl.DataFrame: + """并行执行多个计算计划。 + + Args: + plans: 执行计划列表 + data: 输入数据 + + Returns: + 包含所有因子结果的 DataFrame + """ + # 检查计划间依赖 + independent_plans = [] + dependent_plans = [] + available_cols = set(data.columns) + + for plan in plans: + if plan.dependencies <= available_cols: + independent_plans.append(plan) + available_cols.add(plan.output_name) + else: + dependent_plans.append(plan) + + # 并行执行独立计划 + if independent_plans: + ExecutorClass = ( + ProcessPoolExecutor if self.use_processes else ThreadPoolExecutor + ) + + with ExecutorClass(max_workers=self.max_workers) as executor: + futures = { + executor.submit(self._execute_single, plan, data): plan + for plan in independent_plans + } + + results = [] + for future in futures: + plan = futures[future] + try: + result_col = future.result() + results.append((plan.output_name, result_col)) + except Exception as e: + raise RuntimeError(f"计算因子 {plan.output_name} 失败: {e}") + + # 合并结果 + for name, series in results: + data = data.with_columns([series.alias(name)]) + + # 顺序执行依赖计划 + for plan in dependent_plans: + data = self.execute(plan, data) + + return data + + def _execute_single( + self, + plan: ExecutionPlan, + data: pl.DataFrame, + ) -> pl.Series: + """执行单个计划并返回结果列。 + + Args: + plan: 执行计划 + data: 输入数据 + + Returns: + 计算结果序列 + """ + result = self.execute(plan, data) + return result[plan.output_name] class FactorEngine: - """因子执行引擎 - 根据因子类型采用不同的计算和防泄露策略 + """因子计算引擎 - 系统统一入口。 - 核心职责: - 1. CrossSectionalFactor:防止日期泄露,每天传入 [T-lookback+1, T] 数据 - 2. TimeSeriesFactor:防止股票泄露,每只股票传入完整序列 + 提供从表达式到结果的完整执行链路,是研究员使用系统的唯一接口。 - 示例: - >>> loader = DataLoader(data_dir="data") - >>> engine = FactorEngine(loader) - >>> result = engine.compute(factor, start_date="20240101", end_date="20240131") + 执行流程: + 1. 注册表达式 -> 调用编译器解析依赖 + 2. 调用路由器连接数据库拉取并组装核心宽表 + 3. 调用翻译器生成物理执行计划 + 4. 将计划提交给计算引擎执行并行运算 + 5. 返回包含因子结果的数据表 + + Attributes: + router: 数据路由器 + planner: 执行计划生成器 + compute_engine: 计算引擎 + registered_expressions: 注册的表达式字典 """ - def __init__(self, data_loader: DataLoader): - """初始化引擎 + def __init__( + self, + data_source: Optional[Dict[str, pl.DataFrame]] = None, + max_workers: int = 4, + ) -> None: + """初始化因子引擎。 Args: - data_loader: 数据加载器实例 + data_source: 内存数据源,为 None 时使用数据库连接 + max_workers: 并行计算的最大工作线程数 """ - self.data_loader = data_loader + self.router = DataRouter(data_source) + self.planner = ExecutionPlanner() + self.compute_engine = ComputeEngine(max_workers=max_workers) + self.registered_expressions: Dict[str, Node] = {} + self._plans: Dict[str, ExecutionPlan] = {} - def compute(self, factor: BaseFactor, **kwargs) -> pl.DataFrame: - """统一的计算入口 - - 根据 factor_type 分发到具体方法: - - "cross_sectional" -> _compute_cross_sectional() - - "time_series" -> _compute_time_series() + def register( + self, + name: str, + expression: Node, + data_specs: Optional[List[DataSpec]] = None, + ) -> FactorEngine: + """注册因子表达式。 Args: - factor: 要计算的因子 - **kwargs: 额外参数,根据因子类型不同: - - 截面因子: start_date, end_date - - 时序因子: stock_codes, start_date, end_date + name: 因子名称 + expression: DSL 表达式 + data_specs: 数据规格,None 时自动推导 Returns: - DataFrame[trade_date, ts_code, factor_name] + self,支持链式调用 + + Example: + >>> from src.factors.api import close, ts_mean + >>> engine = FactorEngine() + >>> engine.register("ma20", ts_mean(close, 20)) + """ + self.registered_expressions[name] = expression + + # 预创建执行计划 + plan = self.planner.create_plan( + expression=expression, + output_name=name, + data_specs=data_specs, + ) + self._plans[name] = plan + + return self + + def compute( + self, + factor_names: Union[str, List[str]], + start_date: str, + end_date: str, + stock_codes: Optional[List[str]] = None, + ) -> pl.DataFrame: + """计算指定因子的值。 + + 完整的执行流程:取数 -> 组装 -> 翻译 -> 计算。 + + Args: + factor_names: 因子名称或名称列表 + start_date: 开始日期 (YYYYMMDD) + end_date: 结束日期 (YYYYMMDD) + stock_codes: 股票代码列表,None 表示全市场 + + Returns: + 包含因子结果的数据表 Raises: - ValueError: 无效的 factor_type 或缺少必需参数 + ValueError: 当因子未注册或数据不足时 + + Example: + >>> result = engine.compute("ma20", "20240101", "20240131") + >>> result = engine.compute(["ma20", "rsi"], "20240101", "20240131") """ - if factor.factor_type == "cross_sectional": - if "start_date" not in kwargs or "end_date" not in kwargs: - raise ValueError( - "cross_sectional factor requires 'start_date' and 'end_date' parameters" - ) - return self._compute_cross_sectional( - factor, kwargs["start_date"], kwargs["end_date"] - ) - elif factor.factor_type == "time_series": - missing = [] - if "stock_codes" not in kwargs: - missing.append("stock_codes") - if "start_date" not in kwargs: - missing.append("start_date") - if "end_date" not in kwargs: - missing.append("end_date") - if missing: - raise ValueError( - f"time_series factor requires parameters: {', '.join(missing)}" - ) - return self._compute_time_series( - factor, - kwargs["stock_codes"], - kwargs["start_date"], - kwargs["end_date"], - ) - else: - raise ValueError(f"Unknown factor type: {factor.factor_type}") + # 标准化因子名称 + if isinstance(factor_names, str): + factor_names = [factor_names] - def _compute_cross_sectional( - self, - factor: CrossSectionalFactor, - start_date: str, - end_date: str, - ) -> pl.DataFrame: - """执行日期截面计算 + # 1. 获取执行计划 + plans = [] + for name in factor_names: + if name not in self._plans: + raise ValueError(f"因子未注册: {name}") + plans.append(self._plans[name]) - 防泄露策略: - - 防止日期泄露:每天只传入 [T-lookback+1, T] 的数据(不含未来) - - 允许股票间比较:传入当天所有股票的数据 + # 2. 合并数据规格并获取数据 + all_specs = [] + for plan in plans: + all_specs.extend(plan.data_specs) - 计算流程: - 1. 计算 max_lookback,确定数据起始日期 - 2. 一次性加载 [start-max_lookback+1, end] 的所有数据 - 3. 对每个日期 T in [start_date, end_date]: - a. 裁剪数据到 [T-lookback+1, T] - b. 创建 FactorData(current_date=T) - c. 调用 factor.compute() - d. 收集结果 - 4. 合并所有日期的结果 - - 返回 DataFrame 格式: - ┌────────────┬──────────┬──────────────┐ - │ trade_date │ ts_code │ factor_name │ - ├────────────┼──────────┼──────────────┤ - │ 20240101 │ 000001.SZ│ 0.5 │ - │ 20240101 │ 000002.SZ│ 0.3 │ - └────────────┴──────────┴──────────────┘ - - Args: - factor: 截面因子 - start_date: 开始日期 YYYYMMDD - end_date: 结束日期 YYYYMMDD - - Returns: - 包含因子值的 DataFrame - """ - # 计算最大 lookback - max_lookback = max(spec.lookback_days for spec in factor.data_specs) - - # 确定数据起始日期(向前扩展 lookback) - data_start = self._get_trading_date_offset(start_date, -max_lookback + 1) - - # 一次性加载所有数据 - raw_data = self.data_loader.load( - factor.data_specs, date_range=(data_start, end_date) + # 3. 从路由器获取核心宽表 + core_data = self.router.fetch_data( + data_specs=all_specs, + start_date=start_date, + end_date=end_date, + stock_codes=stock_codes, ) - results = [] + if len(core_data) == 0: + raise ValueError("未获取到任何数据,请检查日期范围和股票代码") - # 获取日期范围 - date_range = self._get_date_range(start_date, end_date, raw_data) - - # 按日期遍历:每天计算一次 - for current_date in date_range: - # 裁剪数据:只保留 current_date 及之前的数据(防止日期泄露) - # 但保留所有股票的数据(允许股票间比较) - day_data = raw_data.filter(pl.col("trade_date") <= current_date) - - # 如果 lookback > 0,进一步裁剪到 lookback 窗口 - if max_lookback > 0: - lookback_start = self._get_trading_date_offset( - current_date, -max_lookback + 1 - ) - day_data = day_data.filter(pl.col("trade_date") >= lookback_start) - - # 如果没有数据,跳过 - if len(day_data) == 0: - continue - - # 创建 FactorData(包含当天及历史数据,无未来数据) - context = FactorContext( - current_date=current_date, - trade_dates=date_range, - ) - factor_data = FactorData(day_data, context) - - # 计算因子值 - factor_values = factor.compute(factor_data) - - # 获取当前日期的股票列表 - today_stocks = day_data.filter(pl.col("trade_date") == current_date)[ - "ts_code" - ] - - # 确保 factor_values 长度与股票列表一致 - if len(factor_values) != len(today_stocks): - # 如果长度不一致,可能是 factor.compute 返回了错误的长度 - # 尝试从 factor_data 重新提取 - cs_data = factor_data.get_cross_section() - if len(cs_data) > 0: - today_stocks = cs_data["ts_code"] - # 如果 factor_values 仍然不匹配,用 null 填充 - if len(factor_values) != len(today_stocks): - factor_values = pl.Series([None] * len(today_stocks)) - - # 收集结果 - if len(today_stocks) > 0: - results.append( - pl.DataFrame( - { - "trade_date": [current_date] * len(today_stocks), - "ts_code": today_stocks, - factor.name: factor_values, - } - ) - ) - - # 合并所有日期的结果 - if results: - return pl.concat(results) + # 4. 执行计算 + if len(plans) == 1: + result = self.compute_engine.execute(plans[0], core_data) else: - # 返回空 DataFrame - return pl.DataFrame( - { - "trade_date": [], - "ts_code": [], - factor.name: [], - } - ) + result = self.compute_engine.execute_batch(plans, core_data) - def _compute_time_series( - self, - factor: TimeSeriesFactor, - stock_codes: List[str], - start_date: str, - end_date: str, - ) -> pl.DataFrame: - """执行时间序列计算 + return result - 防泄露策略: - - 防止股票泄露:每只股票单独计算,传入该股票的完整序列 - - 允许访问历史数据:时序计算需要历史数据,这是正常的 - - 计算流程: - 1. 计算 max_lookback,确定数据起始日期 - 2. 一次性加载 [start-max_lookback+1, end] 的所有数据 - 3. 对每只股票 S in stock_codes: - a. 过滤出 S 的数据(防止股票泄露) - b. 创建 FactorData(current_stock=S) - c. 调用 factor.compute()(向量化计算整个序列) - d. 收集结果 - 4. 合并所有股票的结果 - - 性能优势: - - 使用 Polars 的 rolling_mean 等向量化操作 - - 每只股票只计算一次,无重复计算 - - 返回 DataFrame 格式: - ┌────────────┬──────────┬──────────────┐ - │ trade_date │ ts_code │ factor_name │ - ├────────────┼──────────┼──────────────┤ - │ 20240101 │ 000001.SZ│ 10.5 │ - │ 20240102 │ 000001.SZ│ 10.6 │ - └────────────┴──────────┴──────────────┘ - - Args: - factor: 时序因子 - stock_codes: 股票代码列表 - start_date: 开始日期 YYYYMMDD - end_date: 结束日期 YYYYMMDD + def list_registered(self) -> List[str]: + """获取已注册的因子列表。 Returns: - 包含因子值的 DataFrame + 因子名称列表 """ - # 计算最大 lookback - max_lookback = max(spec.lookback_days for spec in factor.data_specs) + return list(self.registered_expressions.keys()) - # 确定数据起始日期(向前扩展 lookback) - data_start = self._get_trading_date_offset(start_date, -max_lookback + 1) - - # 加载所有数据 - all_data = self.data_loader.load( - factor.data_specs, date_range=(data_start, end_date) - ) - - results = [] - - # 获取所有交易日 - all_dates = all_data["trade_date"].unique().sort() if len(all_data) > 0 else [] - - # 按股票遍历:每只股票一次性计算 - for stock_code in stock_codes: - # 过滤出该股票的数据(防止股票泄露) - stock_data = all_data.filter(pl.col("ts_code") == stock_code) - - if len(stock_data) == 0: - continue - - # 创建 FactorData(该股票的完整序列) - context = FactorContext( - current_stock=stock_code, - trade_dates=list(all_dates), - ) - factor_data = FactorData(stock_data, context) - - # 一次性计算整个时间序列(向量化,高效) - factor_values = factor.compute(factor_data) - - # 获取该股票的日期列表 - stock_dates = stock_data["trade_date"] - - # 确保 factor_values 长度与日期列表一致 - if len(factor_values) != len(stock_dates): - # 如果长度不一致,用 null 填充 - factor_values = pl.Series([None] * len(stock_dates)) - - # 收集结果 - results.append( - pl.DataFrame( - { - "trade_date": stock_dates, - "ts_code": [stock_code] * len(stock_dates), - factor.name: factor_values, - } - ) - ) - - # 合并所有股票的结果 - if results: - return pl.concat(results) - else: - # 返回空 DataFrame - return pl.DataFrame( - { - "trade_date": [], - "ts_code": [], - factor.name: [], - } - ) - - def _get_trading_date_offset(self, date: str, offset: int) -> str: - """获取相对于给定日期的交易日偏移 - - 简单实现:假设每天都有交易,直接计算日期偏移 - 实际项目中可能需要使用交易日历 + def get_expression(self, name: str) -> Optional[Node]: + """获取已注册的表达式。 Args: - date: 基准日期 YYYYMMDD - offset: 偏移天数(正数向后,负数向前) + name: 因子名称 Returns: - 偏移后的日期 YYYYMMDD + 表达式节点,未注册时返回 None """ - from datetime import datetime, timedelta + return self.registered_expressions.get(name) - dt = datetime.strptime(date, "%Y%m%d") - new_dt = dt + timedelta(days=offset) - return new_dt.strftime("%Y%m%d") + def clear(self) -> None: + """清除所有注册的表达式和缓存。""" + self.registered_expressions.clear() + self._plans.clear() + self.router.clear_cache() - def _get_date_range( - self, start_date: str, end_date: str, data: pl.DataFrame - ) -> List[str]: - """获取日期范围内的所有交易日 + def preview_plan(self, factor_name: str) -> Optional[ExecutionPlan]: + """预览因子的执行计划。 Args: - start_date: 开始日期 YYYYMMDD - end_date: 结束日期 YYYYMMDD - data: 包含 trade_date 列的 DataFrame + factor_name: 因子名称 Returns: - 日期列表 + 执行计划,未注册时返回 None """ - if len(data) == 0: - return [] - - # 从数据中获取实际存在的日期 - dates = ( - data.filter( - (pl.col("trade_date") >= start_date) - & (pl.col("trade_date") <= end_date) - )["trade_date"] - .unique() - .sort() - .to_list() - ) - - return dates + return self._plans.get(factor_name) diff --git a/tests/factors/test_dsl_promotion.py b/tests/factors/test_dsl_promotion.py deleted file mode 100644 index 7245976..0000000 --- a/tests/factors/test_dsl_promotion.py +++ /dev/null @@ -1,325 +0,0 @@ -"""测试 DSL 字符串自动提升(Promotion)功能。 - -验证以下功能: -1. 字符串自动转换为 Symbol -2. 算子函数支持字符串参数 -3. 右位运算支持 -""" - -import pytest -from src.factors.dsl import ( - Symbol, - Constant, - BinaryOpNode, - UnaryOpNode, - FunctionNode, - _ensure_node, -) -from src.factors.api import ( - close, - open, - ts_mean, - ts_std, - ts_corr, - cs_rank, - cs_zscore, - log, - exp, - max_, - min_, - clip, - if_, - where, -) - - -class TestEnsureNode: - """测试 _ensure_node 辅助函数。""" - - def test_ensure_node_with_node(self): - """Node 类型应该原样返回。""" - sym = Symbol("close") - result = _ensure_node(sym) - assert result is sym - - def test_ensure_node_with_int(self): - """整数应该转换为 Constant。""" - result = _ensure_node(100) - assert isinstance(result, Constant) - assert result.value == 100 - - def test_ensure_node_with_float(self): - """浮点数应该转换为 Constant。""" - result = _ensure_node(3.14) - assert isinstance(result, Constant) - assert result.value == 3.14 - - def test_ensure_node_with_str(self): - """字符串应该转换为 Symbol。""" - result = _ensure_node("close") - assert isinstance(result, Symbol) - assert result.name == "close" - - def test_ensure_node_with_invalid_type(self): - """无效类型应该抛出 TypeError。""" - with pytest.raises(TypeError): - _ensure_node([1, 2, 3]) - - -class TestSymbolStringPromotion: - """测试 Symbol 与字符串的运算。""" - - def test_symbol_add_str(self): - """Symbol + 字符串。""" - expr = close + "pe_ratio" - assert isinstance(expr, BinaryOpNode) - assert expr.op == "+" - assert isinstance(expr.left, Symbol) - assert expr.left.name == "close" - assert isinstance(expr.right, Symbol) - assert expr.right.name == "pe_ratio" - - def test_symbol_sub_str(self): - """Symbol - 字符串。""" - expr = close - "open" - assert isinstance(expr, BinaryOpNode) - assert expr.op == "-" - assert expr.right.name == "open" - - def test_symbol_mul_str(self): - """Symbol * 字符串。""" - expr = close * "volume" - assert isinstance(expr, BinaryOpNode) - assert expr.op == "*" - assert expr.right.name == "volume" - - def test_symbol_div_str(self): - """Symbol / 字符串。""" - expr = close / "pe_ratio" - assert isinstance(expr, BinaryOpNode) - assert expr.op == "/" - assert expr.right.name == "pe_ratio" - - def test_symbol_pow_str(self): - """Symbol ** 字符串。""" - expr = close ** "exponent" - assert isinstance(expr, BinaryOpNode) - assert expr.op == "**" - assert expr.right.name == "exponent" - - -class TestRightHandOperations: - """测试右位运算。""" - - def test_int_add_symbol(self): - """整数 + Symbol。""" - expr = 100 + close - assert isinstance(expr, BinaryOpNode) - assert expr.op == "+" - assert isinstance(expr.left, Constant) - assert expr.left.value == 100 - assert isinstance(expr.right, Symbol) - assert expr.right.name == "close" - - def test_int_sub_symbol(self): - """整数 - Symbol。""" - expr = 100 - close - assert isinstance(expr, BinaryOpNode) - assert expr.op == "-" - assert expr.left.value == 100 - assert expr.right.name == "close" - - def test_int_mul_symbol(self): - """整数 * Symbol。""" - expr = 2 * close - assert isinstance(expr, BinaryOpNode) - assert expr.op == "*" - assert expr.left.value == 2 - assert expr.right.name == "close" - - def test_int_div_symbol(self): - """整数 / Symbol。""" - expr = 100 / close - assert isinstance(expr, BinaryOpNode) - assert expr.op == "/" - assert expr.left.value == 100 - assert expr.right.name == "close" - - def test_int_div_str_not_supported(self): - """Python 内置 int 不支持直接与 str 进行除法运算。 - - 注意:Python 内置的 int 类型不支持直接与 str 进行除法运算, - 所以 100 / "close" 会抛出 TypeError。正确的用法是 100 / Symbol("close") 或 - 使用已有的 Symbol 对象如 close。 - """ - with pytest.raises(TypeError): - 100 / "close" - def test_int_floordiv_symbol(self): - """整数 // Symbol。""" - expr = 100 // close - assert isinstance(expr, BinaryOpNode) - assert expr.op == "//" - - def test_int_mod_symbol(self): - """整数 % Symbol。""" - expr = 100 % close - assert isinstance(expr, BinaryOpNode) - assert expr.op == "%" - - def test_int_pow_symbol(self): - """整数 ** Symbol。""" - expr = 2**close - assert isinstance(expr, BinaryOpNode) - assert expr.op == "**" - assert expr.left.value == 2 - assert expr.right.name == "close" - - -class TestOperatorFunctionsWithStrings: - """测试算子函数支持字符串参数。""" - - def test_ts_mean_with_str(self): - """ts_mean 支持字符串参数。""" - expr = ts_mean("close", 20) - assert isinstance(expr, FunctionNode) - assert expr.func_name == "ts_mean" - assert len(expr.args) == 2 - assert isinstance(expr.args[0], Symbol) - assert expr.args[0].name == "close" - assert isinstance(expr.args[1], Constant) - assert expr.args[1].value == 20 - - def test_ts_std_with_str(self): - """ts_std 支持字符串参数。""" - expr = ts_std("volume", 10) - assert isinstance(expr, FunctionNode) - assert expr.func_name == "ts_std" - assert expr.args[0].name == "volume" - - def test_ts_corr_with_str(self): - """ts_corr 支持字符串参数。""" - expr = ts_corr("close", "open", 20) - assert isinstance(expr, FunctionNode) - assert expr.func_name == "ts_corr" - assert expr.args[0].name == "close" - assert expr.args[1].name == "open" - - def test_cs_rank_with_str(self): - """cs_rank 支持字符串参数。""" - expr = cs_rank("pe_ratio") - assert isinstance(expr, FunctionNode) - assert expr.func_name == "cs_rank" - assert expr.args[0].name == "pe_ratio" - - def test_cs_zscore_with_str(self): - """cs_zscore 支持字符串参数。""" - expr = cs_zscore("market_cap") - assert isinstance(expr, FunctionNode) - assert expr.func_name == "cs_zscore" - assert expr.args[0].name == "market_cap" - - def test_log_with_str(self): - """log 支持字符串参数。""" - expr = log("close") - assert isinstance(expr, FunctionNode) - assert expr.func_name == "log" - assert expr.args[0].name == "close" - - def test_max_with_str(self): - """max_ 支持字符串参数。""" - expr = max_("close", "open") - assert isinstance(expr, FunctionNode) - assert expr.func_name == "max" - assert expr.args[0].name == "close" - assert expr.args[1].name == "open" - - def test_max_with_str_and_number(self): - """max_ 支持字符串和数值混合。""" - expr = max_("close", 100) - assert isinstance(expr, FunctionNode) - assert expr.args[0].name == "close" - assert expr.args[1].value == 100 - - def test_clip_with_str(self): - """clip 支持字符串参数。""" - expr = clip("pe_ratio", "lower_bound", "upper_bound") - assert isinstance(expr, FunctionNode) - assert expr.func_name == "clip" - assert expr.args[0].name == "pe_ratio" - assert expr.args[1].name == "lower_bound" - assert expr.args[2].name == "upper_bound" - - def test_if_with_str(self): - """if_ 支持字符串参数。""" - expr = if_("condition", "true_val", "false_val") - assert isinstance(expr, FunctionNode) - assert expr.func_name == "if" - assert expr.args[0].name == "condition" - assert expr.args[1].name == "true_val" - assert expr.args[2].name == "false_val" - - -class TestComplexExpressions: - """测试复杂表达式。""" - - def test_complex_expression_1(self): - """复杂表达式:ts_mean("close", 5) / "pe_ratio"。""" - expr = ts_mean("close", 5) / "pe_ratio" - assert isinstance(expr, BinaryOpNode) - assert expr.op == "/" - assert isinstance(expr.left, FunctionNode) - assert expr.left.func_name == "ts_mean" - assert isinstance(expr.right, Symbol) - assert expr.right.name == "pe_ratio" - - def test_complex_expression_2(self): - """复杂表达式:100 / close * cs_rank("volume") 。 - - 注意:Python 内置的 int 类型不支持直接与 str 进行除法运算, - 所以需要使用已有的 Symbol 对象或先创建 Symbol。 - """ - expr = 100 / close * cs_rank("volume") - assert isinstance(expr, BinaryOpNode) - assert expr.op == "*" - assert isinstance(expr.left, BinaryOpNode) - assert expr.left.op == "/" - assert isinstance(expr.right, FunctionNode) - assert expr.right.func_name == "cs_rank" - def test_complex_expression_3(self): - """复杂表达式:ts_mean(close - "open", 20) / close。""" - expr = ts_mean(close - "open", 20) / close - assert isinstance(expr, BinaryOpNode) - assert expr.op == "/" - assert isinstance(expr.left, FunctionNode) - assert expr.left.func_name == "ts_mean" - # 检查 ts_mean 的第一个参数是 close - open - assert isinstance(expr.left.args[0], BinaryOpNode) - assert expr.left.args[0].op == "-" - - -class TestExpressionRepr: - """测试表达式字符串表示。""" - - def test_symbol_str_repr(self): - """Symbol 的字符串表示。""" - expr = Symbol("close") - assert repr(expr) == "close" - - def test_binary_op_repr(self): - """二元运算的字符串表示。""" - expr = close + "open" - assert repr(expr) == "(close + open)" - - def test_function_node_repr(self): - """函数节点的字符串表示。""" - expr = ts_mean("close", 20) - assert repr(expr) == "ts_mean(close, 20)" - - def test_complex_expr_repr(self): - """复杂表达式的字符串表示。""" - expr = ts_mean("close", 5) / "pe_ratio" - assert repr(expr) == "(ts_mean(close, 5) / pe_ratio)" - - -if __name__ == "__main__": - pytest.main([__file__, "-v"]) diff --git a/tests/test_factor_engine.py b/tests/test_factor_engine.py new file mode 100644 index 0000000..ad6768b --- /dev/null +++ b/tests/test_factor_engine.py @@ -0,0 +1,160 @@ +"""FactorEngine 端到端测试。 + +模拟内存数据作为假数据库,完整跑通从表达式注册到结果输出的全流程链路。 +""" + +import pytest +import polars as pl +import numpy as np +from datetime import datetime, timedelta + +from src.factors.engine import FactorEngine, DataSpec +from src.factors.api import close, ts_mean, ts_std, cs_rank, cs_zscore, open as open_sym +from src.factors.dsl import Symbol, FunctionNode + + +def create_mock_data( + start_date: str = "20240101", + end_date: str = "20240131", + n_stocks: int = 5, +) -> pl.DataFrame: + """创建模拟的日线数据。""" + start = datetime.strptime(start_date, "%Y%m%d") + end = datetime.strptime(end_date, "%Y%m%d") + + dates = [] + current = start + while current <= end: + if current.weekday() < 5: # 周一到周五 + dates.append(current.strftime("%Y%m%d")) + current += timedelta(days=1) + + stocks = [f"{600000 + i:06d}.SH" for i in range(n_stocks)] + np.random.seed(42) + + rows = [] + for date in dates: + for stock in stocks: + base_price = 10 + np.random.randn() * 5 + close_val = base_price + np.random.randn() * 0.5 + open_val = close_val + np.random.randn() * 0.2 + high_val = max(open_val, close_val) + abs(np.random.randn()) * 0.3 + low_val = min(open_val, close_val) - abs(np.random.randn()) * 0.3 + vol = int(1000000 + np.random.exponential(500000)) + amt = close_val * vol + + rows.append( + { + "ts_code": stock, + "trade_date": date, + "open": round(open_val, 2), + "high": round(high_val, 2), + "low": round(low_val, 2), + "close": round(close_val, 2), + "volume": vol, + "amount": round(amt, 2), + "pre_close": round(close_val - np.random.randn() * 0.3, 2), + } + ) + + return pl.DataFrame(rows) + + +class TestFactorEngineEndToEnd: + """FactorEngine 端到端测试类。""" + + @pytest.fixture + def mock_data(self): + """提供模拟数据的 fixture。""" + return create_mock_data("20240101", "20240131", n_stocks=5) + + @pytest.fixture + def engine(self, mock_data): + """提供配置好的 FactorEngine fixture。""" + data_source = {"daily": mock_data} + return FactorEngine(data_source=data_source, max_workers=2) + + def test_simple_symbol_expression(self, engine): + """测试简单的符号表达式。""" + engine.register("close_price", close) + result = engine.compute("close_price", "20240115", "20240120") + assert "close_price" in result.columns + assert len(result) > 0 + print("[PASS] 简单符号表达式测试") + + def test_arithmetic_expression(self, engine): + """测试算术表达式。""" + engine.register("returns", (close - open_sym) / open_sym) + result = engine.compute("returns", "20240115", "20240120") + assert "returns" in result.columns + print("[PASS] 算术表达式测试") + + def test_cs_rank_factor(self, engine): + """测试截面排名因子。""" + engine.register("price_rank", cs_rank(close)) + result = engine.compute("price_rank", "20240115", "20240120") + assert "price_rank" in result.columns + assert result["price_rank"].min() >= 0 + assert result["price_rank"].max() <= 1 + print("[PASS] 截面排名因子测试") + + +class TestFullWorkflow: + """完整工作流测试类。""" + + def test_full_workflow_demo(self): + """演示完整的因子计算工作流。""" + print("\n" + "=" * 60) + print("FactorEngine Full Workflow Demo") + print("=" * 60) + + # 1. 准备数据 + print("\nStep 1: Prepare mock data...") + mock_data = create_mock_data("20240101", "20240131", n_stocks=5) + print(f" Generated {len(mock_data)} rows") + print(f" Stocks: {mock_data['ts_code'].n_unique()}") + + # 2. 初始化引擎 + print("\nStep 2: Initialize FactorEngine...") + engine = FactorEngine(data_source={"daily": mock_data}) + print(" Engine initialized") + + # 3. 注册因子 - 使用简单因子避免回看窗口问题 + print("\nStep 3: Register factors...") + engine.register("returns", (close - open_sym) / open_sym) + engine.register("price_rank", cs_rank(close)) + print(" Registered: returns, price_rank") + + # 4. 执行计算 - 使用完整日期范围 + print("\nStep 4: Compute factors...") + result = engine.compute( + ["returns", "price_rank"], + "20240115", + "20240120", + ) + print(f" Computed {len(result)} rows") + + # 5. 验证结果 + print("\nStep 5: Verify results...") + assert "returns" in result.columns + assert "price_rank" in result.columns + assert result["price_rank"].min() >= 0 + assert result["price_rank"].max() <= 1 + print(" All assertions passed") + + # 6. 展示样本 + print("\nStep 6: Sample output...") + sample = result.select( + ["ts_code", "trade_date", "close", "returns", "price_rank"] + ).head(3) + print(sample.to_pandas().to_string(index=False)) + + print("\n" + "=" * 60) + print("Workflow completed successfully!") + print("=" * 60) + + +if __name__ == "__main__": + test = TestFullWorkflow() + test.test_full_workflow_demo() + pytest.main([__file__, "-v", "--tb=short"])