新增基于 LambdaRank 的排序学习模型,用于股票排序预测任务: - 实现 LightGBMLambdaRankModel 模型类,支持分位数标签转换 - 提供完整的训练流程和 NDCG 评估指标 - 添加实验 Notebook 演示排序学习全流程
22 KiB
22 KiB
ProStock 代理指南
A股量化投资框架 - Python 项目,用于量化股票投资分析。
交流语言要求
⚠️ 强制要求:所有沟通和思考过程必须使用中文。
- 所有与 AI Agent 的交流必须使用中文
- 代码中的注释和文档字符串使用中文
- 禁止使用英文进行思考或沟通
构建/检查/测试命令
⚠️ 重要:本项目强制使用 uv 作为 Python 包管理器和运行工具。禁止直接使用 python 或 pip 命令。
测试规则: 当修改或查看 tests/ 目录下的代码时,必须使用 pytest 命令进行测试验证。
# 安装依赖(必须使用 uv)
uv pip install -e .
# 运行所有测试
uv run pytest
# 运行单个测试文件
uv run pytest tests/test_sync.py
# 运行单个测试类
uv run pytest tests/test_sync.py::TestDataSync
# 运行单个测试方法
uv run pytest tests/test_sync.py::TestDataSync::test_get_all_stock_codes_from_daily
# 使用详细输出运行
uv run pytest -v
# 运行覆盖率测试(如果安装了 pytest-cov)
uv run pytest --cov=src --cov-report=term-missing
禁止的命令 ❌
以下命令在本项目中严格禁止:
# 禁止直接使用 python
python -c "…" # 禁止!
python script.py # 禁止!
python -m pytest # 禁止!
python -m pip install # 禁止!
# 禁止直接使用 pip
pip install -e . # 禁止!
pip install package # 禁止!
pip list # 禁止!
正确的 uv 用法 ✅
# 运行 Python 代码
uv run python -c "…" # ✅ 正确
uv run python script.py # ✅ 正确
# 安装依赖
uv pip install -e . # ✅ 正确
uv pip install package # ✅ 正确
# 运行测试
uv run pytest # ✅ 正确
uv run pytest tests/test_sync.py # ✅ 正确
项目结构
ProStock/
├── src/ # 源代码
│ ├── config/ # 配置管理
│ │ ├── __init__.py
│ │ └── settings.py # pydantic-settings 配置
│ │
│ ├── data/ # 数据获取与存储
│ │ ├── api_wrappers/ # Tushare API 封装
│ │ │ ├── base_sync.py # 同步基础抽象类
│ │ │ ├── api_daily.py # 日线数据接口
│ │ │ ├── api_pro_bar.py # Pro Bar 数据接口
│ │ │ ├── api_stock_basic.py # 股票基础信息接口
│ │ │ ├── api_trade_cal.py # 交易日历接口
│ │ │ ├── api_bak_basic.py # 历史股票列表接口
│ │ │ ├── api_namechange.py # 股票名称变更接口
│ │ │ ├── api_stock_st.py # ST股票信息接口
│ │ │ ├── api_daily_basic.py # 每日指标接口
│ │ │ ├── api_stk_limit.py # 涨跌停价格接口
│ │ │ ├── financial_data/ # 财务数据接口
│ │ │ │ ├── api_income.py # 利润表接口
│ │ │ │ ├── api_balance.py # 资产负债表接口
│ │ │ │ ├── api_cashflow.py # 现金流量表接口
│ │ │ │ ├── api_fina_indicator.py # 财务指标接口
│ │ │ │ └── api_financial_sync.py # 财务数据同步调度中心
│ │ │ └── __init__.py
│ │ ├── __init__.py
│ │ ├── client.py # Tushare API 客户端(带速率限制)
│ │ ├── storage.py # 数据存储核心
│ │ ├── db_manager.py # DuckDB 表管理和同步
│ │ ├── db_inspector.py # 数据库信息查看工具
│ │ ├── sync.py # 数据同步调度中心
│ │ ├── sync_registry.py # 同步器注册表
│ │ ├── rate_limiter.py # 令牌桶速率限制器
│ │ ├── catalog.py # 数据目录管理
│ │ ├── config.py # 数据模块配置
│ │ ├── utils.py # 数据模块工具函数
│ │ └── financial_loader.py # 财务数据加载器
│ │
│ ├── factors/ # 因子计算框架(DSL 表达式驱动)
│ │ ├── engine/ # 执行引擎子模块
│ │ │ ├── __init__.py # 导出引擎组件
│ │ │ ├── data_spec.py # 数据规格定义
│ │ │ ├── data_router.py # 数据路由器
│ │ │ ├── planner.py # 执行计划生成器
│ │ │ ├── compute_engine.py # 计算引擎
│ │ │ ├── schema_cache.py # 表结构缓存
│ │ │ └── factor_engine.py # 因子引擎统一入口
│ │ ├── __init__.py # 导出所有公开 API
│ │ ├── dsl.py # DSL 表达式层 - 节点定义和运算符重载
│ │ ├── api.py # API 层 - 常用符号和函数
│ │ ├── compiler.py # AST 编译器 - 依赖提取
│ │ ├── translator.py # Polars 表达式翻译器
│ │ ├── parser.py # 字符串公式解析器
│ │ ├── registry.py # 函数注册表
│ │ ├── decorators.py # 装饰器工具
│ │ └── exceptions.py # 异常定义
│ │
│ ├── training/ # 训练模块
│ │ ├── core/ # 训练核心组件
│ │ │ ├── __init__.py
│ │ │ ├── trainer.py # 训练器主类
│ │ │ └── stock_pool_manager.py # 股票池管理器
│ │ ├── components/ # 组件
│ │ │ ├── base.py # 基础抽象类
│ │ │ ├── splitters.py # 数据划分器
│ │ │ ├── selectors.py # 股票选择器
│ │ │ ├── filters.py # 数据过滤器
│ │ │ ├── models/ # 模型实现
│ │ │ │ ├── __init__.py
│ │ │ │ └── lightgbm.py # LightGBM 模型
│ │ │ └── processors/ # 数据处理器
│ │ │ ├── __init__.py
│ │ │ └── transforms.py # 变换处理器
│ │ ├── config/ # 配置
│ │ │ ├── __init__.py
│ │ │ └── config.py # 训练配置
│ │ ├── registry.py # 组件注册中心
│ │ └── __init__.py # 导出所有组件
│ │
│ └── experiment/ # 实验代码
│ └── regression.ipynb # 完整训练流程示例
│
├── tests/ # 测试文件
│ ├── test_sync.py
│ ├── test_daily.py
│ ├── test_factor_engine.py
│ ├── test_factor_integration.py
│ ├── test_pro_bar.py
│ ├── test_db_manager.py
│ ├── test_daily_storage.py
│ ├── test_tushare_api.py
│ └── pipeline/
│ └── test_core.py
├── config/ # 配置文件
│ └── .env.local # 环境变量(不在 git 中)
├── data/ # 数据存储(DuckDB)
├── docs/ # 文档
├── pyproject.toml # 项目配置
└── README.md
代码风格指南
Python 版本
- 需要 Python 3.10+
- 使用现代 Python 特性(match/case、海象运算符、类型提示)
导入
# 标准库优先
import os
import time
from datetime import datetime, timedelta
from pathlib import Path
from typing import Optional, Dict, Callable
from concurrent.futures import ThreadPoolExecutor
import threading
# 第三方包
import pandas as pd
import numpy as np
import polars as pl
from tqdm import tqdm
from pydantic_settings import BaseSettings
# 本地模块(使用来自 src 的绝对导入)
from src.data.client import TushareClient
from src.data.storage import Storage
from src.config.settings import get_settings
类型提示
- 始终使用类型提示 用于函数参数和返回值
- 对可空类型使用
Optional[X] - 当可用时使用现代联合语法
X | Y(Python 3.10+) - 从
typing导入类型:Optional、Dict、Callable等
def sync_single_stock(
self,
ts_code: str,
start_date: str,
end_date: str,
) -> pd.DataFrame:
...
文档字符串
- 使用 Google 风格文档字符串
- 包含 Args、Returns 部分
- 第一行保持简短摘要
def get_next_date(date_str: str) -> str:
"""获取给定日期之后的下一天。
Args:
date_str: YYYYMMDD 格式的日期
Returns:
YYYYMMDD 格式的下一天日期
"""
...
命名约定
- 变量、函数、方法使用
snake_case - 类使用
PascalCase - 常量使用
UPPER_CASE - 私有方法:
_leading_underscore - 受保护属性:
_single_underscore
错误处理
- 使用特定的异常,不要使用裸
except: - 使用上下文记录错误:
print(f"[ERROR] 上下文: {e}") - 对 API 调用使用指数退避重试逻辑
- 在关键错误时立即停止(设置停止标志)
try:
data = api.query(...)
except Exception as e:
print(f"[ERROR] 获取 {ts_code} 失败: {e}")
raise # 记录后重新抛出
配置
- 对所有配置使用 pydantic-settings
- 从
config/.env.local文件加载 - 环境变量自动转换:
tushare_token->TUSHARE_TOKEN - 对配置单例使用
@lru_cache()
数据存储
- 使用 DuckDB 嵌入式 OLAP 数据库进行持久化
- 存储在
data/目录中(通过DATA_PATH环境变量配置) - 使用 UPSERT 模式(
INSERT OR REPLACE)处理重复数据 - 多线程场景使用
ThreadSafeStorage.queue_save()+flush()模式
线程与并发
- 对 I/O 密集型任务(API 调用)使用
ThreadPoolExecutor - 实现停止标志以实现优雅关闭:
threading.Event() - 数据同步默认工作线程数:10
- 出错时始终使用
executor.shutdown(wait=False, cancel_futures=True)
日志记录
- 使用带前缀的 print 语句:
[模块名] 消息 - 错误格式:
[ERROR] 上下文: 异常 - 进度:循环中使用
tqdm
测试
- 使用 pytest 框架
- 模拟外部依赖(Tushare API)
- 使用
@pytest.fixture进行测试设置 - 在导入位置打补丁:
patch('src.data.sync.Storage') - 测试成功和错误两种情况
日期格式
- 使用
YYYYMMDD字符串格式表示日期 - 辅助函数:
get_today_date()、get_next_date() - 完全同步的默认开始日期:
20180101
依赖项
关键包:
pandas>=2.0.0- 数据处理polars>=0.20.0- 高性能数据处理(因子计算)numpy>=1.24.0- 数值计算tushare>=2.0.0- A股数据 APIpydantic>=2.0.0、pydantic-settings>=2.0.0- 配置tqdm>=4.65.0- 进度条lightgbm>=4.0.0- 机器学习模型pytest- 测试(开发)
环境变量
创建 config/.env.local:
TUSHARE_TOKEN=your_token_here
DATA_PATH=data
RATE_LIMIT=100
THREADS=10
常见任务
# 同步所有股票(增量)
uv run python -c "from src.data.sync import sync_all; sync_all()"
# 强制完全同步
uv run python -c "from src.data.sync import sync_all; sync_all(force_full=True)"
# 自定义线程数
uv run python -c "from src.data.sync import sync_all; sync_all(max_workers=20)"
# 同步财务数据
uv run python -c "from src.data.api_wrappers.financial_data import sync_financial; sync_financial()"
# 运行因子计算测试
uv run pytest tests/test_factor_engine.py -v
Factors 框架设计说明
架构层次
因子框架采用分层设计,从上到下依次是:
API 层 (api.py)
|
v
DSL 层 (dsl.py) <- 因子表达式 (Node)
|
v
Compiler (compiler.py) <- AST 依赖提取
|
v
Parser (parser.py) <- 字符串公式解析器
|
v
Registry (registry.py) <- 函数注册表
|
v
Translator (translator.py) <- 翻译为 Polars 表达式
|
v
Engine (engine/) <- 执行引擎
| - FactorEngine: 统一入口
| - DataRouter: 数据路由
| - ExecutionPlanner: 执行计划
| - ComputeEngine: 计算引擎
|
v
数据层 (data_router.py + DuckDB) <- 数据获取和存储
FactorEngine 核心 API
from src.factors import FactorEngine
# 初始化引擎
engine = FactorEngine()
# 方式1: 使用 DSL 表达式
from src.factors.api import close, ts_mean, cs_rank
engine.register("ma20", ts_mean(close, 20))
engine.register("price_rank", cs_rank(close))
# 方式2: 使用字符串表达式(推荐)
engine.add_factor("ma20", "ts_mean(close, 20)")
engine.add_factor("alpha", "cs_rank(ts_mean(close, 5) - ts_mean(close, 20))")
# 计算因子
result = engine.compute(["ma20", "price_rank"], "20240101", "20240131")
# 查看已注册因子
print(engine.list_registered())
支持的函数
时间序列函数 (ts_*):
ts_mean(x, window)- 滚动均值ts_std(x, window)- 滚动标准差ts_max(x, window)- 滚动最大值ts_min(x, window)- 滚动最小值ts_sum(x, window)- 滚动求和ts_delay(x, periods)- 滞后 N 期ts_delta(x, periods)- 差分 N 期ts_corr(x, y, window)- 滚动相关系数ts_cov(x, y, window)- 滚动协方差ts_rank(x, window)- 滚动排名
截面函数 (cs_*):
cs_rank(x)- 截面排名(分位数)cs_zscore(x)- Z-Score 标准化cs_neutralize(x, group)- 行业/市值中性化cs_winsorize(x, lower, upper)- 缩尾处理cs_demean(x)- 去均值
数学函数:
log(x)- 自然对数exp(x)- 指数函数sqrt(x)- 平方根sign(x)- 符号函数abs(x)- 绝对值max_(x, y)/min_(x, y)- 逐元素最值clip(x, lower, upper)- 数值裁剪
条件函数:
if_(condition, true_val, false_val)- 条件选择where(condition, true_val, false_val)- if_ 的别名
运算符支持
DSL 表达式支持完整的 Python 运算符:
# 算术运算: +, -, *, /, //, %, **
expr1 = (close - open) / open * 100 # 涨跌幅
# 比较运算: ==, !=, <, <=, >, >=
expr2 = close > open # 是否上涨
# 一元运算: -, +, abs()
expr3 = -change # 涨跌额取反
# 链式调用
expr4 = ts_mean(cs_rank(close), 20) # 排名后的20日平滑
异常处理
框架提供清晰的异常类型帮助定位问题:
FormulaParseError- 公式解析错误基类UnknownFunctionError- 未知函数错误(提供模糊匹配建议)InvalidSyntaxError- 语法错误EmptyExpressionError- 空表达式错误DuplicateFunctionError- 函数重复注册错误
from src.factors import FormulaParser, FunctionRegistry, UnknownFunctionError
parser = FormulaParser(FunctionRegistry())
try:
expr = parser.parse("unknown_func(close)")
except UnknownFunctionError as e:
print(e) # 显示错误位置和可用函数建议
Training 模块设计说明
架构概述
Training 模块位于 src/training/ 目录,负责从因子数据到模型训练、预测的完整流程。采用组件化设计,支持数据处理器、模型、过滤器、股票池管理器的灵活组合。
src/training/
├── core/
│ ├── trainer.py # Trainer 主类
│ └── stock_pool_manager.py # 股票池管理器
├── components/
│ ├── base.py # BaseModel、BaseProcessor 抽象基类
│ ├── splitters.py # DateSplitter 日期划分器
│ ├── filters.py # STFilter 等过滤器
│ ├── models/
│ │ └── lightgbm.py # LightGBMModel
│ └── processors/
│ └── transforms.py # 数据处理器实现
├── config/
│ └── config.py # TrainingConfig
└── registry.py # 组件注册中心
Trainer 核心流程
from src.training import Trainer, DateSplitter, StockPoolManager
from src.training.components.models import LightGBMModel
from src.training.components.processors import Winsorizer, StandardScaler
from src.training.components.filters import STFilter
import polars as pl
# 1. 创建模型
model = LightGBMModel(params={
"objective": "regression",
"metric": "mae",
"num_leaves": 20,
"learning_rate": 0.01,
"n_estimators": 1000,
})
# 2. 创建数据划分器(正确的 train/val/test 三分法)
splitter = DateSplitter(
train_start="20200101",
train_end="20231231",
val_start="20240101",
val_end="20241231",
test_start="20250101",
test_end="20261231",
)
# 3. 创建数据处理器
processors = [
NullFiller(strategy="mean"),
Winsorizer(lower=0.01, upper=0.99),
StandardScaler(exclude_cols=["ts_code", "trade_date", "target"]),
]
# 4. 创建股票池筛选函数
def stock_pool_filter(df: pl.DataFrame) -> pl.Series:
"""筛选小市值股票,排除创业板/科创板/北交所"""
code_filter = (
~df["ts_code"].str.starts_with("300") & # 排除创业板
~df["ts_code"].str.starts_with("688") & # 排除科创板
~df["ts_code"].str.starts_with("8") & # 排除北交所
~df["ts_code"].str.starts_with("9") &
~df["ts_code"].str.starts_with("4")
)
valid_df = df.filter(code_filter)
n = min(1000, len(valid_df))
small_cap_codes = valid_df.sort("total_mv").head(n)["ts_code"]
return df["ts_code"].is_in(small_cap_codes)
pool_manager = StockPoolManager(
filter_func=stock_pool_filter,
required_columns=["total_mv"],
data_router=engine.router,
)
# 5. 创建 ST 过滤器
st_filter = STFilter(data_router=engine.router)
# 6. 创建训练器
trainer = Trainer(
model=model,
pool_manager=pool_manager,
processors=processors,
filters=[st_filter],
splitter=splitter,
target_col="future_return_5",
feature_cols=["ma_5", "ma_20", "volume_ratio", "roe"],
)
# 7. 执行训练
trainer.train(data)
# 8. 获取结果
results = trainer.get_results()
数据处理器
NullFiller - 缺失值填充:
from src.training.components.processors import NullFiller
# 使用 0 填充
filler = NullFiller(strategy="zero")
# 使用均值填充(每天独立计算截面均值)
filler = NullFiller(strategy="mean", by_date=True)
# 使用指定值填充
filler = NullFiller(strategy="value", fill_value=-999)
Winsorizer - 缩尾处理:
from src.training.components.processors import Winsorizer
# 全局缩尾(默认)
winsorizer = Winsorizer(lower=0.01, upper=0.99, by_date=False)
# 每天独立缩尾
winsorizer = Winsorizer(lower=0.01, upper=0.99, by_date=True)
StandardScaler - 标准化:
from src.training.components.processors import StandardScaler
# 全局标准化(学习训练集的均值和标准差)
scaler = StandardScaler(exclude_cols=["ts_code", "trade_date", "target"])
CrossSectionalStandardScaler - 截面标准化:
from src.training.components.processors import CrossSectionalStandardScaler
# 每天独立标准化(不需要 fit)
cs_scaler = CrossSectionalStandardScaler(
exclude_cols=["ts_code", "trade_date", "target"],
date_col="trade_date",
)
组件注册机制
from src.training.registry import register_model, register_processor
from src.training.components.base import BaseModel, BaseProcessor
# 注册自定义模型
@register_model("custom_model")
class CustomModel(BaseModel):
name = "custom_model"
def fit(self, X, y):
# 训练逻辑
return self
def predict(self, X):
# 预测逻辑
return predictions
# 注册自定义处理器
@register_processor("custom_processor")
class CustomProcessor(BaseProcessor):
name = "custom_processor"
def transform(self, X):
# 转换逻辑
return X
AI 行为准则
LSP 检测报错处理
⚠️ 强制要求:当进行 LSP 检测时报错,必定是代码格式问题。
如果 LSP 检测报错,必须按照以下流程处理:
-
问题定位
- 报错必定是由基础格式错误引起:缩进错误、引号括号不匹配、代码格式错误等
- 必须读取对应的代码行,精确定位错误
-
修复方式
- ✅ 必须:读取报错文件,检查具体代码行
- ✅ 必须:修复格式错误(缩进、括号匹配、引号闭合等)
- ❌ 禁止:删除文件重新修改
- ❌ 禁止:自行 rollback 文件
- ❌ 禁止:新建文件重新修改
- ❌ 禁止:忽略错误继续执行
-
验证要求
- 修复后必须重新运行 LSP 检测确认无错误
- 确保修改仅针对格式问题,不改变代码逻辑
示例场景:
LSP 报错:Syntax error on line 45
✅ 正确做法:读取文件第 45 行,发现少了一个右括号,添加后重新检测
❌ 错误做法:删除文件重新写、或者忽略错误继续
Emoji 表情禁用规则
⚠️ 强制要求:代码和测试文件中禁止出现 emoji 表情。
-
禁止范围
- 所有
.py源代码文件 - 所有测试文件 (
tests/目录) - 配置文件、脚本文件
- 所有
-
替代方案
- ❌ 禁止使用:
print("✅ 成功")、print("❌ 失败")、# 📝 注释 - ✅ 应使用:
print("[成功]")、print("[失败]")、# 注释 - 使用方括号
[成功]、[警告]、[错误]等文字标记代替 emoji
- ❌ 禁止使用:
-
唯一例外
- AGENTS.md 文件本身可以使用 emoji 进行文档强调(如本文件中的 ⚠️)
- 项目文档、README 等对外展示文件可以酌情使用
-
检查方法
- 使用正则表达式搜索 emoji:
[\U0001F600-\U0001F64F\U0001F300-\U0001F5FF\U0001F680-\U0001F6FF\U0001F1E0-\U0001F1FF\u2600-\u26FF\u2700-\u27BF] - 提交前自查,确保无 emoji 混入代码
- 使用正则表达式搜索 emoji: