fix(factors): 修复 ts_corr/ts_cov 实现并添加 abs 函数支持

- 修复 ts_corr 和 ts_cov 使用 pl.rolling_corr/pl.rolling_cov 模块级函数
- 添加 abs 函数处理器到 translator
- 扩展 notebook 中的因子定义(24 -> 49 个)
- 更新 AGENTS.md 文档结构和 Training 模块说明
This commit is contained in:
2026-03-09 23:37:20 +08:00
parent 88fa848b96
commit f1811815e7
3 changed files with 675 additions and 482 deletions

457
AGENTS.md
View File

@@ -6,9 +6,9 @@ A股量化投资框架 - Python 项目,用于量化股票投资分析。
**⚠️ 强制要求:所有沟通和思考过程必须使用中文。**
所有与 AI Agent 的交流必须使用中文
代码中的注释和文档字符串使用中文
禁止使用英文进行思考或沟通
- 所有与 AI Agent 的交流必须使用中文
- 代码中的注释和文档字符串使用中文
- 禁止使用英文进行思考或沟通
## 构建/检查/测试命令
@@ -44,7 +44,7 @@ uv run pytest --cov=src --cov-report=term-missing
```bash
# 禁止直接使用 python
python -c "..." # 禁止!
python -c "" # 禁止!
python script.py # 禁止!
python -m pytest # 禁止!
python -m pip install # 禁止!
@@ -59,7 +59,7 @@ pip list # 禁止!
```bash
# 运行 Python 代码
uv run python -c "..." # ✅ 正确
uv run python -c "" # ✅ 正确
uv run python script.py # ✅ 正确
# 安装依赖
@@ -82,65 +82,79 @@ ProStock/
│ │
│ ├── data/ # 数据获取与存储
│ │ ├── api_wrappers/ # Tushare API 封装
│ │ │ ├── base_sync.py # 同步基础抽象类(BaseDataSync/StockBasedSync/DateBasedSync)
│ │ │ ├── api_daily.py # 日线数据接口(DailySync)
│ │ │ ├── api_pro_bar.py # Pro Bar 数据接口(ProBarSync)
│ │ │ ├── base_sync.py # 同步基础抽象类
│ │ │ ├── api_daily.py # 日线数据接口
│ │ │ ├── api_pro_bar.py # Pro Bar 数据接口
│ │ │ ├── api_stock_basic.py # 股票基础信息接口
│ │ │ ├── api_trade_cal.py # 交易日历接口
│ │ │ ├── api_bak_basic.py # 历史股票列表接口(BakBasicSync)
│ │ │ ├── api_bak_basic.py # 历史股票列表接口
│ │ │ ├── api_namechange.py # 股票名称变更接口
│ │ │ ├── api_stock_st.py # ST股票信息接口
│ │ │ ├── api_daily_basic.py # 每日指标接口
│ │ │ ├── api_stk_limit.py # 涨跌停价格接口
│ │ │ ├── financial_data/ # 财务数据接口
│ │ │ │ ├── api_income.py # 利润表接口
│ │ │ │ ── api_financial_sync.py # 财务数据同步
│ │ │ │ ── api_balance.py # 资产负债表接口
│ │ │ │ ├── api_cashflow.py # 现金流量表接口
│ │ │ │ ├── api_fina_indicator.py # 财务指标接口
│ │ │ │ └── api_financial_sync.py # 财务数据同步调度中心
│ │ │ └── __init__.py
│ │ ├── __init__.py
│ │ ├── client.py # Tushare API 客户端(带速率限制)
│ │ ├── config.py # 数据模块配置
│ │ ├── data_router.py # 数据路由器factors/engine 专用)
│ │ ├── db_inspector.py # 数据库信息查看工具
│ │ ├── db_manager.py # DuckDB 表管理和同步
│ │ ├── rate_limiter.py # 令牌桶速率限制器
│ │ ├── storage.py # 数据存储核心
│ │ ├── db_manager.py # DuckDB 表管理和同步
│ │ ├── db_inspector.py # 数据库信息查看工具
│ │ ├── sync.py # 数据同步调度中心
│ │ ── utils.py # 数据模块工具函数
│ │ ── sync_registry.py # 同步器注册表
│ │ ├── rate_limiter.py # 令牌桶速率限制器
│ │ ├── catalog.py # 数据目录管理
│ │ ├── config.py # 数据模块配置
│ │ ├── utils.py # 数据模块工具函数
│ │ └── financial_loader.py # 财务数据加载器
│ │
│ ├── factors/ # 因子计算框架DSL 表达式驱动)
│ │ ├── engine/ # 执行引擎子模块
│ │ │ ├── __init__.py # 导出引擎组件
│ │ │ ├── data_spec.py # 数据规格定义(DataSpec, ExecutionPlan)
│ │ │ ├── data_spec.py # 数据规格定义
│ │ │ ├── data_router.py # 数据路由器
│ │ │ ├── planner.py # 执行计划生成器(ExecutionPlanner)
│ │ │ ├── compute_engine.py # 计算引擎(ComputeEngine)
│ │ │ ── factor_engine.py # 因子引擎统一入口(FactorEngine)
│ │ │ ├── planner.py # 执行计划生成器
│ │ │ ├── compute_engine.py # 计算引擎
│ │ │ ── schema_cache.py # 表结构缓存
│ │ │ └── factor_engine.py # 因子引擎统一入口
│ │ ├── __init__.py # 导出所有公开 API
│ │ ├── dsl.py # DSL 表达式层 - 节点定义和运算符重载
│ │ ├── api.py # API 层 - 常用符号(close/open等)和函数(ts_mean/cs_rank等)
│ │ ├── api.py # API 层 - 常用符号和函数
│ │ ├── compiler.py # AST 编译器 - 依赖提取
│ │ ├── translator.py # Polars 表达式翻译器
│ │ ├── parser.py # 字符串公式解析器(FormulaParser)
│ │ ├── registry.py # 函数注册表(FunctionRegistry)
│ │ ── exceptions.py # 异常定义(FormulaParseError等)
│ │ ├── parser.py # 字符串公式解析器
│ │ ├── registry.py # 函数注册表
│ │ ── decorators.py # 装饰器工具
│ │ └── exceptions.py # 异常定义
│ │
│ ├── pipeline/ # 模型训练管道
│ │ ├── __init__.py
│ │ ├── pipeline.py # 处理流水线(ProcessingPipeline)
│ │ ├── registry.py # 插件注册中心(PluginRegistry)
│ │ ├── core/ # 核心抽象
│ ├── training/ # 训练模块
│ │ ├── core/ # 训练核心组件
│ │ │ ├── __init__.py
│ │ │ ├── base.py # 基类定义(BaseProcessor/BaseModel/BaseSplitter等)
│ │ │ └── splitter.py # 时间序列划分策略
│ │ ├── models/ # 模型实现
│ │ │ ├── trainer.py # 训练器主类
│ │ │ └── stock_pool_manager.py # 股票池管理器
│ │ ├── components/ # 组件
│ │ │ ├── base.py # 基础抽象类
│ │ │ ├── splitters.py # 数据划分器
│ │ │ ├── selectors.py # 股票选择器
│ │ │ ├── filters.py # 数据过滤器
│ │ │ ├── models/ # 模型实现
│ │ │ │ ├── __init__.py
│ │ │ │ └── lightgbm.py # LightGBM 模型
│ │ │ └── processors/ # 数据处理器
│ │ │ ├── __init__.py
│ │ │ └── transforms.py # 变换处理器
│ │ ├── config/ # 配置
│ │ │ ├── __init__.py
│ │ │ └── models.py # LightGBM、CatBoost 等
│ │ ── processors/ # 数据处理器
│ │ ── __init__.py
│ │ └── processors.py # 标准化、缩尾、中性化等
│ │ │ └── config.py # 训练配置
│ │ ── registry.py # 组件注册中心
│ │ ── __init__.py # 导出所有组件
│ │
│ └── training/ # 训练入口
── __init__.py
│ ├── pipeline.py # 训练流程配置
│ └── output/ # 训练输出
│ └── top_stocks.tsv # 推荐股票结果
│ └── experiment/ # 实验代码
── regression.ipynb # 完整训练流程示例
├── tests/ # 测试文件
│ ├── test_sync.py
@@ -148,8 +162,6 @@ ProStock/
│ ├── test_factor_engine.py
│ ├── test_factor_integration.py
│ ├── test_pro_bar.py
│ ├── test_601117_factors.py
│ ├── test_two_stocks_string_factors.py
│ ├── test_db_manager.py
│ ├── test_daily_storage.py
│ ├── test_tushare_api.py
@@ -183,6 +195,7 @@ import threading
# 第三方包
import pandas as pd
import numpy as np
import polars as pl
from tqdm import tqdm
from pydantic_settings import BaseSettings
@@ -250,7 +263,7 @@ except Exception as e:
### 配置
- 对所有配置使用 **pydantic-settings**
-`config/.env.local` 文件加载
- 环境变量自动转换:`tushare_token` `TUSHARE_TOKEN`
- 环境变量自动转换:`tushare_token` -> `TUSHARE_TOKEN`
- 对配置单例使用 `@lru_cache()`
### 数据存储
@@ -291,7 +304,6 @@ except Exception as e:
- `pydantic>=2.0.0``pydantic-settings>=2.0.0` - 配置
- `tqdm>=4.65.0` - 进度条
- `lightgbm>=4.0.0` - 机器学习模型
- `catboost>=1.2.0` - 机器学习模型
- `pytest` - 测试(开发)
### 环境变量
@@ -315,124 +327,13 @@ uv run python -c "from src.data.sync import sync_all; sync_all(force_full=True)"
# 自定义线程数
uv run python -c "from src.data.sync import sync_all; sync_all(max_workers=20)"
# 同步财务数据
uv run python -c "from src.data.api_wrappers.financial_data import sync_financial; sync_financial()"
# 运行因子计算测试
uv run pytest tests/test_factor_engine.py -v
```
## 架构变更历史
### v2.2 (2026-03-01) - 因子框架 DSL 化重构
#### 因子计算框架重构
**变更**: 从基类继承方式迁移到 DSL 表达式方式
**原因**:
- 提供更直观的数学公式表达方式
- 支持因子表达式的组合和嵌套
- 更好的类型安全和编译期检查
**架构变化**:
- 新增 `dsl.py`: 表达式节点基类和运算符重载Symbol、FunctionNode等
- 新增 `api.py`: 常用符号close/open/volume等和函数ts_mean/cs_rank等
- 新增 `compiler.py`: AST 编译器,提取表达式依赖
- 新增 `translator.py`: 将 DSL 表达式翻译为 Polars 表达式
- 新增 `parser.py`: 字符串公式解析器FormulaParser支持从字符串解析 DSL 表达式
- 新增 `registry.py`: 函数注册表FunctionRegistry管理字符串函数名到 Python 函数的映射
- 新增 `exceptions.py`: 公式解析异常定义FormulaParseError、UnknownFunctionError等
- 重构 `engine/` 子模块:
- `factor_engine.py`: 因子引擎统一入口FactorEngine
- `data_spec.py`: 数据规格定义DataSpec, ExecutionPlan
- `data_router.py`: 数据路由器
- `planner.py`: 执行计划生成器ExecutionPlanner
- `compute_engine.py`: 计算引擎ComputeEngine
- 移除: `base.py``composite.py``data_loader.py`、根目录的 `data_spec.py`
- 移除: `factors/momentum/``factors/financial/` 子目录
**使用方式对比**:
```python
# 旧方式(基类继承)
class MA20Factor(TimeSeriesFactor):
name = "ma20"
data_specs = [DataSpec("daily", ["close"], 20)]
def compute(self, data):
return data.get_column("close").rolling_mean(20)
# 新方式DSL 表达式)
from src.factors.api import close, ts_mean
ma20 = ts_mean(close, 20) # 直接编写数学表达式
engine = FactorEngine()
engine.register("ma20", ma20)
result = engine.compute(["ma20"], "20240101", "20240131")
# 字符串公式解析Phase 1 新增)
from src.factors import FormulaParser, FunctionRegistry
parser = FormulaParser(FunctionRegistry())
expr = parser.parse("ts_mean(close, 20) / close") # 从字符串解析
```
#### data 模块补充完善
**新增文件**:
- `api_wrappers/base_sync.py`: 数据同步基础抽象类BaseDataSync、StockBasedSync、DateBasedSync
- `data_router.py`: 数据路由器(已集成到 factors/engine.py 中的 DataRouter
- `utils.py`: 日期工具函数get_today_date、get_next_date、is_quarter_end等
**影响**: 数据同步逻辑更加规范化,支持按股票和按日期两种同步模式
### v2.1 (2026-02-28) - 同步模块规范更新
**变更**: 明确 `sync.py` 只包含每日更新的数据同步
**原因**: 区分高频(每日)和低频(季度/年度)数据,避免不必要的 API 调用
**规范**:
- `sync.py` / `sync_all_data()`: **仅包含每日更新的数据**
- 日线数据 (`api_daily`)
- Pro Bar 数据 (`api_pro_bar`)
- 交易日历 (`api_trade_cal`)
- 股票基本信息 (`api_stock_basic`)
- 历史股票列表 (`api_bak_basic`)
- **不应放入 `sync.py` 的季度/低频数据**:
- 财务数据 (`financial_data/` 目录): 利润表、资产负债表、现金流量表等
- 名称变更 (`api_namechange`): 已移除自动同步,建议手动定期同步
- **季度数据同步方式**:
```python
# 财务数据单独同步(不在 sync_all_data 中)
from src.data.api_wrappers.financial_data.api_financial_sync import sync_financial
sync_financial() # 增量同步利润表
# 名称变更手动同步
from src.data.api_wrappers import sync_namechange
sync_namechange(force=True)
```
### v2.0 (2026-02-23) - 重要更新
#### 存储层重构
**变更**: 从 HDF5 迁移到 DuckDB
**原因**: DuckDB 提供更好的查询性能、SQL 下推能力、并发支持
**影响**: 所有数据表现在使用 DuckDB 存储,旧 HDF5 文件可手动迁移
#### Sync 类迁移
**变更**: `DataSync` 类从 `sync.py` 迁移到 `api_daily.py`
**原因**: 实现代码职责分离,每个 API 文件包含自己的同步逻辑
**影响**:
- `sync.py` 保留为调度中心
- `api_daily.py` 包含 `DailySync` 类和 `sync_daily` 函数
#### 新增模块
**pipeline 模块**: 机器学习流水线组件(处理器、模型、划分策略)
**training 模块**: 训练入口程序
**factors/momentum**: 动量因子MA、收益率排名
**factors/financial**: 财务因子框架
**data/utils.py**: 日期工具函数集中管理
#### 新增 API 接口
`api_namechange.py`: 股票曾用名接口(手动同步)
`api_bak_basic.py`: 历史股票列表接口
#### 工具函数统一
`get_today_date()`、`get_next_date()`、`DEFAULT_START_DATE` 等函数统一在 `src/data/utils.py` 中管理
其他模块应从 `utils.py` 导入这些函数,避免重复定义
其他模块应从 `utils.py` 导入这些函数,避免重复定义
## Factors 框架设计说明
### 架构层次
@@ -468,58 +369,28 @@ Engine (engine/) <- 执行引擎
数据层 (data_router.py + DuckDB) <- 数据获取和存储
```
### 使用方式
#### 1. 基础表达式DSL 方式)
```python
from src.factors.api import close, open, ts_mean, cs_rank
# 定义因子表达式(惰性计算)
ma20 = ts_mean(close, 20) # 20日移动平均
price_rank = cs_rank(close) # 收盘价截面排名
# 组合运算
alpha = ma20 * 0.6 + price_rank * 0.4
```
#### 2. 字符串公式解析Phase 1 新增)
```python
from src.factors import FormulaParser, FunctionRegistry
# 创建解析器(自动加载 api.py 中的所有函数)
parser = FormulaParser(FunctionRegistry())
# 从字符串解析公式
expr = parser.parse("ts_mean(close, 20) / close")
complex_expr = parser.parse("cs_rank(ts_mean(close, 5) - ts_mean(close, 20))")
# 支持完整运算符和函数调用
expr2 = parser.parse("(close - open) / open * 100") # 涨跌幅
expr3 = parser.parse("ts_corr(close, volume, 20)") # 量价相关性
```
#### 3. 注册和执行
### FactorEngine 核心 API
```python
from src.factors import FactorEngine
# 初始化引擎
engine = FactorEngine()
# 注册 DSL 表达式
engine.register("ma20", ma20)
engine.register("price_rank", price_rank)
# 方式1: 使用 DSL 表达式
from src.factors.api import close, ts_mean, cs_rank
engine.register("ma20", ts_mean(close, 20))
engine.register("price_rank", cs_rank(close))
# 或注册字符串解析的表达式
engine.register("alpha", parser.parse("ma20 * 0.6 + price_rank * 0.4"))
# 方式2: 使用字符串表达式(推荐)
engine.add_factor("ma20", "ts_mean(close, 20)")
engine.add_factor("alpha", "cs_rank(ts_mean(close, 5) - ts_mean(close, 20))")
# 执行计算
result = engine.compute(
factor_names=["ma20", "price_rank"],
start_date="20240101",
end_date="20240131",
)
# 计算因子
result = engine.compute(["ma20", "price_rank"], "20240101", "20240131")
# 查看已注册因子
print(engine.list_registered())
```
### 支持的函数
@@ -584,7 +455,6 @@ expr4 = ts_mean(cs_rank(close), 20) # 排名后的20日平滑
- `EmptyExpressionError` - 空表达式错误
- `DuplicateFunctionError` - 函数重复注册错误
示例:
```python
from src.factors import FormulaParser, FunctionRegistry, UnknownFunctionError
@@ -596,7 +466,183 @@ except UnknownFunctionError as e:
```
## AI 行为准则
## Training 模块设计说明
### 架构概述
Training 模块位于 `src/training/` 目录,负责从因子数据到模型训练、预测的完整流程。采用组件化设计,支持数据处理器、模型、过滤器、股票池管理器的灵活组合。
```
src/training/
├── core/
│ ├── trainer.py # Trainer 主类
│ └── stock_pool_manager.py # 股票池管理器
├── components/
│ ├── base.py # BaseModel、BaseProcessor 抽象基类
│ ├── splitters.py # DateSplitter 日期划分器
│ ├── filters.py # STFilter 等过滤器
│ ├── models/
│ │ └── lightgbm.py # LightGBMModel
│ └── processors/
│ └── transforms.py # 数据处理器实现
├── config/
│ └── config.py # TrainingConfig
└── registry.py # 组件注册中心
```
### Trainer 核心流程
```python
from src.training import Trainer, DateSplitter, StockPoolManager
from src.training.components.models import LightGBMModel
from src.training.components.processors import Winsorizer, StandardScaler
from src.training.components.filters import STFilter
import polars as pl
# 1. 创建模型
model = LightGBMModel(params={
"objective": "regression",
"metric": "mae",
"num_leaves": 20,
"learning_rate": 0.01,
"n_estimators": 1000,
})
# 2. 创建数据划分器(正确的 train/val/test 三分法)
splitter = DateSplitter(
train_start="20200101",
train_end="20231231",
val_start="20240101",
val_end="20241231",
test_start="20250101",
test_end="20261231",
)
# 3. 创建数据处理器
processors = [
NullFiller(strategy="mean"),
Winsorizer(lower=0.01, upper=0.99),
StandardScaler(exclude_cols=["ts_code", "trade_date", "target"]),
]
# 4. 创建股票池筛选函数
def stock_pool_filter(df: pl.DataFrame) -> pl.Series:
"""筛选小市值股票,排除创业板/科创板/北交所"""
code_filter = (
~df["ts_code"].str.starts_with("300") & # 排除创业板
~df["ts_code"].str.starts_with("688") & # 排除科创板
~df["ts_code"].str.starts_with("8") & # 排除北交所
~df["ts_code"].str.starts_with("9") &
~df["ts_code"].str.starts_with("4")
)
valid_df = df.filter(code_filter)
n = min(1000, len(valid_df))
small_cap_codes = valid_df.sort("total_mv").head(n)["ts_code"]
return df["ts_code"].is_in(small_cap_codes)
pool_manager = StockPoolManager(
filter_func=stock_pool_filter,
required_columns=["total_mv"],
data_router=engine.router,
)
# 5. 创建 ST 过滤器
st_filter = STFilter(data_router=engine.router)
# 6. 创建训练器
trainer = Trainer(
model=model,
pool_manager=pool_manager,
processors=processors,
filters=[st_filter],
splitter=splitter,
target_col="future_return_5",
feature_cols=["ma_5", "ma_20", "volume_ratio", "roe"],
)
# 7. 执行训练
trainer.train(data)
# 8. 获取结果
results = trainer.get_results()
```
### 数据处理器
**NullFiller** - 缺失值填充:
```python
from src.training.components.processors import NullFiller
# 使用 0 填充
filler = NullFiller(strategy="zero")
# 使用均值填充(每天独立计算截面均值)
filler = NullFiller(strategy="mean", by_date=True)
# 使用指定值填充
filler = NullFiller(strategy="value", fill_value=-999)
```
**Winsorizer** - 缩尾处理:
```python
from src.training.components.processors import Winsorizer
# 全局缩尾(默认)
winsorizer = Winsorizer(lower=0.01, upper=0.99, by_date=False)
# 每天独立缩尾
winsorizer = Winsorizer(lower=0.01, upper=0.99, by_date=True)
```
**StandardScaler** - 标准化:
```python
from src.training.components.processors import StandardScaler
# 全局标准化(学习训练集的均值和标准差)
scaler = StandardScaler(exclude_cols=["ts_code", "trade_date", "target"])
```
**CrossSectionalStandardScaler** - 截面标准化:
```python
from src.training.components.processors import CrossSectionalStandardScaler
# 每天独立标准化(不需要 fit
cs_scaler = CrossSectionalStandardScaler(
exclude_cols=["ts_code", "trade_date", "target"],
date_col="trade_date",
)
```
### 组件注册机制
```python
from src.training.registry import register_model, register_processor
from src.training.components.base import BaseModel, BaseProcessor
# 注册自定义模型
@register_model("custom_model")
class CustomModel(BaseModel):
name = "custom_model"
def fit(self, X, y):
# 训练逻辑
return self
def predict(self, X):
# 预测逻辑
return predictions
# 注册自定义处理器
@register_processor("custom_processor")
class CustomProcessor(BaseProcessor):
name = "custom_processor"
def transform(self, X):
# 转换逻辑
return X
```
## AI 行为准则
### LSP 检测报错处理
@@ -649,3 +695,4 @@ LSP 报错Syntax error on line 45
4. **检查方法**
- 使用正则表达式搜索 emoji`[\U0001F600-\U0001F64F\U0001F300-\U0001F5FF\U0001F680-\U0001F6FF\U0001F1E0-\U0001F1FF\u2600-\u26FF\u2700-\u27BF]`
- 提交前自查,确保无 emoji 混入代码