diff --git a/docs/factor_design.md b/docs/factor_design.md index d300612..46cb9d9 100644 --- a/docs/factor_design.md +++ b/docs/factor_design.md @@ -86,3 +86,622 @@ --- + + +## 四、 详细设计规范(新增) + +### 4.1 五层架构总览 + +``` +┌─────────────────────────────────────────────────────────────────┐ +│ Layer 5: 编排层 (Orchestrator) │ +│ - FactorEngine: 统一入口 │ +│ - 协调各层工作流 │ +└─────────────────────────────────────────────────────────────────┘ + ↓ +┌─────────────────────────────────────────────────────────────────┐ +│ Layer 4: 物理执行引擎层 (Execution Engine) │ +│ - PolarsTranslator: AST → Polars表达式 │ +│ - 自动注入分组约束(截面/时序) │ +│ - 执行计算并返回结果 │ +└─────────────────────────────────────────────────────────────────┘ + ↓ +┌─────────────────────────────────────────────────────────────────┐ +│ Layer 3: 动态数据路由层 (Data Router) │ +│ - MetadataRegistry: 字段→表映射 │ +│ - QueryPlanner: 生成最优查询计划 │ +│ - DataAligner: PIT对齐与防未来函数处理 │ +└─────────────────────────────────────────────────────────────────┘ + ↓ +┌─────────────────────────────────────────────────────────────────┐ +│ Layer 2: 编译与分析层 (Compiler) │ +│ - DependencyExtractor: 提取数据依赖 │ +│ - GraphOptimizer: 子表达式合并(预留接口) │ +│ - 输出: 数据需求清单 + 优化后的AST │ +└─────────────────────────────────────────────────────────────────┘ + ↓ +┌─────────────────────────────────────────────────────────────────┐ +│ Layer 1: DSL层 (领域特定语言) │ +│ - AST节点: Field, BinaryOp, UnaryOp, FunctionCall, Constant │ +│ - 算子库: ts_* (时序), cs_* (截面), math_* (数学) │ +│ - 运算符重载: +, -, *, /, >, <, == 等 │ +└─────────────────────────────────────────────────────────────────┘ +``` + +--- + +### 4.2 Layer 1: DSL层详细设计 + +#### 核心设计原则 +- **算子与数据解耦**:算子只描述计算逻辑,不绑定具体数据 +- **纯表达式树**:输出无状态的AST,不涉及任何外部库 +- **延迟执行**:表达式构建时不执行计算,只生成树结构 + +#### AST节点类型体系 + +```python +# 节点基类 +class ASTNode(ABC): + """AST节点基类""" + + @abstractmethod + def accept(self, visitor: "NodeVisitor") -> Any: + """接受访问者""" + pass + + @abstractmethod + def get_children(self) -> List["ASTNode"]: + """获取子节点列表""" + pass + +# 1. 字段节点(叶子节点) +class Field(ASTNode): + """ + 字段节点 - 代表底层数据字段 + 示例: close, volume, pe, pb + """ + name: str # 字段名 + dtype: Optional[str] = None # 数据类型提示 + +# 2. 常量节点(叶子节点) +class Constant(ASTNode): + """ + 常量节点 - 代表常量值 + 示例: 5, 10.5, "20240101" + """ + value: Union[int, float, str] + dtype: str + +# 3. 二元操作节点 +class BinaryOp(ASTNode): + """ + 二元操作节点 + 支持的运算符: +, -, *, /, //, %, **, >, >=, <, <=, ==, !=, &, | + """ + op: str # '+', '-', '*', '/', '>', etc. + left: ASTNode + right: ASTNode + +# 4. 一元操作节点 +class UnaryOp(ASTNode): + """ + 一元操作节点 + 支持的运算符: -, +, ~, abs + """ + op: str # '-', '+', '~', 'abs' + operand: ASTNode + +# 5. 函数调用节点 +class FunctionCall(ASTNode): + """ + 函数调用节点 - 代表算子调用 + 示例: ts_mean(close, 20), cs_rank(pe) + """ + name: str # 函数名 + args: List[ASTNode] + kwargs: Dict[str, Any] + func_type: str # "timeseries" | "cross_sectional" | "math" +``` + +#### 运算符重载规则 + +在 ASTNode 基类中实现运算符重载: + +```python +class ASTNode: + # 算术运算符 + def __add__(self, other) -> BinaryOp: + return BinaryOp("+", self, _ensure_node(other)) + + def __sub__(self, other) -> BinaryOp: + return BinaryOp("-", self, _ensure_node(other)) + + def __mul__(self, other) -> BinaryOp: + return BinaryOp("*", self, _ensure_node(other)) + + def __truediv__(self, other) -> BinaryOp: + return BinaryOp("/", self, _ensure_node(other)) + + # 反向运算符(支持 5 * field) + def __radd__(self, other) -> BinaryOp: + return BinaryOp("+", _ensure_node(other), self) + + def __rmul__(self, other) -> BinaryOp: + return BinaryOp("*", _ensure_node(other), self) + + # 比较运算符 + def __gt__(self, other) -> BinaryOp: + return BinaryOp(">", self, _ensure_node(other)) + + def __lt__(self, other) -> BinaryOp: + return BinaryOp("<", self, _ensure_node(other)) + + # 一元运算符 + def __neg__(self) -> UnaryOp: + return UnaryOp("-", self) +``` + +#### 算子库规范 + +算子按功能分为三类: + +| 前缀 | 类别 | 说明 | 示例 | +|------|------|------|------| +| `ts_` | 时序算子 | 在时间序列上计算,需按股票分组 | `ts_mean`, `ts_std`, `ts_sum` | +| `cs_` | 截面算子 | 在截面上计算,需按日期分组 | `cs_rank`, `cs_zscore`, `cs_percentile` | +| `math_` | 数学算子 | 逐元素计算,无需分组 | `math_log`, `math_exp`, `math_sqrt` | + +**时序算子列表(ts_*)**: +```python +ts_mean(field, window: int) # 移动平均 +ts_std(field, window: int) # 移动标准差 +ts_sum(field, window: int) # 移动求和 +ts_max(field, window: int) # 移动最大值 +ts_min(field, window: int) # 移动最小值 +ts_delta(field, period: int = 1) # 差分 +ts_pct_change(field, period: int = 1) # 百分比变化 +ts_corr(f1, f2, window: int) # 滚动相关系数 +``` + +**截面算子列表(cs_*)**: +```python +cs_rank(field) # 截面排名(0-1) +cs_percentile(field) # 截面分位数 +cs_zscore(field) # Z-Score标准化 +cs_mean(field) # 截面均值 +cs_std(field) # 截面标准差 +``` + +**数学算子列表(math_*)**: +```python +math_log(field) # 自然对数 +math_exp(field) # 指数 +math_sqrt(field) # 平方根 +math_abs(field) # 绝对值 +``` + +#### 表达式构建示例 + +```python +from src.factors.dsl import Field, ts_mean, cs_rank + +# ========== 示例 1: 简单移动平均线因子 ========== +close = Field("close") +ma20 = ts_mean(close, 20) +factor1 = ma20 + +# ========== 示例 2: 双均线差值因子 ========== +close = Field("close") +ma20 = ts_mean(close, 20) +ma5 = ts_mean(close, 5) +factor2 = (ma20 - ma5) / close + +# ========== 示例 3: 复杂多因子组合 ========== +close = Field("close") +volume = Field("volume") +pe = Field("pe") + +price_momentum = ts_pct_change(close, 20) +vol_ma = ts_mean(volume, 20) +vol_ratio = volume / vol_ma +pe_rank = cs_rank(pe) + +factor3 = price_momentum * 0.4 + vol_ratio * 0.3 + pe_rank * 0.3 +``` + +--- + +### 4.3 Layer 2: 编译层详细设计 + +#### 依赖提取器 + +```python +class DependencyExtractor(NodeVisitor): + """ + 依赖提取器 - 遍历AST收集数据依赖 + 输出: DataRequirement + - fields: Set[str] 需要的字段列表 + - min_lookback: Dict[str, int] 每个字段的最小回看天数 + """ + + def __init__(self): + self.fields: Set[str] = set() + self.field_lookback: Dict[str, int] = defaultdict(int) + + def visit_field(self, node: Field) -> None: + """记录字段依赖""" + self.fields.add(node.name) + self.field_lookback[node.name] = max( + self.field_lookback[node.name], 1 + ) + + def visit_function_call(self, node: FunctionCall) -> None: + """处理函数调用,提取窗口参数""" + for arg in node.args: + arg.accept(self) + + if node.func_type == "timeseries": + window = self._extract_window(node) + self._update_lookback(node.args[0], window) + + def extract(self, root: ASTNode) -> DataRequirement: + """执行提取""" + root.accept(self) + return DataRequirement( + fields=self.fields, + lookback=dict(self.field_lookback) + ) +``` + +#### 数据需求规格 + +```python +@dataclass +class DataRequirement: + """ + 数据需求规格 + + 属性: + fields: 需要的字段集合 + lookback: 每个字段需要回看的天数 + date_range: 计算日期范围 (start, end) + """ + fields: Set[str] + lookback: Dict[str, int] + date_range: Optional[Tuple[str, str]] = None + + def get_max_lookback(self) -> int: + """获取最大回看天数""" + return max(self.lookback.values()) if self.lookback else 1 +``` + +--- + +### 4.4 Layer 3: 数据路由层详细设计 + +#### 元数据注册表 + +```python +@dataclass +class FieldMetadata: + """ + 字段元数据 + + 属性: + name: 字段名 + table: 所属表名 + dtype: 数据类型 + freq: 数据频度 ("daily", "quarterly", "pit") + announce_date_field: 公告日字段名(PIT数据使用) + """ + name: str + table: str + dtype: str + freq: str + announce_date_field: Optional[str] = None + +class MetadataRegistry: + """ + 元数据注册表 - 管理字段到表的映射 + 单例模式,系统启动时加载配置 + """ + + def register(self, metadata: FieldMetadata) -> None: + """注册字段元数据""" + pass + + def get_table(self, field: str) -> str: + """获取字段所属表""" + pass + + def group_by_table(self, fields: Set[str]) -> Dict[str, Set[str]]: + """按表分组字段""" + pass +``` + +#### PIT对齐策略 + +```python +class DataAligner: + """ + 数据对齐器 - 处理多表数据合并与PIT对齐 + """ + + def align( + self, + dataframes: Dict[str, pl.DataFrame], + plans: List[QueryPlan] + ) -> pl.DataFrame: + """ + 对齐并合并多个数据表 + + 步骤: + 1. 分离日频表和PIT表 + 2. 日频表直接join + 3. PIT表使用asof join + 4. 最终排序 + """ + pass + + def _asof_join( + self, + left: pl.DataFrame, + right: pl.DataFrame, + announce_date_field: str + ) -> pl.DataFrame: + """ + 执行PIT asof join + 策略: 对于每个交易日,使用最新公告的数据 + """ + return left.join_asof( + right, + left_on="trade_date", + right_on=announce_date_field, + by="ts_code", + strategy="backward" + ) +``` + +--- + +### 4.5 Layer 4: 执行引擎层详细设计 + +#### Polars翻译器 + +```python +class PolarsTranslator(NodeVisitor): + """ + Polars翻译器 - 将AST翻译为Polars表达式 + """ + + def __init__(self, df: pl.LazyFrame): + self.df = df + + def translate(self, root: ASTNode) -> pl.Expr: + """翻译AST为Polars表达式""" + return root.accept(self) + + def visit_field(self, node: Field) -> pl.Expr: + """字段 → pl.col()""" + return pl.col(node.name) + + def visit_binary_op(self, node: BinaryOp) -> pl.Expr: + """二元操作 → Polars运算符""" + left = node.left.accept(self) + right = node.right.accept(self) + + ops = { + "+": lambda a, b: a + b, + "-": lambda a, b: a - b, + "*": lambda a, b: a * b, + "/": lambda a, b: a / b, + } + + return ops[node.op](left, right) + + def visit_function_call(self, node: FunctionCall) -> pl.Expr: + """ + 函数调用 → Polars窗口函数 + 关键:根据func_type注入分组约束 + """ + args = [arg.accept(self) for arg in node.args] + impl = self._get_impl(node.name) + + if node.func_type == "timeseries": + return impl(*args).over("ts_code") + elif node.func_type == "cross_sectional": + return impl(*args).over("trade_date") + else: + return impl(*args) +``` + +#### 分组约束注入规则 + +```python +# 时序算子:按股票分组,确保滚动窗口不跨股票 +def inject_timeseries_constraint(expr: pl.Expr) -> pl.Expr: + return expr.over("ts_code") + +# 截面算子:按日期分组,确保排名在每天内部进行 +def inject_cross_sectional_constraint(expr: pl.Expr) -> pl.Expr: + return expr.over("trade_date") +``` + +--- + +### 4.6 Layer 5: 编排层详细设计 + +#### FactorEngine + +```python +class FactorEngine: + """ + 因子执行引擎 - 系统统一入口 + """ + + def __init__( + self, + data_source: DataSource, + registry: MetadataRegistry + ): + self.data_source = data_source + self.registry = registry + self.compiler = Compiler() + self.planner = QueryPlanner(registry) + self.aligner = DataAligner() + + def compute( + self, + expression: ASTNode, + start_date: str, + end_date: str, + stock_codes: Optional[List[str]] = None + ) -> pl.DataFrame: + """ + 计算因子表达式 + + 执行流程: + 1. 编译:提取数据依赖 + 2. 规划:生成查询计划 + 3. 加载:从数据源获取数据 + 4. 对齐:PIT对齐与合并 + 5. 翻译:AST → Polars表达式 + 6. 执行:计算并返回结果 + """ + # Step 1: 编译 + requirement = self.compiler.extract_dependency(expression) + requirement.date_range = (start_date, end_date) + + # Step 2: 规划 + plans = self.planner.plan(requirement) + + # Step 3: 加载 + raw_data = {} + for plan in plans: + df = self.data_source.load(...) + raw_data[plan.table] = df + + # Step 4: 对齐 + aligned_data = self.aligner.align(raw_data, plans) + + # Step 5: 翻译 + translator = PolarsTranslator(aligned_data.lazy()) + polars_expr = translator.translate(expression) + + # Step 6: 执行 + result = aligned_data.with_columns( + polars_expr.alias("factor_value") + ) + + return result +``` + +--- + +## 五、 实施路线图(详细版) + +### 阶段1: 基础架构(Layer 1 + Layer 2) +**目标**: 实现DSL表达式树和依赖提取 + +**任务清单**: +- [ ] 实现AST节点类(Field, Constant, BinaryOp, UnaryOp, FunctionCall) +- [ ] 实现运算符重载 +- [ ] 实现基础算子库(ts_mean, ts_std, cs_rank等) +- [ ] 实现DependencyExtractor +- [ ] 编写单元测试 + +**验收标准**: +```python +close = Field("close") +factor = ts_mean(close, 20) / close + +deps = extract_dependencies(factor) +assert deps.fields == {"close"} +assert deps.lookback == {"close": 20} +``` + +### 阶段2: 数据层(Layer 3) +**目标**: 实现元数据管理和PIT对齐 + +**任务清单**: +- [ ] 实现MetadataRegistry +- [ ] 实现QueryPlanner +- [ ] 实现DataAligner(含asof join) +- [ ] 集成DuckDB数据源 + +### 阶段3: 执行层(Layer 4) +**目标**: 实现Polars翻译和执行 + +**任务清单**: +- [ ] 实现PolarsTranslator +- [ ] 实现算子到Polars的映射 +- [ ] 实现分组约束注入 + +### 阶段4: 编排层(Layer 5) +**目标**: 实现FactorEngine统一入口 + +**任务清单**: +- [ ] 实现FactorEngine +- [ ] 整合各层组件 +- [ ] 编写端到端测试 + +--- + +## 六、 关键设计决策 + +### 6.1 为什么使用Visitor模式? +- **扩展性**: 新增节点类型只需添加visit方法 +- **分离关注点**: 遍历逻辑与处理逻辑分离 +- **类型安全**: 每个节点类型有明确的处理函数 + +### 6.2 为什么算子需要分类(ts_/cs_/math_)? +- **显式分组**: 用户明确知道计算维度 +- **约束注入**: 系统根据前缀自动注入正确的分组 +- **错误预防**: 避免截面/时序算子混用导致的逻辑错误 + +### 6.3 向后兼容性 +**决策**: 完全重构,不保留旧API + +**理由**: +- 新旧架构差异过大(绑定vs解耦) +- 保持旧API会增加维护负担 +- 量化策略代码通常是一次性编写,迁移成本可控 + +--- + +## 七、 附录 + +### A. 完整算子列表 + +**时序算子 (ts_*)**: ts_mean, ts_std, ts_var, ts_sum, ts_max, ts_min, ts_product, ts_median, ts_argmax, ts_argmin, ts_skew, ts_kurt, ts_delta, ts_pct_change, ts_corr, ts_cov, ts_rank + +**截面算子 (cs_*)**: cs_rank, cs_percentile, cs_zscore, cs_mean, cs_std, cs_median, cs_max, cs_min + +**数学算子 (math_*)**: math_log, math_log1p, math_exp, math_sqrt, math_abs, math_sign, math_power + +### B. 元数据配置示例 + +```python +METADATA = [ + {"name": "close", "table": "daily", "dtype": "float64", "freq": "daily"}, + {"name": "volume", "table": "daily", "dtype": "float64", "freq": "daily"}, + {"name": "pe", "table": "daily", "dtype": "float64", "freq": "daily"}, + {"name": "eps", "table": "financial_income", "dtype": "float64", + "freq": "pit", "announce_date_field": "ann_date"}, +] +``` + +### C. 与现有代码对比 + +| 维度 | 现有实现 | 新设计 | +|------|---------|--------| +| 因子定义 | 类继承 | 表达式 | +| 数据绑定 | data_specs硬编码 | 元数据注册表 | +| 组合方式 | CompositeFactor包装 | AST节点自然组合 | +| 执行时机 | 立即执行 | 延迟执行 | +| 防泄露 | 手动控制 | 自动注入分组约束 | +| 可优化性 | 低 | 高 | + +--- + +**文档版本**: 2.0 | **更新日期**: 2026-02-26