Files
ProStock/tests/check_gtja.py

86 lines
2.2 KiB
Python
Raw Normal View History

"""检查 GTJA_alpha 因子"""
import polars as pl
from src.factors import FactorEngine
from src.training import FactorManager
from src.experiment.common import (
SELECTED_FACTORS,
FACTOR_DEFINITIONS,
LABEL_FACTOR,
)
EXCLUDED_FACTORS = [
"GTJA_alpha001",
"GTJA_alpha002",
"GTJA_alpha003",
"GTJA_alpha004",
"GTJA_alpha005",
"GTJA_alpha006",
"GTJA_alpha007",
"GTJA_alpha008",
"GTJA_alpha009",
"GTJA_alpha010",
"GTJA_alpha011",
"GTJA_alpha012",
"GTJA_alpha013",
"GTJA_alpha014",
"GTJA_alpha015",
]
def main():
print("=" * 80)
print("检查 GTJA_alpha 因子")
print("=" * 80)
engine = FactorEngine()
factor_manager = FactorManager(
selected_factors=SELECTED_FACTORS,
factor_definitions=FACTOR_DEFINITIONS,
label_factor=LABEL_FACTOR,
excluded_factors=EXCLUDED_FACTORS,
)
# 注册因子
feature_cols = factor_manager.register_to_engine(engine, verbose=False)
# 找出 GTJA_alpha 因子
gtja_factors = [f for f in feature_cols if f.startswith("GTJA_alpha")]
print(f"\nGTJA_alpha 因子数量: {len(gtja_factors)}")
print(f"前10个: {gtja_factors[:10]}")
# 计算一个小的日期范围
print("\n计算因子数据...")
data = engine.compute(
factor_names=gtja_factors[:10] + ["close"], # 只计算前10个 GTJA_alpha + close
start_date="20200101",
end_date="20200110",
)
print(f"\n数据形状: {data.shape}")
print(f"列: {data.columns}")
# 检查每个 GTJA_alpha 因子的 NaN 情况
print("\nGTJA_alpha 因子 NaN 统计:")
for col in gtja_factors[:10]:
if col in data.columns:
nan_count = data[col].null_count()
total = len(data)
print(f" {col}: {nan_count}/{total} ({nan_count / total * 100:.1f}%)")
else:
print(f" {col}: 列不存在!")
# 检查 close 列作为对比
print(
f"\n close: {data['close'].null_count()}/{len(data)} ({data['close'].null_count() / len(data) * 100:.1f}%)"
)
# 查看实际数据
print("\n实际数据预览:")
print(data.select(["trade_date", "ts_code"] + gtja_factors[:3]).head(10))
if __name__ == "__main__":
main()