- 添加 TabMModel、TabPFNModel 深度学习模型实现 - 新增 DataQualityAnalyzer 进行训练前数据质量诊断 - 改进数据处理器 NaN/null 双重处理,增强数据鲁棒性 - 支持 train_skip_days 参数跳过训练初期数据不足期 - Pipeline 自动清理标签为 NaN 的样本
86 lines
2.2 KiB
Python
86 lines
2.2 KiB
Python
"""检查 GTJA_alpha 因子"""
|
|
|
|
import polars as pl
|
|
|
|
from src.factors import FactorEngine
|
|
from src.training import FactorManager
|
|
from src.experiment.common import (
|
|
SELECTED_FACTORS,
|
|
FACTOR_DEFINITIONS,
|
|
LABEL_FACTOR,
|
|
)
|
|
|
|
EXCLUDED_FACTORS = [
|
|
"GTJA_alpha001",
|
|
"GTJA_alpha002",
|
|
"GTJA_alpha003",
|
|
"GTJA_alpha004",
|
|
"GTJA_alpha005",
|
|
"GTJA_alpha006",
|
|
"GTJA_alpha007",
|
|
"GTJA_alpha008",
|
|
"GTJA_alpha009",
|
|
"GTJA_alpha010",
|
|
"GTJA_alpha011",
|
|
"GTJA_alpha012",
|
|
"GTJA_alpha013",
|
|
"GTJA_alpha014",
|
|
"GTJA_alpha015",
|
|
]
|
|
|
|
|
|
def main():
|
|
print("=" * 80)
|
|
print("检查 GTJA_alpha 因子")
|
|
print("=" * 80)
|
|
|
|
engine = FactorEngine()
|
|
factor_manager = FactorManager(
|
|
selected_factors=SELECTED_FACTORS,
|
|
factor_definitions=FACTOR_DEFINITIONS,
|
|
label_factor=LABEL_FACTOR,
|
|
excluded_factors=EXCLUDED_FACTORS,
|
|
)
|
|
|
|
# 注册因子
|
|
feature_cols = factor_manager.register_to_engine(engine, verbose=False)
|
|
|
|
# 找出 GTJA_alpha 因子
|
|
gtja_factors = [f for f in feature_cols if f.startswith("GTJA_alpha")]
|
|
print(f"\nGTJA_alpha 因子数量: {len(gtja_factors)}")
|
|
print(f"前10个: {gtja_factors[:10]}")
|
|
|
|
# 计算一个小的日期范围
|
|
print("\n计算因子数据...")
|
|
data = engine.compute(
|
|
factor_names=gtja_factors[:10] + ["close"], # 只计算前10个 GTJA_alpha + close
|
|
start_date="20200101",
|
|
end_date="20200110",
|
|
)
|
|
|
|
print(f"\n数据形状: {data.shape}")
|
|
print(f"列: {data.columns}")
|
|
|
|
# 检查每个 GTJA_alpha 因子的 NaN 情况
|
|
print("\nGTJA_alpha 因子 NaN 统计:")
|
|
for col in gtja_factors[:10]:
|
|
if col in data.columns:
|
|
nan_count = data[col].null_count()
|
|
total = len(data)
|
|
print(f" {col}: {nan_count}/{total} ({nan_count / total * 100:.1f}%)")
|
|
else:
|
|
print(f" {col}: 列不存在!")
|
|
|
|
# 检查 close 列作为对比
|
|
print(
|
|
f"\n close: {data['close'].null_count()}/{len(data)} ({data['close'].null_count() / len(data) * 100:.1f}%)"
|
|
)
|
|
|
|
# 查看实际数据
|
|
print("\n实际数据预览:")
|
|
print(data.select(["trade_date", "ts_code"] + gtja_factors[:3]).head(10))
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|