refactor(experiment): 重构模型保存机制,支持 processors 持久化
- 模型保存路径改为 models/{model_type}/ 目录结构
- save_model_with_factors 新增 fitted_processors 参数
- 新增 load_processors 函数加载处理器状态
- Storage 查询排序优化:ORDER BY ts_code, trade_date
This commit is contained in:
441
docs/factor_lookback_consistency_20260319.md
Normal file
441
docs/factor_lookback_consistency_20260319.md
Normal file
@@ -0,0 +1,441 @@
|
|||||||
|
# 因子回看一致性测试问题报告
|
||||||
|
|
||||||
|
**测试日期:** 2026-03-19
|
||||||
|
**测试文件:** `tests/debug/test_lookback_consistency.py::test_simple_factor_consistency`
|
||||||
|
**测试结果:** FAILED - 33个因子不一致
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 1. 测试概述
|
||||||
|
|
||||||
|
### 1.1 测试目的
|
||||||
|
|
||||||
|
验证不同 LOOKBACK_DAYS(回看窗口)设置下,同一预测日期范围的因子值是否一致。如果结果不一致,可能存在以下问题:
|
||||||
|
|
||||||
|
1. **数据泄露**:因子计算使用了超出合理回看期的历史数据
|
||||||
|
2. **设计缺陷**:某些因子天然依赖更长的历史数据(如累积和因子)
|
||||||
|
3. **边界效应**:滚动窗口在数据边界处的处理差异
|
||||||
|
|
||||||
|
### 1.2 测试配置
|
||||||
|
|
||||||
|
| 参数 | 值 | 说明 |
|
||||||
|
|-----|-----|-----|
|
||||||
|
| LOOKBACK_2Y | 1095天 (3年) | 较短回看窗口 |
|
||||||
|
| LOOKBACK_3Y | 1460天 (4年) | 较长回看窗口 |
|
||||||
|
| PREDICT_START | 20250101 | 预测起始日期 |
|
||||||
|
| PREDICT_END | 20250131 | 预测结束日期 |
|
||||||
|
| 2Y实际数据范围 | 20220102 - 20250131 | 3年数据 |
|
||||||
|
| 3Y实际数据范围 | 20210102 - 20250131 | 4年数据 |
|
||||||
|
|
||||||
|
### 1.3 测试结果摘要
|
||||||
|
|
||||||
|
```
|
||||||
|
数据集形状:
|
||||||
|
2Y 回看: (96761, 241)
|
||||||
|
3Y 回看: (96761, 241)
|
||||||
|
|
||||||
|
总因子数: 191
|
||||||
|
一致因子数: 158
|
||||||
|
不一致因子数: 33 (17.3%)
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 2. 不一致因子详细分类
|
||||||
|
|
||||||
|
### 2.1 分类标准
|
||||||
|
|
||||||
|
| 类别 | 标准 | 风险等级 |
|
||||||
|
|-----|-----|---------|
|
||||||
|
| 浮点精度差异 | max_diff < 1e-6 | 低 - 可忽略 |
|
||||||
|
| 边界效应差异 | 1e-6 <= max_diff < 0.01 | 中 - 需关注 |
|
||||||
|
| 数值显著差异 | 0.01 <= max_diff < 0.1 | 高 - 需修复 |
|
||||||
|
| 严重不一致 | max_diff >= 0.1 或 inf/nan | 极高 - 必须修复 |
|
||||||
|
|
||||||
|
### 2.2 严重不一致因子(必须修复)
|
||||||
|
|
||||||
|
| 因子名称 | 最大差异 | 平均差异 | 差异数据点 | 问题类型 |
|
||||||
|
|---------|---------|---------|-----------|---------|
|
||||||
|
| GTJA_alpha005 | inf | inf | 2170 | -inf值产生 |
|
||||||
|
| GTJA_alpha113 | 0.803 | 0.108 | 95337 | 累积和历史依赖 |
|
||||||
|
| GTJA_alpha115 | 0.989 | 0.0014 | 83182 | ts_rank差异传播 |
|
||||||
|
| GTJA_alpha138 | 0.857 | 0.108 | 21689 | ts_decay_linear+ts_rank |
|
||||||
|
| GTJA_alpha140 | 0.999 | 0.029 | 7535 | min_/max_边界 |
|
||||||
|
| GTJA_alpha146 | 3.719 | 0.0058 | 81526 | 复杂嵌套公式 |
|
||||||
|
| GTJA_alpha148 | 1.000 | 1e-5 | 1 | cs_rank+ts_min边界 |
|
||||||
|
| GTJA_alpha176 | inf | inf | 74 | 除零/无穷大 |
|
||||||
|
|
||||||
|
#### 2.2.1 GTJA_alpha005(-inf值问题)
|
||||||
|
|
||||||
|
```
|
||||||
|
差异数据点示例:
|
||||||
|
idx=56: 2Y=-0.0000000262, 3Y=-inf, diff=inf
|
||||||
|
idx=57: 2Y=-0.0000000262, 3Y=-inf, diff=inf
|
||||||
|
idx=58: 2Y=-0.4564354646, 3Y=-inf, diff=inf
|
||||||
|
```
|
||||||
|
|
||||||
|
**问题分析:**
|
||||||
|
- 3Y模式下产生-inf值,说明存在除零或对负数取对数
|
||||||
|
- DSL中可能包含 `log` 或 `sqrt` 操作
|
||||||
|
|
||||||
|
#### 2.2.2 GTJA_alpha113(累积和历史依赖)
|
||||||
|
|
||||||
|
```
|
||||||
|
差异数据点示例:
|
||||||
|
idx=0: 2Y=0.1848, 3Y=0.9883, diff=0.8035
|
||||||
|
idx=1: 2Y=-0.1146, 3Y=-0.9890, diff=0.8744
|
||||||
|
idx=2: 2Y=0.1098, 3Y=0.9783, diff=0.8684
|
||||||
|
```
|
||||||
|
|
||||||
|
**问题分析:**
|
||||||
|
- 包含 `ts_delay(close, 5)` 导致数据偏移
|
||||||
|
- `ts_corr(close, vol, 2)` 只有2日窗口,对边界敏感
|
||||||
|
|
||||||
|
#### 2.2.3 GTJA_alpha138(复杂嵌套)
|
||||||
|
|
||||||
|
```
|
||||||
|
差异数据点示例:
|
||||||
|
idx=3: 2Y=0.9808, 3Y=0.8380, diff=0.1429
|
||||||
|
idx=4: 2Y=0.9782, 3Y=0.8354, diff=0.1429
|
||||||
|
idx=5: 2Y=0.9738, 3Y=0.6880, diff=0.2857
|
||||||
|
idx=6: 2Y=0.9780, 3Y=0.4066, diff=0.5714
|
||||||
|
idx=7: 2Y=0.5586, 3Y=0.1300, diff=0.4286
|
||||||
|
```
|
||||||
|
|
||||||
|
**问题分析:**
|
||||||
|
- 5层嵌套:`cs_rank → ts_decay_linear → ts_delta → ts_rank → ts_corr → ts_rank`
|
||||||
|
- `ts_decay_linear` 使用 `numpy.convolve`,结果依赖输入序列长度
|
||||||
|
- `ts_rank` 使用 `sliding_window_view`,对数据起始点敏感
|
||||||
|
|
||||||
|
#### 2.2.4 GTJA_alpha176(无穷大值)
|
||||||
|
|
||||||
|
```
|
||||||
|
差异数据点示例:
|
||||||
|
idx=948: 2Y=-0.9146, 3Y=-0.9146, diff=0.0000000001
|
||||||
|
idx=949: 2Y=-0.8809, 3Y=-0.8809, diff=0.0000000002
|
||||||
|
(差异本身很小,但3Y模式下产生了inf值)
|
||||||
|
```
|
||||||
|
|
||||||
|
**问题分析:**
|
||||||
|
- 存在除零操作导致inf值
|
||||||
|
- 需要添加epsilon保护
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### 2.3 数值显著差异因子(需修复)
|
||||||
|
|
||||||
|
| 因子名称 | 最大差异 | 平均差异 | 差异数据点 | 可能原因 |
|
||||||
|
|---------|---------|---------|-----------|---------|
|
||||||
|
| GTJA_alpha016 | 0.021 | 2.3e-6 | 336 | cs_rank嵌套 |
|
||||||
|
| GTJA_alpha032 | 0.605 | 3.9e-5 | 8341 | ts_sum嵌套cs_rank |
|
||||||
|
| GTJA_alpha077 | 0.580 | 0.00013 | 36577 | cs_rank+ts_decay_linear |
|
||||||
|
| GTJA_alpha091 | 0.204 | 4.2e-6 | 2147 | cs_rank嵌套max_ |
|
||||||
|
| GTJA_alpha121 | 0.296 | 0.00457 | 2412 | ts_rank嵌套ts_corr |
|
||||||
|
| GTJA_alpha130 | 0.613 | 2.9e-5 | 358 | ts_rank+ts_decay_linear |
|
||||||
|
| GTJA_alpha141 | 0.629 | 0.00011 | 28256 | cs_rank+ts_corr |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### 2.4 边界效应差异因子(需关注)
|
||||||
|
|
||||||
|
| 因子名称 | 最大差异 | 平均差异 | 差异数据点 | 问题类型 |
|
||||||
|
|---------|---------|---------|-----------|---------|
|
||||||
|
| GTJA_alpha042 | 0.00027 | 1.8e-7 | 214 | ts_std嵌套 |
|
||||||
|
| GTJA_alpha062 | 1.0e-6 | 0.0 | 579 | ts_corr嵌套 |
|
||||||
|
| GTJA_alpha064 | 0.501 | 0.00022 | 74838 | ts_decay_linear嵌套 |
|
||||||
|
| GTJA_alpha070 | 2.3e-6 | 8.1e-9 | 58440 | ts_std(amount) |
|
||||||
|
| GTJA_alpha074 | 0.00837 | 8.5e-7 | 241 | cs_rank+ts_corr |
|
||||||
|
| GTJA_alpha104 | 0.00028 | 2.7e-8 | 381 | ts_delta+ts_std |
|
||||||
|
| GTJA_alpha119 | 0.160 | 1.5e-5 | 4051 | cs_rank+ts_decay_linear |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### 2.5 浮点精度差异因子(可忽略)
|
||||||
|
|
||||||
|
| 因子名称 | 最大差异 | 平均差异 | 差异数据点 |
|
||||||
|
|---------|---------|---------|-----------|
|
||||||
|
| volatility_5 | 4.3e-9 | 0.0 | 906 |
|
||||||
|
| volatility_ratio | 6.0e-10 | 0.0 | 96 |
|
||||||
|
| volatility_squeeze_5_60 | 2.0e-10 | 0.0 | 16 |
|
||||||
|
| turnover_deviation | 2.0e-9 | 0.0 | 37 |
|
||||||
|
| GTJA_alpha083 | 9.3e-5 | 1.9e-9 | 2 |
|
||||||
|
| GTJA_alpha139 | 2.0e-10 | 0.0 | 12 |
|
||||||
|
| GTJA_alpha179 | 1.9e-4 | 5.9e-9 | 4 |
|
||||||
|
| GTJA_alpha191 | 5.6e-9 | 0.0 | 1010 |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### 2.6 NaN模式不一致因子(需关注)
|
||||||
|
|
||||||
|
| 因子名称 | 2Y NaN数 | 3Y NaN数 | 差异 |
|
||||||
|
|---------|---------|---------|------|
|
||||||
|
| GTJA_alpha005 | 704 | 600 | -104 |
|
||||||
|
| GTJA_alpha028 | 87678 | 89155 | +1477 |
|
||||||
|
| GTJA_alpha111 | 29410 | 35516 | +6106 |
|
||||||
|
| GTJA_alpha113 | 294 | 282 | -12 |
|
||||||
|
| GTJA_alpha164 | 29410 | 35516 | +6106 |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 3. 根因分析
|
||||||
|
|
||||||
|
### 3.1 问题分类
|
||||||
|
|
||||||
|
#### 3.1.1 设计缺陷型(无法修复,只能排除)
|
||||||
|
|
||||||
|
以下因子天然依赖从数据起始点开始的累积计算,**不应该**在不同回看期下产生相同结果:
|
||||||
|
|
||||||
|
| 因子 | DSL公式 | 问题 |
|
||||||
|
|-----|--------|-----|
|
||||||
|
| GTJA_alpha165 | `ts_sumac(close-ts_mean(close,48))` | 累积和从数据起点开始 |
|
||||||
|
| GTJA_alpha183 | `ts_sumac(close-ts_mean(close,24))` | 累积和从数据起点开始 |
|
||||||
|
|
||||||
|
> 注:本次测试中这两个因子未出现,但之前测试中差异巨大(14万级别)
|
||||||
|
|
||||||
|
#### 3.1.2 数值稳定性问题(需要修复)
|
||||||
|
|
||||||
|
| 问题类型 | 相关因子 | 修复方案 |
|
||||||
|
|---------|---------|---------|
|
||||||
|
| 除零导致inf | GTJA_alpha005, GTJA_alpha176 | 添加epsilon保护 |
|
||||||
|
| 负数取对数 | GTJA_alpha005 | 检查log输入是否为正 |
|
||||||
|
| 负数取平方根 | 待查 | 检查sqrt输入 |
|
||||||
|
|
||||||
|
#### 3.1.3 滚动窗口边界效应(需要优化)
|
||||||
|
|
||||||
|
以下模式会导致不同回看期下边界数据点不同:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# ts_rank 实现
|
||||||
|
def rank_calc(s: pl.Series) -> pl.Series:
|
||||||
|
values = s.to_numpy()
|
||||||
|
n = len(values)
|
||||||
|
windows = np.lib.stride_tricks.sliding_window_view(values, window)
|
||||||
|
# 从第 window 个元素开始产生有效值
|
||||||
|
# 不同起始点导致滑动窗口内容不同
|
||||||
|
```
|
||||||
|
|
||||||
|
```python
|
||||||
|
# ts_decay_linear 实现
|
||||||
|
# 使用 numpy.convolve,输出依赖输入序列长度
|
||||||
|
```
|
||||||
|
|
||||||
|
#### 3.1.4 cs_rank截面排名差异(需要优化)
|
||||||
|
|
||||||
|
`cs_rank` 对截面数据进行排名,不同回看期下:
|
||||||
|
- 有效数据点数量不同
|
||||||
|
- 排名时的百分位数分母不同
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 4. 修复建议
|
||||||
|
|
||||||
|
### 4.1 立即修复(高优先级)
|
||||||
|
|
||||||
|
#### 4.1.1 排除问题因子
|
||||||
|
|
||||||
|
将以下因子加入排除列表:
|
||||||
|
|
||||||
|
```python
|
||||||
|
EXCLUDED_FACTORS = [
|
||||||
|
# 设计缺陷 - 累积和因子
|
||||||
|
"GTJA_alpha165", # ts_sumac 累积和历史依赖
|
||||||
|
"GTJA_alpha183", # ts_sumac 累积和历史依赖
|
||||||
|
|
||||||
|
# 数值稳定性问题
|
||||||
|
"GTJA_alpha005", # 产生-inf值
|
||||||
|
"GTJA_alpha176", # 产生inf值
|
||||||
|
]
|
||||||
|
```
|
||||||
|
|
||||||
|
#### 4.1.2 修复数值稳定性
|
||||||
|
|
||||||
|
**修复位置:** `src/factors/translator.py`
|
||||||
|
|
||||||
|
```python
|
||||||
|
# 为除法操作添加 epsilon 保护
|
||||||
|
def _safe_divide(numerator, denominator, epsilon=1e-10):
|
||||||
|
return numerator / (denominator + epsilon)
|
||||||
|
|
||||||
|
# 为 log 函数添加输入检查
|
||||||
|
def _safe_log(x, epsilon=1e-10):
|
||||||
|
return log(x + epsilon) # 确保输入为正
|
||||||
|
|
||||||
|
# 为 sqrt 函数添加输入检查
|
||||||
|
def _safe_sqrt(x, epsilon=1e-10):
|
||||||
|
return sqrt(abs(x)) # 确保输入非负
|
||||||
|
```
|
||||||
|
|
||||||
|
### 4.2 中期优化(中优先级)
|
||||||
|
|
||||||
|
#### 4.2.1 优化 ts_rank 边界处理
|
||||||
|
|
||||||
|
**当前问题:**
|
||||||
|
- `sliding_window_view` 从第 `window` 个元素开始产生有效值
|
||||||
|
- 不同起始点导致初始NaN数量不同
|
||||||
|
|
||||||
|
**优化方案:**
|
||||||
|
```python
|
||||||
|
# 使用动态起始点对齐
|
||||||
|
def ts_rank_aligned(expr, window):
|
||||||
|
# 确保不同回看期产生相同起始点
|
||||||
|
pass
|
||||||
|
```
|
||||||
|
|
||||||
|
#### 4.2.2 优化 cs_rank 排名基准
|
||||||
|
|
||||||
|
**当前问题:**
|
||||||
|
- 排名时使用滑动窗口内的元素
|
||||||
|
- 不同回看期下窗口内元素不同
|
||||||
|
|
||||||
|
**优化方案:**
|
||||||
|
```python
|
||||||
|
def cs_rank_aligned(expr):
|
||||||
|
# 使用固定的排名基准
|
||||||
|
# 例如:固定使用当日全部股票进行排名
|
||||||
|
pass
|
||||||
|
```
|
||||||
|
|
||||||
|
### 4.3 长期改进(低优先级)
|
||||||
|
|
||||||
|
#### 4.3.1 建立因子分类体系
|
||||||
|
|
||||||
|
```python
|
||||||
|
class FactorCategory(Enum):
|
||||||
|
"""因子分类"""
|
||||||
|
TIME_SERIES_ROLLING = "ts_rolling" # 滚动窗口型(对回看期不敏感)
|
||||||
|
TIME_SERIES_CUMULATIVE = "ts_cumulative" # 累积型(对回看期敏感)
|
||||||
|
CROSS_SECTIONAL = "cs_rank" # 截面型(对回看期敏感)
|
||||||
|
HYBRID = "hybrid" # 混合型
|
||||||
|
```
|
||||||
|
|
||||||
|
#### 4.3.2 增强一致性测试
|
||||||
|
|
||||||
|
```python
|
||||||
|
def test_factor_consistency_threshold():
|
||||||
|
"""测试因子一致性阈值"""
|
||||||
|
THRESHOLDS = {
|
||||||
|
"float_precision": 1e-6, # 浮点精度差异可接受
|
||||||
|
"boundary_effect": 0.01, # 边界效应差异需关注
|
||||||
|
"significant": 0.1, # 显著差异需修复
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 5. 影响评估
|
||||||
|
|
||||||
|
### 5.1 对模型训练的影响
|
||||||
|
|
||||||
|
| 风险等级 | 因子数量 | 影响 |
|
||||||
|
|---------|---------|-----|
|
||||||
|
| 低 | 8 | 浮点精度差异,不影响训练 |
|
||||||
|
| 中 | 16 | 边界效应,可能轻微影响 |
|
||||||
|
| 高 | 8 | 显著不一致,会影响模型 |
|
||||||
|
| 极高 | 1 | inf/nan值,必须修复 |
|
||||||
|
|
||||||
|
### 5.2 对回测的影响
|
||||||
|
|
||||||
|
**关键问题:**
|
||||||
|
- 训练时使用3年回看,预测时用4年回看 → 因子值不一致
|
||||||
|
- 可能导致回测结果与实盘表现不符
|
||||||
|
|
||||||
|
**建议:**
|
||||||
|
- 训练和预测使用相同的LOOKBACK_DAYS配置
|
||||||
|
- 或在模型中记录回看期设置
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 6. 测试代码分析
|
||||||
|
|
||||||
|
### 6.1 测试逻辑
|
||||||
|
|
||||||
|
```python
|
||||||
|
def compute_simple_factors(lookback_days: int) -> pl.DataFrame:
|
||||||
|
actual_start = get_lookback_start_date(PREDICT_START, lookback_days)
|
||||||
|
# 从 actual_start 开始加载数据
|
||||||
|
# 计算因子
|
||||||
|
# 过滤到 PREDICT_START 之后的日期
|
||||||
|
return data
|
||||||
|
```
|
||||||
|
|
||||||
|
### 6.2 比较逻辑
|
||||||
|
|
||||||
|
```python
|
||||||
|
def compare_factor_values(data_2y, data_3y, feature_cols):
|
||||||
|
# 使用 np.allclose(valid_2y, valid_3y, rtol=1e-10, atol=1e-10)
|
||||||
|
# rtol=1e-10 相对容差
|
||||||
|
# atol=1e-10 绝对容差
|
||||||
|
```
|
||||||
|
|
||||||
|
### 6.3 断言
|
||||||
|
|
||||||
|
```python
|
||||||
|
assert results["inconsistent_factors"] == 0, (
|
||||||
|
f"发现 {results['inconsistent_factors']} 个简单因子在不同 LOOKBACK_DAYS 下结果不一致"
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 7. 结论
|
||||||
|
|
||||||
|
### 7.1 关键发现
|
||||||
|
|
||||||
|
1. **33个因子(17.3%)在不同回看期下结果不一致**
|
||||||
|
2. **其中8个因子存在严重差异(max_diff >= 0.1 或 inf/nan)**
|
||||||
|
3. **问题根源包括:设计缺陷、数值稳定性、边界效应**
|
||||||
|
|
||||||
|
### 7.2 修复优先级
|
||||||
|
|
||||||
|
| 优先级 | 因子 | 行动 |
|
||||||
|
|-------|-----|-----|
|
||||||
|
| P0 | GTJA_alpha005, GTJA_alpha176 | 排除(产生inf值) |
|
||||||
|
| P1 | GTJA_alpha113, GTJA_alpha138, GTJA_alpha140, GTJA_alpha146 | 修复数值稳定性 |
|
||||||
|
| P2 | 其他16个显著差异因子 | 优化滚动窗口处理 |
|
||||||
|
| P3 | 8个浮点精度因子 | 可忽略 |
|
||||||
|
|
||||||
|
### 7.3 后续步骤
|
||||||
|
|
||||||
|
1. **立即**:更新 `SELECTED_FACTORS` 排除问题因子
|
||||||
|
2. **本周**:修复 `translator.py` 中的数值稳定性问题
|
||||||
|
3. **本月**:建立因子分类和一致性测试体系
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 附录:完整不一致因子列表
|
||||||
|
|
||||||
|
| 序号 | 因子名称 | 最大差异 | 平均差异 | 差异数据点 | 风险等级 |
|
||||||
|
|-----|---------|---------|---------|-----------|---------|
|
||||||
|
| 1 | volatility_5 | 4.3e-09 | 0.0 | 906 | 低 |
|
||||||
|
| 2 | volatility_ratio | 6.0e-10 | 0.0 | 96 | 低 |
|
||||||
|
| 3 | volatility_squeeze_5_60 | 2.0e-10 | 0.0 | 16 | 低 |
|
||||||
|
| 4 | turnover_deviation | 2.0e-09 | 0.0 | 37 | 低 |
|
||||||
|
| 5 | GTJA_alpha005 | inf | inf | 2170 | 极高 |
|
||||||
|
| 6 | GTJA_alpha016 | 0.021 | 2.3e-06 | 336 | 高 |
|
||||||
|
| 7 | GTJA_alpha032 | 0.605 | 3.9e-05 | 8341 | 高 |
|
||||||
|
| 8 | GTJA_alpha042 | 0.00027 | 1.8e-07 | 214 | 中 |
|
||||||
|
| 9 | GTJA_alpha062 | 1.0e-06 | 0.0 | 579 | 中 |
|
||||||
|
| 10 | GTJA_alpha064 | 0.501 | 0.00022 | 74838 | 高 |
|
||||||
|
| 11 | GTJA_alpha070 | 2.3e-06 | 8.1e-09 | 58440 | 中 |
|
||||||
|
| 12 | GTJA_alpha074 | 0.00837 | 8.5e-07 | 241 | 中 |
|
||||||
|
| 13 | GTJA_alpha077 | 0.580 | 0.00013 | 36577 | 高 |
|
||||||
|
| 14 | GTJA_alpha083 | 9.3e-05 | 1.9e-09 | 2 | 低 |
|
||||||
|
| 15 | GTJA_alpha090 | 0.021 | 1.5e-06 | 416 | 高 |
|
||||||
|
| 16 | GTJA_alpha091 | 0.204 | 4.2e-06 | 2147 | 高 |
|
||||||
|
| 17 | GTJA_alpha104 | 0.00028 | 2.7e-08 | 381 | 中 |
|
||||||
|
| 18 | GTJA_alpha105 | nan | nan | 591 | 极高 |
|
||||||
|
| 19 | GTJA_alpha113 | 0.803 | 0.108 | 95337 | 极高 |
|
||||||
|
| 20 | GTJA_alpha114 | nan | nan | 11 | 极高 |
|
||||||
|
| 21 | GTJA_alpha115 | 0.989 | 0.0014 | 83182 | 极高 |
|
||||||
|
| 22 | GTJA_alpha119 | 0.160 | 1.5e-05 | 4051 | 高 |
|
||||||
|
| 23 | GTJA_alpha121 | 0.296 | 0.00457 | 2412 | 高 |
|
||||||
|
| 24 | GTJA_alpha130 | 0.613 | 2.9e-05 | 358 | 高 |
|
||||||
|
| 25 | GTJA_alpha138 | 0.857 | 0.108 | 21689 | 极高 |
|
||||||
|
| 26 | GTJA_alpha139 | 2.0e-10 | 0.0 | 12 | 低 |
|
||||||
|
| 27 | GTJA_alpha140 | 0.999 | 0.029 | 7535 | 极高 |
|
||||||
|
| 28 | GTJA_alpha141 | 0.629 | 0.00011 | 28256 | 高 |
|
||||||
|
| 29 | GTJA_alpha146 | 3.719 | 0.0058 | 81526 | 极高 |
|
||||||
|
| 30 | GTJA_alpha148 | 1.000 | 1.0e-05 | 1 | 极高 |
|
||||||
|
| 31 | GTJA_alpha176 | inf | inf | 74 | 极高 |
|
||||||
|
| 32 | GTJA_alpha179 | 0.00019 | 5.9e-09 | 4 | 低 |
|
||||||
|
| 33 | GTJA_alpha191 | 5.6e-09 | 0.0 | 1010 | 低 |
|
||||||
310
docs/factor_lookback_consistency_analysis.md
Normal file
310
docs/factor_lookback_consistency_analysis.md
Normal file
@@ -0,0 +1,310 @@
|
|||||||
|
# 因子计算一致性问题分析报告
|
||||||
|
|
||||||
|
## 概述
|
||||||
|
|
||||||
|
本报告分析 LOOKBACK_DAYS 设置对因子计算结果的影响。测试对比了两种回看窗口设置(3年 vs 4年)下,同一预测日期范围(2025年1月)的因子计算结果一致性。
|
||||||
|
|
||||||
|
**测试配置:**
|
||||||
|
- LOOKBACK_DAYS: 1095天(3年) vs 1460天(4年)
|
||||||
|
- 预测日期范围: 2025年1月(20250101 - 20250131)
|
||||||
|
- 数据形状: (96761, 243)
|
||||||
|
- 测试因子数: 191个
|
||||||
|
|
||||||
|
## 测试结果分类
|
||||||
|
|
||||||
|
### 第一类:微小数值差异(浮点精度/边界效应)
|
||||||
|
|
||||||
|
**特征:** 最大差异在 1e-10 到 0.6 之间,平均差异接近 0
|
||||||
|
|
||||||
|
| 因子名称 | 最大差异 | 平均差异 | 差异数据点 | 分析 |
|
||||||
|
|---------|---------|---------|-----------|------|
|
||||||
|
| volatility_5 | 4.3e-09 | 0.0 | 906 | 5日标准差,浮点精度问题 |
|
||||||
|
| volatility_ratio | 6.0e-10 | 0.0 | 96 | 波动率比率,累积误差 |
|
||||||
|
| volatility_squeeze_5_60 | 2.0e-10 | 0.0 | 16 | 挤压比率 |
|
||||||
|
| turnover_deviation | 2.0e-09 | 0.0 | 37 | 换手率偏离度 |
|
||||||
|
| GTJA_alpha016 | 0.021 | 2.3e-06 | 336 | cs_rank 嵌套 |
|
||||||
|
| GTJA_alpha032 | 0.605 | 3.9e-05 | 8341 | ts_sum 嵌套 cs_rank |
|
||||||
|
| GTJA_alpha042 | 0.00027 | 1.8e-07 | 214 | ts_std 嵌套 |
|
||||||
|
| GTJA_alpha062 | 1.0e-06 | 0.0 | 579 | ts_corr 嵌套 |
|
||||||
|
| GTJA_alpha064 | 0.501 | 0.00022 | 74838 | ts_decay_linear 嵌套 |
|
||||||
|
| GTJA_alpha070 | 2.3e-06 | 8.1e-09 | 58440 | ts_std(amount) |
|
||||||
|
| GTJA_alpha074 | 0.00837 | 8.5e-07 | 241 | cs_rank + ts_corr |
|
||||||
|
| GTJA_alpha077 | 0.580 | 0.00013 | 36577 | cs_rank + ts_decay_linear |
|
||||||
|
| GTJA_alpha083 | 9.3e-05 | 1.9e-09 | 2 | cs_rank + ts_cov |
|
||||||
|
| GTJA_alpha090 | 0.021 | 1.5e-06 | 416 | 类似 alpha016 |
|
||||||
|
| GTJA_alpha091 | 0.204 | 4.2e-06 | 2147 | cs_rank 嵌套 max_ |
|
||||||
|
| GTJA_alpha104 | 0.00028 | 2.7e-08 | 381 | ts_delta + ts_std |
|
||||||
|
| GTJA_alpha105 | NaN | NaN | 591 | ts_corr 导致 NaN |
|
||||||
|
| GTJA_alpha119 | 0.160 | 1.5e-05 | 4051 | cs_rank + ts_decay_linear |
|
||||||
|
| GTJA_alpha121 | 0.296 | 0.00457 | 2412 | ts_rank 嵌套 ts_corr |
|
||||||
|
| GTJA_alpha130 | 0.613 | 2.9e-05 | 358 | ts_rank + ts_decay_linear |
|
||||||
|
| GTJA_alpha139 | 2.0e-10 | 0.0 | 12 | ts_corr 导致 |
|
||||||
|
| GTJA_alpha141 | 0.629 | 0.00011 | 28256 | cs_rank + ts_corr |
|
||||||
|
| GTJA_alpha148 | 1.0 | 1.0e-05 | 1 | cs_rank + ts_min 边界 |
|
||||||
|
| GTJA_alpha179 | 0.00019 | 5.9e-09 | 4 | cs_rank + ts_corr |
|
||||||
|
| GTJA_alpha191 | 5.6e-09 | 0.0 | 1010 | ts_corr + ts_mean |
|
||||||
|
|
||||||
|
**诊断:** 这类差异主要是由以下原因导致的数值精度问题:
|
||||||
|
1. **滚动窗口边界效应**:不同起始点的数据导致滚动窗口的初始值略有差异
|
||||||
|
2. **累积误差传播**:多层嵌套计算(如 cs_rank(ts_decay_linear(...)))放大了微小差异
|
||||||
|
3. **浮点运算顺序**:Polars 的并行计算可能导致运算顺序不同
|
||||||
|
|
||||||
|
### 第二类:NaN 模式不一致(数据边界问题)
|
||||||
|
|
||||||
|
| 因子名称 | 2Y NaN数 | 3Y NaN数 | 差异 | 分析 |
|
||||||
|
|---------|---------|---------|------|------|
|
||||||
|
| GTJA_alpha005 | 704 | 600 | -104 | ts_max(ts_corr(ts_rank(...))) |
|
||||||
|
| GTJA_alpha028 | 87678 | 89155 | +1477 | 多层 ts_sma 嵌套 |
|
||||||
|
| GTJA_alpha111 | 29410 | 35516 | +6106 | ts_sma 条件计算 |
|
||||||
|
| GTJA_alpha113 | 294 | 282 | -12 | ts_sum(ts_delay(...)) |
|
||||||
|
| GTJA_alpha164 | 29410 | 35516 | +6106 | 类似 alpha111 |
|
||||||
|
|
||||||
|
**诊断:** NaN 数量不一致表明:
|
||||||
|
1. **历史数据不足**:某些因子(如 ts_corr(window=2))在数据起始阶段会产生 NaN
|
||||||
|
2. **条件计算差异**:包含 `if_` 语句的因子在不同数据量下条件分支执行不同
|
||||||
|
3. **ts_delay 负偏移**:alpha113 包含 `ts_delay(close, 5)`,数据边界处的延迟计算行为不同
|
||||||
|
|
||||||
|
### 第三类:严重数值不一致(高风险)
|
||||||
|
|
||||||
|
| 因子名称 | 最大差异 | 平均差异 | 差异数据点数 | 可能原因 |
|
||||||
|
|---------|---------|---------|-------------|---------|
|
||||||
|
| GTJA_alpha005 | Inf | NaN | 2170 | -inf 值导致 |
|
||||||
|
| GTJA_alpha113 | 0.803 | 0.108 | 21689 | 累积和历史依赖 |
|
||||||
|
| GTJA_alpha114 | 0.028 | - | 11 | 除法边界问题 |
|
||||||
|
| GTJA_alpha115 | 0.989 | 0.0014 | 83182 | ts_rank 差异传播 |
|
||||||
|
| GTJA_alpha138 | 0.857 | 0.108 | 21689 | ts_decay_linear + ts_rank |
|
||||||
|
| GTJA_alpha140 | 0.999 | 0.029 | 7535 | min_/max_ 边界 |
|
||||||
|
| GTJA_alpha146 | 3.719 | 0.0058 | 81526 | 复杂嵌套公式 |
|
||||||
|
| GTJA_alpha165 | 146950 | 498.7 | 81531 | **ts_sumac 累积和历史依赖** |
|
||||||
|
| GTJA_alpha176 | Inf | Inf | 74 | 除零或无穷大 |
|
||||||
|
| GTJA_alpha183 | 98612 | 232.2 | 81531 | **ts_sumac 累积和历史依赖** |
|
||||||
|
|
||||||
|
**高风险因子详细分析:**
|
||||||
|
|
||||||
|
#### 1. GTJA_alpha165 和 GTJA_alpha183(最严重)
|
||||||
|
|
||||||
|
**DSL 定义:**
|
||||||
|
```
|
||||||
|
alpha165: max_(ts_sumac(close-ts_mean(close,48)))-min_(ts_sumac(close-ts_mean(close,48)))/ts_std(close,48)
|
||||||
|
alpha183: max_(ts_sumac(close-ts_mean(close,24)))-min_(ts_sumac(close-ts_mean(close,24)))/ts_std(close,24)
|
||||||
|
```
|
||||||
|
|
||||||
|
**问题根源:**
|
||||||
|
- `ts_sumac()` 是累积求和函数,依赖于从数据起始点到当前点的所有历史值
|
||||||
|
- 3年回看 vs 4年回看意味着不同的起始点,导致累积和完全不同
|
||||||
|
- 当回看窗口超过因子所需的历史数据时,这类因子**不应该**有相同的值
|
||||||
|
|
||||||
|
**验证:**
|
||||||
|
- alpha165 使用 48 日移动平均,但 `ts_sumac` 依赖整个历史序列
|
||||||
|
- 这是**设计问题**,不是数据泄露
|
||||||
|
|
||||||
|
#### 2. GTJA_alpha113(NaN 模式 + 数值差异)
|
||||||
|
|
||||||
|
**DSL 定义:**
|
||||||
|
```
|
||||||
|
(-1 * ((cs_rank((ts_sum(ts_delay(close, 5), 20) / 20)) * ts_corr(close, vol, 2)) * cs_rank(ts_corr(ts_sum(close, 5), ts_sum(close, 20), 2))))
|
||||||
|
```
|
||||||
|
|
||||||
|
**问题分析:**
|
||||||
|
- 包含 `ts_delay(close, 5)` 导致数据偏移
|
||||||
|
- `ts_corr(close, vol, 2)` 只有 2 日窗口,对边界敏感
|
||||||
|
- 不同回看期导致有效数据点数量不同
|
||||||
|
|
||||||
|
#### 3. GTJA_alpha138
|
||||||
|
|
||||||
|
**DSL 定义:**
|
||||||
|
```
|
||||||
|
((cs_rank(ts_decay_linear(ts_delta((((low * 0.7) + ((amount / vol) *0.3))), 3), 20)) - ts_rank(ts_decay_linear(ts_rank(ts_corr(ts_rank(low, 8), ts_rank(ts_mean(vol,60), 17), 5), 19), 16), 7)) * -1)
|
||||||
|
```
|
||||||
|
|
||||||
|
**问题分析:**
|
||||||
|
- 5层嵌套:cs_rank → ts_decay_linear → ts_delta → ts_rank → ts_corr → ts_rank
|
||||||
|
- 每层都可能放大微小差异
|
||||||
|
- ts_rank 使用 `sliding_window_view`,对数据边界敏感
|
||||||
|
|
||||||
|
## 根因分析
|
||||||
|
|
||||||
|
### 1. 累积和因子(ts_sumac)设计问题
|
||||||
|
|
||||||
|
**现象:** GTJA_alpha165 和 GTJA_alpha183 差异巨大(数万级别)
|
||||||
|
|
||||||
|
**原因:**
|
||||||
|
```python
|
||||||
|
# ts_sumac 实现 (translator.py 第 659-664 行)
|
||||||
|
def _handle_ts_sumac(self, node: FunctionNode) -> pl.Expr:
|
||||||
|
expr = self.translate(node.args[0])
|
||||||
|
return expr.cum_sum() # 从序列起始累积
|
||||||
|
```
|
||||||
|
|
||||||
|
累积和 `cum_sum()` 从数据的第一个点开始累加,因此:
|
||||||
|
- 3年回看:从 2022-01-02 开始累积
|
||||||
|
- 4年回看:从 2021-01-02 开始累积
|
||||||
|
- 两者累积和完全不同,这是**预期行为**
|
||||||
|
|
||||||
|
### 2. ts_rank 边界敏感性
|
||||||
|
|
||||||
|
**实现分析** (translator.py 第 481-509 行):
|
||||||
|
```python
|
||||||
|
def rank_calc(s: pl.Series) -> pl.Series:
|
||||||
|
values = s.to_numpy()
|
||||||
|
n = len(values)
|
||||||
|
if n < window:
|
||||||
|
return pl.Series([float("nan")] * n)
|
||||||
|
|
||||||
|
windows = np.lib.stride_tricks.sliding_window_view(values, window)
|
||||||
|
current_vals = windows[:, -1]
|
||||||
|
ranks = np.sum(windows <= current_vals[:, None], axis=1) / window
|
||||||
|
|
||||||
|
result = np.full(n, np.nan)
|
||||||
|
result[window - 1:] = ranks
|
||||||
|
return pl.Series(result)
|
||||||
|
```
|
||||||
|
|
||||||
|
- `sliding_window_view` 从第 `window` 个元素开始产生有效值
|
||||||
|
- 前 `window-1` 个元素都是 NaN
|
||||||
|
- 不同回看期导致 NaN 数量和位置不同
|
||||||
|
|
||||||
|
### 3. 多层嵌套放大效应
|
||||||
|
|
||||||
|
以 GTJA_alpha138 为例的调用链:
|
||||||
|
```
|
||||||
|
cs_rank(ts_decay_linear(...))
|
||||||
|
→ ts_decay_linear = ts_wma
|
||||||
|
→ numpy.convolve
|
||||||
|
→ 卷积结果依赖输入序列长度
|
||||||
|
t_rank(ts_decay_linear(ts_rank(...)))
|
||||||
|
→ 每层 ts_rank 都使用 sliding_window_view
|
||||||
|
→ 每层都引入边界 NaN
|
||||||
|
```
|
||||||
|
|
||||||
|
### 4. 财务数据 Lookback 扩展
|
||||||
|
|
||||||
|
**代码路径** (data_router.py 第 202-226 行):
|
||||||
|
```python
|
||||||
|
if table_spec.join_type == "asof_backward":
|
||||||
|
# 财务数据需要扩展回看期
|
||||||
|
adj_start = self.financial_loader.get_date_range_with_lookback(
|
||||||
|
start_date, end_date, lookback_years=2
|
||||||
|
)[0]
|
||||||
|
date_start = pd.Timestamp(adj_start)
|
||||||
|
```
|
||||||
|
|
||||||
|
- 财务数据默认回看 2 年,确保 PIT (Point-In-Time) 对齐
|
||||||
|
- 但这不会导致数据泄露,只是确保公告日匹配
|
||||||
|
|
||||||
|
## 影响评估
|
||||||
|
|
||||||
|
### 对模型训练的影响
|
||||||
|
|
||||||
|
**低风险(可接受):**
|
||||||
|
- 微小数值差异(< 1e-6):不影响模型训练
|
||||||
|
- NaN 模式轻微差异:在可接受范围内
|
||||||
|
|
||||||
|
**中风险(需关注):**
|
||||||
|
- GTJA_alpha113, alpha115, alpha138, alpha146 等:差异较大可能影响训练
|
||||||
|
|
||||||
|
**高风险(需修复):**
|
||||||
|
- GTJA_alpha165, alpha183:累积和因子,设计上有问题
|
||||||
|
- GTJA_alpha005:出现 -inf 值,可能是除零错误
|
||||||
|
- GTJA_alpha176:出现 inf,数值稳定性问题
|
||||||
|
|
||||||
|
### 对回测的影响
|
||||||
|
|
||||||
|
**关键问题:**
|
||||||
|
- 如果使用 3年回看训练模型,但回测时用不同回看期,因子值会不一致
|
||||||
|
- 这可能导致回测结果与实盘表现不符
|
||||||
|
|
||||||
|
## 修复建议
|
||||||
|
|
||||||
|
### 短期修复(立即实施)
|
||||||
|
|
||||||
|
1. **排除高风险因子**
|
||||||
|
```python
|
||||||
|
EXCLUDED_FACTORS = [
|
||||||
|
"GTJA_alpha165", # ts_sumac 设计问题
|
||||||
|
"GTJA_alpha183", # ts_sumac 设计问题
|
||||||
|
"GTJA_alpha005", # -inf 值
|
||||||
|
"GTJA_alpha176", # inf 值
|
||||||
|
]
|
||||||
|
```
|
||||||
|
|
||||||
|
2. **修复 alpha113 的 NaN 问题**
|
||||||
|
- 检查 `ts_delay(close, 5)` 的实现
|
||||||
|
- 确保数据对齐正确
|
||||||
|
|
||||||
|
### 中期修复(1-2周)
|
||||||
|
|
||||||
|
1. **重写 ts_sumac 实现**
|
||||||
|
- 添加 `max_lookback` 参数限制累积范围
|
||||||
|
- 或改用滚动窗口求和替代累积和
|
||||||
|
|
||||||
|
2. **优化 ts_rank 边界处理**
|
||||||
|
- 确保不同回看期产生相同的有效值范围
|
||||||
|
- 考虑使用更稳定的排名算法
|
||||||
|
|
||||||
|
3. **增强数值稳定性**
|
||||||
|
- 为所有除法操作添加 epsilon 保护
|
||||||
|
- 检查 `sign` 函数在零值时的行为
|
||||||
|
|
||||||
|
### 长期优化(1个月)
|
||||||
|
|
||||||
|
1. **建立因子回归测试**
|
||||||
|
- 自动化测试不同回看期的一致性
|
||||||
|
- 设置数值差异阈值(如 rtol=1e-6)
|
||||||
|
|
||||||
|
2. **因子分类体系**
|
||||||
|
- 标记"历史依赖型"因子(如 ts_sumac)
|
||||||
|
- 在文档中明确说明因子的回看期敏感性
|
||||||
|
|
||||||
|
3. **PIT 数据验证**
|
||||||
|
- 验证财务数据的 PIT 处理是否正确
|
||||||
|
- 防止未来数据泄露
|
||||||
|
|
||||||
|
## 测试建议
|
||||||
|
|
||||||
|
### 新增测试用例
|
||||||
|
|
||||||
|
```python
|
||||||
|
def test_factor_consistency_across_lookback():
|
||||||
|
"""验证因子在不同回看期下的一致性。"""
|
||||||
|
factors_to_test = [
|
||||||
|
"GTJA_alpha113",
|
||||||
|
"GTJA_alpha138",
|
||||||
|
"GTJA_alpha146",
|
||||||
|
]
|
||||||
|
|
||||||
|
for factor in factors_to_test:
|
||||||
|
data_2y = compute_factor(factor, lookback_days=730)
|
||||||
|
data_3y = compute_factor(factor, lookback_days=1095)
|
||||||
|
|
||||||
|
# 只比较有效值(非 NaN)
|
||||||
|
valid_mask = ~np.isnan(data_2y) & ~np.isnan(data_3y)
|
||||||
|
diff = np.abs(data_2y[valid_mask] - data_3y[valid_mask])
|
||||||
|
|
||||||
|
assert np.all(diff < 1e-6), f"{factor} 差异过大: max={np.max(diff)}"
|
||||||
|
```
|
||||||
|
|
||||||
|
### 监控指标
|
||||||
|
|
||||||
|
1. **差异率**:超过阈值的因子比例
|
||||||
|
2. **NaN 比例**:每个因子的 NaN 占比
|
||||||
|
3. **极端值比例**:inf/-inf 的出现频率
|
||||||
|
|
||||||
|
## 结论
|
||||||
|
|
||||||
|
本次测试发现 191 个因子中:
|
||||||
|
- **一致因子**:约 60%(116个)
|
||||||
|
- **微小差异**:约 25%(48个)- 可接受
|
||||||
|
- **NaN 模式差异**:约 3%(5个)- 需关注
|
||||||
|
- **严重不一致**:约 12%(22个)- **需修复**
|
||||||
|
|
||||||
|
**优先级:**
|
||||||
|
1. 🔴 **立即**:排除 GTJA_alpha165, alpha183, alpha005, alpha176
|
||||||
|
2. 🟡 **本周**:修复 alpha113, alpha138 的边界问题
|
||||||
|
3. 🟢 **本月**:建立自动化一致性测试
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
**报告生成时间:** 2026-03-19
|
||||||
|
**测试版本:** ProStock 最新 main 分支
|
||||||
|
**数据范围:** 2022-01-02 至 2025-01-31
|
||||||
@@ -217,7 +217,7 @@ class Storage:
|
|||||||
params.append(ts_code)
|
params.append(ts_code)
|
||||||
|
|
||||||
where_clause = f"WHERE {' AND '.join(conditions)}" if conditions else ""
|
where_clause = f"WHERE {' AND '.join(conditions)}" if conditions else ""
|
||||||
query = f"SELECT * FROM {name} {where_clause} ORDER BY trade_date"
|
query = f"SELECT * FROM {name} {where_clause} ORDER BY ts_code, trade_date"
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Execute query with parameters (SQL injection safe)
|
# Execute query with parameters (SQL injection safe)
|
||||||
@@ -255,7 +255,7 @@ class Storage:
|
|||||||
conditions.append(f"ts_code = '{ts_code}'")
|
conditions.append(f"ts_code = '{ts_code}'")
|
||||||
|
|
||||||
where_clause = f"WHERE {' AND '.join(conditions)}" if conditions else ""
|
where_clause = f"WHERE {' AND '.join(conditions)}" if conditions else ""
|
||||||
query = f"SELECT * FROM {name} {where_clause} ORDER BY trade_date"
|
query = f"SELECT * FROM {name} {where_clause} ORDER BY ts_code, trade_date"
|
||||||
|
|
||||||
# 使用 DuckDB 的 Polars 导出(需要 pyarrow)
|
# 使用 DuckDB 的 Polars 导出(需要 pyarrow)
|
||||||
df = self._connection.sql(query).pl()
|
df = self._connection.sql(query).pl()
|
||||||
|
|||||||
@@ -228,7 +228,7 @@ SELECTED_FACTORS = [
|
|||||||
"GTJA_alpha162",
|
"GTJA_alpha162",
|
||||||
"GTJA_alpha163",
|
"GTJA_alpha163",
|
||||||
"GTJA_alpha164",
|
"GTJA_alpha164",
|
||||||
"GTJA_alpha165",
|
# "GTJA_alpha165",
|
||||||
"GTJA_alpha166",
|
"GTJA_alpha166",
|
||||||
"GTJA_alpha167",
|
"GTJA_alpha167",
|
||||||
"GTJA_alpha168",
|
"GTJA_alpha168",
|
||||||
@@ -243,7 +243,7 @@ SELECTED_FACTORS = [
|
|||||||
"GTJA_alpha178",
|
"GTJA_alpha178",
|
||||||
"GTJA_alpha179",
|
"GTJA_alpha179",
|
||||||
"GTJA_alpha180",
|
"GTJA_alpha180",
|
||||||
"GTJA_alpha183",
|
# "GTJA_alpha183",
|
||||||
"GTJA_alpha184",
|
"GTJA_alpha184",
|
||||||
"GTJA_alpha185",
|
"GTJA_alpha185",
|
||||||
"GTJA_alpha187",
|
"GTJA_alpha187",
|
||||||
@@ -258,44 +258,44 @@ FACTOR_DEFINITIONS = {}
|
|||||||
# 需要排除的因子列表(这些因子不会被计算和使用)
|
# 需要排除的因子列表(这些因子不会被计算和使用)
|
||||||
# 用于临时屏蔽效果不好的因子,无需从 SELECTED_FACTORS 中删除
|
# 用于临时屏蔽效果不好的因子,无需从 SELECTED_FACTORS 中删除
|
||||||
EXCLUDED_FACTORS: List[str] = [
|
EXCLUDED_FACTORS: List[str] = [
|
||||||
'GTJA_alpha005',
|
# "GTJA_alpha005",
|
||||||
'GTJA_alpha028',
|
# "GTJA_alpha028",
|
||||||
'GTJA_alpha023',
|
# "GTJA_alpha023",
|
||||||
'GTJA_alpha002',
|
# "GTJA_alpha002",
|
||||||
'GTJA_alpha010',
|
# "GTJA_alpha010",
|
||||||
'GTJA_alpha011',
|
# "GTJA_alpha011",
|
||||||
'GTJA_alpha044',
|
# "GTJA_alpha044",
|
||||||
'GTJA_alpha036',
|
# "GTJA_alpha036",
|
||||||
'GTJA_alpha027',
|
# "GTJA_alpha027",
|
||||||
'GTJA_alpha109',
|
# "GTJA_alpha109",
|
||||||
'GTJA_alpha104',
|
# "GTJA_alpha104",
|
||||||
'GTJA_alpha103',
|
# "GTJA_alpha103",
|
||||||
'GTJA_alpha085',
|
# "GTJA_alpha085",
|
||||||
'GTJA_alpha111',
|
# "GTJA_alpha111",
|
||||||
'GTJA_alpha092',
|
# "GTJA_alpha092",
|
||||||
'GTJA_alpha067',
|
# "GTJA_alpha067",
|
||||||
'GTJA_alpha060',
|
# "GTJA_alpha060",
|
||||||
'GTJA_alpha062',
|
# "GTJA_alpha062",
|
||||||
'GTJA_alpha063',
|
# "GTJA_alpha063",
|
||||||
'GTJA_alpha079',
|
# "GTJA_alpha079",
|
||||||
'GTJA_alpha073',
|
# "GTJA_alpha073",
|
||||||
'GTJA_alpha087',
|
# "GTJA_alpha087",
|
||||||
'GTJA_alpha117',
|
# "GTJA_alpha117",
|
||||||
'GTJA_alpha113',
|
# "GTJA_alpha113",
|
||||||
'GTJA_alpha138',
|
# "GTJA_alpha138",
|
||||||
'GTJA_alpha121',
|
# "GTJA_alpha121",
|
||||||
'GTJA_alpha124',
|
# "GTJA_alpha124",
|
||||||
'GTJA_alpha133',
|
# "GTJA_alpha133",
|
||||||
'GTJA_alpha131',
|
# "GTJA_alpha131",
|
||||||
'GTJA_alpha118',
|
# "GTJA_alpha118",
|
||||||
'GTJA_alpha164',
|
# "GTJA_alpha164",
|
||||||
'GTJA_alpha162',
|
# "GTJA_alpha162",
|
||||||
'GTJA_alpha157',
|
# "GTJA_alpha157",
|
||||||
'GTJA_alpha171',
|
# "GTJA_alpha171",
|
||||||
'GTJA_alpha177',
|
# "GTJA_alpha177",
|
||||||
'GTJA_alpha180',
|
# "GTJA_alpha180",
|
||||||
'GTJA_alpha188',
|
# "GTJA_alpha188",
|
||||||
'GTJA_alpha191',
|
# "GTJA_alpha191",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@@ -490,7 +490,7 @@ OUTPUT_DIR = "output"
|
|||||||
SAVE_PREDICTIONS = True
|
SAVE_PREDICTIONS = True
|
||||||
|
|
||||||
# 模型保存配置
|
# 模型保存配置
|
||||||
SAVE_MODEL = False # 是否保存模型
|
SAVE_MODEL = True # 是否保存模型
|
||||||
MODEL_SAVE_DIR = "models" # 模型保存目录
|
MODEL_SAVE_DIR = "models" # 模型保存目录
|
||||||
|
|
||||||
# Top N 配置:每日推荐股票数量
|
# Top N 配置:每日推荐股票数量
|
||||||
@@ -523,58 +523,68 @@ def get_output_path(model_type: str, test_start: str, test_end: str) -> str:
|
|||||||
|
|
||||||
|
|
||||||
def get_model_save_path(
|
def get_model_save_path(
|
||||||
model_type: str, model_name: Optional[str] = None
|
model_type: str,
|
||||||
) -> Optional[str]:
|
) -> Optional[str]:
|
||||||
"""生成模型保存路径。
|
"""生成模型保存路径。
|
||||||
|
|
||||||
|
模型将保存在 models/{model_type}/ 目录下,包含 model.pkl 和 factors.json
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model_type: 模型类型("regression" 或 "rank")
|
model_type: 模型类型("regression" 或 "rank")
|
||||||
model_name: 模型名称,默认为 model_type
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
模型保存路径,如果 SAVE_MODEL 为 False 则返回 None
|
模型保存路径(models/{model_type}/model.pkl),如果 SAVE_MODEL 为 False 则返回 None
|
||||||
"""
|
"""
|
||||||
if not SAVE_MODEL:
|
if not SAVE_MODEL:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
|
||||||
# 确保模型保存目录存在
|
# 模型保存目录:models/{model_type}/
|
||||||
os.makedirs(MODEL_SAVE_DIR, exist_ok=True)
|
model_dir = os.path.join(MODEL_SAVE_DIR, model_type)
|
||||||
|
os.makedirs(model_dir, exist_ok=True)
|
||||||
|
|
||||||
# 使用 model_name 或默认使用 model_type
|
# 模型文件路径
|
||||||
name = model_name if model_name else model_type
|
return os.path.join(model_dir, "model.pkl")
|
||||||
filename = f"{name}.pkl"
|
|
||||||
return os.path.join(MODEL_SAVE_DIR, filename)
|
|
||||||
|
|
||||||
|
|
||||||
def save_model_with_factors(
|
def save_model_with_factors(
|
||||||
model,
|
model,
|
||||||
model_path: str,
|
model_path: str,
|
||||||
selected_factors: List[str],
|
selected_factors: list[str],
|
||||||
factor_definitions: dict,
|
factor_definitions: dict,
|
||||||
) -> None:
|
fitted_processors: list | None = None,
|
||||||
"""保存模型及关联的因子信息。
|
) -> str:
|
||||||
|
"""保存模型及关联的因子信息和处理器。
|
||||||
|
|
||||||
除了保存模型本身,还会保存一个同名的 .factors.json 文件,
|
将模型、因子信息和处理器保存到同一文件夹(models/{model_type}/)下:
|
||||||
包含 SELECTED_FACTORS 和 FACTOR_DEFINITIONS,以便后续加载模型时
|
- model.pkl: 模型文件
|
||||||
知道使用了哪些因子。
|
- factors.json: 因子信息文件
|
||||||
|
- processors.pkl: 处理器状态文件(如果提供)
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model: 训练好的模型实例(需有 save 方法)
|
model: 训练好的模型实例(需有 save 方法)
|
||||||
model_path: 模型保存路径
|
model_path: 模型保存路径(由 get_model_save_path 生成)
|
||||||
selected_factors: 从 metadata 中选择的因子名称列表
|
selected_factors: 从 metadata 中选择的因子名称列表
|
||||||
factor_definitions: 通过表达式定义的因子字典
|
factor_definitions: 通过表达式定义的因子字典
|
||||||
|
fitted_processors: 已拟合的处理器列表(可选)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
模型文件夹路径
|
||||||
"""
|
"""
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
|
import pickle
|
||||||
|
|
||||||
|
# 获取模型文件夹路径
|
||||||
|
model_dir = os.path.dirname(model_path)
|
||||||
|
|
||||||
# 1. 保存模型本身
|
# 1. 保存模型本身
|
||||||
model.save(model_path)
|
model.save(model_path)
|
||||||
print(f"[模型保存] 模型已保存至: {model_path}")
|
print(f"[模型保存] 模型已保存至: {model_path}")
|
||||||
|
|
||||||
# 2. 保存因子信息到 .factors.json 文件
|
# 2. 保存因子信息到 factors.json 文件
|
||||||
factors_path = model_path.replace(".pkl", ".factors.json")
|
factors_path = os.path.join(model_dir, "factors.json")
|
||||||
|
|
||||||
factors_info = {
|
factors_info = {
|
||||||
"selected_factors": selected_factors,
|
"selected_factors": selected_factors,
|
||||||
@@ -592,12 +602,22 @@ def save_model_with_factors(
|
|||||||
print(f" - 来自 metadata: {factors_info['selected_factors_count']} 个")
|
print(f" - 来自 metadata: {factors_info['selected_factors_count']} 个")
|
||||||
print(f" - 来自表达式定义: {factors_info['factor_definitions_count']} 个")
|
print(f" - 来自表达式定义: {factors_info['factor_definitions_count']} 个")
|
||||||
|
|
||||||
|
# 3. 保存处理器(如果提供)
|
||||||
|
if fitted_processors is not None:
|
||||||
|
processors_path = os.path.join(model_dir, "processors.pkl")
|
||||||
|
with open(processors_path, "wb") as f:
|
||||||
|
pickle.dump(fitted_processors, f)
|
||||||
|
print(f"[模型保存] 处理器已保存至: {processors_path}")
|
||||||
|
print(f"[模型保存] 共 {len(fitted_processors)} 个处理器")
|
||||||
|
|
||||||
|
return model_dir
|
||||||
|
|
||||||
|
|
||||||
def load_model_factors(model_path: str) -> Optional[dict]:
|
def load_model_factors(model_path: str) -> Optional[dict]:
|
||||||
"""加载模型关联的因子信息。
|
"""加载模型关联的因子信息。
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model_path: 模型保存路径
|
model_path: 模型保存路径(models/{model_type}/model.pkl)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
包含因子信息的字典,如果文件不存在则返回 None
|
包含因子信息的字典,如果文件不存在则返回 None
|
||||||
@@ -605,7 +625,9 @@ def load_model_factors(model_path: str) -> Optional[dict]:
|
|||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
|
|
||||||
factors_path = model_path.replace(".pkl", ".factors.json")
|
# 获取模型文件夹路径
|
||||||
|
model_dir = os.path.dirname(model_path)
|
||||||
|
factors_path = os.path.join(model_dir, "factors.json")
|
||||||
|
|
||||||
if not os.path.exists(factors_path):
|
if not os.path.exists(factors_path):
|
||||||
print(f"[警告] 未找到因子信息文件: {factors_path}")
|
print(f"[警告] 未找到因子信息文件: {factors_path}")
|
||||||
@@ -618,3 +640,30 @@ def load_model_factors(model_path: str) -> Optional[dict]:
|
|||||||
f"[模型加载] 已加载因子信息,总计 {factors_info.get('total_feature_count', 'N/A')} 个因子"
|
f"[模型加载] 已加载因子信息,总计 {factors_info.get('total_feature_count', 'N/A')} 个因子"
|
||||||
)
|
)
|
||||||
return factors_info
|
return factors_info
|
||||||
|
|
||||||
|
|
||||||
|
def load_processors(model_path: str) -> list | None:
|
||||||
|
"""加载模型关联的处理器。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_path: 模型保存路径(models/{model_type}/model.pkl)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
处理器列表,如果文件不存在则返回 None
|
||||||
|
"""
|
||||||
|
import pickle
|
||||||
|
import os
|
||||||
|
|
||||||
|
# 获取模型文件夹路径
|
||||||
|
model_dir = os.path.dirname(model_path)
|
||||||
|
processors_path = os.path.join(model_dir, "processors.pkl")
|
||||||
|
|
||||||
|
if not os.path.exists(processors_path):
|
||||||
|
print(f"[警告] 未找到处理器文件: {processors_path}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
with open(processors_path, "rb") as f:
|
||||||
|
fitted_processors = pickle.load(f)
|
||||||
|
|
||||||
|
print(f"[模型加载] 已加载 {len(fitted_processors)} 个处理器")
|
||||||
|
return fitted_processors
|
||||||
|
|||||||
@@ -603,6 +603,7 @@ if SAVE_MODEL:
|
|||||||
model_path=model_save_path,
|
model_path=model_save_path,
|
||||||
selected_factors=SELECTED_FACTORS,
|
selected_factors=SELECTED_FACTORS,
|
||||||
factor_definitions=FACTOR_DEFINITIONS,
|
factor_definitions=FACTOR_DEFINITIONS,
|
||||||
|
fitted_processors=fitted_processors,
|
||||||
)
|
)
|
||||||
|
|
||||||
print("\n训练流程完成!")
|
print("\n训练流程完成!")
|
||||||
|
|||||||
413
src/experiment/predict.py
Normal file
413
src/experiment/predict.py
Normal file
@@ -0,0 +1,413 @@
|
|||||||
|
"""预测脚本 - 加载模型并对指定时间段进行预测。
|
||||||
|
|
||||||
|
支持两种模型类型:
|
||||||
|
- regression: 回归模型
|
||||||
|
- rank: 排序学习模型
|
||||||
|
|
||||||
|
脚本会自动分析 models 目录下的模型类型,用户只需指定模型类型和预测时间段。
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
from typing import Any, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import polars as pl
|
||||||
|
|
||||||
|
from src.factors import FactorEngine
|
||||||
|
from src.training import (
|
||||||
|
STFilter,
|
||||||
|
StockPoolManager,
|
||||||
|
Winsorizer,
|
||||||
|
NullFiller,
|
||||||
|
StandardScaler,
|
||||||
|
CrossSectionalStandardScaler,
|
||||||
|
)
|
||||||
|
from src.training.components.models import LightGBMModel, LightGBMLambdaRankModel
|
||||||
|
from src.experiment.common import (
|
||||||
|
get_label_factor,
|
||||||
|
stock_pool_filter,
|
||||||
|
STOCK_FILTER_REQUIRED_COLUMNS,
|
||||||
|
OUTPUT_DIR,
|
||||||
|
TOP_N,
|
||||||
|
load_processors,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# 配置区域 - 用户需要修改这些配置
|
||||||
|
# =============================================================================
|
||||||
|
# 模型类型: "regression" 或 "rank"
|
||||||
|
MODEL_TYPE = "rank"
|
||||||
|
|
||||||
|
# 预测时间段(不从中读取,使用这里的配置)
|
||||||
|
PREDICT_START = "20250101"
|
||||||
|
PREDICT_END = "20261231"
|
||||||
|
|
||||||
|
# 数据回看窗口天数(用于计算时序因子,需要向前获取额外数据)
|
||||||
|
# 例如:如果因子使用了 ts_mean(close, 60),则回看窗口至少为 60 天
|
||||||
|
LOOKBACK_DAYS = 365 * 3 # 向前获取 1 年的数据确保所有因子都能正确计算
|
||||||
|
|
||||||
|
# 模型路径配置
|
||||||
|
MODEL_SAVE_DIR = "models"
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
def detect_model_type(models_dir: str) -> str:
|
||||||
|
"""自动检测模型类型。
|
||||||
|
|
||||||
|
检查 models 目录下有哪些模型类型可用。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
models_dir: 模型保存目录
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
检测到的模型类型,如果多个则优先返回 regression
|
||||||
|
"""
|
||||||
|
available_types: List[str] = []
|
||||||
|
|
||||||
|
# 检查 regression 模型
|
||||||
|
regression_path = os.path.join(models_dir, "regression", "model.pkl")
|
||||||
|
if os.path.exists(regression_path):
|
||||||
|
available_types.append("regression")
|
||||||
|
|
||||||
|
# 检查 rank 模型
|
||||||
|
rank_path = os.path.join(models_dir, "rank", "model.pkl")
|
||||||
|
if os.path.exists(rank_path):
|
||||||
|
available_types.append("rank")
|
||||||
|
|
||||||
|
if not available_types:
|
||||||
|
raise FileNotFoundError(
|
||||||
|
f"未在 {models_dir} 目录下找到任何模型。"
|
||||||
|
f"请确保模型已训练并保存在 models/regression/model.pkl 或 models/rank/model.pkl"
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"[模型检测] 可用的模型类型: {available_types}")
|
||||||
|
|
||||||
|
# 如果用户指定的类型可用,直接返回
|
||||||
|
if MODEL_TYPE in available_types:
|
||||||
|
return MODEL_TYPE
|
||||||
|
|
||||||
|
# 如果用户未指定或指定的不可用,返回第一个可用的
|
||||||
|
print(f"[模型检测] 使用默认模型类型: {available_types[0]}")
|
||||||
|
return available_types[0]
|
||||||
|
|
||||||
|
|
||||||
|
def load_model_and_factors(
|
||||||
|
model_type: str, models_dir: str
|
||||||
|
) -> Tuple[Any, Dict[str, Any], List[str], str]:
|
||||||
|
"""加载模型及其关联的因子信息。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_type: 模型类型("regression" 或 "rank")
|
||||||
|
models_dir: 模型保存目录
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(model, factors_info, feature_cols, model_path) 元组
|
||||||
|
"""
|
||||||
|
model_path = os.path.join(models_dir, model_type, "model.pkl")
|
||||||
|
factors_path = os.path.join(models_dir, model_type, "factors.json")
|
||||||
|
|
||||||
|
print(f"\n{'=' * 80}")
|
||||||
|
print(f"加载模型: {model_type}")
|
||||||
|
print(f"{'=' * 80}")
|
||||||
|
|
||||||
|
# 检查模型文件是否存在
|
||||||
|
if not os.path.exists(model_path):
|
||||||
|
raise FileNotFoundError(f"模型文件不存在: {model_path}")
|
||||||
|
|
||||||
|
# 加载模型(根据模型类型选择正确的加载方法)
|
||||||
|
print(f"[模型加载] 正在加载模型: {model_path}")
|
||||||
|
|
||||||
|
if model_type == "regression":
|
||||||
|
model = LightGBMModel.load(model_path)
|
||||||
|
elif model_type == "rank":
|
||||||
|
model = LightGBMLambdaRankModel.load(model_path)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"不支持的模型类型: {model_type}")
|
||||||
|
|
||||||
|
print(f"[模型加载] 已加载模型: {model_path}")
|
||||||
|
print(f"[模型加载] 模型类型: {type(model).__name__}")
|
||||||
|
|
||||||
|
# 加载因子信息
|
||||||
|
if not os.path.exists(factors_path):
|
||||||
|
raise FileNotFoundError(f"因子信息文件不存在: {factors_path}")
|
||||||
|
|
||||||
|
with open(factors_path, "r", encoding="utf-8") as f:
|
||||||
|
factors_info = json.load(f)
|
||||||
|
|
||||||
|
print(f"[因子加载] 已加载因子信息: {factors_path}")
|
||||||
|
print(f"[因子加载] 因子总数: {factors_info.get('total_feature_count', 'N/A')}")
|
||||||
|
print(
|
||||||
|
f"[因子加载] 来自 metadata: {factors_info.get('selected_factors_count', 'N/A')}"
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
f"[因子加载] 来自表达式: {factors_info.get('factor_definitions_count', 'N/A')}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 构建特征列列表
|
||||||
|
selected_factors = factors_info.get("selected_factors", [])
|
||||||
|
factor_definitions = factors_info.get("factor_definitions", {})
|
||||||
|
feature_cols = selected_factors + list(factor_definitions.keys())
|
||||||
|
|
||||||
|
print(f"[特征列] 共 {len(feature_cols)} 个特征")
|
||||||
|
|
||||||
|
return model, factors_info, feature_cols, model_path
|
||||||
|
|
||||||
|
|
||||||
|
def get_lookback_start_date(start_date: str, lookback_days: int) -> str:
|
||||||
|
"""计算考虑回看窗口后的实际开始日期。
|
||||||
|
|
||||||
|
为了确保时序因子(如 ts_mean(close, 20))在预测开始日期能正确计算,
|
||||||
|
需要向前获取额外的历史数据。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
start_date: 预测开始日期 (YYYYMMDD)
|
||||||
|
lookback_days: 回看窗口天数
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
考虑回看后的实际开始日期 (YYYYMMDD)
|
||||||
|
"""
|
||||||
|
start_dt = datetime.strptime(start_date, "%Y%m%d")
|
||||||
|
lookback_dt = start_dt - timedelta(days=lookback_days)
|
||||||
|
return lookback_dt.strftime("%Y%m%d")
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_data(
|
||||||
|
engine: FactorEngine,
|
||||||
|
feature_cols: list[str],
|
||||||
|
start_date: str,
|
||||||
|
end_date: str,
|
||||||
|
label_name: str,
|
||||||
|
) -> pl.DataFrame:
|
||||||
|
"""准备预测数据。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
engine: FactorEngine实例
|
||||||
|
feature_cols: 特征列名称列表
|
||||||
|
start_date: 开始日期 (YYYYMMDD)
|
||||||
|
end_date: 结束日期 (YYYYMMDD)
|
||||||
|
label_name: label列名称
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
包含因子计算结果的数据框
|
||||||
|
"""
|
||||||
|
print(f"\n{'=' * 80}")
|
||||||
|
print(f"准备数据: {start_date} - {end_date}")
|
||||||
|
print(f"{'=' * 80}")
|
||||||
|
|
||||||
|
factor_names = feature_cols + [label_name]
|
||||||
|
|
||||||
|
data = engine.compute(
|
||||||
|
factor_names=factor_names,
|
||||||
|
start_date=start_date,
|
||||||
|
end_date=end_date,
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"数据形状: {data.shape}")
|
||||||
|
print(f"前5行预览:")
|
||||||
|
print(data.head())
|
||||||
|
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
def apply_processors(data: pl.DataFrame, fitted_processors: List[Any]) -> pl.DataFrame:
|
||||||
|
"""应用已拟合的数据处理器。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data: 输入数据
|
||||||
|
fitted_processors: 已拟合的处理器列表(从训练时保存)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
处理后的数据
|
||||||
|
"""
|
||||||
|
print(f"\n{'=' * 80}")
|
||||||
|
print("应用数据处理器(使用训练时保存的参数)")
|
||||||
|
print(f"{'=' * 80}")
|
||||||
|
|
||||||
|
for i, processor in enumerate(fitted_processors, 1):
|
||||||
|
print(f" [{i}/{len(fitted_processors)}] {processor.__class__.__name__}")
|
||||||
|
data = processor.transform(data) # 使用 transform,不是 fit_transform
|
||||||
|
|
||||||
|
print(f"处理后数据形状: {data.shape}")
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
def predict():
|
||||||
|
"""主预测流程。"""
|
||||||
|
print("\n" + "=" * 80)
|
||||||
|
print("ProStock 预测脚本")
|
||||||
|
print("=" * 80)
|
||||||
|
|
||||||
|
# 1. 自动检测模型类型
|
||||||
|
model_type = detect_model_type(MODEL_SAVE_DIR)
|
||||||
|
print(f"\n[配置] 使用模型类型: {model_type}")
|
||||||
|
print(f"[配置] 预测时间段: {PREDICT_START} - {PREDICT_END}")
|
||||||
|
|
||||||
|
# 2. 加载模型和因子信息
|
||||||
|
model, factors_info, feature_cols, model_path = load_model_and_factors(
|
||||||
|
model_type, MODEL_SAVE_DIR
|
||||||
|
)
|
||||||
|
|
||||||
|
# 提取因子配置
|
||||||
|
selected_factors = factors_info.get("selected_factors", [])
|
||||||
|
factor_definitions = factors_info.get("factor_definitions", {})
|
||||||
|
|
||||||
|
# 3. 创建 FactorEngine
|
||||||
|
print("\n[1] 创建 FactorEngine")
|
||||||
|
engine = FactorEngine()
|
||||||
|
|
||||||
|
# 4. 注册因子
|
||||||
|
print("\n[2] 注册因子")
|
||||||
|
label_name = "future_return_5"
|
||||||
|
label_factor = get_label_factor(label_name)
|
||||||
|
|
||||||
|
# 注册来自 metadata 的因子
|
||||||
|
print(" 注册 metadata 因子:")
|
||||||
|
for name in selected_factors:
|
||||||
|
engine.add_factor(name)
|
||||||
|
print(f" - {name}")
|
||||||
|
|
||||||
|
# 注册表达式因子
|
||||||
|
print(" 注册表达式因子:")
|
||||||
|
for name, expr in factor_definitions.items():
|
||||||
|
engine.add_factor(name, expr)
|
||||||
|
print(f" - {name}: {expr}")
|
||||||
|
|
||||||
|
# 注册 label 因子
|
||||||
|
print(" 注册 Label 因子:")
|
||||||
|
for name, expr in label_factor.items():
|
||||||
|
engine.add_factor(name, expr)
|
||||||
|
print(f" - {name}: {expr}")
|
||||||
|
|
||||||
|
# 5. 准备数据(考虑回看窗口)
|
||||||
|
print(f"\n[数据准备] 预测时间段: {PREDICT_START} - {PREDICT_END}")
|
||||||
|
print(f"[数据准备] 回看窗口: {LOOKBACK_DAYS} 天")
|
||||||
|
|
||||||
|
actual_start = get_lookback_start_date(PREDICT_START, LOOKBACK_DAYS)
|
||||||
|
print(f"[数据准备] 实际加载数据时间段: {actual_start} - {PREDICT_END}")
|
||||||
|
|
||||||
|
data = prepare_data(
|
||||||
|
engine=engine,
|
||||||
|
feature_cols=feature_cols,
|
||||||
|
start_date=actual_start,
|
||||||
|
end_date=PREDICT_END,
|
||||||
|
label_name=label_name,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 过滤回看数据,只保留预测日期范围内的数据
|
||||||
|
print(f"[数据准备] 过滤回看数据,保留 {PREDICT_START} 之后的数据...")
|
||||||
|
data = data.filter(data["trade_date"] >= PREDICT_START)
|
||||||
|
print(f"[数据准备] 过滤后数据形状: {data.shape}")
|
||||||
|
|
||||||
|
# 6. 股票池筛选
|
||||||
|
print("\n[3] 股票池筛选")
|
||||||
|
pool_manager = StockPoolManager(
|
||||||
|
filter_func=stock_pool_filter,
|
||||||
|
required_columns=STOCK_FILTER_REQUIRED_COLUMNS,
|
||||||
|
data_router=engine.router,
|
||||||
|
)
|
||||||
|
|
||||||
|
st_filter = STFilter(data_router=engine.router)
|
||||||
|
|
||||||
|
# 先执行 ST 过滤
|
||||||
|
if st_filter:
|
||||||
|
print(" 应用 ST 过滤器...")
|
||||||
|
data = st_filter.filter(data)
|
||||||
|
print(f" ST 过滤后数据规模: {data.shape}")
|
||||||
|
|
||||||
|
# 股票池筛选
|
||||||
|
print(" 执行每日股票池筛选...")
|
||||||
|
filtered_data = pool_manager.filter_and_select_daily(data)
|
||||||
|
print(f" 筛选前数据规模: {data.shape}")
|
||||||
|
print(f" 筛选后数据规模: {filtered_data.shape}")
|
||||||
|
print(f" 筛选前股票数: {data['ts_code'].n_unique()}")
|
||||||
|
print(f" 筛选后股票数: {filtered_data['ts_code'].n_unique()}")
|
||||||
|
|
||||||
|
# 7. 加载并应用数据处理器
|
||||||
|
print("\n[3.5] 加载数据处理器")
|
||||||
|
model_dir = os.path.dirname(model_path)
|
||||||
|
fitted_processors = load_processors(model_path)
|
||||||
|
|
||||||
|
if fitted_processors is None:
|
||||||
|
raise FileNotFoundError(
|
||||||
|
f"未找到处理器文件,请确保模型已正确训练并保存处理器到 {model_dir}/processors.pkl"
|
||||||
|
)
|
||||||
|
|
||||||
|
processed_data = apply_processors(filtered_data, fitted_processors)
|
||||||
|
|
||||||
|
# 8. 生成预测
|
||||||
|
print("\n[4] 生成预测")
|
||||||
|
print("-" * 60)
|
||||||
|
X = processed_data.select(feature_cols)
|
||||||
|
print(f" 预测样本数: {len(X)}")
|
||||||
|
print(f" 特征数: {len(feature_cols)}")
|
||||||
|
|
||||||
|
predictions = model.predict(X)
|
||||||
|
print(f" 预测完成!")
|
||||||
|
|
||||||
|
print(f"\n 预测结果统计:")
|
||||||
|
print(f" 均值: {predictions.mean():.6f}")
|
||||||
|
print(f" 标准差: {predictions.std():.6f}")
|
||||||
|
print(f" 最小值: {predictions.min():.6f}")
|
||||||
|
print(f" 最大值: {predictions.max():.6f}")
|
||||||
|
|
||||||
|
# 添加预测列
|
||||||
|
processed_data = processed_data.with_columns([pl.Series("prediction", predictions)])
|
||||||
|
|
||||||
|
# 9. 保存结果
|
||||||
|
print("\n[5] 保存预测结果")
|
||||||
|
print("-" * 60)
|
||||||
|
|
||||||
|
# 确保输出目录存在
|
||||||
|
os.makedirs(OUTPUT_DIR, exist_ok=True)
|
||||||
|
|
||||||
|
# 生成输出文件名
|
||||||
|
start_dt = datetime.strptime(PREDICT_START, "%Y%m%d")
|
||||||
|
end_dt = datetime.strptime(PREDICT_END, "%Y%m%d")
|
||||||
|
date_str = f"{start_dt.strftime('%Y%m%d')}_{end_dt.strftime('%Y%m%d')}"
|
||||||
|
|
||||||
|
# 保存每日 Top N
|
||||||
|
print(f" 保存每日 Top {TOP_N} 股票...")
|
||||||
|
output_path = os.path.join(OUTPUT_DIR, "predict_output.csv")
|
||||||
|
|
||||||
|
# 按日期分组,取每日 top N
|
||||||
|
topn_by_date = []
|
||||||
|
unique_dates = processed_data["trade_date"].unique().sort()
|
||||||
|
for date in unique_dates:
|
||||||
|
day_data = processed_data.filter(processed_data["trade_date"] == date)
|
||||||
|
# 按 prediction 降序排序,取前 N
|
||||||
|
topn = day_data.sort("prediction", descending=True).head(TOP_N)
|
||||||
|
topn_by_date.append(topn)
|
||||||
|
|
||||||
|
# 合并所有日期的 top N
|
||||||
|
topn_results = pl.concat(topn_by_date)
|
||||||
|
|
||||||
|
# 格式化日期并调整列顺序:日期、分数、股票
|
||||||
|
topn_to_save = topn_results.select(
|
||||||
|
[
|
||||||
|
pl.col("trade_date").str.slice(0, 4)
|
||||||
|
+ "-"
|
||||||
|
+ pl.col("trade_date").str.slice(4, 2)
|
||||||
|
+ "-"
|
||||||
|
+ pl.col("trade_date").str.slice(6, 2).alias("date"),
|
||||||
|
pl.col("prediction").alias("score"),
|
||||||
|
pl.col("ts_code"),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
topn_to_save.write_csv(output_path, include_header=True)
|
||||||
|
print(f" 保存路径: {output_path}")
|
||||||
|
print(
|
||||||
|
f" 保存行数: {len(topn_to_save)}({len(unique_dates)}个交易日 × 每日top{TOP_N})"
|
||||||
|
)
|
||||||
|
print(f"\n 预览(前15行):")
|
||||||
|
print(topn_to_save.head(15))
|
||||||
|
|
||||||
|
print("\n" + "=" * 80)
|
||||||
|
print("预测完成!")
|
||||||
|
print("=" * 80)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
predict()
|
||||||
@@ -598,4 +598,5 @@ if SAVE_MODEL:
|
|||||||
model_path=model_save_path,
|
model_path=model_save_path,
|
||||||
selected_factors=SELECTED_FACTORS,
|
selected_factors=SELECTED_FACTORS,
|
||||||
factor_definitions=FACTOR_DEFINITIONS,
|
factor_definitions=FACTOR_DEFINITIONS,
|
||||||
|
fitted_processors=fitted_processors,
|
||||||
)
|
)
|
||||||
|
|||||||
409
tests/debug/test_lookback_consistency.py
Normal file
409
tests/debug/test_lookback_consistency.py
Normal file
@@ -0,0 +1,409 @@
|
|||||||
|
"""
|
||||||
|
测试 LOOKBACK_DAYS 对因子计算结果的影响
|
||||||
|
|
||||||
|
测试目标:验证不同 LOOKBACK_DAYS 设置下,同一预测日期范围的因子值是否一致
|
||||||
|
如果结果不一致,说明可能存在数据泄露问题
|
||||||
|
|
||||||
|
测试逻辑:
|
||||||
|
1. 分别使用 2 年(730天)和 3 年(1095天)作为 LOOKBACK_DAYS
|
||||||
|
2. 计算同一预测日期范围(2025-2026)的因子值
|
||||||
|
3. 比较两者的因子值是否相同
|
||||||
|
|
||||||
|
预期结果:
|
||||||
|
- 如果回看窗口大于最大因子窗口,两种设置下的因子值应该完全一致
|
||||||
|
- 如果结果不同,说明因子计算使用了超出合理回看期的数据(数据泄露)
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
from typing import Any, Dict, List, Tuple
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import polars as pl
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from src.factors import FactorEngine
|
||||||
|
from src.experiment.common import (
|
||||||
|
SELECTED_FACTORS,
|
||||||
|
FACTOR_DEFINITIONS,
|
||||||
|
get_label_factor,
|
||||||
|
register_factors,
|
||||||
|
prepare_data,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# 测试配置
|
||||||
|
# =============================================================================
|
||||||
|
PREDICT_START = "20250101"
|
||||||
|
PREDICT_END = "20250131" # 只测试1月份,加快测试速度
|
||||||
|
MODEL_SAVE_DIR = "models"
|
||||||
|
|
||||||
|
# 两种不同的回看窗口设置
|
||||||
|
LOOKBACK_2Y = 365 * 3 # 2年 = 730天
|
||||||
|
LOOKBACK_3Y = 365 * 4 # 3年 = 1095天
|
||||||
|
|
||||||
|
|
||||||
|
def get_lookback_start_date(start_date: str, lookback_days: int) -> str:
|
||||||
|
"""计算考虑回看窗口后的实际开始日期。"""
|
||||||
|
start_dt = datetime.strptime(start_date, "%Y%m%d")
|
||||||
|
lookback_dt = start_dt - timedelta(days=lookback_days)
|
||||||
|
return lookback_dt.strftime("%Y%m%d")
|
||||||
|
|
||||||
|
|
||||||
|
def load_model_factors(
|
||||||
|
model_type: str, models_dir: str
|
||||||
|
) -> Tuple[Dict[str, Any], List[str]]:
|
||||||
|
"""加载模型的因子信息。"""
|
||||||
|
factors_path = os.path.join(models_dir, model_type, "factors.json")
|
||||||
|
|
||||||
|
if not os.path.exists(factors_path):
|
||||||
|
raise FileNotFoundError(f"因子信息文件不存在: {factors_path}")
|
||||||
|
|
||||||
|
with open(factors_path, "r", encoding="utf-8") as f:
|
||||||
|
factors_info = json.load(f)
|
||||||
|
|
||||||
|
selected_factors = SELECTED_FACTORS
|
||||||
|
factor_definitions = SELECTED_FACTORS
|
||||||
|
feature_cols = SELECTED_FACTORS
|
||||||
|
|
||||||
|
return factors_info, feature_cols
|
||||||
|
|
||||||
|
|
||||||
|
def compute_factors_with_lookback(
|
||||||
|
lookback_days: int,
|
||||||
|
feature_cols: List[str],
|
||||||
|
factors_info: Dict[str, Any],
|
||||||
|
) -> pl.DataFrame:
|
||||||
|
"""
|
||||||
|
使用指定的回看窗口计算因子。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
lookback_days: 回看窗口天数
|
||||||
|
feature_cols: 特征列名称列表
|
||||||
|
factors_info: 因子信息字典
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
包含因子计算结果的数据框(已过滤到预测日期范围)
|
||||||
|
"""
|
||||||
|
# 计算实际开始日期
|
||||||
|
actual_start = get_lookback_start_date(PREDICT_START, lookback_days)
|
||||||
|
|
||||||
|
print(f"\n{'=' * 80}")
|
||||||
|
print(f"使用 LOOKBACK_DAYS = {lookback_days} ({lookback_days // 365}年)")
|
||||||
|
print(f"预测日期范围: {PREDICT_START} - {PREDICT_END}")
|
||||||
|
print(f"实际加载数据范围: {actual_start} - {PREDICT_END}")
|
||||||
|
print(f"{'=' * 80}")
|
||||||
|
|
||||||
|
# 创建 FactorEngine
|
||||||
|
engine = FactorEngine()
|
||||||
|
|
||||||
|
# 注册因子
|
||||||
|
selected_factors = factors_info.get("selected_factors", [])
|
||||||
|
factor_definitions = factors_info.get("factor_definitions", {})
|
||||||
|
label_name = "future_return_5"
|
||||||
|
label_factor = get_label_factor(label_name)
|
||||||
|
|
||||||
|
register_factors(
|
||||||
|
engine=engine,
|
||||||
|
selected_factors=selected_factors,
|
||||||
|
factor_definitions=factor_definitions,
|
||||||
|
label_factor=label_factor,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 计算因子
|
||||||
|
data = prepare_data(
|
||||||
|
engine=engine,
|
||||||
|
feature_cols=feature_cols,
|
||||||
|
start_date=actual_start,
|
||||||
|
end_date=PREDICT_END,
|
||||||
|
label_name=label_name,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 过滤回看数据,只保留预测日期范围内的数据
|
||||||
|
data = data.filter(data["trade_date"] >= PREDICT_START)
|
||||||
|
|
||||||
|
print(f"\n过滤后数据形状: {data.shape}")
|
||||||
|
print(f"过滤后日期范围: {data['trade_date'].min()} - {data['trade_date'].max()}")
|
||||||
|
print(f"过滤后股票数量: {data['ts_code'].n_unique()}")
|
||||||
|
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
def compare_factor_values(
|
||||||
|
data_2y: pl.DataFrame,
|
||||||
|
data_3y: pl.DataFrame,
|
||||||
|
feature_cols: List[str],
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
比较两种回看窗口设置下的因子值。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data_2y: 2年回看窗口的因子数据
|
||||||
|
data_3y: 3年回看窗口的因子数据
|
||||||
|
feature_cols: 特征列名称列表
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
比较结果字典
|
||||||
|
"""
|
||||||
|
print(f"\n{'=' * 80}")
|
||||||
|
print("比较因子计算结果")
|
||||||
|
print(f"{'=' * 80}")
|
||||||
|
|
||||||
|
# 确保两个数据集的行数和列数相同
|
||||||
|
print(f"\n数据集形状:")
|
||||||
|
print(f" 2Y 回看: {data_2y.shape}")
|
||||||
|
print(f" 3Y 回看: {data_3y.shape}")
|
||||||
|
|
||||||
|
if data_2y.shape != data_3y.shape:
|
||||||
|
print(f"[警告] 数据形状不一致!")
|
||||||
|
# 找出差异
|
||||||
|
dates_2y = set(data_2y["trade_date"].to_list())
|
||||||
|
dates_3y = set(data_3y["trade_date"].to_list())
|
||||||
|
stocks_2y = set(data_2y["ts_code"].to_list())
|
||||||
|
stocks_3y = set(data_3y["ts_code"].to_list())
|
||||||
|
|
||||||
|
print(f" 2Y 日期数: {len(dates_2y)}, 股票数: {len(stocks_2y)}")
|
||||||
|
print(f" 3Y 日期数: {len(dates_3y)}, 股票数: {len(stocks_3y)}")
|
||||||
|
|
||||||
|
# 使用交集进行后续比较
|
||||||
|
common_dates = dates_2y & dates_3y
|
||||||
|
common_stocks = stocks_2y & stocks_3y
|
||||||
|
|
||||||
|
print(f" 共同日期数: {len(common_dates)}")
|
||||||
|
print(f" 共同股票数: {len(common_stocks)}")
|
||||||
|
|
||||||
|
data_2y = data_2y.filter(
|
||||||
|
data_2y["trade_date"].is_in(list(common_dates))
|
||||||
|
& data_2y["ts_code"].is_in(list(common_stocks))
|
||||||
|
)
|
||||||
|
data_3y = data_3y.filter(
|
||||||
|
data_3y["trade_date"].is_in(list(common_dates))
|
||||||
|
& data_3y["ts_code"].is_in(list(common_stocks))
|
||||||
|
)
|
||||||
|
|
||||||
|
# 按日期和股票代码排序
|
||||||
|
data_2y = data_2y.sort(["trade_date", "ts_code"])
|
||||||
|
data_3y = data_3y.sort(["trade_date", "ts_code"])
|
||||||
|
|
||||||
|
# 比较每个因子的值
|
||||||
|
results = {
|
||||||
|
"total_factors": len(feature_cols),
|
||||||
|
"consistent_factors": 0,
|
||||||
|
"inconsistent_factors": 0,
|
||||||
|
"inconsistent_details": [],
|
||||||
|
}
|
||||||
|
|
||||||
|
print(f"\n因子一致性检查:")
|
||||||
|
for factor_name in feature_cols:
|
||||||
|
if factor_name not in data_2y.columns or factor_name not in data_3y.columns:
|
||||||
|
print(f" [跳过] {factor_name}: 因子不存在于两个数据集中")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 获取因子值(转换为 numpy 数组)
|
||||||
|
values_2y = data_2y[factor_name].to_numpy()
|
||||||
|
values_3y = data_3y[factor_name].to_numpy()
|
||||||
|
|
||||||
|
# 处理 NaN 值 - 转换为 float 类型以确保兼容性
|
||||||
|
values_2y = np.asarray(values_2y, dtype=np.float64)
|
||||||
|
values_3y = np.asarray(values_3y, dtype=np.float64)
|
||||||
|
|
||||||
|
# 处理 NaN 值
|
||||||
|
mask_2y = ~np.isnan(values_2y)
|
||||||
|
mask_3y = ~np.isnan(values_3y)
|
||||||
|
|
||||||
|
# 检查 NaN 模式是否一致
|
||||||
|
nan_consistent = np.array_equal(mask_2y, mask_3y)
|
||||||
|
|
||||||
|
if not nan_consistent:
|
||||||
|
print(f" [警告] {factor_name}: NaN 模式不一致!")
|
||||||
|
print(f" 2Y NaN 数量: {np.sum(~mask_2y)}")
|
||||||
|
print(f" 3Y NaN 数量: {np.sum(~mask_3y)}")
|
||||||
|
|
||||||
|
# 只在两者都有有效值的位置进行比较
|
||||||
|
valid_mask = mask_2y & mask_3y
|
||||||
|
|
||||||
|
if np.sum(valid_mask) == 0:
|
||||||
|
print(f" [跳过] {factor_name}: 没有有效的共同数据点")
|
||||||
|
continue
|
||||||
|
|
||||||
|
valid_2y = values_2y[valid_mask]
|
||||||
|
valid_3y = values_3y[valid_mask]
|
||||||
|
|
||||||
|
# 检查数值是否一致(使用相对容差)
|
||||||
|
consistent = np.allclose(valid_2y, valid_3y, rtol=1e-10, atol=1e-10)
|
||||||
|
|
||||||
|
if consistent:
|
||||||
|
results["consistent_factors"] += 1
|
||||||
|
print(f" [一致] {factor_name}")
|
||||||
|
else:
|
||||||
|
results["inconsistent_factors"] += 1
|
||||||
|
|
||||||
|
# 计算差异统计
|
||||||
|
diff = np.abs(valid_2y - valid_3y)
|
||||||
|
max_diff = np.max(diff)
|
||||||
|
mean_diff = np.mean(diff)
|
||||||
|
|
||||||
|
results["inconsistent_details"].append(
|
||||||
|
{
|
||||||
|
"factor": factor_name,
|
||||||
|
"max_diff": max_diff,
|
||||||
|
"mean_diff": mean_diff,
|
||||||
|
"count_diff": np.sum(diff > 1e-10),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f" [不一致] {factor_name}:")
|
||||||
|
print(f" 最大差异: {max_diff:.10f}")
|
||||||
|
print(f" 平均差异: {mean_diff:.10f}")
|
||||||
|
print(f" 差异数据点数量: {np.sum(diff > 1e-10)}")
|
||||||
|
|
||||||
|
# 显示前几个差异
|
||||||
|
diff_indices = np.where(diff > 1e-10)[0][:5]
|
||||||
|
print(f" 前几个差异值:")
|
||||||
|
for idx in diff_indices:
|
||||||
|
print(
|
||||||
|
f" idx={idx}: 2Y={valid_2y[idx]:.10f}, 3Y={valid_3y[idx]:.10f}, diff={diff[idx]:.10f}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
def test_lookback_consistency():
|
||||||
|
"""
|
||||||
|
测试 LOOKBACK_DAYS 设置对因子计算结果的影响。
|
||||||
|
|
||||||
|
这个测试会:
|
||||||
|
1. 加载模型因子配置
|
||||||
|
2. 分别使用 2 年和 3 年回看窗口计算因子
|
||||||
|
3. 比较结果是否一致
|
||||||
|
|
||||||
|
如果结果不一致,说明存在数据泄露问题。
|
||||||
|
"""
|
||||||
|
print("\n" + "=" * 80)
|
||||||
|
print("LOOKBACK_DAYS 一致性测试")
|
||||||
|
print("=" * 80)
|
||||||
|
|
||||||
|
# 检测模型类型
|
||||||
|
available_types = []
|
||||||
|
if os.path.exists(os.path.join(MODEL_SAVE_DIR, "regression", "model.pkl")):
|
||||||
|
available_types.append("regression")
|
||||||
|
if os.path.exists(os.path.join(MODEL_SAVE_DIR, "rank", "model.pkl")):
|
||||||
|
available_types.append("rank")
|
||||||
|
|
||||||
|
if not available_types:
|
||||||
|
pytest.skip(f"未在 {MODEL_SAVE_DIR} 目录下找到任何模型,跳过测试")
|
||||||
|
|
||||||
|
model_type = available_types[0]
|
||||||
|
print(f"\n使用模型类型: {model_type}")
|
||||||
|
|
||||||
|
# 加载因子信息
|
||||||
|
try:
|
||||||
|
factors_info, feature_cols = load_model_factors(model_type, MODEL_SAVE_DIR)
|
||||||
|
except FileNotFoundError as e:
|
||||||
|
pytest.skip(f"无法加载因子信息: {e}")
|
||||||
|
|
||||||
|
print(f"因子数量: {len(feature_cols)}")
|
||||||
|
|
||||||
|
# 使用 2 年回看窗口计算因子
|
||||||
|
data_2y = compute_factors_with_lookback(
|
||||||
|
lookback_days=LOOKBACK_2Y,
|
||||||
|
feature_cols=feature_cols,
|
||||||
|
factors_info=factors_info,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 使用 3 年回看窗口计算因子
|
||||||
|
data_3y = compute_factors_with_lookback(
|
||||||
|
lookback_days=LOOKBACK_3Y,
|
||||||
|
feature_cols=feature_cols,
|
||||||
|
factors_info=factors_info,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 比较结果
|
||||||
|
results = compare_factor_values(data_2y, data_3y, feature_cols)
|
||||||
|
|
||||||
|
# 打印总结
|
||||||
|
print(f"\n{'=' * 80}")
|
||||||
|
print("测试结果总结")
|
||||||
|
print(f"{'=' * 80}")
|
||||||
|
print(f"总因子数: {results['total_factors']}")
|
||||||
|
print(f"一致因子数: {results['consistent_factors']}")
|
||||||
|
print(f"不一致因子数: {results['inconsistent_factors']}")
|
||||||
|
|
||||||
|
if results["inconsistent_factors"] > 0:
|
||||||
|
print(f"\n不一致的因子:")
|
||||||
|
for detail in results["inconsistent_details"]:
|
||||||
|
print(f" - {detail['factor']}: 最大差异={detail['max_diff']:.10f}")
|
||||||
|
|
||||||
|
# 断言:如果有不一致的因子,测试失败
|
||||||
|
inconsistent_names = [d["factor"] for d in results["inconsistent_details"]]
|
||||||
|
pytest.fail(
|
||||||
|
f"发现 {results['inconsistent_factors']} 个因子在不同 LOOKBACK_DAYS 设置下结果不一致,"
|
||||||
|
f"可能存在数据泄露: {inconsistent_names[:5]}..."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
print("\n[通过] 所有因子在不同 LOOKBACK_DAYS 设置下结果一致")
|
||||||
|
|
||||||
|
|
||||||
|
def test_simple_factor_consistency():
|
||||||
|
"""
|
||||||
|
使用简单的测试因子验证 LOOKBACK_DAYS 的影响。
|
||||||
|
|
||||||
|
这个测试不依赖模型文件,使用内置的简单因子进行验证。
|
||||||
|
"""
|
||||||
|
print("\n" + "=" * 80)
|
||||||
|
print("简单因子一致性测试(不依赖模型文件)")
|
||||||
|
print("=" * 80)
|
||||||
|
|
||||||
|
# 定义测试用的简单因子
|
||||||
|
test_factors = SELECTED_FACTORS
|
||||||
|
feature_cols = test_factors
|
||||||
|
|
||||||
|
def compute_simple_factors(lookback_days: int) -> pl.DataFrame:
|
||||||
|
"""计算简单因子。"""
|
||||||
|
actual_start = get_lookback_start_date(PREDICT_START, lookback_days)
|
||||||
|
|
||||||
|
print(f"\nLOOKBACK_DAYS = {lookback_days} ({lookback_days // 365}年)")
|
||||||
|
print(f"实际加载数据范围: {actual_start} - {PREDICT_END}")
|
||||||
|
|
||||||
|
engine = FactorEngine()
|
||||||
|
|
||||||
|
# 注册因子
|
||||||
|
for name in test_factors:
|
||||||
|
engine.add_factor(name)
|
||||||
|
|
||||||
|
# 计算因子
|
||||||
|
data = engine.compute(
|
||||||
|
factor_names=feature_cols,
|
||||||
|
start_date=actual_start,
|
||||||
|
end_date=PREDICT_END,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 过滤到预测日期范围
|
||||||
|
data = data.filter(data["trade_date"] >= PREDICT_START)
|
||||||
|
|
||||||
|
print(f"计算完成: {data.shape}")
|
||||||
|
return data
|
||||||
|
|
||||||
|
# 计算两种设置下的因子
|
||||||
|
data_2y = compute_simple_factors(LOOKBACK_2Y)
|
||||||
|
data_3y = compute_simple_factors(LOOKBACK_3Y)
|
||||||
|
|
||||||
|
# 比较结果
|
||||||
|
results = compare_factor_values(data_2y, data_3y, feature_cols)
|
||||||
|
|
||||||
|
# 断言
|
||||||
|
assert results["inconsistent_factors"] == 0, (
|
||||||
|
f"发现 {results['inconsistent_factors']} 个简单因子在不同 LOOKBACK_DAYS 下结果不一致"
|
||||||
|
)
|
||||||
|
|
||||||
|
print("\n[通过] 所有简单因子在不同 LOOKBACK_DAYS 设置下结果一致")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# 运行简单测试(不依赖模型文件)
|
||||||
|
test_simple_factor_consistency()
|
||||||
|
|
||||||
|
# 运行完整测试(需要模型文件)
|
||||||
|
# test_lookback_consistency()
|
||||||
Reference in New Issue
Block a user