Files
ProStock/AGENTS.md
liaozhaorun e6c3a918c7 feat(training): 添加 LightGBM LambdaRank 排序学习功能
新增基于 LambdaRank 的排序学习模型,用于股票排序预测任务:
- 实现 LightGBMLambdaRankModel 模型类,支持分位数标签转换
- 提供完整的训练流程和 NDCG 评估指标
- 添加实验 Notebook 演示排序学习全流程
2026-03-10 22:23:44 +08:00

701 lines
22 KiB
Markdown
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# ProStock 代理指南
A股量化投资框架 - Python 项目,用于量化股票投资分析。
## 交流语言要求
**⚠️ 强制要求:所有沟通和思考过程必须使用中文。**
- 所有与 AI Agent 的交流必须使用中文
- 代码中的注释和文档字符串使用中文
- 禁止使用英文进行思考或沟通
## 构建/检查/测试命令
**⚠️ 重要:本项目强制使用 uv 作为 Python 包管理器和运行工具。禁止直接使用 `python``pip` 命令。**
**测试规则:** 当修改或查看 `tests/` 目录下的代码时,必须使用 pytest 命令进行测试验证。
```bash
# 安装依赖(必须使用 uv
uv pip install -e .
# 运行所有测试
uv run pytest
# 运行单个测试文件
uv run pytest tests/test_sync.py
# 运行单个测试类
uv run pytest tests/test_sync.py::TestDataSync
# 运行单个测试方法
uv run pytest tests/test_sync.py::TestDataSync::test_get_all_stock_codes_from_daily
# 使用详细输出运行
uv run pytest -v
# 运行覆盖率测试(如果安装了 pytest-cov
uv run pytest --cov=src --cov-report=term-missing
```
### 禁止的命令 ❌
以下命令在本项目中**严格禁止**
```bash
# 禁止直接使用 python
python -c "…" # 禁止!
python script.py # 禁止!
python -m pytest # 禁止!
python -m pip install # 禁止!
# 禁止直接使用 pip
pip install -e . # 禁止!
pip install package # 禁止!
pip list # 禁止!
```
### 正确的 uv 用法 ✅
```bash
# 运行 Python 代码
uv run python -c "…" # ✅ 正确
uv run python script.py # ✅ 正确
# 安装依赖
uv pip install -e . # ✅ 正确
uv pip install package # ✅ 正确
# 运行测试
uv run pytest # ✅ 正确
uv run pytest tests/test_sync.py # ✅ 正确
```
## 项目结构
```
ProStock/
├── src/ # 源代码
│ ├── config/ # 配置管理
│ │ ├── __init__.py
│ │ └── settings.py # pydantic-settings 配置
│ │
│ ├── data/ # 数据获取与存储
│ │ ├── api_wrappers/ # Tushare API 封装
│ │ │ ├── base_sync.py # 同步基础抽象类
│ │ │ ├── api_daily.py # 日线数据接口
│ │ │ ├── api_pro_bar.py # Pro Bar 数据接口
│ │ │ ├── api_stock_basic.py # 股票基础信息接口
│ │ │ ├── api_trade_cal.py # 交易日历接口
│ │ │ ├── api_bak_basic.py # 历史股票列表接口
│ │ │ ├── api_namechange.py # 股票名称变更接口
│ │ │ ├── api_stock_st.py # ST股票信息接口
│ │ │ ├── api_daily_basic.py # 每日指标接口
│ │ │ ├── api_stk_limit.py # 涨跌停价格接口
│ │ │ ├── financial_data/ # 财务数据接口
│ │ │ │ ├── api_income.py # 利润表接口
│ │ │ │ ├── api_balance.py # 资产负债表接口
│ │ │ │ ├── api_cashflow.py # 现金流量表接口
│ │ │ │ ├── api_fina_indicator.py # 财务指标接口
│ │ │ │ └── api_financial_sync.py # 财务数据同步调度中心
│ │ │ └── __init__.py
│ │ ├── __init__.py
│ │ ├── client.py # Tushare API 客户端(带速率限制)
│ │ ├── storage.py # 数据存储核心
│ │ ├── db_manager.py # DuckDB 表管理和同步
│ │ ├── db_inspector.py # 数据库信息查看工具
│ │ ├── sync.py # 数据同步调度中心
│ │ ├── sync_registry.py # 同步器注册表
│ │ ├── rate_limiter.py # 令牌桶速率限制器
│ │ ├── catalog.py # 数据目录管理
│ │ ├── config.py # 数据模块配置
│ │ ├── utils.py # 数据模块工具函数
│ │ └── financial_loader.py # 财务数据加载器
│ │
│ ├── factors/ # 因子计算框架DSL 表达式驱动)
│ │ ├── engine/ # 执行引擎子模块
│ │ │ ├── __init__.py # 导出引擎组件
│ │ │ ├── data_spec.py # 数据规格定义
│ │ │ ├── data_router.py # 数据路由器
│ │ │ ├── planner.py # 执行计划生成器
│ │ │ ├── compute_engine.py # 计算引擎
│ │ │ ├── schema_cache.py # 表结构缓存
│ │ │ └── factor_engine.py # 因子引擎统一入口
│ │ ├── __init__.py # 导出所有公开 API
│ │ ├── dsl.py # DSL 表达式层 - 节点定义和运算符重载
│ │ ├── api.py # API 层 - 常用符号和函数
│ │ ├── compiler.py # AST 编译器 - 依赖提取
│ │ ├── translator.py # Polars 表达式翻译器
│ │ ├── parser.py # 字符串公式解析器
│ │ ├── registry.py # 函数注册表
│ │ ├── decorators.py # 装饰器工具
│ │ └── exceptions.py # 异常定义
│ │
│ ├── training/ # 训练模块
│ │ ├── core/ # 训练核心组件
│ │ │ ├── __init__.py
│ │ │ ├── trainer.py # 训练器主类
│ │ │ └── stock_pool_manager.py # 股票池管理器
│ │ ├── components/ # 组件
│ │ │ ├── base.py # 基础抽象类
│ │ │ ├── splitters.py # 数据划分器
│ │ │ ├── selectors.py # 股票选择器
│ │ │ ├── filters.py # 数据过滤器
│ │ │ ├── models/ # 模型实现
│ │ │ │ ├── __init__.py
│ │ │ │ └── lightgbm.py # LightGBM 模型
│ │ │ └── processors/ # 数据处理器
│ │ │ ├── __init__.py
│ │ │ └── transforms.py # 变换处理器
│ │ ├── config/ # 配置
│ │ │ ├── __init__.py
│ │ │ └── config.py # 训练配置
│ │ ├── registry.py # 组件注册中心
│ │ └── __init__.py # 导出所有组件
│ │
│ └── experiment/ # 实验代码
│ └── regression.ipynb # 完整训练流程示例
├── tests/ # 测试文件
│ ├── test_sync.py
│ ├── test_daily.py
│ ├── test_factor_engine.py
│ ├── test_factor_integration.py
│ ├── test_pro_bar.py
│ ├── test_db_manager.py
│ ├── test_daily_storage.py
│ ├── test_tushare_api.py
│ └── pipeline/
│ └── test_core.py
├── config/ # 配置文件
│ └── .env.local # 环境变量(不在 git 中)
├── data/ # 数据存储DuckDB
├── docs/ # 文档
├── pyproject.toml # 项目配置
└── README.md
```
## 代码风格指南
### Python 版本
- **需要 Python 3.10+**
- 使用现代 Python 特性match/case、海象运算符、类型提示
### 导入
```python
# 标准库优先
import os
import time
from datetime import datetime, timedelta
from pathlib import Path
from typing import Optional, Dict, Callable
from concurrent.futures import ThreadPoolExecutor
import threading
# 第三方包
import pandas as pd
import numpy as np
import polars as pl
from tqdm import tqdm
from pydantic_settings import BaseSettings
# 本地模块(使用来自 src 的绝对导入)
from src.data.client import TushareClient
from src.data.storage import Storage
from src.config.settings import get_settings
```
### 类型提示
- **始终使用类型提示** 用于函数参数和返回值
- 对可空类型使用 `Optional[X]`
- 当可用时使用现代联合语法 `X | Y`Python 3.10+
-`typing` 导入类型:`Optional``Dict``Callable`
```python
def sync_single_stock(
self,
ts_code: str,
start_date: str,
end_date: str,
) -> pd.DataFrame:
...
```
### 文档字符串
- 使用 **Google 风格文档字符串**
- 包含 Args、Returns 部分
- 第一行保持简短摘要
```python
def get_next_date(date_str: str) -> str:
"""获取给定日期之后的下一天。
Args:
date_str: YYYYMMDD 格式的日期
Returns:
YYYYMMDD 格式的下一天日期
"""
...
```
### 命名约定
- 变量、函数、方法使用 `snake_case`
- 类使用 `PascalCase`
- 常量使用 `UPPER_CASE`
- 私有方法:`_leading_underscore`
- 受保护属性:`_single_underscore`
### 错误处理
- 使用特定的异常,不要使用裸 `except:`
- 使用上下文记录错误:`print(f"[ERROR] 上下文: {e}")`
- 对 API 调用使用指数退避重试逻辑
- 在关键错误时立即停止(设置停止标志)
```python
try:
data = api.query(...)
except Exception as e:
print(f"[ERROR] 获取 {ts_code} 失败: {e}")
raise # 记录后重新抛出
```
### 配置
- 对所有配置使用 **pydantic-settings**
-`config/.env.local` 文件加载
- 环境变量自动转换:`tushare_token` -> `TUSHARE_TOKEN`
- 对配置单例使用 `@lru_cache()`
### 数据存储
- 使用 **DuckDB** 嵌入式 OLAP 数据库进行持久化
- 存储在 `data/` 目录中(通过 `DATA_PATH` 环境变量配置)
- 使用 UPSERT 模式(`INSERT OR REPLACE`)处理重复数据
- 多线程场景使用 `ThreadSafeStorage.queue_save()` + `flush()` 模式
### 线程与并发
- 对 I/O 密集型任务API 调用)使用 `ThreadPoolExecutor`
- 实现停止标志以实现优雅关闭:`threading.Event()`
- 数据同步默认工作线程数10
- 出错时始终使用 `executor.shutdown(wait=False, cancel_futures=True)`
### 日志记录
- 使用带前缀的 print 语句:`[模块名] 消息`
- 错误格式:`[ERROR] 上下文: 异常`
- 进度:循环中使用 `tqdm`
### 测试
- 使用 **pytest** 框架
- 模拟外部依赖Tushare API
- 使用 `@pytest.fixture` 进行测试设置
- 在导入位置打补丁:`patch('src.data.sync.Storage')`
- 测试成功和错误两种情况
### 日期格式
- 使用 `YYYYMMDD` 字符串格式表示日期
- 辅助函数:`get_today_date()``get_next_date()`
- 完全同步的默认开始日期:`20180101`
### 依赖项
关键包:
- `pandas>=2.0.0` - 数据处理
- `polars>=0.20.0` - 高性能数据处理(因子计算)
- `numpy>=1.24.0` - 数值计算
- `tushare>=2.0.0` - A股数据 API
- `pydantic>=2.0.0``pydantic-settings>=2.0.0` - 配置
- `tqdm>=4.65.0` - 进度条
- `lightgbm>=4.0.0` - 机器学习模型
- `pytest` - 测试(开发)
### 环境变量
创建 `config/.env.local`
```bash
TUSHARE_TOKEN=your_token_here
DATA_PATH=data
RATE_LIMIT=100
THREADS=10
```
## 常见任务
```bash
# 同步所有股票(增量)
uv run python -c "from src.data.sync import sync_all; sync_all()"
# 强制完全同步
uv run python -c "from src.data.sync import sync_all; sync_all(force_full=True)"
# 自定义线程数
uv run python -c "from src.data.sync import sync_all; sync_all(max_workers=20)"
# 同步财务数据
uv run python -c "from src.data.api_wrappers.financial_data import sync_financial; sync_financial()"
# 运行因子计算测试
uv run pytest tests/test_factor_engine.py -v
```
## Factors 框架设计说明
### 架构层次
因子框架采用分层设计,从上到下依次是:
```
API 层 (api.py)
|
v
DSL 层 (dsl.py) <- 因子表达式 (Node)
|
v
Compiler (compiler.py) <- AST 依赖提取
|
v
Parser (parser.py) <- 字符串公式解析器
|
v
Registry (registry.py) <- 函数注册表
|
v
Translator (translator.py) <- 翻译为 Polars 表达式
|
v
Engine (engine/) <- 执行引擎
| - FactorEngine: 统一入口
| - DataRouter: 数据路由
| - ExecutionPlanner: 执行计划
| - ComputeEngine: 计算引擎
|
v
数据层 (data_router.py + DuckDB) <- 数据获取和存储
```
### FactorEngine 核心 API
```python
from src.factors import FactorEngine
# 初始化引擎
engine = FactorEngine()
# 方式1: 使用 DSL 表达式
from src.factors.api import close, ts_mean, cs_rank
engine.register("ma20", ts_mean(close, 20))
engine.register("price_rank", cs_rank(close))
# 方式2: 使用字符串表达式(推荐)
engine.add_factor("ma20", "ts_mean(close, 20)")
engine.add_factor("alpha", "cs_rank(ts_mean(close, 5) - ts_mean(close, 20))")
# 计算因子
result = engine.compute(["ma20", "price_rank"], "20240101", "20240131")
# 查看已注册因子
print(engine.list_registered())
```
### 支持的函数
**时间序列函数 (ts_*)**
- `ts_mean(x, window)` - 滚动均值
- `ts_std(x, window)` - 滚动标准差
- `ts_max(x, window)` - 滚动最大值
- `ts_min(x, window)` - 滚动最小值
- `ts_sum(x, window)` - 滚动求和
- `ts_delay(x, periods)` - 滞后 N 期
- `ts_delta(x, periods)` - 差分 N 期
- `ts_corr(x, y, window)` - 滚动相关系数
- `ts_cov(x, y, window)` - 滚动协方差
- `ts_rank(x, window)` - 滚动排名
**截面函数 (cs_*)**
- `cs_rank(x)` - 截面排名(分位数)
- `cs_zscore(x)` - Z-Score 标准化
- `cs_neutralize(x, group)` - 行业/市值中性化
- `cs_winsorize(x, lower, upper)` - 缩尾处理
- `cs_demean(x)` - 去均值
**数学函数**
- `log(x)` - 自然对数
- `exp(x)` - 指数函数
- `sqrt(x)` - 平方根
- `sign(x)` - 符号函数
- `abs(x)` - 绝对值
- `max_(x, y)` / `min_(x, y)` - 逐元素最值
- `clip(x, lower, upper)` - 数值裁剪
**条件函数**
- `if_(condition, true_val, false_val)` - 条件选择
- `where(condition, true_val, false_val)` - if_ 的别名
### 运算符支持
DSL 表达式支持完整的 Python 运算符:
```python
# 算术运算: +, -, *, /, //, %, **
expr1 = (close - open) / open * 100 # 涨跌幅
# 比较运算: ==, !=, <, <=, >, >=
expr2 = close > open # 是否上涨
# 一元运算: -, +, abs()
expr3 = -change # 涨跌额取反
# 链式调用
expr4 = ts_mean(cs_rank(close), 20) # 排名后的20日平滑
```
### 异常处理
框架提供清晰的异常类型帮助定位问题:
- `FormulaParseError` - 公式解析错误基类
- `UnknownFunctionError` - 未知函数错误(提供模糊匹配建议)
- `InvalidSyntaxError` - 语法错误
- `EmptyExpressionError` - 空表达式错误
- `DuplicateFunctionError` - 函数重复注册错误
```python
from src.factors import FormulaParser, FunctionRegistry, UnknownFunctionError
parser = FormulaParser(FunctionRegistry())
try:
expr = parser.parse("unknown_func(close)")
except UnknownFunctionError as e:
print(e) # 显示错误位置和可用函数建议
```
## Training 模块设计说明
### 架构概述
Training 模块位于 `src/training/` 目录,负责从因子数据到模型训练、预测的完整流程。采用组件化设计,支持数据处理器、模型、过滤器、股票池管理器的灵活组合。
```
src/training/
├── core/
│ ├── trainer.py # Trainer 主类
│ └── stock_pool_manager.py # 股票池管理器
├── components/
│ ├── base.py # BaseModel、BaseProcessor 抽象基类
│ ├── splitters.py # DateSplitter 日期划分器
│ ├── filters.py # STFilter 等过滤器
│ ├── models/
│ │ └── lightgbm.py # LightGBMModel
│ └── processors/
│ └── transforms.py # 数据处理器实现
├── config/
│ └── config.py # TrainingConfig
└── registry.py # 组件注册中心
```
### Trainer 核心流程
```python
from src.training import Trainer, DateSplitter, StockPoolManager
from src.training.components.models import LightGBMModel
from src.training.components.processors import Winsorizer, StandardScaler
from src.training.components.filters import STFilter
import polars as pl
# 1. 创建模型
model = LightGBMModel(params={
"objective": "regression",
"metric": "mae",
"num_leaves": 20,
"learning_rate": 0.01,
"n_estimators": 1000,
})
# 2. 创建数据划分器(正确的 train/val/test 三分法)
splitter = DateSplitter(
train_start="20200101",
train_end="20231231",
val_start="20240101",
val_end="20241231",
test_start="20250101",
test_end="20261231",
)
# 3. 创建数据处理器
processors = [
NullFiller(strategy="mean"),
Winsorizer(lower=0.01, upper=0.99),
StandardScaler(exclude_cols=["ts_code", "trade_date", "target"]),
]
# 4. 创建股票池筛选函数
def stock_pool_filter(df: pl.DataFrame) -> pl.Series:
"""筛选小市值股票,排除创业板/科创板/北交所"""
code_filter = (
~df["ts_code"].str.starts_with("300") & # 排除创业板
~df["ts_code"].str.starts_with("688") & # 排除科创板
~df["ts_code"].str.starts_with("8") & # 排除北交所
~df["ts_code"].str.starts_with("9") &
~df["ts_code"].str.starts_with("4")
)
valid_df = df.filter(code_filter)
n = min(1000, len(valid_df))
small_cap_codes = valid_df.sort("total_mv").head(n)["ts_code"]
return df["ts_code"].is_in(small_cap_codes)
pool_manager = StockPoolManager(
filter_func=stock_pool_filter,
required_columns=["total_mv"],
data_router=engine.router,
)
# 5. 创建 ST 过滤器
st_filter = STFilter(data_router=engine.router)
# 6. 创建训练器
trainer = Trainer(
model=model,
pool_manager=pool_manager,
processors=processors,
filters=[st_filter],
splitter=splitter,
target_col="future_return_5",
feature_cols=["ma_5", "ma_20", "volume_ratio", "roe"],
)
# 7. 执行训练
trainer.train(data)
# 8. 获取结果
results = trainer.get_results()
```
### 数据处理器
**NullFiller** - 缺失值填充:
```python
from src.training.components.processors import NullFiller
# 使用 0 填充
filler = NullFiller(strategy="zero")
# 使用均值填充(每天独立计算截面均值)
filler = NullFiller(strategy="mean", by_date=True)
# 使用指定值填充
filler = NullFiller(strategy="value", fill_value=-999)
```
**Winsorizer** - 缩尾处理:
```python
from src.training.components.processors import Winsorizer
# 全局缩尾(默认)
winsorizer = Winsorizer(lower=0.01, upper=0.99, by_date=False)
# 每天独立缩尾
winsorizer = Winsorizer(lower=0.01, upper=0.99, by_date=True)
```
**StandardScaler** - 标准化:
```python
from src.training.components.processors import StandardScaler
# 全局标准化(学习训练集的均值和标准差)
scaler = StandardScaler(exclude_cols=["ts_code", "trade_date", "target"])
```
**CrossSectionalStandardScaler** - 截面标准化:
```python
from src.training.components.processors import CrossSectionalStandardScaler
# 每天独立标准化(不需要 fit
cs_scaler = CrossSectionalStandardScaler(
exclude_cols=["ts_code", "trade_date", "target"],
date_col="trade_date",
)
```
### 组件注册机制
```python
from src.training.registry import register_model, register_processor
from src.training.components.base import BaseModel, BaseProcessor
# 注册自定义模型
@register_model("custom_model")
class CustomModel(BaseModel):
name = "custom_model"
def fit(self, X, y):
# 训练逻辑
return self
def predict(self, X):
# 预测逻辑
return predictions
# 注册自定义处理器
@register_processor("custom_processor")
class CustomProcessor(BaseProcessor):
name = "custom_processor"
def transform(self, X):
# 转换逻辑
return X
```
## AI 行为准则
### LSP 检测报错处理
**⚠️ 强制要求:当进行 LSP 检测时报错,必定是代码格式问题。**
如果 LSP 检测报错,必须按照以下流程处理:
1. **问题定位**
- 报错必定是由基础格式错误引起:缩进错误、引号括号不匹配、代码格式错误等
- 必须读取对应的代码行,精确定位错误
2. **修复方式**
-**必须**:读取报错文件,检查具体代码行
-**必须**:修复格式错误(缩进、括号匹配、引号闭合等)
-**禁止**:删除文件重新修改
-**禁止**:自行 rollback 文件
-**禁止**:新建文件重新修改
-**禁止**:忽略错误继续执行
3. **验证要求**
- 修复后必须重新运行 LSP 检测确认无错误
- 确保修改仅针对格式问题,不改变代码逻辑
**示例场景**
```
LSP 报错Syntax error on line 45
✅ 正确做法:读取文件第 45 行,发现少了一个右括号,添加后重新检测
❌ 错误做法:删除文件重新写、或者忽略错误继续
```
### Emoji 表情禁用规则
**⚠️ 强制要求:代码和测试文件中禁止出现 emoji 表情。**
1. **禁止范围**
- 所有 `.py` 源代码文件
- 所有测试文件 (`tests/` 目录)
- 配置文件、脚本文件
2. **替代方案**
- ❌ 禁止使用:`print("✅ 成功")``print("❌ 失败")``# 📝 注释`
- ✅ 应使用:`print("[成功]")``print("[失败]")``# 注释`
- 使用方括号 `[成功]``[警告]``[错误]` 等文字标记代替 emoji
3. **唯一例外**
- AGENTS.md 文件本身可以使用 emoji 进行文档强调(如本文件中的 ⚠️)
- 项目文档、README 等对外展示文件可以酌情使用
4. **检查方法**
- 使用正则表达式搜索 emoji`[\U0001F600-\U0001F64F\U0001F300-\U0001F5FF\U0001F680-\U0001F6FF\U0001F1E0-\U0001F1FF\u2600-\u26FF\u2700-\u27BF]`
- 提交前自查,确保无 emoji 混入代码