docs(AGENTS): 新增AI行为准则规范
- 添加代码存放位置规则,强制代码存放于 src/ 或 tests/ 目录 - 添加 Tests 目录代码运行规则,强制使用 pytest 运行测试代码 - 更新 learn_to_rank 实验代码:调整因子列表和处理器配置 - 修复 schema_cache 表结构缓存逻辑
This commit is contained in:
68
AGENTS.md
68
AGENTS.md
@@ -297,6 +297,24 @@ storage = Storage(read_only=True) # 默认只读
|
||||
storage = Storage(read_only=False)
|
||||
```
|
||||
|
||||
### 财务数据表与 PIT 策略
|
||||
|
||||
**重要**: 并非所有财务数据表都支持 PIT(Point-In-Time)策略。
|
||||
|
||||
**支持 PIT 的财务表**(有 `f_ann_date` 列):
|
||||
- `financial_income` - 利润表
|
||||
- `financial_balance` - 资产负债表
|
||||
- `financial_cashflow` - 现金流量表
|
||||
|
||||
**不支持 PIT 的财务表**(只有 `ann_date` 列):
|
||||
- `financial_fina_indicator` - 财务指标表
|
||||
|
||||
**因子字段路由规则**:
|
||||
因子引擎在动态路由字段时会自动识别财务表。对于同时存在于多个表的字段(如 `ebit`):
|
||||
- 如果字段存在于 `fina_indicator` 表和其他财务表,会优先路由到支持 PIT 的表
|
||||
- `fina_indicator` 表由于缺少 `f_ann_date`,**被排除在动态字段路由之外**
|
||||
- 这确保了所有财务数据都能正确应用 PIT 策略,避免未来数据泄露
|
||||
|
||||
### 线程与并发
|
||||
- 对 I/O 密集型任务(API 调用)使用 `ThreadPoolExecutor`
|
||||
- 实现停止标志以实现优雅关闭:`threading.Event()`
|
||||
@@ -869,6 +887,56 @@ LSP 报错:Syntax error on line 45
|
||||
❌ 错误做法:删除文件重新写、或者忽略错误继续
|
||||
```
|
||||
|
||||
### 代码存放位置规则
|
||||
|
||||
**⚠️ 强制要求:所有代码必须存放在 `src/` 或 `tests/` 目录下。**
|
||||
|
||||
1. **源代码位置**
|
||||
- 所有正式功能代码必须放在 `src/` 目录下
|
||||
- 按照模块分类存放(`src/data/`、`src/factors/`、`src/training/` 等)
|
||||
|
||||
2. **测试代码位置**
|
||||
- 所有测试代码必须放在 `tests/` 目录下
|
||||
- **临时测试代码**:任何临时性、探索性的测试脚本也必须写在 `tests/` 目录下
|
||||
- 禁止在项目根目录或其他位置创建临时测试文件
|
||||
|
||||
3. **禁止事项**
|
||||
- ❌ 禁止在项目根目录创建 `.py` 文件
|
||||
- ❌ 禁止在 `docs/`、`config/`、`data/` 等目录存放代码文件
|
||||
- ❌ 禁止创建 `test_xxx.py`、`tmp_xxx.py`、`scratch_xxx.py` 等临时文件在项目根目录
|
||||
|
||||
4. **正确示例**
|
||||
```
|
||||
✅ src/data/new_feature.py # 新功能代码
|
||||
✅ tests/test_new_feature.py # 正式测试
|
||||
✅ tests/scratch/experiment.py # 临时实验代码(在 tests 下)
|
||||
```
|
||||
|
||||
### Tests 目录代码运行规则
|
||||
|
||||
**⚠️ 强制要求:`tests/` 目录下的代码必须使用 pytest 指令来运行。**
|
||||
|
||||
1. **运行方式**
|
||||
- ✅ **必须**:使用 `uv run pytest tests/xxx.py` 运行测试文件
|
||||
- ❌ **禁止**:直接使用 `uv run python tests/xxx.py` 或 `python tests/xxx.py`
|
||||
|
||||
2. **原因说明**
|
||||
- pytest 提供测试发现、断言重写、fixture 支持等测试专用功能
|
||||
- 统一使用 pytest 确保测试代码在标准测试框架下执行
|
||||
- 便于集成测试报告、覆盖率统计等功能
|
||||
|
||||
3. **正确示例**
|
||||
```bash
|
||||
# ✅ 正确:使用 pytest 运行
|
||||
uv run pytest tests/test_sync.py
|
||||
uv run pytest tests/test_sync.py::TestDataSync
|
||||
uv run pytest tests/ -v
|
||||
|
||||
# ❌ 错误:直接使用 python 运行
|
||||
uv run python tests/test_sync.py
|
||||
python tests/test_sync.py
|
||||
```
|
||||
|
||||
### Emoji 表情禁用规则
|
||||
|
||||
**⚠️ 强制要求:代码和测试文件中禁止出现 emoji 表情。**
|
||||
|
||||
File diff suppressed because one or more lines are too long
@@ -1,17 +1,17 @@
|
||||
# %% md
|
||||
#%% md
|
||||
# # Learn-to-Rank 排序学习训练流程
|
||||
#
|
||||
# #
|
||||
# 本 Notebook 实现基于 LightGBM LambdaRank 的排序学习训练,用于股票排序任务。
|
||||
#
|
||||
# #
|
||||
# ## 核心特点
|
||||
#
|
||||
# #
|
||||
# 1. **Label 转换**: 将 `future_return_5` 按每日进行 20 分位数划分(qcut)
|
||||
# 2. **排序学习**: 使用 LambdaRank 目标函数,学习每日股票排序
|
||||
# 3. **NDCG 评估**: 使用 NDCG@1/5/10/20 评估排序质量
|
||||
# 4. **策略回测**: 基于排序分数构建 Top-k 选股策略
|
||||
# %% md
|
||||
#%% md
|
||||
# ## 1. 导入依赖
|
||||
# %%
|
||||
#%%
|
||||
import os
|
||||
from datetime import datetime
|
||||
from typing import List, Tuple, Optional
|
||||
@@ -37,9 +37,9 @@ from src.training.components.models import LightGBMLambdaRankModel
|
||||
from src.training.config import TrainingConfig
|
||||
|
||||
|
||||
# %% md
|
||||
#%% md
|
||||
# ## 2. 辅助函数
|
||||
# %%
|
||||
#%%
|
||||
def register_factors(
|
||||
engine: FactorEngine,
|
||||
selected_factors: List[str],
|
||||
@@ -240,11 +240,11 @@ def evaluate_ndcg_at_k(
|
||||
return results
|
||||
|
||||
|
||||
# %% md
|
||||
#%% md
|
||||
# ## 3. 配置参数
|
||||
#
|
||||
# #
|
||||
# ### 3.1 因子定义
|
||||
# %%
|
||||
#%%
|
||||
# 特征因子定义字典(复用 regression.ipynb 的因子定义)
|
||||
LABEL_NAME = "future_return_5_rank"
|
||||
|
||||
@@ -302,23 +302,23 @@ SELECTED_FACTORS = [
|
||||
"return_5_rank",
|
||||
"EP_rank",
|
||||
"pe_expansion_trend",
|
||||
"value_price_divergence",
|
||||
# "value_price_divergence",
|
||||
"active_market_cap",
|
||||
"ebit_rank",
|
||||
# "ebit_rank",
|
||||
]
|
||||
|
||||
# 因子定义字典(完整因子库)
|
||||
FACTOR_DEFINITIONS = {
|
||||
# "turnover_rate_volatility": "ts_std(turnover_rate, 20)"
|
||||
# "turnover_rate_volatility": "ts_std(log(turnover_rate), 20)"
|
||||
}
|
||||
|
||||
# Label 因子定义(不参与训练,用于计算目标)
|
||||
LABEL_FACTOR = {
|
||||
LABEL_NAME: "(ts_delay(close, -5) / ts_delay(open, -1)) - 1",
|
||||
}
|
||||
# %% md
|
||||
#%% md
|
||||
# ### 3.2 训练参数配置
|
||||
# %%
|
||||
#%%
|
||||
# 日期范围配置(正确的 train/val/test 三分法)
|
||||
TRAIN_START = "20200101"
|
||||
TRAIN_END = "20231231"
|
||||
@@ -387,9 +387,9 @@ PERSIST_MODEL = False
|
||||
|
||||
# Top N 配置:每日推荐股票数量
|
||||
TOP_N = 5 # 可调整为 10, 20 等
|
||||
# %% md
|
||||
#%% md
|
||||
# ## 4. 训练流程
|
||||
# %%
|
||||
#%%
|
||||
print("\n" + "=" * 80)
|
||||
print("LightGBM LambdaRank 排序学习训练")
|
||||
print("=" * 80)
|
||||
@@ -429,7 +429,7 @@ print(f"[配置] 特征数: {len(feature_cols)}")
|
||||
print(f"[配置] 目标变量: {target_col}({N_QUANTILES}分位数)")
|
||||
|
||||
# 6. 创建排序学习模型
|
||||
model: LightGBMLambdaRankModel = LightGBMLambdaRankModel(params=MODEL_PARAMS)
|
||||
model = LightGBMLambdaRankModel(params=MODEL_PARAMS)
|
||||
|
||||
# 7. 创建数据处理器(使用函数返回的完整特征列表)
|
||||
processors = [
|
||||
@@ -469,9 +469,9 @@ trainer = Trainer(
|
||||
feature_cols=feature_cols,
|
||||
persist_model=PERSIST_MODEL,
|
||||
)
|
||||
# %% md
|
||||
#%% md
|
||||
# ### 4.1 股票池筛选
|
||||
# %%
|
||||
#%%
|
||||
print("\n" + "=" * 80)
|
||||
print("股票池筛选")
|
||||
print("=" * 80)
|
||||
@@ -493,9 +493,9 @@ if pool_manager:
|
||||
else:
|
||||
filtered_data = data
|
||||
print(" 未配置股票池管理器,跳过筛选")
|
||||
# %% md
|
||||
#%% md
|
||||
# ### 4.2 数据划分
|
||||
# %%
|
||||
#%%
|
||||
print("\n" + "=" * 80)
|
||||
print("数据划分")
|
||||
print("=" * 80)
|
||||
@@ -519,15 +519,15 @@ if splitter:
|
||||
print(f"测试集日均样本数: {np.mean(test_group):.1f}")
|
||||
else:
|
||||
raise ValueError("必须配置数据划分器")
|
||||
# %% md
|
||||
#%% md
|
||||
# ### 4.3 数据质量检查
|
||||
# %%
|
||||
#%%
|
||||
print("\n" + "=" * 80)
|
||||
print("数据质量检查(必须在预处理之前)")
|
||||
print("=" * 80)
|
||||
|
||||
print("\n检查训练集...")
|
||||
check_data_quality(train_data, feature_cols, raise_on_error=True)
|
||||
check_data_quality(train_data, feature_cols, raise_on_error=False)
|
||||
|
||||
print("\n检查验证集...")
|
||||
check_data_quality(val_data, feature_cols, raise_on_error=True)
|
||||
@@ -537,9 +537,9 @@ check_data_quality(test_data, feature_cols, raise_on_error=True)
|
||||
|
||||
print("[成功] 数据质量检查通过,未发现异常")
|
||||
|
||||
# %% md
|
||||
#%% md
|
||||
# ### 4.4 数据预处理
|
||||
# %%
|
||||
#%%
|
||||
print("\n" + "=" * 80)
|
||||
print("数据预处理")
|
||||
print("=" * 80)
|
||||
@@ -563,9 +563,9 @@ if processors:
|
||||
print(f"\n处理后训练集形状: {train_data.shape}")
|
||||
print(f"处理后验证集形状: {val_data.shape}")
|
||||
print(f"处理后测试集形状: {test_data.shape}")
|
||||
# %% md
|
||||
#%% md
|
||||
# ### 4.4 训练 LambdaRank 模型
|
||||
# %%
|
||||
#%%
|
||||
print("\n" + "=" * 80)
|
||||
print("训练 LambdaRank 模型")
|
||||
print("=" * 80)
|
||||
@@ -593,9 +593,9 @@ model.fit(
|
||||
eval_set=(X_val, y_val, val_group),
|
||||
)
|
||||
print("训练完成!")
|
||||
# %% md
|
||||
#%% md
|
||||
# ### 4.5 训练指标曲线
|
||||
# %%
|
||||
#%%
|
||||
print("\n" + "=" * 80)
|
||||
print("训练指标曲线")
|
||||
print("=" * 80)
|
||||
@@ -645,9 +645,9 @@ else:
|
||||
best_val = max(val_metric_list)
|
||||
print(f" {metric}: {best_val:.4f} (迭代 {best_iter_metric + 1})")
|
||||
print(f"\n[重要提醒] 验证集仅用于早停/调参,测试集完全独立于训练过程!")
|
||||
# %% md
|
||||
#%% md
|
||||
# ### 4.6 模型评估
|
||||
# %%
|
||||
#%%
|
||||
print("\n" + "=" * 80)
|
||||
print("模型评估")
|
||||
print("=" * 80)
|
||||
@@ -685,7 +685,7 @@ if importance is not None:
|
||||
top_features = importance.sort_values(ascending=False).head(20)
|
||||
for i, (feature, score) in enumerate(top_features.items(), 1):
|
||||
print(f" {i:2d}. {feature:30s} {score:10.2f}")
|
||||
# %%
|
||||
#%%
|
||||
# 确保输出目录存在
|
||||
os.makedirs(OUTPUT_DIR, exist_ok=True)
|
||||
|
||||
@@ -731,36 +731,37 @@ print(f"\n 预览(前15行):")
|
||||
print(topn_to_save.head(15))
|
||||
|
||||
print("\n训练流程完成!")
|
||||
# %% md
|
||||
#%% md
|
||||
# ## 5. 总结
|
||||
#
|
||||
# #
|
||||
# 本 Notebook 实现了完整的 Learn-to-Rank 训练流程:
|
||||
#
|
||||
# #
|
||||
# ### 核心步骤
|
||||
#
|
||||
# #
|
||||
# 1. **数据准备**: 计算 49 个特征因子,将 `future_return_5` 转换为 20 分位数标签
|
||||
# 2. **模型训练**: 使用 LightGBM LambdaRank 学习每日股票排序
|
||||
# 3. **模型评估**: 使用 NDCG@1/5/10/20 评估排序质量
|
||||
# 4. **策略分析**: 基于排序分数构建 Top-k 选股策略
|
||||
#
|
||||
# #
|
||||
# ### 关键参数
|
||||
#
|
||||
# #
|
||||
# - **Objective**: lambdarank
|
||||
# - **Metric**: ndcg
|
||||
# - **Learning Rate**: 0.05
|
||||
# - **Num Leaves**: 31
|
||||
# - **N Quantiles**: 20
|
||||
#
|
||||
# #
|
||||
# ### 输出结果
|
||||
#
|
||||
# #
|
||||
# - rank_output.csv: 每日Top-N推荐股票(格式:date, score, ts_code)
|
||||
# - 特征重要性排名
|
||||
# - Top-k 策略统计和图表
|
||||
# - NDCG训练指标曲线
|
||||
#
|
||||
# #
|
||||
# ### 后续优化方向
|
||||
#
|
||||
# #
|
||||
# 1. **特征工程**: 尝试更多因子组合
|
||||
# 2. **超参数调优**: 使用网格搜索优化 LambdaRank 参数
|
||||
# 3. **模型集成**: 结合多个排序模型的预测
|
||||
# 4. **更复杂的分组**: 考虑按行业分组排序
|
||||
#
|
||||
@@ -30,6 +30,7 @@
|
||||
" Trainer,\n",
|
||||
" Winsorizer,\n",
|
||||
" NullFiller,\n",
|
||||
" check_data_quality,\n",
|
||||
")\n",
|
||||
"from src.training.config import TrainingConfig\n",
|
||||
"\n"
|
||||
@@ -46,13 +47,13 @@
|
||||
"outputs": [],
|
||||
"execution_count": null,
|
||||
"source": [
|
||||
"def create_factors_with_metadata(\n",
|
||||
"def register_factors(\n",
|
||||
" engine: FactorEngine,\n",
|
||||
" selected_factors: List[str],\n",
|
||||
" factor_definitions: dict,\n",
|
||||
" label_factor: dict,\n",
|
||||
") -> List[str]:\n",
|
||||
" \"\"\"注册因子(SELECTED_FACTORS 从 metadata 查询,FACTOR_DEFINITIONS 用表达式注册)\"\"\"\n",
|
||||
" \"\"\"注册因子(selected_factors 从 metadata 查询,factor_definitions 用 DSL 表达式注册)\"\"\"\n",
|
||||
" print(\"=\" * 80)\n",
|
||||
" print(\"注册因子\")\n",
|
||||
" print(\"=\" * 80)\n",
|
||||
@@ -327,9 +328,6 @@
|
||||
" \"random_state\": 42,\n",
|
||||
"}\n",
|
||||
"\n",
|
||||
"# 数据处理器配置(新 API:需要传入 feature_cols)\n",
|
||||
"# 注意:processor 现在需要显式指定要处理的特征列\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# 股票池筛选函数\n",
|
||||
"# 使用新的 StockPoolManager API:传入自定义筛选函数和所需列/因子\n",
|
||||
@@ -409,7 +407,7 @@
|
||||
"\n",
|
||||
"# 2. 使用 metadata 定义因子\n",
|
||||
"print(\"\\n[2] 定义因子(从 metadata 注册)\")\n",
|
||||
"feature_cols = create_factors_with_metadata(\n",
|
||||
"feature_cols = register_factors(\n",
|
||||
" engine, SELECTED_FACTORS, FACTOR_DEFINITIONS, LABEL_FACTOR\n",
|
||||
")\n",
|
||||
"target_col = LABEL_NAME\n",
|
||||
@@ -434,7 +432,7 @@
|
||||
"# 5. 创建模型\n",
|
||||
"model = LightGBMModel(params=MODEL_PARAMS)\n",
|
||||
"\n",
|
||||
"# 6. 创建数据处理器(新 API:需要传入 feature_cols)\n",
|
||||
"# 6. 创建数据处理器(使用函数返回的完整特征列表)\n",
|
||||
"processors = [\n",
|
||||
" NullFiller(feature_cols=feature_cols, strategy=\"mean\"),\n",
|
||||
" Winsorizer(feature_cols=feature_cols, lower=0.01, upper=0.99),\n",
|
||||
@@ -560,8 +558,32 @@
|
||||
"outputs": [],
|
||||
"execution_count": null,
|
||||
"source": [
|
||||
"# 步骤 3: 训练集数据处理\n",
|
||||
"print(\"\\n[步骤 3/6] 训练集数据处理\")\n",
|
||||
"# 步骤 3: 数据质量检查(必须在预处理之前)\n",
|
||||
"print(\"\\n[步骤 3/7] 数据质量检查\")\n",
|
||||
"print(\"-\" * 60)\n",
|
||||
"print(\" [说明] 此检查在 fillna 等处理之前执行,用于发现数据问题\")\n",
|
||||
"\n",
|
||||
"print(\"\\n 检查训练集...\")\n",
|
||||
"check_data_quality(train_data, feature_cols, raise_on_error=True)\n",
|
||||
"\n",
|
||||
"if \"val_data\" in locals() and val_data is not None:\n",
|
||||
" print(\"\\n 检查验证集...\")\n",
|
||||
" check_data_quality(val_data, feature_cols, raise_on_error=True)\n",
|
||||
"\n",
|
||||
"print(\"\\n 检查测试集...\")\n",
|
||||
"check_data_quality(test_data, feature_cols, raise_on_error=True)\n",
|
||||
"\n",
|
||||
"print(\" [成功] 数据质量检查通过,未发现异常\")\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"metadata": {},
|
||||
"cell_type": "code",
|
||||
"outputs": [],
|
||||
"execution_count": null,
|
||||
"source": [
|
||||
"# 步骤 4: 训练集数据处理\n",
|
||||
"print(\"\\n[步骤 4/7] 训练集数据处理\")\n",
|
||||
"print(\"-\" * 60)\n",
|
||||
"fitted_processors = []\n",
|
||||
"if processors:\n",
|
||||
@@ -595,7 +617,7 @@
|
||||
"execution_count": null,
|
||||
"source": [
|
||||
"# 步骤 4: 训练模型\n",
|
||||
"print(\"\\n[步骤 4/6] 训练模型\")\n",
|
||||
"print(\"\\n[步骤 5/7] 训练模型\")\n",
|
||||
"print(\"-\" * 60)\n",
|
||||
"print(f\" 模型类型: LightGBM\")\n",
|
||||
"print(f\" 训练样本数: {len(train_data)}\")\n",
|
||||
@@ -624,7 +646,7 @@
|
||||
"execution_count": null,
|
||||
"source": [
|
||||
"# 步骤 5: 测试集数据处理\n",
|
||||
"print(\"\\n[步骤 5/6] 测试集数据处理\")\n",
|
||||
"print(\"\\n[步骤 6/7] 测试集数据处理\")\n",
|
||||
"print(\"-\" * 60)\n",
|
||||
"if processors and test_data is not train_data:\n",
|
||||
" for i, processor in enumerate(fitted_processors, 1):\n",
|
||||
@@ -647,7 +669,7 @@
|
||||
"execution_count": null,
|
||||
"source": [
|
||||
"# 步骤 6: 生成预测\n",
|
||||
"print(\"\\n[步骤 6/6] 生成预测\")\n",
|
||||
"print(\"\\n[步骤 7/7] 生成预测\")\n",
|
||||
"print(\"-\" * 60)\n",
|
||||
"X_test = test_data.select(feature_cols)\n",
|
||||
"print(f\" 测试样本数: {len(X_test)}\")\n",
|
||||
|
||||
@@ -110,6 +110,9 @@ class SchemaCache:
|
||||
# 字段到表的映射(一个字段可能在多个表中存在)
|
||||
field_to_tables: Dict[str, List[str]] = {}
|
||||
for table, fields in table_fields.items():
|
||||
# 跳过不支持 PIT 的财务表(如 fina_indicator)
|
||||
if table.lower() in self._NON_PIT_FINANCIAL_TABLES:
|
||||
continue
|
||||
for field in fields:
|
||||
if field not in field_to_tables:
|
||||
field_to_tables[field] = []
|
||||
@@ -124,17 +127,30 @@ class SchemaCache:
|
||||
sorted_tables = sorted(tables, key=lambda t: priority_order.get(t, 999))
|
||||
self._field_to_table_map[field] = sorted_tables[0]
|
||||
|
||||
# 不支持 PIT(Point-In-Time)策略的财务表列表
|
||||
# 这些表缺少 f_ann_date 列,无法使用 asof_backward 模式
|
||||
_NON_PIT_FINANCIAL_TABLES = {"financial_fina_indicator"}
|
||||
|
||||
def is_financial_table(self, table_name: str) -> bool:
|
||||
"""判断是否为财务数据表。
|
||||
|
||||
注意:只有支持 PIT 策略(有 f_ann_date 列)的财务表才会返回 True。
|
||||
fina_indicator 表由于只有 ann_date 而没有 f_ann_date,被排除在外。
|
||||
|
||||
Args:
|
||||
table_name: 表名
|
||||
|
||||
Returns:
|
||||
是否为财务数据表
|
||||
是否为支持 PIT 的财务数据表
|
||||
"""
|
||||
financial_prefixes = ("financial_", "income", "balance", "cashflow")
|
||||
return table_name.lower().startswith(financial_prefixes)
|
||||
is_financial = table_name.lower().startswith(financial_prefixes)
|
||||
|
||||
# 排除不支持 PIT 的表
|
||||
if is_financial and table_name.lower() in self._NON_PIT_FINANCIAL_TABLES:
|
||||
return False
|
||||
|
||||
return is_financial
|
||||
|
||||
def get_table_fields(self, table_name: str) -> List[str]:
|
||||
"""获取指定表的字段列表。
|
||||
|
||||
Reference in New Issue
Block a user