fix(factors): 修复 ts_corr/ts_cov 实现并添加 abs 函数支持

- 修复 ts_corr 和 ts_cov 使用 pl.rolling_corr/pl.rolling_cov 模块级函数
- 添加 abs 函数处理器到 translator
- 扩展 notebook 中的因子定义(24 -> 49 个)
- 更新 AGENTS.md 文档结构和 Training 模块说明
This commit is contained in:
2026-03-09 23:37:20 +08:00
parent 88fa848b96
commit f1811815e7
3 changed files with 675 additions and 482 deletions

457
AGENTS.md
View File

@@ -6,9 +6,9 @@ A股量化投资框架 - Python 项目,用于量化股票投资分析。
**⚠️ 强制要求:所有沟通和思考过程必须使用中文。** **⚠️ 强制要求:所有沟通和思考过程必须使用中文。**
所有与 AI Agent 的交流必须使用中文 - 所有与 AI Agent 的交流必须使用中文
代码中的注释和文档字符串使用中文 - 代码中的注释和文档字符串使用中文
禁止使用英文进行思考或沟通 - 禁止使用英文进行思考或沟通
## 构建/检查/测试命令 ## 构建/检查/测试命令
@@ -44,7 +44,7 @@ uv run pytest --cov=src --cov-report=term-missing
```bash ```bash
# 禁止直接使用 python # 禁止直接使用 python
python -c "..." # 禁止! python -c "" # 禁止!
python script.py # 禁止! python script.py # 禁止!
python -m pytest # 禁止! python -m pytest # 禁止!
python -m pip install # 禁止! python -m pip install # 禁止!
@@ -59,7 +59,7 @@ pip list # 禁止!
```bash ```bash
# 运行 Python 代码 # 运行 Python 代码
uv run python -c "..." # ✅ 正确 uv run python -c "" # ✅ 正确
uv run python script.py # ✅ 正确 uv run python script.py # ✅ 正确
# 安装依赖 # 安装依赖
@@ -82,65 +82,79 @@ ProStock/
│ │ │ │
│ ├── data/ # 数据获取与存储 │ ├── data/ # 数据获取与存储
│ │ ├── api_wrappers/ # Tushare API 封装 │ │ ├── api_wrappers/ # Tushare API 封装
│ │ │ ├── base_sync.py # 同步基础抽象类(BaseDataSync/StockBasedSync/DateBasedSync) │ │ │ ├── base_sync.py # 同步基础抽象类
│ │ │ ├── api_daily.py # 日线数据接口(DailySync) │ │ │ ├── api_daily.py # 日线数据接口
│ │ │ ├── api_pro_bar.py # Pro Bar 数据接口(ProBarSync) │ │ │ ├── api_pro_bar.py # Pro Bar 数据接口
│ │ │ ├── api_stock_basic.py # 股票基础信息接口 │ │ │ ├── api_stock_basic.py # 股票基础信息接口
│ │ │ ├── api_trade_cal.py # 交易日历接口 │ │ │ ├── api_trade_cal.py # 交易日历接口
│ │ │ ├── api_bak_basic.py # 历史股票列表接口(BakBasicSync) │ │ │ ├── api_bak_basic.py # 历史股票列表接口
│ │ │ ├── api_namechange.py # 股票名称变更接口 │ │ │ ├── api_namechange.py # 股票名称变更接口
│ │ │ ├── api_stock_st.py # ST股票信息接口
│ │ │ ├── api_daily_basic.py # 每日指标接口
│ │ │ ├── api_stk_limit.py # 涨跌停价格接口
│ │ │ ├── financial_data/ # 财务数据接口 │ │ │ ├── financial_data/ # 财务数据接口
│ │ │ │ ├── api_income.py # 利润表接口 │ │ │ │ ├── api_income.py # 利润表接口
│ │ │ │ ── api_financial_sync.py # 财务数据同步 │ │ │ │ ── api_balance.py # 资产负债表接口
│ │ │ │ ├── api_cashflow.py # 现金流量表接口
│ │ │ │ ├── api_fina_indicator.py # 财务指标接口
│ │ │ │ └── api_financial_sync.py # 财务数据同步调度中心
│ │ │ └── __init__.py │ │ │ └── __init__.py
│ │ ├── __init__.py │ │ ├── __init__.py
│ │ ├── client.py # Tushare API 客户端(带速率限制) │ │ ├── client.py # Tushare API 客户端(带速率限制)
│ │ ├── config.py # 数据模块配置
│ │ ├── data_router.py # 数据路由器factors/engine 专用)
│ │ ├── db_inspector.py # 数据库信息查看工具
│ │ ├── db_manager.py # DuckDB 表管理和同步
│ │ ├── rate_limiter.py # 令牌桶速率限制器
│ │ ├── storage.py # 数据存储核心 │ │ ├── storage.py # 数据存储核心
│ │ ├── db_manager.py # DuckDB 表管理和同步
│ │ ├── db_inspector.py # 数据库信息查看工具
│ │ ├── sync.py # 数据同步调度中心 │ │ ├── sync.py # 数据同步调度中心
│ │ ── utils.py # 数据模块工具函数 │ │ ── sync_registry.py # 同步器注册表
│ │ ├── rate_limiter.py # 令牌桶速率限制器
│ │ ├── catalog.py # 数据目录管理
│ │ ├── config.py # 数据模块配置
│ │ ├── utils.py # 数据模块工具函数
│ │ └── financial_loader.py # 财务数据加载器
│ │ │ │
│ ├── factors/ # 因子计算框架DSL 表达式驱动) │ ├── factors/ # 因子计算框架DSL 表达式驱动)
│ │ ├── engine/ # 执行引擎子模块 │ │ ├── engine/ # 执行引擎子模块
│ │ │ ├── __init__.py # 导出引擎组件 │ │ │ ├── __init__.py # 导出引擎组件
│ │ │ ├── data_spec.py # 数据规格定义(DataSpec, ExecutionPlan) │ │ │ ├── data_spec.py # 数据规格定义
│ │ │ ├── data_router.py # 数据路由器 │ │ │ ├── data_router.py # 数据路由器
│ │ │ ├── planner.py # 执行计划生成器(ExecutionPlanner) │ │ │ ├── planner.py # 执行计划生成器
│ │ │ ├── compute_engine.py # 计算引擎(ComputeEngine) │ │ │ ├── compute_engine.py # 计算引擎
│ │ │ ── factor_engine.py # 因子引擎统一入口(FactorEngine) │ │ │ ── schema_cache.py # 表结构缓存
│ │ │ └── factor_engine.py # 因子引擎统一入口
│ │ ├── __init__.py # 导出所有公开 API │ │ ├── __init__.py # 导出所有公开 API
│ │ ├── dsl.py # DSL 表达式层 - 节点定义和运算符重载 │ │ ├── dsl.py # DSL 表达式层 - 节点定义和运算符重载
│ │ ├── api.py # API 层 - 常用符号(close/open等)和函数(ts_mean/cs_rank等) │ │ ├── api.py # API 层 - 常用符号和函数
│ │ ├── compiler.py # AST 编译器 - 依赖提取 │ │ ├── compiler.py # AST 编译器 - 依赖提取
│ │ ├── translator.py # Polars 表达式翻译器 │ │ ├── translator.py # Polars 表达式翻译器
│ │ ├── parser.py # 字符串公式解析器(FormulaParser) │ │ ├── parser.py # 字符串公式解析器
│ │ ├── registry.py # 函数注册表(FunctionRegistry) │ │ ├── registry.py # 函数注册表
│ │ ── exceptions.py # 异常定义(FormulaParseError等) │ │ ── decorators.py # 装饰器工具
│ │ └── exceptions.py # 异常定义
│ │ │ │
│ ├── pipeline/ # 模型训练管道 │ ├── training/ # 训练模块
│ │ ├── __init__.py │ │ ├── core/ # 训练核心组件
│ │ ├── pipeline.py # 处理流水线(ProcessingPipeline)
│ │ ├── registry.py # 插件注册中心(PluginRegistry)
│ │ ├── core/ # 核心抽象
│ │ │ ├── __init__.py │ │ │ ├── __init__.py
│ │ │ ├── base.py # 基类定义(BaseProcessor/BaseModel/BaseSplitter等) │ │ │ ├── trainer.py # 训练器主类
│ │ │ └── splitter.py # 时间序列划分策略 │ │ │ └── stock_pool_manager.py # 股票池管理器
│ │ ├── models/ # 模型实现 │ │ ├── components/ # 组件
│ │ │ ├── base.py # 基础抽象类
│ │ │ ├── splitters.py # 数据划分器
│ │ │ ├── selectors.py # 股票选择器
│ │ │ ├── filters.py # 数据过滤器
│ │ │ ├── models/ # 模型实现
│ │ │ │ ├── __init__.py
│ │ │ │ └── lightgbm.py # LightGBM 模型
│ │ │ └── processors/ # 数据处理器
│ │ │ ├── __init__.py
│ │ │ └── transforms.py # 变换处理器
│ │ ├── config/ # 配置
│ │ │ ├── __init__.py │ │ │ ├── __init__.py
│ │ │ └── models.py # LightGBM、CatBoost 等 │ │ │ └── config.py # 训练配置
│ │ ── processors/ # 数据处理器 │ │ ── registry.py # 组件注册中心
│ │ ── __init__.py │ │ ── __init__.py # 导出所有组件
│ │ └── processors.py # 标准化、缩尾、中性化等
│ │ │ │
│ └── training/ # 训练入口 │ └── experiment/ # 实验代码
── __init__.py ── regression.ipynb # 完整训练流程示例
│ ├── pipeline.py # 训练流程配置
│ └── output/ # 训练输出
│ └── top_stocks.tsv # 推荐股票结果
├── tests/ # 测试文件 ├── tests/ # 测试文件
│ ├── test_sync.py │ ├── test_sync.py
@@ -148,8 +162,6 @@ ProStock/
│ ├── test_factor_engine.py │ ├── test_factor_engine.py
│ ├── test_factor_integration.py │ ├── test_factor_integration.py
│ ├── test_pro_bar.py │ ├── test_pro_bar.py
│ ├── test_601117_factors.py
│ ├── test_two_stocks_string_factors.py
│ ├── test_db_manager.py │ ├── test_db_manager.py
│ ├── test_daily_storage.py │ ├── test_daily_storage.py
│ ├── test_tushare_api.py │ ├── test_tushare_api.py
@@ -183,6 +195,7 @@ import threading
# 第三方包 # 第三方包
import pandas as pd import pandas as pd
import numpy as np import numpy as np
import polars as pl
from tqdm import tqdm from tqdm import tqdm
from pydantic_settings import BaseSettings from pydantic_settings import BaseSettings
@@ -250,7 +263,7 @@ except Exception as e:
### 配置 ### 配置
- 对所有配置使用 **pydantic-settings** - 对所有配置使用 **pydantic-settings**
-`config/.env.local` 文件加载 -`config/.env.local` 文件加载
- 环境变量自动转换:`tushare_token` `TUSHARE_TOKEN` - 环境变量自动转换:`tushare_token` -> `TUSHARE_TOKEN`
- 对配置单例使用 `@lru_cache()` - 对配置单例使用 `@lru_cache()`
### 数据存储 ### 数据存储
@@ -291,7 +304,6 @@ except Exception as e:
- `pydantic>=2.0.0``pydantic-settings>=2.0.0` - 配置 - `pydantic>=2.0.0``pydantic-settings>=2.0.0` - 配置
- `tqdm>=4.65.0` - 进度条 - `tqdm>=4.65.0` - 进度条
- `lightgbm>=4.0.0` - 机器学习模型 - `lightgbm>=4.0.0` - 机器学习模型
- `catboost>=1.2.0` - 机器学习模型
- `pytest` - 测试(开发) - `pytest` - 测试(开发)
### 环境变量 ### 环境变量
@@ -315,124 +327,13 @@ uv run python -c "from src.data.sync import sync_all; sync_all(force_full=True)"
# 自定义线程数 # 自定义线程数
uv run python -c "from src.data.sync import sync_all; sync_all(max_workers=20)" uv run python -c "from src.data.sync import sync_all; sync_all(max_workers=20)"
# 同步财务数据
uv run python -c "from src.data.api_wrappers.financial_data import sync_financial; sync_financial()"
# 运行因子计算测试 # 运行因子计算测试
uv run pytest tests/test_factor_engine.py -v uv run pytest tests/test_factor_engine.py -v
``` ```
## 架构变更历史
### v2.2 (2026-03-01) - 因子框架 DSL 化重构
#### 因子计算框架重构
**变更**: 从基类继承方式迁移到 DSL 表达式方式
**原因**:
- 提供更直观的数学公式表达方式
- 支持因子表达式的组合和嵌套
- 更好的类型安全和编译期检查
**架构变化**:
- 新增 `dsl.py`: 表达式节点基类和运算符重载Symbol、FunctionNode等
- 新增 `api.py`: 常用符号close/open/volume等和函数ts_mean/cs_rank等
- 新增 `compiler.py`: AST 编译器,提取表达式依赖
- 新增 `translator.py`: 将 DSL 表达式翻译为 Polars 表达式
- 新增 `parser.py`: 字符串公式解析器FormulaParser支持从字符串解析 DSL 表达式
- 新增 `registry.py`: 函数注册表FunctionRegistry管理字符串函数名到 Python 函数的映射
- 新增 `exceptions.py`: 公式解析异常定义FormulaParseError、UnknownFunctionError等
- 重构 `engine/` 子模块:
- `factor_engine.py`: 因子引擎统一入口FactorEngine
- `data_spec.py`: 数据规格定义DataSpec, ExecutionPlan
- `data_router.py`: 数据路由器
- `planner.py`: 执行计划生成器ExecutionPlanner
- `compute_engine.py`: 计算引擎ComputeEngine
- 移除: `base.py``composite.py``data_loader.py`、根目录的 `data_spec.py`
- 移除: `factors/momentum/``factors/financial/` 子目录
**使用方式对比**:
```python
# 旧方式(基类继承)
class MA20Factor(TimeSeriesFactor):
name = "ma20"
data_specs = [DataSpec("daily", ["close"], 20)]
def compute(self, data):
return data.get_column("close").rolling_mean(20)
# 新方式DSL 表达式)
from src.factors.api import close, ts_mean
ma20 = ts_mean(close, 20) # 直接编写数学表达式
engine = FactorEngine()
engine.register("ma20", ma20)
result = engine.compute(["ma20"], "20240101", "20240131")
# 字符串公式解析Phase 1 新增)
from src.factors import FormulaParser, FunctionRegistry
parser = FormulaParser(FunctionRegistry())
expr = parser.parse("ts_mean(close, 20) / close") # 从字符串解析
```
#### data 模块补充完善
**新增文件**:
- `api_wrappers/base_sync.py`: 数据同步基础抽象类BaseDataSync、StockBasedSync、DateBasedSync
- `data_router.py`: 数据路由器(已集成到 factors/engine.py 中的 DataRouter
- `utils.py`: 日期工具函数get_today_date、get_next_date、is_quarter_end等
**影响**: 数据同步逻辑更加规范化,支持按股票和按日期两种同步模式
### v2.1 (2026-02-28) - 同步模块规范更新
**变更**: 明确 `sync.py` 只包含每日更新的数据同步
**原因**: 区分高频(每日)和低频(季度/年度)数据,避免不必要的 API 调用
**规范**:
- `sync.py` / `sync_all_data()`: **仅包含每日更新的数据**
- 日线数据 (`api_daily`)
- Pro Bar 数据 (`api_pro_bar`)
- 交易日历 (`api_trade_cal`)
- 股票基本信息 (`api_stock_basic`)
- 历史股票列表 (`api_bak_basic`)
- **不应放入 `sync.py` 的季度/低频数据**:
- 财务数据 (`financial_data/` 目录): 利润表、资产负债表、现金流量表等
- 名称变更 (`api_namechange`): 已移除自动同步,建议手动定期同步
- **季度数据同步方式**:
```python
# 财务数据单独同步(不在 sync_all_data 中)
from src.data.api_wrappers.financial_data.api_financial_sync import sync_financial
sync_financial() # 增量同步利润表
# 名称变更手动同步
from src.data.api_wrappers import sync_namechange
sync_namechange(force=True)
```
### v2.0 (2026-02-23) - 重要更新
#### 存储层重构
**变更**: 从 HDF5 迁移到 DuckDB
**原因**: DuckDB 提供更好的查询性能、SQL 下推能力、并发支持
**影响**: 所有数据表现在使用 DuckDB 存储,旧 HDF5 文件可手动迁移
#### Sync 类迁移
**变更**: `DataSync` 类从 `sync.py` 迁移到 `api_daily.py`
**原因**: 实现代码职责分离,每个 API 文件包含自己的同步逻辑
**影响**:
- `sync.py` 保留为调度中心
- `api_daily.py` 包含 `DailySync` 类和 `sync_daily` 函数
#### 新增模块
**pipeline 模块**: 机器学习流水线组件(处理器、模型、划分策略)
**training 模块**: 训练入口程序
**factors/momentum**: 动量因子MA、收益率排名
**factors/financial**: 财务因子框架
**data/utils.py**: 日期工具函数集中管理
#### 新增 API 接口
`api_namechange.py`: 股票曾用名接口(手动同步)
`api_bak_basic.py`: 历史股票列表接口
#### 工具函数统一
`get_today_date()`、`get_next_date()`、`DEFAULT_START_DATE` 等函数统一在 `src/data/utils.py` 中管理
其他模块应从 `utils.py` 导入这些函数,避免重复定义
其他模块应从 `utils.py` 导入这些函数,避免重复定义
## Factors 框架设计说明 ## Factors 框架设计说明
### 架构层次 ### 架构层次
@@ -468,58 +369,28 @@ Engine (engine/) <- 执行引擎
数据层 (data_router.py + DuckDB) <- 数据获取和存储 数据层 (data_router.py + DuckDB) <- 数据获取和存储
``` ```
### 使用方式 ### FactorEngine 核心 API
#### 1. 基础表达式DSL 方式)
```python
from src.factors.api import close, open, ts_mean, cs_rank
# 定义因子表达式(惰性计算)
ma20 = ts_mean(close, 20) # 20日移动平均
price_rank = cs_rank(close) # 收盘价截面排名
# 组合运算
alpha = ma20 * 0.6 + price_rank * 0.4
```
#### 2. 字符串公式解析Phase 1 新增)
```python
from src.factors import FormulaParser, FunctionRegistry
# 创建解析器(自动加载 api.py 中的所有函数)
parser = FormulaParser(FunctionRegistry())
# 从字符串解析公式
expr = parser.parse("ts_mean(close, 20) / close")
complex_expr = parser.parse("cs_rank(ts_mean(close, 5) - ts_mean(close, 20))")
# 支持完整运算符和函数调用
expr2 = parser.parse("(close - open) / open * 100") # 涨跌幅
expr3 = parser.parse("ts_corr(close, volume, 20)") # 量价相关性
```
#### 3. 注册和执行
```python ```python
from src.factors import FactorEngine from src.factors import FactorEngine
# 初始化引擎
engine = FactorEngine() engine = FactorEngine()
# 注册 DSL 表达式 # 方式1: 使用 DSL 表达式
engine.register("ma20", ma20) from src.factors.api import close, ts_mean, cs_rank
engine.register("price_rank", price_rank) engine.register("ma20", ts_mean(close, 20))
engine.register("price_rank", cs_rank(close))
# 或注册字符串解析的表达式 # 方式2: 使用字符串表达式(推荐)
engine.register("alpha", parser.parse("ma20 * 0.6 + price_rank * 0.4")) engine.add_factor("ma20", "ts_mean(close, 20)")
engine.add_factor("alpha", "cs_rank(ts_mean(close, 5) - ts_mean(close, 20))")
# 执行计算 # 计算因子
result = engine.compute( result = engine.compute(["ma20", "price_rank"], "20240101", "20240131")
factor_names=["ma20", "price_rank"],
start_date="20240101", # 查看已注册因子
end_date="20240131", print(engine.list_registered())
)
``` ```
### 支持的函数 ### 支持的函数
@@ -584,7 +455,6 @@ expr4 = ts_mean(cs_rank(close), 20) # 排名后的20日平滑
- `EmptyExpressionError` - 空表达式错误 - `EmptyExpressionError` - 空表达式错误
- `DuplicateFunctionError` - 函数重复注册错误 - `DuplicateFunctionError` - 函数重复注册错误
示例:
```python ```python
from src.factors import FormulaParser, FunctionRegistry, UnknownFunctionError from src.factors import FormulaParser, FunctionRegistry, UnknownFunctionError
@@ -596,7 +466,183 @@ except UnknownFunctionError as e:
``` ```
## AI 行为准则 ## Training 模块设计说明
### 架构概述
Training 模块位于 `src/training/` 目录,负责从因子数据到模型训练、预测的完整流程。采用组件化设计,支持数据处理器、模型、过滤器、股票池管理器的灵活组合。
```
src/training/
├── core/
│ ├── trainer.py # Trainer 主类
│ └── stock_pool_manager.py # 股票池管理器
├── components/
│ ├── base.py # BaseModel、BaseProcessor 抽象基类
│ ├── splitters.py # DateSplitter 日期划分器
│ ├── filters.py # STFilter 等过滤器
│ ├── models/
│ │ └── lightgbm.py # LightGBMModel
│ └── processors/
│ └── transforms.py # 数据处理器实现
├── config/
│ └── config.py # TrainingConfig
└── registry.py # 组件注册中心
```
### Trainer 核心流程
```python
from src.training import Trainer, DateSplitter, StockPoolManager
from src.training.components.models import LightGBMModel
from src.training.components.processors import Winsorizer, StandardScaler
from src.training.components.filters import STFilter
import polars as pl
# 1. 创建模型
model = LightGBMModel(params={
"objective": "regression",
"metric": "mae",
"num_leaves": 20,
"learning_rate": 0.01,
"n_estimators": 1000,
})
# 2. 创建数据划分器(正确的 train/val/test 三分法)
splitter = DateSplitter(
train_start="20200101",
train_end="20231231",
val_start="20240101",
val_end="20241231",
test_start="20250101",
test_end="20261231",
)
# 3. 创建数据处理器
processors = [
NullFiller(strategy="mean"),
Winsorizer(lower=0.01, upper=0.99),
StandardScaler(exclude_cols=["ts_code", "trade_date", "target"]),
]
# 4. 创建股票池筛选函数
def stock_pool_filter(df: pl.DataFrame) -> pl.Series:
"""筛选小市值股票,排除创业板/科创板/北交所"""
code_filter = (
~df["ts_code"].str.starts_with("300") & # 排除创业板
~df["ts_code"].str.starts_with("688") & # 排除科创板
~df["ts_code"].str.starts_with("8") & # 排除北交所
~df["ts_code"].str.starts_with("9") &
~df["ts_code"].str.starts_with("4")
)
valid_df = df.filter(code_filter)
n = min(1000, len(valid_df))
small_cap_codes = valid_df.sort("total_mv").head(n)["ts_code"]
return df["ts_code"].is_in(small_cap_codes)
pool_manager = StockPoolManager(
filter_func=stock_pool_filter,
required_columns=["total_mv"],
data_router=engine.router,
)
# 5. 创建 ST 过滤器
st_filter = STFilter(data_router=engine.router)
# 6. 创建训练器
trainer = Trainer(
model=model,
pool_manager=pool_manager,
processors=processors,
filters=[st_filter],
splitter=splitter,
target_col="future_return_5",
feature_cols=["ma_5", "ma_20", "volume_ratio", "roe"],
)
# 7. 执行训练
trainer.train(data)
# 8. 获取结果
results = trainer.get_results()
```
### 数据处理器
**NullFiller** - 缺失值填充:
```python
from src.training.components.processors import NullFiller
# 使用 0 填充
filler = NullFiller(strategy="zero")
# 使用均值填充(每天独立计算截面均值)
filler = NullFiller(strategy="mean", by_date=True)
# 使用指定值填充
filler = NullFiller(strategy="value", fill_value=-999)
```
**Winsorizer** - 缩尾处理:
```python
from src.training.components.processors import Winsorizer
# 全局缩尾(默认)
winsorizer = Winsorizer(lower=0.01, upper=0.99, by_date=False)
# 每天独立缩尾
winsorizer = Winsorizer(lower=0.01, upper=0.99, by_date=True)
```
**StandardScaler** - 标准化:
```python
from src.training.components.processors import StandardScaler
# 全局标准化(学习训练集的均值和标准差)
scaler = StandardScaler(exclude_cols=["ts_code", "trade_date", "target"])
```
**CrossSectionalStandardScaler** - 截面标准化:
```python
from src.training.components.processors import CrossSectionalStandardScaler
# 每天独立标准化(不需要 fit
cs_scaler = CrossSectionalStandardScaler(
exclude_cols=["ts_code", "trade_date", "target"],
date_col="trade_date",
)
```
### 组件注册机制
```python
from src.training.registry import register_model, register_processor
from src.training.components.base import BaseModel, BaseProcessor
# 注册自定义模型
@register_model("custom_model")
class CustomModel(BaseModel):
name = "custom_model"
def fit(self, X, y):
# 训练逻辑
return self
def predict(self, X):
# 预测逻辑
return predictions
# 注册自定义处理器
@register_processor("custom_processor")
class CustomProcessor(BaseProcessor):
name = "custom_processor"
def transform(self, X):
# 转换逻辑
return X
```
## AI 行为准则 ## AI 行为准则
### LSP 检测报错处理 ### LSP 检测报错处理
@@ -649,3 +695,4 @@ LSP 报错Syntax error on line 45
4. **检查方法** 4. **检查方法**
- 使用正则表达式搜索 emoji`[\U0001F600-\U0001F64F\U0001F300-\U0001F5FF\U0001F680-\U0001F6FF\U0001F1E0-\U0001F1FF\u2600-\u26FF\u2700-\u27BF]` - 使用正则表达式搜索 emoji`[\U0001F600-\U0001F64F\U0001F300-\U0001F5FF\U0001F680-\U0001F6FF\U0001F1E0-\U0001F1FF\u2600-\u26FF\u2700-\u27BF]`
- 提交前自查,确保无 emoji 混入代码 - 提交前自查,确保无 emoji 混入代码

File diff suppressed because one or more lines are too long

View File

@@ -78,6 +78,7 @@ class PolarsTranslator:
self.register_handler("cs_neutral", self._handle_cs_neutral) self.register_handler("cs_neutral", self._handle_cs_neutral)
# 元素级数学函数 (element_wise) # 元素级数学函数 (element_wise)
self.register_handler("abs", self._handle_abs)
self.register_handler("log", self._handle_log) self.register_handler("log", self._handle_log)
self.register_handler("exp", self._handle_exp) self.register_handler("exp", self._handle_exp)
self.register_handler("sqrt", self._handle_sqrt) self.register_handler("sqrt", self._handle_sqrt)
@@ -297,23 +298,23 @@ class PolarsTranslator:
@time_series @time_series
def _handle_ts_corr(self, node: FunctionNode) -> pl.Expr: def _handle_ts_corr(self, node: FunctionNode) -> pl.Expr:
"""处理 ts_corr(x, y, window) -> rolling_corr(y, window)。""" """处理 ts_corr(x, y, window) -> rolling_corr(x, y, window_size)。"""
if len(node.args) != 3: if len(node.args) != 3:
raise ValueError("ts_corr 需要 3 个参数: (x, y, window)") raise ValueError("ts_corr 需要 3 个参数: (x, y, window)")
x = self.translate(node.args[0]) x = self.translate(node.args[0])
y = self.translate(node.args[1]) y = self.translate(node.args[1])
window = self._extract_window(node.args[2]) window = self._extract_window(node.args[2])
return x.rolling_corr(y, window_size=window) return pl.rolling_corr(x, y, window_size=window)
@time_series @time_series
def _handle_ts_cov(self, node: FunctionNode) -> pl.Expr: def _handle_ts_cov(self, node: FunctionNode) -> pl.Expr:
"""处理 ts_cov(x, y, window) -> rolling_cov(y, window)。""" """处理 ts_cov(x, y, window) -> rolling_cov(x, y, window_size)。"""
if len(node.args) != 3: if len(node.args) != 3:
raise ValueError("ts_cov 需要 3 个参数: (x, y, window)") raise ValueError("ts_cov 需要 3 个参数: (x, y, window)")
x = self.translate(node.args[0]) x = self.translate(node.args[0])
y = self.translate(node.args[1]) y = self.translate(node.args[1])
window = self._extract_window(node.args[2]) window = self._extract_window(node.args[2])
return x.rolling_cov(y, window_size=window) return pl.rolling_cov(x, y, window_size=window)
@time_series @time_series
def _handle_ts_var(self, node: FunctionNode) -> pl.Expr: def _handle_ts_var(self, node: FunctionNode) -> pl.Expr:
@@ -494,6 +495,14 @@ class PolarsTranslator:
expr = self.translate(node.args[0]) expr = self.translate(node.args[0])
return expr.log() return expr.log()
@element_wise
def _handle_abs(self, node: FunctionNode) -> pl.Expr:
"""处理 abs(expr) -> 绝对值。"""
if len(node.args) != 1:
raise ValueError("abs 需要 1 个参数: (expr)")
expr = self.translate(node.args[0])
return expr.abs()
@element_wise @element_wise
def _handle_exp(self, node: FunctionNode) -> pl.Expr: def _handle_exp(self, node: FunctionNode) -> pl.Expr:
"""处理 exp(expr) -> 指数函数。""" """处理 exp(expr) -> 指数函数。"""