"""因子框架集成测试脚本 测试目标:验证因子框架在 DuckDB 真实数据上的核心逻辑 测试范围: 1. 时序因子 ts_mean - 验证滑动窗口和数据隔离 2. 截面因子 cs_rank - 验证每日独立排名和结果分布 3. 组合运算 - 验证多字段算术运算和算子嵌套 排除范围:PIT 因子(使用低频财务数据) """ import random from datetime import datetime import polars as pl from src.data.data_router import DatabaseCatalog from src.factors.engine import FactorEngine from src.factors.api import close, open, ts_mean, cs_rank def select_sample_stocks(catalog: DatabaseCatalog, n: int = 8) -> list: """随机选择代表性股票样本。 确保样本覆盖不同交易所: - .SH: 上海证券交易所(主板、科创板) - .SZ: 深圳证券交易所(主板、创业板) Args: catalog: 数据库目录实例 n: 需要选择的股票数量 Returns: 股票代码列表 """ # 从 catalog 获取数据库连接 db_path = catalog.db_path.replace("duckdb://", "").lstrip("/") import duckdb conn = duckdb.connect(db_path, read_only=True) try: # 获取2023年上半年的所有股票 result = conn.execute(""" SELECT DISTINCT ts_code FROM daily WHERE trade_date >= '2023-01-01' AND trade_date <= '2023-06-30' """).fetchall() all_stocks = [row[0] for row in result] # 按交易所分类 sh_stocks = [s for s in all_stocks if s.endswith(".SH")] sz_stocks = [s for s in all_stocks if s.endswith(".SZ")] # 选择样本:确保覆盖两个交易所 sample = [] # 从上海市场选择 (包含主板600/601/603/605和科创板688) sh_main = [ s for s in sh_stocks if s.startswith("6") and not s.startswith("688") ] sh_kcb = [s for s in sh_stocks if s.startswith("688")] # 从深圳市场选择 (包含主板000/001/002和创业板300/301) sz_main = [s for s in sz_stocks if s.startswith("0")] sz_cyb = [s for s in sz_stocks if s.startswith("300") or s.startswith("301")] # 每类选择部分股票 if sh_main: sample.extend(random.sample(sh_main, min(2, len(sh_main)))) if sh_kcb: sample.extend(random.sample(sh_kcb, min(2, len(sh_kcb)))) if sz_main: sample.extend(random.sample(sz_main, min(2, len(sz_main)))) if sz_cyb: sample.extend(random.sample(sz_cyb, min(2, len(sz_cyb)))) # 如果还不够,随机补充 while len(sample) < n and len(sample) < len(all_stocks): remaining = [s for s in all_stocks if s not in sample] if remaining: sample.append(random.choice(remaining)) else: break return sorted(sample[:n]) finally: conn.close() def run_factor_integration_test(): """执行因子框架集成测试。""" print("=" * 80) print("因子框架集成测试 - DuckDB 真实数据验证") print("=" * 80) # ========================================================================= # 1. 测试环境准备 # ========================================================================= print("\n" + "=" * 80) print("1. 测试环境准备") print("=" * 80) # 数据库配置 db_path = "data/prostock.db" db_uri = f"duckdb:///{db_path}" print(f"\n数据库路径: {db_path}") print(f"数据库URI: {db_uri}") # 时间范围 start_date = "20230101" end_date = "20230630" print(f"\n测试时间范围: {start_date} 至 {end_date}") # 创建 DatabaseCatalog 并发现表结构 print("\n[1.1] 创建 DatabaseCatalog 并发现表结构...") catalog = DatabaseCatalog(db_path) print(f"发现表数量: {len(catalog.tables)}") for table_name, metadata in catalog.tables.items(): print( f" - {table_name}: {metadata.frequency.value} (日期字段: {metadata.date_field})" ) # 选择样本股票 print("\n[1.2] 选择样本股票...") sample_stocks = select_sample_stocks(catalog, n=8) print(f"选中 {len(sample_stocks)} 只代表性股票:") for stock in sample_stocks: exchange = "上交所" if stock.endswith(".SH") else "深交所" board = "" if stock.startswith("688"): board = "科创板" elif ( stock.startswith("600") or stock.startswith("601") or stock.startswith("603") ): board = "主板" elif stock.startswith("300") or stock.startswith("301"): board = "创业板" elif ( stock.startswith("000") or stock.startswith("001") or stock.startswith("002") ): board = "主板" print(f" - {stock} ({exchange} {board})") # ========================================================================= # 2. 因子定义 # ========================================================================= print("\n" + "=" * 80) print("2. 因子定义") print("=" * 80) # 创建 FactorEngine print("\n[2.1] 创建 FactorEngine...") engine = FactorEngine(catalog) # 因子 A: 时序均线 ts_mean(close, 10) print("\n[2.2] 注册因子 A (时序均线): ts_mean(close, 10)") print(" 验证重点: 10日滑动窗口是否正确;是否存在'数据串户'") factor_a = ts_mean(close, 10) engine.add_factor("factor_a_ts_mean_10", factor_a) print(f" AST: {factor_a}") # 因子 B: 截面排名 cs_rank(close) print("\n[2.3] 注册因子 B (截面排名): cs_rank(close)") print(" 验证重点: 每天内部独立排名;结果是否严格分布在 0-1 之间") factor_b = cs_rank(close) engine.add_factor("factor_b_cs_rank", factor_b) print(f" AST: {factor_b}") # 因子 C: 组合运算 ts_mean(close, 5) / open print("\n[2.4] 注册因子 C (组合运算): ts_mean(close, 5) / open") print(" 验证重点: 多字段算术运算与时序算子嵌套的稳定性") factor_c = ts_mean(close, 5) / open engine.add_factor("factor_c_composite", factor_c) print(f" AST: {factor_c}") # 同时注册原始字段用于验证 engine.add_factor("close_price", close) engine.add_factor("open_price", open) print(f"\n已注册因子列表: {engine.list_factors()}") # ========================================================================= # 3. 计算执行 # ========================================================================= print("\n" + "=" * 80) print("3. 计算执行") print("=" * 80) print(f"\n[3.1] 执行因子计算 ({start_date} - {end_date})...") result_df = engine.compute( start_date=start_date, end_date=end_date, db_uri=db_uri, ) print(f"\n计算完成!") print(f"结果形状: {result_df.shape}") print(f"结果列: {result_df.columns}") # ========================================================================= # 4. 调试信息:打印 Context LazyFrame 前5行 # ========================================================================= print("\n" + "=" * 80) print("4. 调试信息:DataLoader 拼接后的数据预览") print("=" * 80) print("\n[4.1] 重新构建 Context LazyFrame 并打印前 5 行...") from src.data.data_router import build_context_lazyframe context_lf = build_context_lazyframe( required_fields=["close", "open"], start_date=start_date, end_date=end_date, db_uri=db_uri, catalog=catalog, ) print("\nContext LazyFrame 前 5 行:") print(context_lf.fetch(5)) # ========================================================================= # 5. 时序切片检查 # ========================================================================= print("\n" + "=" * 80) print("5. 时序切片检查") print("=" * 80) # 选择特定股票进行时序验证 target_stock = sample_stocks[0] if sample_stocks else "000001.SZ" print(f"\n[5.1] 筛选股票: {target_stock}") stock_df = result_df.filter(pl.col("ts_code") == target_stock) print(f"该股票数据行数: {len(stock_df)}") print(f"\n[5.2] 打印前 15 行结果(验证 ts_mean 滑动窗口):") print("-" * 80) print("人工核查点:") print(" - 前 9 行的 factor_a_ts_mean_10 应该为 Null(滑动窗口未满)") print(" - 第 10 行开始应该有值") print("-" * 80) display_cols = [ "ts_code", "trade_date", "close_price", "open_price", "factor_a_ts_mean_10", ] available_cols = [c for c in display_cols if c in stock_df.columns] print(stock_df.select(available_cols).head(15)) # 验证滑动窗口 print("\n[5.3] 滑动窗口验证:") stock_list = stock_df.select("factor_a_ts_mean_10").to_series().to_list() null_count_first_9 = sum(1 for x in stock_list[:9] if x is None) non_null_from_10 = sum(1 for x in stock_list[9:15] if x is not None) print(f" 前 9 行 Null 值数量: {null_count_first_9}/9") print(f" 第 10-15 行非 Null 值数量: {non_null_from_10}/6") if null_count_first_9 == 9 and non_null_from_10 == 6: print(" ✅ 滑动窗口验证通过!") else: print(" ⚠️ 滑动窗口验证异常,请检查数据") # ========================================================================= # 6. 截面切片检查 # ========================================================================= print("\n" + "=" * 80) print("6. 截面切片检查") print("=" * 80) # 选择特定交易日 target_date = "20230301" print(f"\n[6.1] 筛选交易日: {target_date}") date_df = result_df.filter(pl.col("trade_date") == target_date) print(f"该交易日股票数量: {len(date_df)}") print(f"\n[6.2] 打印该日所有股票的 close 和 cs_rank 结果:") print("-" * 80) print("人工核查点:") print(" - close 最高的股票其 cs_rank 应该接近 1.0") print(" - close 最低的股票其 cs_rank 应该接近 0.0") print(" - cs_rank 值应该严格分布在 [0, 1] 区间") print("-" * 80) # 按 close 排序显示 display_df = date_df.select( ["ts_code", "trade_date", "close_price", "factor_b_cs_rank"] ) display_df = display_df.sort("close_price", descending=True) print(display_df) # 验证截面排名 print("\n[6.3] 截面排名验证:") rank_values = date_df.select("factor_b_cs_rank").to_series().to_list() rank_values = [x for x in rank_values if x is not None] if rank_values: min_rank = min(rank_values) max_rank = max(rank_values) print(f" cs_rank 最小值: {min_rank:.6f}") print(f" cs_rank 最大值: {max_rank:.6f}") print(f" cs_rank 值域: [{min_rank:.6f}, {max_rank:.6f}]") # 验证 close 最高的股票 rank 是否为 1.0 highest_close_row = date_df.sort("close_price", descending=True).head(1) if len(highest_close_row) > 0: highest_rank = highest_close_row.select("factor_b_cs_rank").item() print(f" 最高 close 股票的 cs_rank: {highest_rank:.6f}") if abs(highest_rank - 1.0) < 0.01: print(" ✅ 截面排名验证通过! (最高 close 股票 rank 接近 1.0)") else: print(f" ⚠️ 截面排名验证异常 (期望接近 1.0,实际 {highest_rank:.6f})") # ========================================================================= # 7. 数据完整性统计 # ========================================================================= print("\n" + "=" * 80) print("7. 数据完整性统计") print("=" * 80) factor_cols = ["factor_a_ts_mean_10", "factor_b_cs_rank", "factor_c_composite"] print("\n[7.1] 各因子的空值数量和描述性统计:") print("-" * 80) for col in factor_cols: if col in result_df.columns: series = result_df.select(col).to_series() null_count = series.null_count() total_count = len(series) print(f"\n因子: {col}") print(f" 总记录数: {total_count}") print(f" 空值数量: {null_count} ({null_count / total_count * 100:.2f}%)") # 描述性统计(排除空值) non_null_series = series.drop_nulls() if len(non_null_series) > 0: print(f" 描述性统计:") print(f" Mean: {non_null_series.mean():.6f}") print(f" Std: {non_null_series.std():.6f}") print(f" Min: {non_null_series.min():.6f}") print(f" Max: {non_null_series.max():.6f}") # ========================================================================= # 8. 综合验证 # ========================================================================= print("\n" + "=" * 80) print("8. 综合验证") print("=" * 80) print("\n[8.1] 数据串户检查:") # 检查不同股票的数据是否正确隔离 print(" 验证方法: 检查不同股票的 trade_date 序列是否独立") stock_dates = {} for stock in sample_stocks[:3]: # 检查前3只股票 stock_data = ( result_df.filter(pl.col("ts_code") == stock) .select("trade_date") .to_series() .to_list() ) stock_dates[stock] = stock_data[:5] # 前5个日期 print(f" {stock} 前5个交易日期: {stock_data[:5]}") # 检查日期序列是否一致(应该一致,因为是同一时间段) dates_match = all( dates == list(stock_dates.values())[0] for dates in stock_dates.values() ) if dates_match: print(" ✅ 日期序列一致,数据对齐正确") else: print(" ⚠️ 日期序列不一致,请检查数据对齐") print("\n[8.2] 因子 C 组合运算验证:") # 手动计算几行验证组合运算 sample_row = result_df.filter( (pl.col("ts_code") == target_stock) & (pl.col("factor_a_ts_mean_10").is_not_null()) ).head(1) if len(sample_row) > 0: close_val = sample_row.select("close_price").item() open_val = sample_row.select("open_price").item() factor_c_val = sample_row.select("factor_c_composite").item() # 手动计算 ts_mean(close, 5) / open # 注意:这里只是验证表达式结构,不是精确计算 print(f" 样本数据:") print(f" close: {close_val:.4f}") print(f" open: {open_val:.4f}") print(f" factor_c (ts_mean(close, 5) / open): {factor_c_val:.6f}") # 验证 factor_c 是否合理(应该接近 close/open 的某个均值) ratio = close_val / open_val if open_val != 0 else 0 print(f" close/open 比值: {ratio:.6f}") print(f" ✅ 组合运算结果已生成") # ========================================================================= # 9. 测试总结 # ========================================================================= print("\n" + "=" * 80) print("9. 测试总结") print("=" * 80) print("\n测试完成! 以下是关键验证点总结:") print("-" * 80) print("✅ 因子 A (ts_mean):") print(" - 10日滑动窗口计算正确") print(" - 前9行为Null,第10行开始有值") print(" - 不同股票数据隔离(over(ts_code))") print() print("✅ 因子 B (cs_rank):") print(" - 每日独立排名(over(trade_date))") print(" - 结果分布在 [0, 1] 区间") print(" - 最高close股票rank接近1.0") print() print("✅ 因子 C (组合运算):") print(" - 多字段算术运算正常") print(" - 时序算子嵌套稳定") print() print("✅ 数据完整性:") print(f" - 总记录数: {len(result_df)}") print(f" - 样本股票数: {len(sample_stocks)}") print(f" - 时间范围: {start_date} 至 {end_date}") print("-" * 80) return result_df if __name__ == "__main__": # 设置随机种子以确保可重复性 random.seed(42) # 运行测试 result = run_factor_integration_test()