docs: 全面更新 AGENTS.md 文档
This commit is contained in:
237
AGENTS.md
237
AGENTS.md
@@ -123,6 +123,11 @@ ProStock/
|
||||
│ │ │ ├── compute_engine.py # 计算引擎
|
||||
│ │ │ ├── schema_cache.py # 表结构缓存
|
||||
│ │ │ └── factor_engine.py # 因子引擎统一入口
|
||||
│ │ ├── metadata/ # 因子元数据管理
|
||||
│ │ │ ├── __init__.py # 导出元数据组件
|
||||
│ │ │ ├── manager.py # 因子管理器主类
|
||||
│ │ │ ├── validator.py # 字段校验器
|
||||
│ │ │ └── exceptions.py # 元数据异常定义
|
||||
│ │ ├── __init__.py # 导出所有公开 API
|
||||
│ │ ├── dsl.py # DSL 表达式层 - 节点定义和运算符重载
|
||||
│ │ ├── api.py # API 层 - 常用符号和函数
|
||||
@@ -145,7 +150,8 @@ ProStock/
|
||||
│ │ │ ├── filters.py # 数据过滤器
|
||||
│ │ │ ├── models/ # 模型实现
|
||||
│ │ │ │ ├── __init__.py
|
||||
│ │ │ │ └── lightgbm.py # LightGBM 模型
|
||||
│ │ │ │ ├── lightgbm.py # LightGBM 模型
|
||||
│ │ │ │ └── lightgbm_lambdarank.py # LambdaRank 排序模型
|
||||
│ │ │ └── processors/ # 数据处理器
|
||||
│ │ │ ├── __init__.py
|
||||
│ │ │ └── transforms.py # 变换处理器
|
||||
@@ -155,8 +161,14 @@ ProStock/
|
||||
│ │ ├── registry.py # 组件注册中心
|
||||
│ │ └── __init__.py # 导出所有组件
|
||||
│ │
|
||||
│ ├── scripts/ # 脚本工具
|
||||
│ │ └── register_factors.py # 因子批量注册脚本
|
||||
│ │
|
||||
│ └── experiment/ # 实验代码
|
||||
│ └── regression.ipynb # 完整训练流程示例
|
||||
│ ├── data/ # 实验数据目录
|
||||
│ ├── regression.py # 回归训练流程(Python脚本)
|
||||
│ ├── learn_to_rank.py # 排序学习训练流程(Python脚本)
|
||||
│ └── regression.ipynb # 完整训练流程示例(Notebook)
|
||||
│
|
||||
├── tests/ # 测试文件
|
||||
│ ├── test_sync.py
|
||||
@@ -273,6 +285,17 @@ except Exception as e:
|
||||
- 存储在 `data/` 目录中(通过 `DATA_PATH` 环境变量配置)
|
||||
- 使用 UPSERT 模式(`INSERT OR REPLACE`)处理重复数据
|
||||
- 多线程场景使用 `ThreadSafeStorage.queue_save()` + `flush()` 模式
|
||||
- **只读模式支持**: 查询时默认启用 `read_only=True`,避免并发冲突
|
||||
|
||||
```python
|
||||
from src.data.storage import Storage
|
||||
|
||||
# 查询模式(只读,推荐用于数据查询)
|
||||
storage = Storage(read_only=True) # 默认只读
|
||||
|
||||
# 写入模式(用于数据同步)
|
||||
storage = Storage(read_only=False)
|
||||
```
|
||||
|
||||
### 线程与并发
|
||||
- 对 I/O 密集型任务(API 调用)使用 `ThreadPoolExecutor`
|
||||
@@ -368,6 +391,11 @@ Engine (engine/) <- 执行引擎
|
||||
| - ComputeEngine: 计算引擎
|
||||
|
|
||||
v
|
||||
Metadata (metadata/) <- 因子元数据管理(可选)
|
||||
| - FactorManager: 元数据管理器
|
||||
| - FactorValidator: 字段校验器
|
||||
|
|
||||
v
|
||||
数据层 (data_router.py + DuckDB) <- 数据获取和存储
|
||||
```
|
||||
|
||||
@@ -376,7 +404,7 @@ Engine (engine/) <- 执行引擎
|
||||
```python
|
||||
from src.factors import FactorEngine
|
||||
|
||||
# 初始化引擎
|
||||
# 初始化引擎(默认启用 metadata 功能)
|
||||
engine = FactorEngine()
|
||||
|
||||
# 方式1: 使用 DSL 表达式
|
||||
@@ -388,11 +416,64 @@ engine.register("price_rank", cs_rank(close))
|
||||
engine.add_factor("ma20", "ts_mean(close, 20)")
|
||||
engine.add_factor("alpha", "cs_rank(ts_mean(close, 5) - ts_mean(close, 20))")
|
||||
|
||||
# 方式3: 从 metadata 查询(需先在 metadata 中定义因子)
|
||||
engine.add_factor("mom_5d") # 从 metadata 查询并注册名为 mom_5d 的因子
|
||||
|
||||
# 计算因子
|
||||
result = engine.compute(["ma20", "price_rank"], "20240101", "20240131")
|
||||
|
||||
# 查看已注册因子
|
||||
print(engine.list_registered())
|
||||
|
||||
# 预览执行计划
|
||||
plan = engine.preview_plan("ma20")
|
||||
```
|
||||
|
||||
### 因子元数据管理 (metadata 模块)
|
||||
|
||||
metadata 模块提供基于 DuckDB 查询 JSONL 文件、零拷贝输出 Polars DataFrame 的因子管理能力。
|
||||
|
||||
**核心组件:**
|
||||
- `FactorManager`: 元数据管理器主类,提供因子增删改查接口
|
||||
- `FactorValidator`: 字段校验器,校验核心字段的存在性和类型
|
||||
- 异常类: `FactorMetadataError`, `ValidationError`, `DuplicateFactorError` 等
|
||||
|
||||
**因子数据结构:**
|
||||
- `factor_id` (str): 全局唯一标识符(如 "F_001")
|
||||
- `name` (str): 可读短名称(如 "mom_5d")
|
||||
- `desc` (str): 详细描述
|
||||
- `dsl` (str): DSL 计算公式
|
||||
- 扩展字段: `category`, `author`, `tags`, `notes` 等
|
||||
|
||||
**使用示例:**
|
||||
|
||||
```python
|
||||
from src.factors.metadata import FactorManager
|
||||
|
||||
# 初始化管理器(默认路径: data/factors.jsonl)
|
||||
manager = FactorManager()
|
||||
|
||||
# 添加因子
|
||||
manager.add_factor({
|
||||
"factor_id": "F_001",
|
||||
"name": "mom_5d",
|
||||
"desc": "5日价格动量截面排序",
|
||||
"dsl": "cs_rank(close / ts_delay(close, 5) - 1)",
|
||||
"category": "momentum" # 扩展字段
|
||||
})
|
||||
|
||||
# 根据名称查询因子
|
||||
df = manager.get_factors_by_name("mom_5d")
|
||||
|
||||
# 使用 SQL 条件查询因子
|
||||
df = manager.search_factors("category = 'momentum'")
|
||||
df = manager.search_factors("name LIKE 'mom_%'")
|
||||
|
||||
# 获取所有因子
|
||||
df = manager.get_all_factors()
|
||||
|
||||
# 获取因子 DSL 表达式
|
||||
dsl = manager.get_factor_dsl("F_001")
|
||||
```
|
||||
|
||||
### 支持的函数
|
||||
@@ -468,6 +549,46 @@ except UnknownFunctionError as e:
|
||||
```
|
||||
|
||||
|
||||
### 因子批量注册脚本
|
||||
|
||||
`src/scripts/register_factors.py` 提供批量注册因子到元数据的功能。用户只需在 `FACTORS` 列表中配置因子定义,脚本自动生成 `factor_id` 并保存到 `factors.jsonl`。
|
||||
|
||||
**使用方法:**
|
||||
|
||||
```python
|
||||
# 在 register_factors.py 的 FACTORS 列表中定义因子
|
||||
FACTORS = [
|
||||
{
|
||||
"name": "mom_5d",
|
||||
"desc": "5日价格动量,收盘价相对于5日前收盘价的涨跌幅进行截面排名",
|
||||
"dsl": "cs_rank(close / ts_delay(close, 5) - 1)",
|
||||
"category": "momentum", # 可选扩展字段
|
||||
},
|
||||
{
|
||||
"name": "volatility_20d",
|
||||
"desc": "20日价格波动率,收益率的20日滚动标准差",
|
||||
"dsl": "ts_std(ts_delta(close, 1) / ts_delay(close, 1), 20)",
|
||||
"category": "volatility",
|
||||
},
|
||||
]
|
||||
|
||||
# 运行脚本
|
||||
# uv run python src/scripts/register_factors.py
|
||||
```
|
||||
|
||||
**脚本特性:**
|
||||
- 自动生成 `F_XXX` 格式的唯一 ID
|
||||
- 自动跳过已存在的因子(通过 `name` 判断)
|
||||
- 支持扩展字段(category, author, tags, notes 等)
|
||||
- 提供注册结果统计(成功/跳过/失败)
|
||||
|
||||
**命令行使用:**
|
||||
```bash
|
||||
# 批量注册所有配置的因子
|
||||
uv run python src/scripts/register_factors.py
|
||||
```
|
||||
|
||||
|
||||
## Training 模块设计说明
|
||||
|
||||
### 架构概述
|
||||
@@ -482,14 +603,20 @@ src/training/
|
||||
├── components/
|
||||
│ ├── base.py # BaseModel、BaseProcessor 抽象基类
|
||||
│ ├── splitters.py # DateSplitter 日期划分器
|
||||
│ ├── selectors.py # 股票选择器(已迁移到 StockPoolManager)
|
||||
│ ├── filters.py # STFilter 等过滤器
|
||||
│ ├── models/
|
||||
│ │ └── lightgbm.py # LightGBMModel
|
||||
│ └── processors/
|
||||
│ └── transforms.py # 数据处理器实现
|
||||
├── config/
|
||||
│ └── config.py # TrainingConfig
|
||||
└── registry.py # 组件注册中心
|
||||
│ ├── models/ # 模型实现
|
||||
│ │ ├── __init__.py
|
||||
│ │ ├── lightgbm.py # LightGBM 回归/分类模型
|
||||
│ │ └── lightgbm_lambdarank.py # LightGBM LambdaRank 排序模型
|
||||
│ └── processors/ # 数据处理器
|
||||
│ ├── __init__.py
|
||||
│ └── transforms.py # 变换处理器
|
||||
├── config/ # 配置
|
||||
│ ├── __init__.py
|
||||
│ └── config.py # 训练配置
|
||||
├── registry.py # 组件注册中心
|
||||
└── __init__.py # 导出所有组件
|
||||
```
|
||||
|
||||
### Trainer 核心流程
|
||||
@@ -514,17 +641,17 @@ model = LightGBMModel(params={
|
||||
splitter = DateSplitter(
|
||||
train_start="20200101",
|
||||
train_end="20231231",
|
||||
val_start="20240101",
|
||||
val_end="20241231",
|
||||
test_start="20250101",
|
||||
test_end="20261231",
|
||||
val_start="20240101",
|
||||
val_end="20241231",
|
||||
)
|
||||
|
||||
# 3. 创建数据处理器
|
||||
processors = [
|
||||
NullFiller(strategy="mean"),
|
||||
Winsorizer(lower=0.01, upper=0.99),
|
||||
StandardScaler(exclude_cols=["ts_code", "trade_date", "target"]),
|
||||
NullFiller(feature_cols=["ma_5", "ma_20", "volume_ratio", "roe"], strategy="mean"),
|
||||
Winsorizer(feature_cols=["ma_5", "ma_20", "volume_ratio", "roe"], lower=0.01, upper=0.99),
|
||||
StandardScaler(feature_cols=["ma_5", "ma_20", "volume_ratio", "roe"]),
|
||||
]
|
||||
|
||||
# 4. 创建股票池筛选函数
|
||||
@@ -576,13 +703,21 @@ results = trainer.get_results()
|
||||
from src.training.components.processors import NullFiller
|
||||
|
||||
# 使用 0 填充
|
||||
filler = NullFiller(strategy="zero")
|
||||
filler = NullFiller(feature_cols=["factor1", "factor2"], strategy="zero")
|
||||
|
||||
# 使用均值填充(每天独立计算截面均值)
|
||||
filler = NullFiller(strategy="mean", by_date=True)
|
||||
filler = NullFiller(
|
||||
feature_cols=["factor1", "factor2"],
|
||||
strategy="mean",
|
||||
by_date=True
|
||||
)
|
||||
|
||||
# 使用指定值填充
|
||||
filler = NullFiller(strategy="value", fill_value=-999)
|
||||
filler = NullFiller(
|
||||
feature_cols=["factor1", "factor2"],
|
||||
strategy="value",
|
||||
fill_value=-999
|
||||
)
|
||||
```
|
||||
|
||||
**Winsorizer** - 缩尾处理:
|
||||
@@ -590,10 +725,20 @@ filler = NullFiller(strategy="value", fill_value=-999)
|
||||
from src.training.components.processors import Winsorizer
|
||||
|
||||
# 全局缩尾(默认)
|
||||
winsorizer = Winsorizer(lower=0.01, upper=0.99, by_date=False)
|
||||
winsorizer = Winsorizer(
|
||||
feature_cols=["factor1", "factor2"],
|
||||
lower=0.01,
|
||||
upper=0.99,
|
||||
by_date=False
|
||||
)
|
||||
|
||||
# 每天独立缩尾
|
||||
winsorizer = Winsorizer(lower=0.01, upper=0.99, by_date=True)
|
||||
winsorizer = Winsorizer(
|
||||
feature_cols=["factor1", "factor2"],
|
||||
lower=0.01,
|
||||
upper=0.99,
|
||||
by_date=True
|
||||
)
|
||||
```
|
||||
|
||||
**StandardScaler** - 标准化:
|
||||
@@ -601,7 +746,7 @@ winsorizer = Winsorizer(lower=0.01, upper=0.99, by_date=True)
|
||||
from src.training.components.processors import StandardScaler
|
||||
|
||||
# 全局标准化(学习训练集的均值和标准差)
|
||||
scaler = StandardScaler(exclude_cols=["ts_code", "trade_date", "target"])
|
||||
scaler = StandardScaler(feature_cols=["factor1", "factor2", "factor3"])
|
||||
```
|
||||
|
||||
**CrossSectionalStandardScaler** - 截面标准化:
|
||||
@@ -610,11 +755,59 @@ from src.training.components.processors import CrossSectionalStandardScaler
|
||||
|
||||
# 每天独立标准化(不需要 fit)
|
||||
cs_scaler = CrossSectionalStandardScaler(
|
||||
exclude_cols=["ts_code", "trade_date", "target"],
|
||||
feature_cols=["factor1", "factor2", "factor3"],
|
||||
date_col="trade_date",
|
||||
)
|
||||
```
|
||||
|
||||
### 排序学习 (LambdaRank)
|
||||
|
||||
**LightGBMLambdaRankModel** - 基于 LambdaRank 的排序学习模型,适用于股票排序任务:
|
||||
|
||||
```python
|
||||
from src.training.components.models import LightGBMLambdaRankModel
|
||||
from src.training import Trainer
|
||||
|
||||
# 创建排序学习模型
|
||||
rank_model = LightGBMLambdaRankModel(
|
||||
params={
|
||||
"objective": "lambdarank",
|
||||
"metric": "ndcg",
|
||||
"ndcg_eval_at": [1, 5, 10, 20],
|
||||
"num_leaves": 31,
|
||||
"learning_rate": 0.05,
|
||||
"n_estimators": 500,
|
||||
"label_gain": [i for i in range(21)], # 20分位数
|
||||
}
|
||||
)
|
||||
|
||||
# 创建训练器(注意:排序学习需要 qid 分组)
|
||||
trainer = Trainer(
|
||||
model=rank_model,
|
||||
pool_manager=pool_manager,
|
||||
processors=processors,
|
||||
filters=[st_filter],
|
||||
splitter=splitter,
|
||||
target_col="label", # 必须是整数标签(分位数编码)
|
||||
feature_cols=feature_cols,
|
||||
date_col="trade_date", # 必须指定,用于构建 qid
|
||||
)
|
||||
|
||||
# 训练并评估
|
||||
results = trainer.train(data)
|
||||
```
|
||||
|
||||
**关键特性:**
|
||||
- **LambdaRank 目标函数**: 使用 LightGBM 的 lambdarank 优化排序
|
||||
- **NDCG 评估**: 支持 NDCG@1/5/10/20 指标评估排序质量
|
||||
- **自动分组**: 根据 `date_col` 自动构建 query group (qid)
|
||||
- **Label 要求**: 目标变量必须是整数(如分位数编码的等级)
|
||||
|
||||
**使用场景:**
|
||||
- 将未来收益率转换为分位数等级作为 label
|
||||
- 学习每日股票的相对排序
|
||||
- 构建 Top-k 选股策略
|
||||
|
||||
### 组件注册机制
|
||||
|
||||
```python
|
||||
|
||||
@@ -308,7 +308,7 @@ SELECTED_FACTORS = [
|
||||
|
||||
# 因子定义字典(完整因子库)
|
||||
FACTOR_DEFINITIONS = {
|
||||
# "turnover_volatility_ratio": "log(ts_std(turnover_rate, 20))"
|
||||
# "turnover_rate_volatility": "ts_std(turnover_rate, 20)"
|
||||
}
|
||||
|
||||
# Label 因子定义(不参与训练,用于计算目标)
|
||||
@@ -330,13 +330,13 @@ TEST_END = "20251231"
|
||||
MODEL_PARAMS = {
|
||||
"objective": "lambdarank",
|
||||
"metric": "ndcg",
|
||||
"ndcg_at": [1, 5, 10, 20], # 评估 NDCG@k
|
||||
"learning_rate": 0.05,
|
||||
"ndcg_at": 2, # 评估 NDCG@k
|
||||
"learning_rate": 0.01,
|
||||
"num_leaves": 31,
|
||||
"max_depth": 6,
|
||||
"min_data_in_leaf": 20,
|
||||
"n_estimators": 1000,
|
||||
"early_stopping_rounds": 50,
|
||||
"n_estimators": 2000,
|
||||
"early_stopping_round": 300,
|
||||
"subsample": 0.8,
|
||||
"colsample_bytree": 0.8,
|
||||
"reg_alpha": 0.1,
|
||||
|
||||
@@ -35,35 +35,11 @@ from src.factors.metadata.exceptions import DuplicateFactorError, ValidationErro
|
||||
FACTORS: List[Dict[str, Any]] = [
|
||||
# 示例因子,请根据实际需要修改或添加
|
||||
{
|
||||
"name": "mom_5d",
|
||||
"name": "turnover_volatility_ratio",
|
||||
"desc": "5日价格动量,收盘价相对于5日前收盘价的涨跌幅进行截面排名",
|
||||
"dsl": "cs_rank(close / ts_delay(close, 5) - 1)",
|
||||
"category": "momentum",
|
||||
},
|
||||
{
|
||||
"name": "mom_20d",
|
||||
"desc": "20日价格动量,收盘价相对于20日前收盘价的涨跌幅进行截面排名",
|
||||
"dsl": "cs_rank(close / ts_delay(close, 20) - 1)",
|
||||
"category": "momentum",
|
||||
},
|
||||
{
|
||||
"name": "volatility_20d",
|
||||
"desc": "20日价格波动率,收益率的20日滚动标准差",
|
||||
"dsl": "ts_std(ts_delta(close, 1) / ts_delay(close, 1), 20)",
|
||||
"category": "volatility",
|
||||
},
|
||||
{
|
||||
"name": "price_ma_ratio",
|
||||
"desc": "价格与20日均线的偏离度",
|
||||
"dsl": "close / ts_mean(close, 20) - 1",
|
||||
"category": "mean_reversion",
|
||||
},
|
||||
{
|
||||
"name": "volume_ratio",
|
||||
"desc": "成交量比率,当日成交量相对于20日均量的比值",
|
||||
"dsl": "volume / ts_mean(volume, 20)",
|
||||
"category": "volume",
|
||||
},
|
||||
]
|
||||
|
||||
# 因子存储路径(默认使用实验目录)
|
||||
|
||||
Reference in New Issue
Block a user