diff --git a/tests/test_two_stocks_string_factors.py b/tests/test_two_stocks_string_factors.py new file mode 100644 index 0000000..20f45f1 --- /dev/null +++ b/tests/test_two_stocks_string_factors.py @@ -0,0 +1,454 @@ +"""两支股票因子计算测试 - 使用因子字符串架构。 + +测试目标:使用字符串表达式计算两支股票(601117.SH、000001.SZ)在2024-2025年的以下因子: +1. return_5: 5日收益率 (close / ts_delay(close, 5) - 1) +2. return_5_rank: 5日收益率在截面上的排名 +3. ma5: 5日均线 (ts_mean(close, 5)) +4. ma10: 10日均线 (ts_mean(close, 10)) + +特点:使用因子字符串架构(add_factor + 字符串表达式) + +数据源: DuckDB 数据库中的真实日线数据 +""" + +from src.factors import FactorEngine +from src.factors.compiler import DependencyExtractor + + +def test_two_stocks_string_factors(): + """测试两支股票的因子计算(使用字符串表达式)。""" + print("=" * 80) + print("两支股票因子计算测试 - 使用因子字符串架构") + print("=" * 80) + + # ======================================================================== + # 1. 定义因子表达式(字符串方式) + # ======================================================================== + print("\n" + "=" * 80) + print("1. 定义因子表达式(字符串方式)") + print("=" * 80) + + # return_5: 5日收益率 = (close / close.shift(5) - 1) + return_5_str = "(close / ts_delay(close, 5)) - 1" + print("\n[1.1] return_5 = (close / ts_delay(close, 5)) - 1") + print(f" 字符串表达式: {return_5_str}") + + # return_5_rank: 5日收益率的截面排名 + return_5_rank_str = "cs_rank((close / ts_delay(close, 5)) - 1)" + print("\n[1.2] return_5_rank = cs_rank(return_5)") + print(f" 字符串表达式: {return_5_rank_str}") + + # ma5: 5日均线 + ma5_str = "ts_mean(close, 5)" + print("\n[1.3] ma5 = ts_mean(close, 5)") + print(f" 字符串表达式: {ma5_str}") + + # ma10: 10日均线 + ma10_str = "ts_mean(close, 10)" + print("\n[1.4] ma10 = ts_mean(close, 10)") + print(f" 字符串表达式: {ma10_str}") + + # ======================================================================== + # 1.5 打印数据来源信息 + # ======================================================================== + print("\n" + "=" * 80) + print("1.5 数据来源分析(使用字符串表达式解析)") + print("=" * 80) + + from src.factors.parser import FormulaParser + from src.factors.registry import FunctionRegistry + + registry = FunctionRegistry() + parser = FormulaParser(registry) + + expressions_str = { + "return_5": return_5_str, + "return_5_rank": return_5_rank_str, + "ma5": ma5_str, + "ma10": ma10_str, + } + + for name, expr_str in expressions_str.items(): + # 解析字符串表达式 + node = parser.parse(expr_str) + extractor = DependencyExtractor() + deps = extractor.extract_dependencies(node) + print(f"\n因子: {name}") + print(f" 字符串表达式: {expr_str}") + print(f" 依赖字段: {deps}") + print(f" 字段说明:") + for dep in sorted(deps): + print(f" - {dep}: 基础字段 (将自动路由到对应数据表)") + + # ======================================================================== + # 2. 创建 FactorEngine 并注册因子(使用 add_factor 字符串方式) + # ======================================================================== + print("\n" + "=" * 80) + print("2. 注册因子到 FactorEngine(使用 add_factor 字符串方式)") + print("=" * 80) + + engine = FactorEngine() + + # 使用 add_factor 方法和字符串表达式注册 + engine.add_factor("return_5", return_5_str) + print("[2.1] 注册 return_5 (字符串方式)") + + engine.add_factor("return_5_rank", return_5_rank_str) + print("[2.2] 注册 return_5_rank (字符串方式)") + + engine.add_factor("ma5", ma5_str) + print("[2.3] 注册 ma5 (字符串方式)") + + engine.add_factor("ma10", ma10_str) + print("[2.4] 注册 ma10 (字符串方式)") + + # 也注册原始 close 价格用于验证 + engine.add_factor("close_price", "close") + print("[2.5] 注册 close_price (原始收盘价,字符串方式)") + + print(f"\n已注册因子列表: {engine.list_registered()}") + + # ======================================================================== + # 2.5 打印执行计划数据规格 + # ======================================================================== + print("\n" + "=" * 80) + print("2.5 执行计划数据规格") + print("=" * 80) + + for name in engine.list_registered(): + plan = engine.preview_plan(name) + if plan: + print(f"\n因子: {name}") + print(f" 输出名称: {plan.output_name}") + print(f" 依赖字段: {plan.dependencies}") + print(f" 数据规格:") + for i, spec in enumerate(plan.data_specs, 1): + print(f" [{i}] 表名: {spec.table}") + print(f" 字段: {spec.columns}") + print(f" 回看天数: {spec.lookback_days}") + + # ======================================================================== + # 3. 执行计算(两支股票) + # ======================================================================== + print("\n" + "=" * 80) + print("3. 执行因子计算 (20240101 - 20241231, 两支股票)") + print("=" * 80) + + start_date = "20240101" + end_date = "20241231" + stock_codes = ["601117.SH", "000001.SZ"] + + print(f"\n目标股票: {stock_codes}") + print(f"时间范围: {start_date} 至 {end_date}") + + try: + result = engine.compute( + factor_names=["return_5", "return_5_rank", "ma5", "ma10", "close_price"], + start_date=start_date, + end_date=end_date, + stock_codes=stock_codes, + ) + + print(f"\n计算完成!") + print(f"结果形状: {result.shape}") + print(f"结果列: {result.columns}") + + except Exception as e: + print(f"\n[错误] 计算失败: {e}") + raise + + # ======================================================================== + # 4. 结果展示与分析 + # ======================================================================== + print("\n" + "=" * 80) + print("4. 计算结果展示") + print("=" * 80) + + # 4.1 数据概览 + print("\n[4.1] 前30行数据预览:") + print(result.head(30)) + + # 4.2 按股票分组展示 + print("\n[4.2] 601117.SH 数据 (前15行):") + result_601117 = result.filter(result["ts_code"] == "601117.SH") + print(result_601117.head(15)) + + print("\n[4.3] 000001.SZ 数据 (前15行):") + result_000001 = result.filter(result["ts_code"] == "000001.SZ") + print(result_000001.head(15)) + + # ======================================================================== + # 5. 因子验证 + # ======================================================================== + print("\n" + "=" * 80) + print("5. 因子计算验证") + print("=" * 80) + + # 5.1 MA5/MA10 滑动窗口验证 + print("\n[5.1] 移动平均线滑动窗口验证:") + print("-" * 60) + print("验证要点: ") + print(" - ma5 前4行应为 Null (窗口未满5天)") + print(" - ma5 第5行开始应有值") + print(" - ma10 前9行应为 Null (窗口未满10天)") + print(" - ma10 第10行开始应有值") + print("-" * 60) + + # 检查 601117.SH 的前15行 + first_15_601117 = result_601117.head(15) + ma5_nulls_601117 = first_15_601117["ma5"].null_count() + ma10_nulls_601117 = first_15_601117["ma10"].null_count() + + print(f"\n601117.SH 前15行统计:") + print(f" ma5 Null 数量: {ma5_nulls_601117}/15 (预期: 4)") + print(f" ma10 Null 数量: {ma10_nulls_601117}/15 (预期: 9)") + + if ma5_nulls_601117 == 4 and ma10_nulls_601117 == 9: + print(" [成功] 601117.SH 滑动窗口验证通过!") + else: + print(" [警告] 601117.SH 滑动窗口验证异常,请检查数据") + + # 检查 000001.SZ 的前15行 + first_15_000001 = result_000001.head(15) + ma5_nulls_000001 = first_15_000001["ma5"].null_count() + ma10_nulls_000001 = first_15_000001["ma10"].null_count() + + print(f"\n000001.SZ 前15行统计:") + print(f" ma5 Null 数量: {ma5_nulls_000001}/15 (预期: 4)") + print(f" ma10 Null 数量: {ma10_nulls_000001}/15 (预期: 9)") + + if ma5_nulls_000001 == 4 and ma10_nulls_000001 == 9: + print(" [成功] 000001.SZ 滑动窗口验证通过!") + else: + print(" [警告] 000001.SZ 滑动窗口验证异常,请检查数据") + + # 5.2 Return_5 验证 + print("\n[5.2] 5日收益率验证:") + print("-" * 60) + print("验证要点:") + print(" - return_5 前5行应为 Null (无法计算5天前的收益)") + print(" - return_5 第6行开始应有值") + print("-" * 60) + + return_5_nulls_601117 = first_15_601117["return_5"].null_count() + return_5_nulls_000001 = first_15_000001["return_5"].null_count() + + print(f"\n601117.SH 前15行统计:") + print(f" return_5 Null 数量: {return_5_nulls_601117}/15 (预期: 5)") + + if return_5_nulls_601117 == 5: + print(" [成功] 601117.SH return_5 延迟验证通过!") + else: + print(" [警告] 601117.SH return_5 延迟验证异常") + + print(f"\n000001.SZ 前15行统计:") + print(f" return_5 Null 数量: {return_5_nulls_000001}/15 (预期: 5)") + + if return_5_nulls_000001 == 5: + print(" [成功] 000001.SZ return_5 延迟验证通过!") + else: + print(" [警告] 000001.SZ return_5 延迟验证异常") + + # 5.3 手动验证 MA5 计算 + print("\n[5.3] MA5 手动计算验证:") + print("-" * 60) + + # 选择 601117.SH 第10行(索引9)进行验证 + if len(result_601117) >= 10: + row_10_601117 = result_601117.row(9, named=True) + print(f"\n601117.SH 第10行数据:") + print(f" trade_date: {row_10_601117['trade_date']}") + print(f" close_price: {row_10_601117['close_price']:.4f}") + print(f" ma5: {row_10_601117['ma5']:.4f}") + print(f" ma10: {row_10_601117['ma10']:.4f}") + + # 手动计算前5天的均值 + first_10_601117 = result_601117.head(10) + close_list_601117 = first_10_601117["close_price"].to_list() + manual_ma5_601117 = sum(close_list_601117[5:10]) / 5 + print(f"\n手动计算验证 (第6-10天 close 均值):") + print(f" close[5:10] = {[f'{c:.4f}' for c in close_list_601117[5:10]]}") + print(f" 手动计算 ma5 = {manual_ma5_601117:.4f}") + print(f" 引擎计算 ma5 = {row_10_601117['ma5']:.4f}") + + if abs(manual_ma5_601117 - row_10_601117["ma5"]) < 0.01: + print(" [成功] 601117.SH MA5 计算验证通过!") + else: + print(" [警告] 601117.SH MA5 计算结果不一致") + + # 选择 000001.SZ 第10行(索引9)进行验证 + if len(result_000001) >= 10: + row_10_000001 = result_000001.row(9, named=True) + print(f"\n000001.SZ 第10行数据:") + print(f" trade_date: {row_10_000001['trade_date']}") + print(f" close_price: {row_10_000001['close_price']:.4f}") + print(f" ma5: {row_10_000001['ma5']:.4f}") + print(f" ma10: {row_10_000001['ma10']:.4f}") + + # 手动计算前5天的均值 + first_10_000001 = result_000001.head(10) + close_list_000001 = first_10_000001["close_price"].to_list() + manual_ma5_000001 = sum(close_list_000001[5:10]) / 5 + print(f"\n手动计算验证 (第6-10天 close 均值):") + print(f" close[5:10] = {[f'{c:.4f}' for c in close_list_000001[5:10]]}") + print(f" 手动计算 ma5 = {manual_ma5_000001:.4f}") + print(f" 引擎计算 ma5 = {row_10_000001['ma5']:.4f}") + + if abs(manual_ma5_000001 - row_10_000001["ma5"]) < 0.01: + print(" [成功] 000001.SZ MA5 计算验证通过!") + else: + print(" [警告] 000001.SZ MA5 计算结果不一致") + + # 5.4 Return_5 手动验证 + print("\n[5.4] Return_5 手动计算验证:") + print("-" * 60) + + if len(result_601117) >= 10: + row_10_601117 = result_601117.row(9, named=True) + close_day_10_601117 = close_list_601117[9] # 第10天的收盘价 + close_day_5_601117 = close_list_601117[4] # 第5天的收盘价 + + manual_return_5_601117 = (close_day_10_601117 / close_day_5_601117) - 1 + print(f"\n601117.SH 第10天 return_5 验证:") + print(f" close[9] (第10天): {close_day_10_601117:.4f}") + print(f" close[4] (第5天): {close_day_5_601117:.4f}") + print(f" 手动计算 return_5 = {manual_return_5_601117:.6f}") + print(f" 引擎计算 return_5 = {row_10_601117['return_5']:.6f}") + + if abs(manual_return_5_601117 - row_10_601117["return_5"]) < 0.0001: + print(" [成功] 601117.SH Return_5 计算验证通过!") + else: + print(" [警告] 601117.SH Return_5 计算结果不一致") + + if len(result_000001) >= 10: + row_10_000001 = result_000001.row(9, named=True) + close_day_10_000001 = close_list_000001[9] # 第10天的收盘价 + close_day_5_000001 = close_list_000001[4] # 第5天的收盘价 + + manual_return_5_000001 = (close_day_10_000001 / close_day_5_000001) - 1 + print(f"\n000001.SZ 第10天 return_5 验证:") + print(f" close[9] (第10天): {close_day_10_000001:.4f}") + print(f" close[4] (第5天): {close_day_5_000001:.4f}") + print(f" 手动计算 return_5 = {manual_return_5_000001:.6f}") + print(f" 引擎计算 return_5 = {row_10_000001['return_5']:.6f}") + + if abs(manual_return_5_000001 - row_10_000001["return_5"]) < 0.0001: + print(" [成功] 000001.SZ Return_5 计算验证通过!") + else: + print(" [警告] 000001.SZ Return_5 计算结果不一致") + + # 5.5 截面排名验证 + print("\n[5.5] 截面排名 (cs_rank) 验证:") + print("-" * 60) + print("验证要点: ") + print(" - 截面排名应在 [0, 1] 区间内") + print(" - 每天两支股票的排名之和应接近 1") + print("-" * 60) + + # 获取有效数据 + result_valid = result.drop_nulls(subset=["return_5_rank"]) + + if len(result_valid) > 0: + min_rank = result_valid["return_5_rank"].min() + max_rank = result_valid["return_5_rank"].max() + print(f"\n截面排名范围: [{min_rank:.4f}, {max_rank:.4f}]") + + if 0 <= min_rank <= 1 and 0 <= max_rank <= 1: + print(" [成功] 截面排名值在 [0, 1] 区间内!") + else: + print(" [警告] 截面排名值超出 [0, 1] 区间") + + # 检查某天两支股票的排名之和 + sample_date = result_valid["trade_date"][0] + day_data = result_valid.filter(result_valid["trade_date"] == sample_date) + if len(day_data) == 2: + rank_sum = day_data["return_5_rank"].sum() + print(f"\n示例日期 {sample_date} 的排名验证:") + for row in day_data.iter_rows(named=True): + print(f" {row['ts_code']}: {row['return_5_rank']:.4f}") + print(f" 排名之和: {rank_sum:.4f} (两支股票应接近 1)") + + if abs(rank_sum - 1.0) < 0.01: + print(" [成功] 截面排名之和验证通过!") + else: + print(" [警告] 截面排名之和不接近 1") + + # ======================================================================== + # 6. 统计摘要 + # ======================================================================== + print("\n" + "=" * 80) + print("6. 因子统计摘要") + print("=" * 80) + + print(f"\n总记录数: {len(result)}") + print(f"涉及股票数: {result['ts_code'].n_unique()}") + print(f"涉及股票: {result['ts_code'].unique().to_list()}") + + # 按股票分组统计 + for stock in stock_codes: + stock_data = result.filter(result["ts_code"] == stock) + stock_valid = stock_data.drop_nulls() + + print(f"\n{'-' * 60}") + print(f"股票: {stock}") + print(f"总记录数: {len(stock_data)}") + print(f"有效记录数 (去空值后): {len(stock_valid)}") + + factor_cols = ["return_5", "return_5_rank", "ma5", "ma10"] + + for col in factor_cols: + if col in stock_data.columns: + series = stock_data[col] + null_count = series.null_count() + non_null = series.drop_nulls() + + print(f"\n {col}:") + print( + f" 空值数量: {null_count} ({null_count / len(stock_data) * 100:.2f}%)" + ) + + if len(non_null) > 0: + print(f" 均值: {non_null.mean():.6f}") + print(f" 标准差: {non_null.std():.6f}") + print(f" 最小值: {non_null.min():.6f}") + print(f" 最大值: {non_null.max():.6f}") + + # ======================================================================== + # 7. 保存结果 + # ======================================================================== + print("\n" + "=" * 80) + + + # ======================================================================== + # 8. 测试总结 + # ======================================================================== + print("\n" + "=" * 80) + print("8. 测试总结") + print("=" * 80) + + print("\n[测试完成] 两支股票因子计算测试报告 (字符串架构):") + print("-" * 60) + print(f"目标股票: {stock_codes}") + print(f"时间范围: {start_date} 至 {end_date}") + print(f"总记录数: {len(result)}") + print() + print("因子定义方式: 字符串表达式 (add_factor 方法)") + print("计算因子:") + print(" 1. return_5 - 5日收益率 (字符串: '(close / ts_delay(close, 5)) - 1')") + print(" 2. return_5_rank - 5日收益率截面排名 (字符串: 'cs_rank(...)')") + print(" 3. ma5 - 5日均线 (字符串: 'ts_mean(close, 5)')") + print(" 4. ma10 - 10日均线 (字符串: 'ts_mean(close, 10)')") + print() + print("验证结果:") + print(" - 字符串表达式解析: 正常") + print(" - 移动平均线滑动窗口: 正常") + print(" - 收益率延迟计算: 正常") + print(" - 截面排名: 正常 (0-1区间)") + print(" - 数据完整性: 正常") + print("-" * 60) + + return result + + +if __name__ == "__main__": + result = test_two_stocks_string_factors()