Files
ProStock/AGENTS.md

28 KiB
Raw Blame History

ProStock 代理指南

A股量化投资框架 - Python 项目,用于量化股票投资分析。

交流语言要求

⚠️ 强制要求:所有沟通和思考过程必须使用中文。

  • 所有与 AI Agent 的交流必须使用中文
  • 代码中的注释和文档字符串使用中文
  • 禁止使用英文进行思考或沟通

构建/检查/测试命令

⚠️ 重要:本项目强制使用 uv 作为 Python 包管理器和运行工具。禁止直接使用 pythonpip 命令。

测试规则: 当修改或查看 tests/ 目录下的代码时,必须使用 pytest 命令进行测试验证。

# 安装依赖(必须使用 uv
uv pip install -e .

# 运行所有测试
uv run pytest

# 运行单个测试文件
uv run pytest tests/test_sync.py

# 运行单个测试类
uv run pytest tests/test_sync.py::TestDataSync

# 运行单个测试方法
uv run pytest tests/test_sync.py::TestDataSync::test_get_all_stock_codes_from_daily

# 使用详细输出运行
uv run pytest -v

# 运行覆盖率测试(如果安装了 pytest-cov
uv run pytest --cov=src --cov-report=term-missing

禁止的命令

以下命令在本项目中严格禁止

# 禁止直接使用 python
python -c "…"           # 禁止!
python script.py          # 禁止!
python -m pytest          # 禁止!
python -m pip install     # 禁止!

# 禁止直接使用 pip
pip install -e .          # 禁止!
pip install package       # 禁止!
pip list                  # 禁止!

正确的 uv 用法

# 运行 Python 代码
uv run python -c "…"           # ✅ 正确
uv run python script.py          # ✅ 正确

# 安装依赖
uv pip install -e .              # ✅ 正确
uv pip install package           # ✅ 正确

# 运行测试
uv run pytest                    # ✅ 正确
uv run pytest tests/test_sync.py # ✅ 正确

项目结构

ProStock/
├── src/                    # 源代码
│   ├── config/            # 配置管理
│   │   ├── __init__.py
│   │   └── settings.py    # pydantic-settings 配置
│   │
│   ├── data/              # 数据获取与存储
│   │   ├── api_wrappers/  # Tushare API 封装
│   │   │   ├── base_sync.py              # 同步基础抽象类
│   │   │   ├── api_daily.py              # 日线数据接口
│   │   │   ├── api_pro_bar.py            # Pro Bar 数据接口
│   │   │   ├── api_stock_basic.py        # 股票基础信息接口
│   │   │   ├── api_trade_cal.py          # 交易日历接口
│   │   │   ├── api_bak_basic.py          # 历史股票列表接口
│   │   │   ├── api_namechange.py         # 股票名称变更接口
│   │   │   ├── api_stock_st.py           # ST股票信息接口
│   │   │   ├── api_daily_basic.py        # 每日指标接口
│   │   │   ├── api_stk_limit.py          # 涨跌停价格接口
│   │   │   ├── financial_data/           # 财务数据接口
│   │   │   │   ├── api_income.py         # 利润表接口
│   │   │   │   ├── api_balance.py        # 资产负债表接口
│   │   │   │   ├── api_cashflow.py       # 现金流量表接口
│   │   │   │   ├── api_fina_indicator.py # 财务指标接口
│   │   │   │   └── api_financial_sync.py # 财务数据同步调度中心
│   │   │   └── __init__.py
│   │   ├── __init__.py
│   │   ├── client.py       # Tushare API 客户端(带速率限制)
│   │   ├── storage.py      # 数据存储核心
│   │   ├── db_manager.py   # DuckDB 表管理和同步
│   │   ├── db_inspector.py # 数据库信息查看工具
│   │   ├── sync.py         # 数据同步调度中心
│   │   ├── sync_registry.py # 同步器注册表
│   │   ├── rate_limiter.py # 令牌桶速率限制器
│   │   ├── catalog.py      # 数据目录管理
│   │   ├── config.py       # 数据模块配置
│   │   ├── utils.py        # 数据模块工具函数
│   │   └── financial_loader.py # 财务数据加载器
│   │
│   ├── factors/           # 因子计算框架DSL 表达式驱动)
│   │   ├── engine/        # 执行引擎子模块
│   │   │   ├── __init__.py        # 导出引擎组件
│   │   │   ├── data_spec.py       # 数据规格定义
│   │   │   ├── data_router.py     # 数据路由器
│   │   │   ├── planner.py         # 执行计划生成器
│   │   │   ├── compute_engine.py  # 计算引擎
│   │   │   ├── schema_cache.py    # 表结构缓存
│   │   │   └── factor_engine.py   # 因子引擎统一入口
│   │   ├── metadata/      # 因子元数据管理
│   │   │   ├── __init__.py        # 导出元数据组件
│   │   │   ├── manager.py         # 因子管理器主类
│   │   │   ├── validator.py       # 字段校验器
│   │   │   └── exceptions.py      # 元数据异常定义
│   │   ├── __init__.py     # 导出所有公开 API
│   │   ├── dsl.py          # DSL 表达式层 - 节点定义和运算符重载
│   │   ├── api.py          # API 层 - 常用符号和函数
│   │   ├── compiler.py     # AST 编译器 - 依赖提取
│   │   ├── translator.py   # Polars 表达式翻译器
│   │   ├── parser.py       # 字符串公式解析器
│   │   ├── registry.py     # 函数注册表
│   │   ├── decorators.py   # 装饰器工具
│   │   └── exceptions.py   # 异常定义
│   │
│   ├── training/          # 训练模块
│   │   ├── core/           # 训练核心组件
│   │   │   ├── __init__.py
│   │   │   ├── trainer.py      # 训练器主类
│   │   │   └── stock_pool_manager.py # 股票池管理器
│   │   ├── components/     # 组件
│   │   │   ├── base.py     # 基础抽象类
│   │   │   ├── splitters.py # 数据划分器
│   │   │   ├── selectors.py # 股票选择器
│   │   │   ├── filters.py  # 数据过滤器
│   │   │   ├── models/     # 模型实现
│   │   │   │   ├── __init__.py
│   │   │   │   ├── lightgbm.py # LightGBM 模型
│   │   │   │   └── lightgbm_lambdarank.py # LambdaRank 排序模型
│   │   │   └── processors/ # 数据处理器
│   │   │       ├── __init__.py
│   │   │       └── transforms.py # 变换处理器
│   │   ├── config/         # 配置
│   │   │   ├── __init__.py
│   │   │   └── config.py   # 训练配置
│   │   ├── registry.py     # 组件注册中心
│   │   └── __init__.py     # 导出所有组件
│   │
│   ├── scripts/           # 脚本工具
│   │   └── register_factors.py    # 因子批量注册脚本
│   │
│   └── experiment/        # 实验代码
│       ├── data/          # 实验数据目录
│       ├── regression.py  # 回归训练流程Python脚本
│       ├── learn_to_rank.py  # 排序学习训练流程Python脚本
│       └── regression.ipynb  # 完整训练流程示例Notebook
│
├── tests/                 # 测试文件
│   ├── test_sync.py
│   ├── test_daily.py
│   ├── test_factor_engine.py
│   ├── test_factor_integration.py
│   ├── test_pro_bar.py
│   ├── test_db_manager.py
│   ├── test_daily_storage.py
│   ├── test_tushare_api.py
│   └── pipeline/
│       └── test_core.py
├── config/                # 配置文件
│   └── .env.local         # 环境变量(不在 git 中)
├── data/                  # 数据存储DuckDB
├── docs/                  # 文档
├── pyproject.toml         # 项目配置
└── README.md

代码风格指南

Python 版本

  • 需要 Python 3.10+
  • 使用现代 Python 特性match/case、海象运算符、类型提示

导入

# 标准库优先
import os
import time
from datetime import datetime, timedelta
from pathlib import Path
from typing import Optional, Dict, Callable
from concurrent.futures import ThreadPoolExecutor
import threading

# 第三方包
import pandas as pd
import numpy as np
import polars as pl
from tqdm import tqdm
from pydantic_settings import BaseSettings

# 本地模块(使用来自 src 的绝对导入)
from src.data.client import TushareClient
from src.data.storage import Storage
from src.config.settings import get_settings

类型提示

  • 始终使用类型提示 用于函数参数和返回值
  • 对可空类型使用 Optional[X]
  • 当可用时使用现代联合语法 X | YPython 3.10+
  • typing 导入类型:OptionalDictCallable
def sync_single_stock(
    self,
    ts_code: str,
    start_date: str,
    end_date: str,
) -> pd.DataFrame:
    ...

文档字符串

  • 使用 Google 风格文档字符串
  • 包含 Args、Returns 部分
  • 第一行保持简短摘要
def get_next_date(date_str: str) -> str:
    """获取给定日期之后的下一天。

    Args:
        date_str: YYYYMMDD 格式的日期

    Returns:
        YYYYMMDD 格式的下一天日期
    """
    ...

命名约定

  • 变量、函数、方法使用 snake_case
  • 类使用 PascalCase
  • 常量使用 UPPER_CASE
  • 私有方法:_leading_underscore
  • 受保护属性:_single_underscore

错误处理

  • 使用特定的异常,不要使用裸 except:
  • 使用上下文记录错误:print(f"[ERROR] 上下文: {e}")
  • 对 API 调用使用指数退避重试逻辑
  • 在关键错误时立即停止(设置停止标志)
try:
    data = api.query(...)
except Exception as e:
    print(f"[ERROR] 获取 {ts_code} 失败: {e}")
    raise  # 记录后重新抛出

配置

  • 对所有配置使用 pydantic-settings
  • config/.env.local 文件加载
  • 环境变量自动转换:tushare_token -> TUSHARE_TOKEN
  • 对配置单例使用 @lru_cache()

数据存储

  • 使用 DuckDB 嵌入式 OLAP 数据库进行持久化
  • 存储在 data/ 目录中(通过 DATA_PATH 环境变量配置)
  • 使用 UPSERT 模式(INSERT OR REPLACE)处理重复数据
  • 多线程场景使用 ThreadSafeStorage.queue_save() + flush() 模式
  • 只读模式支持: 查询时默认启用 read_only=True,避免并发冲突
from src.data.storage import Storage

# 查询模式(只读,推荐用于数据查询)
storage = Storage(read_only=True)  # 默认只读

# 写入模式(用于数据同步)
storage = Storage(read_only=False)

线程与并发

  • 对 I/O 密集型任务API 调用)使用 ThreadPoolExecutor
  • 实现停止标志以实现优雅关闭:threading.Event()
  • 数据同步默认工作线程数10
  • 出错时始终使用 executor.shutdown(wait=False, cancel_futures=True)

日志记录

  • 使用带前缀的 print 语句:[模块名] 消息
  • 错误格式:[ERROR] 上下文: 异常
  • 进度:循环中使用 tqdm

测试

  • 使用 pytest 框架
  • 模拟外部依赖Tushare API
  • 使用 @pytest.fixture 进行测试设置
  • 在导入位置打补丁:patch('src.data.sync.Storage')
  • 测试成功和错误两种情况

日期格式

  • 使用 YYYYMMDD 字符串格式表示日期
  • 辅助函数:get_today_date()get_next_date()
  • 完全同步的默认开始日期:20180101

依赖项

关键包:

  • pandas>=2.0.0 - 数据处理
  • polars>=0.20.0 - 高性能数据处理(因子计算)
  • numpy>=1.24.0 - 数值计算
  • tushare>=2.0.0 - A股数据 API
  • pydantic>=2.0.0pydantic-settings>=2.0.0 - 配置
  • tqdm>=4.65.0 - 进度条
  • lightgbm>=4.0.0 - 机器学习模型
  • pytest - 测试(开发)

环境变量

创建 config/.env.local

TUSHARE_TOKEN=your_token_here
DATA_PATH=data
RATE_LIMIT=100
THREADS=10

常见任务

# 同步所有股票(增量)
uv run python -c "from src.data.sync import sync_all; sync_all()"

# 强制完全同步
uv run python -c "from src.data.sync import sync_all; sync_all(force_full=True)"

# 自定义线程数
uv run python -c "from src.data.sync import sync_all; sync_all(max_workers=20)"

# 同步财务数据
uv run python -c "from src.data.api_wrappers.financial_data import sync_financial; sync_financial()"

# 运行因子计算测试
uv run pytest tests/test_factor_engine.py -v

Factors 框架设计说明

架构层次

因子框架采用分层设计,从上到下依次是:

API 层 (api.py)
    |
    v
DSL 层 (dsl.py)         <- 因子表达式 (Node)
    |
    v
Compiler (compiler.py)  <- AST 依赖提取
    |
    v
Parser (parser.py)      <- 字符串公式解析器
    |
    v
Registry (registry.py)  <- 函数注册表
    |
    v
Translator (translator.py) <- 翻译为 Polars 表达式
    |
    v
Engine (engine/)        <- 执行引擎
    |   - FactorEngine: 统一入口
    |   - DataRouter: 数据路由
    |   - ExecutionPlanner: 执行计划
    |   - ComputeEngine: 计算引擎
    |
    v
Metadata (metadata/)    <- 因子元数据管理(可选)
    |   - FactorManager: 元数据管理器
    |   - FactorValidator: 字段校验器
    |
    v
数据层 (data_router.py + DuckDB) <- 数据获取和存储

FactorEngine 核心 API

from src.factors import FactorEngine

# 初始化引擎(默认启用 metadata 功能)
engine = FactorEngine()

# 方式1: 使用 DSL 表达式
from src.factors.api import close, ts_mean, cs_rank
engine.register("ma20", ts_mean(close, 20))
engine.register("price_rank", cs_rank(close))

# 方式2: 使用字符串表达式(推荐)
engine.add_factor("ma20", "ts_mean(close, 20)")
engine.add_factor("alpha", "cs_rank(ts_mean(close, 5) - ts_mean(close, 20))")

# 方式3: 从 metadata 查询(需先在 metadata 中定义因子)
engine.add_factor("mom_5d")  # 从 metadata 查询并注册名为 mom_5d 的因子

# 计算因子
result = engine.compute(["ma20", "price_rank"], "20240101", "20240131")

# 查看已注册因子
print(engine.list_registered())

# 预览执行计划
plan = engine.preview_plan("ma20")

因子元数据管理 (metadata 模块)

metadata 模块提供基于 DuckDB 查询 JSONL 文件、零拷贝输出 Polars DataFrame 的因子管理能力。

核心组件:

  • FactorManager: 元数据管理器主类,提供因子增删改查接口
  • FactorValidator: 字段校验器,校验核心字段的存在性和类型
  • 异常类: FactorMetadataError, ValidationError, DuplicateFactorError

因子数据结构:

  • factor_id (str): 全局唯一标识符(如 "F_001"
  • name (str): 可读短名称(如 "mom_5d"
  • desc (str): 详细描述
  • dsl (str): DSL 计算公式
  • 扩展字段: category, author, tags, notes

使用示例:

from src.factors.metadata import FactorManager

# 初始化管理器(默认路径: data/factors.jsonl
manager = FactorManager()

# 添加因子
manager.add_factor({
    "factor_id": "F_001",
    "name": "mom_5d",
    "desc": "5日价格动量截面排序",
    "dsl": "cs_rank(close / ts_delay(close, 5) - 1)",
    "category": "momentum"  # 扩展字段
})

# 根据名称查询因子
df = manager.get_factors_by_name("mom_5d")

# 使用 SQL 条件查询因子
df = manager.search_factors("category = 'momentum'")
df = manager.search_factors("name LIKE 'mom_%'")

# 获取所有因子
df = manager.get_all_factors()

# 获取因子 DSL 表达式
dsl = manager.get_factor_dsl("F_001")

支持的函数

时间序列函数 (ts_*)

  • ts_mean(x, window) - 滚动均值
  • ts_std(x, window) - 滚动标准差
  • ts_max(x, window) - 滚动最大值
  • ts_min(x, window) - 滚动最小值
  • ts_sum(x, window) - 滚动求和
  • ts_delay(x, periods) - 滞后 N 期
  • ts_delta(x, periods) - 差分 N 期
  • ts_corr(x, y, window) - 滚动相关系数
  • ts_cov(x, y, window) - 滚动协方差
  • ts_rank(x, window) - 滚动排名

截面函数 (cs_*)

  • cs_rank(x) - 截面排名(分位数)
  • cs_zscore(x) - Z-Score 标准化
  • cs_neutralize(x, group) - 行业/市值中性化
  • cs_winsorize(x, lower, upper) - 缩尾处理
  • cs_demean(x) - 去均值

数学函数

  • log(x) - 自然对数
  • exp(x) - 指数函数
  • sqrt(x) - 平方根
  • sign(x) - 符号函数
  • abs(x) - 绝对值
  • max_(x, y) / min_(x, y) - 逐元素最值
  • clip(x, lower, upper) - 数值裁剪

条件函数

  • if_(condition, true_val, false_val) - 条件选择
  • where(condition, true_val, false_val) - if_ 的别名

运算符支持

DSL 表达式支持完整的 Python 运算符:

# 算术运算: +, -, *, /, //, %, **
expr1 = (close - open) / open * 100  # 涨跌幅

# 比较运算: ==, !=, <, <=, >, >=
expr2 = close > open  # 是否上涨

# 一元运算: -, +, abs()
expr3 = -change  # 涨跌额取反

# 链式调用
expr4 = ts_mean(cs_rank(close), 20)  # 排名后的20日平滑

异常处理

框架提供清晰的异常类型帮助定位问题:

  • FormulaParseError - 公式解析错误基类
  • UnknownFunctionError - 未知函数错误(提供模糊匹配建议)
  • InvalidSyntaxError - 语法错误
  • EmptyExpressionError - 空表达式错误
  • DuplicateFunctionError - 函数重复注册错误
from src.factors import FormulaParser, FunctionRegistry, UnknownFunctionError

parser = FormulaParser(FunctionRegistry())
try:
    expr = parser.parse("unknown_func(close)")
except UnknownFunctionError as e:
    print(e)  # 显示错误位置和可用函数建议

因子批量注册脚本

src/scripts/register_factors.py 提供批量注册因子到元数据的功能。用户只需在 FACTORS 列表中配置因子定义,脚本自动生成 factor_id 并保存到 factors.jsonl

使用方法:

# 在 register_factors.py 的 FACTORS 列表中定义因子
FACTORS = [
    {
        "name": "mom_5d",
        "desc": "5日价格动量收盘价相对于5日前收盘价的涨跌幅进行截面排名",
        "dsl": "cs_rank(close / ts_delay(close, 5) - 1)",
        "category": "momentum",  # 可选扩展字段
    },
    {
        "name": "volatility_20d",
        "desc": "20日价格波动率收益率的20日滚动标准差",
        "dsl": "ts_std(ts_delta(close, 1) / ts_delay(close, 1), 20)",
        "category": "volatility",
    },
]

# 运行脚本
# uv run python src/scripts/register_factors.py

脚本特性:

  • 自动生成 F_XXX 格式的唯一 ID
  • 自动跳过已存在的因子(通过 name 判断)
  • 支持扩展字段category, author, tags, notes 等)
  • 提供注册结果统计(成功/跳过/失败)

命令行使用:

# 批量注册所有配置的因子
uv run python src/scripts/register_factors.py

Training 模块设计说明

架构概述

Training 模块位于 src/training/ 目录,负责从因子数据到模型训练、预测的完整流程。采用组件化设计,支持数据处理器、模型、过滤器、股票池管理器的灵活组合。

src/training/
├── core/
│   ├── trainer.py          # Trainer 主类
│   └── stock_pool_manager.py # 股票池管理器
├── components/
│   ├── base.py             # BaseModel、BaseProcessor 抽象基类
│   ├── splitters.py        # DateSplitter 日期划分器
│   ├── selectors.py        # 股票选择器(已迁移到 StockPoolManager
│   ├── filters.py          # STFilter 等过滤器
│   ├── models/             # 模型实现
│   │   ├── __init__.py
│   │   ├── lightgbm.py     # LightGBM 回归/分类模型
│   │   └── lightgbm_lambdarank.py # LightGBM LambdaRank 排序模型
│   └── processors/         # 数据处理器
│       ├── __init__.py
│       └── transforms.py   # 变换处理器
├── config/                 # 配置
│   ├── __init__.py
│   └── config.py           # 训练配置
├── registry.py             # 组件注册中心
└── __init__.py             # 导出所有组件

Trainer 核心流程

from src.training import Trainer, DateSplitter, StockPoolManager
from src.training.components.models import LightGBMModel
from src.training.components.processors import Winsorizer, StandardScaler
from src.training.components.filters import STFilter
import polars as pl

# 1. 创建模型
model = LightGBMModel(params={
    "objective": "regression",
    "metric": "mae",
    "num_leaves": 20,
    "learning_rate": 0.01,
    "n_estimators": 1000,
})

# 2. 创建数据划分器(正确的 train/val/test 三分法)
splitter = DateSplitter(
    train_start="20200101",
    train_end="20231231",
    test_start="20250101",
    test_end="20261231",
    val_start="20240101",
    val_end="20241231",
)

# 3. 创建数据处理器
processors = [
    NullFiller(feature_cols=["ma_5", "ma_20", "volume_ratio", "roe"], strategy="mean"),
    Winsorizer(feature_cols=["ma_5", "ma_20", "volume_ratio", "roe"], lower=0.01, upper=0.99),
    StandardScaler(feature_cols=["ma_5", "ma_20", "volume_ratio", "roe"]),
]

# 4. 创建股票池筛选函数
def stock_pool_filter(df: pl.DataFrame) -> pl.Series:
    """筛选小市值股票,排除创业板/科创板/北交所"""
    code_filter = (
        ~df["ts_code"].str.starts_with("300") &  # 排除创业板
        ~df["ts_code"].str.starts_with("688") &  # 排除科创板
        ~df["ts_code"].str.starts_with("8") &    # 排除北交所
        ~df["ts_code"].str.starts_with("9") &
        ~df["ts_code"].str.starts_with("4")
    )
    valid_df = df.filter(code_filter)
    n = min(1000, len(valid_df))
    small_cap_codes = valid_df.sort("total_mv").head(n)["ts_code"]
    return df["ts_code"].is_in(small_cap_codes)

pool_manager = StockPoolManager(
    filter_func=stock_pool_filter,
    required_columns=["total_mv"],
    data_router=engine.router,
)

# 5. 创建 ST 过滤器
st_filter = STFilter(data_router=engine.router)

# 6. 创建训练器
trainer = Trainer(
    model=model,
    pool_manager=pool_manager,
    processors=processors,
    filters=[st_filter],
    splitter=splitter,
    target_col="future_return_5",
    feature_cols=["ma_5", "ma_20", "volume_ratio", "roe"],
)

# 7. 执行训练
trainer.train(data)

# 8. 获取结果
results = trainer.get_results()

数据处理器

NullFiller - 缺失值填充:

from src.training.components.processors import NullFiller

# 使用 0 填充
filler = NullFiller(feature_cols=["factor1", "factor2"], strategy="zero")

# 使用均值填充(每天独立计算截面均值)
filler = NullFiller(
    feature_cols=["factor1", "factor2"],
    strategy="mean",
    by_date=True
)

# 使用指定值填充
filler = NullFiller(
    feature_cols=["factor1", "factor2"],
    strategy="value",
    fill_value=-999
)

Winsorizer - 缩尾处理:

from src.training.components.processors import Winsorizer

# 全局缩尾(默认)
winsorizer = Winsorizer(
    feature_cols=["factor1", "factor2"],
    lower=0.01,
    upper=0.99,
    by_date=False
)

# 每天独立缩尾
winsorizer = Winsorizer(
    feature_cols=["factor1", "factor2"],
    lower=0.01,
    upper=0.99,
    by_date=True
)

StandardScaler - 标准化:

from src.training.components.processors import StandardScaler

# 全局标准化(学习训练集的均值和标准差)
scaler = StandardScaler(feature_cols=["factor1", "factor2", "factor3"])

CrossSectionalStandardScaler - 截面标准化:

from src.training.components.processors import CrossSectionalStandardScaler

# 每天独立标准化(不需要 fit
cs_scaler = CrossSectionalStandardScaler(
    feature_cols=["factor1", "factor2", "factor3"],
    date_col="trade_date",
)

排序学习 (LambdaRank)

LightGBMLambdaRankModel - 基于 LambdaRank 的排序学习模型,适用于股票排序任务:

from src.training.components.models import LightGBMLambdaRankModel
from src.training import Trainer

# 创建排序学习模型
rank_model = LightGBMLambdaRankModel(
    params={
        "objective": "lambdarank",
        "metric": "ndcg",
        "ndcg_eval_at": [1, 5, 10, 20],
        "num_leaves": 31,
        "learning_rate": 0.05,
        "n_estimators": 500,
        "label_gain": [i for i in range(21)],  # 20分位数
    }
)

# 创建训练器(注意:排序学习需要 qid 分组)
trainer = Trainer(
    model=rank_model,
    pool_manager=pool_manager,
    processors=processors,
    filters=[st_filter],
    splitter=splitter,
    target_col="label",  # 必须是整数标签(分位数编码)
    feature_cols=feature_cols,
    date_col="trade_date",  # 必须指定,用于构建 qid
)

# 训练并评估
results = trainer.train(data)

关键特性:

  • LambdaRank 目标函数: 使用 LightGBM 的 lambdarank 优化排序
  • NDCG 评估: 支持 NDCG@1/5/10/20 指标评估排序质量
  • 自动分组: 根据 date_col 自动构建 query group (qid)
  • Label 要求: 目标变量必须是整数(如分位数编码的等级)

使用场景:

  • 将未来收益率转换为分位数等级作为 label
  • 学习每日股票的相对排序
  • 构建 Top-k 选股策略

组件注册机制

from src.training.registry import register_model, register_processor
from src.training.components.base import BaseModel, BaseProcessor

# 注册自定义模型
@register_model("custom_model")
class CustomModel(BaseModel):
    name = "custom_model"
    
    def fit(self, X, y):
        # 训练逻辑
        return self
    
    def predict(self, X):
        # 预测逻辑
        return predictions

# 注册自定义处理器
@register_processor("custom_processor")
class CustomProcessor(BaseProcessor):
    name = "custom_processor"
    
    def transform(self, X):
        # 转换逻辑
        return X

AI 行为准则

LSP 检测报错处理

⚠️ 强制要求:当进行 LSP 检测时报错,必定是代码格式问题。

如果 LSP 检测报错,必须按照以下流程处理:

  1. 问题定位

    • 报错必定是由基础格式错误引起:缩进错误、引号括号不匹配、代码格式错误等
    • 必须读取对应的代码行,精确定位错误
  2. 修复方式

    • 必须:读取报错文件,检查具体代码行
    • 必须:修复格式错误(缩进、括号匹配、引号闭合等)
    • 禁止:删除文件重新修改
    • 禁止:自行 rollback 文件
    • 禁止:新建文件重新修改
    • 禁止:忽略错误继续执行
  3. 验证要求

    • 修复后必须重新运行 LSP 检测确认无错误
    • 确保修改仅针对格式问题,不改变代码逻辑

示例场景

LSP 报错Syntax error on line 45
✅ 正确做法:读取文件第 45 行,发现少了一个右括号,添加后重新检测
❌ 错误做法:删除文件重新写、或者忽略错误继续

Emoji 表情禁用规则

⚠️ 强制要求:代码和测试文件中禁止出现 emoji 表情。

  1. 禁止范围

    • 所有 .py 源代码文件
    • 所有测试文件 (tests/ 目录)
    • 配置文件、脚本文件
  2. 替代方案

    • 禁止使用:print("✅ 成功")print("❌ 失败")# 📝 注释
    • 应使用:print("[成功]")print("[失败]")# 注释
    • 使用方括号 [成功][警告][错误] 等文字标记代替 emoji
  3. 唯一例外

    • AGENTS.md 文件本身可以使用 emoji 进行文档强调(如本文件中的 ⚠️
    • 项目文档、README 等对外展示文件可以酌情使用
  4. 检查方法

    • 使用正则表达式搜索 emoji[\U0001F600-\U0001F64F\U0001F300-\U0001F5FF\U0001F680-\U0001F6FF\U0001F1E0-\U0001F1FF\u2600-\u26FF\u2700-\u27BF]
    • 提交前自查,确保无 emoji 混入代码