docs(AGENTS): 新增AI行为准则规范
- 添加代码存放位置规则,强制代码存放于 src/ 或 tests/ 目录 - 添加 Tests 目录代码运行规则,强制使用 pytest 运行测试代码 - 更新 learn_to_rank 实验代码:调整因子列表和处理器配置 - 修复 schema_cache 表结构缓存逻辑
This commit is contained in:
68
AGENTS.md
68
AGENTS.md
@@ -297,6 +297,24 @@ storage = Storage(read_only=True) # 默认只读
|
|||||||
storage = Storage(read_only=False)
|
storage = Storage(read_only=False)
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### 财务数据表与 PIT 策略
|
||||||
|
|
||||||
|
**重要**: 并非所有财务数据表都支持 PIT(Point-In-Time)策略。
|
||||||
|
|
||||||
|
**支持 PIT 的财务表**(有 `f_ann_date` 列):
|
||||||
|
- `financial_income` - 利润表
|
||||||
|
- `financial_balance` - 资产负债表
|
||||||
|
- `financial_cashflow` - 现金流量表
|
||||||
|
|
||||||
|
**不支持 PIT 的财务表**(只有 `ann_date` 列):
|
||||||
|
- `financial_fina_indicator` - 财务指标表
|
||||||
|
|
||||||
|
**因子字段路由规则**:
|
||||||
|
因子引擎在动态路由字段时会自动识别财务表。对于同时存在于多个表的字段(如 `ebit`):
|
||||||
|
- 如果字段存在于 `fina_indicator` 表和其他财务表,会优先路由到支持 PIT 的表
|
||||||
|
- `fina_indicator` 表由于缺少 `f_ann_date`,**被排除在动态字段路由之外**
|
||||||
|
- 这确保了所有财务数据都能正确应用 PIT 策略,避免未来数据泄露
|
||||||
|
|
||||||
### 线程与并发
|
### 线程与并发
|
||||||
- 对 I/O 密集型任务(API 调用)使用 `ThreadPoolExecutor`
|
- 对 I/O 密集型任务(API 调用)使用 `ThreadPoolExecutor`
|
||||||
- 实现停止标志以实现优雅关闭:`threading.Event()`
|
- 实现停止标志以实现优雅关闭:`threading.Event()`
|
||||||
@@ -869,6 +887,56 @@ LSP 报错:Syntax error on line 45
|
|||||||
❌ 错误做法:删除文件重新写、或者忽略错误继续
|
❌ 错误做法:删除文件重新写、或者忽略错误继续
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### 代码存放位置规则
|
||||||
|
|
||||||
|
**⚠️ 强制要求:所有代码必须存放在 `src/` 或 `tests/` 目录下。**
|
||||||
|
|
||||||
|
1. **源代码位置**
|
||||||
|
- 所有正式功能代码必须放在 `src/` 目录下
|
||||||
|
- 按照模块分类存放(`src/data/`、`src/factors/`、`src/training/` 等)
|
||||||
|
|
||||||
|
2. **测试代码位置**
|
||||||
|
- 所有测试代码必须放在 `tests/` 目录下
|
||||||
|
- **临时测试代码**:任何临时性、探索性的测试脚本也必须写在 `tests/` 目录下
|
||||||
|
- 禁止在项目根目录或其他位置创建临时测试文件
|
||||||
|
|
||||||
|
3. **禁止事项**
|
||||||
|
- ❌ 禁止在项目根目录创建 `.py` 文件
|
||||||
|
- ❌ 禁止在 `docs/`、`config/`、`data/` 等目录存放代码文件
|
||||||
|
- ❌ 禁止创建 `test_xxx.py`、`tmp_xxx.py`、`scratch_xxx.py` 等临时文件在项目根目录
|
||||||
|
|
||||||
|
4. **正确示例**
|
||||||
|
```
|
||||||
|
✅ src/data/new_feature.py # 新功能代码
|
||||||
|
✅ tests/test_new_feature.py # 正式测试
|
||||||
|
✅ tests/scratch/experiment.py # 临时实验代码(在 tests 下)
|
||||||
|
```
|
||||||
|
|
||||||
|
### Tests 目录代码运行规则
|
||||||
|
|
||||||
|
**⚠️ 强制要求:`tests/` 目录下的代码必须使用 pytest 指令来运行。**
|
||||||
|
|
||||||
|
1. **运行方式**
|
||||||
|
- ✅ **必须**:使用 `uv run pytest tests/xxx.py` 运行测试文件
|
||||||
|
- ❌ **禁止**:直接使用 `uv run python tests/xxx.py` 或 `python tests/xxx.py`
|
||||||
|
|
||||||
|
2. **原因说明**
|
||||||
|
- pytest 提供测试发现、断言重写、fixture 支持等测试专用功能
|
||||||
|
- 统一使用 pytest 确保测试代码在标准测试框架下执行
|
||||||
|
- 便于集成测试报告、覆盖率统计等功能
|
||||||
|
|
||||||
|
3. **正确示例**
|
||||||
|
```bash
|
||||||
|
# ✅ 正确:使用 pytest 运行
|
||||||
|
uv run pytest tests/test_sync.py
|
||||||
|
uv run pytest tests/test_sync.py::TestDataSync
|
||||||
|
uv run pytest tests/ -v
|
||||||
|
|
||||||
|
# ❌ 错误:直接使用 python 运行
|
||||||
|
uv run python tests/test_sync.py
|
||||||
|
python tests/test_sync.py
|
||||||
|
```
|
||||||
|
|
||||||
### Emoji 表情禁用规则
|
### Emoji 表情禁用规则
|
||||||
|
|
||||||
**⚠️ 强制要求:代码和测试文件中禁止出现 emoji 表情。**
|
**⚠️ 强制要求:代码和测试文件中禁止出现 emoji 表情。**
|
||||||
|
|||||||
File diff suppressed because one or more lines are too long
@@ -1,17 +1,17 @@
|
|||||||
# %% md
|
#%% md
|
||||||
# # Learn-to-Rank 排序学习训练流程
|
# # Learn-to-Rank 排序学习训练流程
|
||||||
#
|
# #
|
||||||
# 本 Notebook 实现基于 LightGBM LambdaRank 的排序学习训练,用于股票排序任务。
|
# 本 Notebook 实现基于 LightGBM LambdaRank 的排序学习训练,用于股票排序任务。
|
||||||
#
|
# #
|
||||||
# ## 核心特点
|
# ## 核心特点
|
||||||
#
|
# #
|
||||||
# 1. **Label 转换**: 将 `future_return_5` 按每日进行 20 分位数划分(qcut)
|
# 1. **Label 转换**: 将 `future_return_5` 按每日进行 20 分位数划分(qcut)
|
||||||
# 2. **排序学习**: 使用 LambdaRank 目标函数,学习每日股票排序
|
# 2. **排序学习**: 使用 LambdaRank 目标函数,学习每日股票排序
|
||||||
# 3. **NDCG 评估**: 使用 NDCG@1/5/10/20 评估排序质量
|
# 3. **NDCG 评估**: 使用 NDCG@1/5/10/20 评估排序质量
|
||||||
# 4. **策略回测**: 基于排序分数构建 Top-k 选股策略
|
# 4. **策略回测**: 基于排序分数构建 Top-k 选股策略
|
||||||
# %% md
|
#%% md
|
||||||
# ## 1. 导入依赖
|
# ## 1. 导入依赖
|
||||||
# %%
|
#%%
|
||||||
import os
|
import os
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import List, Tuple, Optional
|
from typing import List, Tuple, Optional
|
||||||
@@ -37,9 +37,9 @@ from src.training.components.models import LightGBMLambdaRankModel
|
|||||||
from src.training.config import TrainingConfig
|
from src.training.config import TrainingConfig
|
||||||
|
|
||||||
|
|
||||||
# %% md
|
#%% md
|
||||||
# ## 2. 辅助函数
|
# ## 2. 辅助函数
|
||||||
# %%
|
#%%
|
||||||
def register_factors(
|
def register_factors(
|
||||||
engine: FactorEngine,
|
engine: FactorEngine,
|
||||||
selected_factors: List[str],
|
selected_factors: List[str],
|
||||||
@@ -240,11 +240,11 @@ def evaluate_ndcg_at_k(
|
|||||||
return results
|
return results
|
||||||
|
|
||||||
|
|
||||||
# %% md
|
#%% md
|
||||||
# ## 3. 配置参数
|
# ## 3. 配置参数
|
||||||
#
|
# #
|
||||||
# ### 3.1 因子定义
|
# ### 3.1 因子定义
|
||||||
# %%
|
#%%
|
||||||
# 特征因子定义字典(复用 regression.ipynb 的因子定义)
|
# 特征因子定义字典(复用 regression.ipynb 的因子定义)
|
||||||
LABEL_NAME = "future_return_5_rank"
|
LABEL_NAME = "future_return_5_rank"
|
||||||
|
|
||||||
@@ -302,23 +302,23 @@ SELECTED_FACTORS = [
|
|||||||
"return_5_rank",
|
"return_5_rank",
|
||||||
"EP_rank",
|
"EP_rank",
|
||||||
"pe_expansion_trend",
|
"pe_expansion_trend",
|
||||||
"value_price_divergence",
|
# "value_price_divergence",
|
||||||
"active_market_cap",
|
"active_market_cap",
|
||||||
"ebit_rank",
|
# "ebit_rank",
|
||||||
]
|
]
|
||||||
|
|
||||||
# 因子定义字典(完整因子库)
|
# 因子定义字典(完整因子库)
|
||||||
FACTOR_DEFINITIONS = {
|
FACTOR_DEFINITIONS = {
|
||||||
# "turnover_rate_volatility": "ts_std(turnover_rate, 20)"
|
# "turnover_rate_volatility": "ts_std(log(turnover_rate), 20)"
|
||||||
}
|
}
|
||||||
|
|
||||||
# Label 因子定义(不参与训练,用于计算目标)
|
# Label 因子定义(不参与训练,用于计算目标)
|
||||||
LABEL_FACTOR = {
|
LABEL_FACTOR = {
|
||||||
LABEL_NAME: "(ts_delay(close, -5) / ts_delay(open, -1)) - 1",
|
LABEL_NAME: "(ts_delay(close, -5) / ts_delay(open, -1)) - 1",
|
||||||
}
|
}
|
||||||
# %% md
|
#%% md
|
||||||
# ### 3.2 训练参数配置
|
# ### 3.2 训练参数配置
|
||||||
# %%
|
#%%
|
||||||
# 日期范围配置(正确的 train/val/test 三分法)
|
# 日期范围配置(正确的 train/val/test 三分法)
|
||||||
TRAIN_START = "20200101"
|
TRAIN_START = "20200101"
|
||||||
TRAIN_END = "20231231"
|
TRAIN_END = "20231231"
|
||||||
@@ -387,9 +387,9 @@ PERSIST_MODEL = False
|
|||||||
|
|
||||||
# Top N 配置:每日推荐股票数量
|
# Top N 配置:每日推荐股票数量
|
||||||
TOP_N = 5 # 可调整为 10, 20 等
|
TOP_N = 5 # 可调整为 10, 20 等
|
||||||
# %% md
|
#%% md
|
||||||
# ## 4. 训练流程
|
# ## 4. 训练流程
|
||||||
# %%
|
#%%
|
||||||
print("\n" + "=" * 80)
|
print("\n" + "=" * 80)
|
||||||
print("LightGBM LambdaRank 排序学习训练")
|
print("LightGBM LambdaRank 排序学习训练")
|
||||||
print("=" * 80)
|
print("=" * 80)
|
||||||
@@ -429,7 +429,7 @@ print(f"[配置] 特征数: {len(feature_cols)}")
|
|||||||
print(f"[配置] 目标变量: {target_col}({N_QUANTILES}分位数)")
|
print(f"[配置] 目标变量: {target_col}({N_QUANTILES}分位数)")
|
||||||
|
|
||||||
# 6. 创建排序学习模型
|
# 6. 创建排序学习模型
|
||||||
model: LightGBMLambdaRankModel = LightGBMLambdaRankModel(params=MODEL_PARAMS)
|
model = LightGBMLambdaRankModel(params=MODEL_PARAMS)
|
||||||
|
|
||||||
# 7. 创建数据处理器(使用函数返回的完整特征列表)
|
# 7. 创建数据处理器(使用函数返回的完整特征列表)
|
||||||
processors = [
|
processors = [
|
||||||
@@ -469,9 +469,9 @@ trainer = Trainer(
|
|||||||
feature_cols=feature_cols,
|
feature_cols=feature_cols,
|
||||||
persist_model=PERSIST_MODEL,
|
persist_model=PERSIST_MODEL,
|
||||||
)
|
)
|
||||||
# %% md
|
#%% md
|
||||||
# ### 4.1 股票池筛选
|
# ### 4.1 股票池筛选
|
||||||
# %%
|
#%%
|
||||||
print("\n" + "=" * 80)
|
print("\n" + "=" * 80)
|
||||||
print("股票池筛选")
|
print("股票池筛选")
|
||||||
print("=" * 80)
|
print("=" * 80)
|
||||||
@@ -493,9 +493,9 @@ if pool_manager:
|
|||||||
else:
|
else:
|
||||||
filtered_data = data
|
filtered_data = data
|
||||||
print(" 未配置股票池管理器,跳过筛选")
|
print(" 未配置股票池管理器,跳过筛选")
|
||||||
# %% md
|
#%% md
|
||||||
# ### 4.2 数据划分
|
# ### 4.2 数据划分
|
||||||
# %%
|
#%%
|
||||||
print("\n" + "=" * 80)
|
print("\n" + "=" * 80)
|
||||||
print("数据划分")
|
print("数据划分")
|
||||||
print("=" * 80)
|
print("=" * 80)
|
||||||
@@ -519,15 +519,15 @@ if splitter:
|
|||||||
print(f"测试集日均样本数: {np.mean(test_group):.1f}")
|
print(f"测试集日均样本数: {np.mean(test_group):.1f}")
|
||||||
else:
|
else:
|
||||||
raise ValueError("必须配置数据划分器")
|
raise ValueError("必须配置数据划分器")
|
||||||
# %% md
|
#%% md
|
||||||
# ### 4.3 数据质量检查
|
# ### 4.3 数据质量检查
|
||||||
# %%
|
#%%
|
||||||
print("\n" + "=" * 80)
|
print("\n" + "=" * 80)
|
||||||
print("数据质量检查(必须在预处理之前)")
|
print("数据质量检查(必须在预处理之前)")
|
||||||
print("=" * 80)
|
print("=" * 80)
|
||||||
|
|
||||||
print("\n检查训练集...")
|
print("\n检查训练集...")
|
||||||
check_data_quality(train_data, feature_cols, raise_on_error=True)
|
check_data_quality(train_data, feature_cols, raise_on_error=False)
|
||||||
|
|
||||||
print("\n检查验证集...")
|
print("\n检查验证集...")
|
||||||
check_data_quality(val_data, feature_cols, raise_on_error=True)
|
check_data_quality(val_data, feature_cols, raise_on_error=True)
|
||||||
@@ -537,9 +537,9 @@ check_data_quality(test_data, feature_cols, raise_on_error=True)
|
|||||||
|
|
||||||
print("[成功] 数据质量检查通过,未发现异常")
|
print("[成功] 数据质量检查通过,未发现异常")
|
||||||
|
|
||||||
# %% md
|
#%% md
|
||||||
# ### 4.4 数据预处理
|
# ### 4.4 数据预处理
|
||||||
# %%
|
#%%
|
||||||
print("\n" + "=" * 80)
|
print("\n" + "=" * 80)
|
||||||
print("数据预处理")
|
print("数据预处理")
|
||||||
print("=" * 80)
|
print("=" * 80)
|
||||||
@@ -563,9 +563,9 @@ if processors:
|
|||||||
print(f"\n处理后训练集形状: {train_data.shape}")
|
print(f"\n处理后训练集形状: {train_data.shape}")
|
||||||
print(f"处理后验证集形状: {val_data.shape}")
|
print(f"处理后验证集形状: {val_data.shape}")
|
||||||
print(f"处理后测试集形状: {test_data.shape}")
|
print(f"处理后测试集形状: {test_data.shape}")
|
||||||
# %% md
|
#%% md
|
||||||
# ### 4.4 训练 LambdaRank 模型
|
# ### 4.4 训练 LambdaRank 模型
|
||||||
# %%
|
#%%
|
||||||
print("\n" + "=" * 80)
|
print("\n" + "=" * 80)
|
||||||
print("训练 LambdaRank 模型")
|
print("训练 LambdaRank 模型")
|
||||||
print("=" * 80)
|
print("=" * 80)
|
||||||
@@ -593,9 +593,9 @@ model.fit(
|
|||||||
eval_set=(X_val, y_val, val_group),
|
eval_set=(X_val, y_val, val_group),
|
||||||
)
|
)
|
||||||
print("训练完成!")
|
print("训练完成!")
|
||||||
# %% md
|
#%% md
|
||||||
# ### 4.5 训练指标曲线
|
# ### 4.5 训练指标曲线
|
||||||
# %%
|
#%%
|
||||||
print("\n" + "=" * 80)
|
print("\n" + "=" * 80)
|
||||||
print("训练指标曲线")
|
print("训练指标曲线")
|
||||||
print("=" * 80)
|
print("=" * 80)
|
||||||
@@ -645,9 +645,9 @@ else:
|
|||||||
best_val = max(val_metric_list)
|
best_val = max(val_metric_list)
|
||||||
print(f" {metric}: {best_val:.4f} (迭代 {best_iter_metric + 1})")
|
print(f" {metric}: {best_val:.4f} (迭代 {best_iter_metric + 1})")
|
||||||
print(f"\n[重要提醒] 验证集仅用于早停/调参,测试集完全独立于训练过程!")
|
print(f"\n[重要提醒] 验证集仅用于早停/调参,测试集完全独立于训练过程!")
|
||||||
# %% md
|
#%% md
|
||||||
# ### 4.6 模型评估
|
# ### 4.6 模型评估
|
||||||
# %%
|
#%%
|
||||||
print("\n" + "=" * 80)
|
print("\n" + "=" * 80)
|
||||||
print("模型评估")
|
print("模型评估")
|
||||||
print("=" * 80)
|
print("=" * 80)
|
||||||
@@ -685,7 +685,7 @@ if importance is not None:
|
|||||||
top_features = importance.sort_values(ascending=False).head(20)
|
top_features = importance.sort_values(ascending=False).head(20)
|
||||||
for i, (feature, score) in enumerate(top_features.items(), 1):
|
for i, (feature, score) in enumerate(top_features.items(), 1):
|
||||||
print(f" {i:2d}. {feature:30s} {score:10.2f}")
|
print(f" {i:2d}. {feature:30s} {score:10.2f}")
|
||||||
# %%
|
#%%
|
||||||
# 确保输出目录存在
|
# 确保输出目录存在
|
||||||
os.makedirs(OUTPUT_DIR, exist_ok=True)
|
os.makedirs(OUTPUT_DIR, exist_ok=True)
|
||||||
|
|
||||||
@@ -731,36 +731,37 @@ print(f"\n 预览(前15行):")
|
|||||||
print(topn_to_save.head(15))
|
print(topn_to_save.head(15))
|
||||||
|
|
||||||
print("\n训练流程完成!")
|
print("\n训练流程完成!")
|
||||||
# %% md
|
#%% md
|
||||||
# ## 5. 总结
|
# ## 5. 总结
|
||||||
#
|
# #
|
||||||
# 本 Notebook 实现了完整的 Learn-to-Rank 训练流程:
|
# 本 Notebook 实现了完整的 Learn-to-Rank 训练流程:
|
||||||
#
|
# #
|
||||||
# ### 核心步骤
|
# ### 核心步骤
|
||||||
#
|
# #
|
||||||
# 1. **数据准备**: 计算 49 个特征因子,将 `future_return_5` 转换为 20 分位数标签
|
# 1. **数据准备**: 计算 49 个特征因子,将 `future_return_5` 转换为 20 分位数标签
|
||||||
# 2. **模型训练**: 使用 LightGBM LambdaRank 学习每日股票排序
|
# 2. **模型训练**: 使用 LightGBM LambdaRank 学习每日股票排序
|
||||||
# 3. **模型评估**: 使用 NDCG@1/5/10/20 评估排序质量
|
# 3. **模型评估**: 使用 NDCG@1/5/10/20 评估排序质量
|
||||||
# 4. **策略分析**: 基于排序分数构建 Top-k 选股策略
|
# 4. **策略分析**: 基于排序分数构建 Top-k 选股策略
|
||||||
#
|
# #
|
||||||
# ### 关键参数
|
# ### 关键参数
|
||||||
#
|
# #
|
||||||
# - **Objective**: lambdarank
|
# - **Objective**: lambdarank
|
||||||
# - **Metric**: ndcg
|
# - **Metric**: ndcg
|
||||||
# - **Learning Rate**: 0.05
|
# - **Learning Rate**: 0.05
|
||||||
# - **Num Leaves**: 31
|
# - **Num Leaves**: 31
|
||||||
# - **N Quantiles**: 20
|
# - **N Quantiles**: 20
|
||||||
#
|
# #
|
||||||
# ### 输出结果
|
# ### 输出结果
|
||||||
#
|
# #
|
||||||
# - rank_output.csv: 每日Top-N推荐股票(格式:date, score, ts_code)
|
# - rank_output.csv: 每日Top-N推荐股票(格式:date, score, ts_code)
|
||||||
# - 特征重要性排名
|
# - 特征重要性排名
|
||||||
# - Top-k 策略统计和图表
|
# - Top-k 策略统计和图表
|
||||||
# - NDCG训练指标曲线
|
# - NDCG训练指标曲线
|
||||||
#
|
# #
|
||||||
# ### 后续优化方向
|
# ### 后续优化方向
|
||||||
#
|
# #
|
||||||
# 1. **特征工程**: 尝试更多因子组合
|
# 1. **特征工程**: 尝试更多因子组合
|
||||||
# 2. **超参数调优**: 使用网格搜索优化 LambdaRank 参数
|
# 2. **超参数调优**: 使用网格搜索优化 LambdaRank 参数
|
||||||
# 3. **模型集成**: 结合多个排序模型的预测
|
# 3. **模型集成**: 结合多个排序模型的预测
|
||||||
# 4. **更复杂的分组**: 考虑按行业分组排序
|
# 4. **更复杂的分组**: 考虑按行业分组排序
|
||||||
|
#
|
||||||
@@ -30,6 +30,7 @@
|
|||||||
" Trainer,\n",
|
" Trainer,\n",
|
||||||
" Winsorizer,\n",
|
" Winsorizer,\n",
|
||||||
" NullFiller,\n",
|
" NullFiller,\n",
|
||||||
|
" check_data_quality,\n",
|
||||||
")\n",
|
")\n",
|
||||||
"from src.training.config import TrainingConfig\n",
|
"from src.training.config import TrainingConfig\n",
|
||||||
"\n"
|
"\n"
|
||||||
@@ -46,13 +47,13 @@
|
|||||||
"outputs": [],
|
"outputs": [],
|
||||||
"execution_count": null,
|
"execution_count": null,
|
||||||
"source": [
|
"source": [
|
||||||
"def create_factors_with_metadata(\n",
|
"def register_factors(\n",
|
||||||
" engine: FactorEngine,\n",
|
" engine: FactorEngine,\n",
|
||||||
" selected_factors: List[str],\n",
|
" selected_factors: List[str],\n",
|
||||||
" factor_definitions: dict,\n",
|
" factor_definitions: dict,\n",
|
||||||
" label_factor: dict,\n",
|
" label_factor: dict,\n",
|
||||||
") -> List[str]:\n",
|
") -> List[str]:\n",
|
||||||
" \"\"\"注册因子(SELECTED_FACTORS 从 metadata 查询,FACTOR_DEFINITIONS 用表达式注册)\"\"\"\n",
|
" \"\"\"注册因子(selected_factors 从 metadata 查询,factor_definitions 用 DSL 表达式注册)\"\"\"\n",
|
||||||
" print(\"=\" * 80)\n",
|
" print(\"=\" * 80)\n",
|
||||||
" print(\"注册因子\")\n",
|
" print(\"注册因子\")\n",
|
||||||
" print(\"=\" * 80)\n",
|
" print(\"=\" * 80)\n",
|
||||||
@@ -327,9 +328,6 @@
|
|||||||
" \"random_state\": 42,\n",
|
" \"random_state\": 42,\n",
|
||||||
"}\n",
|
"}\n",
|
||||||
"\n",
|
"\n",
|
||||||
"# 数据处理器配置(新 API:需要传入 feature_cols)\n",
|
|
||||||
"# 注意:processor 现在需要显式指定要处理的特征列\n",
|
|
||||||
"\n",
|
|
||||||
"\n",
|
"\n",
|
||||||
"# 股票池筛选函数\n",
|
"# 股票池筛选函数\n",
|
||||||
"# 使用新的 StockPoolManager API:传入自定义筛选函数和所需列/因子\n",
|
"# 使用新的 StockPoolManager API:传入自定义筛选函数和所需列/因子\n",
|
||||||
@@ -409,7 +407,7 @@
|
|||||||
"\n",
|
"\n",
|
||||||
"# 2. 使用 metadata 定义因子\n",
|
"# 2. 使用 metadata 定义因子\n",
|
||||||
"print(\"\\n[2] 定义因子(从 metadata 注册)\")\n",
|
"print(\"\\n[2] 定义因子(从 metadata 注册)\")\n",
|
||||||
"feature_cols = create_factors_with_metadata(\n",
|
"feature_cols = register_factors(\n",
|
||||||
" engine, SELECTED_FACTORS, FACTOR_DEFINITIONS, LABEL_FACTOR\n",
|
" engine, SELECTED_FACTORS, FACTOR_DEFINITIONS, LABEL_FACTOR\n",
|
||||||
")\n",
|
")\n",
|
||||||
"target_col = LABEL_NAME\n",
|
"target_col = LABEL_NAME\n",
|
||||||
@@ -434,7 +432,7 @@
|
|||||||
"# 5. 创建模型\n",
|
"# 5. 创建模型\n",
|
||||||
"model = LightGBMModel(params=MODEL_PARAMS)\n",
|
"model = LightGBMModel(params=MODEL_PARAMS)\n",
|
||||||
"\n",
|
"\n",
|
||||||
"# 6. 创建数据处理器(新 API:需要传入 feature_cols)\n",
|
"# 6. 创建数据处理器(使用函数返回的完整特征列表)\n",
|
||||||
"processors = [\n",
|
"processors = [\n",
|
||||||
" NullFiller(feature_cols=feature_cols, strategy=\"mean\"),\n",
|
" NullFiller(feature_cols=feature_cols, strategy=\"mean\"),\n",
|
||||||
" Winsorizer(feature_cols=feature_cols, lower=0.01, upper=0.99),\n",
|
" Winsorizer(feature_cols=feature_cols, lower=0.01, upper=0.99),\n",
|
||||||
@@ -560,8 +558,32 @@
|
|||||||
"outputs": [],
|
"outputs": [],
|
||||||
"execution_count": null,
|
"execution_count": null,
|
||||||
"source": [
|
"source": [
|
||||||
"# 步骤 3: 训练集数据处理\n",
|
"# 步骤 3: 数据质量检查(必须在预处理之前)\n",
|
||||||
"print(\"\\n[步骤 3/6] 训练集数据处理\")\n",
|
"print(\"\\n[步骤 3/7] 数据质量检查\")\n",
|
||||||
|
"print(\"-\" * 60)\n",
|
||||||
|
"print(\" [说明] 此检查在 fillna 等处理之前执行,用于发现数据问题\")\n",
|
||||||
|
"\n",
|
||||||
|
"print(\"\\n 检查训练集...\")\n",
|
||||||
|
"check_data_quality(train_data, feature_cols, raise_on_error=True)\n",
|
||||||
|
"\n",
|
||||||
|
"if \"val_data\" in locals() and val_data is not None:\n",
|
||||||
|
" print(\"\\n 检查验证集...\")\n",
|
||||||
|
" check_data_quality(val_data, feature_cols, raise_on_error=True)\n",
|
||||||
|
"\n",
|
||||||
|
"print(\"\\n 检查测试集...\")\n",
|
||||||
|
"check_data_quality(test_data, feature_cols, raise_on_error=True)\n",
|
||||||
|
"\n",
|
||||||
|
"print(\" [成功] 数据质量检查通过,未发现异常\")\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"metadata": {},
|
||||||
|
"cell_type": "code",
|
||||||
|
"outputs": [],
|
||||||
|
"execution_count": null,
|
||||||
|
"source": [
|
||||||
|
"# 步骤 4: 训练集数据处理\n",
|
||||||
|
"print(\"\\n[步骤 4/7] 训练集数据处理\")\n",
|
||||||
"print(\"-\" * 60)\n",
|
"print(\"-\" * 60)\n",
|
||||||
"fitted_processors = []\n",
|
"fitted_processors = []\n",
|
||||||
"if processors:\n",
|
"if processors:\n",
|
||||||
@@ -595,7 +617,7 @@
|
|||||||
"execution_count": null,
|
"execution_count": null,
|
||||||
"source": [
|
"source": [
|
||||||
"# 步骤 4: 训练模型\n",
|
"# 步骤 4: 训练模型\n",
|
||||||
"print(\"\\n[步骤 4/6] 训练模型\")\n",
|
"print(\"\\n[步骤 5/7] 训练模型\")\n",
|
||||||
"print(\"-\" * 60)\n",
|
"print(\"-\" * 60)\n",
|
||||||
"print(f\" 模型类型: LightGBM\")\n",
|
"print(f\" 模型类型: LightGBM\")\n",
|
||||||
"print(f\" 训练样本数: {len(train_data)}\")\n",
|
"print(f\" 训练样本数: {len(train_data)}\")\n",
|
||||||
@@ -624,7 +646,7 @@
|
|||||||
"execution_count": null,
|
"execution_count": null,
|
||||||
"source": [
|
"source": [
|
||||||
"# 步骤 5: 测试集数据处理\n",
|
"# 步骤 5: 测试集数据处理\n",
|
||||||
"print(\"\\n[步骤 5/6] 测试集数据处理\")\n",
|
"print(\"\\n[步骤 6/7] 测试集数据处理\")\n",
|
||||||
"print(\"-\" * 60)\n",
|
"print(\"-\" * 60)\n",
|
||||||
"if processors and test_data is not train_data:\n",
|
"if processors and test_data is not train_data:\n",
|
||||||
" for i, processor in enumerate(fitted_processors, 1):\n",
|
" for i, processor in enumerate(fitted_processors, 1):\n",
|
||||||
@@ -647,7 +669,7 @@
|
|||||||
"execution_count": null,
|
"execution_count": null,
|
||||||
"source": [
|
"source": [
|
||||||
"# 步骤 6: 生成预测\n",
|
"# 步骤 6: 生成预测\n",
|
||||||
"print(\"\\n[步骤 6/6] 生成预测\")\n",
|
"print(\"\\n[步骤 7/7] 生成预测\")\n",
|
||||||
"print(\"-\" * 60)\n",
|
"print(\"-\" * 60)\n",
|
||||||
"X_test = test_data.select(feature_cols)\n",
|
"X_test = test_data.select(feature_cols)\n",
|
||||||
"print(f\" 测试样本数: {len(X_test)}\")\n",
|
"print(f\" 测试样本数: {len(X_test)}\")\n",
|
||||||
|
|||||||
@@ -110,6 +110,9 @@ class SchemaCache:
|
|||||||
# 字段到表的映射(一个字段可能在多个表中存在)
|
# 字段到表的映射(一个字段可能在多个表中存在)
|
||||||
field_to_tables: Dict[str, List[str]] = {}
|
field_to_tables: Dict[str, List[str]] = {}
|
||||||
for table, fields in table_fields.items():
|
for table, fields in table_fields.items():
|
||||||
|
# 跳过不支持 PIT 的财务表(如 fina_indicator)
|
||||||
|
if table.lower() in self._NON_PIT_FINANCIAL_TABLES:
|
||||||
|
continue
|
||||||
for field in fields:
|
for field in fields:
|
||||||
if field not in field_to_tables:
|
if field not in field_to_tables:
|
||||||
field_to_tables[field] = []
|
field_to_tables[field] = []
|
||||||
@@ -124,17 +127,30 @@ class SchemaCache:
|
|||||||
sorted_tables = sorted(tables, key=lambda t: priority_order.get(t, 999))
|
sorted_tables = sorted(tables, key=lambda t: priority_order.get(t, 999))
|
||||||
self._field_to_table_map[field] = sorted_tables[0]
|
self._field_to_table_map[field] = sorted_tables[0]
|
||||||
|
|
||||||
|
# 不支持 PIT(Point-In-Time)策略的财务表列表
|
||||||
|
# 这些表缺少 f_ann_date 列,无法使用 asof_backward 模式
|
||||||
|
_NON_PIT_FINANCIAL_TABLES = {"financial_fina_indicator"}
|
||||||
|
|
||||||
def is_financial_table(self, table_name: str) -> bool:
|
def is_financial_table(self, table_name: str) -> bool:
|
||||||
"""判断是否为财务数据表。
|
"""判断是否为财务数据表。
|
||||||
|
|
||||||
|
注意:只有支持 PIT 策略(有 f_ann_date 列)的财务表才会返回 True。
|
||||||
|
fina_indicator 表由于只有 ann_date 而没有 f_ann_date,被排除在外。
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
table_name: 表名
|
table_name: 表名
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
是否为财务数据表
|
是否为支持 PIT 的财务数据表
|
||||||
"""
|
"""
|
||||||
financial_prefixes = ("financial_", "income", "balance", "cashflow")
|
financial_prefixes = ("financial_", "income", "balance", "cashflow")
|
||||||
return table_name.lower().startswith(financial_prefixes)
|
is_financial = table_name.lower().startswith(financial_prefixes)
|
||||||
|
|
||||||
|
# 排除不支持 PIT 的表
|
||||||
|
if is_financial and table_name.lower() in self._NON_PIT_FINANCIAL_TABLES:
|
||||||
|
return False
|
||||||
|
|
||||||
|
return is_financial
|
||||||
|
|
||||||
def get_table_fields(self, table_name: str) -> List[str]:
|
def get_table_fields(self, table_name: str) -> List[str]:
|
||||||
"""获取指定表的字段列表。
|
"""获取指定表的字段列表。
|
||||||
|
|||||||
Reference in New Issue
Block a user