Files
ProStock/tests/test_two_stocks_string_factors.py
liaozhaorun 53225b9443 feat(data): 添加每日指标接口并优化因子引擎
- 新增 api_daily_basic.py 封装 Tushare 每日指标接口
- 因子引擎移除 lookback_days,支持 daily_basic 表字段路由
- 将每日指标纳入自动同步流程
- 删除废弃的 training/main.py
2026-03-03 17:09:39 +08:00

510 lines
20 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""两支股票因子计算测试 - 使用因子字符串架构。
测试目标使用字符串表达式计算两支股票601117.SH、000001.SZ在2024-2025年的以下因子
1. return_5: 5日收益率 (close / ts_delay(close, 5) - 1)
2. return_5_rank: 5日收益率在截面上的排名
3. ma5: 5日均线 (ts_mean(close, 5))
4. ma10: 10日均线 (ts_mean(close, 10))
5. market_cap_rank: 市值百分比排名 (cs_rank(total_mv))
特点使用因子字符串架构add_factor + 字符串表达式)
数据源: DuckDB 数据库中的真实日线数据
"""
from src.factors import FactorEngine
from src.factors.compiler import DependencyExtractor
def test_two_stocks_string_factors():
"""测试两支股票的因子计算(使用字符串表达式)。"""
print("=" * 80)
print("两支股票因子计算测试 - 使用因子字符串架构")
print("=" * 80)
# ========================================================================
# 1. 定义因子表达式(字符串方式)
# ========================================================================
print("\n" + "=" * 80)
print("1. 定义因子表达式(字符串方式)")
print("=" * 80)
# return_5: 5日收益率 = (close / close.shift(5) - 1)
return_5_str = "(close / ts_delay(close, 5)) - 1"
print("\n[1.1] return_5 = (close / ts_delay(close, 5)) - 1")
print(f" 字符串表达式: {return_5_str}")
# return_5_rank: 5日收益率的截面排名
return_5_rank_str = "cs_rank((close / ts_delay(close, 5)) - 1)"
print("\n[1.2] return_5_rank = cs_rank(return_5)")
print(f" 字符串表达式: {return_5_rank_str}")
# ma5: 5日均线
ma5_str = "ts_mean(close, 5)"
print("\n[1.3] ma5 = ts_mean(close, 5)")
print(f" 字符串表达式: {ma5_str}")
# ma10: 10日均线
ma10_str = "ts_mean(close, 10)"
print("\n[1.4] ma10 = ts_mean(close, 10)")
print(f" 字符串表达式: {ma10_str}")
# market_cap_rank: 市值百分比排名 (截面排名)
market_cap_rank_str = "cs_rank(total_mv)"
print("\n[1.5] market_cap_rank = cs_rank(total_mv)")
print(f" 字符串表达式: {market_cap_rank_str}")
# ========================================================================
# 1.5 打印数据来源信息
# ========================================================================
print("\n" + "=" * 80)
print("1.5 数据来源分析(使用字符串表达式解析)")
print("=" * 80)
from src.factors.parser import FormulaParser
from src.factors.registry import FunctionRegistry
registry = FunctionRegistry()
parser = FormulaParser(registry)
expressions_str = {
"return_5": return_5_str,
"return_5_rank": return_5_rank_str,
"ma5": ma5_str,
"ma10": ma10_str,
"market_cap_rank": market_cap_rank_str,
}
for name, expr_str in expressions_str.items():
# 解析字符串表达式
node = parser.parse(expr_str)
extractor = DependencyExtractor()
deps = extractor.extract_dependencies(node)
print(f"\n因子: {name}")
print(f" 字符串表达式: {expr_str}")
print(f" 依赖字段: {deps}")
print(f" 字段说明:")
for dep in sorted(deps):
print(f" - {dep}: 基础字段 (将自动路由到对应数据表)")
# ========================================================================
# 2. 创建 FactorEngine 并注册因子(使用 add_factor 字符串方式)
# ========================================================================
print("\n" + "=" * 80)
print("2. 注册因子到 FactorEngine使用 add_factor 字符串方式)")
print("=" * 80)
engine = FactorEngine()
# 使用 add_factor 方法和字符串表达式注册
engine.add_factor("return_5", return_5_str)
print("[2.1] 注册 return_5 (字符串方式)")
engine.add_factor("return_5_rank", return_5_rank_str)
print("[2.2] 注册 return_5_rank (字符串方式)")
engine.add_factor("ma5", ma5_str)
print("[2.3] 注册 ma5 (字符串方式)")
engine.add_factor("ma10", ma10_str)
print("[2.4] 注册 ma10 (字符串方式)")
engine.add_factor("market_cap_rank", market_cap_rank_str)
print("[2.5] 注册 market_cap_rank (市值百分比排名,字符串方式)")
# 也注册原始 close 价格用于验证
engine.add_factor("close_price", "close")
print("[2.6] 注册 close_price (原始收盘价,字符串方式)")
print(f"\n已注册因子列表: {engine.list_registered()}")
# ========================================================================
# 2.5 打印执行计划数据规格
# ========================================================================
print("\n" + "=" * 80)
print("2.5 执行计划数据规格")
print("=" * 80)
for name in engine.list_registered():
plan = engine.preview_plan(name)
if plan:
print(f"\n因子: {name}")
print(f" 输出名称: {plan.output_name}")
print(f" 依赖字段: {plan.dependencies}")
print(f" 数据规格:")
for i, spec in enumerate(plan.data_specs, 1):
print(f" [{i}] 表名: {spec.table}")
print(f" 字段: {spec.columns}")
# ========================================================================
# 3. 执行计算(两支股票)
# ========================================================================
print("\n" + "=" * 80)
print("3. 执行因子计算 (20240101 - 20241231, 两支股票)")
print("=" * 80)
start_date = "20240101"
end_date = "20241231"
stock_codes = ["601117.SH", "000001.SZ"]
print(f"\n目标股票: {stock_codes}")
print(f"时间范围: {start_date}{end_date}")
try:
result = engine.compute(
factor_names=[
"return_5",
"return_5_rank",
"ma5",
"ma10",
"close_price",
"market_cap_rank",
],
start_date=start_date,
end_date=end_date,
stock_codes=stock_codes,
)
print(f"\n计算完成!")
print(f"结果形状: {result.shape}")
print(f"结果列: {result.columns}")
except Exception as e:
print(f"\n[错误] 计算失败: {e}")
raise
# ========================================================================
# 4. 结果展示与分析
# ========================================================================
print("\n" + "=" * 80)
print("4. 计算结果展示")
print("=" * 80)
# 4.1 数据概览
print("\n[4.1] 前30行数据预览:")
print(result.head(30))
# 4.2 按股票分组展示
print("\n[4.2] 601117.SH 数据 (前15行):")
result_601117 = result.filter(result["ts_code"] == "601117.SH")
print(result_601117.head(15))
print("\n[4.3] 000001.SZ 数据 (前15行):")
result_000001 = result.filter(result["ts_code"] == "000001.SZ")
print(result_000001.head(15))
# ========================================================================
# 5. 因子验证
# ========================================================================
print("\n" + "=" * 80)
print("5. 因子计算验证")
print("=" * 80)
# 5.1 MA5/MA10 滑动窗口验证
print("\n[5.1] 移动平均线滑动窗口验证:")
print("-" * 60)
print("验证要点: ")
print(" - ma5 前4行应为 Null (窗口未满5天)")
print(" - ma5 第5行开始应有值")
print(" - ma10 前9行应为 Null (窗口未满10天)")
print(" - ma10 第10行开始应有值")
print("-" * 60)
# 检查 601117.SH 的前15行
first_15_601117 = result_601117.head(15)
ma5_nulls_601117 = first_15_601117["ma5"].null_count()
ma10_nulls_601117 = first_15_601117["ma10"].null_count()
print(f"\n601117.SH 前15行统计:")
print(f" ma5 Null 数量: {ma5_nulls_601117}/15 (预期: 4)")
print(f" ma10 Null 数量: {ma10_nulls_601117}/15 (预期: 9)")
if ma5_nulls_601117 == 4 and ma10_nulls_601117 == 9:
print(" [成功] 601117.SH 滑动窗口验证通过!")
else:
print(" [警告] 601117.SH 滑动窗口验证异常,请检查数据")
# 检查 000001.SZ 的前15行
first_15_000001 = result_000001.head(15)
ma5_nulls_000001 = first_15_000001["ma5"].null_count()
ma10_nulls_000001 = first_15_000001["ma10"].null_count()
print(f"\n000001.SZ 前15行统计:")
print(f" ma5 Null 数量: {ma5_nulls_000001}/15 (预期: 4)")
print(f" ma10 Null 数量: {ma10_nulls_000001}/15 (预期: 9)")
if ma5_nulls_000001 == 4 and ma10_nulls_000001 == 9:
print(" [成功] 000001.SZ 滑动窗口验证通过!")
else:
print(" [警告] 000001.SZ 滑动窗口验证异常,请检查数据")
# 5.2 Return_5 验证
print("\n[5.2] 5日收益率验证:")
print("-" * 60)
print("验证要点:")
print(" - return_5 前5行应为 Null (无法计算5天前的收益)")
print(" - return_5 第6行开始应有值")
print("-" * 60)
return_5_nulls_601117 = first_15_601117["return_5"].null_count()
return_5_nulls_000001 = first_15_000001["return_5"].null_count()
print(f"\n601117.SH 前15行统计:")
print(f" return_5 Null 数量: {return_5_nulls_601117}/15 (预期: 5)")
if return_5_nulls_601117 == 5:
print(" [成功] 601117.SH return_5 延迟验证通过!")
else:
print(" [警告] 601117.SH return_5 延迟验证异常")
print(f"\n000001.SZ 前15行统计:")
print(f" return_5 Null 数量: {return_5_nulls_000001}/15 (预期: 5)")
if return_5_nulls_000001 == 5:
print(" [成功] 000001.SZ return_5 延迟验证通过!")
else:
print(" [警告] 000001.SZ return_5 延迟验证异常")
# 5.3 手动验证 MA5 计算
print("\n[5.3] MA5 手动计算验证:")
print("-" * 60)
# 选择 601117.SH 第10行索引9进行验证
if len(result_601117) >= 10:
row_10_601117 = result_601117.row(9, named=True)
print(f"\n601117.SH 第10行数据:")
print(f" trade_date: {row_10_601117['trade_date']}")
print(f" close_price: {row_10_601117['close_price']:.4f}")
print(f" ma5: {row_10_601117['ma5']:.4f}")
print(f" ma10: {row_10_601117['ma10']:.4f}")
# 手动计算前5天的均值
first_10_601117 = result_601117.head(10)
close_list_601117 = first_10_601117["close_price"].to_list()
manual_ma5_601117 = sum(close_list_601117[5:10]) / 5
print(f"\n手动计算验证 (第6-10天 close 均值):")
print(f" close[5:10] = {[f'{c:.4f}' for c in close_list_601117[5:10]]}")
print(f" 手动计算 ma5 = {manual_ma5_601117:.4f}")
print(f" 引擎计算 ma5 = {row_10_601117['ma5']:.4f}")
if abs(manual_ma5_601117 - row_10_601117["ma5"]) < 0.01:
print(" [成功] 601117.SH MA5 计算验证通过!")
else:
print(" [警告] 601117.SH MA5 计算结果不一致")
# 选择 000001.SZ 第10行索引9进行验证
if len(result_000001) >= 10:
row_10_000001 = result_000001.row(9, named=True)
print(f"\n000001.SZ 第10行数据:")
print(f" trade_date: {row_10_000001['trade_date']}")
print(f" close_price: {row_10_000001['close_price']:.4f}")
print(f" ma5: {row_10_000001['ma5']:.4f}")
print(f" ma10: {row_10_000001['ma10']:.4f}")
# 手动计算前5天的均值
first_10_000001 = result_000001.head(10)
close_list_000001 = first_10_000001["close_price"].to_list()
manual_ma5_000001 = sum(close_list_000001[5:10]) / 5
print(f"\n手动计算验证 (第6-10天 close 均值):")
print(f" close[5:10] = {[f'{c:.4f}' for c in close_list_000001[5:10]]}")
print(f" 手动计算 ma5 = {manual_ma5_000001:.4f}")
print(f" 引擎计算 ma5 = {row_10_000001['ma5']:.4f}")
if abs(manual_ma5_000001 - row_10_000001["ma5"]) < 0.01:
print(" [成功] 000001.SZ MA5 计算验证通过!")
else:
print(" [警告] 000001.SZ MA5 计算结果不一致")
# 5.4 Return_5 手动验证
print("\n[5.4] Return_5 手动计算验证:")
print("-" * 60)
if len(result_601117) >= 10:
row_10_601117 = result_601117.row(9, named=True)
close_day_10_601117 = close_list_601117[9] # 第10天的收盘价
close_day_5_601117 = close_list_601117[4] # 第5天的收盘价
manual_return_5_601117 = (close_day_10_601117 / close_day_5_601117) - 1
print(f"\n601117.SH 第10天 return_5 验证:")
print(f" close[9] (第10天): {close_day_10_601117:.4f}")
print(f" close[4] (第5天): {close_day_5_601117:.4f}")
print(f" 手动计算 return_5 = {manual_return_5_601117:.6f}")
print(f" 引擎计算 return_5 = {row_10_601117['return_5']:.6f}")
if abs(manual_return_5_601117 - row_10_601117["return_5"]) < 0.0001:
print(" [成功] 601117.SH Return_5 计算验证通过!")
else:
print(" [警告] 601117.SH Return_5 计算结果不一致")
if len(result_000001) >= 10:
row_10_000001 = result_000001.row(9, named=True)
close_day_10_000001 = close_list_000001[9] # 第10天的收盘价
close_day_5_000001 = close_list_000001[4] # 第5天的收盘价
manual_return_5_000001 = (close_day_10_000001 / close_day_5_000001) - 1
print(f"\n000001.SZ 第10天 return_5 验证:")
print(f" close[9] (第10天): {close_day_10_000001:.4f}")
print(f" close[4] (第5天): {close_day_5_000001:.4f}")
print(f" 手动计算 return_5 = {manual_return_5_000001:.6f}")
print(f" 引擎计算 return_5 = {row_10_000001['return_5']:.6f}")
if abs(manual_return_5_000001 - row_10_000001["return_5"]) < 0.0001:
print(" [成功] 000001.SZ Return_5 计算验证通过!")
else:
print(" [警告] 000001.SZ Return_5 计算结果不一致")
# 5.5 截面排名验证
print("\n[5.5] 截面排名 (cs_rank) 验证:")
print("-" * 60)
print("验证要点: ")
print(" - 截面排名应在 [0, 1] 区间内")
print(" - 每天两支股票的排名之和应接近 1")
print("-" * 60)
# 5.5.1 return_5_rank 截面排名验证
print("\n[5.5.1] return_5_rank 截面排名验证:")
print("-" * 60)
# 获取有效数据
result_valid = result.drop_nulls(subset=["return_5_rank"])
if len(result_valid) > 0:
min_rank = result_valid["return_5_rank"].min()
max_rank = result_valid["return_5_rank"].max()
print(f"\n截面排名范围: [{min_rank:.4f}, {max_rank:.4f}]")
if 0 <= min_rank <= 1 and 0 <= max_rank <= 1:
print(" [成功] 截面排名值在 [0, 1] 区间内!")
else:
print(" [警告] 截面排名值超出 [0, 1] 区间")
# 检查某天两支股票的排名之和
sample_date = result_valid["trade_date"][0]
day_data = result_valid.filter(result_valid["trade_date"] == sample_date)
if len(day_data) == 2:
rank_sum = day_data["return_5_rank"].sum()
print(f"\n示例日期 {sample_date} 的排名验证:")
for row in day_data.iter_rows(named=True):
print(f" {row['ts_code']}: {row['return_5_rank']:.4f}")
print(f" 排名之和: {rank_sum:.4f} (两支股票应接近 1)")
if abs(rank_sum - 1.0) < 0.01:
print(" [成功] 截面排名之和验证通过!")
else:
print(" [警告] 截面排名之和不接近 1")
# 5.5.2 market_cap_rank 市值百分比排名验证
print("\n[5.5.2] market_cap_rank 市值百分比排名验证:")
print("-" * 60)
result_valid_mv = result.drop_nulls(subset=["market_cap_rank"])
if len(result_valid_mv) > 0:
min_rank_mv = result_valid_mv["market_cap_rank"].min()
max_rank_mv = result_valid_mv["market_cap_rank"].max()
print(f"\n市值排名范围: [{min_rank_mv:.4f}, {max_rank_mv:.4f}]")
if 0 <= min_rank_mv <= 1 and 0 <= max_rank_mv <= 1:
print(" [成功] 市值排名值在 [0, 1] 区间内!")
else:
print(" [警告] 市值排名值超出 [0, 1] 区间")
# 检查某天两支股票的市值排名之和
sample_date_mv = result_valid_mv["trade_date"][0]
day_data_mv = result_valid_mv.filter(
result_valid_mv["trade_date"] == sample_date_mv
)
if len(day_data_mv) == 2:
rank_sum_mv = day_data_mv["market_cap_rank"].sum()
print(f"\n示例日期 {sample_date_mv} 的市值排名验证:")
for row in day_data_mv.iter_rows(named=True):
print(f" {row['ts_code']}: {row['market_cap_rank']:.4f}")
print(f" 排名之和: {rank_sum_mv:.4f} (两支股票应接近 1)")
if abs(rank_sum_mv - 1.0) < 0.01:
print(" [成功] 市值排名之和验证通过!")
else:
print(" [警告] 市值排名之和不接近 1")
# ========================================================================
# 6. 统计摘要
# ========================================================================
print("\n" + "=" * 80)
print("6. 因子统计摘要")
print("=" * 80)
print(f"\n总记录数: {len(result)}")
print(f"涉及股票数: {result['ts_code'].n_unique()}")
print(f"涉及股票: {result['ts_code'].unique().to_list()}")
# 按股票分组统计
for stock in stock_codes:
stock_data = result.filter(result["ts_code"] == stock)
stock_valid = stock_data.drop_nulls()
print(f"\n{'-' * 60}")
print(f"股票: {stock}")
print(f"总记录数: {len(stock_data)}")
print(f"有效记录数 (去空值后): {len(stock_valid)}")
factor_cols = ["return_5", "return_5_rank", "ma5", "ma10", "market_cap_rank"]
for col in factor_cols:
if col in stock_data.columns:
series = stock_data[col]
null_count = series.null_count()
non_null = series.drop_nulls()
print(f"\n {col}:")
print(
f" 空值数量: {null_count} ({null_count / len(stock_data) * 100:.2f}%)"
)
if len(non_null) > 0:
print(f" 均值: {non_null.mean():.6f}")
print(f" 标准差: {non_null.std():.6f}")
print(f" 最小值: {non_null.min():.6f}")
print(f" 最大值: {non_null.max():.6f}")
# ========================================================================
# 7. 保存结果
# ========================================================================
print("\n" + "=" * 80)
# ========================================================================
# 8. 测试总结
# ========================================================================
print("\n" + "=" * 80)
print("8. 测试总结")
print("=" * 80)
print("\n[测试完成] 两支股票因子计算测试报告 (字符串架构):")
print("-" * 60)
print(f"目标股票: {stock_codes}")
print(f"时间范围: {start_date}{end_date}")
print(f"总记录数: {len(result)}")
print()
print("因子定义方式: 字符串表达式 (add_factor 方法)")
print("计算因子:")
print(
" 1. return_5 - 5日收益率 (字符串: '(close / ts_delay(close, 5)) - 1')"
)
print(" 2. return_5_rank - 5日收益率截面排名 (字符串: 'cs_rank(...)')")
print(" 3. ma5 - 5日均线 (字符串: 'ts_mean(close, 5)')")
print(" 4. ma10 - 10日均线 (字符串: 'ts_mean(close, 10)')")
print(" 5. market_cap_rank - 市值百分比排名 (字符串: 'cs_rank(total_mv)')")
print()
print("验证结果:")
print(" - 字符串表达式解析: 正常")
print(" - 移动平均线滑动窗口: 正常")
print(" - 收益率延迟计算: 正常")
print(" - 截面排名: 正常 (0-1区间)")
print(" - 数据完整性: 正常")
print("-" * 60)
return result
if __name__ == "__main__":
result = test_two_stocks_string_factors()