refactor(experiment): 重构模型保存机制,支持 processors 持久化
- 模型保存路径改为 models/{model_type}/ 目录结构
- save_model_with_factors 新增 fitted_processors 参数
- 新增 load_processors 函数加载处理器状态
- Storage 查询排序优化:ORDER BY ts_code, trade_date
This commit is contained in:
409
tests/debug/test_lookback_consistency.py
Normal file
409
tests/debug/test_lookback_consistency.py
Normal file
@@ -0,0 +1,409 @@
|
||||
"""
|
||||
测试 LOOKBACK_DAYS 对因子计算结果的影响
|
||||
|
||||
测试目标:验证不同 LOOKBACK_DAYS 设置下,同一预测日期范围的因子值是否一致
|
||||
如果结果不一致,说明可能存在数据泄露问题
|
||||
|
||||
测试逻辑:
|
||||
1. 分别使用 2 年(730天)和 3 年(1095天)作为 LOOKBACK_DAYS
|
||||
2. 计算同一预测日期范围(2025-2026)的因子值
|
||||
3. 比较两者的因子值是否相同
|
||||
|
||||
预期结果:
|
||||
- 如果回看窗口大于最大因子窗口,两种设置下的因子值应该完全一致
|
||||
- 如果结果不同,说明因子计算使用了超出合理回看期的数据(数据泄露)
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Any, Dict, List, Tuple
|
||||
|
||||
import numpy as np
|
||||
import polars as pl
|
||||
import pytest
|
||||
|
||||
from src.factors import FactorEngine
|
||||
from src.experiment.common import (
|
||||
SELECTED_FACTORS,
|
||||
FACTOR_DEFINITIONS,
|
||||
get_label_factor,
|
||||
register_factors,
|
||||
prepare_data,
|
||||
)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# 测试配置
|
||||
# =============================================================================
|
||||
PREDICT_START = "20250101"
|
||||
PREDICT_END = "20250131" # 只测试1月份,加快测试速度
|
||||
MODEL_SAVE_DIR = "models"
|
||||
|
||||
# 两种不同的回看窗口设置
|
||||
LOOKBACK_2Y = 365 * 3 # 2年 = 730天
|
||||
LOOKBACK_3Y = 365 * 4 # 3年 = 1095天
|
||||
|
||||
|
||||
def get_lookback_start_date(start_date: str, lookback_days: int) -> str:
|
||||
"""计算考虑回看窗口后的实际开始日期。"""
|
||||
start_dt = datetime.strptime(start_date, "%Y%m%d")
|
||||
lookback_dt = start_dt - timedelta(days=lookback_days)
|
||||
return lookback_dt.strftime("%Y%m%d")
|
||||
|
||||
|
||||
def load_model_factors(
|
||||
model_type: str, models_dir: str
|
||||
) -> Tuple[Dict[str, Any], List[str]]:
|
||||
"""加载模型的因子信息。"""
|
||||
factors_path = os.path.join(models_dir, model_type, "factors.json")
|
||||
|
||||
if not os.path.exists(factors_path):
|
||||
raise FileNotFoundError(f"因子信息文件不存在: {factors_path}")
|
||||
|
||||
with open(factors_path, "r", encoding="utf-8") as f:
|
||||
factors_info = json.load(f)
|
||||
|
||||
selected_factors = SELECTED_FACTORS
|
||||
factor_definitions = SELECTED_FACTORS
|
||||
feature_cols = SELECTED_FACTORS
|
||||
|
||||
return factors_info, feature_cols
|
||||
|
||||
|
||||
def compute_factors_with_lookback(
|
||||
lookback_days: int,
|
||||
feature_cols: List[str],
|
||||
factors_info: Dict[str, Any],
|
||||
) -> pl.DataFrame:
|
||||
"""
|
||||
使用指定的回看窗口计算因子。
|
||||
|
||||
Args:
|
||||
lookback_days: 回看窗口天数
|
||||
feature_cols: 特征列名称列表
|
||||
factors_info: 因子信息字典
|
||||
|
||||
Returns:
|
||||
包含因子计算结果的数据框(已过滤到预测日期范围)
|
||||
"""
|
||||
# 计算实际开始日期
|
||||
actual_start = get_lookback_start_date(PREDICT_START, lookback_days)
|
||||
|
||||
print(f"\n{'=' * 80}")
|
||||
print(f"使用 LOOKBACK_DAYS = {lookback_days} ({lookback_days // 365}年)")
|
||||
print(f"预测日期范围: {PREDICT_START} - {PREDICT_END}")
|
||||
print(f"实际加载数据范围: {actual_start} - {PREDICT_END}")
|
||||
print(f"{'=' * 80}")
|
||||
|
||||
# 创建 FactorEngine
|
||||
engine = FactorEngine()
|
||||
|
||||
# 注册因子
|
||||
selected_factors = factors_info.get("selected_factors", [])
|
||||
factor_definitions = factors_info.get("factor_definitions", {})
|
||||
label_name = "future_return_5"
|
||||
label_factor = get_label_factor(label_name)
|
||||
|
||||
register_factors(
|
||||
engine=engine,
|
||||
selected_factors=selected_factors,
|
||||
factor_definitions=factor_definitions,
|
||||
label_factor=label_factor,
|
||||
)
|
||||
|
||||
# 计算因子
|
||||
data = prepare_data(
|
||||
engine=engine,
|
||||
feature_cols=feature_cols,
|
||||
start_date=actual_start,
|
||||
end_date=PREDICT_END,
|
||||
label_name=label_name,
|
||||
)
|
||||
|
||||
# 过滤回看数据,只保留预测日期范围内的数据
|
||||
data = data.filter(data["trade_date"] >= PREDICT_START)
|
||||
|
||||
print(f"\n过滤后数据形状: {data.shape}")
|
||||
print(f"过滤后日期范围: {data['trade_date'].min()} - {data['trade_date'].max()}")
|
||||
print(f"过滤后股票数量: {data['ts_code'].n_unique()}")
|
||||
|
||||
return data
|
||||
|
||||
|
||||
def compare_factor_values(
|
||||
data_2y: pl.DataFrame,
|
||||
data_3y: pl.DataFrame,
|
||||
feature_cols: List[str],
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
比较两种回看窗口设置下的因子值。
|
||||
|
||||
Args:
|
||||
data_2y: 2年回看窗口的因子数据
|
||||
data_3y: 3年回看窗口的因子数据
|
||||
feature_cols: 特征列名称列表
|
||||
|
||||
Returns:
|
||||
比较结果字典
|
||||
"""
|
||||
print(f"\n{'=' * 80}")
|
||||
print("比较因子计算结果")
|
||||
print(f"{'=' * 80}")
|
||||
|
||||
# 确保两个数据集的行数和列数相同
|
||||
print(f"\n数据集形状:")
|
||||
print(f" 2Y 回看: {data_2y.shape}")
|
||||
print(f" 3Y 回看: {data_3y.shape}")
|
||||
|
||||
if data_2y.shape != data_3y.shape:
|
||||
print(f"[警告] 数据形状不一致!")
|
||||
# 找出差异
|
||||
dates_2y = set(data_2y["trade_date"].to_list())
|
||||
dates_3y = set(data_3y["trade_date"].to_list())
|
||||
stocks_2y = set(data_2y["ts_code"].to_list())
|
||||
stocks_3y = set(data_3y["ts_code"].to_list())
|
||||
|
||||
print(f" 2Y 日期数: {len(dates_2y)}, 股票数: {len(stocks_2y)}")
|
||||
print(f" 3Y 日期数: {len(dates_3y)}, 股票数: {len(stocks_3y)}")
|
||||
|
||||
# 使用交集进行后续比较
|
||||
common_dates = dates_2y & dates_3y
|
||||
common_stocks = stocks_2y & stocks_3y
|
||||
|
||||
print(f" 共同日期数: {len(common_dates)}")
|
||||
print(f" 共同股票数: {len(common_stocks)}")
|
||||
|
||||
data_2y = data_2y.filter(
|
||||
data_2y["trade_date"].is_in(list(common_dates))
|
||||
& data_2y["ts_code"].is_in(list(common_stocks))
|
||||
)
|
||||
data_3y = data_3y.filter(
|
||||
data_3y["trade_date"].is_in(list(common_dates))
|
||||
& data_3y["ts_code"].is_in(list(common_stocks))
|
||||
)
|
||||
|
||||
# 按日期和股票代码排序
|
||||
data_2y = data_2y.sort(["trade_date", "ts_code"])
|
||||
data_3y = data_3y.sort(["trade_date", "ts_code"])
|
||||
|
||||
# 比较每个因子的值
|
||||
results = {
|
||||
"total_factors": len(feature_cols),
|
||||
"consistent_factors": 0,
|
||||
"inconsistent_factors": 0,
|
||||
"inconsistent_details": [],
|
||||
}
|
||||
|
||||
print(f"\n因子一致性检查:")
|
||||
for factor_name in feature_cols:
|
||||
if factor_name not in data_2y.columns or factor_name not in data_3y.columns:
|
||||
print(f" [跳过] {factor_name}: 因子不存在于两个数据集中")
|
||||
continue
|
||||
|
||||
# 获取因子值(转换为 numpy 数组)
|
||||
values_2y = data_2y[factor_name].to_numpy()
|
||||
values_3y = data_3y[factor_name].to_numpy()
|
||||
|
||||
# 处理 NaN 值 - 转换为 float 类型以确保兼容性
|
||||
values_2y = np.asarray(values_2y, dtype=np.float64)
|
||||
values_3y = np.asarray(values_3y, dtype=np.float64)
|
||||
|
||||
# 处理 NaN 值
|
||||
mask_2y = ~np.isnan(values_2y)
|
||||
mask_3y = ~np.isnan(values_3y)
|
||||
|
||||
# 检查 NaN 模式是否一致
|
||||
nan_consistent = np.array_equal(mask_2y, mask_3y)
|
||||
|
||||
if not nan_consistent:
|
||||
print(f" [警告] {factor_name}: NaN 模式不一致!")
|
||||
print(f" 2Y NaN 数量: {np.sum(~mask_2y)}")
|
||||
print(f" 3Y NaN 数量: {np.sum(~mask_3y)}")
|
||||
|
||||
# 只在两者都有有效值的位置进行比较
|
||||
valid_mask = mask_2y & mask_3y
|
||||
|
||||
if np.sum(valid_mask) == 0:
|
||||
print(f" [跳过] {factor_name}: 没有有效的共同数据点")
|
||||
continue
|
||||
|
||||
valid_2y = values_2y[valid_mask]
|
||||
valid_3y = values_3y[valid_mask]
|
||||
|
||||
# 检查数值是否一致(使用相对容差)
|
||||
consistent = np.allclose(valid_2y, valid_3y, rtol=1e-10, atol=1e-10)
|
||||
|
||||
if consistent:
|
||||
results["consistent_factors"] += 1
|
||||
print(f" [一致] {factor_name}")
|
||||
else:
|
||||
results["inconsistent_factors"] += 1
|
||||
|
||||
# 计算差异统计
|
||||
diff = np.abs(valid_2y - valid_3y)
|
||||
max_diff = np.max(diff)
|
||||
mean_diff = np.mean(diff)
|
||||
|
||||
results["inconsistent_details"].append(
|
||||
{
|
||||
"factor": factor_name,
|
||||
"max_diff": max_diff,
|
||||
"mean_diff": mean_diff,
|
||||
"count_diff": np.sum(diff > 1e-10),
|
||||
}
|
||||
)
|
||||
|
||||
print(f" [不一致] {factor_name}:")
|
||||
print(f" 最大差异: {max_diff:.10f}")
|
||||
print(f" 平均差异: {mean_diff:.10f}")
|
||||
print(f" 差异数据点数量: {np.sum(diff > 1e-10)}")
|
||||
|
||||
# 显示前几个差异
|
||||
diff_indices = np.where(diff > 1e-10)[0][:5]
|
||||
print(f" 前几个差异值:")
|
||||
for idx in diff_indices:
|
||||
print(
|
||||
f" idx={idx}: 2Y={valid_2y[idx]:.10f}, 3Y={valid_3y[idx]:.10f}, diff={diff[idx]:.10f}"
|
||||
)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def test_lookback_consistency():
|
||||
"""
|
||||
测试 LOOKBACK_DAYS 设置对因子计算结果的影响。
|
||||
|
||||
这个测试会:
|
||||
1. 加载模型因子配置
|
||||
2. 分别使用 2 年和 3 年回看窗口计算因子
|
||||
3. 比较结果是否一致
|
||||
|
||||
如果结果不一致,说明存在数据泄露问题。
|
||||
"""
|
||||
print("\n" + "=" * 80)
|
||||
print("LOOKBACK_DAYS 一致性测试")
|
||||
print("=" * 80)
|
||||
|
||||
# 检测模型类型
|
||||
available_types = []
|
||||
if os.path.exists(os.path.join(MODEL_SAVE_DIR, "regression", "model.pkl")):
|
||||
available_types.append("regression")
|
||||
if os.path.exists(os.path.join(MODEL_SAVE_DIR, "rank", "model.pkl")):
|
||||
available_types.append("rank")
|
||||
|
||||
if not available_types:
|
||||
pytest.skip(f"未在 {MODEL_SAVE_DIR} 目录下找到任何模型,跳过测试")
|
||||
|
||||
model_type = available_types[0]
|
||||
print(f"\n使用模型类型: {model_type}")
|
||||
|
||||
# 加载因子信息
|
||||
try:
|
||||
factors_info, feature_cols = load_model_factors(model_type, MODEL_SAVE_DIR)
|
||||
except FileNotFoundError as e:
|
||||
pytest.skip(f"无法加载因子信息: {e}")
|
||||
|
||||
print(f"因子数量: {len(feature_cols)}")
|
||||
|
||||
# 使用 2 年回看窗口计算因子
|
||||
data_2y = compute_factors_with_lookback(
|
||||
lookback_days=LOOKBACK_2Y,
|
||||
feature_cols=feature_cols,
|
||||
factors_info=factors_info,
|
||||
)
|
||||
|
||||
# 使用 3 年回看窗口计算因子
|
||||
data_3y = compute_factors_with_lookback(
|
||||
lookback_days=LOOKBACK_3Y,
|
||||
feature_cols=feature_cols,
|
||||
factors_info=factors_info,
|
||||
)
|
||||
|
||||
# 比较结果
|
||||
results = compare_factor_values(data_2y, data_3y, feature_cols)
|
||||
|
||||
# 打印总结
|
||||
print(f"\n{'=' * 80}")
|
||||
print("测试结果总结")
|
||||
print(f"{'=' * 80}")
|
||||
print(f"总因子数: {results['total_factors']}")
|
||||
print(f"一致因子数: {results['consistent_factors']}")
|
||||
print(f"不一致因子数: {results['inconsistent_factors']}")
|
||||
|
||||
if results["inconsistent_factors"] > 0:
|
||||
print(f"\n不一致的因子:")
|
||||
for detail in results["inconsistent_details"]:
|
||||
print(f" - {detail['factor']}: 最大差异={detail['max_diff']:.10f}")
|
||||
|
||||
# 断言:如果有不一致的因子,测试失败
|
||||
inconsistent_names = [d["factor"] for d in results["inconsistent_details"]]
|
||||
pytest.fail(
|
||||
f"发现 {results['inconsistent_factors']} 个因子在不同 LOOKBACK_DAYS 设置下结果不一致,"
|
||||
f"可能存在数据泄露: {inconsistent_names[:5]}..."
|
||||
)
|
||||
else:
|
||||
print("\n[通过] 所有因子在不同 LOOKBACK_DAYS 设置下结果一致")
|
||||
|
||||
|
||||
def test_simple_factor_consistency():
|
||||
"""
|
||||
使用简单的测试因子验证 LOOKBACK_DAYS 的影响。
|
||||
|
||||
这个测试不依赖模型文件,使用内置的简单因子进行验证。
|
||||
"""
|
||||
print("\n" + "=" * 80)
|
||||
print("简单因子一致性测试(不依赖模型文件)")
|
||||
print("=" * 80)
|
||||
|
||||
# 定义测试用的简单因子
|
||||
test_factors = SELECTED_FACTORS
|
||||
feature_cols = test_factors
|
||||
|
||||
def compute_simple_factors(lookback_days: int) -> pl.DataFrame:
|
||||
"""计算简单因子。"""
|
||||
actual_start = get_lookback_start_date(PREDICT_START, lookback_days)
|
||||
|
||||
print(f"\nLOOKBACK_DAYS = {lookback_days} ({lookback_days // 365}年)")
|
||||
print(f"实际加载数据范围: {actual_start} - {PREDICT_END}")
|
||||
|
||||
engine = FactorEngine()
|
||||
|
||||
# 注册因子
|
||||
for name in test_factors:
|
||||
engine.add_factor(name)
|
||||
|
||||
# 计算因子
|
||||
data = engine.compute(
|
||||
factor_names=feature_cols,
|
||||
start_date=actual_start,
|
||||
end_date=PREDICT_END,
|
||||
)
|
||||
|
||||
# 过滤到预测日期范围
|
||||
data = data.filter(data["trade_date"] >= PREDICT_START)
|
||||
|
||||
print(f"计算完成: {data.shape}")
|
||||
return data
|
||||
|
||||
# 计算两种设置下的因子
|
||||
data_2y = compute_simple_factors(LOOKBACK_2Y)
|
||||
data_3y = compute_simple_factors(LOOKBACK_3Y)
|
||||
|
||||
# 比较结果
|
||||
results = compare_factor_values(data_2y, data_3y, feature_cols)
|
||||
|
||||
# 断言
|
||||
assert results["inconsistent_factors"] == 0, (
|
||||
f"发现 {results['inconsistent_factors']} 个简单因子在不同 LOOKBACK_DAYS 下结果不一致"
|
||||
)
|
||||
|
||||
print("\n[通过] 所有简单因子在不同 LOOKBACK_DAYS 设置下结果一致")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 运行简单测试(不依赖模型文件)
|
||||
test_simple_factor_consistency()
|
||||
|
||||
# 运行完整测试(需要模型文件)
|
||||
# test_lookback_consistency()
|
||||
Reference in New Issue
Block a user