# ProStock 代理指南 A股量化投资框架 - Python 项目,用于量化股票投资分析。 ## 交流语言要求 **⚠️ 强制要求:所有沟通和思考过程必须使用中文。** - 所有与 AI Agent 的交流必须使用中文 - 代码中的注释和文档字符串使用中文 - 禁止使用英文进行思考或沟通 ## 构建/检查/测试命令 **⚠️ 重要:本项目强制使用 uv 作为 Python 包管理器和运行工具。禁止直接使用 `python` 或 `pip` 命令。** **测试规则:** 当修改或查看 `tests/` 目录下的代码时,必须使用 pytest 命令进行测试验证。 ```bash # 安装依赖(必须使用 uv) uv pip install -e . # 运行所有测试 uv run pytest # 运行单个测试文件 uv run pytest tests/test_sync.py # 运行单个测试类 uv run pytest tests/test_sync.py::TestDataSync # 运行单个测试方法 uv run pytest tests/test_sync.py::TestDataSync::test_get_all_stock_codes_from_daily # 使用详细输出运行 uv run pytest -v # 运行覆盖率测试(如果安装了 pytest-cov) uv run pytest --cov=src --cov-report=term-missing ``` ### 禁止的命令 ❌ 以下命令在本项目中**严格禁止**: ```bash # 禁止直接使用 python python -c "…" # 禁止! python script.py # 禁止! python -m pytest # 禁止! python -m pip install # 禁止! # 禁止直接使用 pip pip install -e . # 禁止! pip install package # 禁止! pip list # 禁止! ``` ### 正确的 uv 用法 ✅ ```bash # 运行 Python 代码 uv run python -c "…" # ✅ 正确 uv run python script.py # ✅ 正确 # 安装依赖 uv pip install -e . # ✅ 正确 uv pip install package # ✅ 正确 # 运行测试 uv run pytest # ✅ 正确 uv run pytest tests/test_sync.py # ✅ 正确 ``` ## 项目结构 ``` ProStock/ ├── src/ # 源代码 │ ├── config/ # 配置管理 │ │ ├── __init__.py │ │ └── settings.py # pydantic-settings 配置 │ │ │ ├── data/ # 数据获取与存储 │ │ ├── api_wrappers/ # Tushare API 封装 │ │ │ ├── base_sync.py # 同步基础抽象类 │ │ │ ├── api_pro_bar.py # Pro Bar 行情数据接口(主用) │ │ │ ├── api_stock_basic.py # 股票基础信息接口 │ │ │ ├── api_trade_cal.py # 交易日历接口 │ │ │ ├── api_bak_basic.py # 历史股票列表接口 │ │ │ ├── api_namechange.py # 股票名称变更接口 │ │ │ ├── api_stock_st.py # ST股票信息接口 │ │ │ ├── api_daily_basic.py # 每日指标接口 │ │ │ ├── api_stk_limit.py # 涨跌停价格接口 │ │ │ ├── financial_data/ # 财务数据接口 │ │ │ │ ├── api_income.py # 利润表接口 │ │ │ │ ├── api_balance.py # 资产负债表接口 │ │ │ │ ├── api_cashflow.py # 现金流量表接口 │ │ │ │ ├── api_fina_indicator.py # 财务指标接口 │ │ │ │ └── api_financial_sync.py # 财务数据同步调度中心 │ │ │ └── __init__.py │ │ ├── __init__.py │ │ ├── client.py # Tushare API 客户端(带速率限制) │ │ ├── storage.py # 数据存储核心 │ │ ├── db_manager.py # DuckDB 表管理和同步 │ │ ├── db_inspector.py # 数据库信息查看工具 │ │ ├── sync.py # 数据同步调度中心 │ │ ├── sync_registry.py # 同步器注册表 │ │ ├── rate_limiter.py # 令牌桶速率限制器 │ │ ├── catalog.py # 数据目录管理 │ │ ├── config.py # 数据模块配置 │ │ ├── utils.py # 数据模块工具函数 │ │ └── financial_loader.py # 财务数据加载器 │ │ │ ├── factors/ # 因子计算框架(DSL 表达式驱动) │ │ ├── engine/ # 执行引擎子模块 │ │ │ ├── __init__.py # 导出引擎组件 │ │ │ ├── data_spec.py # 数据规格定义 │ │ │ ├── data_router.py # 数据路由器 │ │ │ ├── planner.py # 执行计划生成器 │ │ │ ├── compute_engine.py # 计算引擎 │ │ │ ├── schema_cache.py # 表结构缓存 │ │ │ └── factor_engine.py # 因子引擎统一入口 │ │ ├── metadata/ # 因子元数据管理 │ │ │ ├── __init__.py # 导出元数据组件 │ │ │ ├── manager.py # 因子管理器主类 │ │ │ ├── validator.py # 字段校验器 │ │ │ └── exceptions.py # 元数据异常定义 │ │ ├── __init__.py # 导出所有公开 API │ │ ├── dsl.py # DSL 表达式层 - 节点定义和运算符重载 │ │ ├── api.py # API 层 - 常用符号和函数 │ │ ├── compiler.py # AST 编译器 - 依赖提取 │ │ ├── translator.py # Polars 表达式翻译器 │ │ ├── parser.py # 字符串公式解析器 │ │ ├── registry.py # 函数注册表 │ │ ├── decorators.py # 装饰器工具 │ │ └── exceptions.py # 异常定义 │ │ │ ├── training/ # 训练模块 │ │ ├── core/ # 训练核心组件 │ │ │ ├── __init__.py │ │ │ ├── trainer.py # 训练器主类 │ │ │ └── stock_pool_manager.py # 股票池管理器 │ │ ├── components/ # 组件 │ │ │ ├── base.py # 基础抽象类 │ │ │ ├── splitters.py # 数据划分器 │ │ │ ├── selectors.py # 股票选择器 │ │ │ ├── filters.py # 数据过滤器 │ │ │ ├── models/ # 模型实现 │ │ │ │ ├── __init__.py │ │ │ │ ├── lightgbm.py # LightGBM 模型 │ │ │ │ └── lightgbm_lambdarank.py # LambdaRank 排序模型 │ │ │ └── processors/ # 数据处理器 │ │ │ ├── __init__.py │ │ │ └── transforms.py # 变换处理器 │ │ ├── config/ # 配置 │ │ │ ├── __init__.py │ │ │ └── config.py # 训练配置 │ │ ├── registry.py # 组件注册中心 │ │ └── __init__.py # 导出所有组件 │ │ │ ├── scripts/ # 脚本工具 │ │ └── register_factors.py # 因子批量注册脚本 │ │ │ └── experiment/ # 实验代码 │ ├── data/ # 实验数据目录 │ ├── regression.py # 回归训练流程(Python脚本) │ ├── learn_to_rank.py # 排序学习训练流程(Python脚本) │ └── regression.ipynb # 完整训练流程示例(Notebook) │ ├── tests/ # 测试文件 │ ├── test_sync.py │ ├── test_daily.py │ ├── test_factor_engine.py │ ├── test_factor_integration.py │ ├── test_pro_bar.py │ ├── test_db_manager.py │ ├── test_daily_storage.py │ ├── test_tushare_api.py │ └── pipeline/ │ └── test_core.py ├── config/ # 配置文件 │ └── .env.local # 环境变量(不在 git 中) ├── data/ # 数据存储(DuckDB) ├── docs/ # 文档 ├── pyproject.toml # 项目配置 └── README.md ``` ## 代码风格指南 ### Python 版本 - **需要 Python 3.10+** - 使用现代 Python 特性(match/case、海象运算符、类型提示) ### 导入 ```python # 标准库优先 import os import time from datetime import datetime, timedelta from pathlib import Path from typing import Optional, Dict, Callable from concurrent.futures import ThreadPoolExecutor import threading # 第三方包 import pandas as pd import numpy as np import polars as pl from tqdm import tqdm from pydantic_settings import BaseSettings # 本地模块(使用来自 src 的绝对导入) from src.data.client import TushareClient from src.data.storage import Storage from src.config.settings import get_settings ``` ### 类型提示 - **始终使用类型提示** 用于函数参数和返回值 - 对可空类型使用 `Optional[X]` - 当可用时使用现代联合语法 `X | Y`(Python 3.10+) - 从 `typing` 导入类型:`Optional`、`Dict`、`Callable` 等 ```python def sync_single_stock( self, ts_code: str, start_date: str, end_date: str, ) -> pd.DataFrame: ... ``` ### 文档字符串 - 使用 **Google 风格文档字符串** - 包含 Args、Returns 部分 - 第一行保持简短摘要 ```python def get_next_date(date_str: str) -> str: """获取给定日期之后的下一天。 Args: date_str: YYYYMMDD 格式的日期 Returns: YYYYMMDD 格式的下一天日期 """ ... ``` ### 命名约定 - 变量、函数、方法使用 `snake_case` - 类使用 `PascalCase` - 常量使用 `UPPER_CASE` - 私有方法:`_leading_underscore` - 受保护属性:`_single_underscore` ### 错误处理 - 使用特定的异常,不要使用裸 `except:` - 使用上下文记录错误:`print(f"[ERROR] 上下文: {e}")` - 对 API 调用使用指数退避重试逻辑 - 在关键错误时立即停止(设置停止标志) ```python try: data = api.query(...) except Exception as e: print(f"[ERROR] 获取 {ts_code} 失败: {e}") raise # 记录后重新抛出 ``` ### 配置 - 对所有配置使用 **pydantic-settings** - 从 `config/.env.local` 文件加载 - 环境变量自动转换:`tushare_token` -> `TUSHARE_TOKEN` - 对配置单例使用 `@lru_cache()` ### 数据存储 - 使用 **DuckDB** 嵌入式 OLAP 数据库进行持久化 - 存储在 `data/` 目录中(通过 `DATA_PATH` 环境变量配置) - 使用 UPSERT 模式(`INSERT OR REPLACE`)处理重复数据 - 多线程场景使用 `ThreadSafeStorage.queue_save()` + `flush()` 模式 - **只读模式支持**: 查询时默认启用 `read_only=True`,避免并发冲突 ```python from src.data.storage import Storage # 查询模式(只读,推荐用于数据查询) storage = Storage(read_only=True) # 默认只读 # 写入模式(用于数据同步) storage = Storage(read_only=False) ``` ### 财务数据表与 PIT 策略 **重要**: 并非所有财务数据表都支持 PIT(Point-In-Time)策略。 **支持 PIT 的财务表**(有 `f_ann_date` 列): - `financial_income` - 利润表 - `financial_balance` - 资产负债表 - `financial_cashflow` - 现金流量表 **不支持 PIT 的财务表**(只有 `ann_date` 列): - `financial_fina_indicator` - 财务指标表 **因子字段路由规则**: 因子引擎在动态路由字段时会自动识别财务表。对于同时存在于多个表的字段(如 `ebit`): - 如果字段存在于 `fina_indicator` 表和其他财务表,会优先路由到支持 PIT 的表 - `fina_indicator` 表由于缺少 `f_ann_date`,**被排除在动态字段路由之外** - 这确保了所有财务数据都能正确应用 PIT 策略,避免未来数据泄露 ### 线程与并发 - 对 I/O 密集型任务(API 调用)使用 `ThreadPoolExecutor` - 实现停止标志以实现优雅关闭:`threading.Event()` - 数据同步默认工作线程数:10 - 出错时始终使用 `executor.shutdown(wait=False, cancel_futures=True)` ### 日志记录 - 使用带前缀的 print 语句:`[模块名] 消息` - 错误格式:`[ERROR] 上下文: 异常` - 进度:循环中使用 `tqdm` ### 测试 - 使用 **pytest** 框架 - 模拟外部依赖(Tushare API) - 使用 `@pytest.fixture` 进行测试设置 - 在导入位置打补丁:`patch('src.data.sync.Storage')` - 测试成功和错误两种情况 ### 日期格式 - 使用 `YYYYMMDD` 字符串格式表示日期 - 辅助函数:`get_today_date()`、`get_next_date()` - 完全同步的默认开始日期:`20180101` ### 依赖项 关键包: - `pandas>=2.0.0` - 数据处理 - `polars>=0.20.0` - 高性能数据处理(因子计算) - `numpy>=1.24.0` - 数值计算 - `tushare>=2.0.0` - A股数据 API - `pydantic>=2.0.0`、`pydantic-settings>=2.0.0` - 配置 - `tqdm>=4.65.0` - 进度条 - `lightgbm>=4.0.0` - 机器学习模型 - `pytest` - 测试(开发) ### 环境变量 创建 `config/.env.local`: ```bash TUSHARE_TOKEN=your_token_here DATA_PATH=data RATE_LIMIT=100 THREADS=10 ``` ## 常见任务 ```bash # 同步所有股票(增量) uv run python -c "from src.data.sync import sync_all; sync_all()" # 强制完全同步 uv run python -c "from src.data.sync import sync_all; sync_all(force_full=True)" # 自定义线程数 uv run python -c "from src.data.sync import sync_all; sync_all(max_workers=20)" # 同步财务数据 uv run python -c "from src.data.api_wrappers.financial_data import sync_financial; sync_financial()" # 运行因子计算测试 uv run pytest tests/test_factor_engine.py -v ``` ## Factors 框架设计说明 ### 架构层次 因子框架采用分层设计,从上到下依次是: ``` API 层 (api.py) | v DSL 层 (dsl.py) <- 因子表达式 (Node) | v Compiler (compiler.py) <- AST 依赖提取 | v Parser (parser.py) <- 字符串公式解析器 | v Registry (registry.py) <- 函数注册表 | v Translator (translator.py) <- 翻译为 Polars 表达式 | v Engine (engine/) <- 执行引擎 | - FactorEngine: 统一入口 | - DataRouter: 数据路由 | - ExecutionPlanner: 执行计划 | - ComputeEngine: 计算引擎 | v Metadata (metadata/) <- 因子元数据管理(可选) | - FactorManager: 元数据管理器 | - FactorValidator: 字段校验器 | v 数据层 (data_router.py + DuckDB) <- 数据获取和存储 ``` ### FactorEngine 核心 API ```python from src.factors import FactorEngine # 初始化引擎(默认启用 metadata 功能) engine = FactorEngine() # 方式1: 使用 DSL 表达式 from src.factors.api import close, ts_mean, cs_rank engine.register("ma20", ts_mean(close, 20)) engine.register("price_rank", cs_rank(close)) # 方式2: 使用字符串表达式(推荐) engine.add_factor("ma20", "ts_mean(close, 20)") engine.add_factor("alpha", "cs_rank(ts_mean(close, 5) - ts_mean(close, 20))") # 方式3: 从 metadata 查询(需先在 metadata 中定义因子) engine.add_factor("mom_5d") # 从 metadata 查询并注册名为 mom_5d 的因子 # 计算因子 result = engine.compute(["ma20", "price_rank"], "20240101", "20240131") # 查看已注册因子 print(engine.list_registered()) # 预览执行计划 plan = engine.preview_plan("ma20") ``` ### 因子元数据管理 (metadata 模块) metadata 模块提供基于 DuckDB 查询 JSONL 文件、零拷贝输出 Polars DataFrame 的因子管理能力。 **核心组件:** - `FactorManager`: 元数据管理器主类,提供因子增删改查接口 - `FactorValidator`: 字段校验器,校验核心字段的存在性和类型 - 异常类: `FactorMetadataError`, `ValidationError`, `DuplicateFactorError` 等 **因子数据结构:** - `factor_id` (str): 全局唯一标识符(如 "F_001") - `name` (str): 可读短名称(如 "mom_5d") - `desc` (str): 详细描述 - `dsl` (str): DSL 计算公式 - 扩展字段: `category`, `author`, `tags`, `notes` 等 **使用示例:** ```python from src.factors.metadata import FactorManager # 初始化管理器(默认路径: data/factors.jsonl) manager = FactorManager() # 添加因子 manager.add_factor({ "factor_id": "F_001", "name": "mom_5d", "desc": "5日价格动量截面排序", "dsl": "cs_rank(close / ts_delay(close, 5) - 1)", "category": "momentum" # 扩展字段 }) # 根据名称查询因子 df = manager.get_factors_by_name("mom_5d") # 使用 SQL 条件查询因子 df = manager.search_factors("category = 'momentum'") df = manager.search_factors("name LIKE 'mom_%'") # 获取所有因子 df = manager.get_all_factors() # 获取因子 DSL 表达式 dsl = manager.get_factor_dsl("F_001") ``` ### 支持的函数 **时间序列函数 (ts_*)**: - `ts_mean(x, window)` - 滚动均值 - `ts_std(x, window)` - 滚动标准差 - `ts_max(x, window)` - 滚动最大值 - `ts_min(x, window)` - 滚动最小值 - `ts_sum(x, window)` - 滚动求和 - `ts_delay(x, periods)` - 滞后 N 期 - `ts_delta(x, periods)` - 差分 N 期 - `ts_corr(x, y, window)` - 滚动相关系数 - `ts_cov(x, y, window)` - 滚动协方差 - `ts_rank(x, window)` - 滚动排名 **截面函数 (cs_*)**: - `cs_rank(x)` - 截面排名(分位数) - `cs_zscore(x)` - Z-Score 标准化 - `cs_neutralize(x, group)` - 行业/市值中性化 - `cs_winsorize(x, lower, upper)` - 缩尾处理 - `cs_demean(x)` - 去均值 **数学函数**: - `log(x)` - 自然对数 - `exp(x)` - 指数函数 - `sqrt(x)` - 平方根 - `sign(x)` - 符号函数 - `abs(x)` - 绝对值 - `max_(x, y)` / `min_(x, y)` - 逐元素最值 - `clip(x, lower, upper)` - 数值裁剪 **条件函数**: - `if_(condition, true_val, false_val)` - 条件选择 - `where(condition, true_val, false_val)` - if_ 的别名 ### 运算符支持 DSL 表达式支持完整的 Python 运算符: ```python # 算术运算: +, -, *, /, //, %, ** expr1 = (close - open) / open * 100 # 涨跌幅 # 比较运算: ==, !=, <, <=, >, >= expr2 = close > open # 是否上涨 # 一元运算: -, +, abs() expr3 = -change # 涨跌额取反 # 链式调用 expr4 = ts_mean(cs_rank(close), 20) # 排名后的20日平滑 ``` ### 异常处理 框架提供清晰的异常类型帮助定位问题: - `FormulaParseError` - 公式解析错误基类 - `UnknownFunctionError` - 未知函数错误(提供模糊匹配建议) - `InvalidSyntaxError` - 语法错误 - `EmptyExpressionError` - 空表达式错误 - `DuplicateFunctionError` - 函数重复注册错误 ```python from src.factors import FormulaParser, FunctionRegistry, UnknownFunctionError parser = FormulaParser(FunctionRegistry()) try: expr = parser.parse("unknown_func(close)") except UnknownFunctionError as e: print(e) # 显示错误位置和可用函数建议 ``` ### 因子批量注册脚本 `src/scripts/register_factors.py` 提供批量注册因子到元数据的功能。用户只需在 `FACTORS` 列表中配置因子定义,脚本自动生成 `factor_id` 并保存到 `factors.jsonl`。 **使用方法:** ```python # 在 register_factors.py 的 FACTORS 列表中定义因子 FACTORS = [ { "name": "mom_5d", "desc": "5日价格动量,收盘价相对于5日前收盘价的涨跌幅进行截面排名", "dsl": "cs_rank(close / ts_delay(close, 5) - 1)", "category": "momentum", # 可选扩展字段 }, { "name": "volatility_20d", "desc": "20日价格波动率,收益率的20日滚动标准差", "dsl": "ts_std(ts_delta(close, 1) / ts_delay(close, 1), 20)", "category": "volatility", }, ] # 运行脚本 # uv run python src/scripts/register_factors.py ``` **脚本特性:** - 自动生成 `F_XXX` 格式的唯一 ID - 自动跳过已存在的因子(通过 `name` 判断) - 支持扩展字段(category, author, tags, notes 等) - 提供注册结果统计(成功/跳过/失败) **命令行使用:** ```bash # 批量注册所有配置的因子 uv run python src/scripts/register_factors.py ``` ## Training 模块设计说明 ### 架构概述 Training 模块位于 `src/training/` 目录,负责从因子数据到模型训练、预测的完整流程。采用组件化设计,支持数据处理器、模型、过滤器、股票池管理器的灵活组合。 ``` src/training/ ├── core/ │ ├── trainer.py # Trainer 主类 │ └── stock_pool_manager.py # 股票池管理器 ├── components/ │ ├── base.py # BaseModel、BaseProcessor 抽象基类 │ ├── splitters.py # DateSplitter 日期划分器 │ ├── selectors.py # 股票选择器(已迁移到 StockPoolManager) │ ├── filters.py # STFilter 等过滤器 │ ├── models/ # 模型实现 │ │ ├── __init__.py │ │ ├── lightgbm.py # LightGBM 回归/分类模型 │ │ └── lightgbm_lambdarank.py # LightGBM LambdaRank 排序模型 │ └── processors/ # 数据处理器 │ ├── __init__.py │ └── transforms.py # 变换处理器 ├── config/ # 配置 │ ├── __init__.py │ └── config.py # 训练配置 ├── registry.py # 组件注册中心 └── __init__.py # 导出所有组件 ``` ### Trainer 核心流程 ```python from src.training import Trainer, DateSplitter, StockPoolManager from src.training.components.models import LightGBMModel from src.training.components.processors import Winsorizer, StandardScaler from src.training.components.filters import STFilter import polars as pl # 1. 创建模型 model = LightGBMModel(params={ "objective": "regression", "metric": "mae", "num_leaves": 20, "learning_rate": 0.01, "n_estimators": 1000, }) # 2. 创建数据划分器(正确的 train/val/test 三分法) splitter = DateSplitter( train_start="20200101", train_end="20231231", test_start="20250101", test_end="20261231", val_start="20240101", val_end="20241231", ) # 3. 创建数据处理器 processors = [ NullFiller(feature_cols=["ma_5", "ma_20", "volume_ratio", "roe"], strategy="mean"), Winsorizer(feature_cols=["ma_5", "ma_20", "volume_ratio", "roe"], lower=0.01, upper=0.99), StandardScaler(feature_cols=["ma_5", "ma_20", "volume_ratio", "roe"]), ] # 4. 创建股票池筛选函数 def stock_pool_filter(df: pl.DataFrame) -> pl.Series: """筛选小市值股票,排除创业板/科创板/北交所""" code_filter = ( ~df["ts_code"].str.starts_with("300") & # 排除创业板 ~df["ts_code"].str.starts_with("688") & # 排除科创板 ~df["ts_code"].str.starts_with("8") & # 排除北交所 ~df["ts_code"].str.starts_with("9") & ~df["ts_code"].str.starts_with("4") ) valid_df = df.filter(code_filter) n = min(1000, len(valid_df)) small_cap_codes = valid_df.sort("total_mv").head(n)["ts_code"] return df["ts_code"].is_in(small_cap_codes) pool_manager = StockPoolManager( filter_func=stock_pool_filter, required_columns=["total_mv"], data_router=engine.router, ) # 5. 创建 ST 过滤器 st_filter = STFilter(data_router=engine.router) # 6. 创建训练器 trainer = Trainer( model=model, pool_manager=pool_manager, processors=processors, filters=[st_filter], splitter=splitter, target_col="future_return_5", feature_cols=["ma_5", "ma_20", "volume_ratio", "roe"], ) # 7. 执行训练 trainer.train(data) # 8. 获取结果 results = trainer.get_results() ``` ### 数据处理器 **NullFiller** - 缺失值填充: ```python from src.training.components.processors import NullFiller # 使用 0 填充 filler = NullFiller(feature_cols=["factor1", "factor2"], strategy="zero") # 使用均值填充(每天独立计算截面均值) filler = NullFiller( feature_cols=["factor1", "factor2"], strategy="mean", by_date=True ) # 使用指定值填充 filler = NullFiller( feature_cols=["factor1", "factor2"], strategy="value", fill_value=-999 ) ``` **Winsorizer** - 缩尾处理: ```python from src.training.components.processors import Winsorizer # 全局缩尾(默认) winsorizer = Winsorizer( feature_cols=["factor1", "factor2"], lower=0.01, upper=0.99, by_date=False ) # 每天独立缩尾 winsorizer = Winsorizer( feature_cols=["factor1", "factor2"], lower=0.01, upper=0.99, by_date=True ) ``` **StandardScaler** - 标准化: ```python from src.training.components.processors import StandardScaler # 全局标准化(学习训练集的均值和标准差) scaler = StandardScaler(feature_cols=["factor1", "factor2", "factor3"]) ``` **CrossSectionalStandardScaler** - 截面标准化: ```python from src.training.components.processors import CrossSectionalStandardScaler # 每天独立标准化(不需要 fit) cs_scaler = CrossSectionalStandardScaler( feature_cols=["factor1", "factor2", "factor3"], date_col="trade_date", ) ``` ### 排序学习 (LambdaRank) **LightGBMLambdaRankModel** - 基于 LambdaRank 的排序学习模型,适用于股票排序任务: ```python from src.training.components.models import LightGBMLambdaRankModel from src.training import Trainer # 创建排序学习模型 rank_model = LightGBMLambdaRankModel( params={ "objective": "lambdarank", "metric": "ndcg", "ndcg_eval_at": [1, 5, 10, 20], "num_leaves": 31, "learning_rate": 0.05, "n_estimators": 500, "label_gain": [i for i in range(21)], # 20分位数 } ) # 创建训练器(注意:排序学习需要 qid 分组) trainer = Trainer( model=rank_model, pool_manager=pool_manager, processors=processors, filters=[st_filter], splitter=splitter, target_col="label", # 必须是整数标签(分位数编码) feature_cols=feature_cols, date_col="trade_date", # 必须指定,用于构建 qid ) # 训练并评估 results = trainer.train(data) ``` **关键特性:** - **LambdaRank 目标函数**: 使用 LightGBM 的 lambdarank 优化排序 - **NDCG 评估**: 支持 NDCG@1/5/10/20 指标评估排序质量 - **自动分组**: 根据 `date_col` 自动构建 query group (qid) - **Label 要求**: 目标变量必须是整数(如分位数编码的等级) **使用场景:** - 将未来收益率转换为分位数等级作为 label - 学习每日股票的相对排序 - 构建 Top-k 选股策略 ### 组件注册机制 ```python from src.training.registry import register_model, register_processor from src.training.components.base import BaseModel, BaseProcessor # 注册自定义模型 @register_model("custom_model") class CustomModel(BaseModel): name = "custom_model" def fit(self, X, y): # 训练逻辑 return self def predict(self, X): # 预测逻辑 return predictions # 注册自定义处理器 @register_processor("custom_processor") class CustomProcessor(BaseProcessor): name = "custom_processor" def transform(self, X): # 转换逻辑 return X ``` ## AI 行为准则 ### LSP 检测报错处理 **⚠️ 强制要求:当进行 LSP 检测时报错,必定是代码格式问题。** 如果 LSP 检测报错,必须按照以下流程处理: 1. **问题定位** - 报错必定是由基础格式错误引起:缩进错误、引号括号不匹配、代码格式错误等 - 必须读取对应的代码行,精确定位错误 2. **修复方式** - ✅ **必须**:读取报错文件,检查具体代码行 - ✅ **必须**:修复格式错误(缩进、括号匹配、引号闭合等) - ❌ **禁止**:删除文件重新修改 - ❌ **禁止**:自行 rollback 文件 - ❌ **禁止**:新建文件重新修改 - ❌ **禁止**:忽略错误继续执行 3. **验证要求** - 修复后必须重新运行 LSP 检测确认无错误 - 确保修改仅针对格式问题,不改变代码逻辑 **示例场景**: ``` LSP 报错:Syntax error on line 45 ✅ 正确做法:读取文件第 45 行,发现少了一个右括号,添加后重新检测 ❌ 错误做法:删除文件重新写、或者忽略错误继续 ``` ### 代码存放位置规则 **⚠️ 强制要求:所有代码必须存放在 `src/` 或 `tests/` 目录下。** 1. **源代码位置** - 所有正式功能代码必须放在 `src/` 目录下 - 按照模块分类存放(`src/data/`、`src/factors/`、`src/training/` 等) 2. **测试代码位置** - 所有测试代码必须放在 `tests/` 目录下 - **临时测试代码**:任何临时性、探索性的测试脚本也必须写在 `tests/` 目录下 - 禁止在项目根目录或其他位置创建临时测试文件 3. **禁止事项** - ❌ 禁止在项目根目录创建 `.py` 文件 - ❌ 禁止在 `docs/`、`config/`、`data/` 等目录存放代码文件 - ❌ 禁止创建 `test_xxx.py`、`tmp_xxx.py`、`scratch_xxx.py` 等临时文件在项目根目录 4. **正确示例** ``` ✅ src/data/new_feature.py # 新功能代码 ✅ tests/test_new_feature.py # 正式测试 ✅ tests/scratch/experiment.py # 临时实验代码(在 tests 下) ``` ### Tests 目录代码运行规则 **⚠️ 强制要求:`tests/` 目录下的代码必须使用 pytest 指令来运行。** 1. **运行方式** - ✅ **必须**:使用 `uv run pytest tests/xxx.py` 运行测试文件 - ❌ **禁止**:直接使用 `uv run python tests/xxx.py` 或 `python tests/xxx.py` 2. **原因说明** - pytest 提供测试发现、断言重写、fixture 支持等测试专用功能 - 统一使用 pytest 确保测试代码在标准测试框架下执行 - 便于集成测试报告、覆盖率统计等功能 3. **正确示例** ```bash # ✅ 正确:使用 pytest 运行 uv run pytest tests/test_sync.py uv run pytest tests/test_sync.py::TestDataSync uv run pytest tests/ -v # ❌ 错误:直接使用 python 运行 uv run python tests/test_sync.py python tests/test_sync.py ``` ### 因子编写规范 **⚠️ 强制要求:编写因子时,优先使用字符串表达式而非 DSL 表达式。** 1. **推荐方式(字符串表达式)** ```python from src.factors import FactorEngine engine = FactorEngine() engine.add_factor("ma20", "ts_mean(close, 20)") engine.add_factor("alpha", "cs_rank(ts_mean(close, 5) - ts_mean(close, 20))") ``` 2. **不推荐方式(DSL 表达式)** ```python from src.factors.api import close, ts_mean, cs_rank engine.register("ma20", ts_mean(close, 20)) # 不推荐 ``` 3. **原因说明** - 字符串表达式更易于序列化存储到因子元数据(`factors.jsonl`) - 字符串表达式支持从元数据动态加载和复用 - 字符串表达式便于在配置文件中定义和维护 - 与 `src/scripts/register_factors.py` 批量注册脚本兼容 4. **使用场景** - ✅ 在 `register_factors.py` 的 `FACTORS` 列表中定义因子 - ✅ 动态添加因子到 FactorEngine - ✅ 从因子元数据查询并注册因子 ### Emoji 表情禁用规则 **⚠️ 强制要求:代码和测试文件中禁止出现 emoji 表情。** 1. **禁止范围** - 所有 `.py` 源代码文件 - 所有测试文件 (`tests/` 目录) - 配置文件、脚本文件 2. **替代方案** - ❌ 禁止使用:`print("✅ 成功")`、`print("❌ 失败")`、`# 📝 注释` - ✅ 应使用:`print("[成功]")`、`print("[失败]")`、`# 注释` - 使用方括号 `[成功]`、`[警告]`、`[错误]` 等文字标记代替 emoji 3. **唯一例外** - AGENTS.md 文件本身可以使用 emoji 进行文档强调(如本文件中的 ⚠️) - 项目文档、README 等对外展示文件可以酌情使用 4. **检查方法** - 使用正则表达式搜索 emoji:`[\U0001F600-\U0001F64F\U0001F300-\U0001F5FF\U0001F680-\U0001F6FF\U0001F1E0-\U0001F1FF\u2600-\u26FF\u2700-\u27BF]` - 提交前自查,确保无 emoji 混入代码