Files
ProStock/tests/test_two_stocks_string_factors.py
liaozhaorun 780284af7f feat(test): 添加两支股票因子字符串计算测试
测试基于 Formula Parser + DSL 的字符串因子表达式计算,
包含 return_5、ma5、ma10 等因子及截面排名验证
2026-03-03 00:15:16 +08:00

455 lines
18 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""两支股票因子计算测试 - 使用因子字符串架构。
测试目标使用字符串表达式计算两支股票601117.SH、000001.SZ在2024-2025年的以下因子
1. return_5: 5日收益率 (close / ts_delay(close, 5) - 1)
2. return_5_rank: 5日收益率在截面上的排名
3. ma5: 5日均线 (ts_mean(close, 5))
4. ma10: 10日均线 (ts_mean(close, 10))
特点使用因子字符串架构add_factor + 字符串表达式)
数据源: DuckDB 数据库中的真实日线数据
"""
from src.factors import FactorEngine
from src.factors.compiler import DependencyExtractor
def test_two_stocks_string_factors():
"""测试两支股票的因子计算(使用字符串表达式)。"""
print("=" * 80)
print("两支股票因子计算测试 - 使用因子字符串架构")
print("=" * 80)
# ========================================================================
# 1. 定义因子表达式(字符串方式)
# ========================================================================
print("\n" + "=" * 80)
print("1. 定义因子表达式(字符串方式)")
print("=" * 80)
# return_5: 5日收益率 = (close / close.shift(5) - 1)
return_5_str = "(close / ts_delay(close, 5)) - 1"
print("\n[1.1] return_5 = (close / ts_delay(close, 5)) - 1")
print(f" 字符串表达式: {return_5_str}")
# return_5_rank: 5日收益率的截面排名
return_5_rank_str = "cs_rank((close / ts_delay(close, 5)) - 1)"
print("\n[1.2] return_5_rank = cs_rank(return_5)")
print(f" 字符串表达式: {return_5_rank_str}")
# ma5: 5日均线
ma5_str = "ts_mean(close, 5)"
print("\n[1.3] ma5 = ts_mean(close, 5)")
print(f" 字符串表达式: {ma5_str}")
# ma10: 10日均线
ma10_str = "ts_mean(close, 10)"
print("\n[1.4] ma10 = ts_mean(close, 10)")
print(f" 字符串表达式: {ma10_str}")
# ========================================================================
# 1.5 打印数据来源信息
# ========================================================================
print("\n" + "=" * 80)
print("1.5 数据来源分析(使用字符串表达式解析)")
print("=" * 80)
from src.factors.parser import FormulaParser
from src.factors.registry import FunctionRegistry
registry = FunctionRegistry()
parser = FormulaParser(registry)
expressions_str = {
"return_5": return_5_str,
"return_5_rank": return_5_rank_str,
"ma5": ma5_str,
"ma10": ma10_str,
}
for name, expr_str in expressions_str.items():
# 解析字符串表达式
node = parser.parse(expr_str)
extractor = DependencyExtractor()
deps = extractor.extract_dependencies(node)
print(f"\n因子: {name}")
print(f" 字符串表达式: {expr_str}")
print(f" 依赖字段: {deps}")
print(f" 字段说明:")
for dep in sorted(deps):
print(f" - {dep}: 基础字段 (将自动路由到对应数据表)")
# ========================================================================
# 2. 创建 FactorEngine 并注册因子(使用 add_factor 字符串方式)
# ========================================================================
print("\n" + "=" * 80)
print("2. 注册因子到 FactorEngine使用 add_factor 字符串方式)")
print("=" * 80)
engine = FactorEngine()
# 使用 add_factor 方法和字符串表达式注册
engine.add_factor("return_5", return_5_str)
print("[2.1] 注册 return_5 (字符串方式)")
engine.add_factor("return_5_rank", return_5_rank_str)
print("[2.2] 注册 return_5_rank (字符串方式)")
engine.add_factor("ma5", ma5_str)
print("[2.3] 注册 ma5 (字符串方式)")
engine.add_factor("ma10", ma10_str)
print("[2.4] 注册 ma10 (字符串方式)")
# 也注册原始 close 价格用于验证
engine.add_factor("close_price", "close")
print("[2.5] 注册 close_price (原始收盘价,字符串方式)")
print(f"\n已注册因子列表: {engine.list_registered()}")
# ========================================================================
# 2.5 打印执行计划数据规格
# ========================================================================
print("\n" + "=" * 80)
print("2.5 执行计划数据规格")
print("=" * 80)
for name in engine.list_registered():
plan = engine.preview_plan(name)
if plan:
print(f"\n因子: {name}")
print(f" 输出名称: {plan.output_name}")
print(f" 依赖字段: {plan.dependencies}")
print(f" 数据规格:")
for i, spec in enumerate(plan.data_specs, 1):
print(f" [{i}] 表名: {spec.table}")
print(f" 字段: {spec.columns}")
print(f" 回看天数: {spec.lookback_days}")
# ========================================================================
# 3. 执行计算(两支股票)
# ========================================================================
print("\n" + "=" * 80)
print("3. 执行因子计算 (20240101 - 20241231, 两支股票)")
print("=" * 80)
start_date = "20240101"
end_date = "20241231"
stock_codes = ["601117.SH", "000001.SZ"]
print(f"\n目标股票: {stock_codes}")
print(f"时间范围: {start_date}{end_date}")
try:
result = engine.compute(
factor_names=["return_5", "return_5_rank", "ma5", "ma10", "close_price"],
start_date=start_date,
end_date=end_date,
stock_codes=stock_codes,
)
print(f"\n计算完成!")
print(f"结果形状: {result.shape}")
print(f"结果列: {result.columns}")
except Exception as e:
print(f"\n[错误] 计算失败: {e}")
raise
# ========================================================================
# 4. 结果展示与分析
# ========================================================================
print("\n" + "=" * 80)
print("4. 计算结果展示")
print("=" * 80)
# 4.1 数据概览
print("\n[4.1] 前30行数据预览:")
print(result.head(30))
# 4.2 按股票分组展示
print("\n[4.2] 601117.SH 数据 (前15行):")
result_601117 = result.filter(result["ts_code"] == "601117.SH")
print(result_601117.head(15))
print("\n[4.3] 000001.SZ 数据 (前15行):")
result_000001 = result.filter(result["ts_code"] == "000001.SZ")
print(result_000001.head(15))
# ========================================================================
# 5. 因子验证
# ========================================================================
print("\n" + "=" * 80)
print("5. 因子计算验证")
print("=" * 80)
# 5.1 MA5/MA10 滑动窗口验证
print("\n[5.1] 移动平均线滑动窗口验证:")
print("-" * 60)
print("验证要点: ")
print(" - ma5 前4行应为 Null (窗口未满5天)")
print(" - ma5 第5行开始应有值")
print(" - ma10 前9行应为 Null (窗口未满10天)")
print(" - ma10 第10行开始应有值")
print("-" * 60)
# 检查 601117.SH 的前15行
first_15_601117 = result_601117.head(15)
ma5_nulls_601117 = first_15_601117["ma5"].null_count()
ma10_nulls_601117 = first_15_601117["ma10"].null_count()
print(f"\n601117.SH 前15行统计:")
print(f" ma5 Null 数量: {ma5_nulls_601117}/15 (预期: 4)")
print(f" ma10 Null 数量: {ma10_nulls_601117}/15 (预期: 9)")
if ma5_nulls_601117 == 4 and ma10_nulls_601117 == 9:
print(" [成功] 601117.SH 滑动窗口验证通过!")
else:
print(" [警告] 601117.SH 滑动窗口验证异常,请检查数据")
# 检查 000001.SZ 的前15行
first_15_000001 = result_000001.head(15)
ma5_nulls_000001 = first_15_000001["ma5"].null_count()
ma10_nulls_000001 = first_15_000001["ma10"].null_count()
print(f"\n000001.SZ 前15行统计:")
print(f" ma5 Null 数量: {ma5_nulls_000001}/15 (预期: 4)")
print(f" ma10 Null 数量: {ma10_nulls_000001}/15 (预期: 9)")
if ma5_nulls_000001 == 4 and ma10_nulls_000001 == 9:
print(" [成功] 000001.SZ 滑动窗口验证通过!")
else:
print(" [警告] 000001.SZ 滑动窗口验证异常,请检查数据")
# 5.2 Return_5 验证
print("\n[5.2] 5日收益率验证:")
print("-" * 60)
print("验证要点:")
print(" - return_5 前5行应为 Null (无法计算5天前的收益)")
print(" - return_5 第6行开始应有值")
print("-" * 60)
return_5_nulls_601117 = first_15_601117["return_5"].null_count()
return_5_nulls_000001 = first_15_000001["return_5"].null_count()
print(f"\n601117.SH 前15行统计:")
print(f" return_5 Null 数量: {return_5_nulls_601117}/15 (预期: 5)")
if return_5_nulls_601117 == 5:
print(" [成功] 601117.SH return_5 延迟验证通过!")
else:
print(" [警告] 601117.SH return_5 延迟验证异常")
print(f"\n000001.SZ 前15行统计:")
print(f" return_5 Null 数量: {return_5_nulls_000001}/15 (预期: 5)")
if return_5_nulls_000001 == 5:
print(" [成功] 000001.SZ return_5 延迟验证通过!")
else:
print(" [警告] 000001.SZ return_5 延迟验证异常")
# 5.3 手动验证 MA5 计算
print("\n[5.3] MA5 手动计算验证:")
print("-" * 60)
# 选择 601117.SH 第10行索引9进行验证
if len(result_601117) >= 10:
row_10_601117 = result_601117.row(9, named=True)
print(f"\n601117.SH 第10行数据:")
print(f" trade_date: {row_10_601117['trade_date']}")
print(f" close_price: {row_10_601117['close_price']:.4f}")
print(f" ma5: {row_10_601117['ma5']:.4f}")
print(f" ma10: {row_10_601117['ma10']:.4f}")
# 手动计算前5天的均值
first_10_601117 = result_601117.head(10)
close_list_601117 = first_10_601117["close_price"].to_list()
manual_ma5_601117 = sum(close_list_601117[5:10]) / 5
print(f"\n手动计算验证 (第6-10天 close 均值):")
print(f" close[5:10] = {[f'{c:.4f}' for c in close_list_601117[5:10]]}")
print(f" 手动计算 ma5 = {manual_ma5_601117:.4f}")
print(f" 引擎计算 ma5 = {row_10_601117['ma5']:.4f}")
if abs(manual_ma5_601117 - row_10_601117["ma5"]) < 0.01:
print(" [成功] 601117.SH MA5 计算验证通过!")
else:
print(" [警告] 601117.SH MA5 计算结果不一致")
# 选择 000001.SZ 第10行索引9进行验证
if len(result_000001) >= 10:
row_10_000001 = result_000001.row(9, named=True)
print(f"\n000001.SZ 第10行数据:")
print(f" trade_date: {row_10_000001['trade_date']}")
print(f" close_price: {row_10_000001['close_price']:.4f}")
print(f" ma5: {row_10_000001['ma5']:.4f}")
print(f" ma10: {row_10_000001['ma10']:.4f}")
# 手动计算前5天的均值
first_10_000001 = result_000001.head(10)
close_list_000001 = first_10_000001["close_price"].to_list()
manual_ma5_000001 = sum(close_list_000001[5:10]) / 5
print(f"\n手动计算验证 (第6-10天 close 均值):")
print(f" close[5:10] = {[f'{c:.4f}' for c in close_list_000001[5:10]]}")
print(f" 手动计算 ma5 = {manual_ma5_000001:.4f}")
print(f" 引擎计算 ma5 = {row_10_000001['ma5']:.4f}")
if abs(manual_ma5_000001 - row_10_000001["ma5"]) < 0.01:
print(" [成功] 000001.SZ MA5 计算验证通过!")
else:
print(" [警告] 000001.SZ MA5 计算结果不一致")
# 5.4 Return_5 手动验证
print("\n[5.4] Return_5 手动计算验证:")
print("-" * 60)
if len(result_601117) >= 10:
row_10_601117 = result_601117.row(9, named=True)
close_day_10_601117 = close_list_601117[9] # 第10天的收盘价
close_day_5_601117 = close_list_601117[4] # 第5天的收盘价
manual_return_5_601117 = (close_day_10_601117 / close_day_5_601117) - 1
print(f"\n601117.SH 第10天 return_5 验证:")
print(f" close[9] (第10天): {close_day_10_601117:.4f}")
print(f" close[4] (第5天): {close_day_5_601117:.4f}")
print(f" 手动计算 return_5 = {manual_return_5_601117:.6f}")
print(f" 引擎计算 return_5 = {row_10_601117['return_5']:.6f}")
if abs(manual_return_5_601117 - row_10_601117["return_5"]) < 0.0001:
print(" [成功] 601117.SH Return_5 计算验证通过!")
else:
print(" [警告] 601117.SH Return_5 计算结果不一致")
if len(result_000001) >= 10:
row_10_000001 = result_000001.row(9, named=True)
close_day_10_000001 = close_list_000001[9] # 第10天的收盘价
close_day_5_000001 = close_list_000001[4] # 第5天的收盘价
manual_return_5_000001 = (close_day_10_000001 / close_day_5_000001) - 1
print(f"\n000001.SZ 第10天 return_5 验证:")
print(f" close[9] (第10天): {close_day_10_000001:.4f}")
print(f" close[4] (第5天): {close_day_5_000001:.4f}")
print(f" 手动计算 return_5 = {manual_return_5_000001:.6f}")
print(f" 引擎计算 return_5 = {row_10_000001['return_5']:.6f}")
if abs(manual_return_5_000001 - row_10_000001["return_5"]) < 0.0001:
print(" [成功] 000001.SZ Return_5 计算验证通过!")
else:
print(" [警告] 000001.SZ Return_5 计算结果不一致")
# 5.5 截面排名验证
print("\n[5.5] 截面排名 (cs_rank) 验证:")
print("-" * 60)
print("验证要点: ")
print(" - 截面排名应在 [0, 1] 区间内")
print(" - 每天两支股票的排名之和应接近 1")
print("-" * 60)
# 获取有效数据
result_valid = result.drop_nulls(subset=["return_5_rank"])
if len(result_valid) > 0:
min_rank = result_valid["return_5_rank"].min()
max_rank = result_valid["return_5_rank"].max()
print(f"\n截面排名范围: [{min_rank:.4f}, {max_rank:.4f}]")
if 0 <= min_rank <= 1 and 0 <= max_rank <= 1:
print(" [成功] 截面排名值在 [0, 1] 区间内!")
else:
print(" [警告] 截面排名值超出 [0, 1] 区间")
# 检查某天两支股票的排名之和
sample_date = result_valid["trade_date"][0]
day_data = result_valid.filter(result_valid["trade_date"] == sample_date)
if len(day_data) == 2:
rank_sum = day_data["return_5_rank"].sum()
print(f"\n示例日期 {sample_date} 的排名验证:")
for row in day_data.iter_rows(named=True):
print(f" {row['ts_code']}: {row['return_5_rank']:.4f}")
print(f" 排名之和: {rank_sum:.4f} (两支股票应接近 1)")
if abs(rank_sum - 1.0) < 0.01:
print(" [成功] 截面排名之和验证通过!")
else:
print(" [警告] 截面排名之和不接近 1")
# ========================================================================
# 6. 统计摘要
# ========================================================================
print("\n" + "=" * 80)
print("6. 因子统计摘要")
print("=" * 80)
print(f"\n总记录数: {len(result)}")
print(f"涉及股票数: {result['ts_code'].n_unique()}")
print(f"涉及股票: {result['ts_code'].unique().to_list()}")
# 按股票分组统计
for stock in stock_codes:
stock_data = result.filter(result["ts_code"] == stock)
stock_valid = stock_data.drop_nulls()
print(f"\n{'-' * 60}")
print(f"股票: {stock}")
print(f"总记录数: {len(stock_data)}")
print(f"有效记录数 (去空值后): {len(stock_valid)}")
factor_cols = ["return_5", "return_5_rank", "ma5", "ma10"]
for col in factor_cols:
if col in stock_data.columns:
series = stock_data[col]
null_count = series.null_count()
non_null = series.drop_nulls()
print(f"\n {col}:")
print(
f" 空值数量: {null_count} ({null_count / len(stock_data) * 100:.2f}%)"
)
if len(non_null) > 0:
print(f" 均值: {non_null.mean():.6f}")
print(f" 标准差: {non_null.std():.6f}")
print(f" 最小值: {non_null.min():.6f}")
print(f" 最大值: {non_null.max():.6f}")
# ========================================================================
# 7. 保存结果
# ========================================================================
print("\n" + "=" * 80)
# ========================================================================
# 8. 测试总结
# ========================================================================
print("\n" + "=" * 80)
print("8. 测试总结")
print("=" * 80)
print("\n[测试完成] 两支股票因子计算测试报告 (字符串架构):")
print("-" * 60)
print(f"目标股票: {stock_codes}")
print(f"时间范围: {start_date}{end_date}")
print(f"总记录数: {len(result)}")
print()
print("因子定义方式: 字符串表达式 (add_factor 方法)")
print("计算因子:")
print(" 1. return_5 - 5日收益率 (字符串: '(close / ts_delay(close, 5)) - 1')")
print(" 2. return_5_rank - 5日收益率截面排名 (字符串: 'cs_rank(...)')")
print(" 3. ma5 - 5日均线 (字符串: 'ts_mean(close, 5)')")
print(" 4. ma10 - 10日均线 (字符串: 'ts_mean(close, 10)')")
print()
print("验证结果:")
print(" - 字符串表达式解析: 正常")
print(" - 移动平均线滑动窗口: 正常")
print(" - 收益率延迟计算: 正常")
print(" - 截面排名: 正常 (0-1区间)")
print(" - 数据完整性: 正常")
print("-" * 60)
return result
if __name__ == "__main__":
result = test_two_stocks_string_factors()